Skip to content

Conversation

@yaoyaoding
Copy link
Contributor

@yaoyaoding yaoyaoding commented Sep 12, 2025

This PR update the API of tvm_ffi.cpp.load_inline:

Before the PR: the functions parameter specify the functions that will be exported in cpp_sources. We need to declare the functions defined in cuda_sources in cpp_sources to export the functions in cuda_sources.

After this PR: cpp_sources is optional and when it's not given, we directly export functions in cuda_sources with functions parameter.

Example:

    mod: Module = tvm_ffi.cpp.load_inline(
        name="hello",
        cuda_sources=r"""
            __global__ void AddOneKernel(float* x, float* y, int n) {
              int idx = blockIdx.x * blockDim.x + threadIdx.x;
              if (idx < n) {
                y[idx] = x[idx] + 1;
              }
            }

            void add_one_cuda(DLTensor* x, DLTensor* y) {
              // implementation of a library function
              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
              DLDataType f32_dtype{kDLFloat, 32, 1};
              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";

              int64_t n = x->shape[0];
              int64_t nthread_per_block = 256;
              int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
              // Obtain the current stream from the environment
              // it will be set to torch.cuda.current_stream() when calling the function
              // with torch.Tensors
              cudaStream_t stream = static_cast<cudaStream_t>(
                  TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id));
              // launch the kernel
              AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
                                                                     static_cast<float*>(y->data), n);
            }
        """,
        functions=["add_one_cuda"],
    )

@tqchen tqchen merged commit 0c9e7cd into apache:main Sep 12, 2025
11 checks passed
@yaoyaoding yaoyaoding deleted the update-load-inline branch September 12, 2025 20:58
tqchen pushed a commit to tqchen/tvm that referenced this pull request Sep 13, 2025
tqchen pushed a commit to tqchen/tvm that referenced this pull request Sep 13, 2025
tqchen pushed a commit to tqchen/tvm that referenced this pull request Sep 13, 2025
update load_inline interface
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants