diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md new file mode 100644 index 00000000000..6bef175199b --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -0,0 +1,42 @@ +## AutoDeploy Custom Operators + +All AutoDeploy custom operators follow the following naming convention: + +`torch.ops.auto_deploy.__` + +The table below lists the operators ordered by their backend. + +### Available Custom Operators + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support | +| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation | +| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format | +| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) | +| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation | +| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation | +| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | +| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | +| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation | +| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation | +| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation | +| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation | +| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation | +| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values | +| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation | +| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer | +| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer | +| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies | +| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine | +| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving | +| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache | +| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion | +| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache | +| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache | +| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache | +| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support | +| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs | +| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions | +| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation | +| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index 3236e0267c3..f80d1e5ca91 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -4,11 +4,12 @@ from .dist import * from .flashinfer_attention import * from .flashinfer_rope import * -from .fused_moe import * from .linear import * from .mla import * from .quant import * -from .rope import * from .torch_attention import * +from .torch_moe import * from .torch_rope import * from .triton_attention import * +from .triton_rope import * +from .trtllm_moe import * diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py index 07b5d46aedb..18452d3b417 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py @@ -168,7 +168,9 @@ def _paged_context_mha( ) -@torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_fused_mha_with_paged_cache", mutates_args=() +) def fused_mha_with_paged_cache( q: torch.Tensor, k: torch.Tensor, @@ -210,10 +212,10 @@ def fused_mha_with_paged_cache( if freqs_cis is not None: if s == 1: rope_args = (freqs_cis, input_pos, "bsnd") - fn_rope = torch.ops.rope.apply_rope_with_input_pos + fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos else: rope_args = (freqs_cis, input_pos, seq_len, seq_start) - fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs + fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs q = fn_rope(q, *rope_args) k = fn_rope(k, *rope_args) @@ -416,7 +418,9 @@ def _flattened_context_mha_rope_fusion( ) -@torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_fused_flattened_mha_with_cache_rope_fusion", mutates_args=() +) def fused_flattened_mha_with_cache_rope_fusion( q: torch.Tensor, k: torch.Tensor, @@ -541,7 +545,7 @@ def _context_mha( ) -@torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=()) +@torch.library.custom_op("auto_deploy::triton_attention_fused_mha_with_cache", mutates_args=()) def fused_mha_with_cache( q: torch.Tensor, k: torch.Tensor, @@ -563,8 +567,8 @@ def fused_mha_with_cache( # rope embedding if freqs_cis is not None: - q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd") - k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd") + q = torch.ops.auto_deploy.triton_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd") + k = torch.ops.auto_deploy.triton_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd") # attention (assumed layout is bsnd) y = torch.empty_like(q) @@ -593,7 +597,9 @@ def fused_mha_fake( return torch.empty_like(q.contiguous()) -@torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_fused_flattened_mha_with_cache", mutates_args=() +) def fused_flattened_mha_with_cache( # Q, K, V q: torch.Tensor, @@ -638,10 +644,10 @@ def fused_flattened_mha_with_cache( if freqs_cis.numel() > 0: if s == 1: rope_args = (freqs_cis, input_pos, "bsnd") - fn_rope = torch.ops.rope.apply_rope_with_input_pos + fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos else: rope_args = (freqs_cis, input_pos, seq_len, seq_start) - fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs + fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs q = fn_rope(q, *rope_args) k = fn_rope(k, *rope_args) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py index 755638852e4..d6f13fbedd7 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py @@ -8,7 +8,7 @@ from ..distributed import trtllm as trtllm_dist -@torch.library.custom_op("dist::all_gather", mutates_args=(), device_types="cuda") +@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda") def all_gather( tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None ) -> torch.Tensor: @@ -25,7 +25,7 @@ def all_gather_fake(tensor, dim=0): return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim) -@torch.library.custom_op("dist::all_reduce", mutates_args=(), device_types="cuda") +@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda") def all_reduce(t: torch.Tensor) -> torch.Tensor: """All_reduce across the ranks. Reduction op is SUM. diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 8003b7f10a0..6682299a656 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -153,7 +153,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): _GlobalFlashInferPlanner = _FlashInferPlanner() -@torch.library.custom_op("attention::prepare_flashinfer_metadata", mutates_args=()) +@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( input_ids: torch.Tensor, position_ids: torch.Tensor, @@ -228,7 +228,7 @@ def prepare_flashinfer_metadata_fake( ) -@torch.library.custom_op("attention::flashinfer_mha_with_cache", mutates_args=()) +@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=()) def flashinfer_mha_with_cache( # Q, K, V q: torch.Tensor, @@ -355,15 +355,15 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: """Get the source attention op that we target for replacement.""" - return torch.ops.attention.bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.attention.flashinfer_mha_with_cache + return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - return torch.ops.attention.prepare_flashinfer_metadata, 6 + return torch.ops.auto_deploy.flashinfer_attention_prepare_metadata, 6 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py index 4746e6fb124..dd65701ec48 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py @@ -4,7 +4,7 @@ import torch -@torch.library.custom_op("rope::flashinfer", mutates_args=()) +@torch.library.custom_op("auto_deploy::flashinfer_rope", mutates_args=()) def apply_rope_with_input_pos_flashinfer( q: torch.Tensor, k: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py index 26e5ce37124..fda48e4ba57 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py @@ -8,7 +8,7 @@ from ..distributed import trtllm as trtllm_dist -@torch.library.custom_op("linear::simple", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=()) def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: """A wrapper for the linear functional to control how it is exposed. @@ -30,7 +30,9 @@ def simple_fake(input, weight, bias): return torch.ops.aten.linear(input, weight, bias) -@torch.library.custom_op("linear::fused_linear_all_reduce", mutates_args=(), device_types="cuda") +@torch.library.custom_op( + "auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda" +) def fused_linear_all_reduce( input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] ) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index c077d66b585..7104916feff 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -22,7 +22,9 @@ Constant = Union[int, float, str, None] -@torch.library.custom_op("attention::fused_flattened_mla_with_cache", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=() +) def fused_flattened_mla_with_cache( # Q, K, V q_nope: torch.Tensor, @@ -94,7 +96,7 @@ def fused_flattened_mla_with_cache( q_slice = q_pe[start : start + length] k_slice = k_pe[start : start + length] - q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( q_slice, k_slice, cos, @@ -169,7 +171,9 @@ def fused_flattened_mla_with_cache_fake( return torch.empty_like(kv[..., -v_head_dim:]) -@torch.library.custom_op("attention::prepare_fused_mla_metadata", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=() +) def prepare_fused_mla_metadata( input_ids: torch.Tensor, position_ids: torch.Tensor, @@ -221,15 +225,15 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.deepseek.fused_mla + return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.attention.fused_flattened_mla_with_cache + return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - return torch.ops.attention.prepare_fused_mla_metadata, 4 + return torch.ops.auto_deploy.triton_attention_prepare_fused_mla_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index 30493605780..6183394533a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -16,7 +16,7 @@ TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16 -@torch.library.custom_op("quant::quant_fn", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=()) def quant_fn(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: scaled_x = x / scale rounded_x = torch.round(scaled_x) @@ -37,7 +37,7 @@ def __init__(self, scale): self.register_buffer("scale", torch.tensor(scale)) def forward(self, x: torch.Tensor): - return torch.ops.quant.quant_fn(x, self.scale) + return torch.ops.auto_deploy.torch_quant_fn(x, self.scale) FP8_MIN = torch.finfo(torch.float8_e4m3fn).min @@ -50,7 +50,7 @@ def _to_fp8(x, scale): return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) -@torch.library.custom_op("quant::fp8_linear", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=()) @torch.compile(dynamic=True) def fp8_linear( input: torch.Tensor, @@ -105,7 +105,7 @@ def fp8_linear_fake( return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias) -@torch.library.custom_op("quant::fused_fp8_linear_all_reduce", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=()) @torch.compile(dynamic=True) def fused_fp8_linear_all_reduce( input: torch.Tensor, @@ -114,7 +114,9 @@ def fused_fp8_linear_all_reduce( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - out = torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale) + out = torch.ops.auto_deploy.torch_quant_fp8_linear( + input, weight_fp8, bias, input_scale, weight_scale + ) if trtllm_dist.is_trtllm_op_available(): return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM) dist.all_reduce(out, op=dist.ReduceOp.SUM) @@ -129,7 +131,9 @@ def fused_fp8_linear_all_reduce_fake( input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale) + return torch.ops.auto_deploy.torch_quant_fp8_linear( + input, weight_fp8, bias, input_scale, weight_scale + ) class FP8Linear(nn.Linear): @@ -146,12 +150,12 @@ def __init__(self, *args, **kwargs): self.bias = nn.Parameter(self.bias.to(torch.half)) def forward(self, x): - return torch.ops.quant.fp8_linear( + return torch.ops.auto_deploy.torch_quant_fp8_linear( x, self.weight, self.bias, self.input_scale, self.weight_scale ) -@torch.library.custom_op("quant::fp4_linear", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=()) @torch.compile(dynamic=True) def fp4_linear( input: torch.Tensor, @@ -218,4 +222,7 @@ def fp4_linear_fake( return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias) -QUANT_OPS = [torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear] +QUANT_OPS = [ + torch.ops.auto_deploy.torch_quant_fp8_linear, + torch.ops.auto_deploy.torch_quant_fp4_linear, +] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index ebf8dd685ec..6764ca3d91e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -8,7 +8,7 @@ import torch.nn.functional as F -@torch.library.custom_op("attention::repeat_kv", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=()) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -31,7 +31,7 @@ def repeat_kv_fake(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return torch.empty(replicated_shape, device=hidden_states.device, dtype=hidden_states.dtype) -@torch.library.custom_op("attention::scaled_dot_product_attention", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_sdpa", mutates_args=()) def scaled_dot_product_attention( query: torch.Tensor, key: torch.Tensor, @@ -66,7 +66,7 @@ def scaled_dot_product_attention_fake( return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() -@torch.library.custom_op("attention::grouped_sdpa", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=()) def grouped_sdpa( query: torch.Tensor, key: torch.Tensor, @@ -104,7 +104,7 @@ def grouped_sdpa_fake( return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() -@torch.library.custom_op("attention::bsnd_grouped_sdpa", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=()) def bsnd_grouped_sdpa( query: torch.Tensor, # layout: [b, n, s_q, d] key: torch.Tensor, # layout: [b, n, s_k, d] @@ -162,7 +162,7 @@ def update_kv_cache( ) -@torch.library.custom_op("attention::fused_mla_ref", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=()) def fused_mla_ref( q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -215,7 +215,7 @@ def fused_mla_ref( q_slice = q_pe[start : start + length] k_slice = k_pe[start : start + length] - q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( q_slice, k_slice, cos, @@ -315,7 +315,7 @@ def fused_mla_ref_fake( return torch.empty_like(kv[..., -v_head_dim:]) -@torch.library.custom_op("deepseek::fused_mla", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=()) def fused_mla( q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -340,7 +340,7 @@ def fused_mla( cos = cos[position_ids] sin = sin[position_ids] - q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin) + q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim) query_states[:, :, :, :qk_nope_head_dim] = q_nope @@ -399,7 +399,7 @@ def fused_mla( return torch.empty_like(kv[..., -v_head_dim:]) -@torch.library.custom_op("deepseek::mla", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=()) def mla( q_nope: torch.Tensor, # Down projected q_nope q_pe: torch.Tensor, # q_pe after applying rope diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py similarity index 81% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py index 18e4d8bf649..f5e7373c47a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py @@ -3,10 +3,8 @@ import torch import torch.nn.functional as F -from ...modules.fused_moe import MoE # noqa: F401 - -@torch.library.custom_op("moe::torch_moe", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_moe", mutates_args=()) def torch_moe( x: torch.Tensor, selected_experts: torch.Tensor, @@ -80,7 +78,7 @@ def torch_moe( return torch.empty_like(x) -@torch.library.custom_op("moe::torch_fused_moe", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_moe_fused", mutates_args=()) def torch_fused_moe( x: torch.Tensor, selected_experts: torch.Tensor, @@ -90,7 +88,6 @@ def torch_fused_moe( ) -> torch.Tensor: """ A reference implementation of a fused MoE layer computation. - Parameters: x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size, S is the sequence length, and H is the hidden size. @@ -102,7 +99,6 @@ def torch_fused_moe( containing the fused weights for w3 and w1 for each expert. w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) containing the weights for w2 for each expert. - Returns: torch.Tensor: Output tensor with the same shape as the input x. """ @@ -145,45 +141,3 @@ def torch_fused_moe( w2_stacked_weight: torch.Tensor, ) -> torch.Tensor: return torch.empty_like(x) - - -@torch.library.custom_op("moe::trtllm_fused_moe", mutates_args=()) -def trtllm_fused_moe( - x: torch.Tensor, - selected_experts: torch.Tensor, - routing_weights: torch.Tensor, - w3_w1_stacked_weight: torch.Tensor, - w2_stacked_weight: torch.Tensor, -) -> torch.Tensor: - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - - routing_weights = routing_weights.to(torch.float32) - selected_experts = selected_experts.to(torch.int32) - quant_scales = [] - - return torch.ops.trtllm.fused_moe( - x, - selected_experts, - routing_weights, - w3_w1_stacked_weight, - w2_stacked_weight, - x.dtype, - quant_scales, - tp_size=1, - tp_rank=0, - ep_size=1, - ep_rank=0, - enable_alltoall=False, - )[0].view(x_shape) - - -@trtllm_fused_moe.register_fake -def trtllm_fused_moe( - x: torch.Tensor, - selected_experts: torch.Tensor, - routing_weights: torch.Tensor, - w3_w1_stacked_weight: torch.Tensor, - w2_stacked_weight: torch.Tensor, -) -> torch.Tensor: - return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py index 33286d5946c..da769158b6c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py @@ -11,7 +11,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_rope_with_explicit_cos_sin", mutates_args=()) def torch_apply_rope_with_explicit_cos_sin( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -38,7 +38,7 @@ def torch_apply_rope_with_explicit_cos_sin_fake( return torch.empty_like(q), torch.empty_like(k) -@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_rope_with_complex_freqs", mutates_args=()) def torch_apply_rope_with_complex_freqs( xq: torch.Tensor, xk: torch.Tensor, @@ -69,7 +69,7 @@ def torch_apply_rope_with_complex_freqs_fake( return torch.empty_like(xq), torch.empty_like(xk) -@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_rope_with_qk_interleaving", mutates_args=()) def torch_apply_rope_with_qk_interleaving( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 60a467b024c..c95e1c28547 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -169,7 +169,7 @@ def _flattened_context_mha( ) -@torch.library.custom_op("attention::flattened_mha_with_cache", mutates_args=()) +@torch.library.custom_op("auto_deploy::triton_attention_flattened_mha_with_cache", mutates_args=()) def flattened_mha_with_cache( # Q, K, V q: torch.Tensor, @@ -259,7 +259,9 @@ def flattened_mha_fake( return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous() -@torch.library.custom_op("attention::prepare_fused_mha_metadata", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=() +) def prepare_fused_mha_metadata( input_ids: torch.Tensor, position_ids: torch.Tensor, @@ -314,15 +316,15 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.attention.bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.attention.flattened_mha_with_cache + return torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - return torch.ops.attention.prepare_fused_mha_metadata, 4 + return torch.ops.auto_deploy.triton_attention_prepare_fused_mha_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_rope.py similarity index 95% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/rope.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/triton_rope.py index 11e0649f027..b9282d30383 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_rope.py @@ -4,7 +4,7 @@ from .triton_kernels.rope import rope_fwd_flattened_kernel, rope_fwd_kernel -@torch.library.custom_op("rope::apply_rope_with_input_pos", mutates_args=()) +@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=()) def apply_rope_with_input_pos( x: torch.Tensor, freqs_cis: torch.Tensor, input_pos: torch.Tensor, layout: str ) -> torch.Tensor: @@ -77,7 +77,7 @@ def apply_rope_with_input_pos_fake(x, freqs_cis, input_pos, layout): return torch.empty_like(x) -@torch.library.custom_op("rope::apply_rope_on_flattened_inputs", mutates_args=()) +@torch.library.custom_op("auto_deploy::triton_rope_on_flattened_inputs", mutates_args=()) def apply_rope_on_flattened_inputs( x: torch.Tensor, freqs_cis: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_moe.py new file mode 100644 index 00000000000..7ed14c6afa5 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_moe.py @@ -0,0 +1,43 @@ +import torch + + +@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=()) +def trtllm_fused_moe( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + + routing_weights = routing_weights.to(torch.float32) + selected_experts = selected_experts.to(torch.int32) + quant_scales = [] + + return torch.ops.trtllm.fused_moe( + x, + selected_experts, + routing_weights, + w3_w1_stacked_weight, + w2_stacked_weight, + x.dtype, + quant_scales, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + enable_alltoall=False, + )[0].view(x_shape) + + +@trtllm_fused_moe.register_fake +def trtllm_fused_moe( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/models/deepseek.py b/tensorrt_llm/_torch/auto_deploy/models/deepseek.py index 38f1713c633..ae04bf6e592 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/deepseek.py +++ b/tensorrt_llm/_torch/auto_deploy/models/deepseek.py @@ -53,7 +53,7 @@ def deepseek_v3_attention( # Use custom op to capture mla. This does not handle KV cache # as passing transformers Cache into a custom op is throwing an error. # Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline - attn_output = torch.ops.deepseek.fused_mla( + attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla( q_nope, q_pe, kv, @@ -131,7 +131,7 @@ def deepseek_v3_moe(self, hidden_states): """DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export.""" selected_experts, routing_weights, *_ = self.gate(hidden_states) - final_hidden_states = torch.ops.moe.torch_moe( + final_hidden_states = torch.ops.auto_deploy.torch_moe( hidden_states, selected_experts, routing_weights, diff --git a/tensorrt_llm/_torch/auto_deploy/models/qwen3.py b/tensorrt_llm/_torch/auto_deploy/models/qwen3.py index dae68b1a2cf..c835eda66ed 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/qwen3.py +++ b/tensorrt_llm/_torch/auto_deploy/models/qwen3.py @@ -17,7 +17,7 @@ def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor): # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.ops.moe.torch_moe( + final_hidden_states = torch.ops.auto_deploy.torch_moe( hidden_states, selected_experts, routing_weights, diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index c9624d8c339..4b88e611b71 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -13,7 +13,6 @@ from ....llmapi.llm_args import _AutoDeployLlmArgs from ....mapping import Mapping from ...distributed import MPIDist -from ...pyexecutor._util import create_torch_sampler_args from ...pyexecutor.config import PyTorchConfig from ...pyexecutor.model_engine import ModelEngine from ...pyexecutor.py_executor import PyExecutor @@ -268,9 +267,14 @@ def create_autodeploy_executor( ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config max_batch_size = ad_config.max_batch_size + max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size max_seq_len = ad_config.max_seq_len attn_page_size = ad_config.attn_page_size max_num_tokens = ad_config.max_num_tokens + max_draft_tokens = ( + 0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens + ) + ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}") # initialize model engine @@ -311,22 +315,30 @@ def create_autodeploy_executor( scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) # search sampler with speculative decoding - sampler_args = create_torch_sampler_args( - executor_config, dist_mapping, mixed_sampler=False, max_seq_len=max_seq_len + # TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see + # https://github.com/NVIDIA/TensorRT-LLM/issues/5254 + # We should expose mixed_sample to our build_and_run_ad script so we can configure this + # correctly for models as needed. + sampler_args = TorchSampler.Args( + max_seq_len=max_seq_len, + max_draft_tokens=max_draft_tokens, + max_num_sequences=max_num_sequences, + max_beam_width=executor_config.max_beam_width, + mixed_sampler=ad_config.mixed_sampler, ) sampler = TorchSampler(sampler_args) + + # creating the executor object py_executor = PyExecutor( resource_manager, scheduler, model_engine=engine, sampler=sampler, dist=mpi_dist, - max_num_sequences=ad_config.max_batch_size * dist_mapping.pp_size, + max_num_sequences=max_num_sequences, disable_overlap_scheduler=ad_config.disable_overlap_scheduler, max_input_len=ad_config.max_input_len, - max_batch_size=ad_config.max_batch_size, - max_draft_tokens=ad_config.speculative_config.max_draft_tokens - if ad_config.speculative_config is not None - else 0, + max_batch_size=max_batch_size, + max_draft_tokens=max_draft_tokens, ) return py_executor diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index 13c9a7374b0..83db1cda825 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -215,11 +215,12 @@ def _logits_to_probs( def _sample( cls, logits: torch.Tensor, sampling_params: SamplingParams ) -> Tuple[torch.Tensor, torch.Tensor]: - probs = cls._logits_to_probs( - logits, sampling_params.temperature, sampling_params.top_k - ) # [*logits.shape] - # idx_next shape is [*logits.shape[:-1]] - idx_next = cls._multinomial_sample_one_no_sync(probs) + from tensorrt_llm._torch.pyexecutor.sampler import top_k_sampling_batch + + logits_shape = logits.shape + logits = logits.view(-1, logits_shape[-1]) # top_k_sampling_batch expects 2D logits + idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k) + idx_next = idx_next.view(logits_shape[:-1]) return idx_next, probs def _decode_tokens( diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/export.py b/tensorrt_llm/_torch/auto_deploy/transformations/export.py index 1b8f30e8038..a13d40f3aef 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/export.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/export.py @@ -213,7 +213,7 @@ def _torch_where_patch(condition: torch.Tensor, *args, **kwargs): def _torch_linear_patch( input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: - return torch.ops.linear.simple(input, weight, bias) + return torch.ops.auto_deploy.torch_linear_simple(input, weight, bias) # TODO: remove once https://github.com/pytorch/pytorch/issues/142439 is resolved @@ -336,7 +336,7 @@ def torch_export_to_gm( # there is no guarantee how it is represented and we need to make sure it is easily identifiable # in the graph. sdpa_original = F.scaled_dot_product_attention - F.scaled_dot_product_attention = torch.ops.attention.scaled_dot_product_attention + F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa # We overwrite the linear functional as well. This basically avoids exporting the view ops # that are used to flatten/unflatten multiple batch dimensions of the input tensor. diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py index 1983b29bfde..7e46bd652ce 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py @@ -18,7 +18,7 @@ def match_repeat_kv(gm: GraphModule) -> GraphModule: The pattern is: unsqueeze -> expand -> reshape -> [optional] contiguous - This is replaced with torch.ops.attention.repeat_kv. + This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv. """ graph = gm.graph @@ -49,7 +49,7 @@ def match_eager_attention(gm: GraphModule) -> GraphModule: The pattern is: transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul - This is replaced with torch.ops.attention.scaled_dot_product_attention. + This is replaced with torch.ops.auto_deploy.torch_attention_sdpa. """ graph = gm.graph @@ -82,7 +82,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule: repeat_kv(v, n_rep) -> sdpa(q, repeated_k, repeated_v) - This is replaced with torch.ops.attention.grouped_sdpa. + This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa. """ graph = gm.graph @@ -92,7 +92,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule: # Iterate through nodes in the graph for node in list(graph.nodes): # Look for SDPA nodes that could be part of our pattern - if is_op(node, torch.ops.attention.scaled_dot_product_attention): + if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa): match_info = _match_grouped_attention_pattern(node) if match_info: ad_logger.debug(f"Found grouped attention pattern at {node}") @@ -126,8 +126,8 @@ def match_causal_attn_mask(gm: GraphModule) -> GraphModule: for node in list(graph.nodes): # Look for SDPA nodes or grouped SDPA nodes if not ( - is_op(node, torch.ops.attention.scaled_dot_product_attention) - or is_op(node, torch.ops.attention.grouped_sdpa) + is_op(node, torch.ops.auto_deploy.torch_attention_sdpa) + or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa) ): continue @@ -437,7 +437,7 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node Returns a dictionary with information about the match or None if no match. """ # Check that sdpa_node is an SDPA operation - if not is_op(sdpa_node, torch.ops.attention.scaled_dot_product_attention): + if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa): return None # SDPA should have query, key, value as its first three arguments @@ -447,8 +447,8 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node query, key_repeated, value_repeated = sdpa_node.args[0:3] # Key and value should come from repeat_kv operations - if not is_op(key_repeated, torch.ops.attention.repeat_kv) or not is_op( - value_repeated, torch.ops.attention.repeat_kv + if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op( + value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv ): return None @@ -487,7 +487,7 @@ def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None: with graph.inserting_before(node_to_replace): repeat_kv_node = graph.call_function( - torch.ops.attention.repeat_kv, args=(input_tensor, n_rep) + torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep) ) # Preserve metadata from the original node @@ -502,7 +502,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None: Replace the matched eager attention pattern with scaled_dot_product_attention. """ # retrieve the default op for scaled_dot_product_attention - sdpa_op = torch.ops.attention.scaled_dot_product_attention.default + sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default # construct the args for the ops based on the match_info and the op's schema args = [] @@ -530,7 +530,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None: def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None: """ - Replace the matched grouped attention pattern with torch.ops.attention.grouped_sdpa. + Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa. """ sdpa_node = match_info["sdpa_node"] query = match_info["query"] @@ -543,7 +543,7 @@ def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None: with graph.inserting_before(sdpa_node): grouped_sdpa_node = graph.call_function( - torch.ops.attention.grouped_sdpa.default, args=args, kwargs=kwargs + torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs ) # Preserve metadata from the original node @@ -763,8 +763,8 @@ def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescript # List of SDPA operations to look for sdpa_ops = { - torch.ops.attention.scaled_dot_product_attention, - torch.ops.attention.grouped_sdpa, + torch.ops.auto_deploy.torch_attention_sdpa, + torch.ops.auto_deploy.torch_attention_grouped_sdpa, } graph = gm.graph diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py index d6199c8d4ed..bf6f804c427 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py @@ -22,14 +22,14 @@ def fuse_collectives(gm: GraphModule) -> GraphModule: # lookup for fused ops # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. lookup = { - torch.ops.linear.simple: torch.ops.linear.fused_linear_all_reduce, - torch.ops.aten.linear: torch.ops.linear.fused_linear_all_reduce, - torch.ops.quant.fp8_linear: torch.ops.quant.fused_fp8_linear_all_reduce, + torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, } # go through all nodes and find all_reduce nodes for node in gm.graph.nodes: - if not is_op(node, torch.ops.dist.all_reduce): + if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): continue # check if args are as expected @@ -162,7 +162,7 @@ def trace_and_fuse(allreduce_node, graph): # Traverse all nodes for node in gm.graph.nodes: - if is_op(node, torch.ops.dist.all_reduce): + if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): trace_and_fuse(allreduce_node=node, graph=gm.graph) gm = canonicalize_graph(gm) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py index 84761d442ab..acae157a6b7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py @@ -38,7 +38,7 @@ def ep_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule: assert isinstance(gm, GraphModule), "Expecting GraphModule" num_moe_patterns = 0 for node in list(gm.graph.nodes): - if not is_op(node, torch.ops.moe.torch_moe): + if not is_op(node, torch.ops.auto_deploy.torch_moe): continue _insert_sharded_moe(gm, node, rank, world_size) num_moe_patterns += 1 @@ -123,6 +123,8 @@ def get_partition(lst, world_size, rank): # -- add an all_reduce node -- with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function(torch.ops.dist.all_reduce, args=(node,)) + dist_node = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,) + ) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py index 8122d5068a9..02e3e64e170 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py @@ -69,7 +69,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule: w3_list = expert_weights["w3"] fused_moe_node = graph.call_function( - torch.ops.moe.torch_moe, + torch.ops.auto_deploy.torch_moe, args=( hidden_states, selected_experts, @@ -99,7 +99,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule: def fuse_moe(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Scan the FX graph and replace all calls to torch.ops.moe.torch_moe with - torch.ops.moe.trtllm_fused_moe. + torch.ops.auto_deploy.trtllm_moe_fused. """ ad_logger.debug("Before MoE fusion: " + str(gm)) @@ -118,7 +118,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: graph = gm.graph for node in list(graph.nodes): - if not is_op(node, torch.ops.moe.torch_moe): + if not is_op(node, torch.ops.auto_deploy.torch_moe): continue ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}") @@ -146,7 +146,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: with graph.inserting_before(node): new_node = graph.call_function( - torch.ops.moe.trtllm_fused_moe, + torch.ops.auto_deploy.trtllm_moe_fused, args=( hidden_states, selected_experts, diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py index 3a50b7e91a7..651d0730e55 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py @@ -1,6 +1,6 @@ """ This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used -to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`). +to identify and replace RoPE subgraphs with a custom op (`torch.ops.auto_deploy.flashinfer_rope`). Supported RoPE variants: @@ -73,7 +73,7 @@ def _explicit_rope_pattern(q, k, cos, sin, unsqueeze_dim=1): def _explicit_rope_repl(q, k, cos, sin, unsqueeze_dim): - return torch.ops.rope.torch_apply_rope_with_explicit_cos_sin.default( + return torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default( q, k, cos, sin, unsqueeze_dim ) @@ -91,7 +91,7 @@ def _interleaved_rope_pattern(q, k, cos, sin, unsqueeze_dim=1): def _interleaved_rope_repl(q, k, cos, sin, unsqueeze_dim): - return torch.ops.rope.torch_apply_rope_with_qk_interleaving.default( + return torch.ops.auto_deploy.torch_rope_with_qk_interleaving.default( q, k, cos, sin, unsqueeze_dim ) @@ -109,7 +109,7 @@ def _complex_rope_pattern(xq, xk, freqs_cis, unsqueeze_dim=1): def _complex_rope_repl(q, k, freqs_cis, unsqueeze_dim): - return torch.ops.rope.torch_apply_rope_with_complex_freqs.default( + return torch.ops.auto_deploy.torch_rope_with_complex_freqs.default( q, k, freqs_cis, unsqueeze_dim ) @@ -195,9 +195,9 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo graph = gm.graph rope_ops = { - torch.ops.rope.torch_apply_rope_with_explicit_cos_sin, - torch.ops.rope.torch_apply_rope_with_qk_interleaving, - torch.ops.rope.torch_apply_rope_with_complex_freqs, + torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin, + torch.ops.auto_deploy.torch_rope_with_qk_interleaving, + torch.ops.auto_deploy.torch_rope_with_complex_freqs, } need_transpose = False @@ -206,7 +206,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo if not is_op(node, rope_ops): continue - if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): q_node, k_node, freqs_node, unsq = extract_op_args( node, "xq", # argument name in schema @@ -257,7 +257,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2) k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2) - if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): new_args = ( q_for_op_contig, k_for_op_contig, @@ -309,9 +309,9 @@ def optimize_rope(gm: GraphModule) -> GraphModule: num_rope_optimizations = 0 for node in list(graph.nodes): - if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin): + if is_op(node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin): _optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache) - elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): _optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache) else: continue @@ -398,7 +398,7 @@ def _optimize_explicit( rope_position_ids_cache=pos_cache, ) flash_node = graph.call_function( - torch.ops.rope.flashinfer, + torch.ops.auto_deploy.flashinfer_rope, args=(q_node, k_node, position_ids, fused_cos_sin_to, True), ) @@ -478,7 +478,7 @@ def _optimize_complex( graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache ) flash_node = graph.call_function( - torch.ops.rope.flashinfer, + torch.ops.auto_deploy.flashinfer_rope, args=(q_node, k_node, position_ids, cos_sin_flash, False), ) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py index 0dedb3ada58..3afa7f5064f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py @@ -16,6 +16,7 @@ happens automatically via the checkpoint loading hook added in step 2c. """ +import math import operator from collections import defaultdict from functools import partial @@ -71,7 +72,13 @@ def _load_hook_remove( def _insert_sharded_matmul( - gm: GraphModule, node: Node, dim: int, rank: int, world_size: int, add_dist: bool = False + gm: GraphModule, + node: Node, + dim: int, + rank: int, + world_size: int, + add_dist: bool = False, + min_local_shape: int = 1, ): """Replaces the matmul node with a new matmul node that accepts sharded weights. @@ -83,8 +90,21 @@ def _insert_sharded_matmul( quantization_impl = QuantizationImpl.create(node) def split_tensor( - t: torch.Tensor, d: int = dim, r: int = rank, ws: int = world_size + t: torch.Tensor, + d: int = dim, + r: int = rank, + ws: int = world_size, + min_d_shape: int = min_local_shape, ) -> torch.Tensor: + # The local tensor shape has to be divisible by min_d_shape + max_split_size = t.shape[d] // min_d_shape + if ws > max_split_size: + num_groups = math.ceil(ws / max_split_size) + ad_logger.debug( + f"World size {ws} is greater than the max split size {max_split_size}. " + + f"Splitting tensor to {num_groups} chunks" + ) + return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] return torch.tensor_split(t, ws, dim=d)[r] num_users = num_users_of_weight_node(node) @@ -168,8 +188,8 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to # figure out the right dist op dist_lookup = { - 0: (torch.ops.dist.all_gather, -1), - 1: (torch.ops.dist.all_reduce,), + 0: (torch.ops.auto_deploy.torch_dist_all_gather, -1), + 1: (torch.ops.auto_deploy.torch_dist_all_reduce,), } fn_dist, *dist_args = dist_lookup[dim] @@ -191,7 +211,10 @@ def _simple_shard( def column_row_shard( - gm: GraphModule, rank: int, world_size: int, simple_shard_only: bool = False + gm: GraphModule, + rank: int, + world_size: int, + simple_shard_only: bool = False, ) -> GraphModule: """A transformation to apply sharding to the model following tensor parallelism. @@ -205,6 +228,9 @@ def column_row_shard( **all** nodes in the subgraph. The subgraph here is defined as the region between the first linear node to the last linear node of an identified sharding region. # 5. Shard the GEMM nodes or skip accordingly. + + min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism + splitting, e.g., the individual heads into smaller shards. """ ad_logger.debug("Before sharding graph: " + str(gm)) @@ -232,9 +258,9 @@ def column_row_shard( # acceptable attention nodes between sharded GEMMs shardable_attention_nodes = { - torch.ops.attention.scaled_dot_product_attention, - torch.ops.attention.grouped_sdpa, - torch.ops.attention.bsnd_grouped_sdpa, + torch.ops.auto_deploy.torch_attention_sdpa, + torch.ops.auto_deploy.torch_attention_grouped_sdpa, + torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa, } # This is a heuristic. Basically, we assume those are okay to shard if we also encounter an @@ -244,7 +270,7 @@ def column_row_shard( shardable_nodes_with_attention = { torch.ops.aten.view, torch.ops.aten.reshape, - torch.ops.rope.flashinfer, + torch.ops.auto_deploy.flashinfer_rope, operator.getitem, } @@ -327,9 +353,25 @@ def column_row_shard( # If we can account for all sharded nodes, we can do a two-way shard # --> row_split (dim 0) + col_split (dim 1) + all_reduce + + # check if we are sharding the attention block + if attention_nodes: + if len(attention_nodes) > 1: + # Column-row shard boundary region detection is probably wrong - there should be + # only one attention operation. Fall back to simple shard. + ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") + _simple_shard(gm, nodes_linear, rank, world_size) + continue + # Extract head dimension. We cannot shard below the head_dim size. + # Assume that head_dim is the last (innermost) dimension of the tensor + min_local_shape = attention_nodes.pop().meta["val"].shape[-1] + else: + min_local_shape = 1 for i, group in enumerate(nodes_linear.values()): for n in group: - _insert_sharded_matmul(gm, n, i, rank, world_size, add_dist=i > 0) + _insert_sharded_matmul( + gm, n, i, rank, world_size, add_dist=i > 0, min_local_shape=min_local_shape + ) # canonicalize and return if num_shards: @@ -424,7 +466,7 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor: base_size = bmm_batch_size // world_size remainder = bmm_batch_size % world_size - # NOTE: our torch.ops.dist.all_gather doesn't support uneven splits at the moment. + # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. if remainder: ad_logger.warning( f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " @@ -451,7 +493,7 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor: # Add all_gather node after BMM to collect results with gm.graph.inserting_after(node): gather_node = gm.graph.call_function( - torch.ops.dist.all_gather, + torch.ops.auto_deploy.torch_dist_all_gather, args=(node, 0), # Gather along batch dimension (0) ) node.replace_all_uses_with(gather_node) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py index 71b196aeb21..d02cdecd4f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py @@ -68,11 +68,11 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): # TODO(yudong): make custom_ops configurable CUSTOM_OPS = ( - torch.ops.dist.all_reduce.default, + torch.ops.auto_deploy.torch_dist_all_reduce.default, torch.ops.aten.slice.Tensor, - torch.ops.attention.fused_mha_with_cache.default, - torch.ops.linear.fused_linear_all_reduce.default, - torch.ops.linear.simple.default, + torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default, + torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default, + torch.ops.auto_deploy.torch_linear_simple.default, torch.ops.aten.split_with_sizes.default, ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 88409dd98e7..ee0dcd0ab7d 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -222,7 +222,7 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool: """ lin_ops = { torch.ops.aten.linear, - torch.ops.linear.simple, + torch.ops.auto_deploy.torch_linear_simple, } if include_quantization: @@ -233,8 +233,8 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool: def is_dist_op(node: Node) -> bool: """Check if the node is a distributed op.""" dist_ops = { - torch.ops.dist.all_gather, - torch.ops.dist.all_reduce, + torch.ops.auto_deploy.torch_dist_all_gather, + torch.ops.auto_deploy.torch_dist_all_reduce, } return is_op(node, dist_ops) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index 11b3c18eff1..c58c788810c 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -136,7 +136,7 @@ def fuse_linear_weights(weights, **kwargs) -> Tuple[torch.Tensor, Dict[str, torc class FP8QuantizationImpl(QuantizationImpl): @staticmethod def target_op(): - return torch.ops.quant.fp8_linear + return torch.ops.auto_deploy.torch_quant_fp8_linear @staticmethod def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: @@ -207,7 +207,7 @@ def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, class FP4QuantizationImpl(QuantizationImpl): @staticmethod def target_op(): - return torch.ops.quant.fp4_linear + return torch.ops.auto_deploy.torch_quant_fp4_linear @staticmethod def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index cd17de77c6b..bb7c896b5a6 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -184,7 +184,7 @@ def __init__(self, hidden_size=32, intermediate_size=16, num_experts=4, top_k=2) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: Tensor of shape (batch, hidden_size) - Computes router logits via a gate, and then calls the MoE op via torch.moe.torch_moe. + Computes router logits via a gate, and then calls the MoE op via torch.ops.auto_deploy.torch_moe. """ router_logits = self.gate(x) @@ -197,7 +197,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w2_list = [expert.w2 for expert in self.experts] w3_list = [expert.w3 for expert in self.experts] - out = torch.ops.moe.torch_moe( + out = torch.ops.auto_deploy.torch_moe( x, selected_experts, routing_weights, w1_list, w2_list, w3_list ) return out diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py index 43a653a282f..d4c8091158a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py @@ -9,14 +9,14 @@ def _run_all_reduce_test(rank, world_size): x = torch.ones(10, 10).to("cuda") - y = torch.ops.dist.all_reduce(x) + y = torch.ops.auto_deploy.torch_dist_all_reduce(x) assert torch.equal(x * world_size, y) def _run_all_gather_test(rank, world_size): x = torch.ones(10, 10).to("cuda") - y = torch.ops.dist.all_gather(x) + y = torch.ops.auto_deploy.torch_dist_all_gather(x) assert torch.sum(y) == world_size * torch.sum(x) assert y.shape == (world_size * x.shape[0], *x.shape[1:]) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_moe_ep.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_moe_ep.py index ad7cfe006ea..0e8f84a6d15 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_moe_ep.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_moe_ep.py @@ -74,7 +74,7 @@ def get_partition(t: torch.Tensor, world_size: int, rank: int) -> torch.Tensor: final_scales_local = final_scales * rank_mask - output_trt = torch.ops.moe.trtllm_fused_moe( + output_trt = torch.ops.auto_deploy.trtllm_moe_fused( x, selected_experts_local, final_scales_local, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index 42317620208..b7a4b5a3668 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -36,7 +36,7 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.dist.all_reduce(x) + x = torch.ops.auto_deploy.torch_dist_all_reduce(x) y = x + residual normed = self.norm(y) return normed, y diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index 3544f1f814e..f6f48072049 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -64,7 +64,7 @@ def _get_expected_num_params(num_p_og: int) -> int: return num_params # now run the test - op_expected = getattr(torch.ops.dist, "all_gather") + op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather") run_test( model, x, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py index 7a1f4f8f844..4aa1a875c42 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py @@ -26,8 +26,8 @@ def __init__(self, in_features, out_features, bias, cls): self.linear2 = cls(4 * in_features, out_features, bias=bias) def forward(self, x): - y = F.relu(torch.ops.dist.all_reduce(self.linear1(x))) - return torch.ops.dist.all_reduce(self.linear2(y)) + y = F.relu(torch.ops.auto_deploy.torch_dist_all_reduce(self.linear1(x))) + return torch.ops.auto_deploy.torch_dist_all_reduce(self.linear2(y)) def _run_job( @@ -58,7 +58,7 @@ def _get_expected_num_params(num_p_og: int) -> int: def check_transformed_graph(gm): return any(is_op(n, op_expected) for n in gm.graph.nodes) and not any( - is_op(n, torch.ops.dist.all_reduce) for n in gm.graph.nodes + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes ) # now run the test @@ -76,10 +76,10 @@ def check_transformed_graph(gm): @pytest.mark.parametrize( "linear_cls, dist_op_expected", ( - (nn.Linear, "linear.fused_linear_all_reduce"), + (nn.Linear, "auto_deploy.trtllm_dist_fused_linear_all_reduce"), pytest.param( FP8Linear, - "quant.fused_fp8_linear_all_reduce", + "auto_deploy.torch_quant_fused_fp8_linear_all_reduce", marks=pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support"), ), ), diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index cfc48bb5d6b..66c76ec835a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -33,7 +33,7 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: expected_expert = num_experts_per_rank * hidden_size * intermediate_size * 3 return n_gate + expected_expert - op_expected = torch.ops.dist.all_reduce + op_expected = torch.ops.auto_deploy.torch_dist_all_reduce run_test( model, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_graph_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_graph_sharding.py index 3ac02d8bd6b..45f673cfff9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_graph_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_graph_sharding.py @@ -15,6 +15,50 @@ from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +class GQA_Block(nn.Module): + def __init__( + self, + num_attention_heads: int, + hidden_size: int, + num_key_value_heads: int, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.is_gqa = num_key_value_heads < num_attention_heads + assert self.hidden_size == self.num_attention_heads * self.head_dim + + # key, query, value, out projections + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = nn.Linear( + self.hidden_size, + self.head_dim * self.num_key_value_heads, + bias=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.head_dim * self.num_key_value_heads, + bias=False, + ) + + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + + q = self.q_proj(x).view(b, s, -1, self.head_dim) + k = self.k_proj(x).view(b, s, -1, self.head_dim) + v = self.v_proj(x).view(b, s, -1, self.head_dim) + + y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True) + y = y.contiguous().view(b, s, -1) + + return self.o_proj(y) + + class MLP(nn.Module): def __init__(self, in_features, out_features, bias=False): super().__init__() @@ -37,27 +81,80 @@ def _run_job( ) -> None: # init model and input batch_size = 4 - num_features = 10 - model = model_cls(num_features, num_features, bias=bias).to(device="cuda", dtype=torch.float16) - x = torch.randn(batch_size, num_features, device="cuda", dtype=torch.float16) + sequence_len = 8 + num_features = 32 + + # GQA specific parameters + num_heads = 4 + num_key_value_heads = 1 + + if model_cls == GQA_Block: + model = model_cls( + num_attention_heads=num_heads, + hidden_size=num_features, + num_key_value_heads=num_key_value_heads, + ).to(device="cuda", dtype=torch.float16) + else: + model = model_cls(num_features, num_features, bias=bias).to( + device="cuda", dtype=torch.float16 + ) + x = torch.randn(batch_size, sequence_len, num_features, device="cuda", dtype=torch.float16) + + if model_cls == GQA_Block: + head_dim = num_features // num_heads + min_local_size = head_dim + else: + min_local_size = 1 def _get_expected_num_params(num_p_og: int) -> int: num_update = 0 - if bias and dist_op_expected == "all_reduce": + if bias and dist_op_expected == "torch_dist_all_reduce": num_p_og -= num_features num_update = num_features * (rank == world_size - 1) - num_params = num_p_og // world_size + num_update + if min_local_size > 1: + # it means we are in the GQA. W_kv are partially replicated, we need to count + # the number of parameters manually. + W_q_local_size = num_features * num_features // world_size + W_o_local_size = W_q_local_size + W_k_local_size = num_features * head_dim * max(num_key_value_heads // world_size, 1) + W_v_local_size = W_k_local_size + num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size + else: + num_params = num_p_og // world_size + num_update return num_params + def verify_local_weight_sizes(gm) -> bool: + """Verify that all weight tensors have first dimension >= min_local_size after sharding.""" + for name, param in gm.named_parameters(): + # Only check parameters that have at least 1 dimension and are weight matrices + if param.dim() >= 1 and "weight" in name: + if param.shape[0] < min_local_size: + print( + f"Weight {name} has shape {param.shape}, dim {param.shape[0]} < min_local_size {min_local_size}" + ) + return False + return True + # now run the test - op_expected = getattr(torch.ops.dist, dist_op_expected) + op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) + + transform_func = partial(column_row_shard, rank=rank, world_size=world_size) + + def combined_graph_check(gm) -> bool: + # Check for expected distributed operations + has_expected_dist_ops = any(is_op(n, op_expected) for n in gm.graph.nodes) == ( + world_size > 1 + ) + # Check weight size constraints + weight_sizes_valid = verify_local_weight_sizes(gm) + return has_expected_dist_ops and weight_sizes_valid + run_test( model, x, - transform=partial(column_row_shard, rank=rank, world_size=world_size), - check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes) - == (world_size > 1), + transform=transform_func, + check_transformed_graph=combined_graph_check, _get_expected_num_params=_get_expected_num_params, ) @@ -67,8 +164,9 @@ def _get_expected_num_params(num_p_og: int) -> int: @pytest.mark.parametrize( "model_cls, dist_op_expected", ( - (MLP, "all_reduce"), - (nn.Linear, "all_gather"), + (MLP, "torch_dist_all_reduce"), + (nn.Linear, "torch_dist_all_gather"), + (GQA_Block, "torch_dist_all_reduce"), ), ) def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int): diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index 865fe807c40..116126dc925 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -49,7 +49,7 @@ def test_moe_op_run(dtype): fused_w2_weight.data[expert_id].copy_(w2) with torch.inference_mode(): - output_torch_moe = torch.ops.moe.torch_moe( + output_torch_moe = torch.ops.auto_deploy.torch_moe( x, selected_experts, final_scales, @@ -57,14 +57,14 @@ def test_moe_op_run(dtype): w2_weight, w3_weight, ) - output_torch_fused_moe = torch.ops.moe.torch_fused_moe( + output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( x, selected_experts, final_scales, fused_w3_w1_stacked_weight, fused_w2_weight, ) - output_trt_fused_moe = torch.ops.moe.trtllm_fused_moe( + output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused( x, selected_experts, final_scales, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py index 25614a3b042..cfc5ac1891c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py @@ -21,7 +21,7 @@ def test_attention_op(): q, k, v = (x.contiguous() for x in torch.split(qkv, 1, dim=1)) - output = torch.ops.attention.fused_mha_with_cache( + output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache( q, k, v, input_positions, k_cache, v_cache, None ) ref = torch.nn.functional.scaled_dot_product_attention( @@ -66,7 +66,7 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len): v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device) # run custom op - output = torch.ops.attention.fused_mha_with_cache( + output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache( q, k, v, input_positions, k_cache, v_cache, None ) @@ -148,7 +148,7 @@ def test_flat_gqa_op( v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs) # run op - output = torch.ops.attention.flattened_mha_with_cache( + output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache( # Q, K, V q, k, @@ -274,7 +274,7 @@ def test_flat_gqa_op_with_rope( source = 1 if source == 1: # call rope fusion kernels - output = torch.ops.attention.fused_flattened_mha_with_cache_rope_fusion( + output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion( q, k, v, @@ -288,7 +288,7 @@ def test_flat_gqa_op_with_rope( ) else: # call stand-alone rope embedding kernel - output = torch.ops.attention.fused_flattened_mha_with_cache( + output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache( q, k, v, @@ -466,7 +466,7 @@ def test_paged_gqa_op( v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs) # run op - output = torch.ops.attention.fused_mha_with_paged_cache( + output = torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache( q, k, v, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index c1fb8b63de7..4872aef2210 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -88,7 +88,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, ), BATCH_SIZE * SEQ_LEN, ) - flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, @@ -213,7 +213,7 @@ def test_flashinfer_attention_op_decode( ), BATCH_SIZE * SEQ_LEN, ) - flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, @@ -329,7 +329,7 @@ def test_flashinfer_attention_context_and_generate( ), BATCH_SIZE * PREFILL_SEQ_LEN, ) - flashinfer_output_1 = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q_1, k_1, @@ -404,7 +404,7 @@ def test_flashinfer_attention_context_and_generate( ), BATCH_SIZE * 1, ) - flashinfer_output_3 = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q_3, k_3, @@ -513,7 +513,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty ), BATCH_SIZE * SEQ_LEN, ) - flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, @@ -660,7 +660,7 @@ def test_flashinfer_attention_with_fp8_cache( ), BATCH_SIZE * SEQ_LEN, ) - flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, @@ -757,7 +757,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ), BATCH_SIZE * SEQ_LEN, ) - flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, @@ -840,7 +840,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ), BATCH_SIZE * 1, ) - flashinfer_output_gen = torch.ops.attention.flashinfer_mha_with_cache( + flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q_gen, k_gen, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py index bb496185a6b..ae2bfe6bf6a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py @@ -20,7 +20,7 @@ def test_fp8_linear(bias): weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda") weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn) - output_fp8_gemm = torch.ops.quant.fp8_linear( + output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear( input, weight_fp8, bias=bias, @@ -49,7 +49,7 @@ def test_fp4_linear(): weight, weight_scale_2, scaling_vector_size, False ) - output_fp4_gemm = torch.ops.quant.fp4_linear( + output_fp4_gemm = torch.ops.auto_deploy.torch_quant_fp4_linear( input, weight_fp4, bias=None, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py index ee7bca3f6bd..bf00ecf0acd 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py @@ -81,7 +81,7 @@ def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim): # Custom op call positions_flat = torch.arange(batch * seq_len, device=device) - custom_q, custom_k = torch.ops.rope.flashinfer( + custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope( query, key, positions_flat, cos_sin_cache_expand, True ) @@ -135,7 +135,7 @@ def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim): # q/k of llama4 rope is interleaved positions_flat = torch.arange(batch * seq_len, device=device) - custom_q, custom_k = torch.ops.rope.flashinfer( + custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope( query, key, positions_flat, cos_sin_cache_expand, False ) @@ -211,8 +211,8 @@ def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol): q_hf = q_f32.to(dtype) k_hf = k_f32.to(dtype) - q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout) - k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout) + q_out = torch.ops.auto_deploy.triton_rope_with_input_pos(q, cosin_cache, positions, layout) + k_out = torch.ops.auto_deploy.triton_rope_with_input_pos(k, cosin_cache, positions, layout) torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol) torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_rope.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_rope.py index 461974d7ee5..f7a8a5972e3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_rope.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_rope.py @@ -34,7 +34,7 @@ def test_rope(d_head): y_ref = torch_rope_reference(x, freqs_cis, input_position) freqs_cis = freqs_cis.to("cuda") x_reshaped = x.unflatten(-1, (N_ELEM // 2, 2)).transpose(-1, -2).flatten(-2).contiguous() - y = torch.ops.rope.apply_rope_with_input_pos( + y = torch.ops.auto_deploy.triton_rope_with_input_pos( x_reshaped.to("cuda"), freqs_cis, input_position, "bsnd" ) y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous() @@ -64,7 +64,7 @@ def test_rope_flattened(d_head): seq_start_indices = torch.zeros(len(SEQ_LENS), dtype=torch.int32, device="cuda") seq_start_indices[1:] = torch.cumsum(seq_lens[:-1], 0) - y = torch.ops.rope.apply_rope_on_flattened_inputs( + y = torch.ops.auto_deploy.triton_rope_on_flattened_inputs( x_reshaped.to("cuda"), freqs_cis, input_position, seq_lens, seq_start_indices ) y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous() diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py index df2de5ecbb0..ffa2594d908 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_sdpa_mla.py @@ -55,7 +55,9 @@ def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale): q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim]) # MLA ref operation - x = torch.ops.deepseek.mla(q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale) + x = torch.ops.auto_deploy.torch_attention_deepseek_mla( + q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale + ) # Up project attention scores x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:]) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py index e14ad05d693..5e6133e581b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py @@ -379,8 +379,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Manually apply repeat_kv to k and v if self.num_kv_heads != self.num_heads: - k = torch.ops.attention.repeat_kv(k, self.n_rep) - v = torch.ops.attention.repeat_kv(v, self.n_rep) + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, self.n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep) # Create attention mask if needed attn_mask = None @@ -396,7 +396,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).masked_fill(mask, float("-inf")) # Apply scaled dot product attention - attn_output = torch.ops.attention.scaled_dot_product_attention( + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( q, k, v, @@ -434,7 +434,9 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls): expected_matches = 0 if num_heads == num_kv_heads else 2 def verify_matcher(gm): - repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)] + repeat_kv_nodes = [ + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv) + ] # Check that we have the expected number of replacements if len(repeat_kv_nodes) != expected_matches: @@ -549,7 +551,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, rtol, atol, mode def verify_matcher(gm): sdpa_nodes = [ - n for n in gm.graph.nodes if is_op(n, torch.ops.attention.scaled_dot_product_attention) + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) ] # Check that we have the expected number of replacements @@ -655,8 +657,10 @@ def test_counter_example(): dynamic_shapes = model.get_dynamic_shapes() def verify_no_matches(gm): - # No nodes should be replaced with torch.ops.attention.repeat_kv - repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)] + # No nodes should be replaced with torch.ops.auto_deploy.torch_attention_repeat_kv + repeat_kv_nodes = [ + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv) + ] return len(repeat_kv_nodes) == 0 # Ensure the pattern matcher doesn't match our counter-examples @@ -693,7 +697,9 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask): def verify_matcher(gm): grouped_sdpa_nodes = [ - n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa) + n + for n in gm.graph.nodes + if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) ] # Check that we have the expected number of replacements @@ -790,8 +796,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # For grouped attention, repeat k and v if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads: n_rep = self.num_heads // self.num_kv_heads - k = torch.ops.attention.repeat_kv(k, n_rep) - v = torch.ops.attention.repeat_kv(v, n_rep) + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) # Create attention mask based on mask_type if self.mask_type == "triu": @@ -830,7 +836,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.attention.grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( q, k, v, @@ -840,7 +846,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scale=1.0 / (self.head_dim**0.5), ) else: - attn_output = torch.ops.attention.scaled_dot_product_attention( + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( q, k, v, @@ -886,12 +892,14 @@ def test_match_causal_attention(mask_type, use_grouped_sdpa): def verify_matcher(gm): # Find attention operations if use_grouped_sdpa: - attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)] - else: attn_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.attention.scaled_dot_product_attention) + if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + ] + else: + attn_nodes = [ + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) ] if len(attn_nodes) != 1: @@ -990,8 +998,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # For grouped attention, repeat k and v if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads: n_rep = self.num_heads // self.num_kv_heads - k = torch.ops.attention.repeat_kv(k, n_rep) - v = torch.ops.attention.repeat_kv(v, n_rep) + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) # Create a llama-3.1 style causal mask # 1. Create a full tensor with a very negative value @@ -1026,7 +1034,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Choose the appropriate attention implementation if self.use_grouped_sdpa: - attn_output = torch.ops.attention.grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( q, k, v, @@ -1036,7 +1044,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scale=1.0 / (self.head_dim**0.5), ) else: - attn_output = torch.ops.attention.scaled_dot_product_attention( + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( q, k, v, @@ -1078,12 +1086,14 @@ def test_match_llama3_causal_attention(use_grouped_sdpa): def verify_matcher(gm): # Find attention operations if use_grouped_sdpa: - attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)] - else: attn_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.attention.scaled_dot_product_attention) + if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) + ] + else: + attn_nodes = [ + n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) ] if len(attn_nodes) != 1: @@ -1129,7 +1139,7 @@ class MockAttentionDescriptor: """A mock class that mimics the AttentionDescriptor interface for testing.""" layout: str = "bnsd" - source_attention_op: Callable = torch.ops.attention.scaled_dot_product_attention + source_attention_op: Callable = torch.ops.auto_deploy.torch_attention_sdpa @classmethod def get_attention_layout(cls) -> str: @@ -1199,7 +1209,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply scaled dot product attention if self.use_grouped_sdpa: - attn_output = torch.ops.attention.grouped_sdpa( + attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa( q, k, v, @@ -1209,7 +1219,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scale=1.0 / (self.head_dim**0.5), ) else: - attn_output = torch.ops.attention.scaled_dot_product_attention( + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( q, k, v, @@ -1246,7 +1256,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn_mask = self._get_attn_mask(x) if self.has_mask else None # Apply bsnd_grouped_sdpa directly - attn_output = torch.ops.attention.bsnd_grouped_sdpa.default( + attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default( q, k, v, @@ -1284,11 +1294,11 @@ def test_match_attention_layout(layout, model_config, has_mask): MockAttentionDescriptor.layout = layout if layout == "bnsd": if model_config.get("use_grouped_sdpa"): - source_op = torch.ops.attention.grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa else: - source_op = torch.ops.attention.scaled_dot_product_attention + source_op = torch.ops.auto_deploy.torch_attention_sdpa else: - source_op = torch.ops.attention.bsnd_grouped_sdpa + source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa MockAttentionDescriptor.source_attention_op = source_op # Create appropriate model based on model_config @@ -1319,18 +1329,24 @@ def verify_matcher(gm): if model_config["type"] == "standard": if model_config["use_grouped_sdpa"]: original_nodes = [ - n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa) + n + for n in gm.graph.nodes + if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) ] else: original_nodes = [ n for n in gm.graph.nodes - if is_op(n, torch.ops.attention.scaled_dot_product_attention) + if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) ] else: # already_bsnd original_nodes = [] - bsnd_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.bsnd_grouped_sdpa)] + bsnd_nodes = [ + n + for n in gm.graph.nodes + if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa) + ] transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)] # Different expectations based on model type and layout diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index 51de9795fac..cff1fdbb094 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -31,7 +31,7 @@ def get_attention_layout(cls) -> str: @classmethod def get_source_attention_op(cls) -> Callable: - return torch.ops.attention.bsnd_grouped_sdpa + return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa class HFWrapper(nn.Module): @@ -83,9 +83,11 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str) ) def verify_matcher(gm: GraphModule): - """Ensure that there is exactly one torch.ops.attention.bsnd_grouped_sdpa call in the graph.""" + """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa + call in the graph. Also check that there is no repeat_kv pattern left. + """ nodes = gm.graph.find_nodes( - op="call_function", target=torch.ops.attention.bsnd_grouped_sdpa + op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa ) assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph" @@ -100,7 +102,9 @@ def verify_matcher(gm: GraphModule): assert attn_node.args[6] == scale # scale # TODO: check that there is no repeat_kv pattern left... - nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.attention.repeat_kv) + nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv + ) assert len(nodes) == 0, "Found repeat_kv pattern in the graph" return True diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index 3b0b3dd39e9..64c5cd127a5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -53,7 +53,9 @@ def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) v = v.view(b, s, self.num_kv_heads, self.head_dim) # Use grouped SDPA in bsnd layout - attn_output = torch.ops.attention.bsnd_grouped_sdpa(q, k, v, None, 0.0, True, None) + attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa( + q, k, v, None, 0.0, True, None + ) # SDPA output is already in [b, s, n, h_d] format # Reshape to [b, s, n*h_d] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py index 4e507312981..ece6788217f 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py @@ -100,7 +100,7 @@ def test_moe_matching(): model, x, match_moe_pattern, - lambda gm: any(is_op(n, torch.ops.moe.torch_moe) for n in gm.graph.nodes), + lambda gm: any(is_op(n, torch.ops.auto_deploy.torch_moe) for n in gm.graph.nodes), lambda num_p_og: num_p_og, atol=1e-3, rtol=1e-3, @@ -119,7 +119,9 @@ def test_moe_fusion(): x, fuse_moe, lambda gm: any( - is_op(n, {torch.ops.moe.torch_fused_moe, torch.ops.moe.trtllm_fused_moe}) + is_op( + n, {torch.ops.auto_deploy.torch_moe_fused, torch.ops.auto_deploy.trtllm_moe_fused} + ) for n in gm.graph.nodes ), lambda num_p_og: num_p_og, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index ca49a0fc7a8..bea8c2a0cb3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -14,7 +14,10 @@ def check_quantized(gm): - op_expected = {torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear} + op_expected = { + torch.ops.auto_deploy.torch_quant_fp8_linear, + torch.ops.auto_deploy.torch_quant_fp4_linear, + } return any(is_op(n, op_expected) for n in gm.graph.nodes) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py index 693ea61611f..6b6913311dc 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py @@ -91,7 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.mode == "match": q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim) else: # optimize - q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin( + q_out, k_out = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( q, k, cos, sin, unsq_dim ) @@ -119,7 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.mode == "match": q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim) else: - q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs( + q_out, k_out = torch.ops.auto_deploy.torch_rope_with_complex_freqs( q, k, freqs, unsq_dim ) @@ -212,9 +212,9 @@ def test_rope_variants( if transformation == "match": fn = match_rope_pattern check_op = ( - torch.ops.rope.torch_apply_rope_with_explicit_cos_sin + torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin if variant == "explicit" or variant == "explicit_pm" - else torch.ops.rope.torch_apply_rope_with_complex_freqs + else torch.ops.auto_deploy.torch_rope_with_complex_freqs ) def checker(gm): @@ -228,8 +228,8 @@ def checker(gm): if is_op( n, { - torch.ops.rope.torch_apply_rope_with_explicit_cos_sin, - torch.ops.rope.torch_apply_rope_with_complex_freqs, + torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin, + torch.ops.auto_deploy.torch_rope_with_complex_freqs, }, ): q_arg, k_arg, *rest = n.args @@ -254,7 +254,7 @@ def checker(gm): fn = optimize_rope def checker(gm): - return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes) + return any(is_op(n, torch.ops.auto_deploy.flashinfer_rope) for n in gm.graph.nodes) if transformation == "match_layout": _ = run_test( @@ -346,7 +346,7 @@ def forward(self, x): else: cos = cos[pos_ids] sin = sin[pos_ids] - q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q_out, k_out = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( q, k, cos, sin, unsq_dim ) if self.layout == "BNSD": @@ -387,7 +387,7 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target def checker(gm): return any( - is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving) + is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving) for n in gm.graph.nodes ) @@ -396,7 +396,7 @@ def checker(gm): def checker(gm): for n in gm.graph.nodes: - if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving): + if is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving): q_arg, k_arg, *rest = n.args if not ( is_op(q_arg, torch.ops.aten.contiguous)