Skip to content

[Bug]: failed to test tests/lora/test_layers.py::test_embeddings[True-512-cuda:1-1] #9794

@NaNAGISaSA

Description

@NaNAGISaSA

Your current environment

The output of `python collect_env.py`
Collecting environment information...
WARNING 10-29 04:15:30 _custom_ops.py:19] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.12.7 (main, Oct  1 2024, 08:52:12) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

Nvidia driver version: 535.183.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      43 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             256
On-line CPU(s) list:                0-255
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7742 64-Core Processor
CPU family:                         23
Model:                              49
Thread(s) per core:                 2
Core(s) per socket:                 64
Socket(s):                          2
Stepping:                           0
Frequency boost:                    enabled
CPU max MHz:                        2250.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           4491.84
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization:                     AMD-V
L1d cache:                          4 MiB (128 instances)
L1i cache:                          4 MiB (128 instances)
L2 cache:                           64 MiB (128 instances)
L3 cache:                           512 MiB (32 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-63,128-191
NUMA node1 CPU(s):                  64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] flashinfer==0.1.6+cu121torch2.4
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.45.2
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NV12	64-127,192-255	1		N/A
GPU1	NV12	 X 	64-127,192-255	1		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Model Input Dumps

None

🐛 Describe the bug

  • pytest tests/lora/test_layers.py::test_embeddings[True-512-cuda:0-1]: passed
  • pytest tests/lora/test_layers.py::test_embeddings[True-512-cuda:1-1]: failed

seems that DummyLoRAManager().init_random_lora puts lora weight on the wrong device, error msg:

=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.12.7, pytest-8.3.3, pluggy-1.5.0
rootdir: /root/vllm
configfile: pyproject.toml
plugins: anyio-4.6.2.post1
collected 1 item                                                                                                                                                                                          

tests/lora/test_layers.py F                                                                                                                                                                         [100%]

================================================================================================ FAILURES =================================================================================================
___________________________________________________________________________________ test_embeddings[True-512-cuda:1-1] ____________________________________________________________________________________

dist_init = None, num_loras = 1, device = 'cuda:1', vocab_size = 512, stage = True

    @torch.inference_mode()
    @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
    @pytest.mark.parametrize("device", CUDA_DEVICES)
    @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
    @pytest.mark.parametrize("stage", STAGES)
    def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
    
        torch.set_default_device(device)
        max_loras = 8
        punica_wrapper = PunicaWrapper(8192, 256, device)
        lora_config = LoRAConfig(max_loras=max_loras,
                                 max_lora_rank=8,
                                 lora_dtype=torch.float16)
    
        def create_random_embedding_layer():
            embedding = VocabParallelEmbedding(vocab_size, 256)
            embedding.weight.data = torch.rand_like(embedding.weight.data)
            embedding.weight.data[vocab_size:, :] = 0
            lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
            lora_embedding.create_lora_weights(max_loras, lora_config)
    
            return embedding, lora_embedding
    
        for i in range(10):
            set_random_seed(i)
    
            id_to_index = get_random_id_to_index(num_loras, max_loras)
            embedding, lora_embedding = create_random_embedding_layer()
            lora_embedding.set_mapping(punica_wrapper)
            lora_dict, _ = populate_loras(
                id_to_index,
                layer=lora_embedding,
                layer_weights=embedding.weight.T,
            )
    
            inputs, index_mapping, prompt_mapping = create_random_inputs(
                active_lora_ids=list(lora_dict.keys()),
                num_inputs=num_loras * 3,
                input_size=(200, ),
                input_range=(1, vocab_size),
            )
            lora_mapping = LoRAMapping(index_mapping,
                                       prompt_mapping,
                                       is_prefill=stage)
            punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                           vocab_size,
                                           lora_config.lora_extra_vocab_size)
    
            lora_result = lora_embedding(torch.cat(inputs))
    
            expected_results: List[torch.Tensor] = []
            for input_, lora_id in zip(inputs, prompt_mapping):
                lora = lora_dict[lora_id]
                result = embedding(input_)
>               after_a = F.embedding(
                    input_,
                    lora.lora_a,
                )

