Skip to content

Conversation

@sirejdua
Copy link
Contributor

@sirejdua sirejdua commented Jul 1, 2024

This PR adds support for a draft worker with TP==1 and a target worker with TP>1. Support for draft worker TP>1 will come in a 2nd PR.

This PR makes use of #5414 to wrap the MLPSpeculatorWorker with a SmallerTPProposerWorker.

Adds a test case for
ibm-granite/granite-3b-code-instruct{-accelerator} to test_draft_model_tp_lt_target_model_tp2.

FIX #5809

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @sirejdua, very clean changes!

@njhill
Copy link
Member

njhill commented Jul 1, 2024

@sirejdua you need to merge in the latest main and resolve the conflicts

@njhill njhill mentioned this pull request Jul 1, 2024
This PR adds support for a draft worker with TP==1 and a target worker
with TP>1. Support for draft worker>1 will come in a 2nd PR.

This PR makes use of vllm-project#5414 to wrap the `MLPSpeculatorWorker` with a
`SmallerTPProposerWorker`.

Adds a test case for
`ibm-granite/granite-3b-code-instruct{-accelerator}` to `test_draft_model_tp_lt_target_model_tp2`.
@njhill njhill enabled auto-merge (squash) July 2, 2024 01:24
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

@sirejdua-db
Copy link
Contributor

The test failure here looks like a config issue. After fixing it locally, I do see an issue regarding token generation. Any idea what could be causing this? I will continue debugging in the morning

============================= test session starts ==============================
platform linux -- Python 3.11.9, pytest-8.2.2, pluggy-1.5.0 -- /root/venv/bin/python3
cachedir: .pytest_cache
rootdir: /root/vllm-private
configfile: pyproject.toml
plugins: rerunfailures-14.0, shard-0.1.2, forked-1.6.0, anyio-4.4.0, asyncio-0.23.7
asyncio: mode=Mode.STRICT
collecting ... collected 4 items / 3 deselected / 1 selected
Running 1 items in this shard: tests/spec_decode/e2e/test_integration_dist_tp2.py::test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0]

tests/spec_decode/e2e/test_integration_dist_tp2.py::test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0] FAILED [100%]

=================================== FAILURES ===================================
_ test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0] _

test_llm_generator = <function create_llm_generator.<locals>.generator_outer at 0x7be9af90c9a0>
baseline_llm_generator = <function create_llm_generator.<locals>.generator_outer at 0x7be9af90cb80>
batch_size = 2

    @pytest.mark.skipif(torch.cuda.device_count() < 2,
                        reason="Need at least 2 GPUs to run the test.")
    @pytest.mark.parametrize(
        "common_llm_kwargs",
        [{
            # Skip cuda graph recording for fast test.
            "enforce_eager": True,
    
            # Required for spec decode.
            "use_v2_block_manager": True,
            "tensor_parallel_size": 2,
    
            # Use AsyncLLM engine, so that the engine runs in its own process.
            # Otherwise, since vLLM does not follow true SPMD, the test runner
            # process will have both the engine and the rank0 worker. NCCL is not
            # cleaned up properly, and its server host thread leaks, causing the
            # second run of the test to fail with internal NCCL error.
            "use_async": True,
        }])
    @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
    @pytest.mark.parametrize(
        "per_test_common_llm_kwargs, test_llm_kwargs",
        [(
            {
                # Use a small model for a fast test.
                # Note this is repeated in the test body; to initialize a tokenizer.
                "model": "JackFram/llama-68m",
            },
            {
                "speculative_model": "JackFram/llama-68m",
                "num_speculative_tokens": 5,
                "speculative_draft_tensor_parallel_size": 1,
            }),
            ({
                "model": "ibm-granite/granite-3b-code-instruct",
            },
            {
                "speculative_model":
                "ibm-granite/granite-3b-code-instruct-accelerator",
                "num_speculative_tokens": 5,
                "speculative_draft_tensor_parallel_size": 1,
            })
        ])
    @pytest.mark.parametrize("batch_size", [2])
    @pytest.mark.parametrize("seed", [1])
    def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
                                                baseline_llm_generator,
                                                batch_size: int):
        """Verify spec decode works well with smaller tp for draft models.
        """
>       run_greedy_equality_correctness_test(baseline_llm_generator,
                                             test_llm_generator,
                                             batch_size,
                                             max_output_len=32,
                                             force_output_len=True)

tests/spec_decode/e2e/test_integration_dist_tp2.py:118: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

