-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Closed as not planned
Closed as not planned
Copy link
Labels
bugSomething isn't workingSomething isn't workingstaleOver 90 days of inactivityOver 90 days of inactivitytorch.compile
Description
Your current environment
The output of `python collect_env.py`
PyTorch version: 2.7.0a0+git295f2ed
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.3.42133-1b9c17779
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 18.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-6.3.1 24491 1e0fda770a2079fbd71e4b70974d74f62fd3af10)
CMake version: version 3.31.6
Libc version: glibc-2.35
Python version: 3.12.9 (main, Feb 5 2025, 08:49:00) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-116-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.3.42133
MIOpen runtime version: 3.3.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3707.8120
CPU min MHz: 1500.0000
BogoMIPS: 4792.60
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95
NUMA node1 CPU(s): 96-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pyzmq==26.4.0
[pip3] torch==2.7.0a0+git295f2ed
[pip3] torchvision==0.21.0+7af6987
[pip3] transformers==4.51.0
[pip3] triton==3.2.0+gite5be006a
[conda] Could not collect
ROCM Version: 6.3.42133-1b9c17779
Neuron SDK Version: N/A
vLLM Version: 0.8.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 15 15 15 15 15 15 15
GPU1 15 0 15 15 15 15 15 15
GPU2 15 15 0 15 15 15 15 15
GPU3 15 15 15 0 15 15 15 15
GPU4 15 15 15 15 0 15 15 15
GPU5 15 15 15 15 15 0 15 15
GPU6 15 15 15 15 15 15 0 15
GPU7 15 15 15 15 15 15 15 0
================================= Hops between two GPUs ==================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 1 1 1 1 1 1 1
GPU1 1 0 1 1 1 1 1 1
GPU2 1 1 0 1 1 1 1 1
GPU3 1 1 1 0 1 1 1 1
GPU4 1 1 1 1 0 1 1 1
GPU5 1 1 1 1 1 0 1 1
GPU6 1 1 1 1 1 1 0 1
GPU7 1 1 1 1 1 1 1 0
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
======================================= Numa Nodes =======================================
GPU[0] : (Topology) Numa Node: 0
GPU[0] : (Topology) Numa Affinity: 0
GPU[1] : (Topology) Numa Node: 0
GPU[1] : (Topology) Numa Affinity: 0
GPU[2] : (Topology) Numa Node: 0
GPU[2] : (Topology) Numa Affinity: 0
GPU[3] : (Topology) Numa Node: 0
GPU[3] : (Topology) Numa Affinity: 0
GPU[4] : (Topology) Numa Node: 1
GPU[4] : (Topology) Numa Affinity: 1
GPU[5] : (Topology) Numa Node: 1
GPU[5] : (Topology) Numa Affinity: 1
GPU[6] : (Topology) Numa Node: 1
GPU[6] : (Topology) Numa Affinity: 1
GPU[7] : (Topology) Numa Node: 1
GPU[7] : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================
PYTORCH_TUNABLEOP_TUNING=0
PYTORCH_TUNABLEOP_ENABLED=1
PYTORCH_ROCM_ARCH=gfx942
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
PYTORCH_TUNABLEOP_FILENAME=/app/afo_tune_device_%d_full.csv
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
When running Llama-4-Scout-17B-16E-Instruct with the following combination of vars/args:
- VLLM_USE_V1=1
- quantization="fp8"
- enforce_eager=False
The model produces gibberish. To reproduce the issue, run the following script:
VLLM_WORKER_MULTIPROC_METHOD=spawn SAFETENSORS_FAST_GPU=1 VLLM_USE_V1=1 python example.py
#example.py
from vllm import LLM, SamplingParams
def test():
prompts = [
"The color of the sky is blue but sometimes it can also be",
"The capital of France is",
]
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=256)
llm = LLM(
model="/app/model/Llama-4-Scout-17B-16E-Instruct/",
tensor_parallel_size=4,
quantization="fp8",
max_model_len=8192,
enforce_eager=False,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
test()
This issue occurs only when enforce_eager=False
. Setting enforce_eager
to true will generate reasonable output. Furthermore, removing cuda graph padding from num_input_tokens
in vllm/v1/worker/gpu_model_runner.py
seems to resolve this issue:
# Remove cuda padding
# if (self.use_cuda_graph
# and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# # Use piecewise CUDA graphs.
# # Add padding to the batch size.
# num_input_tokens = self.vllm_config.pad_for_cudagraph(
# num_scheduled_tokens)
# else:
# # Eager mode.
# num_input_tokens = num_scheduled_tokens
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
Sample outputs
With CUDA padding
Generated text: ' seen as sky blue if I have seen sunlight.\n\n• counters old, counters...'
Generated text: " given as well as capital France. \n\nThe port, int. The enclosure is a big..."
Without CUDA padding
Generated text: 'red, orange, or other colors depending on the time of day and atmospheric conditions. ...'
Generated text: "known as the City of Light. It is famous for its art, fashion, and culture. If you are planning to visit Paris,..."
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleOver 90 days of inactivityOver 90 days of inactivitytorch.compile
Type
Projects
Status
Done