Skip to content

Commit 3e43bf4

Browse files
committed
[FFI] AudoDLPack compatible with torch stream context (apache#18217)
This PR updates the autodlpack path to automatically update the env stream to be consistent with torch stream context. The change would help to make FFI functions to be compatible in stream based executions. We leverage torch cpp_extension load_inline to create an efficient query function, the first time loading might take more time to build the jit module and things should be fast after the torch jit module is cached.
1 parent 02f4b84 commit 3e43bf4

File tree

1 file changed

+68
-2
lines changed

1 file changed

+68
-2
lines changed

scripts/benchmark_dlpack.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
-
3737
3838
"""
39+
import os
3940
import torch
4041
import numpy as np
4142
from tvm import ffi as tvm_ffi
@@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
244245
print_speed(name, speed)
245246

246247

247-
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
248+
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False):
248249
"""
249250
Measures overhead of running dlpack via auto convert by directly
250251
take torch.Tensor as inputs.
@@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
253254
x = torch.arange(1, device=device)
254255
y = torch.arange(1, device=device)
255256
z = torch.arange(1, device=device)
256-
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)
257+
if stream:
258+
with torch.cuda.stream(torch.cuda.Stream()):
259+
bench_tvm_ffi_nop_autodlpack(
260+
f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat
261+
)
262+
else:
263+
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)
257264

258265

259266
def tvm_ffi_nop_autodlpack_from_numpy(repeat):
@@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat):
308315
print_speed("torch.utils.dlpack.to_dlpack", speed)
309316

310317

318+
def torch_get_cuda_stream_native(device_id):
319+
return torch.cuda.current_stream(device_id).cuda_stream
320+
321+
322+
def load_torch_get_current_cuda_stream():
323+
"""Create a faster get_current_cuda_stream for torch through cpp extension."""
324+
from torch.utils import cpp_extension
325+
326+
source = """
327+
#include <c10/cuda/CUDAStream.h>
328+
329+
int64_t get_current_cuda_stream(int device_id) {
330+
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
331+
// fast invariant, default stream is always 0
332+
if (stream.id() == 0) return 0;
333+
// convert to cudaStream_t
334+
return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
335+
}
336+
"""
337+
result = cpp_extension.load_inline(
338+
name="get_current_cuda_stream",
339+
cpp_sources=[source],
340+
cuda_sources=[],
341+
extra_cflags=["-O3"],
342+
extra_include_paths=cpp_extension.include_paths("cuda"),
343+
functions=["get_current_cuda_stream"],
344+
)
345+
return result.get_current_cuda_stream
346+
347+
348+
def bench_torch_get_current_stream(repeat, name, func):
349+
"""
350+
Measures overhead of running torch.cuda.current_stream
351+
"""
352+
x = torch.arange(1, device="cuda")
353+
func(0)
354+
start = time.time()
355+
for i in range(repeat):
356+
func(0)
357+
end = time.time()
358+
speed = (end - start) / repeat
359+
print_speed(f"torch.cuda.current_stream[{name}]", speed)
360+
361+
311362
def main():
312363
repeat = 10000
313364
print("-----------------------------")
@@ -323,6 +374,8 @@ def main():
323374
tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
324375
tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
325376
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
377+
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)
378+
326379
tvm_ffi_nop_autodlpack_from_numpy(repeat)
327380
print("-------------------------------")
328381
print("Benchmark x.__dlpack__ overhead")
@@ -339,6 +392,19 @@ def main():
339392
bench_to_dlpack_versioned(
340393
tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat
341394
)
395+
print("---------------------------------------------------")
396+
print("Benchmark torch.get_cuda_stream[default stream]")
397+
print("---------------------------------------------------")
398+
bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream())
399+
bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)
400+
print("---------------------------------------------------")
401+
print("Benchmark torch.get_cuda_stream[non-default stream]")
402+
print("---------------------------------------------------")
403+
with torch.cuda.stream(torch.cuda.Stream()):
404+
bench_torch_get_current_stream(
405+
repeat, "cpp-extension", load_torch_get_current_cuda_stream()
406+
)
407+
bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native)
342408

343409

344410
if __name__ == "__main__":

0 commit comments

Comments
 (0)