Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,14 @@ def get_target_compute_version(target=None):
# 2. Target.current()
target = target or Target.current()
if target and target.arch:
major, minor = target.arch.split("_")[1]
return major + "." + minor
arch = target.arch.split("_")[1]
if len(arch) == 2:
major, minor = arch
return major + "." + minor
elif len(arch) == 3:
# This is for arch like "sm_90a"
major, minor, suffix = arch
return major + "." + minor + "." + suffix

# 3. GPU compute version
if tvm.cuda(0).exist:
Expand Down
2 changes: 1 addition & 1 deletion src/target/tag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536)
.with_config("l2_cache_size_bytes", Integer(41943040));
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90", 49152, 65536)
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536)
.with_config("l2_cache_size_bytes", Integer(52428800));
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536);
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536);
Expand Down
37 changes: 7 additions & 30 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>

#include "../node/attr_registry.h"
#include "../support/utils.h"
#include "./parsers/cpu.h"

namespace tvm {
Expand Down Expand Up @@ -81,30 +82,6 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {

/********** Utility functions **********/

/*!
* \brief Extract a number from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
* This function first checks if `str` starts with `prefix`,
* then return the integer 20 after the `prefix`
* \param str The string to be extracted
* \param prefix The prefix to be checked
* \return An integer, the extracted number. -1 if the check fails
*/
static int ExtractIntWithPrefix(const std::string& str, const std::string& prefix) {
if (str.substr(0, prefix.size()) != prefix) {
return -1;
}
int result = 0;
for (size_t i = prefix.size(); i < str.size(); ++i) {
char c = str[i];
if (!isdigit(c)) {
return -1;
}
result = result * 10 + c - '0';
}
return result;
}

/*!
* \brief Extract a string from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
Expand Down Expand Up @@ -168,14 +145,14 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
*/
TargetJSON UpdateCUDAAttrs(TargetJSON target) {
// Update -arch=sm_xx
int archInt;
if (target.count("arch")) {
// If -arch has been specified, validate the correctness
String archStr = Downcast<String>(target.at("arch"));
archInt = ExtractIntWithPrefix(archStr, "sm_");
ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
ICHECK(support::StartsWith(archStr, "sm_"))
<< "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
} else {
// Use the compute version of the first CUDA GPU instead
int archInt;
TVMRetValue version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead";
Expand All @@ -196,14 +173,14 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) {
TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
int arch;
if (target.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "sm_");
ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
ICHECK(support::StartsWith(mcpu, "sm_"))
<< "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
// Use the compute version of the first CUDA GPU instead
int arch;
TVMRetValue version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead";
Expand Down