tests/lora/test_layers.py:242: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py:2236: in embedding
    return handle_torch_function(
/usr/local/lib/python3.12/dist-packages/torch/overrides.py:1630: in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
/usr/local/lib/python3.12/dist-packages/torch/utils/_device.py:79: in __torch_function__
    return func(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

input = tensor([ 36, 331, 238, 230, 382, 423, 240, 180, 416, 450, 480, 129, 485, 348,
        223, 312, 257, 249,  30, 111, 41...    324, 503, 414,  97,  23, 229, 426, 194, 113, 366, 147, 411, 326, 236,
        374, 254, 448,  55], device='cuda:1')
weight = tensor([[0.3990, 0.5167, 0.0249,  ..., 0.7967, 0.4150, 0.8203],
        [0.2290, 0.9096, 0.1183,  ..., 0.9601, 0.2093,... ..., 0.2047, 0.2683, 0.8661],
        [0.9411, 0.3439, 0.2431,  ..., 0.2671, 0.1570, 0.2273]],
       device='cuda:0')
padding_idx = -1, max_norm = None, norm_type = 2.0, scale_grad_by_freq = False, sparse = False

    def embedding(
        input: Tensor,
        weight: Tensor,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
    ) -> Tensor:
        r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.
    
        This module is often used to retrieve word embeddings using indices.
        The input to the module is a list of indices, and the embedding matrix,
        and the output is the corresponding word embeddings.
    
        See :class:`torch.nn.Embedding` for more details.
    
        .. note::
            Note that the analytical gradients of this function with respect to
            entries in :attr:`weight` at the row specified by :attr:`padding_idx`
            are expected to differ from the numerical ones.
    
        .. note::
            Note that `:class:`torch.nn.Embedding` differs from this function in
            that it initializes the row of :attr:`weight` specified by
            :attr:`padding_idx` to all zeros on construction.
    
        Args:
            input (LongTensor): Tensor containing indices into the embedding matrix
            weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
                and number of columns equal to the embedding size
            padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
                                         therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
                                         i.e. it remains as a fixed "pad".
            max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                                        is renormalized to have norm :attr:`max_norm`.
                                        Note: this will modify :attr:`weight` in-place.
            norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
            scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
                                                    the words in the mini-batch. Default ``False``.
            sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
                                     :class:`torch.nn.Embedding` for more details regarding sparse gradients.
    
        Shape:
            - Input: LongTensor of arbitrary shape containing the indices to extract
            - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
              where V = maximum index + 1 and embedding_dim = the embedding size
            - Output: `(*, embedding_dim)`, where `*` is the input shape
    
        Examples::
    
            >>> # a batch of 2 samples of 4 indices each
            >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
            >>> # an embedding matrix containing 10 tensors of size 3
            >>> embedding_matrix = torch.rand(10, 3)
            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
            >>> F.embedding(input, embedding_matrix)
            tensor([[[ 0.8490,  0.9625,  0.6753],
                     [ 0.9666,  0.7761,  0.6108],
                     [ 0.6246,  0.9751,  0.3618],
                     [ 0.4161,  0.2419,  0.7383]],
    
                    [[ 0.6246,  0.9751,  0.3618],
                     [ 0.0237,  0.7794,  0.0528],
                     [ 0.9666,  0.7761,  0.6108],
                     [ 0.3385,  0.8612,  0.1867]]])
    
            >>> # example with padding_idx
            >>> weights = torch.rand(10, 3)
            >>> weights[0, :].zero_()
            >>> embedding_matrix = weights
            >>> input = torch.tensor([[0, 2, 0, 5]])
            >>> F.embedding(input, embedding_matrix, padding_idx=0)
            tensor([[[ 0.0000,  0.0000,  0.0000],
                     [ 0.5609,  0.5384,  0.8720],
                     [ 0.0000,  0.0000,  0.0000],
                     [ 0.6262,  0.2438,  0.7471]]])
        """
        if has_torch_function_variadic(input, weight):
            return handle_torch_function(
                embedding,
                (input, weight),
                input,
                weight,
                padding_idx=padding_idx,
                max_norm=max_norm,
                norm_type=norm_type,
                scale_grad_by_freq=scale_grad_by_freq,
                sparse=sparse,
            )
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
            elif padding_idx < 0:
                assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings"
                padding_idx = weight.size(0) + padding_idx
        else:
            padding_idx = -1
        if max_norm is not None:
            # Note [embedding_renorm contiguous]
            # `embedding_renorm_` will call .contiguous() on input anyways, so we
            # call it here and take advantage of the improved locality in the
            # `embedding` call below too.
            input = input.contiguous()
            # Note [embedding_renorm set_grad_enabled]
            # XXX: equivalent to
            # with torch.no_grad():
            #   torch.embedding_renorm_
            # remove once script supports set_grad_enabled
            _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
>       return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py:2267: RuntimeError
------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
Created lora_id_to_index mapping: [None, None, None, None, None, 1, None, None].
========================================================================================= short test summary info =========================================================================================
FAILED tests/lora/test_layers.py::test_embeddings[True-512-cuda:1-1] - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)
============================================================================================ 1 failed in 1.86s ============================================================================================

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions