diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py new file mode 100644 index 0000000000..7436f31939 --- /dev/null +++ b/examples/dynamo/register_sdpa.py @@ -0,0 +1,119 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from sdpa_converter import * +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + +# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention +# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default +) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) + +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +@_aten_lowering_pass +def replace_variants_of_sdpa( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace scaled_dot_product_attention with an equivalent + implementation which can be accurately converted to TRT + """ + attn_mask = None + is_causal = True + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + if ( + node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ): + if len(node.args) == 7: + ( + query, + key, + value, + attn_bias, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif ( + node.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ): + if len(node.args) == 6: + query, key, value, dropout_p, is_causal, return_debug_mask = ( + node.args + ) + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + if attn_mask is not None: + logger.warning( + f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." + ) + + modified_input_args = (query, key, value, None, dropout_p, is_causal) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + # The input args is (query, key, value, is_causal). kwargs has scale + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={"scale": node.kwargs.get("scale", None)}, + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + + # Clean up the graph + clean_up_graph_after_modifications(gm) + + logger.info( + "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" + ) + return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py new file mode 100644 index 0000000000..903324dff5 --- /dev/null +++ b/examples/dynamo/sdpa_converter.py @@ -0,0 +1,176 @@ +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, +) -> TRTTensor: + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + row_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] + ) + + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + col_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + ) + + mask = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + ) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, attn_mask, dropout_p, is_causal = args + logger.info( + "Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time" + ) + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + # If is_causal is True, we need to generate a causal mask + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + else: + scaled_add_attn_bias = scaled + + # Create a if condition to check if is_causal is True + if isinstance(is_causal, TRTTensor): + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + scaled_add_attn_bias = output_layer.get_output(0) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 3891fcbb9a..32e75b06d8 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -19,6 +19,8 @@ we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) """ +import register_sdpa # Register SDPA as a standalone operator + # %% # Import the following libraries # ----------------------------- diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index b276a7542d..bec5e407b5 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,4 +1,5 @@ from ._decomposition_groups import ( + TORCH_TRT_DECOMPOSITIONS, torch_disabled_decompositions, torch_enabled_decompositions, )