baseline_llm_generator = <function create_llm_generator.<locals>.generator_outer at 0x7be9af90cb80>
test_llm_generator = <function create_llm_generator.<locals>.generator_outer at 0x7be9af90c9a0>
batch_size = 2, max_output_len = 32, force_output_len = True
print_tokens = False

    def run_greedy_equality_correctness_test(baseline_llm_generator,
                                             test_llm_generator,
                                             batch_size,
                                             max_output_len,
                                             force_output_len: bool,
                                             print_tokens: bool = False):
        """Helper method that compares the outputs of both the baseline LLM and
        the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
        the same when temperature is zero.
        """
        temperature = 0.0
    
        prompts = [
            "Hello, my name is",
            "The president of the United States is",
            "The capital of France is",
            "The future of AI is",
            "San Francisco is know for its",
            "Facebook was created in 2004 by",
            "Curious George is a",
            "Python 3.11 brings improvements to its",
        ]
    
        prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
    
        # If the test requires that we generated max_output_len tokens, then set the
        # sampling params to ignore eos token.
        ignore_eos = force_output_len
    
        sampling_params = SamplingParams(
            max_tokens=max_output_len,
            ignore_eos=ignore_eos,
            temperature=temperature,
        )
    
        spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
            test_llm_generator, prompts, sampling_params)
    
        (baseline_batch_tokens,
         baseline_batch_token_ids) = get_output_from_llm_generator(
             baseline_llm_generator, prompts, sampling_params)
    
        assert len(baseline_batch_token_ids) == len(prompts)
        assert len(spec_batch_token_ids) == len(prompts)
    
        for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
                spec_tokens) in enumerate(
                    zip(baseline_batch_token_ids, baseline_batch_tokens,
                        spec_batch_token_ids, spec_batch_tokens)):
            if print_tokens:
                print(f'{i=} {baseline_tokens=}')
                print(f'{i=}     {spec_tokens=}')
            print(f'{i=} {baseline_token_ids=}')
            print(f'{i=}     {spec_token_ids=}')
