Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
## AutoDeploy Custom Operators

All AutoDeploy custom operators follow the following naming convention:

`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`

The table below lists the operators ordered by their backend.

### Available Custom Operators

| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache |
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support |
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation |
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from .dist import *
from .flashinfer_attention import *
from .flashinfer_rope import *
from .fused_moe import *
from .linear import *
from .mla import *
from .quant import *
from .rope import *
from .torch_attention import *
from .torch_moe import *
from .torch_rope import *
from .triton_attention import *
from .triton_rope import *
from .trtllm_moe import *
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def _paged_context_mha(
)


@torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_mha_with_paged_cache", mutates_args=()
)
def fused_mha_with_paged_cache(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -210,10 +212,10 @@ def fused_mha_with_paged_cache(
if freqs_cis is not None:
if s == 1:
rope_args = (freqs_cis, input_pos, "bsnd")
fn_rope = torch.ops.rope.apply_rope_with_input_pos
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
else:
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
q = fn_rope(q, *rope_args)
k = fn_rope(k, *rope_args)

Expand Down Expand Up @@ -416,7 +418,9 @@ def _flattened_context_mha_rope_fusion(
)


@torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mha_with_cache_rope_fusion", mutates_args=()
)
def fused_flattened_mha_with_cache_rope_fusion(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -541,7 +545,7 @@ def _context_mha(
)


@torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=())
@torch.library.custom_op("auto_deploy::triton_attention_fused_mha_with_cache", mutates_args=())
def fused_mha_with_cache(
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -563,8 +567,8 @@ def fused_mha_with_cache(

# rope embedding
if freqs_cis is not None:
q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
q = torch.ops.auto_deploy.triton_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
k = torch.ops.auto_deploy.triton_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")

# attention (assumed layout is bsnd)
y = torch.empty_like(q)
Expand Down Expand Up @@ -593,7 +597,9 @@ def fused_mha_fake(
return torch.empty_like(q.contiguous())


@torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mha_with_cache", mutates_args=()
)
def fused_flattened_mha_with_cache(
# Q, K, V
q: torch.Tensor,
Expand Down Expand Up @@ -638,10 +644,10 @@ def fused_flattened_mha_with_cache(
if freqs_cis.numel() > 0:
if s == 1:
rope_args = (freqs_cis, input_pos, "bsnd")
fn_rope = torch.ops.rope.apply_rope_with_input_pos
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
else:
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
q = fn_rope(q, *rope_args)
k = fn_rope(k, *rope_args)

Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..distributed import trtllm as trtllm_dist


@torch.library.custom_op("dist::all_gather", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
def all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
Expand All @@ -25,7 +25,7 @@ def all_gather_fake(tensor, dim=0):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)


@torch.library.custom_op("dist::all_reduce", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
_GlobalFlashInferPlanner = _FlashInferPlanner()


@torch.library.custom_op("attention::prepare_flashinfer_metadata", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
def prepare_flashinfer_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
Expand Down Expand Up @@ -228,7 +228,7 @@ def prepare_flashinfer_metadata_fake(
)


@torch.library.custom_op("attention::flashinfer_mha_with_cache", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
def flashinfer_mha_with_cache(
# Q, K, V
q: torch.Tensor,
Expand Down Expand Up @@ -355,15 +355,15 @@ def get_num_qkv_args(cls) -> int:
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.attention.bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.attention.flashinfer_mha_with_cache
return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_flashinfer_metadata, 6
return torch.ops.auto_deploy.flashinfer_attention_prepare_metadata, 6

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch


@torch.library.custom_op("rope::flashinfer", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_rope", mutates_args=())
def apply_rope_with_input_pos_flashinfer(
q: torch.Tensor,
k: torch.Tensor,
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..distributed import trtllm as trtllm_dist


@torch.library.custom_op("linear::simple", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
"""A wrapper for the linear functional to control how it is exposed.

Expand All @@ -30,7 +30,9 @@ def simple_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)


@torch.library.custom_op("linear::fused_linear_all_reduce", mutates_args=(), device_types="cuda")
@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
Expand Down
16 changes: 10 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
Constant = Union[int, float, str, None]


@torch.library.custom_op("attention::fused_flattened_mla_with_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
)
def fused_flattened_mla_with_cache(
# Q, K, V
q_nope: torch.Tensor,
Expand Down Expand Up @@ -94,7 +96,7 @@ def fused_flattened_mla_with_cache(
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]

q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
Expand Down Expand Up @@ -169,7 +171,9 @@ def fused_flattened_mla_with_cache_fake(
return torch.empty_like(kv[..., -v_head_dim:])


@torch.library.custom_op("attention::prepare_fused_mla_metadata", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=()
)
def prepare_fused_mla_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
Expand Down Expand Up @@ -221,15 +225,15 @@ def get_num_qkv_args(cls) -> int:

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.deepseek.fused_mla
return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.attention.fused_flattened_mla_with_cache
return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_fused_mla_metadata, 4
return torch.ops.auto_deploy.triton_attention_prepare_fused_mla_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
25 changes: 16 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16


@torch.library.custom_op("quant::quant_fn", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())
def quant_fn(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
scaled_x = x / scale
rounded_x = torch.round(scaled_x)
Expand All @@ -37,7 +37,7 @@ def __init__(self, scale):
self.register_buffer("scale", torch.tensor(scale))

def forward(self, x: torch.Tensor):
return torch.ops.quant.quant_fn(x, self.scale)
return torch.ops.auto_deploy.torch_quant_fn(x, self.scale)


FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
Expand All @@ -50,7 +50,7 @@ def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)


@torch.library.custom_op("quant::fp8_linear", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp8_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -105,7 +105,7 @@ def fp8_linear_fake(
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)


@torch.library.custom_op("quant::fused_fp8_linear_all_reduce", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
@torch.compile(dynamic=True)
def fused_fp8_linear_all_reduce(
input: torch.Tensor,
Expand All @@ -114,7 +114,9 @@ def fused_fp8_linear_all_reduce(
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
Expand All @@ -129,7 +131,9 @@ def fused_fp8_linear_all_reduce_fake(
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)


class FP8Linear(nn.Linear):
Expand All @@ -146,12 +150,12 @@ def __init__(self, *args, **kwargs):
self.bias = nn.Parameter(self.bias.to(torch.half))

def forward(self, x):
return torch.ops.quant.fp8_linear(
return torch.ops.auto_deploy.torch_quant_fp8_linear(
x, self.weight, self.bias, self.input_scale, self.weight_scale
)


@torch.library.custom_op("quant::fp4_linear", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp4_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -218,4 +222,7 @@ def fp4_linear_fake(
return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias)


QUANT_OPS = [torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear]
QUANT_OPS = [
torch.ops.auto_deploy.torch_quant_fp8_linear,
torch.ops.auto_deploy.torch_quant_fp4_linear,
]
Loading