3636-
3737
3838"""
39+ import os
3940import torch
4041import numpy as np
4142from tvm import ffi as tvm_ffi
@@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
244245 print_speed (name , speed )
245246
246247
247- def tvm_ffi_nop_autodlpack_from_torch (repeat , device = "cpu" ):
248+ def tvm_ffi_nop_autodlpack_from_torch (repeat , device = "cpu" , stream = False ):
248249 """
249250 Measures overhead of running dlpack via auto convert by directly
250251 take torch.Tensor as inputs.
@@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
253254 x = torch .arange (1 , device = device )
254255 y = torch .arange (1 , device = device )
255256 z = torch .arange (1 , device = device )
256- bench_tvm_ffi_nop_autodlpack (f"tvm.ffi.nop.autodlpack(torch[{ device } ])" , x , y , z , repeat )
257+ if stream :
258+ with torch .cuda .stream (torch .cuda .Stream ()):
259+ bench_tvm_ffi_nop_autodlpack (
260+ f"tvm.ffi.nop.autodlpack(torch[{ device } ][stream])" , x , y , z , repeat
261+ )
262+ else :
263+ bench_tvm_ffi_nop_autodlpack (f"tvm.ffi.nop.autodlpack(torch[{ device } ])" , x , y , z , repeat )
257264
258265
259266def tvm_ffi_nop_autodlpack_from_numpy (repeat ):
@@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat):
308315 print_speed ("torch.utils.dlpack.to_dlpack" , speed )
309316
310317
318+ def torch_get_cuda_stream_native (device_id ):
319+ return torch .cuda .current_stream (device_id ).cuda_stream
320+
321+
322+ def load_torch_get_current_cuda_stream ():
323+ """Create a faster get_current_cuda_stream for torch through cpp extension."""
324+ from torch .utils import cpp_extension
325+
326+ source = """
327+ #include <c10/cuda/CUDAStream.h>
328+
329+ int64_t get_current_cuda_stream(int device_id) {
330+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
331+ // fast invariant, default stream is always 0
332+ if (stream.id() == 0) return 0;
333+ // convert to cudaStream_t
334+ return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
335+ }
336+ """
337+ result = cpp_extension .load_inline (
338+ name = "get_current_cuda_stream" ,
339+ cpp_sources = [source ],
340+ cuda_sources = [],
341+ extra_cflags = ["-O3" ],
342+ extra_include_paths = cpp_extension .include_paths ("cuda" ),
343+ functions = ["get_current_cuda_stream" ],
344+ )
345+ return result .get_current_cuda_stream
346+
347+
348+ def bench_torch_get_current_stream (repeat , name , func ):
349+ """
350+ Measures overhead of running torch.cuda.current_stream
351+ """
352+ x = torch .arange (1 , device = "cuda" )
353+ func (0 )
354+ start = time .time ()
355+ for i in range (repeat ):
356+ func (0 )
357+ end = time .time ()
358+ speed = (end - start ) / repeat
359+ print_speed (f"torch.cuda.current_stream[{ name } ]" , speed )
360+
361+
311362def main ():
312363 repeat = 10000
313364 print ("-----------------------------" )
@@ -323,6 +374,8 @@ def main():
323374 tvm_ffi_nop_from_torch_utils_to_dlpack (repeat )
324375 tvm_ffi_nop_autodlpack_from_torch (repeat , "cpu" )
325376 tvm_ffi_nop_autodlpack_from_torch (repeat , "cuda" )
377+ tvm_ffi_nop_autodlpack_from_torch (repeat , "cuda" , stream = True )
378+
326379 tvm_ffi_nop_autodlpack_from_numpy (repeat )
327380 print ("-------------------------------" )
328381 print ("Benchmark x.__dlpack__ overhead" )
@@ -339,6 +392,19 @@ def main():
339392 bench_to_dlpack_versioned (
340393 tvm_ffi .from_dlpack (torch .arange (1 )), "tvm.__dlpack__(max_version=(1,1))" , repeat
341394 )
395+ print ("---------------------------------------------------" )
396+ print ("Benchmark torch.get_cuda_stream[default stream]" )
397+ print ("---------------------------------------------------" )
398+ bench_torch_get_current_stream (repeat , "cpp-extension" , load_torch_get_current_cuda_stream ())
399+ bench_torch_get_current_stream (repeat , "python" , torch_get_cuda_stream_native )
400+ print ("---------------------------------------------------" )
401+ print ("Benchmark torch.get_cuda_stream[non-default stream]" )
402+ print ("---------------------------------------------------" )
403+ with torch .cuda .stream (torch .cuda .Stream ()):
404+ bench_torch_get_current_stream (
405+ repeat , "cpp-extension" , load_torch_get_current_cuda_stream ()
406+ )
407+ bench_torch_get_current_stream (repeat , "python" , torch_get_cuda_stream_native )
342408
343409
344410if __name__ == "__main__" :
0 commit comments