diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py new file mode 100644 index 0000000000..37581f76cd --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.fx.types import TRTNetwork + + +@dataclass +class ConversionContext: + """Class representing the context for conversion of a particular network + + Args: + net: TensorRT Network being built + compilation_settings: Settings selected by the user for compilation + """ + + net: TRTNetwork + compilation_settings: CompilationSettings = field( + default_factory=CompilationSettings + ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9f3dc5deb9..206636a637 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -13,7 +13,13 @@ from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_registry import CallingConvention +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_node_name, + get_trt_tensor, +) from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -46,6 +52,7 @@ def __init__( input_specs: List[Input], logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, output_dtypes: Optional[List[torch.dtype]] = None, + compilation_settings: CompilationSettings = CompilationSettings(), ): super().__init__(module) @@ -59,7 +66,9 @@ def __init__( EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) flag |= EXPLICIT_BATCH - self.network = self.builder.create_network(flag) + self.ctx = ConversionContext( + self.builder.create_network(flag), compilation_settings + ) missing_ops = self.validate_conversion() if missing_ops: @@ -95,14 +104,14 @@ def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() for node in self.module.graph.nodes: - if node.op == "call_function" and not CONVERTERS.get(node): + if node.op == "call_function" and CONVERTERS.get(node) is None: missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}") - elif node.op == "call_method" and not CONVERTERS.get(node): + elif node.op == "call_method" and CONVERTERS.get(node) is None: missing_converters.add(f"{node.op} torch.Tensor.{node.target}") elif node.op == "call_module": submod = self.fetch_attr(node.target) submod_type = getattr(submod, "_base_class_origin", type(submod)) - if not CONVERTERS.get(node): + if CONVERTERS.get(node) is None: missing_converters.add(f"{node.op} {torch.typename(submod_type)}") return missing_converters @@ -221,7 +230,7 @@ def run( if tactic_sources is not None: builder_config.set_tactic_sources(tactic_sources=tactic_sources) - engine = self.builder.build_engine(self.network, builder_config) + engine = self.builder.build_engine(self.ctx.net, builder_config) assert engine serialized_cache = ( @@ -291,7 +300,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: f"Unable to access shape spec for input: {target} (got: {current_input})" ) - return self.network.add_input( + return self.ctx.net.add_input( name=target, shape=tuple(shape), dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT), @@ -303,30 +312,40 @@ def call_module( assert isinstance(target, str) submod = self.fetch_attr(target) submod_type = getattr(submod, "_base_class_origin", type(submod)) - converter = CONVERTERS.get(self._cur_node) + converter_packet = CONVERTERS.get(self._cur_node) - if not converter: + if converter_packet is None: raise UnsupportedOperatorException( f"Conversion of module of type {submod_type} not currently supported!" ) + converter, calling_convention = converter_packet + assert self._cur_node_name is not None - return converter(self.network, submod, args, kwargs, self._cur_node_name) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, submod, args, kwargs, self._cur_node_name) def call_function(self, target: str, args: Any, kwargs: Any) -> Any: # TODO: Why is this stateful? We should be able to take in the inputs - converter = CONVERTERS.get(self._cur_node) - if not converter: + converter_packet = CONVERTERS.get(self._cur_node) + if converter_packet is None: raise UnsupportedOperatorException( f"Conversion of function {torch.typename(target)} not currently supported!" ) + converter, calling_convention = converter_packet + assert self._cur_node_name is not None - return converter(self.network, target, args, kwargs, self._cur_node_name) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, target, args, kwargs, self._cur_node_name) def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: with _disable_current_modes(): - from torch_tensorrt.fx.converters import to_numpy + from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy frozen_attr = self.fetch_attr(target) @@ -341,15 +360,19 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) - converter = CONVERTERS.get(self._cur_node) + converter_packet = CONVERTERS.get(self._cur_node) - if not converter: + if converter_packet is None: raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) + converter, calling_convention = converter_packet assert self._cur_node_name is not None - return converter(self.network, target, args, kwargs, self._cur_node_name) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, target, args, kwargs, self._cur_node_name) def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: assert len(args) == 1 @@ -361,12 +384,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: outputs = (args[0],) for output_idx in range(len(outputs)): - from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor - output = outputs[output_idx] if not isinstance(output, trt.tensorrt.ITensor): - new_output = get_trt_tensor(self.network, output, target) + new_output = get_trt_tensor(self.ctx, output, target) outputs = ( outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] ) @@ -400,7 +421,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: output_bool = False name = f"output{i}" output.name = name - self.network.mark_output(output) + self.ctx.net.mark_output(output) if output_bool: output.dtype = trt.bool elif self.output_dtypes is not None: diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 9cbfff950e..3fabb1bb45 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,3 +1,4 @@ +from ._ConversionContext import ConversionContext from ._TRTInterpreter import * # noqa: F403 from .aten_ops_converters import * # noqa: F403 from .conversion import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index dd18be9151..d844fe1995 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,16 +1,21 @@ import logging from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +import numpy as np import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_registry import ( + dynamo_tensorrt_converter, +) from torch_tensorrt.dynamo.conversion.converter_utils import ( + enforce_tensor_types, is_only_operator_on_placeholder, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor -from .converter_registry import dynamo_tensorrt_converter from .converter_utils import dynamic_unsupported_with_args _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -24,14 +29,14 @@ def args_bounds_check( @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] def aten_ops_batch_norm( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.normalization.batch_norm( - network, + ctx, target, SourceIR.ATEN, name, @@ -67,14 +72,14 @@ def embedding_param_validator(embedding_node: Node) -> bool: torch.ops.aten.embedding.default, capability_validator=embedding_param_validator ) # type: ignore[misc] def aten_ops_embedding( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.embedding.embedding( - network, + ctx, target, SourceIR.ATEN, name, @@ -89,25 +94,25 @@ def aten_ops_embedding( @dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) # type: ignore[misc] def aten_ops_fmod( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) + return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) @dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] def aten_ops_relu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.relu( - network, + ctx, target, SourceIR.ATEN, name, @@ -117,14 +122,14 @@ def aten_ops_relu( @dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc] def aten_ops_sigmoid( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.sigmoid( - network, + ctx, target, SourceIR.ATEN, name, @@ -134,14 +139,14 @@ def aten_ops_sigmoid( @dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc] def aten_ops_tanh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.tanh( - network, + ctx, target, SourceIR.ATEN, name, @@ -151,14 +156,14 @@ def aten_ops_tanh( @dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc] def aten_ops_leaky_relu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.leaky_relu( - network, + ctx, target, SourceIR.ATEN, name, @@ -169,14 +174,14 @@ def aten_ops_leaky_relu( @dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc] def aten_ops_elu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.elu( - network, + ctx, target, SourceIR.ATEN, name, @@ -188,14 +193,14 @@ def aten_ops_elu( @dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc] def aten_ops_softplus( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.softplus( - network, + ctx, target, SourceIR.ATEN, name, @@ -206,14 +211,14 @@ def aten_ops_softplus( @dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc] def aten_ops_clip( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.clip( - network, + ctx, target, SourceIR.ATEN, name, @@ -225,14 +230,14 @@ def aten_ops_clip( @dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc] def aten_ops_hard_sigmoid( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.activation.hard_sigmoid( - network, + ctx, target, SourceIR.ATEN, name, @@ -247,14 +252,14 @@ def aten_ops_hard_sigmoid( @dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.bmm.default) # type: ignore[misc] def aten_ops_matmul( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.matmul.matrix_multiply( - network, + ctx, target, SourceIR.ATEN, name, @@ -265,14 +270,14 @@ def aten_ops_matmul( @dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] def aten_ops_layernorm( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.normalization.layer_norm( - network, + ctx, target, SourceIR.ATEN, name, @@ -286,14 +291,14 @@ def aten_ops_layernorm( @dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc] def aten_ops_rsqrt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.rsqrt( - network, + ctx, target, SourceIR.ATEN, name, @@ -303,14 +308,14 @@ def aten_ops_rsqrt( @dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc] def aten_ops_neg( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.neg( - network, + ctx, target, SourceIR.ATEN, name, @@ -321,25 +326,25 @@ def aten_ops_neg( @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc] def aten_ops_squeeze( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1]) + return impl.squeeze.squeeze(ctx, target, SourceIR.ATEN, name, args[0], args[1]) @dynamo_tensorrt_converter(torch.ops.aten.erf.default) # type: ignore[misc] def aten_ops_erf( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.erf( - network, + ctx, target, SourceIR.ATEN, name, @@ -349,27 +354,27 @@ def aten_ops_erf( @dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) # type: ignore[misc] def aten_ops_unsqueeze( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unsqueeze.unsqueeze( - network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1] + ctx, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1] ) @dynamo_tensorrt_converter(torch.ops.aten._softmax.default) # type: ignore[misc] def aten_ops_softmax( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.normalization.softmax( - network, target, SourceIR.ATEN, name, args[0], args[1] + ctx, target, SourceIR.ATEN, name, args[0], args[1] ) @@ -384,14 +389,14 @@ def aten_ops_softmax( capability_validator=dynamic_unsupported_with_args([1]), ) # type: ignore[misc] def aten_ops_split( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.split.split( - network, + ctx, target, SourceIR.ATEN, name, @@ -403,14 +408,14 @@ def aten_ops_split( @dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc] def aten_ops_where( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.condition.where( - network, + ctx, target, SourceIR.ATEN, name, @@ -422,14 +427,14 @@ def aten_ops_where( @dynamo_tensorrt_converter(torch.ops.aten.clamp.default) # type: ignore[misc] def aten_ops_clamp( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.clamp( - network, + ctx, target, SourceIR.ATEN, name, @@ -441,27 +446,27 @@ def aten_ops_clamp( @dynamo_tensorrt_converter(torch.ops.aten.select.int) # type: ignore[misc] def aten_ops_select( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.select.select( - network, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] ) @dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) # type: ignore[misc] def aten_ops_slice( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.slice.slice_op( - network, + ctx, target, SourceIR.ATEN, name, @@ -474,15 +479,20 @@ def aten_ops_slice( @dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_permute( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.permutation.permute( - network, + ctx, target, SourceIR.ATEN, name, @@ -544,14 +554,14 @@ def validator(to_copy_node: Node) -> bool: capability_validator=to_copy_dtype_validator(placeholder_only=False), ) # type: ignore[misc] def aten_ops_clone_copy_dtype( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.cast.to_copy( - network, + ctx, target, SourceIR.ATEN, name, @@ -570,7 +580,7 @@ def aten_ops_clone_copy_dtype( capability_validator=to_copy_dtype_validator(placeholder_only=True), ) # type: ignore[misc] def aten_ops_clone_copy_placeholder( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], @@ -580,7 +590,7 @@ def aten_ops_clone_copy_placeholder( # we need to force cast to ensure a layer is added to the TRT engine # since TRT engine inputs cannot also be TRT engine outputs return impl.cast.to_copy( - network, + ctx, target, SourceIR.ATEN, name, @@ -592,14 +602,14 @@ def aten_ops_clone_copy_placeholder( @dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc] def aten_ops_expand( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.slice.expand( - network, + ctx, target, SourceIR.ATEN, name, @@ -622,14 +632,14 @@ def amax_param_validator(amax_node: Node) -> bool: torch.ops.aten.amax.default, capability_validator=amax_param_validator ) # type: ignore[misc] def aten_ops_amax( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.reduce.amax( - network, + ctx, target, SourceIR.ATEN, name, @@ -642,14 +652,14 @@ def aten_ops_amax( @dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc] def aten_ops_sum( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.reduce.sum( - network, + ctx, target, SourceIR.ATEN, name, @@ -661,14 +671,14 @@ def aten_ops_sum( @dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] def aten_ops_exp( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.exp( - network, + ctx, target, SourceIR.ATEN, name, @@ -678,14 +688,14 @@ def aten_ops_exp( @dynamo_tensorrt_converter(torch.ops.aten.log.default) # type: ignore[misc] def aten_ops_log( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.log( - network, + ctx, target, SourceIR.ATEN, name, @@ -695,14 +705,14 @@ def aten_ops_log( @dynamo_tensorrt_converter(torch.ops.aten.sqrt.default) # type: ignore[misc] def aten_ops_sqrt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.sqrt( - network, + ctx, target, SourceIR.ATEN, name, @@ -712,14 +722,14 @@ def aten_ops_sqrt( @dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default) # type: ignore[misc] def aten_ops_recip( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.recip( - network, + ctx, target, SourceIR.ATEN, name, @@ -729,14 +739,14 @@ def aten_ops_recip( @dynamo_tensorrt_converter(torch.ops.aten.abs.default) # type: ignore[misc] def aten_ops_abs( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.abs( - network, + ctx, target, SourceIR.ATEN, name, @@ -746,14 +756,14 @@ def aten_ops_abs( @dynamo_tensorrt_converter(torch.ops.aten.sin.default) # type: ignore[misc] def aten_ops_sin( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.sin( - network, + ctx, target, SourceIR.ATEN, name, @@ -763,14 +773,14 @@ def aten_ops_sin( @dynamo_tensorrt_converter(torch.ops.aten.cos.default) # type: ignore[misc] def aten_ops_cos( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.cos( - network, + ctx, target, SourceIR.ATEN, name, @@ -780,14 +790,14 @@ def aten_ops_cos( @dynamo_tensorrt_converter(torch.ops.aten.tan.default) # type: ignore[misc] def aten_ops_tan( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.tan( - network, + ctx, target, SourceIR.ATEN, name, @@ -797,14 +807,14 @@ def aten_ops_tan( @dynamo_tensorrt_converter(torch.ops.aten.sinh.default) # type: ignore[misc] def aten_ops_sinh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.sinh( - network, + ctx, target, SourceIR.ATEN, name, @@ -814,14 +824,14 @@ def aten_ops_sinh( @dynamo_tensorrt_converter(torch.ops.aten.cosh.default) # type: ignore[misc] def aten_ops_cosh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.cosh( - network, + ctx, target, SourceIR.ATEN, name, @@ -831,14 +841,14 @@ def aten_ops_cosh( @dynamo_tensorrt_converter(torch.ops.aten.asin.default) # type: ignore[misc] def aten_ops_asin( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.asin( - network, + ctx, target, SourceIR.ATEN, name, @@ -848,14 +858,14 @@ def aten_ops_asin( @dynamo_tensorrt_converter(torch.ops.aten.acos.default) # type: ignore[misc] def aten_ops_acos( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.acos( - network, + ctx, target, SourceIR.ATEN, name, @@ -865,14 +875,14 @@ def aten_ops_acos( @dynamo_tensorrt_converter(torch.ops.aten.atan.default) # type: ignore[misc] def aten_ops_atan( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.atan( - network, + ctx, target, SourceIR.ATEN, name, @@ -882,14 +892,14 @@ def aten_ops_atan( @dynamo_tensorrt_converter(torch.ops.aten.asinh.default) # type: ignore[misc] def aten_ops_asinh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.asinh( - network, + ctx, target, SourceIR.ATEN, name, @@ -899,14 +909,14 @@ def aten_ops_asinh( @dynamo_tensorrt_converter(torch.ops.aten.acosh.default) # type: ignore[misc] def aten_ops_acosh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.acosh( - network, + ctx, target, SourceIR.ATEN, name, @@ -916,14 +926,14 @@ def aten_ops_acosh( @dynamo_tensorrt_converter(torch.ops.aten.atanh.default) # type: ignore[misc] def aten_ops_atanh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.atanh( - network, + ctx, target, SourceIR.ATEN, name, @@ -933,14 +943,14 @@ def aten_ops_atanh( @dynamo_tensorrt_converter(torch.ops.aten.ceil.default) # type: ignore[misc] def aten_ops_ceil( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.ceil( - network, + ctx, target, SourceIR.ATEN, name, @@ -950,14 +960,14 @@ def aten_ops_ceil( @dynamo_tensorrt_converter(torch.ops.aten.floor.default) # type: ignore[misc] def aten_ops_floor( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.floor( - network, + ctx, target, SourceIR.ATEN, name, @@ -967,14 +977,14 @@ def aten_ops_floor( @dynamo_tensorrt_converter(torch.ops.aten.logical_not.default) # type: ignore[misc] def aten_ops_logical_not( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.logical_not( - network, + ctx, target, SourceIR.ATEN, name, @@ -984,14 +994,14 @@ def aten_ops_logical_not( @dynamo_tensorrt_converter(torch.ops.aten.sign.default) # type: ignore[misc] def aten_ops_sign( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.sign( - network, + ctx, target, SourceIR.ATEN, name, @@ -1001,14 +1011,14 @@ def aten_ops_sign( @dynamo_tensorrt_converter(torch.ops.aten.round.default) # type: ignore[misc] def aten_ops_round( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.round( - network, + ctx, target, SourceIR.ATEN, name, @@ -1018,14 +1028,14 @@ def aten_ops_round( @dynamo_tensorrt_converter(torch.ops.aten.isinf.default) # type: ignore[misc] def aten_ops_isinf( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.unary.isinf( - network, + ctx, target, SourceIR.ATEN, name, @@ -1036,7 +1046,7 @@ def aten_ops_isinf( @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc] def aten_ops_add( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], @@ -1047,7 +1057,7 @@ def aten_ops_add( if alpha != 1: other = impl.elementwise.mul( - network, + ctx, target, SourceIR.ATEN, name, @@ -1056,7 +1066,7 @@ def aten_ops_add( ) return impl.elementwise.add( - network, + ctx, target, SourceIR.ATEN, name, @@ -1068,14 +1078,14 @@ def aten_ops_add( @dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc] def aten_ops_mul( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.mul( - network, + ctx, target, SourceIR.ATEN, name, @@ -1086,14 +1096,14 @@ def aten_ops_mul( @dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc] def aten_ops_max( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.max( - network, + ctx, target, SourceIR.ATEN, name, @@ -1104,14 +1114,14 @@ def aten_ops_max( @dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc] def aten_ops_min( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.min( - network, + ctx, target, SourceIR.ATEN, name, @@ -1123,7 +1133,7 @@ def aten_ops_min( @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc] def aten_ops_sub( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], @@ -1134,7 +1144,7 @@ def aten_ops_sub( if alpha != 1: other = impl.elementwise.mul( - network, + ctx, target, SourceIR.ATEN, name, @@ -1143,7 +1153,7 @@ def aten_ops_sub( ) return impl.elementwise.sub( - network, + ctx, target, SourceIR.ATEN, name, @@ -1157,7 +1167,7 @@ def aten_ops_sub( @dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc] def aten_ops_div( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], @@ -1167,7 +1177,7 @@ def aten_ops_div( if rounding_mode is None: return impl.elementwise.div( - network, + ctx, target, SourceIR.ATEN, name, @@ -1176,7 +1186,7 @@ def aten_ops_div( ) elif rounding_mode == "floor": return impl.elementwise.floor_divide( - network, + ctx, target, SourceIR.ATEN, name, @@ -1185,7 +1195,7 @@ def aten_ops_div( ) elif rounding_mode == "trunc": return impl.elementwise.trunc_div( - network, + ctx, target, SourceIR.ATEN, name, @@ -1202,14 +1212,14 @@ def aten_ops_div( @dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc] def aten_ops_pow( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.pow( - network, + ctx, target, SourceIR.ATEN, name, @@ -1221,14 +1231,14 @@ def aten_ops_pow( @dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc] def aten_ops_floor_div( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.floor_divide( - network, + ctx, target, SourceIR.ATEN, name, @@ -1239,14 +1249,14 @@ def aten_ops_floor_div( @dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc] def aten_ops_logical_and( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.logical_and( - network, + ctx, target, SourceIR.ATEN, name, @@ -1257,14 +1267,14 @@ def aten_ops_logical_and( @dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc] def aten_ops_logical_or( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.logical_or( - network, + ctx, target, SourceIR.ATEN, name, @@ -1275,14 +1285,14 @@ def aten_ops_logical_or( @dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc] def aten_ops_logical_xor( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.logical_xor( - network, + ctx, target, SourceIR.ATEN, name, @@ -1294,14 +1304,14 @@ def aten_ops_logical_xor( @dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] def aten_ops_equal( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.eq( - network, + ctx, target, SourceIR.ATEN, name, @@ -1313,14 +1323,14 @@ def aten_ops_equal( @dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] def aten_ops_greater( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.gt( - network, + ctx, target, SourceIR.ATEN, name, @@ -1332,14 +1342,14 @@ def aten_ops_greater( @dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] def aten_ops_less( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.elementwise.lt( - network, + ctx, target, SourceIR.ATEN, name, @@ -1355,8 +1365,15 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, capability_validator=conv_param_validator ) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (np.ndarray, torch.Tensor, TRTTensor), + 2: (np.ndarray, torch.Tensor, TRTTensor), + } +) # type: ignore[misc] def aten_ops_convolution( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], @@ -1365,7 +1382,7 @@ def aten_ops_convolution( is_transposed = args[6] if not is_transposed: return impl.conv.convNd( - network, + ctx, target, source_ir=SourceIR.ATEN, name=name, @@ -1380,7 +1397,7 @@ def aten_ops_convolution( ) else: return impl.deconv.deconvNd( - network, + ctx, target, source_ir=SourceIR.ATEN, name=name, @@ -1398,14 +1415,14 @@ def aten_ops_convolution( @dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.linear) # type: ignore[misc] def aten_ops_linear( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.linear.linear( - network, + ctx, target, SourceIR.ATEN, name, @@ -1439,14 +1456,14 @@ def avg_pool_param_validator(pool_node: Node) -> bool: @dynamo_tensorrt_converter(torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] def aten_ops_avg_pool( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.pool.avg_poolNd( - network, + ctx, target, source_ir=SourceIR.ATEN, name=name, @@ -1482,14 +1499,14 @@ def max_pool_param_validator(pool_node: Node) -> bool: @dynamo_tensorrt_converter(torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] def aten_ops_max_pool( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.pool.max_poolNd( - network, + ctx, target, source_ir=SourceIR.ATEN, name=name, diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 787a6d6c25..5555686e77 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -39,6 +39,7 @@ def convert_module( Input.from_tensors(inputs, disable_memory_format_check=True), logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, + compilation_settings=settings, ) interpreter_result = interpreter.run( workspace_size=settings.workspace_size, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 58913c0d54..45445f0f89 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -18,12 +18,13 @@ from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS from torch_tensorrt.fx.types import TRTNetwork, TRTTensor logger = logging.getLogger(__name__) -ConverterImplSignature = Callable[ +LegacyConverterImplSignature = Callable[ [ TRTNetwork, Target, @@ -34,6 +35,21 @@ Union[TRTTensor, Sequence[TRTTensor]], ] +DynamoConverterImplSignature = Callable[ + [ + ConversionContext, + Target, + Tuple[Argument, ...], + Dict[str, Argument], + str, + ], + Union[TRTTensor, Sequence[TRTTensor]], +] + +ConverterImplSignature = Union[ + LegacyConverterImplSignature, DynamoConverterImplSignature +] + class ConverterPriority(Enum): """Enum to set a converter's priority in the registry""" @@ -42,6 +58,13 @@ class ConverterPriority(Enum): HIGH = auto() +class CallingConvention(Enum): + """Enum representing a converter's calling convention""" + + LEGACY = auto() # Legacy FX converters + CTX = auto() # New Dynamo converters + + @dataclass(frozen=True) class ConverterSupport: """Class representing a converter implementation and support function @@ -67,7 +90,7 @@ def dynamo_tensorrt_converter( enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, -) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]: +) -> Callable[[ConverterImplSignature], ConverterImplSignature]: """Decorator for Dynamo TensorRT Converter Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry @@ -156,7 +179,10 @@ class ConverterRegistry: registries: List of dictionaries representing converter registries. The order of the provided dictionaries is the order in which they will be traversed. This is only significant when using non-validated - methods. + methods + registry_names: Optional list of names for each registry + registry_calling_conventions: Optional list of calling conventions + for each registry """ def __init__( @@ -165,6 +191,7 @@ def __init__( Dict[Target, Union[Callable[..., Any], Sequence[ConverterSupport]]] ], registry_names: Optional[Sequence[str]] = None, + registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -177,6 +204,14 @@ def __init__( f"Registry {i + 1}" for i in range(len(self.registries)) ] + if registry_calling_conventions is not None: + assert len(self.registries) == len(registry_calling_conventions) + self.registry_calling_conventions = list(registry_calling_conventions) + else: + self.registry_calling_conventions = [ + CallingConvention.CTX for _ in range(len(self.registries)) + ] + self.validate_invariants() def validate_invariants(self) -> None: @@ -202,12 +237,13 @@ def validate_invariants(self) -> None: def __getitem_without_validation__( self, key: Target - ) -> ( - Any - ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters + ) -> Tuple[ + Any, CallingConvention + ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found converter in any registry - Searches all registries in order and returns the first converter encountered + Searches all registries in order and returns the first converter encountered, + along with the calling convention of the registry the converter was sourced from """ if isinstance(key, Node): raise KeyError( @@ -218,26 +254,29 @@ def __getitem_without_validation__( self.validate_invariants() # Iterate over all registries and return the first converter found - for registry in self.registries: + for registry, calling_convention in zip( + self.registries, self.registry_calling_conventions + ): if key in registry: converters = registry[key] if isinstance(converters, (list, tuple)): - return converters[0].converter_implementation + return converters[0].converter_implementation, calling_convention else: - return converters + return converters, calling_convention raise KeyError(f"None of the converter registries have an entry for {key}") def __getitem__( self, node: Node - ) -> ( - Any - ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters + ) -> Tuple[ + Any, CallingConvention + ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found validated converter in any registry - Searches all registries in order and returns the first converter - which passes validation on the input node + Searches all registries in order and returns the first converter which passes + validation on the input node, along with the calling convention of the + registry the converter was sourced from """ if not isinstance(node, Node): raise KeyError( @@ -251,16 +290,21 @@ def __getitem__( # Iterate over all registries, validating the converter on the input node # If no capability_validator function is found, assume full coverage - for registry in self.registries: + for registry, calling_convention in zip( + self.registries, self.registry_calling_conventions + ): if key in registry: converters = registry[key] if isinstance(converters, (list, tuple)): for candidate in converters: if candidate.capability_validator(node): - return candidate.converter_implementation + return ( + candidate.converter_implementation, + calling_convention, + ) else: - return converters + return converters, calling_convention raise KeyError( f"None of the converter registries have a validated entry for {key}, with node {node}" @@ -272,9 +316,9 @@ def keys(self) -> Set[Target]: def get_unvalidated( self, key: Target, value: Optional[ConverterImplSignature] = None - ) -> ( - Any - ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters + ) -> Union[ + Any, Tuple[Any, CallingConvention] + ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get unvalidated converter for input target with a default return""" try: return self.__getitem_without_validation__(key) @@ -283,9 +327,9 @@ def get_unvalidated( def get( self, node: Node, value: Optional[ConverterImplSignature] = None - ) -> ( - Any - ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters + ) -> Union[ + Any, Tuple[Any, CallingConvention] + ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get validated converter for input node with a default return""" try: return self.__getitem__(node) @@ -398,5 +442,6 @@ def display_all_available_converters(self) -> str: # Note the Dynamo registry is listed first, for precedence DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry( [DYNAMO_ATEN_CONVERTERS, CONVERTERS], # type: ignore[list-item] - ["Dynamo ATen Converters Registry", "FX ATen Converters Registry"], + ["Dynamo ATen Converters Registry", "FX Legacy ATen Converters Registry"], + [CallingConvention.CTX, CallingConvention.LEGACY], ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 12c11bc9f1..d1d94d40cc 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,23 +1,25 @@ import functools import logging import re -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt -from torch.fx.node import Target +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_registry import ( + ConverterRegistry, + DynamoConverterImplSignature, +) from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, get_axes_for_reduce_op, - to_numpy, unified_dtype_converter, ) -from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor - -from .._SourceIR import SourceIR -from .converter_registry import ConverterRegistry +from torch_tensorrt.fx.types import TRTDataType, TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -117,7 +119,7 @@ def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: def cast_trt_tensor( - network: TRTNetwork, + ctx: ConversionContext, input_val: TRTTensor, dtype: TRTDataType, name: str, @@ -131,7 +133,7 @@ def cast_trt_tensor( input unchanged Args: - network (TRTNetwork): A TensorRT network + ctx (ConversionContext): A ConversionContext containing the TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to name (str): Name of the calling layer @@ -147,7 +149,7 @@ def cast_trt_tensor( target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" - identity_layer = network.add_identity(input_val) + identity_layer = ctx.net.add_identity(input_val) identity_layer.set_output_type(0, trt_dtype) identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" return identity_layer.get_output(0) @@ -156,7 +158,7 @@ def cast_trt_tensor( def cast_int_int_div_trt_tensor( - network: TRTNetwork, + ctx: ConversionContext, lhs_val: TRTTensor, rhs_val: TRTTensor, name: str, @@ -164,7 +166,7 @@ def cast_int_int_div_trt_tensor( """ Given two `int` data type TRT Tensor to div operation, cast the TRT Tensor to float type Args: - network (TRTNetwork): A TensorRT network + ctx (ConversionContext): A ConversionContext object lhs_val (TRTTensor): A TRT Tensor numerator rhs_val (TRTTensor): A TRT Tensor numerator name (str): Name of calling layer @@ -172,8 +174,8 @@ def cast_int_int_div_trt_tensor( A list of lhs_val and rhs_val casted to the approriate datatype """ if lhs_val.dtype == trt.int32 and rhs_val.dtype == trt.int32: - lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name) - rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name) + lhs_val = cast_trt_tensor(ctx, lhs_val, trt.float32, name) + rhs_val = cast_trt_tensor(ctx, rhs_val, trt.float32, name) return [lhs_val, rhs_val] @@ -240,26 +242,26 @@ def extend_attr_to_tuple( def cast_int_or_float_to_bool( - network: TRTNetwork, name: str, tensor: TRTTensor + ctx: ConversionContext, name: str, tensor: TRTTensor ) -> TRTTensor: if tensor.dtype != trt.bool: - return cast_trt_tensor(network, tensor, trt.bool, name) + return cast_trt_tensor(ctx, tensor, trt.bool, name) return tensor def create_constant( - network: TRTNetwork, - value: Union[int, float, np.ndarray, torch.Tensor], + ctx: ConversionContext, + value: Union[int, float, bool, np.ndarray, torch.Tensor], name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], ) -> TRTTensor: """ - Add a TensorRT constant layer whose value is `value` to `network`. + Add a TensorRT constant layer whose value is `value` to `ctx.net`. Args: - network (TRTNetwork): A TensorRT network to which we want to add + ctx (ConversionContext): A TensorRT ConversionContext to which we want to add a constant layer. - value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array, + value (Union[int, float, bool, np.ndarray, torch.Tensor]): A literal value, Numpy array, or a PyTorch tensor that will be used as value of the added TensorRT Constant layer. name (str): Name of the added TensorRT Constant layer. dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): @@ -267,16 +269,17 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - constant = network.add_constant( - (1,) if isinstance(value, (int, float)) else value.shape, - to_numpy(value, dtype).copy(), + numpy_value = to_numpy(value, dtype) + constant = ctx.net.add_constant( + (1,) if isinstance(value, (int, float, bool)) else value.shape, + numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, ) constant.name = name return constant.get_output(0) def get_trt_tensor( - network: TRTNetwork, + ctx: ConversionContext, input_val: Any, name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, @@ -285,7 +288,7 @@ def get_trt_tensor( Given a value of random type, we try to convert it to a TensorRT ITensor. An runtime error is raised if we're not able to do that. Args: - network (TRTNetwork): A TensorRT network. If we want to + ctx (ConversionContext): A TensorRT ConversionContext. If we want to add a TensorRT Constant layer, we will add it to this network. input_val (Any): An value that we want to convert to a TensorRT ITensor. name (str): The name of the created TensorRT Constant layer if there's @@ -295,21 +298,26 @@ def get_trt_tensor( Returns: A TensorRT ITensor that represents the given value. """ - # TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later - # This is useful for logical operations which require input to be bool type - if isinstance(input_val, bool): - input_val = int(input_val) - elif isinstance(input_val, torch.Tensor) and ( - input_val.dtype == torch.bool or input_val.dtype == torch.int64 + # If the input is 64-bit, cast it to 32-bit for TRT freezing + if ( + isinstance(input_val, torch.Tensor) + and ctx.compilation_settings.truncate_long_and_double ): - input_val = input_val.to(torch.int32) - elif isinstance(input_val, np.ndarray) and ( - input_val.dtype == np.bool_ or input_val.dtype == np.int64 + if input_val.dtype == torch.int64: + input_val = input_val.to(torch.int32) + elif input_val.dtype == torch.float64: + input_val = input_val.to(torch.float32) + elif ( + isinstance(input_val, np.ndarray) + and ctx.compilation_settings.truncate_long_and_double ): - input_val = input_val.astype(np.int32) + if input_val.dtype == np.int64: + input_val = input_val.astype(np.int32) + elif input_val.dtype == np.float64: + input_val = input_val.astype(np.float32) - if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): - return create_constant(network, input_val, name, dtype) + if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): + return create_constant(ctx, input_val, name, dtype) elif isinstance(input_val, TRTTensor): return input_val else: @@ -352,3 +360,154 @@ def positive_dim(d: int) -> int: if isinstance(dim, int) else tuple(positive_dim(d) for d in dim) ) + + +def enforce_tensor_types( + type_dictionary: Dict[Union[int, str], Tuple[Union[TRTTensor, np.ndarray], ...]], + promote: bool = True, +) -> Callable[[DynamoConverterImplSignature], DynamoConverterImplSignature]: + """Decorator to enforce tensor types for input arguments to converters + + Keys in the type dictionary must be integers if they refer to a positional + argument in args, or strings if they refer to a keyword argument in kwargs + + Values must be tuples of data types denoting the approved data types for a given position + The approved types are TRTTensor, np.ndarray, and torch.Tensor. + + Note: torch.Tensor cannot be present without np.ndarray + + The promote argument controls whether tensors will be promoted if they are of the + incorrect format + """ + assert all( + isinstance(key, (int, str)) for key in type_dictionary + ), "Invalid key for type enforcement" + assert all( + ( + isinstance(val, tuple) + and not (torch.Tensor in val and np.ndarray not in val) + and all((dtype in (TRTTensor, np.ndarray, torch.Tensor)) for dtype in val) + ) + for val in type_dictionary.values() + ), ( + "Invalid value(s) specified in type enforcement." + "Note that torch.Tensor cannot be present as a type without np.ndarray." + ) + + def wrapper(func: DynamoConverterImplSignature) -> DynamoConverterImplSignature: + @functools.wraps(func) + def convert_with_type_enforcement( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + new_args = args + new_kwargs = {**kwargs} + new_value = None + + # Go through type dictionary and promote types accordingly + for index, approved_dtypes in type_dictionary.items(): + # Referencing an arg + if isinstance(index, int): + candidate = args[index] + # Referencing a kwarg + elif isinstance(index, str): + candidate = kwargs[index] + + # If the candidate Tensor is already an approved type, do nothing + if isinstance(candidate, approved_dtypes): + continue + # If the candidate Tensor is not an approved type, but promotion is disabled, error + elif not promote: + raise AssertionError( + f"Detected argument at index {index} had type {type(candidate)} " + f"which is not one of the approved types {approved_dtypes}" + ) + + promoted = False + + # Type-promotion preference order depends on tuple order + for dtype in approved_dtypes: + # Currently, we do not cast to Torch tensor, due to issues with such casts + # in FakeTensor contexts + if dtype == np.ndarray and not isinstance(candidate, TRTTensor): + new_value = to_numpy(candidate) + promoted = True + break + # As a fallback, freeze tensors to IConstantLayers if they cannot be handled as Numpy arrays + elif dtype == TRTTensor: + _LOGGER.debug( + f"Freezing tensor {name}_constant_{index} to TRT IConstantLayer" + ) + new_value = get_trt_tensor( + ctx, candidate, name + f"_constant_{index}" + ) + promoted = True + break + + if not promoted: + raise AssertionError( + f"Argument {candidate} at index {index} was not able to be " + f"converted to one of the following types: {approved_dtypes}" + ) + + # Reassemble args or kwargs if the value was modified + if isinstance(index, int): + new_args = new_args[:index] + (new_value,) + new_args[index + 1 :] + elif isinstance(index, str): + new_kwargs[index] = new_value + + return func(ctx, target, new_args, new_kwargs, name) + + return convert_with_type_enforcement + + return wrapper + + +def to_numpy( + value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, +) -> Optional[np.ndarray]: + """ + Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is + quantized it will be dequantized first. + Args: + value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): + A PyTorch tensor, Numpy array, int, float, or bool + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A Numpy array or None, if the input was None. + """ + output = None + + if value is None or isinstance(value, np.ndarray): + output = value + + elif isinstance(value, torch.Tensor): + if value.is_quantized: + value = value.dequantize() + + output = value.cpu().detach().contiguous().numpy() + + elif isinstance(value, int): + output = np.array([value], dtype=np.int32) + + elif isinstance(value, float): + output = np.array([value], dtype=np.float32) + + elif isinstance(value, bool): + output = np.array([value], dtype=np.bool_) + + if isinstance(output, np.ndarray) or output is None: + return ( + output + if (dtype is None or output is None) + else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY)) + ) + else: + raise AssertionError( + f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py index f2157dbdbd..f726a1c500 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py @@ -3,15 +3,16 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import ( mark_as_int8_layer, set_layer_name, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def convert_activation( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -19,7 +20,7 @@ def convert_activation( input_val: TRTTensor, alpha: Optional[Any] = None, beta: Optional[Any] = None, - dyn_range_fn: Optional[Callable[[float, float], Any]] = None, + dyn_range_fn: Optional[Callable[[Any], Any]] = None, ) -> TRTTensor: """ Add a TensorRT Activation layer to `network`. @@ -29,14 +30,14 @@ def convert_activation( f"{operation_type} received input {input_val} that is not part " "of the TensorRT region!" ) - layer = network.add_activation(input_val, operation_type) + layer = ctx.net.add_activation(input_val, operation_type) if alpha is not None: layer.alpha = alpha if beta is not None: layer.beta = beta set_layer_name(layer, target, name, source_ir) - if input_val.dynamic_range is not None: + if input_val.dynamic_range is not None and dyn_range_fn is not None: dyn_range = dyn_range_fn(input_val.dynamic_range) mark_as_int8_layer(layer, dyn_range) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index e39e781dd2..ac77f790cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -1,28 +1,29 @@ -from typing import Any, Optional +from typing import Any, Optional, Tuple import numpy as np import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def relu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, -): +) -> TRTTensor: operation_type = trt.ActivationType.RELU - def relu_dyn_range_fn(dyn_range): + def relu_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: return max(0, dyn_range[0]), max(0, dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -33,22 +34,22 @@ def relu_dyn_range_fn(dyn_range): def sigmoid( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, -): +) -> TRTTensor: operation_type = trt.ActivationType.SIGMOID - def sigmoid_dyn_range_fn(dyn_range): - def sigmoid_fn(x): + def sigmoid_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: + def sigmoid_fn(x: float) -> Any: return 1 / (1 + np.exp(-x)) return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -59,22 +60,22 @@ def sigmoid_fn(x): def tanh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, -): +) -> TRTTensor: operation_type = trt.ActivationType.TANH - def tanh_dyn_range_fn(dyn_range): - def tanh_fn(x): + def tanh_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: + def tanh_fn(x: float) -> Any: return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) return tanh_fn(dyn_range[0]), tanh_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -85,23 +86,23 @@ def tanh_fn(x): def leaky_relu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any] = 0.01, -): + alpha: float = 0.01, +) -> TRTTensor: operation_type = trt.ActivationType.LEAKY_RELU - def leaky_relu_dyn_range_fn(dyn_range): - def leaky_relu_fn(x): + def leaky_relu_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: + def leaky_relu_fn(x: float) -> float: return max(0, x) + alpha * min(0, x) return leaky_relu_fn(dyn_range[0]), leaky_relu_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -113,14 +114,14 @@ def leaky_relu_fn(x): def elu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any] = 1.0, - beta: Optional[Any] = None, -): + alpha: float = 1.0, + beta: Optional[float] = None, +) -> TRTTensor: EPS = 1e-4 # actually call selu() if ( @@ -129,19 +130,19 @@ def elu( and abs(beta - 1.0507009873554805) < EPS ): print("Selu is called but re-uses elu function!") - return selu(network, target, source_ir, name, input_val) + return selu(ctx, target, source_ir, name, input_val) else: operation_type = trt.ActivationType.ELU - def elu_dyn_range_fn(dyn_range): + def elu_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: return ( torch.nn.functional.elu(dyn_range[0], alpha), torch.nn.functional.elu(dyn_range[1], alpha), ) return convert_activation( - network, + ctx, target, source_ir, name, @@ -153,22 +154,22 @@ def elu_dyn_range_fn(dyn_range): def selu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, -): +) -> TRTTensor: operation_type = trt.ActivationType.SELU - def selu_dyn_range_fn(dyn_range): + def selu_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: return ( torch.nn.functional.selu(dyn_range[0]), torch.nn.functional.selu(dyn_range[1]), ) return convert_activation( - network, + ctx, target, source_ir, name, @@ -180,22 +181,22 @@ def selu_dyn_range_fn(dyn_range): # no corresponding function in aten/native_functions def softsign( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, -): +) -> TRTTensor: operation_type = trt.ActivationType.SOFTSIGN - def softsign_dyn_range_fn(dyn_range): + def softsign_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: return ( torch.nn.functional.softsign(dyn_range[0]), torch.nn.functional.softsign(dyn_range[1]), ) return convert_activation( - network, + ctx, target, source_ir, name, @@ -206,23 +207,23 @@ def softsign_dyn_range_fn(dyn_range): def softplus( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - beta: Optional[Any] = 1, -): + beta: float = 1, +) -> TRTTensor: operation_type = trt.ActivationType.SOFTPLUS - def softplus_dyn_range_fn(dyn_range): + def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: return ( torch.nn.functional.softplus(dyn_range[0], beta), torch.nn.functional.softplus(dyn_range[1], beta), ) return convert_activation( - network, + ctx, target, source_ir, name, @@ -235,24 +236,24 @@ def softplus_dyn_range_fn(dyn_range): def clip( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any], - beta: Optional[Any], -): + alpha: float, + beta: float, +) -> TRTTensor: operation_type = trt.ActivationType.CLIP - def clip_dyn_range_fn(dyn_range): - def clip_fn(x): + def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: + def clip_fn(x: float) -> float: return max(alpha, min(beta, x)) return clip_fn(dyn_range[0]), clip_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -265,24 +266,26 @@ def clip_fn(x): def hard_sigmoid( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any], - beta: Optional[Any], -): + alpha: float, + beta: float, +) -> TRTTensor: operation_type = trt.ActivationType.HARD_SIGMOID - def hard_sigmoid_dyn_range_fn(dyn_range): - def hard_sigmoid_fn(x): + def hard_sigmoid_dyn_range_fn( + dyn_range: Tuple[float, float] + ) -> Tuple[float, float]: + def hard_sigmoid_fn(x: float) -> float: return max(0, min(1, alpha * x + beta)) return hard_sigmoid_fn(dyn_range[0]), hard_sigmoid_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -296,24 +299,24 @@ def hard_sigmoid_fn(x): # no corresponding function in aten/native_functions def scaled_tanh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any], - beta: Optional[Any], -): + alpha: float, + beta: float, +) -> TRTTensor: operation_type = trt.ActivationType.SCALED_TANH - def scaled_tanh_dyn_range_fn(dyn_range): - def scaled_tanh_fn(x): + def scaled_tanh_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: + def scaled_tanh_fn(x: float) -> Any: return alpha * torch.nn.functional.tanh(beta * x) return scaled_tanh_fn(dyn_range[0]), scaled_tanh_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, @@ -327,23 +330,25 @@ def scaled_tanh_fn(x): # no corresponding function in aten/native_functions def thresholded_relu( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - alpha: Optional[Any], -): + alpha: float, +) -> TRTTensor: operation_type = trt.ActivationType.THRESHOLDED_RELU - def thresholded_relu_dyn_range_fn(dyn_range): - def thresholded_relu_fn(x): + def thresholded_relu_dyn_range_fn( + dyn_range: Tuple[float, float] + ) -> Tuple[float, float]: + def thresholded_relu_fn(x: float) -> float: return x if x > alpha else 0 return thresholded_relu_fn(dyn_range[0]), thresholded_relu_fn(dyn_range[1]) return convert_activation( - network, + ctx, target, source_ir, name, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index f31fd9a396..790f0d6f60 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -3,19 +3,20 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, unified_dtype_converter, ) -from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTDataType, TRTTensor LOGGER: logging.Logger = logging.getLogger(__name__) def to_copy( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -36,10 +37,10 @@ def to_copy( target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" - identity_layer = network.add_identity(input) + identity_layer = ctx.net.add_identity(input) identity_layer.set_output_type(0, trt_dtype) identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]" return identity_layer.get_output(0) else: - casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + casted_tensor = cast_trt_tensor(ctx, input, dtype, name, target, source_ir) return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index 9c225357b5..981c13397f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -4,17 +4,18 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.slice import expand from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def where( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -39,10 +40,10 @@ def where( # purpose of this is to bring input and other rank same as # output_shape to input it to the add_expand operation # condition will have dimension of either input or other - input, other = broadcast(network, input, other, f"{name}_x", f"{name}_y") + input, other = broadcast(ctx.net, input, other, f"{name}_x", f"{name}_y") if len(tuple(condition.shape)) != len(tuple(input.shape)): condition, input = broadcast( - network, condition, input, f"{name}_condition", f"{name}_x" + ctx.net, condition, input, f"{name}_condition", f"{name}_x" ) x_shape = list(input.shape) @@ -56,8 +57,8 @@ def where( if condition_shape != output_shape: condition.expand(output_shape) condition = condition.to(torch.int32) - condition_const = get_trt_tensor(network, condition, f"{name}_condition") - condition_layer = network.add_identity(condition_const) + condition_const = get_trt_tensor(ctx, condition, f"{name}_condition") + condition_layer = ctx.net.add_identity(condition_const) condition_layer.set_output_type(0, trt.bool) set_layer_name(condition_layer, target, f"{name}_condition") condition_val = condition_layer.get_output(0) @@ -65,7 +66,7 @@ def where( assert condition.dtype == trt.bool, "mask dtype is not bool!" if len(condition_shape) != condition_dim: condition_val = expand( - network, target, source_ir, f"{name}_expand", condition, output_shape + ctx, target, source_ir, f"{name}_expand", condition, output_shape ) else: condition_val = condition @@ -76,12 +77,12 @@ def where( if len(input.shape) == 0: input = input.unsqueeze(0) input = input.expand(output_shape) - x_val = get_trt_tensor(network, input, f"{name}_x") + x_val = get_trt_tensor(ctx, input, f"{name}_x") else: x_val = input if x_shape != output_shape: x_val = expand( - network, target, source_ir, f"{name}_x_expand", input, output_shape + ctx, target, source_ir, f"{name}_x_expand", input, output_shape ) if type(other) != TRTTensor: @@ -90,15 +91,15 @@ def where( if len(other.shape) == 0: other = other.unsqueeze(0) other = other.expand(output_shape) - y_val = get_trt_tensor(network, other, f"{name}_y") + y_val = get_trt_tensor(ctx, other, f"{name}_y") else: y_val = other if y_shape != output_shape: y_val = expand( - network, target, source_ir, f"{name}_y_expand", y_val, output_shape + ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape ) - select_layer = network.add_select(condition_val, x_val, y_val) + select_layer = ctx.net.add_select(condition_val, x_val, y_val) set_layer_name(select_layer, target, f"{name}_select") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index ebe4e37c9e..33b5fcbd87 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -7,23 +7,24 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, extend_attr_to_tuple, get_trt_tensor, + to_numpy, ) from torch_tensorrt.fx.converters.converter_utils import ( - SourceIR, get_dyn_range, has_dynamic_shape, mark_as_int8_layer, set_layer_name, - to_numpy, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def convNd( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -44,7 +45,7 @@ def convNd( if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d input = impl.unsqueeze.unsqueeze( - network, target, source_ir, name + "_unsqueeze_conv1d", input, -1 + ctx, target, source_ir, name + "_unsqueeze_conv1d", input, -1 ) # Process bias terms @@ -53,7 +54,7 @@ def convNd( bias = to_numpy(bias) elif isinstance(bias, TRTTensor): - bias = get_trt_tensor(network, bias, f"{name}_bias") + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif bias is not None: raise RuntimeError( @@ -61,12 +62,12 @@ def convNd( ) # Process weight terms - if network.has_explicit_precision or isinstance(weight, TRTTensor): - weight = get_trt_tensor(network, weight, f"{name}_weight") + if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: input = impl.unsqueeze.unsqueeze( - network, target, source_ir, name + "_unsqueeze_weight", weight, -1 + ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): @@ -83,7 +84,7 @@ def convNd( ) # add conv layer - conv_layer = network.add_convolution_nd( + conv_layer = ctx.net.add_convolution_nd( input=input, num_output_maps=weight.shape[0], kernel_shape=weight.shape[2:], @@ -134,7 +135,7 @@ def convNd( if is_conv1d: # Apply a squeeze operation to transform the conv2d problem back into conv1d result = impl.squeeze.squeeze( - network, target, source_ir, name + "_squeeze_conv1d", result, -1 + ctx, target, source_ir, name + "_squeeze_conv1d", result, -1 ) return result diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index e0f5844bd7..ebb9b1bec2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -7,9 +7,11 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( extend_attr_to_tuple, get_trt_tensor, + to_numpy, ) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, @@ -17,13 +19,12 @@ has_dynamic_shape, mark_as_int8_layer, set_layer_name, - to_numpy, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def deconvNd( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -44,7 +45,7 @@ def deconvNd( if is_deconv1d: # Apply an unsqueeze operation to transform the deconv1d problem into deconv2d input = impl.unsqueeze.unsqueeze( - network, target, source_ir, name + "_unsqueeze_deconv1d", input, -1 + ctx, target, source_ir, name + "_unsqueeze_deconv1d", input, -1 ) # Process bias terms @@ -53,7 +54,7 @@ def deconvNd( bias = to_numpy(bias) elif isinstance(bias, TRTTensor): - bias = get_trt_tensor(network, bias, f"{name}_bias") + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif bias is not None: raise RuntimeError( @@ -61,12 +62,12 @@ def deconvNd( ) # Process weight terms - if network.has_explicit_precision or isinstance(weight, TRTTensor): - weight = get_trt_tensor(network, weight, f"{name}_weight") + if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: input = impl.unsqueeze.unsqueeze( - network, target, source_ir, name + "_unsqueeze_weight", weight, -1 + ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): @@ -83,7 +84,7 @@ def deconvNd( ) # add deconv layer - deconv_layer = network.add_deconvolution_nd( + deconv_layer = ctx.net.add_deconvolution_nd( input=input, num_output_maps=weight.shape[0], kernel_shape=weight.shape[2:], @@ -134,7 +135,7 @@ def deconvNd( if is_deconv1d: # Apply a squeeze operation to transform the deconv2d problem back into deconv1d result = impl.squeeze.squeeze( - network, target, source_ir, name + "_squeeze_deconv1d", result, -1 + ctx, target, source_ir, name + "_squeeze_deconv1d", result, -1 ) return result diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index b2176653d1..3700242fe7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -7,12 +7,13 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_trt_tensor, ) from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name -from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -52,7 +53,7 @@ def get_python_op_from_trt_elementwise_op( def convert_binary_elementwise( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -73,7 +74,7 @@ def convert_binary_elementwise( tensor are not allowed to have larger ranks than the trt tensor operand. Args: - network (TRTNetwork): TensorRT network object. + ctx (ConversionContext): TensorRT ConversionContext object. target (Target): Target of fx node. source_ir (SourceIR): The IR that is calling the function. name (str): The name we want to assign to the created TensorRT layer. @@ -128,8 +129,8 @@ def convert_binary_elementwise( [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) ) - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) promoted_type = torch.promote_types( unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), @@ -139,15 +140,15 @@ def convert_binary_elementwise( if trt_promoted_type != lhs_val.dtype: lhs_val = cast_trt_tensor( - network, lhs_val, trt_promoted_type, name, target, source_ir + ctx, lhs_val, trt_promoted_type, name, target, source_ir ) if trt_promoted_type != rhs_val.dtype: rhs_val = cast_trt_tensor( - network, rhs_val, trt_promoted_type, name, target, source_ir + ctx, rhs_val, trt_promoted_type, name, target, source_ir ) # Check the limitation in the doc string. - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: if is_lhs_trt_tensor and not is_rhs_trt_tensor: assert len(lhs_val.shape) >= len( rhs_val.shape @@ -158,9 +159,9 @@ def convert_binary_elementwise( ), f"{rhs_val.shape} >= {lhs_val.shape}" lhs_val, rhs_val = broadcast( - network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" ) - layer = network.add_elementwise(lhs_val, rhs_val, op_type) + layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) kind: str = str(target.__name__) if callable(target) else target diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 75ff33f26f..9f1143959f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -4,6 +4,7 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_int_int_div_trt_tensor, cast_int_or_float_to_bool, @@ -15,12 +16,12 @@ from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter def trunc_div( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -33,7 +34,7 @@ def trunc_div( it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. Args: - network: INetworkDefinition. + ctx: ConversionContext. target: node target source_ir (SourceIR): Source IR calling the function. name: namespace for the op @@ -44,7 +45,7 @@ def trunc_div( A TensorRT tensor represent the result of trunc divide. """ prod_output = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_prod", @@ -54,7 +55,7 @@ def trunc_div( ) sign_output = sign( - network, + ctx, target, source_ir, name, @@ -63,17 +64,17 @@ def trunc_div( # Convert constant input into ITensor for UnaryOperation if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") + input = get_trt_tensor(ctx, input, f"{name}_input") if not isinstance(other, trt.tensorrt.ITensor): other = get_trt_tensor( - network, + ctx, other, f"{name}_other", dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), ) abs_input_output = convert_unary( - network, + ctx, target, source_ir, f"{name}_abs_input", @@ -81,7 +82,7 @@ def trunc_div( input, ) abs_other_output = convert_unary( - network, + ctx, target, source_ir, f"{name}_abs_other", @@ -89,7 +90,7 @@ def trunc_div( other, ) abs_floor_output = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_floor_div", @@ -98,7 +99,7 @@ def trunc_div( abs_other_output, ) output = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_output", @@ -111,14 +112,14 @@ def trunc_div( def rsqrt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, ) -> TRTTensor: sqrt_trt_output = convert_unary( - network, + ctx, target, source_ir, f"{name}_sqrt", @@ -127,7 +128,7 @@ def rsqrt( ) output = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_output", @@ -140,7 +141,7 @@ def rsqrt( def fmod( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -149,7 +150,7 @@ def fmod( ) -> TRTTensor: # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it trunc_div_value = trunc_div( - network, + ctx, target, source_ir, name + "_trunc_div", @@ -157,7 +158,7 @@ def fmod( other, ) prod_value = convert_binary_elementwise( - network, + ctx, target, source_ir, name + "_prod", @@ -166,7 +167,7 @@ def fmod( other, ) sub_value = convert_binary_elementwise( - network, + ctx, target, SourceIR.ACC, name + "_sub", @@ -178,7 +179,7 @@ def fmod( def clamp( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -193,7 +194,7 @@ def clamp( ) def _add_layer( - network: TRTNetwork, + ctx: ConversionContext, input: TRTTensor, val: float, op: trt.ElementWiseOperation, @@ -204,7 +205,7 @@ def _add_layer( if not len(input.shape): # clamping scalar acc_ops_clamp_trt = get_trt_tensor( - network, + ctx, squeeze_left( np.array( [val], @@ -220,21 +221,21 @@ def _add_layer( val, dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), ) - acc_ops_clamp_trt = network.add_constant( + acc_ops_clamp_trt = ctx.net.add_constant( acc_ops_clamp_shape, acc_ops_clamp_tensor ).get_output(0) - layer = network.add_elementwise(input, acc_ops_clamp_trt, op) + layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op) return layer if min_val is not None: clamp_min_layer = _add_layer( - network, input_val, min_val, trt.ElementWiseOperation.MAX, name + ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name ) set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") input_val = clamp_min_layer.get_output(0) if max_val is not None: clamp_max_layer = _add_layer( - network, input_val, max_val, trt.ElementWiseOperation.MIN, name + ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name ) set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") input_val = clamp_max_layer.get_output(0) @@ -243,7 +244,7 @@ def _add_layer( def add( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -251,12 +252,12 @@ def add( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val ) def mul( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -264,7 +265,7 @@ def mul( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, + ctx, target, source_ir, name, @@ -275,7 +276,7 @@ def mul( def max( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -283,12 +284,12 @@ def max( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.MAX, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.MAX, lhs_val, rhs_val ) def min( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -296,12 +297,12 @@ def min( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.MIN, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.MIN, lhs_val, rhs_val ) def sub( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -309,12 +310,12 @@ def sub( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.SUB, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.SUB, lhs_val, rhs_val ) def div( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -322,15 +323,15 @@ def div( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor): - lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name) + lhs_val, rhs_val = cast_int_int_div_trt_tensor(ctx, lhs_val, rhs_val, name) return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.DIV, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.DIV, lhs_val, rhs_val ) def pow( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -338,15 +339,15 @@ def pow( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor): - lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name) + lhs_val, rhs_val = cast_int_int_div_trt_tensor(ctx, lhs_val, rhs_val, name) return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val ) def floor_divide( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -354,7 +355,7 @@ def floor_divide( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, + ctx, target, source_ir, name, @@ -365,7 +366,7 @@ def floor_divide( def logical_and( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -373,18 +374,18 @@ def logical_and( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): - lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) if isinstance(rhs_val, TRTTensor): - rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + rhs_val = cast_int_or_float_to_bool(ctx, name, rhs_val) return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.AND, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.AND, lhs_val, rhs_val ) def logical_or( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -392,18 +393,18 @@ def logical_or( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): - lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) if isinstance(rhs_val, TRTTensor): - rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + rhs_val = cast_int_or_float_to_bool(ctx, name, rhs_val) return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.OR, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.OR, lhs_val, rhs_val ) def logical_xor( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -411,18 +412,18 @@ def logical_xor( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): - lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) if isinstance(rhs_val, TRTTensor): - rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + rhs_val = cast_int_or_float_to_bool(ctx, name, rhs_val) return convert_binary_elementwise( - network, target, source_ir, name, trt.ElementWiseOperation.XOR, lhs_val, rhs_val + ctx, target, source_ir, name, trt.ElementWiseOperation.XOR, lhs_val, rhs_val ) def eq( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -430,7 +431,7 @@ def eq( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, + ctx, target, source_ir, name, @@ -441,7 +442,7 @@ def eq( def gt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -449,7 +450,7 @@ def gt( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, + ctx, target, source_ir, name, @@ -460,7 +461,7 @@ def gt( def lt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -468,7 +469,7 @@ def lt( rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: return convert_binary_elementwise( - network, + ctx, target, source_ir, name, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 8ddfdf015f..b7795ea1f3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -3,13 +3,14 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def embedding( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -24,10 +25,8 @@ def embedding( raise RuntimeError( "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." ) - indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor") - embedding_tensor = get_trt_tensor( - network, embedding_tensor, f"{name}_embedding_tensor" - ) + indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") + embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters # ignore padding_idx since it is meaningful for training only @@ -40,6 +39,6 @@ def embedding( raise RuntimeError("Currently we don't support sparse gradient.") # Implement embedding lookup with gather layer - gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0) + gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0) set_layer_name(gather_layer, target, name + "_gather", source_ir) return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py index cad97a5c9a..69ef73964d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/linear.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -5,12 +5,13 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def linear( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -24,7 +25,7 @@ def linear( f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - weight = get_trt_tensor(network, weight, f"{name}_weight") + weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Process bias terms if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)): @@ -32,11 +33,11 @@ def linear( f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," ) elif isinstance(bias, (torch.Tensor, np.ndarray)): - bias = get_trt_tensor(network, bias, f"{name}_bias") + bias = get_trt_tensor(ctx, bias, f"{name}_bias") # add IMatrixMultiplyLayer out = impl.matmul.matrix_multiply( - network, + ctx, target, source_ir, name, @@ -48,6 +49,6 @@ def linear( if bias is not None: # add bias - out = impl.elementwise.add(network, target, source_ir, name, out, bias) + out = impl.elementwise.add(ctx, target, source_ir, name, out, bias) return out diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index a62d24121f..a50ec3c434 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -3,14 +3,15 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter def matrix_multiply( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -20,10 +21,10 @@ def matrix_multiply( other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, ) -> TRTTensor: if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") + input = get_trt_tensor(ctx, input, f"{name}_input") if not isinstance(other, trt.tensorrt.ITensor): other = get_trt_tensor( - network, + ctx, other, f"{name}_other", dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), @@ -40,8 +41,8 @@ def matrix_multiply( other_matrix_op = trt.MatrixOperation.VECTOR input, other = broadcast( - network, input, other, f"{name}_input", f"{name}_other", preset_diff + ctx.net, input, other, f"{name}_input", f"{name}_other", preset_diff ) - layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) + layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 44209de2f0..81bd88cd4f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -6,7 +6,8 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) @@ -15,16 +16,15 @@ get_trt_plugin, has_dynamic_shape, set_layer_name, - to_numpy, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims _LOGGER: logging.Logger = logging.getLogger(__name__) def batch_norm( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -55,11 +55,11 @@ def batch_norm( # For BatchNorm1d, reshape 1d to 2d output_shape = input.shape - if not network.has_implicit_batch_dimension and len(input.shape) < 4: + if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4: assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "BatchNorm1D with more than one dynamic dims is not currently supported." - reshape_layer = network.add_shuffle(input) + reshape_layer = ctx.net.add_shuffle(input) if len(input.shape) == 2: reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1) else: # len(input_val.shape) == 3 @@ -71,12 +71,12 @@ def batch_norm( ) set_layer_name(reshape_layer, target, f"{name}_reshape_2d") input = reshape_layer.get_output(0) - layer = network.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) + layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) set_layer_name(layer, target, name) # For BatchNorm1d, reshape output back to 1d - if not network.has_implicit_batch_dimension and len(output_shape) < 4: - reshape_output_layer = network.add_shuffle(layer.get_output(0)) + if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: + reshape_output_layer = ctx.net.add_shuffle(layer.get_output(0)) reshape_output_layer.reshape_dims = tuple(output_shape) set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") layer = reshape_output_layer @@ -84,7 +84,7 @@ def batch_norm( def layer_norm( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -127,7 +127,7 @@ def layer_norm( ) try: - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") else: plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") @@ -136,15 +136,15 @@ def layer_norm( "Unable to find layer norm plugin, fall back to TensorRT implementation." ) return layer_norm_no_plugin( - network, target, source_ir, name, input, normalized_shape, weight, bias, eps + ctx, target, source_ir, name, input, normalized_shape, weight, bias, eps ) - layer = network.add_plugin_v2([input], plugin) + layer = ctx.net.add_plugin_v2([input], plugin) layer.name = name return layer.get_output(0) def layer_norm_no_plugin( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -170,14 +170,14 @@ def layer_norm_no_plugin( axes |= 1 << (len(input.shape) - d - 1) # E[x] - mean_expected_layer = network.add_reduce( + mean_expected_layer = ctx.net.add_reduce( input, trt.ReduceOperation.AVG, axes, keep_dims=True ) set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] sub_trt = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_sub", @@ -186,13 +186,13 @@ def layer_norm_no_plugin( mean_expected_layer.get_output(0), ) # Variance = mean(pow(x_sub_mean,2)) - pow_tensor = network.add_constant( + pow_tensor = ctx.net.add_constant( (1,) * len(input.shape), trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" pow_var = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_pow_var", @@ -200,18 +200,18 @@ def layer_norm_no_plugin( sub_trt, pow_tensor.get_output(0), ) - mean_trt_layer = network.add_reduce( + mean_trt_layer = ctx.net.add_reduce( pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True ) set_layer_name(mean_trt_layer, target, f"{name}_mean") # Variance + eps - eps_tensor = network.add_constant( + eps_tensor = ctx.net.add_constant( (1,) * len(input.shape), trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), ) eps_tensor.name = f"{name}_eps" add_trt = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_add", @@ -221,7 +221,7 @@ def layer_norm_no_plugin( ) # SQRT((Var + eps)) sqrt_trt = convert_unary( - network, + ctx, target, source_ir, f"{name}_sqrt", @@ -230,7 +230,7 @@ def layer_norm_no_plugin( ) # (x - E[x]) / sqrt((var + eps)) div_trt = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_div_trt", @@ -240,18 +240,18 @@ def layer_norm_no_plugin( ) assert gamma is not None - gamma_tensor = network.add_constant( + gamma_tensor = ctx.net.add_constant( gamma.shape, trt.Weights(np.ascontiguousarray(gamma)) ) gamma_tensor.name = f"{name}_gamma" assert beta is not None - beta_tensor = network.add_constant( + beta_tensor = ctx.net.add_constant( gamma.shape, trt.Weights(np.ascontiguousarray(beta)) ) beta_tensor.name = f"{name}_beta" # y * gamma + beta scale_layer = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_scale", @@ -260,7 +260,7 @@ def layer_norm_no_plugin( gamma_tensor.get_output(0), ) return convert_binary_elementwise( - network, + ctx, target, source_ir, name, @@ -271,14 +271,14 @@ def layer_norm_no_plugin( def softmax( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, dim: Optional[Any] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + input_ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) if not isinstance(input, TRTTensor): raise RuntimeError( @@ -300,11 +300,11 @@ def get_softmax_dim(ndim: int) -> int: dim = cast(int, dim) dim = get_positive_dim(dim, input_ranks) - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." dim -= 1 - layer = network.add_softmax(input) + layer = ctx.net.add_softmax(input) layer.axes = 1 << dim set_layer_name(layer, target, name) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index ff1e98dbf5..bdd9b46314 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -2,13 +2,14 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def permute( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -22,7 +23,7 @@ def permute( permutation = get_positive_dim(permutation, len(input.shape)) - layer = network.add_shuffle(input) + layer = ctx.net.add_shuffle(input) layer.second_transpose = tuple(permutation) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index a84402ba89..13c8645a90 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -3,16 +3,17 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def avg_poolNd( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -45,7 +46,7 @@ def avg_poolNd( padding = extend_attr_to_tuple(padding, dim) # add average pooling layer - pool_layer = network.add_pooling_nd( + pool_layer = ctx.net.add_pooling_nd( input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size, @@ -60,7 +61,7 @@ def avg_poolNd( def max_poolNd( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -92,7 +93,7 @@ def max_poolNd( padding = extend_attr_to_tuple(padding, dim) # add max pooling layer - pool_layer = network.add_pooling_nd( + pool_layer = ctx.net.add_pooling_nd( input=input, type=trt.PoolingType.MAX, window_size=kernel_size, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 1cb2559ae3..0357962be5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -3,17 +3,18 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_axes_for_reduce_op, get_positive_dim, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def amax( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -24,9 +25,9 @@ def amax( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) - layer = network.add_reduce( + layer = ctx.net.add_reduce( input_val, trt.ReduceOperation.MAX, axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))), @@ -37,7 +38,7 @@ def amax( def sum( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -48,11 +49,11 @@ def sum( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) if dim is None: dim = tuple(range(len(input_val.shape))) - layer = network.add_reduce( + layer = ctx.net.add_reduce( input_val, trt.ReduceOperation.SUM, axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 9b65245dbe..20132fa460 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -3,14 +3,15 @@ import numpy as np from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape -from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape, to_numpy -from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor +from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape +from torch_tensorrt.fx.types import Shape, TRTTensor def select( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -24,10 +25,10 @@ def select( "of the TensorRT region!" ) - ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) dim = get_positive_dim(cast(int, dim), ranks) dynamic_shape = has_dynamic_shape(input.shape) - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: if dim == 0: raise RuntimeError( f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" @@ -47,14 +48,14 @@ def select( output_shape[dim] = 1 if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( - network, target, source_ir, name, output_shape, input + ctx, target, source_ir, name, output_shape, input ) index_value = np.array(index, dtype=np.int32) - indices_tensor = network.add_constant( + indices_tensor = ctx.net.add_constant( index_value.shape, to_numpy(index_value) ).get_output(0) - layer = network.add_gather(input, indices_tensor, dim) + layer = ctx.net.add_gather(input, indices_tensor, dim) out = layer.get_output(0) if len(out.shape) != 1: - layer = network.add_shuffle(out) + layer = ctx.net.add_shuffle(out) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 8bd137d991..ef30b186c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -3,20 +3,21 @@ from typing import List, Optional, Tuple import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name, to_numpy -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -import tensorrt as trt +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor def get_shape_with_dynamic_shape( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -37,7 +38,7 @@ def get_shape_with_dynamic_shape( 5. output shape with actual batch_size as [2048, 128, 256] Args: - network (TRTNetwork): TensorRT network object. + ctx (ConversionContext): TensorRT ConversionContext object. shape: calculated shape of the expected output tensor input_val (TRTTensor): A TensorRT ITensor. target (Target): Target of fx node. @@ -46,22 +47,22 @@ def get_shape_with_dynamic_shape( TensorRT ITensors that represents the actual shape of the input_val """ # Ger real shape info for input_val - input_shape = network.add_shape(input_val).get_output(0) + input_shape = ctx.net.add_shape(input_val).get_output(0) - scale_layer = network.add_constant( + scale_layer = ctx.net.add_constant( input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) ) set_layer_name(scale_layer, target, f"{name}_scale") scale_res = scale_layer.get_output(0) length = input_shape.shape[0] - zero_layer = network.add_constant( + zero_layer = ctx.net.add_constant( input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) ) set_layer_name(zero_layer, target, f"{name}_zeros") condition_val = convert_binary_elementwise( - network, + ctx, target, source_ir, f"{name}_shape", @@ -69,6 +70,6 @@ def get_shape_with_dynamic_shape( scale_res, zero_layer.get_output(0), ) - select_layer = network.add_select(condition_val, input_shape, scale_res) + select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) set_layer_name(select_layer, target, f"{name}_select") return select_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py index 57e72803a8..018ac63b8c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py @@ -2,16 +2,17 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import Shape, TRTTensor def slice( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -22,10 +23,8 @@ def slice( ) -> TRTTensor: dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape: - shape = get_shape_with_dynamic_shape( - network, target, source_ir, name, shape, input - ) - layer = network.add_slice( + shape = get_shape_with_dynamic_shape(ctx, target, source_ir, name, shape, input) + layer = ctx.net.add_slice( input, start=start, shape=[] if dynamic_shape else shape, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 8904e140cf..97ffdb728f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -3,6 +3,7 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.fx.converters.converter_utils import ( @@ -10,11 +11,11 @@ prepend_ones, set_layer_name, ) -from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import Shape, TRTTensor def slice_op( # TODO: This should be slice not whatever is in base - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -30,10 +31,10 @@ def slice_op( # TODO: This should be slice not whatever is in base "of the TensorRT region!" ) - ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) dim = get_positive_dim(dim, ranks) dynamic_shape = has_dynamic_shape(input.shape) - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: if dim == 0: raise RuntimeError( f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" @@ -56,12 +57,12 @@ def slice_op( # TODO: This should be slice not whatever is in base output_shape[dim] = math.ceil((stop_int - start_int) / step_int) return slice( - network, target, source_ir, name, input, start_slice, output_shape, stride_slice + ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice ) def expand( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -79,7 +80,7 @@ def expand( # If the rank of the input tensor is less than the shape's rank, pad with ones if initial_tensor_rank < shape_rank: input_t = prepend_ones( - network, + ctx.net, input_t, name + "_expand_broadcast", shape_rank - initial_tensor_rank, @@ -105,6 +106,6 @@ def expand( stride = tuple( [int(i == o) for i, o in zip(input_tensor_shape, shape)] ) # stride == 1 if dimensions match, 0 otherwise - layer = network.add_slice(input_t, start=start, shape=shape, stride=stride) + layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py index 1785e454e5..0f07ceb7ab 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/split.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -1,21 +1,18 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Union -import numpy as np -import torch -import torch_tensorrt as trt -from torch import Tensor from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def split( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -51,7 +48,7 @@ def split( sum_split_sizes = sum(split_sizes) if sum_split_sizes != input.shape[dim]: raise RuntimeError( - f"split sizes don't add up to the tensor's size in the given dimension" + "split sizes don't add up to the tensor's size in the given dimension" ) if num_splits < 1: @@ -68,9 +65,9 @@ def split( start[dim] = offset if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, target, source_ir, f"{name}_shape_{i}", shape, input + ctx, target, source_ir, f"{name}_shape_{i}", shape, input ) - layer = network.add_slice( + layer = ctx.net.add_slice( input, start=start, shape=[] if dynamic_shape else shape, stride=stride ) if dynamic_shape: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index 1eb6c3c3aa..cde4fdd90d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -2,14 +2,15 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims def squeeze( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -31,9 +32,9 @@ def squeeze( for dim in dims: dim = get_positive_dim( dim, - len(input.shape) + (1 if network.has_implicit_batch_dimension else 0), + len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0), ) - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: assert dim != 0, "We don't support squeeze batch dim when it's implicit." dim -= 1 @@ -48,7 +49,7 @@ def squeeze( if (i in new_dims) and s == 1: continue output_shape.append(s) - layer = network.add_shuffle(input) + layer = ctx.net.add_shuffle(input) layer.reshape_dims = tuple(output_shape) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py index 4c5011eeec..5da8bad252 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py @@ -1,15 +1,15 @@ from typing import Optional +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -import tensorrt as trt +from torch_tensorrt.fx.types import TRTTensor def convert_unary( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -20,7 +20,7 @@ def convert_unary( Add a TensorRT Unary layer to `network`. Args: - network (TRTNetwork): TensorRT network object. + ctx (ConversionContext): TensorRT ConversionContext object. input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. target (Target): Target of fx node. @@ -34,7 +34,7 @@ def convert_unary( f"{operation_type} received input {input_val} that is not part " "of the TensorRT region!" ) - layer = network.add_unary(input_val, operation_type) + layer = ctx.net.add_unary(input_val, operation_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) kind: str = str(target.__name__) if callable(target) else target diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 1a52ae7dc6..58c5f6ff4a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -3,13 +3,14 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def exp( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -17,7 +18,7 @@ def exp( ) -> TRTTensor: """ Args: - network (TRTNetwork): TensorRT network object. + ctx (ConversionContext): TensorRT ConversionContext object. target (Target): fx node target. source_ir (SourceIR): Source IR calling the function name (str): Name of the fx node with optional suffix. @@ -29,15 +30,15 @@ def exp( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.EXP, input_val + ctx, target, source_ir, name, trt.UnaryOperation.EXP, input_val ) def log( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -46,15 +47,15 @@ def log( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.LOG, input_val + ctx, target, source_ir, name, trt.UnaryOperation.LOG, input_val ) def sqrt( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -63,15 +64,15 @@ def sqrt( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.SQRT, input_val + ctx, target, source_ir, name, trt.UnaryOperation.SQRT, input_val ) def recip( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -80,27 +81,27 @@ def recip( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.RECIP, input_val + ctx, target, source_ir, name, trt.UnaryOperation.RECIP, input_val ) def abs( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, ) -> TRTTensor: return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ABS, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ABS, input_val ) def sin( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -109,15 +110,15 @@ def sin( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.SIN, input_val + ctx, target, source_ir, name, trt.UnaryOperation.SIN, input_val ) def cos( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -126,15 +127,15 @@ def cos( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.COS, input_val + ctx, target, source_ir, name, trt.UnaryOperation.COS, input_val ) def tan( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -143,15 +144,15 @@ def tan( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.TAN, input_val + ctx, target, source_ir, name, trt.UnaryOperation.TAN, input_val ) def sinh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -160,15 +161,15 @@ def sinh( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.SINH, input_val + ctx, target, source_ir, name, trt.UnaryOperation.SINH, input_val ) def cosh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -177,15 +178,15 @@ def cosh( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.COSH, input_val + ctx, target, source_ir, name, trt.UnaryOperation.COSH, input_val ) def asin( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -194,15 +195,15 @@ def asin( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ASIN, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ASIN, input_val ) def acos( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -211,15 +212,15 @@ def acos( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ACOS, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ACOS, input_val ) def atan( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -228,15 +229,15 @@ def atan( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ATAN, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ATAN, input_val ) def asinh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -245,15 +246,15 @@ def asinh( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ASINH, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ASINH, input_val ) def acosh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -262,15 +263,15 @@ def acosh( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ACOSH, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ACOSH, input_val ) def atanh( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -279,15 +280,15 @@ def atanh( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ATANH, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ATANH, input_val ) def ceil( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -296,15 +297,15 @@ def ceil( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.CEIL, input_val + ctx, target, source_ir, name, trt.UnaryOperation.CEIL, input_val ) def floor( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -313,30 +314,30 @@ def floor( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.FLOOR, input_val + ctx, target, source_ir, name, trt.UnaryOperation.FLOOR, input_val ) def logical_not( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, ) -> TRTTensor: if (isinstance(input_val, TRTTensor)) and input_val.dtype != trt.bool: - input_val = cast_trt_tensor(network, input_val, trt.bool, name) + input_val = cast_trt_tensor(ctx, input_val, trt.bool, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.NOT, input_val + ctx, target, source_ir, name, trt.UnaryOperation.NOT, input_val ) def sign( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -345,15 +346,15 @@ def sign( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.SIGN, input_val + ctx, target, source_ir, name, trt.UnaryOperation.SIGN, input_val ) def round( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -362,15 +363,15 @@ def round( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ROUND, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ROUND, input_val ) def isinf( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -379,15 +380,15 @@ def isinf( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ISINF, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ISINF, input_val ) def neg( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -396,15 +397,15 @@ def neg( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.NEG, input_val + ctx, target, source_ir, name, trt.UnaryOperation.NEG, input_val ) def erf( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -413,8 +414,8 @@ def erf( if (isinstance(input_val, TRTTensor)) and ( input_val.dtype == trt.int8 or input_val.dtype == trt.int32 ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) return convert_unary( - network, target, source_ir, name, trt.UnaryOperation.ERF, input_val + ctx, target, source_ir, name, trt.UnaryOperation.ERF, input_val ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 4f84973d84..185a985e10 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -2,24 +2,25 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( get_positive_dim, get_trt_tensor, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import Shape, TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims def unsqueeze( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input_t: TRTTensor, dim: Shape, ) -> TRTTensor: - input_val = get_trt_tensor(network, input_t, f"{name}_input_t") + input_val = get_trt_tensor(ctx, input_t, f"{name}_input_t") if not isinstance(input_val, TRTTensor): raise RuntimeError( f"unsqueeze received input {input_val} that is not part " @@ -30,19 +31,19 @@ def unsqueeze( input_shape_size = ( len(input_val.shape) + 1 - if network.has_implicit_batch_dimension + if ctx.net.has_implicit_batch_dimension else len(input_val.shape) ) dim = get_positive_dim(dim, input_shape_size + 1) - if network.has_implicit_batch_dimension: + if ctx.net.has_implicit_batch_dimension: assert dim != 0 dim -= 1 assert ( len(get_dynamic_dims(input_val.shape)) <= 1 ), "Currently we don't support unsqueeze with more than one dynamic dims." - layer = network.add_shuffle(input_val) + layer = ctx.net.add_shuffle(input_val) layer.reshape_dims = ( tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:] ) diff --git a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py index a546e34305..08285762ce 100644 --- a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py @@ -3,7 +3,8 @@ from typing import Dict, Sequence, Tuple, Union from torch.fx.node import Argument, Node, Target -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.types import TRTTensor from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter @@ -20,7 +21,7 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) def generic_evaluator( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py index 0fb8e89414..8bb44ac8e0 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py @@ -65,7 +65,7 @@ def maxpool1d_generic( # "bias": bias, # ... # -@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) # type: ignore +@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) # type: ignore[misc] def maxpool1d_insertion_fn( gm: torch.fx.GraphModule, node: torch.fx.Node, diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index d32171e48b..db9520ccd5 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -222,6 +222,7 @@ def forward(self, x): inputs, min_block_size=1, pass_through_build_failures=True, + truncate_long_and_double=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index c8ea5bb5c0..c997250b5f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -8,6 +8,7 @@ from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input +from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter @@ -282,10 +283,15 @@ def run_test( pass_tracer = chain_passes(*apply_passes) mod = pass_tracer(mod, inputs) + # Previous instance of the interpreter auto-casted 64-bit inputs + # We replicate this behavior here + compilation_settings = CompilationSettings(truncate_long_and_double=True) + interp = TRTInterpreter( mod, Input.from_tensors(inputs), output_dtypes=output_dtypes, + compilation_settings=compilation_settings, ) super().run_test( mod, @@ -321,10 +327,15 @@ def run_test_with_dynamic_shape( disable_passes=disable_passes, ) + # Previous instance of the interpreter auto-casted 64-bit inputs + # We replicate this behavior here + compilation_settings = CompilationSettings(truncate_long_and_double=True) + interp = TRTInterpreter( mod, input_specs, output_dtypes=output_dtypes, + compilation_settings=compilation_settings, ) # Since the lowering is based on optimal shape. We need to test with # different shape(for ex. max shape) for testing dynamic shape diff --git a/tests/py/dynamo/conversion/test_converter_utils.py b/tests/py/dynamo/conversion/test_converter_utils.py new file mode 100644 index 0000000000..b4f1ff2f93 --- /dev/null +++ b/tests/py/dynamo/conversion/test_converter_utils.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types +from torch_tensorrt.fx.types import TRTTensor + +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing + + +class TestTensorTypeEnforcement(TestCase): + def test_valid_type_no_promotion(self): + @enforce_tensor_types({0: (np.ndarray, torch.Tensor)}, promote=False) + def fake_converter(network, target, args, kwargs, name): + self.assertIsInstance(args[0], np.ndarray) + return + + fake_converter(None, None, (np.ones((4, 4)),), {}, "fake") + + def test_different_type_no_promotion(self): + @enforce_tensor_types({0: (TRTTensor,)}, promote=False) + def fake_converter(network, target, args, kwargs, name): + return + + with self.assertRaises(AssertionError): + fake_converter(None, None, (np.ones((4, 4)),), {}, "fake") + + def test_different_type_with_promotion(self): + @enforce_tensor_types({"sample": (np.ndarray,)}, promote=True) + def fake_converter(network, target, args, kwargs, name): + self.assertIsInstance(kwargs["sample"], np.ndarray) + return + + fake_converter(None, None, tuple(), {"sample": torch.ones((4, 4))}, "fake") + + def test_invalid_invocation_type(self): + with self.assertRaises(AssertionError): + enforce_tensor_types({0: (int, bool)}) + + +if __name__ == "__main__": + run_tests()