Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions ffi/scripts/benchmark_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
-

"""
import os
import torch
import numpy as np
from tvm import ffi as tvm_ffi
Expand Down Expand Up @@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
print_speed(name, speed)


def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False):
"""
Measures overhead of running dlpack via auto convert by directly
take torch.Tensor as inputs.
Expand All @@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
x = torch.arange(1, device=device)
y = torch.arange(1, device=device)
z = torch.arange(1, device=device)
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)
if stream:
with torch.cuda.stream(torch.cuda.Stream()):
bench_tvm_ffi_nop_autodlpack(
f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat
)
else:
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)


def tvm_ffi_nop_autodlpack_from_numpy(repeat):
Expand Down Expand Up @@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat):
print_speed("torch.utils.dlpack.to_dlpack", speed)


def torch_get_cuda_stream_native(device_id):
return torch.cuda.current_stream(device_id).cuda_stream


def load_torch_get_current_cuda_stream():
"""Create a faster get_current_cuda_stream for torch through cpp extension."""
from torch.utils import cpp_extension

source = """
#include <c10/cuda/CUDAStream.h>

int64_t get_current_cuda_stream(int device_id) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
// fast invariant, default stream is always 0
if (stream.id() == 0) return 0;
// convert to cudaStream_t
return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
}
"""
result = cpp_extension.load_inline(
name="get_current_cuda_stream",
cpp_sources=[source],
cuda_sources=[],
extra_cflags=["-O3"],
extra_include_paths=cpp_extension.include_paths("cuda"),
functions=["get_current_cuda_stream"],
)
return result.get_current_cuda_stream


def bench_torch_get_current_stream(repeat, name, func):
"""
Measures overhead of running torch.cuda.current_stream
"""
x = torch.arange(1, device="cuda")
func(0)
start = time.time()
for i in range(repeat):
func(0)
end = time.time()
speed = (end - start) / repeat
print_speed(f"torch.cuda.current_stream[{name}]", speed)


def main():
repeat = 10000
print("-----------------------------")
Expand All @@ -323,6 +374,8 @@ def main():
tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)

tvm_ffi_nop_autodlpack_from_numpy(repeat)
print("-------------------------------")
print("Benchmark x.__dlpack__ overhead")
Expand All @@ -339,6 +392,19 @@ def main():
bench_to_dlpack_versioned(
tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat
)
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[default stream]")
print("---------------------------------------------------")
bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream())
bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)
print("---------------------------------------------------")
print("Benchmark torch.get_cuda_stream[non-default stream]")
print("---------------------------------------------------")
with torch.cuda.stream(torch.cuda.Stream()):
bench_torch_get_current_stream(
repeat, "cpp-extension", load_torch_get_current_cuda_stream()
)
bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ cdef extern from "tvm/ffi/c_api.h":
DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil
DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil

cdef extern from "tvm/ffi/extra/c_env_api.h":
ctypedef void* TVMFFIStreamHandle

void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream) nogil


cdef class ByteArrayArg:
cdef TVMFFIByteArray cdata
Expand Down
92 changes: 86 additions & 6 deletions python/tvm/ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,51 @@ import ctypes
from numbers import Real, Integral

try:
# optionally import torch and setup torch related utils
import torch
except ImportError:
torch = None


def load_torch_get_current_cuda_stream():
"""Create a faster get_current_cuda_stream for torch through cpp extension.
"""
from torch.utils import cpp_extension

source = """
#include <c10/cuda/CUDAStream.h>

int64_t get_current_cuda_stream(int device_id) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
// fast invariant, default stream is always 0
if (stream.id() == 0) return 0;
// convert to cudaStream_t
return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
}
"""
def fallback_get_current_cuda_stream(device_id):
"""Fallback with python api"""
return torch.cuda.current_stream(device_id).cuda_stream
return fallback_get_current_cuda_stream
try:
result = cpp_extension.load_inline(
name="get_current_cuda_stream",
cpp_sources=[source],
cuda_sources=[],
extra_cflags=["-O3"],
extra_include_paths=cpp_extension.include_paths("cuda"),
functions=["get_current_cuda_stream"],
)
return result.get_current_cuda_stream
except Exception:
return fallback_get_current_cuda_stream

