Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ void fused_qk_norm_rope(
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"fused_qk_norm_rope(Tensor qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()",
&fused_qk_norm_rope);
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()");
}

// Register the CUDA implementation
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"renorm_moe_routing_op(Tensor router_logits, int topk"
"renorm_moe_routing_op(Tensor router_logits, SymInt topk"
") -> (Tensor, Tensor)");
}

Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/compilation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def inplace_info():
},
torch.ops.trtllm.mla_custom_op_inplace.default: {
1: "output"
},
torch.ops.trtllm.fused_qk_norm_rope.default: {
1: "qkv"
}
}
return inplace_map
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,11 @@ def _(input, residual, norm_weight, expanded_idx_to_permuted_idx,
torch.empty_like(residual),
torch.empty_like(residual),
]

@torch.library.register_fake("trtllm::renorm_moe_routing_op")
def _(router_logits, topk):
num_tokens = router_logits.shape[0]
sz = (num_tokens, topk)
return router_logits.new_empty(
sz, dtype=torch.int32), router_logits.new_empty(sz,
dtype=torch.float32)