Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,16 @@ def __init__(
self.routed_scaling_factor = routed_scaling_factor
self.is_fused = is_fused

def noaux_tc(self, logits, e_score_correction_bias):
n_group = self.n_group
@torch.compile(options={"max-autotune": True})
def get_scores(self, logits, e_score_correction_bias):
scores = F.sigmoid(logits)
scores_with_bias = scores + e_score_correction_bias
return scores, scores_with_bias

def noaux_tc(self, logits, e_score_correction_bias):
n_group = self.n_group
scores, scores_with_bias = self.get_scores(logits,
e_score_correction_bias)
scores_shape = list(scores_with_bias.shape)

if enable_llm_debug():
Expand Down
28 changes: 1 addition & 27 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,33 +555,7 @@ def fp8_block_scaling_bmm_out(
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale, out)
elif sm_version == 100:
output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2))
out.copy_(output)

# low_latency = True
# use_deep_seek_fp8 = True
# tile_size = 8
# epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
# m_size = mat1.shape[0]
# if m_size % tile_size != 0:
# tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
# mat1 = torch.nn.functional.pad(
# mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)

# mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
# mat1)
# output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
# mat1_fp8,
# mat2_fp8,
# tile_size=tile_size,
# epilogue_tile_m=epilogue_tile_m,
# use_deep_seek_fp8=use_deep_seek_fp8,
# low_latency=low_latency,
# dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
# dq_sfs_b=mat2_scale,
# out_dtype=out.dtype,
# )
# out.copy_(output[:, :m_size])
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")

Expand Down
92 changes: 72 additions & 20 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
import triton
import triton.language as tl

Expand Down Expand Up @@ -216,30 +215,83 @@ def triton_masked_index_gather(output, input, start_offsets, row_indices):
return


@nvtx_range("[DG] act")
@torch.compile(dynamic=True)
def swiglu_fused_moe(x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x


@nvtx_range("[DG] indexing")
@torch.compile(dynamic=True)
def indexing(x, mask):
return x[mask > 0, :].contiguous()
@triton.jit
def _preprocess_after_permute_kernel(
expert_offsets_ptr,
masked_m_ptr,
token_map_ptr,
TOTAL_TOKENS: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid_x = tl.program_id(0)
pid_y = tl.program_id(1)
if pid_y == 0:
token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
token_mask = token_offsets < TOTAL_TOKENS
# get expert_id for each token in the block
expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32)
found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1)
for i in tl.static_range(NUM_EXPERTS):
boundary = tl.load(expert_offsets_ptr + i + 1)
cond = (token_offsets < boundary) & ~found_mask
expert_ids = tl.where(cond, i, expert_ids)
found_mask = found_mask | cond
tl.store(token_map_ptr + token_offsets,
expert_ids.to(tl.int64),
mask=token_mask)
elif pid_y == 1:
# get num_tokens for each expert
expert_mask = pid_x < NUM_EXPERTS
next_offset = tl.load(expert_offsets_ptr + pid_x + 1,
mask=expert_mask,
other=0)
current_offset = tl.load(expert_offsets_ptr + pid_x,
mask=expert_mask,
other=0)
tokens_per_expert = next_offset - current_offset
tl.store(masked_m_ptr + pid_x,
tokens_per_expert.to(tl.int32),
mask=expert_mask)


@nvtx_range("[DG] preprocess_after_permute")
def preprocess_after_permute(expert_first_token_offset_tensor,
permuted_data_tensor):
# get tokens per expert
masked_m = expert_first_token_offset_tensor[
1:] - expert_first_token_offset_tensor[:-1]
token_to_expert_map = torch.searchsorted(
expert_first_token_offset_tensor[1:],
torch.arange(permuted_data_tensor.shape[0], device='cuda'),
right=True)
return masked_m.to(torch.int32), token_to_expert_map
"""
Python wrapper that launches a single fused kernel to get the token-to-expert map
and the number of tokens per expert.
"""
total_tokens = permuted_data_tensor.shape[0]
num_experts = expert_first_token_offset_tensor.shape[0] - 1

# create output tensors
masked_m = torch.empty(num_experts, dtype=torch.int32, device='cuda')
token_to_expert_map = torch.empty(total_tokens,
dtype=torch.int64,
device='cuda')

# calculate the grid size
DEFAULT_BLOCK_SIZE_M = 256
grid_m_size = triton.cdiv(total_tokens, DEFAULT_BLOCK_SIZE_M)
if grid_m_size >= num_experts:
BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M
grid = (grid_m_size, 2)
else:
block_size_m = triton.cdiv(total_tokens, num_experts)
BLOCK_SIZE_M = triton.next_power_of_2(block_size_m)
grid = (num_experts, 2)

# launch the kernel
_preprocess_after_permute_kernel[grid](
expert_first_token_offset_tensor,
masked_m,
token_to_expert_map,
TOTAL_TOKENS=total_tokens,
NUM_EXPERTS=num_experts,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)
return masked_m, token_to_expert_map


@nvtx_range("[DG]")
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def per_token_quant_and_transform(
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
output_scale = torch.empty((scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')

Expand Down