if torch is not None:
# when torch is available, jit compile the get_current_cuda_stream function
# the torch caches the extension so second loading is faster
torch_get_current_cuda_stream = load_torch_get_current_cuda_stream()


cdef inline object make_ret_small_str(TVMFFIAny result):
"""convert small string to return value."""
cdef TVMFFIByteArray bytes
Expand Down Expand Up @@ -76,9 +116,13 @@ cdef inline object make_ret(TVMFFIAny result):
raise ValueError("Unhandled type index %d" % type_index)


cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except -1:
cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
int* ctx_dev_type, int* ctx_dev_id, TVMFFIStreamHandle* ctx_stream) except -1:
"""Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr
cdef unsigned long long temp_ptr
cdef DLTensor* temp_dltensor
cdef int is_cuda = 0

for i, arg in enumerate(py_args):
# clear the value to ensure zero padding on 32bit platforms
if sizeof(void*) != 8:
Expand All @@ -96,10 +140,18 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
out[i].v_ptr = (<Object>arg).chandle
elif torch is not None and isinstance(arg, torch.Tensor):
is_cuda = arg.is_cuda
arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg),
required_alignment=__dlpack_auto_import_required_alignment__)
out[i].type_index = kTVMFFINDArray
out[i].v_ptr = (<NDArray>arg).chandle
temp_dltensor = TVMFFINDArrayGetDLTensorPtr((<NDArray>arg).chandle)
# record the stream and device for torch context
if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
ctx_dev_type[0] = temp_dltensor.device.device_type
ctx_dev_id[0] = temp_dltensor.device.device_id
temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id)
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__)
Expand Down Expand Up @@ -177,12 +229,27 @@ cdef inline int FuncCall3(void* chandle,
# fast path with stack alloca for less than 3 args
cdef TVMFFIAny[3] packed_args
cdef int nargs = len(args)
cdef int ctx_dev_type = -1
cdef int ctx_dev_id = 0
cdef TVMFFIStreamHandle ctx_stream = NULL
cdef TVMFFIStreamHandle prev_stream = NULL
temp_args = []
make_args(args, &packed_args[0], temp_args)
make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream)
with nogil:
if ctx_dev_type != -1:
# set the stream based on ctx stream
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(
chandle, &packed_args[0], nargs, result
)
# restore the original stream if it is not the same as the context stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
# restore the original stream
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0
return 0


Expand All @@ -191,6 +258,10 @@ cdef inline int FuncCall(void* chandle,
TVMFFIAny* result,
int* c_api_ret_code) except -1:
cdef int nargs = len(args)
cdef int ctx_dev_type = -1
cdef int ctx_dev_id = 0
cdef TVMFFIStreamHandle ctx_stream = NULL
cdef TVMFFIStreamHandle prev_stream = NULL

if nargs <= 3:
FuncCall3(chandle, args, result, c_api_ret_code)
Expand All @@ -200,10 +271,19 @@ cdef inline int FuncCall(void* chandle,
packed_args.resize(nargs)

temp_args = []
make_args(args, &packed_args[0], temp_args)
make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream)

with nogil:
if ctx_dev_type != -1:
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result)
# restore the original stream if it is not the same as the context stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0

return 0

Expand Down Expand Up @@ -274,7 +354,7 @@ cdef class FieldSetter:
cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
cdef int nargs = 1
temp_args = []
make_args((value,), &packed_args[0], temp_args)
make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL)
c_api_ret_code = self.setter(field_ptr, &packed_args[0])
# NOTE: logic is same as check_call
# directly inline here to simplify traceback
Expand Down Expand Up @@ -412,7 +492,7 @@ cdef int tvm_ffi_callback(void* context,
return -1

temp_args = []
make_args((rv,), &temp_result, temp_args)
make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL)
CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result))

return 0
Expand Down
Loading