Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 70 additions & 47 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -22,14 +22,13 @@
* \file rocm_device_api.cc
* \brief GPU specific API
*/
#include <tvm/runtime/device_api.h>

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <hip/hip_runtime_api.h>
#include <hsa/hsa.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include "../../../include/tvm/runtime/device_api.h"

#include "rocm_common.h"

namespace tvm {
Expand All @@ -55,19 +54,57 @@ class ROCMDeviceAPI final : public DeviceAPI {
break;
}
case kMaxThreadsPerBlock: {
value = 1024;
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
break;
}
case kWarpSize: {
value = 64;
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize,
ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion:
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kMaxSharedMemoryPerBlock: {
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
break;
}
case kComputeVersion: {
std::ostringstream os;
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
os << value << ".";
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
}
case kDeviceName:
return;
case kMaxClockRate: {
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate,
ctx.device_id));
break;
}
case kMultiProcessorCount: {
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
ROCM_CALL(hipDeviceGetAttribute(
&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(
&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(
&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));

std::stringstream ss;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
case kGcnArch: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
Expand All @@ -77,14 +114,11 @@ class ROCMDeviceAPI final : public DeviceAPI {
}
*rv = value;
}
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
TVMType type_hint) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
<< "ROCM space is aligned at 256 bytes";
void *ret;
CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
void* ret;
ROCM_CALL(hipMalloc(&ret, nbytes));
return ret;
}
Expand All @@ -94,14 +128,9 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCM_CALL(hipFree(ptr));
}

void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
void CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, TVMContext ctx_from,
TVMContext ctx_to, TVMType type_hint,
TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
Expand All @@ -111,14 +140,15 @@ class ROCMDeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
} else {
hipMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, hip_stream);
hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size,
hip_stream);
}
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
} else if (ctx_from.device_type == kDLROCM &&
ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
} else if (ctx_from.device_type == kDLCPU &&
ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else {
Expand All @@ -132,8 +162,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
}

void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
ROCMThreadEntry::ThreadLocal()
->stream = static_cast<hipStream_t>(stream);
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
}

void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
Expand All @@ -151,11 +180,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
}

private:
static void GPUCopy(const void* from,
void* to,
size_t size,
hipMemcpyKind kind,
hipStream_t stream) {
static void GPUCopy(const void* from, void* to, size_t size,
hipMemcpyKind kind, hipStream_t stream) {
if (stream != 0) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
} else {
Expand All @@ -166,19 +192,16 @@ class ROCMDeviceAPI final : public DeviceAPI {

typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;

ROCMThreadEntry::ROCMThreadEntry()
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}

ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
return ROCMThreadStore::Get();
}

TVM_REGISTER_GLOBAL("device_api.rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});

.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm