66 * to you under the Apache License, Version 2.0 (the
77 * "License"); you may not use this file except in compliance
88 * with the License. You may obtain a copy of the License at
9- *
9+ *
1010 * http://www.apache.org/licenses/LICENSE-2.0
11- *
11+ *
1212 * Unless required by applicable law or agreed to in writing,
1313 * software distributed under the License is distributed on an
1414 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
2222 * \file rocm_device_api.cc
2323 * \brief GPU specific API
2424 */
25- #include < tvm/runtime/device_api.h>
26-
2725#include < dmlc/logging.h>
2826#include < dmlc/thread_local.h>
29- #include < tvm/runtime/registry.h>
3027#include < hip/hip_runtime_api.h>
3128#include < hsa/hsa.h>
29+ #include < tvm/runtime/device_api.h>
30+ #include < tvm/runtime/registry.h>
31+
3232#include " rocm_common.h"
3333
3434namespace tvm {
3535namespace runtime {
3636
3737class ROCMDeviceAPI final : public DeviceAPI {
3838 public:
39- void SetDevice (TVMContext ctx) final {
40- ROCM_CALL (hipSetDevice (ctx.device_id ));
41- }
39+ void SetDevice (TVMContext ctx) final { ROCM_CALL (hipSetDevice (ctx.device_id )); }
4240 void GetAttr (TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
4341 int value = 0 ;
4442 switch (kind) {
@@ -54,35 +52,59 @@ class ROCMDeviceAPI final : public DeviceAPI {
5452 break ;
5553 }
5654 case kMaxThreadsPerBlock : {
57- value = 1024 ;
55+ ROCM_CALL (
56+ hipDeviceGetAttribute (&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id ));
5857 break ;
5958 }
6059 case kWarpSize : {
61- value = 64 ;
60+ ROCM_CALL (hipDeviceGetAttribute (&value, hipDeviceAttributeWarpSize, ctx.device_id ));
61+ break ;
62+ }
63+ case kMaxSharedMemoryPerBlock : {
64+ ROCM_CALL (hipDeviceGetAttribute (&value, hipDeviceAttributeMaxSharedMemoryPerBlock,
65+ ctx.device_id ));
6266 break ;
6367 }
64- case kMaxSharedMemoryPerBlock : return ;
6568 case kComputeVersion : {
66- hipDeviceProp_t prop;
67- ROCM_CALL (hipGetDeviceProperties (&prop, ctx.device_id ));
68- *rv = prop.gcnArch ;
69+ std::ostringstream os;
70+ ROCM_CALL (
71+ hipDeviceGetAttribute (&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id ));
72+ os << value << " ." ;
73+ ROCM_CALL (
74+ hipDeviceGetAttribute (&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id ));
75+ os << value;
76+ *rv = os.str ();
77+ return ;
78+ }
79+ case kDeviceName :
80+ return ;
81+ case kMaxClockRate : {
82+ ROCM_CALL (hipDeviceGetAttribute (&value, hipDeviceAttributeClockRate, ctx.device_id ));
83+ break ;
84+ }
85+ case kMultiProcessorCount : {
86+ ROCM_CALL (
87+ hipDeviceGetAttribute (&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id ));
88+ break ;
89+ }
90+ case kMaxThreadDimensions : {
91+ int dims[3 ];
92+ ROCM_CALL (hipDeviceGetAttribute (&dims[0 ], hipDeviceAttributeMaxBlockDimX, ctx.device_id ));
93+ ROCM_CALL (hipDeviceGetAttribute (&dims[1 ], hipDeviceAttributeMaxBlockDimY, ctx.device_id ));
94+ ROCM_CALL (hipDeviceGetAttribute (&dims[2 ], hipDeviceAttributeMaxBlockDimZ, ctx.device_id ));
95+
96+ std::stringstream ss;
97+ ss << " [" << dims[0 ] << " , " << dims[1 ] << " , " << dims[2 ] << " ]" ;
98+ *rv = ss.str ();
6999 return ;
70100 }
71- case kDeviceName : return ;
72- case kMaxClockRate : return ;
73- case kMultiProcessorCount : return ;
74- case kMaxThreadDimensions : return ;
75101 }
76102 *rv = value;
77103 }
78- void * AllocDataSpace (TVMContext ctx,
79- size_t nbytes,
80- size_t alignment,
81- TVMType type_hint) final {
104+ void * AllocDataSpace (TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
82105 ROCM_CALL (hipSetDevice (ctx.device_id ));
83- CHECK_EQ (256 % alignment, 0U )
84- << " ROCM space is aligned at 256 bytes" ;
85- void *ret;
106+ CHECK_EQ (256 % alignment, 0U ) << " ROCM space is aligned at 256 bytes" ;
107+ void * ret;
86108 ROCM_CALL (hipMalloc (&ret, nbytes));
87109 return ret;
88110 }
@@ -92,14 +114,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
92114 ROCM_CALL (hipFree (ptr));
93115 }
94116
95- void CopyDataFromTo (const void * from,
96- size_t from_offset,
97- void * to,
98- size_t to_offset,
99- size_t size,
100- TVMContext ctx_from,
101- TVMContext ctx_to,
102- TVMType type_hint,
117+ void CopyDataFromTo (const void * from, size_t from_offset, void * to, size_t to_offset, size_t size,
118+ TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
103119 TVMStreamHandle stream) final {
104120 hipStream_t hip_stream = static_cast <hipStream_t>(stream);
105121 from = static_cast <const char *>(from) + from_offset;
@@ -109,9 +125,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
109125 if (ctx_from.device_id == ctx_to.device_id ) {
110126 GPUCopy (from, to, size, hipMemcpyDeviceToDevice, hip_stream);
111127 } else {
112- hipMemcpyPeerAsync (to, ctx_to.device_id ,
113- from, ctx_from.device_id ,
114- size, hip_stream);
128+ hipMemcpyPeerAsync (to, ctx_to.device_id , from, ctx_from.device_id , size, hip_stream);
115129 }
116130 } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU ) {
117131 ROCM_CALL (hipSetDevice (ctx_from.device_id ));
@@ -130,8 +144,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
130144 }
131145
132146 void SetStream (TVMContext ctx, TVMStreamHandle stream) final {
133- ROCMThreadEntry::ThreadLocal ()
134- ->stream = static_cast <hipStream_t>(stream);
147+ ROCMThreadEntry::ThreadLocal ()->stream = static_cast <hipStream_t>(stream);
135148 }
136149
137150 void * AllocWorkspace (TVMContext ctx, size_t size, TVMType type_hint) final {
@@ -143,16 +156,12 @@ class ROCMDeviceAPI final : public DeviceAPI {
143156 }
144157
145158 static const std::shared_ptr<ROCMDeviceAPI>& Global () {
146- static std::shared_ptr<ROCMDeviceAPI> inst =
147- std::make_shared<ROCMDeviceAPI>();
159+ static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
148160 return inst;
149161 }
150162
151163 private:
152- static void GPUCopy (const void * from,
153- void * to,
154- size_t size,
155- hipMemcpyKind kind,
164+ static void GPUCopy (const void * from, void * to, size_t size, hipMemcpyKind kind,
156165 hipStream_t stream) {
157166 if (stream != 0 ) {
158167 ROCM_CALL (hipMemcpyAsync (to, from, size, kind, stream));
@@ -164,19 +173,14 @@ class ROCMDeviceAPI final : public DeviceAPI {
164173
165174typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
166175
167- ROCMThreadEntry::ROCMThreadEntry ()
168- : pool(kDLROCM , ROCMDeviceAPI::Global()) {
169- }
176+ ROCMThreadEntry::ROCMThreadEntry () : pool(kDLROCM , ROCMDeviceAPI::Global()) {}
170177
171- ROCMThreadEntry* ROCMThreadEntry::ThreadLocal () {
172- return ROCMThreadStore::Get ();
173- }
178+ ROCMThreadEntry* ROCMThreadEntry::ThreadLocal () { return ROCMThreadStore::Get (); }
174179
175- TVM_REGISTER_GLOBAL (" device_api.rocm" )
176- .set_body([](TVMArgs args, TVMRetValue* rv) {
177- DeviceAPI* ptr = ROCMDeviceAPI::Global ().get ();
178- *rv = static_cast <void *>(ptr);
179- });
180+ TVM_REGISTER_GLOBAL (" device_api.rocm" ).set_body([](TVMArgs args, TVMRetValue* rv) {
181+ DeviceAPI* ptr = ROCMDeviceAPI::Global ().get ();
182+ *rv = static_cast <void *>(ptr);
183+ });
180184
181185} // namespace runtime
182186} // namespace tvm
0 commit comments