diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 81b93ebdf0fc..38cea2462b44 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -66,12 +66,20 @@ def gguf_model(self): gguf_filename="starcoder2-3b.Q6_K.gguf", ) +DOLPHIN_CONFIG = GGUFTestConfig( + # Test VocabParallelEmbedding sharding issue. + original_model="cognitivecomputations/TinyDolphin-2.8-1.1b", + gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF", + gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf", +) + MODELS = [ LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, + DOLPHIN_CONFIG # STARCODER_CONFIG, # broken ] @@ -107,6 +115,7 @@ def test_models( # Run unquantized model. with vllm_runner(model_name=model.original_model, + enforce_eager=True, # faster tests dtype=dtype, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=tp_size) as original_model: @@ -115,6 +124,7 @@ def test_models( # Run gguf model. with vllm_runner(model_name=model.gguf_model, + enforce_eager=True, tokenizer_name=model.original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 65920aa61ba1..3eb5c39ccf58 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -355,7 +355,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): elif isinstance(param, UninitializedParameter): shape = list(loaded_weight.shape) if output_dim is not None: - shape[output_dim] = shape[output_dim] // self.tp_size + shape[output_dim] = self.num_embeddings_per_partition param.materialize(tuple(shape), dtype=loaded_weight.dtype) # If parameter does not have output dim, then it should @@ -381,7 +381,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): else: assert loaded_weight.shape[output_dim] == self.org_vocab_size - # Copy the data. + # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) if current_platform.is_hpu():