1010from torch_tensorrt .dynamo .backend .lowering import module_substitution
1111
1212
13+ # This file serves as an example and a tutorial for excluding custom modules from
14+ # torch.compile tracing. Each required step is labeled with a number indicating the
15+ # preferable implementation order.
16+
17+
18+ # 1. The Placeholder
19+ #
20+ # Specify the schema and namespace of the operator, as well as a placeholder function
21+ # representing the schema. The schema should be in torch JIT syntax, indicating input and output
22+ # types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
23+ # Then, create a placeholder function with no operations, but having the same schema and naming as that
24+ # used in the decorator
1325@custom_op (
1426 "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor" ,
1527 ns = "tensorrt" ,
@@ -19,19 +31,71 @@ def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=Fals
1931 ...
2032
2133
34+ # 2. The Generic Implementation
35+ #
36+ # Define the default implementation of the operator in torch syntax. This is used for autograd
37+ # and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace
38+ # is desirable. If the operator to replace is a custom module you've written, then add its Torch
39+ # implementation here. Note that the function header to the generic function can have specific arguments
40+ # as in the above placeholder
2241@maxpool1d .impl ("cpu" )
2342@maxpool1d .impl ("cuda" )
2443def maxpool1d_generic (
2544 * args ,
2645 ** kwargs ,
2746):
28- # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
47+ # Defines an implementation for AOT Autograd to use for shape analysis/propagation
2948 return torch .nn .functional .max_pool1d (
3049 * args ,
3150 ** kwargs ,
3251 )
3352
3453
54+ # 3. The Module Substitution Function
55+ #
56+ # Define a function which can intercept a node of the kind to be replaced, extract
57+ # the relevant data from that node/submodule, and then re-package the information
58+ # for use by an accelerated implementation (to be implemented in step 4). This function
59+ # should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d).
60+ # It should refactor the args and kwargs as is needed by the accelerated implementation.
61+ #
62+ # If the submodule has weights or other Tensor fields which the accelerated implementation
63+ # needs, the function should insert the necessary nodes to access those weights. For example,
64+ # if the weight Tensor of a submodule is needed, one could write:
65+ #
66+ # weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
67+ # bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
68+ # ...
69+ # kwargs={"weight": weights,
70+ # "bias": bias,
71+ # ...
72+ #
73+ @module_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d )
74+ def maxpool1d_insertion_fn (
75+ gm : torch .fx .GraphModule , submodule : torch .nn .Module , node : torch .fx .Node
76+ ) -> torch .fx .Node :
77+ # Defines insertion function for new node
78+ new_node = gm .graph .call_function (
79+ torch .ops .tensorrt .maxpool1d ,
80+ args = node .args ,
81+ kwargs = {
82+ "kernel_size" : submodule .kernel_size ,
83+ "stride" : submodule .stride ,
84+ "padding" : submodule .padding ,
85+ "dilation" : submodule .dilation ,
86+ "ceil_mode" : submodule .ceil_mode ,
87+ },
88+ )
89+
90+ return new_node
91+
92+
93+ # 4. The Accelerated Implementation
94+ #
95+ # Define an accelerated implementation of the operator, and register it as necessary.
96+ # This accelerated implementation should consume the args/kwargs specified in step 3.
97+ # One should expect that torch.compile will compress all kwargs into the args field in
98+ # the order specified in the schema written in step 1.
3599@tensorrt_converter (torch .ops .tensorrt .maxpool1d .default )
36100def aten_ops_maxpool1d (
37101 network : TRTNetwork ,
@@ -55,21 +119,8 @@ def aten_ops_maxpool1d(
55119 )
56120
57121
58- @module_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d )
59- def maxpool1d_insertion_fn (
60- gm : torch .fx .GraphModule , submodule : torch .nn .Module , node : torch .fx .Node
61- ) -> torch .fx .Node :
62- # Defines insertion function for new node
63- new_node = gm .graph .call_function (
64- torch .ops .tensorrt .maxpool1d ,
65- args = node .args ,
66- kwargs = {
67- "kernel_size" : submodule .kernel_size ,
68- "stride" : submodule .stride ,
69- "padding" : submodule .padding ,
70- "dilation" : submodule .dilation ,
71- "ceil_mode" : submodule .ceil_mode ,
72- },
73- )
74-
75- return new_node
122+ # 5. Add Imports
123+ #
124+ # Add your accelerated module file to the __init__.py in this directory, to ensure
125+ # all registrations are run. For instance, if the new module file is called new_mod.py,
126+ # one should add `from .new_mod import *` to the __init__.py
0 commit comments