diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index ce4f4d4e208a..b5823b76a7c9 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -73,6 +73,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc" ) endif() diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h index 5d5d908f78ba..1211ab0eeb1b 100644 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -29,6 +29,39 @@ extern "C" { #endif +// ---------------------------------------------------------------------------- +// Stream context +// Focusing on minimalistic thread-local context recording stream being used. +// We explicitly not handle allocation/de-allocation of stream here. +// ---------------------------------------------------------------------------- +typedef void* TVMFFIStreamHandle; + +/*! + * \brief FFI function to set the current stream for a device + * + * \param device_type The type of the device. + * \param device_id The id of the device. + * \param stream The stream to set. + * \param opt_out_original_stream Output original stream if the address is not nullptr. + * \note The stream is a weak reference that is cached/owned by the module. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream); + +/*! + * \brief FFI function to get the current stream for a device + * + * \param device_type The type of the device. + * \param device_id The id of the device. + * \return The current stream of the device. + */ +TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id); + +// ---------------------------------------------------------------------------- +// Module symbol management +// ---------------------------------------------------------------------------- /*! * \brief FFI function to lookup a function from a module's imports. * diff --git a/ffi/src/ffi/extra/stream_context.cc b/ffi/src/ffi/extra/stream_context.cc new file mode 100644 index 000000000000..d063efdef579 --- /dev/null +++ b/ffi/src/ffi/extra/stream_context.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/stream_context.cc + * + * \brief A minimalistic stream context based on ffi values. + */ + +#include +#include + +#include + +namespace tvm { +namespace ffi { + +class StreamContext { + public: + void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + if (static_cast(device_type) >= stream_table_.size()) { + stream_table_.resize(device_type + 1); + } + if (static_cast(device_id) >= stream_table_[device_type].size()) { + stream_table_[device_type].resize(device_id + 1, nullptr); + } + if (out_original_stream != nullptr) { + *out_original_stream = stream_table_[device_type][device_id]; + } + stream_table_[device_type][device_id] = stream; + } + + TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { + if (static_cast(device_type) < stream_table_.size() && + static_cast(device_id) < stream_table_[device_type].size()) { + return stream_table_[device_type][device_id]; + } + return nullptr; + } + + static StreamContext* ThreadLocal() { + static thread_local StreamContext inst; + return &inst; + } + + private: + std::vector> stream_table_; +}; + +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream, + out_original_stream); + TVM_FFI_SAFE_CALL_END(); +} + +TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream); +} diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 7366b9895d5e..f14b22c57628 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -225,7 +225,7 @@ class TVM_DLL DeviceAPI { * \param dev The device to set stream. * \param stream The stream to be set. */ - virtual void SetStream(Device dev, TVMStreamHandle stream) {} + virtual void SetStream(Device dev, TVMStreamHandle stream); /*! * \brief Get the current stream * \param dev The device to get stream. diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 4c876142d3d0..fe29cd59459b 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -147,8 +147,7 @@ def instantiate_attention_template(attrs): } CHECK(Attention::check_supported(p)); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); kernel_fn<<>>(p); @@ -186,8 +185,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${query}->data), @@ -237,8 +235,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${qkv}->data), @@ -294,8 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs): int v_row_stride = v_head_stride * ${num_kv_heads}; int o_row_stride = o_head_stride * ${num_q_heads}; - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_var_len_forward( static_cast(${query}->data), diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 361bcb54e532..b0afdcdd6e84 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -424,8 +424,7 @@ def instantiate_conv2d_template(attrs): TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); ${split_k_update} - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id)); status = conv2d_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 65dc5da772c1..453839cc8130 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -345,8 +345,7 @@ def instantiate_gemm_template(attrs): status = gemm_op.initialize(arguments, workspace.get()); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); status = gemm_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); @@ -428,8 +427,8 @@ def emit_fp16A_intB_matmul(attrs): int n = ${B_arg}->shape[1] * ${float_per_int}; int k = ${B_arg}->shape[0]; - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); """, attrs, ) @@ -447,12 +446,14 @@ def emit_fp16A_intB_matmul(attrs): template_residual = """ ${template_common} - gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(static_cast(${A_arg}->data), + gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>( + static_cast(${A_arg}->data), static_cast<${weight_dtype}*>(${B_arg}->data), static_cast(${scales_arg}->data), ${bias}, static_cast(${residual_arg}->data), - static_cast(out0->data), "${activation}", "${binary_op}", "${unary_op}", + static_cast(out0->data), + "${activation}", "${binary_op}", "${unary_op}", m, n, k, ${group_size}, nullptr, 0, stream); """ diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6fa349b28e44..c594b3897a6c 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - headers = ["tvm/ffi/function.h"] + headers = ["tvm/ffi/function.h", "tvm/ffi/extra/c_env_api.h"] if "relu" in func_name: headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h") diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py b/python/tvm/contrib/cutlass/layer_norm_operation.py index 74f397b39ad3..d2a031024475 100644 --- a/python/tvm/contrib/cutlass/layer_norm_operation.py +++ b/python/tvm/contrib/cutlass/layer_norm_operation.py @@ -39,8 +39,7 @@ def instantiate_layer_norm_template(attrs): cutlass::TensorRef _beta((data_type*)${beta}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); cutlass::layernorm(size, _output, _input, _gamma, _beta, stream); """ diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py b/python/tvm/contrib/cutlass/rms_norm_operation.py index 27e98fb251cf..51c18d4ae47b 100644 --- a/python/tvm/contrib/cutlass/rms_norm_operation.py +++ b/python/tvm/contrib/cutlass/rms_norm_operation.py @@ -38,8 +38,7 @@ def instantiate_rms_norm_template(attrs): cutlass::TensorRef _weight((data_type*)${weight}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps}); """ diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index a3861aabe75e..7410867aaf25 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -230,6 +230,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { const auto& attr_name = MetaAttrCls(plugin); const auto& func_name = ComputeName(plugin); String device_cond = ""; + String device_index = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { String device_type = ""; if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { @@ -381,7 +382,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; compute_args.push_back("meta_attr"); if (device == "cuda") { - stack_.assign("stream", "runtime::CUDAThreadEntry::ThreadLocal()->stream", "auto"); + // TODO(tvm-team): update to support get stream from device id + stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", "auto"); compute_args.push_back("stream"); } CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index d55e0535c228..13f958744e61 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,6 +20,7 @@ /*! * \file Use external cblas library call. */ +#include #include #include #include @@ -522,7 +523,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto A = args[0].cast(); auto C = args[2].cast(); - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device); CUBLASTryEnableTensorCore(entry_ptr->handle); @@ -549,15 +550,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ "tvm.contrib.cublaslt.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device); CUBLASTryEnableTensorCore(entry_ptr->handle); ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = + static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, A->device.device_id)); CallLtIgemm(args, ret, ltHandle, stream); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); @@ -571,7 +572,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto A = args[0].cast(); auto C = args[2].cast(); - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device); CUBLASTryEnableTensorCore(entry_ptr->handle); if (TypeEqual(A->dtype, C->dtype)) { diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 11fa3b0c4d49..0416391303ad 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -22,6 +22,7 @@ * \brief A simple JSON runtime for CUBLAS. */ +#include #include #include #include @@ -30,6 +31,7 @@ #include #include +#include "../../cuda/cuda_common.h" #include "../json/json_node.h" #include "../json/json_runtime.h" #include "cublas_utils.h" @@ -67,13 +69,8 @@ class CublasJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "cublas_json"; } // May be overridden void Run(ffi::PackedArgs args) { - auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); - - auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); - std::vector dl_tensors(NumEntries()); - + int device_id = -1; for (size_t i = 0; i < static_cast(args.size()); i++) { auto eid = i < input_var_eid_.size() ? input_var_eid_[i] : EntryID(outputs_[i - input_var_eid_.size()]); @@ -87,7 +84,14 @@ class CublasJSONRuntime : public JSONRuntimeBase { } dl_tensors[eid] = arg; + device_id = arg->device.device_id; + } + + if (device_id == -1) { + CUDA_CALL(cudaGetDevice(&device_id)); } + auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 53e00fe14199..0ba654c9ebc8 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -23,6 +23,7 @@ #include "cublas_utils.h" #include +#include #include #include "../../cuda/cuda_common.h" @@ -41,10 +42,11 @@ CuBlasThreadEntry::~CuBlasThreadEntry() { typedef dmlc::ThreadLocalStore CuBlasThreadStore; -CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { - auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; +CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) { CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); - CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, static_cast(stream))); + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream)); return retval; } @@ -71,7 +73,9 @@ CuBlasLtThreadEntry::~CuBlasLtThreadEntry() { typedef dmlc::ThreadLocalStore CuBlasLtThreadStore; -CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return CuBlasLtThreadStore::Get(); } +CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) { + return CuBlasLtThreadStore::Get(); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 3e9ded08deb1..12260a78ef6b 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -75,7 +75,7 @@ struct CuBlasThreadEntry { CuBlasThreadEntry(); ~CuBlasThreadEntry(); cublasHandle_t handle{nullptr}; - static CuBlasThreadEntry* ThreadLocal(); + static CuBlasThreadEntry* ThreadLocal(DLDevice curr_device); }; // CuBlasThreadEntry struct CuBlasLtThreadEntry { @@ -89,7 +89,7 @@ struct CuBlasLtThreadEntry { // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace. static constexpr const size_t workspace_size = 33554432; - static CuBlasLtThreadEntry* ThreadLocal(); + static CuBlasLtThreadEntry* ThreadLocal(DLDevice curr_device); }; // CuBlasLtThreadEntry inline cudaDataType_t GetCudaDataType(DLDataType type) { diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 915f21bc7ca6..515263ef364e 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -35,7 +35,7 @@ using namespace runtime; void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], DLTensor* dy, DLTensor* w, DLTensor* dx, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(dy->device); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape, @@ -65,7 +65,9 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con const int dilation[], const int dy_dim[], const int w_dim[], const int dx_dim[], const std::string& data_dtype, const std::string& conv_dtype, bool verbose, ffi::Any* ret) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); const int full_dims = dims + 2; std::vector dy_dim_int64(full_dims); std::vector w_dim_int64(full_dims); @@ -112,7 +114,7 @@ void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int gro const int pad[], const int stride[], const int dilation[], DLTensor* dy, DLTensor* x, DLTensor* dw, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape, @@ -142,7 +144,9 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c const int dilation[], const int dy_dim[], const int x_dim[], const int dw_dim[], const std::string& data_dtype, const std::string& conv_dtype, bool verbose, ffi::Any* ret) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); const int full_dims = dims + 2; std::vector x_dim_int64(full_dims); std::vector dy_dim_int64(full_dims); diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index a0a9edef9765..7a93e194ce3c 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -35,7 +35,7 @@ using namespace runtime; void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], const DLTensor* x, const DLTensor* w, const DLTensor* y, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, @@ -69,7 +69,7 @@ void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, const int dilation[], const DLTensor* x, const DLTensor* w, const DLTensor* y, const DLTensor* bias, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc, @@ -110,7 +110,9 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, bool verbose, ffi::Any* ret) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); const int full_dims = dims + 2; std::vector x_dim_int64(full_dims); std::vector w_dim_int64(full_dims); diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc index dffce6738907..fbde314bc6ae 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -98,13 +98,13 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options); CHECK(stats == nullptr); o->set_output(true).set_dim({batch, num_heads, seq_len, head_size_v}).set_stride(o_stride); - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, {cudnn_frontend::HeurMode_t::A})); } void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out) { - CUDNN_CALL( - cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle, tvm::runtime::GetCUDAStream())); auto* qkv_base = reinterpret_cast(qkv->data) + qkv->byte_offset; auto* q_ptr = reinterpret_cast(qkv_base) + offset_q_; auto* k_ptr = reinterpret_cast(qkv_base) + offset_k_; @@ -116,7 +116,7 @@ void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, DLTensor std::unordered_map inputs = { {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, {kTensorIDOut, out_ptr}}; - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(qkv->device); CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs, workspace->data)); } diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index fd4fa68c783c..3888bca3df04 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -22,6 +22,7 @@ * \brief A simple JSON runtime for CUDNN. */ +#include #include #include #include @@ -100,7 +101,9 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { } std::function GetConv2DExec(const JSONGraphNode& node) { - auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); auto op_name = node.GetOpName(); std::vector input_dims, kernel_dims, output_dims; @@ -159,7 +162,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { int algo = best_algo.cast(); std::function op_exec = [=]() { - auto stream = static_cast(GetCUDAStream()); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index f5bb56e089e1..acedf7a9e2dd 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -24,6 +24,7 @@ #include "cudnn_utils.h" #include +#include #include #include #include @@ -101,7 +102,6 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { // CuDNNThreadEntry CuDNNThreadEntry::CuDNNThreadEntry() { - auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; auto func = tvm::ffi::Function::GetGlobalRequired("device_api.cuda"); void* ret = func().cast(); cuda_api = static_cast(ret); @@ -116,8 +116,6 @@ CuDNNThreadEntry::CuDNNThreadEntry() { } CUDNN_CALL(create_res); } - - CUDNN_CALL(cudnnSetStream(handle, stream)); conv_entry.cuda_api = cuda_api; } @@ -125,12 +123,15 @@ CuDNNThreadEntry::~CuDNNThreadEntry() {} typedef dmlc::ThreadLocalStore CuDNNThreadStore; -CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) { +CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool check_exists) { auto* res = CuDNNThreadStore::Get(); if (check_exists) { ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; } + cudaStream_t stream = static_cast( + TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + CUDNN_CALL(cudnnSetStream(res->handle, stream)); return res; } @@ -268,8 +269,11 @@ SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_de TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tvm.contrib.cudnn.exists", - []() -> bool { return CuDNNThreadEntry::ThreadLocal(false)->exists(); }); + refl::GlobalDef().def("tvm.contrib.cudnn.exists", []() -> bool { + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + return CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}, false)->exists(); + }); }); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 902b61532353..499cc5d6c9e5 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -106,7 +106,7 @@ struct CuDNNThreadEntry { ConvEntry conv_entry; SoftmaxEntry softmax_entry; runtime::DeviceAPI* cuda_api{nullptr}; - static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); + static CuDNNThreadEntry* ThreadLocal(Device curr_device, bool check_exists = true); }; // CuDNNThreadEntry void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index f0fda4fd59d2..eb2fceb3d2db 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -40,8 +40,9 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r int64_t* shape = x->shape; if (axis < 0) axis += ndim; ICHECK(axis >= 0 && axis < ndim); - - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); // Set mode and shape descriptor diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh index ebb8f58a6b18..a09051a86e79 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -36,7 +36,7 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDAr NDArray out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index befef1db936f..5cabd0ca7af2 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -42,8 +43,8 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray NDArray out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = + static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); CHECK_GE(x->ndim, 2); CHECK_EQ(weight->ndim, 2); @@ -68,7 +69,7 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray static_cast(alpha->data), beta, static_cast(out->data), stream); } else { tvm::contrib::CuBlasLtThreadEntry* cublas_entry = - tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); + tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(x->device); tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc, x.operator->(), weight.operator->(), nullptr, alpha.operator->(), nullptr, out.operator->(), /*transa=*/false, /*transb=*/true, diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index f9f03fc4ed3c..150485b86822 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -45,8 +46,8 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr NDArray alpha, NDArray out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = + static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index 4ecca5f1d8a9..0f688616d55e 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "cutlass/bfloat16.h" @@ -39,9 +40,7 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray sc NDArray out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - static tvm::ffi::Function get_stream_func = - tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(get_stream_func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); CHECK_GE(a->ndim, 2); CHECK_EQ(scales_a->ndim, a->ndim); @@ -107,9 +106,7 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, NDArray b, NDArray sca NDArray out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - static tvm::ffi::Function get_stream_func = - tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(get_stream_func().cast()); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 3); CHECK_EQ(scales_a->ndim, 3); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index 5b467c9bd504..2745c0b1fc03 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -37,8 +37,8 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray sca NDArray out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommended size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); + cudaStream_t stream = + static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 2); CHECK_EQ(b->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index 4e7a5c5d1037..628ffb5bdf8a 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -416,7 +416,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto A = args[0].cast(); auto C = args[2].cast(); - HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); if (TypeEqual(A->dtype, C->dtype)) { ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || @@ -438,7 +438,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto A = args[0].cast(); auto C = args[2].cast(); - HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(); + HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); if (TypeEqual(A->dtype, C->dtype)) { ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 5750b91ab4ca..ab8545561be4 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -65,10 +65,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "hipblas_json"; } // May be overridden void Run(ffi::PackedArgs args) { - auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_rocm_stream"); - hipStream_t stream = static_cast(func().cast()); - + int device_id = -1; std::vector dl_tensors(NumEntries()); for (size_t i = 0; i < static_cast(args.size()); i++) { @@ -84,7 +81,13 @@ class HipblasJSONRuntime : public JSONRuntimeBase { } dl_tensors[eid] = arg; + device_id = arg->device.device_id; + } + if (device_id == -1) { + ROCM_CALL(hipGetDevice(&device_id)); } + auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(Device(kDLROCM, device_id)); + hipStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 6facbb232b2c..454ab7a3707e 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -41,9 +41,10 @@ HipBlasThreadEntry::~HipBlasThreadEntry() { typedef dmlc::ThreadLocalStore HipBlasThreadStore; -HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal() { - auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; +HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(Device curr_device) { HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); + TVMFFIStreamHandle stream = + TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); return retval; } @@ -71,7 +72,9 @@ HipBlasLtThreadEntry::~HipBlasLtThreadEntry() { typedef dmlc::ThreadLocalStore HipBlasLtThreadStore; -HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal() { return HipBlasLtThreadStore::Get(); } +HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(Device curr_device) { + return HipBlasLtThreadStore::Get(); +} } // namespace contrib diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 53eba8e9c420..2c8a70aa6b34 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -59,8 +59,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ const int w_dim3 = args[15].cast(); const int n_group = args[16].cast(); void* out_shape = args[17].cast(); - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + int device_id = -1; + ROCM_CALL(hipGetDevice(&device_id)); + MIOpenThreadEntry* entry_ptr = + MIOpenThreadEntry::ThreadLocal(Device{kDLROCM, device_id}); assert(n_group > 0 && "Group Size > 0 is expected"); if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); @@ -168,7 +170,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ const auto w = args[10].cast(); const auto y = args[11].cast(); - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(x->device); entry_ptr->conv_entry.fwd_algo = static_cast(algo); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index bb091fdf7aa1..e860ba8ea7f2 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -42,12 +42,10 @@ std::string miopenGetErrorString(int error_code) { // MiopenThreadEntry MIOpenThreadEntry::MIOpenThreadEntry() { - auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; const auto get_rocm_api = tvm::ffi::Function::GetGlobalRequired("device_api.rocm"); void* ret = get_rocm_api(); rocm_api = static_cast(ret); MIOPEN_CALL(miopenCreate(&handle)); - MIOPEN_CALL(miopenSetStream(handle, stream)); conv_entry.rocm_api = rocm_api; } @@ -55,7 +53,14 @@ MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); } typedef dmlc::ThreadLocalStore MIOpenThreadStore; -MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); } +MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal(Device curr_device) { + // Need to update stream per fetch to avoid stream switching + MIOpenThreadEntry* res = MIOpenThreadStore::Get(); + TVMFFIStreamHandle stream = + TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + MIOPEN_CALL(miopenSetStream(res->handle, stream)); + return res; +} // ConvEntry diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index dfcde9e87915..5853cb2a7b11 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -45,7 +45,7 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(x->device); miopenSoftmaxMode_t mode; if (axis == ndim - 1) { diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index e19c03d4fda5..37ae9f254895 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -123,15 +123,17 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; Map input_datas; + int device_id = 0; for (const auto& pair : input_bindings_) { const auto& tensor_name = engine_->getBindingName(pair.first); input_datas.Set(tensor_name, device_buffers_[pair.first]); + device_id = data_entry_[pair.first]->device.device_id; } Map> context; context.Set("datas", input_datas); (*pf)(context, "before_forward", graph_name_, tool_tag_); } - auto tvm_stream = CUDAThreadEntry::ThreadLocal()->stream; + auto tvm_stream = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); #if TRT_VERSION_GE(6, 0, 1) ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) << "Running TensorRT failed."; diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index a378e53c54a5..fd032fc75bd1 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -54,8 +54,6 @@ namespace runtime { /*! \brief Thread local workspace */ class CUDAThreadEntry { public: - /*! \brief The cuda stream */ - cudaStream_t stream{nullptr}; /*! \brief thread local pool*/ WorkspacePool pool; /*! \brief constructor */ @@ -64,8 +62,6 @@ class CUDAThreadEntry { static CUDAThreadEntry* ThreadLocal(); }; -inline cudaStream_t GetCUDAStream() { return CUDAThreadEntry::ThreadLocal()->stream; } - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 8a0da35c205e..451348afbf1a 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -249,14 +250,6 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaStreamSynchronize(static_cast(stream))); } - void SetStream(Device dev, TVMStreamHandle stream) final { - CUDAThreadEntry::ThreadLocal()->stream = static_cast(stream); - } - - TVMStreamHandle GetCurrentStream(Device dev) final { - return static_cast(CUDAThreadEntry::ThreadLocal()->stream); - } - void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } @@ -306,9 +299,16 @@ class CUDATimerNode : public TimerNode { virtual void Start() { // This initial cudaEventRecord is sometimes pretty slow (~100us). Does // cudaEventRecord do some stream synchronization? - CUDA_CALL(cudaEventRecord(start_, CUDAThreadEntry::ThreadLocal()->stream)); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + stream_ = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + CUDA_CALL(cudaEventRecord(start_, static_cast(stream_))); + } + virtual void Stop() { + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CUDA_CALL(cudaEventRecord(stop_, static_cast(stream_))); } - virtual void Stop() { CUDA_CALL(cudaEventRecord(stop_, CUDAThreadEntry::ThreadLocal()->stream)); } virtual int64_t SyncAndGetElapsedNanos() { CUDA_CALL(cudaEventSynchronize(stop_)); float milliseconds = 0; @@ -330,6 +330,7 @@ class CUDATimerNode : public TimerNode { private: cudaEvent_t start_; cudaEvent_t stop_; + TVMStreamHandle stream_; }; TVM_FFI_STATIC_INIT_BLOCK({ @@ -351,8 +352,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory) - .def("runtime.get_cuda_stream", - []() { return static_cast(CUDAThreadEntry::ThreadLocal()->stream); }); + .def("runtime.get_cuda_stream", []() { + // TODO(tvm-team): remove once confirms all dep such as flashinfer + // migrated to TVMFFIEnvGetCurrentStream + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + return static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + }); }); TVM_DLL int GetCudaDeviceCount() { diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 5a4e682da8da..eb3bee4757bf 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -198,7 +199,7 @@ class CUDAWrappedFunc { } } } - CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); + CUstream strm = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 9427a6a3eeae..0c7f939181a2 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -19,6 +19,7 @@ #include "../../../3rdparty/nvbench/l2_cache_flush.h" #include +#include #include #include #include @@ -37,7 +38,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("l2_cache_flush_cuda", [](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; - cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); L2Flush::ThreadLocal()->Flush(stream); }); }); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index ae85f9ce5384..31006069a26b 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -164,7 +164,13 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} -TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } +void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr)); +} + +TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { + return TVMFFIEnvGetCurrentStream(dev.device_type, dev.device_id); +} void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { } diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 138d312dd47c..f10489826a5a 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -168,8 +168,6 @@ class MetalWorkspace final : public DeviceAPI { TVMStreamHandle CreateStream(Device dev) final; void FreeStream(Device dev, TVMStreamHandle stream) final; void StreamSync(Device dev, TVMStreamHandle stream) final; - void SetStream(Device dev, TVMStreamHandle stream) final; - TVMStreamHandle GetCurrentStream(Device dev) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; void ReinitializeDefaultStreams(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index ba2f69b8e7cc..2a8544f6f17c 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -312,17 +312,6 @@ int GetWarpSize(id dev) { }; } -void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { - ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; - ICHECK(stream != nullptr); - MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream; -} - -TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) { - ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; - return MetalThreadEntry::ThreadLocal()->stream[dev.device_id]; -} - void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c8842f7f53ca..9692b811a40c 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -214,14 +215,6 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipStreamSynchronize(static_cast(stream))); } - void SetStream(Device dev, TVMStreamHandle stream) final { - ROCMThreadEntry::ThreadLocal()->stream = static_cast(stream); - } - - TVMStreamHandle GetCurrentStream(Device dev) final { - return static_cast(ROCMThreadEntry::ThreadLocal()->stream); - } - void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } @@ -269,9 +262,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ class ROCMTimerNode : public TimerNode { public: virtual void Start() { - ROCM_CALL(hipEventRecord(start_, ROCMThreadEntry::ThreadLocal()->stream)); + int device_id; + ROCM_CALL(hipGetDevice(&device_id)); + stream_ = TVMFFIEnvGetCurrentStream(kDLROCM, device_id); + ROCM_CALL(hipEventRecord(start_, static_cast(stream_))); + } + virtual void Stop() { + int device_id; + ROCM_CALL(hipGetDevice(&device_id)); + ROCM_CALL(hipEventRecord(stop_, static_cast(stream_))); } - virtual void Stop() { ROCM_CALL(hipEventRecord(stop_, ROCMThreadEntry::ThreadLocal()->stream)); } virtual int64_t SyncAndGetElapsedNanos() { ROCM_CALL(hipEventSynchronize(stop_)); float milliseconds = 0; @@ -293,14 +293,18 @@ class ROCMTimerNode : public TimerNode { private: hipEvent_t start_; hipEvent_t stop_; + TVMStreamHandle stream_; }; TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("profiling.timer.rocm", [](Device dev) { return Timer(make_object()); }) - .def("runtime.get_rocm_stream", - []() { return static_cast(ROCMThreadEntry::ThreadLocal()->stream); }); + .def("runtime.get_rocm_stream", []() { + int device_id; + ROCM_CALL(hipGetDevice(&device_id)); + return static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + }); }); } // namespace runtime diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 13b14e13e0e7..f6beaca210bc 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -171,7 +172,7 @@ class ROCMWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } - hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); + hipStream_t strm = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 691246c3bf77..d7ccff66a046 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -114,18 +115,20 @@ class ScopedCUDAStream { class CUDACaptureStream { public: - explicit CUDACaptureStream(cudaGraph_t* graph) - : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) { - CUDAThreadEntry::ThreadLocal()->stream = capture_stream_; - + explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) { + CUDA_CALL(cudaGetDevice(&device_id_)); + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_, + reinterpret_cast(&prev_default_stream_))); CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } - ~CUDACaptureStream() { + ~CUDACaptureStream() noexcept(false) { cudaStreamEndCapture(capture_stream_, output_graph_); - CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); } private: + int device_id_; cudaStream_t prev_default_stream_; ScopedCUDAStream capture_stream_; @@ -155,7 +158,10 @@ class CUDAGraphExtensionNode : public VMExtensionNode { if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { // Launch CUDA graph const auto& [states, exec] = it->second; - CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream)); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CUDA_CALL(cudaGraphLaunch( + exec, static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)))); return states; } diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 09c3d522b0a0..023d34e68bda 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -332,12 +332,6 @@ void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { device(dev.device_id).ThreadLocalStream().Synchronize(); } -void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { - ICHECK_EQ(stream, static_cast(nullptr)); -} - -TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return nullptr; } - void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 64ca0db701e8..5e9bfeb8c086 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -61,8 +61,6 @@ class VulkanDeviceAPI final : public DeviceAPI { void FreeStream(Device dev, TVMStreamHandle stream) final; void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final; void StreamSync(Device dev, TVMStreamHandle stream) final; - void SetStream(Device dev, TVMStreamHandle stream) final; - TVMStreamHandle GetCurrentStream(Device dev) final; protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index cd50bc067983..eb14a7b7d7ee 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -118,10 +118,6 @@ class WebGPUDeviceAPI : public DeviceAPI { (*func)(); } - void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } - - TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not implemented"; } - void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); }