diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index cf404042497..5fdcb43be3a 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -279,10 +279,16 @@ def __init__( self.routed_scaling_factor = routed_scaling_factor self.is_fused = is_fused - def noaux_tc(self, logits, e_score_correction_bias): - n_group = self.n_group + @torch.compile(options={"max-autotune": True}) + def get_scores(self, logits, e_score_correction_bias): scores = F.sigmoid(logits) scores_with_bias = scores + e_score_correction_bias + return scores, scores_with_bias + + def noaux_tc(self, logits, e_score_correction_bias): + n_group = self.n_group + scores, scores_with_bias = self.get_scores(logits, + e_score_correction_bias) scores_shape = list(scores_with_bias.shape) if enable_llm_debug(): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index aa5183c1c14..166489b29b9 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -555,33 +555,7 @@ def fp8_block_scaling_bmm_out( torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out) elif sm_version == 100: - output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2)) - out.copy_(output) - - # low_latency = True - # use_deep_seek_fp8 = True - # tile_size = 8 - # epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 - # m_size = mat1.shape[0] - # if m_size % tile_size != 0: - # tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size - # mat1 = torch.nn.functional.pad( - # mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) - - # mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( - # mat1) - # output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( - # mat1_fp8, - # mat2_fp8, - # tile_size=tile_size, - # epilogue_tile_m=epilogue_tile_m, - # use_deep_seek_fp8=use_deep_seek_fp8, - # low_latency=low_latency, - # dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), - # dq_sfs_b=mat2_scale, - # out_dtype=out.dtype, - # ) - # out.copy_(output[:, :m_size]) + torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out) else: raise NotImplementedError(f"SM{sm_version} is not supported") diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 79af3cb7f2b..4b32f7f476b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Union import torch -import torch.nn.functional as F import triton import triton.language as tl @@ -216,30 +215,83 @@ def triton_masked_index_gather(output, input, start_offsets, row_indices): return -@nvtx_range("[DG] act") -@torch.compile(dynamic=True) -def swiglu_fused_moe(x): - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -@nvtx_range("[DG] indexing") -@torch.compile(dynamic=True) -def indexing(x, mask): - return x[mask > 0, :].contiguous() +@triton.jit +def _preprocess_after_permute_kernel( + expert_offsets_ptr, + masked_m_ptr, + token_map_ptr, + TOTAL_TOKENS: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + if pid_y == 0: + token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + token_mask = token_offsets < TOTAL_TOKENS + # get expert_id for each token in the block + expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32) + found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1) + for i in tl.static_range(NUM_EXPERTS): + boundary = tl.load(expert_offsets_ptr + i + 1) + cond = (token_offsets < boundary) & ~found_mask + expert_ids = tl.where(cond, i, expert_ids) + found_mask = found_mask | cond + tl.store(token_map_ptr + token_offsets, + expert_ids.to(tl.int64), + mask=token_mask) + elif pid_y == 1: + # get num_tokens for each expert + expert_mask = pid_x < NUM_EXPERTS + next_offset = tl.load(expert_offsets_ptr + pid_x + 1, + mask=expert_mask, + other=0) + current_offset = tl.load(expert_offsets_ptr + pid_x, + mask=expert_mask, + other=0) + tokens_per_expert = next_offset - current_offset + tl.store(masked_m_ptr + pid_x, + tokens_per_expert.to(tl.int32), + mask=expert_mask) @nvtx_range("[DG] preprocess_after_permute") def preprocess_after_permute(expert_first_token_offset_tensor, permuted_data_tensor): - # get tokens per expert - masked_m = expert_first_token_offset_tensor[ - 1:] - expert_first_token_offset_tensor[:-1] - token_to_expert_map = torch.searchsorted( - expert_first_token_offset_tensor[1:], - torch.arange(permuted_data_tensor.shape[0], device='cuda'), - right=True) - return masked_m.to(torch.int32), token_to_expert_map + """ + Python wrapper that launches a single fused kernel to get the token-to-expert map + and the number of tokens per expert. + """ + total_tokens = permuted_data_tensor.shape[0] + num_experts = expert_first_token_offset_tensor.shape[0] - 1 + + # create output tensors + masked_m = torch.empty(num_experts, dtype=torch.int32, device='cuda') + token_to_expert_map = torch.empty(total_tokens, + dtype=torch.int64, + device='cuda') + + # calculate the grid size + DEFAULT_BLOCK_SIZE_M = 256 + grid_m_size = triton.cdiv(total_tokens, DEFAULT_BLOCK_SIZE_M) + if grid_m_size >= num_experts: + BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M + grid = (grid_m_size, 2) + else: + block_size_m = triton.cdiv(total_tokens, num_experts) + BLOCK_SIZE_M = triton.next_power_of_2(block_size_m) + grid = (num_experts, 2) + + # launch the kernel + _preprocess_after_permute_kernel[grid]( + expert_first_token_offset_tensor, + masked_m, + token_to_expert_map, + TOTAL_TOKENS=total_tokens, + NUM_EXPERTS=num_experts, + BLOCK_SIZE_M=BLOCK_SIZE_M, + ) + return masked_m, token_to_expert_map @nvtx_range("[DG]") diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 4c486a15115..31b90bf69e9 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -476,7 +476,7 @@ def per_token_quant_and_transform( scale_k = ceil_div(k, quant_group_size) m_padded = align(m, alignment) scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((scale_k_padded // 4, m_padded), + output_scale = torch.empty((scale_k_padded // 4, m_padded), dtype=torch.int32, device='cuda')