Skip to content

Commit 7794e13

Browse files
committed
Disable H20
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 7e2b181 commit 7794e13

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,15 @@ def swizzle_weight_and_scale(w: torch.Tensor, w_scale: torch.Tensor):
643643
mx_axis=1)
644644
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
645645
mx_axis=1, num_warps=num_warps)
646+
# Swizzeling path is broken for H20
647+
if torch.cuda.get_device_name() == "NVIDIA H20":
648+
from triton_kernels.tensor_details.layout_details.strided import \
649+
StridedLayout
650+
value_layout = StridedLayout
651+
value_layout_opts = dict()
652+
scale_layout = StridedLayout
653+
scale_layout_opts = dict()
654+
646655
opt = {"value_layout": value_layout, "value_layout_opts": value_layout_opts, \
647656
"scale_layout": scale_layout, "scale_layout_opts": scale_layout_opts}
648657

0 commit comments

Comments
 (0)