66 * to you under the Apache License, Version 2.0 (the
77 * "License"); you may not use this file except in compliance
88 * with the License. You may obtain a copy of the License at
9- *
9+ *
1010 * http://www.apache.org/licenses/LICENSE-2.0
11- *
11+ *
1212 * Unless required by applicable law or agreed to in writing,
1313 * software distributed under the License is distributed on an
1414 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
2222 * \file rocm_device_api.cc
2323 * \brief GPU specific API
2424 */
25- #include < tvm/runtime/device_api.h>
26-
2725#include < dmlc/logging.h>
2826#include < dmlc/thread_local.h>
2927#include < hip/hip_runtime_api.h>
3028#include < hsa/hsa.h>
29+ #include < tvm/runtime/device_api.h>
3130#include < tvm/runtime/registry.h>
32- #include " ../../../include/tvm/runtime/device_api.h"
3331#include " rocm_common.h"
3432
3533namespace tvm {
@@ -55,130 +53,153 @@ class ROCMDeviceAPI final : public DeviceAPI {
5553 break ;
5654 }
5755 case kMaxThreadsPerBlock : {
58- value = 1024 ;
56+ ROCM_CALL (hipDeviceGetAttribute (
57+ &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id ));
5958 break ;
6059 }
6160 case kWarpSize : {
62- value = 64 ;
61+ ROCM_CALL (hipDeviceGetAttribute (&value, hipDeviceAttributeWarpSize,
62+ ctx.device_id ));
63+ break ;
64+ }
65+ case kMaxSharedMemoryPerBlock : {
66+ ROCM_CALL (hipDeviceGetAttribute (
67+ &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id ));
68+ break ;
69+ }
70+ case kComputeVersion : {
71+ std::ostringstream os;
72+ ROCM_CALL (hipDeviceGetAttribute (
73+ &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id ));
74+ os << value << " ." ;
75+ ROCM_CALL (hipDeviceGetAttribute (
76+ &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id ));
77+ os << value;
78+ *rv = os.str ();
79+ return ;
80+ }
81+ case kDeviceName :
82+ return ;
83+ case kMaxClockRate : {
84+ ROCM_CALL (hipDeviceGetAttribute (&value, hipDeviceAttributeClockRate,
85+ ctx.device_id ));
6386 break ;
6487 }
65- case kMaxSharedMemoryPerBlock : return ;
66- case kComputeVersion :
67- case kDeviceName : return ;
68- case kMaxClockRate : return ;
69- case kMultiProcessorCount : return ;
70- case kMaxThreadDimensions : return ;
88+ case kMultiProcessorCount : {
89+ ROCM_CALL (hipDeviceGetAttribute (
90+ &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id ));
91+ break ;
92+ }
93+ case kMaxThreadDimensions : {
94+ int dims[3 ];
95+ ROCM_CALL (hipDeviceGetAttribute (
96+ &dims[0 ], hipDeviceAttributeMaxBlockDimX, ctx.device_id ));
97+ ROCM_CALL (hipDeviceGetAttribute (
98+ &dims[1 ], hipDeviceAttributeMaxBlockDimY, ctx.device_id ));
99+ ROCM_CALL (hipDeviceGetAttribute (
100+ &dims[2 ], hipDeviceAttributeMaxBlockDimZ, ctx.device_id ));
101+
102+ std::stringstream ss;
103+ ss << " [" << dims[0 ] << " , " << dims[1 ] << " , " << dims[2 ] << " ]" ;
104+ *rv = ss.str ();
105+ return ;
106+ }
71107 case kGcnArch : {
72108 hipDeviceProp_t prop;
73109 ROCM_CALL (hipGetDeviceProperties (&prop, ctx.device_id ));
74110 *rv = prop.gcnArch ;
75111 return ;
76112 }
113+ *rv = value;
114+ }
115+ void * AllocDataSpace (TVMContext ctx, size_t nbytes, size_t alignment,
116+ TVMType type_hint) final {
117+ ROCM_CALL (hipSetDevice (ctx.device_id ));
118+ CHECK_EQ (256 % alignment, 0U ) << " ROCM space is aligned at 256 bytes" ;
119+ void * ret;
120+ ROCM_CALL (hipMalloc (&ret, nbytes));
121+ return ret;
77122 }
78- *rv = value;
79- }
80- void * AllocDataSpace (TVMContext ctx,
81- size_t nbytes,
82- size_t alignment,
83- TVMType type_hint) final {
84- ROCM_CALL (hipSetDevice (ctx.device_id ));
85- CHECK_EQ (256 % alignment, 0U )
86- << " ROCM space is aligned at 256 bytes" ;
87- void *ret;
88- ROCM_CALL (hipMalloc (&ret, nbytes));
89- return ret;
90- }
91123
92- void FreeDataSpace (TVMContext ctx, void * ptr) final {
93- ROCM_CALL (hipSetDevice (ctx.device_id ));
94- ROCM_CALL (hipFree (ptr));
95- }
124+ void FreeDataSpace (TVMContext ctx, void * ptr) final {
125+ ROCM_CALL (hipSetDevice (ctx.device_id ));
126+ ROCM_CALL (hipFree (ptr));
127+ }
96128
97- void CopyDataFromTo (const void * from,
98- size_t from_offset,
99- void * to,
100- size_t to_offset,
101- size_t size,
102- TVMContext ctx_from,
103- TVMContext ctx_to,
104- TVMType type_hint,
105- TVMStreamHandle stream) final {
106- hipStream_t hip_stream = static_cast <hipStream_t>(stream);
107- from = static_cast <const char *>(from) + from_offset;
108- to = static_cast <char *>(to) + to_offset;
109- if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM ) {
110- ROCM_CALL (hipSetDevice (ctx_from.device_id ));
111- if (ctx_from.device_id == ctx_to.device_id ) {
112- GPUCopy (from, to, size, hipMemcpyDeviceToDevice, hip_stream);
129+ void CopyDataFromTo (const void * from, size_t from_offset, void * to,
130+ size_t to_offset, size_t size, TVMContext ctx_from,
131+ TVMContext ctx_to, TVMType type_hint,
132+ TVMStreamHandle stream) final {
133+ hipStream_t hip_stream = static_cast <hipStream_t>(stream);
134+ from = static_cast <const char *>(from) + from_offset;
135+ to = static_cast <char *>(to) + to_offset;
136+ if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM ) {
137+ ROCM_CALL (hipSetDevice (ctx_from.device_id ));
138+ if (ctx_from.device_id == ctx_to.device_id ) {
139+ GPUCopy (from, to, size, hipMemcpyDeviceToDevice, hip_stream);
140+ } else {
141+ hipMemcpyPeerAsync (to, ctx_to.device_id , from, ctx_from.device_id ,
142+ size, hip_stream);
143+ }
144+ } else if (ctx_from.device_type == kDLROCM &&
145+ ctx_to.device_type == kDLCPU ) {
146+ ROCM_CALL (hipSetDevice (ctx_from.device_id ));
147+ GPUCopy (from, to, size, hipMemcpyDeviceToHost, hip_stream);
148+ } else if (ctx_from.device_type == kDLCPU &&
149+ ctx_to.device_type == kDLROCM ) {
150+ ROCM_CALL (hipSetDevice (ctx_to.device_id ));
151+ GPUCopy (from, to, size, hipMemcpyHostToDevice, hip_stream);
113152 } else {
114- hipMemcpyPeerAsync (to, ctx_to.device_id ,
115- from, ctx_from.device_id ,
116- size, hip_stream);
153+ LOG (FATAL) << " expect copy from/to GPU or between GPU" ;
117154 }
118- } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU ) {
119- ROCM_CALL (hipSetDevice (ctx_from.device_id ));
120- GPUCopy (from, to, size, hipMemcpyDeviceToHost, hip_stream);
121- } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM ) {
122- ROCM_CALL (hipSetDevice (ctx_to.device_id ));
123- GPUCopy (from, to, size, hipMemcpyHostToDevice, hip_stream);
124- } else {
125- LOG (FATAL) << " expect copy from/to GPU or between GPU" ;
126155 }
127- }
128-
129- void StreamSync (TVMContext ctx, TVMStreamHandle stream) final {
130- ROCM_CALL (hipSetDevice (ctx.device_id ));
131- ROCM_CALL (hipStreamSynchronize (static_cast <hipStream_t>(stream)));
132- }
133156
134- void SetStream (TVMContext ctx, TVMStreamHandle stream) final {
135- ROCMThreadEntry::ThreadLocal ()
136- -> stream = static_cast <hipStream_t>(stream);
137- }
157+ void StreamSync (TVMContext ctx, TVMStreamHandle stream) final {
158+ ROCM_CALL ( hipSetDevice (ctx. device_id ));
159+ ROCM_CALL ( hipStreamSynchronize ( static_cast <hipStream_t>(stream)) );
160+ }
138161
139- void * AllocWorkspace (TVMContext ctx, size_t size, TVMType type_hint ) final {
140- return ROCMThreadEntry::ThreadLocal ()->pool . AllocWorkspace (ctx, size );
141- }
162+ void SetStream (TVMContext ctx, TVMStreamHandle stream ) final {
163+ ROCMThreadEntry::ThreadLocal ()->stream = static_cast <hipStream_t>(stream );
164+ }
142165
143- void FreeWorkspace (TVMContext ctx, void * data ) final {
144- ROCMThreadEntry::ThreadLocal ()->pool .FreeWorkspace (ctx, data );
145- }
166+ void * AllocWorkspace (TVMContext ctx, size_t size, TVMType type_hint ) final {
167+ return ROCMThreadEntry::ThreadLocal ()->pool .AllocWorkspace (ctx, size );
168+ }
146169
147- static const std::shared_ptr<ROCMDeviceAPI>& Global () {
148- static std::shared_ptr<ROCMDeviceAPI> inst =
149- std::make_shared<ROCMDeviceAPI>();
150- return inst;
151- }
170+ void FreeWorkspace (TVMContext ctx, void * data) final {
171+ ROCMThreadEntry::ThreadLocal ()->pool .FreeWorkspace (ctx, data);
172+ }
152173
153- private:
154- static void GPUCopy (const void * from,
155- void * to,
156- size_t size,
157- hipMemcpyKind kind,
158- hipStream_t stream) {
159- if (stream != 0 ) {
160- ROCM_CALL (hipMemcpyAsync (to, from, size, kind, stream));
161- } else {
162- ROCM_CALL (hipMemcpy (to, from, size, kind));
174+ static const std::shared_ptr<ROCMDeviceAPI>& Global () {
175+ static std::shared_ptr<ROCMDeviceAPI> inst =
176+ std::make_shared<ROCMDeviceAPI>();
177+ return inst;
163178 }
164- }
165- };
166179
167- typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
180+ private:
181+ static void GPUCopy (const void * from, void * to, size_t size,
182+ hipMemcpyKind kind, hipStream_t stream) {
183+ if (stream != 0 ) {
184+ ROCM_CALL (hipMemcpyAsync (to, from, size, kind, stream));
185+ } else {
186+ ROCM_CALL (hipMemcpy (to, from, size, kind));
187+ }
188+ }
189+ };
168190
169- ROCMThreadEntry::ROCMThreadEntry ()
170- : pool(kDLROCM , ROCMDeviceAPI::Global()) {
171- }
191+ typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
172192
173- ROCMThreadEntry* ROCMThreadEntry::ThreadLocal () {
174- return ROCMThreadStore::Get ();
175- }
193+ ROCMThreadEntry::ROCMThreadEntry () : pool(kDLROCM , ROCMDeviceAPI::Global()) {}
176194
177- TVM_REGISTER_GLOBAL (" device_api.rocm" )
178- .set_body([](TVMArgs args, TVMRetValue* rv) {
179- DeviceAPI* ptr = ROCMDeviceAPI::Global ().get ();
180- *rv = static_cast <void *>(ptr);
181- });
195+ ROCMThreadEntry* ROCMThreadEntry::ThreadLocal () {
196+ return ROCMThreadStore::Get ();
197+ }
182198
199+ TVM_REGISTER_GLOBAL (" device_api.rocm" )
200+ .set_body([](TVMArgs args, TVMRetValue* rv) {
201+ DeviceAPI* ptr = ROCMDeviceAPI::Global ().get ();
202+ *rv = static_cast <void *>(ptr);
203+ });
183204} // namespace runtime
184205} // namespace tvm
0 commit comments