Skip to content

Commit 88c304e

Browse files
committed
update apis
1 parent c3cd651 commit 88c304e

File tree

2 files changed

+58
-60
lines changed

2 files changed

+58
-60
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 44 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import argparse
22
from typing import Tuple, Union
33

4-
54
import tensorrt as trt
65
import tensorrt.plugin as trtp
76
import torch
87
import torch_tensorrt
98
import triton
109
import triton.language as tl
1110

12-
1311
trt_logger = trt.Logger(trt.Logger.VERBOSE)
1412

1513

@@ -25,9 +23,7 @@ def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
2523

2624

2725
@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
28-
def add_one(
29-
X: torch.Tensor
30-
) -> torch.Tensor:
26+
def add_one(X: torch.Tensor) -> torch.Tensor:
3127
# Ensure the tensors are on the GPU
3228
assert X.is_cuda
3329

@@ -51,63 +47,58 @@ def _(X: torch.Tensor) -> torch.Tensor:
5147
return X
5248

5349

54-
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
55-
# "my::add_one"
56-
# )
57-
5850
@trtp.register("my::add_one")
5951
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
6052
return X.like()
6153

62-
@trtp.aot_impl("my::add_one")
63-
def add_plugin_aot_impl(
64-
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
65-
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
66-
67-
68-
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
69-
70-
block_size = 256
71-
src = triton.compiler.ASTSource(
72-
fn=add_one_kernel,
73-
signature={
74-
"x_ptr": f"*{type_str}",
75-
"n_elements": "i32",
76-
"y_ptr": f"*{type_str}",
77-
"BLOCK_SIZE": "constexpr",
78-
},
79-
constants={
80-
"BLOCK_SIZE": block_size,
81-
},
82-
)
83-
84-
compiled_kernel = triton.compile(src)
85-
86-
N = X.shape_expr.numel()
87-
launch_params = trtp.KernelLaunchParams()
8854

89-
# grid dims
90-
launch_params.grid_x = trtp.cdiv(N, block_size)
91-
# block dims
92-
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
93-
# shared memory
94-
launch_params.shared_mem = compiled_kernel.metadata.shared
95-
96-
extra_args = trtp.SymIntExprs(1)
97-
extra_args[0] = trtp.SymInt32(N)
98-
99-
return (
100-
compiled_kernel.metadata.name,
101-
compiled_kernel.asm["ptx"],
102-
launch_params,
103-
extra_args,
104-
)
55+
# @trtp.aot_impl("my::add_one")
56+
# def add_plugin_aot_impl(
57+
# X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58+
# ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
59+
# type_str = "fp32" if X.dtype == trt.float32 else "fp16"
60+
61+
# block_size = 256
62+
# src = triton.compiler.ASTSource(
63+
# fn=add_one_kernel,
64+
# signature={
65+
# "x_ptr": f"*{type_str}",
66+
# "n_elements": "i32",
67+
# "y_ptr": f"*{type_str}",
68+
# "BLOCK_SIZE": "constexpr",
69+
# },
70+
# constants={
71+
# "BLOCK_SIZE": block_size,
72+
# },
73+
# )
74+
75+
# compiled_kernel = triton.compile(src)
76+
77+
# N = X.shape_expr.numel()
78+
# launch_params = trtp.KernelLaunchParams()
79+
80+
# # grid dims
81+
# launch_params.grid_x = trtp.cdiv(N, block_size)
82+
# # block dims
83+
# launch_params.block_x = compiled_kernel.metadata.num_warps * 32
84+
# # shared memory
85+
# launch_params.shared_mem = compiled_kernel.metadata.shared
86+
87+
# extra_args = trtp.SymIntExprs(1)
88+
# extra_args[0] = trtp.SymInt32(N)
89+
90+
# return (
91+
# compiled_kernel.metadata.name,
92+
# compiled_kernel.asm["ptx"],
93+
# launch_params,
94+
# extra_args,
95+
# )
10596

10697
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
10798
"my::add_one",
10899
supports_dynamic_shapes=False,
109100
requires_output_allocator=False,
110-
aot=True,
101+
use_aot_if_available=True,
111102
)
112103

113104

@@ -129,15 +120,12 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
129120
)
130121
args = parser.parse_args()
131122

132-
133-
134123
my_model = MyModel().to("cuda")
135124
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
136125

137126
# This works!
138127
assert my_model(X=m)[0][0] == 3.0
139128

140-
141129
with torch_tensorrt.logging.debug():
142130
trt_inputs = [m]
143131
model_trt = torch_tensorrt.compile(
@@ -153,4 +141,4 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
153141
assert torch.allclose(res, my_model(m)), "Results do not match!"
154142

155143
print("Inference successful!")
156-
print(res)
144+
print(res)

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _generate_plugin_converter(
3131
priority: ConverterPriority = ConverterPriority.STANDARD,
3232
supports_dynamic_shapes: bool = False,
3333
requires_output_allocator: bool = False,
34-
aot: bool = False,
34+
use_aot_if_available: bool = False,
3535
) -> DynamoConverterImplSignature:
3636
torch_target = getattr(getattr(torch.ops, namespace), op_name)
3737
overload_str = overload if overload else ""
@@ -42,6 +42,16 @@ def _generate_plugin_converter(
4242
), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter"
4343
torch_schema = torch_target._schemas[overload_str]
4444

45+
use_aot_plugin = use_aot_if_available
46+
47+
if use_aot_if_available:
48+
desc = QDP_REGISTRY[f"{namespace}::{op_name}"]
49+
if desc.aot_impl_func is None:
50+
use_aot_plugin = False
51+
_LOGGER.debug(
52+
f"AOT impl func not found for {namespace}::{op_name}, use JIT plugin instead"
53+
)
54+
4555
def custom_kernel_converter(
4656
ctx: ConversionContext,
4757
target: Target,
@@ -81,7 +91,7 @@ def custom_kernel_converter(
8191
if isinstance(v, torch.fx.immutable_collections.immutable_list):
8292
kwargs[k] = np.array(v)
8393

84-
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot)
94+
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=use_aot_plugin)
8595
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
8696
_LOGGER.debug(
8797
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
@@ -108,7 +118,7 @@ def generate_plugin_converter(
108118
priority: ConverterPriority = ConverterPriority.STANDARD,
109119
supports_dynamic_shapes: bool = False,
110120
requires_output_allocator: bool = False,
111-
aot: bool = False,
121+
use_aot_if_available: bool = False,
112122
) -> DynamoConverterImplSignature:
113123
plugin_ns, plugin_name = plugin_id.split("::")
114124
return _generate_plugin_converter(
@@ -118,5 +128,5 @@ def generate_plugin_converter(
118128
priority=priority,
119129
supports_dynamic_shapes=supports_dynamic_shapes,
120130
requires_output_allocator=requires_output_allocator,
121-
aot=aot,
131+
use_aot_if_available=use_aot_if_available,
122132
)

0 commit comments

Comments
 (0)