Skip to content

Commit fcb8bb5

Browse files
author
Peter Yeh
committed
proper device query through rocm api
1 parent 03a29da commit fcb8bb5

File tree

1 file changed

+59
-55
lines changed

1 file changed

+59
-55
lines changed

src/runtime/rocm/rocm_device_api.cc

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,23 +22,21 @@
2222
* \file rocm_device_api.cc
2323
* \brief GPU specific API
2424
*/
25-
#include <tvm/runtime/device_api.h>
26-
2725
#include <dmlc/logging.h>
2826
#include <dmlc/thread_local.h>
29-
#include <tvm/runtime/registry.h>
3027
#include <hip/hip_runtime_api.h>
3128
#include <hsa/hsa.h>
29+
#include <tvm/runtime/device_api.h>
30+
#include <tvm/runtime/registry.h>
31+
3232
#include "rocm_common.h"
3333

3434
namespace tvm {
3535
namespace runtime {
3636

3737
class ROCMDeviceAPI final : public DeviceAPI {
3838
public:
39-
void SetDevice(TVMContext ctx) final {
40-
ROCM_CALL(hipSetDevice(ctx.device_id));
41-
}
39+
void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); }
4240
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
4341
int value = 0;
4442
switch (kind) {
@@ -54,35 +52,59 @@ class ROCMDeviceAPI final : public DeviceAPI {
5452
break;
5553
}
5654
case kMaxThreadsPerBlock: {
57-
value = 1024;
55+
ROCM_CALL(
56+
hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
5857
break;
5958
}
6059
case kWarpSize: {
61-
value = 64;
60+
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id));
61+
break;
62+
}
63+
case kMaxSharedMemoryPerBlock: {
64+
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock,
65+
ctx.device_id));
6266
break;
6367
}
64-
case kMaxSharedMemoryPerBlock: return;
6568
case kComputeVersion: {
66-
hipDeviceProp_t prop;
67-
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
68-
*rv = prop.gcnArch;
69+
std::ostringstream os;
70+
ROCM_CALL(
71+
hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
72+
os << value << ".";
73+
ROCM_CALL(
74+
hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
75+
os << value;
76+
*rv = os.str();
77+
return;
78+
}
79+
case kDeviceName:
80+
return;
81+
case kMaxClockRate: {
82+
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id));
83+
break;
84+
}
85+
case kMultiProcessorCount: {
86+
ROCM_CALL(
87+
hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
88+
break;
89+
}
90+
case kMaxThreadDimensions: {
91+
int dims[3];
92+
ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
93+
ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
94+
ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
95+
96+
std::stringstream ss;
97+
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
98+
*rv = ss.str();
6999
return;
70100
}
71-
case kDeviceName: return;
72-
case kMaxClockRate: return;
73-
case kMultiProcessorCount: return;
74-
case kMaxThreadDimensions: return;
75101
}
76102
*rv = value;
77103
}
78-
void* AllocDataSpace(TVMContext ctx,
79-
size_t nbytes,
80-
size_t alignment,
81-
TVMType type_hint) final {
104+
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
82105
ROCM_CALL(hipSetDevice(ctx.device_id));
83-
CHECK_EQ(256 % alignment, 0U)
84-
<< "ROCM space is aligned at 256 bytes";
85-
void *ret;
106+
CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
107+
void* ret;
86108
ROCM_CALL(hipMalloc(&ret, nbytes));
87109
return ret;
88110
}
@@ -92,14 +114,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
92114
ROCM_CALL(hipFree(ptr));
93115
}
94116

95-
void CopyDataFromTo(const void* from,
96-
size_t from_offset,
97-
void* to,
98-
size_t to_offset,
99-
size_t size,
100-
TVMContext ctx_from,
101-
TVMContext ctx_to,
102-
TVMType type_hint,
117+
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
118+
TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
103119
TVMStreamHandle stream) final {
104120
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
105121
from = static_cast<const char*>(from) + from_offset;
@@ -109,9 +125,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
109125
if (ctx_from.device_id == ctx_to.device_id) {
110126
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
111127
} else {
112-
hipMemcpyPeerAsync(to, ctx_to.device_id,
113-
from, ctx_from.device_id,
114-
size, hip_stream);
128+
hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream);
115129
}
116130
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
117131
ROCM_CALL(hipSetDevice(ctx_from.device_id));
@@ -130,8 +144,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
130144
}
131145

132146
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
133-
ROCMThreadEntry::ThreadLocal()
134-
->stream = static_cast<hipStream_t>(stream);
147+
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
135148
}
136149

137150
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
@@ -143,16 +156,12 @@ class ROCMDeviceAPI final : public DeviceAPI {
143156
}
144157

145158
static const std::shared_ptr<ROCMDeviceAPI>& Global() {
146-
static std::shared_ptr<ROCMDeviceAPI> inst =
147-
std::make_shared<ROCMDeviceAPI>();
159+
static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
148160
return inst;
149161
}
150162

151163
private:
152-
static void GPUCopy(const void* from,
153-
void* to,
154-
size_t size,
155-
hipMemcpyKind kind,
164+
static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind,
156165
hipStream_t stream) {
157166
if (stream != 0) {
158167
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
@@ -164,19 +173,14 @@ class ROCMDeviceAPI final : public DeviceAPI {
164173

165174
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
166175

167-
ROCMThreadEntry::ROCMThreadEntry()
168-
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
169-
}
176+
ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
170177

171-
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
172-
return ROCMThreadStore::Get();
173-
}
178+
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }
174179

175-
TVM_REGISTER_GLOBAL("device_api.rocm")
176-
.set_body([](TVMArgs args, TVMRetValue* rv) {
177-
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
178-
*rv = static_cast<void*>(ptr);
179-
});
180+
TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
181+
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
182+
*rv = static_cast<void*>(ptr);
183+
});
180184

181185
} // namespace runtime
182186
} // namespace tvm

0 commit comments

Comments
 (0)