1
1
import argparse
2
2
from typing import Tuple , Union
3
3
4
-
5
4
import tensorrt as trt
6
5
import tensorrt .plugin as trtp
7
6
import torch
8
7
import torch_tensorrt
9
8
import triton
10
9
import triton .language as tl
11
10
12
-
13
11
trt_logger = trt .Logger (trt .Logger .VERBOSE )
14
12
15
13
@@ -25,9 +23,7 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
25
23
26
24
27
25
@torch .library .custom_op ("my::add_one" , mutates_args = ()) # type: ignore[misc]
28
- def add_one (
29
- X : torch .Tensor
30
- ) -> torch .Tensor :
26
+ def add_one (X : torch .Tensor ) -> torch .Tensor :
31
27
# Ensure the tensors are on the GPU
32
28
assert X .is_cuda
33
29
@@ -51,63 +47,58 @@ def _(X: torch.Tensor) -> torch.Tensor:
51
47
return X
52
48
53
49
54
- # torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
55
- # "my::add_one"
56
- # )
57
-
58
50
@trtp .register ("my::add_one" )
59
51
def add_plugin_desc (X : trtp .TensorDesc ) -> Tuple [trtp .TensorDesc ]:
60
52
return X .like ()
61
53
62
- @trtp .aot_impl ("my::add_one" )
63
- def add_plugin_aot_impl (
64
- X : trtp .TensorDesc , outputs : Tuple [trtp .TensorDesc ], tactic : int
65
- ) -> Tuple [Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs ]:
66
-
67
-
68
- type_str = "fp32" if X .dtype == trt .float32 else "fp16"
69
-
70
- block_size = 256
71
- src = triton .compiler .ASTSource (
72
- fn = add_one_kernel ,
73
- signature = {
74
- "x_ptr" : f"*{ type_str } " ,
75
- "n_elements" : "i32" ,
76
- "y_ptr" : f"*{ type_str } " ,
77
- "BLOCK_SIZE" : "constexpr" ,
78
- },
79
- constants = {
80
- "BLOCK_SIZE" : block_size ,
81
- },
82
- )
83
-
84
- compiled_kernel = triton .compile (src )
85
-
86
- N = X .shape_expr .numel ()
87
- launch_params = trtp .KernelLaunchParams ()
88
54
89
- # grid dims
90
- launch_params .grid_x = trtp .cdiv (N , block_size )
91
- # block dims
92
- launch_params .block_x = compiled_kernel .metadata .num_warps * 32
93
- # shared memory
94
- launch_params .shared_mem = compiled_kernel .metadata .shared
95
-
96
- extra_args = trtp .SymIntExprs (1 )
97
- extra_args [0 ] = trtp .SymInt32 (N )
98
-
99
- return (
100
- compiled_kernel .metadata .name ,
101
- compiled_kernel .asm ["ptx" ],
102
- launch_params ,
103
- extra_args ,
104
- )
55
+ # @trtp.aot_impl("my::add_one")
56
+ # def add_plugin_aot_impl(
57
+ # X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58
+ # ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
59
+ # type_str = "fp32" if X.dtype == trt.float32 else "fp16"
60
+
61
+ # block_size = 256
62
+ # src = triton.compiler.ASTSource(
63
+ # fn=add_one_kernel,
64
+ # signature={
65
+ # "x_ptr": f"*{type_str}",
66
+ # "n_elements": "i32",
67
+ # "y_ptr": f"*{type_str}",
68
+ # "BLOCK_SIZE": "constexpr",
69
+ # },
70
+ # constants={
71
+ # "BLOCK_SIZE": block_size,
72
+ # },
73
+ # )
74
+
75
+ # compiled_kernel = triton.compile(src)
76
+
77
+ # N = X.shape_expr.numel()
78
+ # launch_params = trtp.KernelLaunchParams()
79
+
80
+ # # grid dims
81
+ # launch_params.grid_x = trtp.cdiv(N, block_size)
82
+ # # block dims
83
+ # launch_params.block_x = compiled_kernel.metadata.num_warps * 32
84
+ # # shared memory
85
+ # launch_params.shared_mem = compiled_kernel.metadata.shared
86
+
87
+ # extra_args = trtp.SymIntExprs(1)
88
+ # extra_args[0] = trtp.SymInt32(N)
89
+
90
+ # return (
91
+ # compiled_kernel.metadata.name,
92
+ # compiled_kernel.asm["ptx"],
93
+ # launch_params,
94
+ # extra_args,
95
+ # )
105
96
106
97
torch_tensorrt .dynamo .conversion .plugins .generate_plugin_converter (
107
98
"my::add_one" ,
108
99
supports_dynamic_shapes = False ,
109
100
requires_output_allocator = False ,
110
- aot = True ,
101
+ use_aot_if_available = True ,
111
102
)
112
103
113
104
@@ -129,15 +120,12 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
129
120
)
130
121
args = parser .parse_args ()
131
122
132
-
133
-
134
123
my_model = MyModel ().to ("cuda" )
135
124
m = torch .full ((64 , 64 ), 2 , device = "cuda" , dtype = torch .float )
136
125
137
126
# This works!
138
127
assert my_model (X = m )[0 ][0 ] == 3.0
139
128
140
-
141
129
with torch_tensorrt .logging .debug ():
142
130
trt_inputs = [m ]
143
131
model_trt = torch_tensorrt .compile (
@@ -153,4 +141,4 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
153
141
assert torch .allclose (res , my_model (m )), "Results do not match!"
154
142
155
143
print ("Inference successful!" )
156
- print (res )
144
+ print (res )
0 commit comments