Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
8f682d8
add cute dsl fp8 blockwise gemm op
limin2021 Jul 31, 2025
8cb6c2c
add release_gemm code.
limin2021 Jul 31, 2025
5606444
add bmm custom op.
limin2021 Aug 4, 2025
733d253
Merge remote-tracking branch 'origin/main' into merge-cute-dsl-deepge…
limin2021 Aug 4, 2025
4db6d16
update linear.
limin2021 Aug 4, 2025
9dd89c9
update tests.
limin2021 Aug 4, 2025
899b1bf
update test name.
limin2021 Aug 4, 2025
3c8cca4
save.
limin2021 Aug 5, 2025
3270b8d
make fused_moe_cute_dsl work on blackwell.
limin2021 Aug 5, 2025
66e4a7d
rename.
limin2021 Aug 5, 2025
9cd1f27
remove print.
limin2021 Aug 5, 2025
d80cabe
refactor moe loading logics.
limin2021 Aug 5, 2025
0a74c99
minor.
limin2021 Aug 5, 2025
2376001
save.
limin2021 Aug 5, 2025
dacfca3
Merge branch 'merge-cute-dsl-blackwell-step-0' into merge-cute-dsl-de…
limin2021 Aug 5, 2025
5d97b07
use env to control cute dsl ops.
limin2021 Aug 5, 2025
e7d2120
minor.
limin2021 Aug 5, 2025
7924f93
update
limin2021 Aug 5, 2025
c1d9877
Merge remote-tracking branch 'origin' into merge-cute-dsl-blackwell-s…
limin2021 Aug 6, 2025
99b56a7
do not delete fused_moe_cutlass hopper test
limin2021 Aug 6, 2025
b59a9ad
recover fused_moe_cutlass test on hopper.
limin2021 Aug 6, 2025
eddf7ea
minor
limin2021 Aug 6, 2025
88e9e2e
Merge branch 'merge-cute-dsl-blackwell-step-0' into merge-cute-dsl-de…
limin2021 Aug 6, 2025
97ab168
Merge remote-tracking branch 'origin' into merge-cute-dsl-blackwell-s…
limin2021 Aug 7, 2025
4a06767
Merge remote-tracking branch 'origin' into merge-cute-dsl-blackwell-s…
limin2021 Aug 7, 2025
fc6acbd
Merge remote-tracking branch 'origin' into merge-cute-dsl-blackwell-s…
limin2021 Aug 8, 2025
6450dfc
Merge branch 'merge-cute-dsl-blackwell-step-0' into merge-cute-dsl-de…
limin2021 Aug 8, 2025
f3f402d
Merge remote-tracking branch 'origin' into merge-cute-dsl-deepgemm-bl…
limin2021 Aug 8, 2025
2f65803
add cute_dsl_group_gemm op.
limin2021 Aug 8, 2025
e826468
Merge remote-tracking branch 'origin' into merge-cute-dsl-deepgemm-bl…
limin2021 Aug 11, 2025
18eb06b
minor.
limin2021 Aug 11, 2025
5c4fe5a
recover llm.py
limin2021 Aug 13, 2025
a39d3d8
Merge remote-tracking branch 'origin/main' into merge-cute-dsl-deepge…
limin2021 Aug 13, 2025
73878f1
Merge remote-tracking branch 'origin' into merge-cute-dsl-deepgemm-bl…
limin2021 Aug 14, 2025
04d25b6
recover linear.py
limin2021 Aug 14, 2025
919e886
minor.
limin2021 Aug 14, 2025
94f7ea6
remove env and use configs.
limin2021 Aug 14, 2025
f4a127c
clean up codes.
limin2021 Aug 14, 2025
ccc78e1
clean custom op.
limin2021 Aug 14, 2025
4a85b63
remove useless cute dsl kernel
limin2021 Aug 14, 2025
d4daff8
minor.
limin2021 Aug 14, 2025
2cd2f00
remove useless code.
limin2021 Aug 14, 2025
3a9ccb7
add mm, bmm autotune.
limin2021 Aug 14, 2025
2012495
fix typo.
limin2021 Aug 14, 2025
88248b8
add comments.
limin2021 Aug 15, 2025
73b9085
Merge remote-tracking branch 'origin' into merge-cute-dsl-deepgemm-bl…
limin2021 Aug 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/thop/moeUtilOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
int const num_experts_per_node = num_experts_on_rank;
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int64_t num_moe_inputs = static_cast<int64_t>(experts_per_token * num_rows);
TORCH_CHECK(num_moe_inputs <= std::numeric_limits<int32_t>::max(),
"num_moe_inputs exceeds int32 range (because we use int32 for expert_first_token_offset_tensor output). "
"num_moe_inputs = ",
num_moe_inputs);

auto permuted_row_to_unpermuted_row_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
Expand Down Expand Up @@ -224,6 +228,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
break;
}
expert_first_token_offset_tensor = expert_first_token_offset_tensor.to(torch::kInt32);

return std::make_tuple(permuted_row_to_unpermuted_row_tensor, permuted_token_selected_experts_tensor,
permuted_data_tensor, expert_first_token_offset_tensor, permuted_token_final_scales_tensor,
unpermuted_row_to_permuted_row_tensor);
Expand Down
10 changes: 10 additions & 0 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def add_llm_args(parser):
default=False,
action='store_true')
parser.add_argument('--logprobs', default=False, action='store_true')

# cute dsl op configs
parser.add_argument('--use_cute_dsl_blockscaling_mm',
default=False,
action='store_true')
parser.add_argument('--use_cute_dsl_blockscaling_bmm',
default=False,
action='store_true')
return parser


Expand Down Expand Up @@ -246,6 +254,8 @@ def setup_llm(args, **kwargs):
trust_remote_code=args.trust_remote_code,
gather_generation_logits=args.return_generation_logits,
max_beam_width=args.max_beam_width,
use_cute_dsl_blockscaling_mm=args.use_cute_dsl_blockscaling_mm,
use_cute_dsl_blockscaling_bmm=args.use_cute_dsl_blockscaling_bmm,
**kwargs)

use_beam_search = args.max_beam_width > 1
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ soundfile
triton==3.3.1; platform_machine == "x86_64"
tiktoken
blobfile
nvidia-cutlass-dsl
Loading