Skip to content

Commit 18df399

Browse files
author
Peter Yeh
committed
proper device query through rocm api
1 parent b0b16a0 commit 18df399

File tree

1 file changed

+123
-102
lines changed

1 file changed

+123
-102
lines changed

src/runtime/rocm/rocm_device_api.cc

Lines changed: 123 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,14 +22,12 @@
2222
* \file rocm_device_api.cc
2323
* \brief GPU specific API
2424
*/
25-
#include <tvm/runtime/device_api.h>
26-
2725
#include <dmlc/logging.h>
2826
#include <dmlc/thread_local.h>
2927
#include <hip/hip_runtime_api.h>
3028
#include <hsa/hsa.h>
29+
#include <tvm/runtime/device_api.h>
3130
#include <tvm/runtime/registry.h>
32-
#include "../../../include/tvm/runtime/device_api.h"
3331
#include "rocm_common.h"
3432

3533
namespace tvm {
@@ -55,130 +53,153 @@ class ROCMDeviceAPI final : public DeviceAPI {
5553
break;
5654
}
5755
case kMaxThreadsPerBlock: {
58-
value = 1024;
56+
ROCM_CALL(hipDeviceGetAttribute(
57+
&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
5958
break;
6059
}
6160
case kWarpSize: {
62-
value = 64;
61+
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize,
62+
ctx.device_id));
63+
break;
64+
}
65+
case kMaxSharedMemoryPerBlock: {
66+
ROCM_CALL(hipDeviceGetAttribute(
67+
&value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
68+
break;
69+
}
70+
case kComputeVersion: {
71+
std::ostringstream os;
72+
ROCM_CALL(hipDeviceGetAttribute(
73+
&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
74+
os << value << ".";
75+
ROCM_CALL(hipDeviceGetAttribute(
76+
&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
77+
os << value;
78+
*rv = os.str();
79+
return;
80+
}
81+
case kDeviceName:
82+
return;
83+
case kMaxClockRate: {
84+
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate,
85+
ctx.device_id));
6386
break;
6487
}
65-
case kMaxSharedMemoryPerBlock: return;
66-
case kComputeVersion:
67-
case kDeviceName: return;
68-
case kMaxClockRate: return;
69-
case kMultiProcessorCount: return;
70-
case kMaxThreadDimensions: return;
88+
case kMultiProcessorCount: {
89+
ROCM_CALL(hipDeviceGetAttribute(
90+
&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
91+
break;
92+
}
93+
case kMaxThreadDimensions: {
94+
int dims[3];
95+
ROCM_CALL(hipDeviceGetAttribute(
96+
&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
97+
ROCM_CALL(hipDeviceGetAttribute(
98+
&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
99+
ROCM_CALL(hipDeviceGetAttribute(
100+
&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
101+
102+
std::stringstream ss;
103+
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
104+
*rv = ss.str();
105+
return;
106+
}
71107
case kGcnArch: {
72108
hipDeviceProp_t prop;
73109
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
74110
*rv = prop.gcnArch;
75111
return;
76112
}
113+
*rv = value;
114+
}
115+
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
116+
TVMType type_hint) final {
117+
ROCM_CALL(hipSetDevice(ctx.device_id));
118+
CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
119+
void* ret;
120+
ROCM_CALL(hipMalloc(&ret, nbytes));
121+
return ret;
77122
}
78-
*rv = value;
79-
}
80-
void* AllocDataSpace(TVMContext ctx,
81-
size_t nbytes,
82-
size_t alignment,
83-
TVMType type_hint) final {
84-
ROCM_CALL(hipSetDevice(ctx.device_id));
85-
CHECK_EQ(256 % alignment, 0U)
86-
<< "ROCM space is aligned at 256 bytes";
87-
void *ret;
88-
ROCM_CALL(hipMalloc(&ret, nbytes));
89-
return ret;
90-
}
91123

92-
void FreeDataSpace(TVMContext ctx, void* ptr) final {
93-
ROCM_CALL(hipSetDevice(ctx.device_id));
94-
ROCM_CALL(hipFree(ptr));
95-
}
124+
void FreeDataSpace(TVMContext ctx, void* ptr) final {
125+
ROCM_CALL(hipSetDevice(ctx.device_id));
126+
ROCM_CALL(hipFree(ptr));
127+
}
96128

97-
void CopyDataFromTo(const void* from,
98-
size_t from_offset,
99-
void* to,
100-
size_t to_offset,
101-
size_t size,
102-
TVMContext ctx_from,
103-
TVMContext ctx_to,
104-
TVMType type_hint,
105-
TVMStreamHandle stream) final {
106-
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
107-
from = static_cast<const char*>(from) + from_offset;
108-
to = static_cast<char*>(to) + to_offset;
109-
if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
110-
ROCM_CALL(hipSetDevice(ctx_from.device_id));
111-
if (ctx_from.device_id == ctx_to.device_id) {
112-
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
129+
void CopyDataFromTo(const void* from, size_t from_offset, void* to,
130+
size_t to_offset, size_t size, TVMContext ctx_from,
131+
TVMContext ctx_to, TVMType type_hint,
132+
TVMStreamHandle stream) final {
133+
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
134+
from = static_cast<const char*>(from) + from_offset;
135+
to = static_cast<char*>(to) + to_offset;
136+
if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
137+
ROCM_CALL(hipSetDevice(ctx_from.device_id));
138+
if (ctx_from.device_id == ctx_to.device_id) {
139+
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
140+
} else {
141+
hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id,
142+
size, hip_stream);
143+
}
144+
} else if (ctx_from.device_type == kDLROCM &&
145+
ctx_to.device_type == kDLCPU) {
146+
ROCM_CALL(hipSetDevice(ctx_from.device_id));
147+
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
148+
} else if (ctx_from.device_type == kDLCPU &&
149+
ctx_to.device_type == kDLROCM) {
150+
ROCM_CALL(hipSetDevice(ctx_to.device_id));
151+
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
113152
} else {
114-
hipMemcpyPeerAsync(to, ctx_to.device_id,
115-
from, ctx_from.device_id,
116-
size, hip_stream);
153+
LOG(FATAL) << "expect copy from/to GPU or between GPU";
117154
}
118-
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
119-
ROCM_CALL(hipSetDevice(ctx_from.device_id));
120-
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
121-
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
122-
ROCM_CALL(hipSetDevice(ctx_to.device_id));
123-
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
124-
} else {
125-
LOG(FATAL) << "expect copy from/to GPU or between GPU";
126155
}
127-
}
128-
129-
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
130-
ROCM_CALL(hipSetDevice(ctx.device_id));
131-
ROCM_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
132-
}
133156

134-
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
135-
ROCMThreadEntry::ThreadLocal()
136-
->stream = static_cast<hipStream_t>(stream);
137-
}
157+
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
158+
ROCM_CALL(hipSetDevice(ctx.device_id));
159+
ROCM_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
160+
}
138161

139-
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
140-
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
141-
}
162+
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
163+
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
164+
}
142165

143-
void FreeWorkspace(TVMContext ctx, void* data) final {
144-
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
145-
}
166+
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
167+
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
168+
}
146169

147-
static const std::shared_ptr<ROCMDeviceAPI>& Global() {
148-
static std::shared_ptr<ROCMDeviceAPI> inst =
149-
std::make_shared<ROCMDeviceAPI>();
150-
return inst;
151-
}
170+
void FreeWorkspace(TVMContext ctx, void* data) final {
171+
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
172+
}
152173

153-
private:
154-
static void GPUCopy(const void* from,
155-
void* to,
156-
size_t size,
157-
hipMemcpyKind kind,
158-
hipStream_t stream) {
159-
if (stream != 0) {
160-
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
161-
} else {
162-
ROCM_CALL(hipMemcpy(to, from, size, kind));
174+
static const std::shared_ptr<ROCMDeviceAPI>& Global() {
175+
static std::shared_ptr<ROCMDeviceAPI> inst =
176+
std::make_shared<ROCMDeviceAPI>();
177+
return inst;
163178
}
164-
}
165-
};
166179

167-
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
180+
private:
181+
static void GPUCopy(const void* from, void* to, size_t size,
182+
hipMemcpyKind kind, hipStream_t stream) {
183+
if (stream != 0) {
184+
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
185+
} else {
186+
ROCM_CALL(hipMemcpy(to, from, size, kind));
187+
}
188+
}
189+
};
168190

169-
ROCMThreadEntry::ROCMThreadEntry()
170-
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
171-
}
191+
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
172192

173-
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
174-
return ROCMThreadStore::Get();
175-
}
193+
ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
176194

177-
TVM_REGISTER_GLOBAL("device_api.rocm")
178-
.set_body([](TVMArgs args, TVMRetValue* rv) {
179-
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
180-
*rv = static_cast<void*>(ptr);
181-
});
195+
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
196+
return ROCMThreadStore::Get();
197+
}
182198

199+
TVM_REGISTER_GLOBAL("device_api.rocm")
200+
.set_body([](TVMArgs args, TVMRetValue* rv) {
201+
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
202+
*rv = static_cast<void*>(ptr);
203+
});
183204
} // namespace runtime
184205
} // namespace tvm

0 commit comments

Comments
 (0)