>           assert baseline_token_ids == spec_token_ids
E           AssertionError: assert [428, 10195, ...461, 439, ...] == [428, 10195, ...461, 439, ...]
E             
E             At index 11 diff: 997 != 79
E             
E             Full diff:
E               [
E                   428,
E                   10195,...
E             
E             ...Full output truncated (42 lines hidden), use '-vv' to show

tests/spec_decode/e2e/conftest.py:287: AssertionError
----------------------------- Captured stdout call -----------------------------
gpu memory used (GB): 0=0.55; 1=0.55; 2=0.55; 3=0.55; 4=0.55; 5=0.55; 6=0.55; 7=0.55; 
Done waiting for free GPU memory on devices devices=[0, 1, 2, 3, 4, 5, 6, 7] (threshold_bytes/2**30=2.0) dur_s=0.00
use_async=True
Creating baseline_or_test='test' LLM for test_name='test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0]'. kwargs={'enforce_eager': True, 'use_v2_block_manager': True, 'tensor_parallel_size': 2, 'model': 'ibm-granite/granite-3b-code-instruct', 'speculative_model': 'ibm-granite/granite-3b-code-instruct-accelerator', 'num_speculative_tokens': 5, 'speculative_draft_tensor_parallel_size': 1}
WARNING 07-02 05:17:09 config.py:1378] Casting torch.float16 to torch.bfloat16.
WARNING 07-02 05:17:09 config.py:1431] The model's config.json does not contain any of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'seq_length', 'model_max_length', 'max_sequence_length', 'max_seq_length', 'seq_len']. Assuming the model's maximum length is 2048.
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:14 llm_engine.py:169] Initializing an LLM engine (v0.5.0.post1) with config: model='ibm-granite/granite-3b-code-instruct', speculative_config=SpeculativeConfig(draft_model='ibm-granite/granite-3b-code-instruct-accelerator', num_spec_tokens=5), tokenizer='ibm-granite/granite-3b-code-instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=ibm-granite/granite-3b-code-instruct)
INFO 07-02 05:17:15 async_llm_engine.py:592] Received request 370bd3e9ffd64db6b48c00b3222e5335: prompt: 'Hello, my name is', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=32, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: None, lora_request: None.
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:19 selector.py:186] Cannot use FlashAttention-2 backend for head size 80.
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:19 selector.py:53] Using XFormers backend.
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:20 smaller_tp_proposer_worker.py:39] Wrapping {<class 'vllm.spec_decode.mlp_speculator_worker.MLPSpeculatorWorker'>} in {<class 'vllm.spec_decode.smaller_tp_proposer_worker.SmallerTpProposerWorker'>}
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:20 spec_decode_worker.py:141] Configuring SpecDecodeWorker with proposer=<class 'vllm.spec_decode.smaller_tp_proposer_worker.SmallerTpProposerWorker'>
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:20 spec_decode_worker.py:155] Configuring SpecDecodeWorker with sampler=<class 'vllm.model_executor.layers.rejection_sampler.RejectionSampler'>
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:20 utils.py:719] Found nccl from library libnccl.so.2
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:20 pynccl.py:63] vLLM is using nccl==2.20.5
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:21 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:22 weight_utils.py:218] Using model weights format ['*.safetensors']
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:23 model_runner.py:234] Loading model weights took 3.2752 GB
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:23 weight_utils.py:261] No model.safetensors.index.json found in remote.
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:22 selector.py:186] Cannot use FlashAttention-2 backend for head size 80.�[32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)�[0m
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:22 selector.py:53] Using XFormers backend.�[32m [repeated 3x across cluster]�[0m
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:20 smaller_tp_proposer_worker.py:39] Wrapping {<class 'vllm.spec_decode.mlp_speculator_worker.MLPSpeculatorWorker'>} in {<class 'vllm.spec_decode.smaller_tp_proposer_worker.SmallerTpProposerWorker'>}
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:20 spec_decode_worker.py:141] Configuring SpecDecodeWorker with proposer=<class 'vllm.spec_decode.smaller_tp_proposer_worker.SmallerTpProposerWorker'>
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:20 spec_decode_worker.py:155] Configuring SpecDecodeWorker with sampler=<class 'vllm.model_executor.layers.rejection_sampler.RejectionSampler'>
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:20 utils.py:719] Found nccl from library libnccl.so.2
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:20 pynccl.py:63] vLLM is using nccl==2.20.5
�[36m(RayWorkerWrapper pid=401734)�[0m INFO 07-02 05:17:21 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:23 weight_utils.py:218] Using model weights format ['*.safetensors']�[32m [repeated 2x across cluster]�[0m
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:29 distributed_gpu_executor.py:56] # GPU blocks: 25256, # CPU blocks: 1638
�[36m(_AsyncLLMEngine pid=401551)�[0m INFO 07-02 05:17:28 model_runner.py:234] Loading model weights took 3.8946 GB�[32m [repeated 2x across cluster]�[0m
�[36m(_AsyncLLMEngine pid=401551)�[0m WARNING 07-02 05:17:31 multi_step.py:57] Prompt logprob is not supported by multi step workers. (e.g., speculative decode uses multi step workers).
INFO 07-02 05:17:32 async_llm_engine.py:134] Finished request 370bd3e9ffd64db6b48c00b3222e5335.
INFO 07-02 05:17:32 async_llm_engine.py:50] Engine is gracefully shutting down.
INFO 07-02 05:17:32 async_llm_engine.py:592] Received request 2ff485e018b049379a4d64ab3138e767: prompt: 'The president of the United States is', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=32, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: None, lora_request: None.
INFO 07-02 05:17:32 async_llm_engine.py:134] Finished request 2ff485e018b049379a4d64ab3138e767.
INFO 07-02 05:17:32 async_llm_engine.py:50] Engine is gracefully shutting down.
gpu memory used (GB): 0=1.96; 1=67.69; 2=0.55; 3=0.55; 4=0.55; 5=0.55; 6=0.55; 7=0.55; 
gpu memory used (GB): 0=0.55; 1=0.55; 2=0.55; 3=0.55; 4=0.55; 5=0.55; 6=0.55; 7=0.55; 
Done waiting for free GPU memory on devices devices=[0, 1, 2, 3, 4, 5, 6, 7] (threshold_bytes/2**30=2.0) dur_s=5.00
use_async=True
Creating baseline_or_test='baseline' LLM for test_name='test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0]'. kwargs={'enforce_eager': True, 'use_v2_block_manager': True, 'tensor_parallel_size': 2, 'model': 'ibm-granite/granite-3b-code-instruct'}
INFO 07-02 05:17:41 async_llm_engine.py:592] Received request 0835bb9de9f642b8a2a033d4baf6120d: prompt: 'Hello, my name is', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=32, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: None, lora_request: None.
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:43 llm_engine.py:169] Initializing an LLM engine (v0.5.0.post1) with config: model='ibm-granite/granite-3b-code-instruct', speculative_config=None, tokenizer='ibm-granite/granite-3b-code-instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=ibm-granite/granite-3b-code-instruct)
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:48 selector.py:186] Cannot use FlashAttention-2 backend for head size 80.
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:48 selector.py:53] Using XFormers backend.
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:49 utils.py:719] Found nccl from library libnccl.so.2
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:49 pynccl.py:63] vLLM is using nccl==2.20.5
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:50 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:51 weight_utils.py:218] Using model weights format ['*.safetensors']
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:52 model_runner.py:234] Loading model weights took 3.2752 GB
�[36m(_AsyncLLMEngine pid=414703)�[0m INFO 07-02 05:17:53 distributed_gpu_executor.py:56] # GPU blocks: 26853, # CPU blocks: 1638
INFO 07-02 05:17:56 async_llm_engine.py:134] Finished request 0835bb9de9f642b8a2a033d4baf6120d.
INFO 07-02 05:17:56 async_llm_engine.py:50] Engine is gracefully shutting down.
INFO 07-02 05:17:56 async_llm_engine.py:592] Received request 6b8a3c3ad926480089eb72d055a8480d: prompt: 'The president of the United States is', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=32, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: None, lora_request: None.
INFO 07-02 05:17:56 async_llm_engine.py:134] Finished request 6b8a3c3ad926480089eb72d055a8480d.
INFO 07-02 05:17:56 async_llm_engine.py:50] Engine is gracefully shutting down.
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:50 selector.py:186] Cannot use FlashAttention-2 backend for head size 80.�[32m [repeated 3x across cluster]�[0m
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:50 selector.py:53] Using XFormers backend.�[32m [repeated 3x across cluster]�[0m
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:49 utils.py:719] Found nccl from library libnccl.so.2
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:49 pynccl.py:63] vLLM is using nccl==2.20.5
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:50 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:51 weight_utils.py:218] Using model weights format ['*.safetensors']
�[36m(RayWorkerWrapper pid=414896)�[0m INFO 07-02 05:17:52 model_runner.py:234] Loading model weights took 3.2752 GB
i=0 baseline_token_ids=[428, 10195, 3270, 79, 461, 439, 3860, 312, 428, 10195, 11933, 997, 439, 3860, 3097, 285, 332, 2625, 428, 10195, 21694, 432, 1283, 1668, 1125, 997, 439, 3860, 5182, 9696, 436, 537]
i=0     spec_token_ids=[428, 10195, 3270, 79, 461, 439, 3860, 312, 428, 10195, 11933, 79, 821, 428, 10195, 16517, 997, 439, 3860, 3097, 285, 332, 2625, 428, 10195, 21694, 432, 1283, 1668, 1125, 79, 461]
----------------------------- Captured stderr call -----------------------------
2024-07-02 05:17:10,996	INFO worker.py:1771 -- Started a local Ray instance.
2024-07-02 05:17:40,330	INFO worker.py:1771 -- Started a local Ray instance.
=========================== short test summary info ============================
FAILED tests/spec_decode/e2e/test_integration_dist_tp2.py::test_draft_model_tp_lt_target_model_tp2[1-2-per_test_common_llm_kwargs1-test_llm_kwargs1-baseline_llm_kwargs0-common_llm_kwargs0]
======================= 1 failed, 3 deselected in 49.26s =======================

auto-merge was automatically disabled July 2, 2024 05:45

Head branch was pushed to by a user without write access

@njhill
Copy link
Member

njhill commented Jul 2, 2024

@sirejdua looks like the tests area all passing now. The prior failure looks precision-related, if it happens again we can consider having affected tests run with fp32.

@njhill njhill merged commit 15aba08 into vllm-project:main Jul 2, 2024
prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 3, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 7, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…llm-project#6050)

Co-authored-by: Sirej Dua <[email protected]>
Co-authored-by: Sirej Dua <Sirej Dua>
Signed-off-by: Alvant <[email protected]>
@962086838
Copy link

@sirejdua Great work! Any schedule for completing the second part which supports parallel MLPSpeculator?

LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…llm-project#6050)

Co-authored-by: Sirej Dua <[email protected]>
Co-authored-by: Sirej Dua <Sirej Dua>
Signed-off-by: LeiWang1999 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: MLPSpeculator Tensor Parallel support

5 participants