- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.29.5
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A4000
GPU 1: NVIDIA RTX A4000
GPU 2: NVIDIA RTX A4000
GPU 3: NVIDIA RTX A4000
Nvidia driver version: 555.42.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        43 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               48
On-line CPU(s) list:                  0-47
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen Threadripper 3960X 24-Core Processor
CPU family:                           23
Model:                                49
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            1
Stepping:                             0
Frequency boost:                      enabled
CPU max MHz:                          4568.1641
CPU min MHz:                          2200.0000
BogoMIPS:                             7599.93
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization:                       AMD-V
L1d cache:                            768 KiB (24 instances)
L1i cache:                            768 KiB (24 instances)
L2 cache:                             12 MiB (24 instances)
L3 cache:                             128 MiB (8 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-47
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow:   Mitigation; Safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Versions of relevant libraries:
[pip3] mypy==1.11.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.555.43
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.40
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.18.1
[pip3] pyzmq==26.0.3
[pip3] sentence-transformers==3.0.1
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.43.4
[pip3] transformers-stream-generator==0.0.5
[pip3] triton==3.0.0
[conda] No relevant packages
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.0@6f81d9d941997dfad3b1dde3e77984dedffd1022
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      SYS     SYS     SYS     0-47    0               N/A
GPU1    SYS      X      SYS     SYS     0-47    0               N/A
GPU2    SYS     SYS      X      PHB     0-47    0               N/A
GPU3    SYS     SYS     PHB      X      0-47    0               N/A
Legend:
  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks
🐛 Describe the bug
#7652 added support for decode token logprobs and prompt logprobs with multi-step scheduling. However attempting to configure a non-None prompt logprobs setting causing vLLM to crash.
This is the test code (not present on the main branch, see PR branch c6f703d or link: https://github.com/neuralmagic/vllm/blob/c6f703d9ea084042cd5002e08a253b8e1f161cd7/tests/multi_step/test_correctness_llm.py#L113):
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
def test_multi_step_llm_w_prompt_logprobs(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    tp_size: int,
    max_tokens: int,
    enforce_eager: int,
    num_scheduler_steps: int,
    num_prompts: int,
    num_logprobs: Optional[int],
    num_prompt_logprobs: Optional[int],
) -> None:
    """Test vLLM engine with multi-step scheduling via sync LLM Engine.
    Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
    reference.
    Prompt them with the same example prompts.
    Validate:
    * Generated logprobs are all very close
    Args:
      hf_runner: HF transformers model runner fixture
      vllm_runner: vLLM model runner fixture
      example_prompts: test fixture providing example prompts
      model: model under test (same for single- and multi-step engines)
      dtype: tensor datatype for engine to utilize
      tp_size: degree of tensor-parallelism
      max_tokens: the maximum number of tokens to generate
      enforce_eager
      num_scheduler_steps: for multi-step scheduling, GPU-side steps per
                           GPU -> CPU output transfer
      num_prompts: number of example prompts under test
      num_logprobs: corresponds to the `logprobs` argument to the OpenAI
                    completions endpoint; `None` -> no logprobs
    """
    prompts = example_prompts
    if len(prompts) < num_prompts:
        prompts = prompts * ((num_prompts // len(prompts)) + 1)
    prompts = prompts[:num_prompts]
    assert len(prompts) == num_prompts
    with vllm_runner(
            model,
            dtype=dtype,
            enforce_eager=enforce_eager,
            gpu_memory_utilization=0.7,
            tensor_parallel_size=tp_size,
            use_v2_block_manager=True,
            num_scheduler_steps=num_scheduler_steps,
    ) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            prompts,
            max_tokens,
            num_logprobs,
            num_prompt_logprobs=num_prompt_logprobs)
    with vllm_runner(
            model,
            dtype=dtype,
            enforce_eager=enforce_eager,
            gpu_memory_utilization=0.7,
            tensor_parallel_size=tp_size,
    ) as vllm_model:
        single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
            prompts,
            max_tokens,
            num_logprobs,
            num_prompt_logprobs=num_prompt_logprobs)
    check_logprobs_close(
        outputs_0_lst=single_step_vllm_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
Run the test:
pytest -s -v tests/multi_step/test_correctness_llm.py::test_multi_step_llm_w_prompt_logprobs[5-5-10-8-True-5-1-half-JackFram/llama-160m]
The resultant exception is:
    def _pythonize_sampler_output(
        model_input: StatefulModelInput,
        output: SamplerOutput,
        pinned_sampled_token_buffer: torch.Tensor,
        sampled_token_ids: torch.Tensor,
        logprobs_tensor: Optional[torch.Tensor],
        cache: Optional[PythonizationCache],
    ) -> None:
        """ This function is only called when the output tensors are ready.
        See :class:`ModelOutput`.
    
        Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
        adding a Pythonized output data structure
        (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
    
        Args:
          model_input
          output: sampler output
          pinned_sampled_token_token_buffer: CPU-side pinned memory
                                             (receives copy of
                                             GPU-side token buffer.)
          sampled_token_ids: GPU-side token buffer
          logprobs_tensor: GPU-side tensor containing
                           logprobs computed during sampling
        """
    
        assert model_input.frozen_model_input is not None
    
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input.sampling_metadata is not None
        # samples generation should have been skipped
        assert not output.outputs
    
        pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
    
        # CPU GPU sync
>       pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
E       RuntimeError: The size of tensor a (10) must match the size of tensor b (240) at non-singleton dimension 0
vllm/worker/multi_step_model_runner.py:632: RuntimeError
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
 
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working