Skip to content

[Bug]: torch.compile padding causes IMA on Hopper + DBO #25623

@ProExpertProg

Description

@ProExpertProg

Your current environment

The output of python collect_env.py
Collecting environment information...
uv is set
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.4 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version                : Could not collect
CMake version                : Could not collect
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.8.0+cu128
Is debug build               : False
CUDA used to build PyTorch   : 12.8
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] (64-bit runtime)
Python platform              : Linux-5.15.0-113-generic-x86_64-with-glibc2.35

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 12.8.61
CUDA_MODULE_LOADING set to   : LAZY
GPU models and configuration : 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version        : 570.86.10
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             384
On-line CPU(s) list:                0-383
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 9654 96-Core Processor
CPU family:                         25
Model:                              17
Thread(s) per core:                 2
Core(s) per socket:                 96
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU max MHz:                        3707.8120
CPU min MHz:                        1500.0000
BogoMIPS:                           4799.98
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                     AMD-V
L1d cache:                          6 MiB (192 instances)
L1i cache:                          6 MiB (192 instances)
L2 cache:                           192 MiB (192 instances)
L3 cache:                           768 MiB (24 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-95,192-287
NUMA node1 CPU(s):                  96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.8.0+cu128
[pip3] torchaudio==2.8.0+cu128
[pip3] torchvision==0.23.0+cu128
[pip3] transformers==4.56.2
[pip3] triton==3.4.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.11.0rc2.dev71+ge18b714b2 (git sha: e18b714b2)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    96-191,288-383  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     96-191,288-383  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    96-191,288-383  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    96-191,288-383  1               N/A
NIC0    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS      X      NODE    NODE    NODE    SYS     SYS     SYS     SYS                             
NIC1    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE    SYS     SYS     SYS     SYS                             
NIC2    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    SYS     SYS     SYS     SYS                             
NIC3    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      SYS     SYS     SYS     SYS                             
NIC4    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS      X      NODE    NODE    NODE                            
NIC5    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE                            
NIC6    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE                            
NIC7    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE     X                              

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7

==============================
     Environment Variables
==============================
LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64
CUDA_HOME=/usr/local/cuda
CUDA_HOME=/usr/local/cuda
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

torch.compile generated padding kernel causes an Illegal Memory Access when padding inouts to block-quantized fp8 cutlass kernels on Hopper and DBO enabled. This bug was introduced in #24666.

^[[1;36m(APIServer pid=1)^[[0;0m ^[[1;36m(EngineCore_DP0 pid=276)^[[0;0m ERROR 09-24 13:00:49 [core.py:708] RuntimeError: Failed: CUDA error /tmp/deepep/csrc/kernels/internode_ll.cu:391 'an illegal memory access was encountered'

Minimal Repro instructions:

Credit to @ElizaWszola for finding this smaller repro.

VLLM_USE_DEEP_GEMM=0 VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve \
                Qwen/Qwen3-30B-A3B-FP8 \
                --port 7557 \
                --disable-uvicorn-access-log \
                --trust-remote-code \
                --enable-expert-parallel \
                --data-parallel-hybrid-lb \
                --tensor-parallel-size 2 \
                --data-parallel-size 2 \
                --data-parallel-size-local 2 \
                --data-parallel-address localhost \
                --data-parallel-rpc-port 5555 \
                --data-parallel-start-rank 0 \
                --enable-eplb \
                --eplb-config '{"window_size":"1000",
                                "step_interval":"3000",
                                "num_redundant_experts":"32",
                                "log_balancedness":"False"}' \
                --enable-dbo \
                --dbo-decode-token-threshold 32 \
                --kv_transfer_config '{"kv_connector":"NixlConnector",
                                        "kv_role":"kv_both"}' \
                --max_num_seqs 256

Original repro instructions:

I'm deploying vLLM using the llm-d WideEP well-lit-path

See the decoder manifest here:
https://github.com/llm-d/llm-d/blob/4970c7c2703dc23605719491c4fb380973b13517/guides/wide-ep-lws/manifests/modelserver/base/decode.yaml

In particular this is the vLLM launch command.

              exec vllm serve \
                deepseek-ai/DeepSeek-R1-0528 \
                --port 8200 \
                --disable-uvicorn-access-log \
                --trust-remote-code \
                --enable-expert-parallel \
                --data-parallel-hybrid-lb \
                --tensor-parallel-size $TP_SIZE \
                --data-parallel-size $((LWS_GROUP_SIZE * DP_SIZE_LOCAL)) \
                --data-parallel-size-local $DP_SIZE_LOCAL \
                --data-parallel-address ${LWS_LEADER_ADDRESS} \
                --data-parallel-rpc-port 5555 \
                --data-parallel-start-rank $START_RANK \
                --enable-eplb \
                --eplb-config '{"window_size":"1000",
                                "step_interval":"3000",
                                "num_redundant_experts":"32",
                                "log_balancedness":"False"}' \
                --enable-dbo \
                --dbo-decode-token-threshold 32 \
                --kv_transfer_config '{"kv_connector":"NixlConnector",
                                        "kv_role":"kv_both"}'

From investigations of @LucasWilkinson:

the weird part is it is failing in triton_poi_fused.to_copy_add_constant_pad_nd_mean_mul_pow_rsqrt_2
and DBO isnt even running
the fishy thing is torch compile appears to be rounding the input up to 4?

        triton_poi_fused__to_copy_add_constant_pad_nd_mean_mul_pow_rsqrt_2_xnumel = 7168*s72 + 7168*(((-1)*s72) % 4)
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_constant_pad_nd_mean_mul_pow_rsqrt_2.run(buf17, buf13, buf12, arg4_1, buf14, arg7_1, s72, triton_poi_fused__to_copy_add_constant_pad_nd_mean_mul_pow_rsqrt_2_xnumel, stream=stream0)

possibly related to

if self.is_hopper:
# We pad unconditionally (even if shape is already divisible by 4)
# to support dynamic shape for input_2d.shape[0] in torch.compile
x = torch.nn.functional.pad(input_2d,
(0, 0, 0, -input_2d.shape[0] % 4))

This is failing on a store:

Address Instruction
0x0000002ad1ea96c0 <+3520> PRMT R15, R8, 0x5410, R9
0x0000002ad1ea96d0 <+3536> @p2 EXIT
*> 0x0000002ad1ea96e0 <+3552> STG.E.128 desc[UR6][R26.64], R12
=> 0x0000002ad1ea96f0 <+3568> EXIT

Also another failure for DBO

Both TP and DBO need to be enabled for the issue to be triggered:

$ vllm serve deepseek-ai/DeepSeek-V2-Lite --disable-uvicorn-access-log --trust-remote-code --enable-dbo --dbo-decode-token-threshold 32 --tensor-parallel 2
...
Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:455 'an illegal memory access was encountered'
Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:455 'an illegal memory access was encountered'

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

Ready

Relationships

None yet

Development

No branches or pull requests

Issue actions