Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f97e0ae
added example
afeldman-nm Aug 21, 2024
f969241
wip:
afeldman-nm Aug 21, 2024
642d31b
first working attempt at logprobs
afeldman-nm Aug 21, 2024
a0ca262
merge; format
afeldman-nm Aug 21, 2024
ed97288
passing test; dataclass
afeldman-nm Aug 21, 2024
861e1b9
refactoring
afeldman-nm Aug 21, 2024
8bc0765
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
a34d1ac
refactoring
afeldman-nm Aug 21, 2024
4cda5c0
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 21, 2024
ac8a39a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
1284327
removing example
afeldman-nm Aug 21, 2024
a6c1207
removed example from build pipeline
afeldman-nm Aug 21, 2024
fe42995
fixed one docstring; embedded NUM_LOGPROBS
afeldman-nm Aug 21, 2024
9fb5bbe
test refactor
afeldman-nm Aug 21, 2024
046a8b1
incremental refactors
afeldman-nm Aug 21, 2024
fa86efd
remove unnecessary conftest change
afeldman-nm Aug 21, 2024
1c0ffb6
Update vllm/model_executor/layers/sampler.py
afeldman-nm Aug 21, 2024
3babadb
refactor
afeldman-nm Aug 21, 2024
f502029
Merge branch 'afeldman-nm/logprobs' of https://github.com/neuralmagic…
afeldman-nm Aug 21, 2024
1875b37
test_multi_step comment
afeldman-nm Aug 21, 2024
3760a95
utils function docstrings
afeldman-nm Aug 21, 2024
d43308c
docstring refactors
afeldman-nm Aug 21, 2024
54db498
merge
afeldman-nm Aug 21, 2024
dfbbaf0
passing tests & formatted
afeldman-nm Aug 21, 2024
5eebfca
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
5e23d9a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 22, 2024
717efa3
merge; format
afeldman-nm Aug 22, 2024
e0d59ce
removed incorrect SamplerOutput imports
afeldman-nm Aug 22, 2024
102fd92
formatting
afeldman-nm Aug 22, 2024
948f4ef
Update tests/multi_step/test_correctness.py
afeldman-nm Aug 22, 2024
6e6711f
fixed comment
afeldman-nm Aug 22, 2024
f61163e
merge; format
afeldman-nm Aug 23, 2024
1cc93dd
rename
afeldman-nm Aug 23, 2024
4995204
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 23, 2024
da5826b
test modification
afeldman-nm Aug 26, 2024
d4fb430
merge; format
afeldman-nm Aug 26, 2024
b6752e0
merge
afeldman-nm Aug 27, 2024
1e42656
formatting
afeldman-nm Aug 27, 2024
cd0fdf9
disabled logprobs pythonization when logprobs are disabled
afeldman-nm Aug 27, 2024
3fecbc4
wip
afeldman-nm Aug 27, 2024
67bd035
skip logprobs processing entirely when logprobs are not enabled; form…
afeldman-nm Aug 27, 2024
419659d
multi-step output processing; formatting
afeldman-nm Aug 27, 2024
55eaab9
wip
afeldman-nm Aug 27, 2024
bae1fb9
small fixes
afeldman-nm Aug 27, 2024
fbb75b7
reverting to no prompt-logprobs support; merged in main
afeldman-nm Aug 28, 2024
63c5582
timeout increase
afeldman-nm Aug 28, 2024
8191571
refactoring
afeldman-nm Aug 28, 2024
9a708f8
Merge branch 'main' into logprobs_no_prompt
afeldman-nm Aug 28, 2024
e54606d
upstream merge
afeldman-nm Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import SampleLogprobs
from vllm.sequence import Logprob, SampleLogprobs

TokensText = Tuple[List[int], str]

Expand Down Expand Up @@ -38,34 +38,39 @@ def check_outputs_equal(
float]],
SampleLogprobs]]]

# Allow for tokens to be represented as str's rather than IDs
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]


def check_logprobs_close(
*,
outputs_0_lst: Sequence[TokensTextLogprobs],
outputs_1_lst: Sequence[TokensTextLogprobs],
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
always_check_logprobs: bool = False,
) -> None:
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.

Arguments:

* outputs_0_lst: First sequence to compare
* outputs_0_lst: Second sequence to compare
* name_0: sequence #0 name
* name_1: sequence #1 name
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
name_0: sequence #0 name
name_1: sequence #1 name
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
"""
assert len(outputs_0_lst) == len(outputs_1_lst)

Expand Down Expand Up @@ -94,8 +99,12 @@ def check_logprobs_close(
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

# If generated tokens don't match, then
if output_id_0 != output_id_1:
is_tok_mismatch = output_id_0 != output_id_1

# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if is_tok_mismatch or always_check_logprobs:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]

Expand All @@ -111,7 +120,7 @@ def check_logprobs_close(
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg

if warn_on_mismatch:
if warn_on_mismatch and is_tok_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
Expand Down
99 changes: 69 additions & 30 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Test the AsyncLLMEngine with multi-step-decoding

from typing import List
from typing import List, Optional

import pytest

from ..utils import RemoteOpenAIServer
from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations,
get_client_text_logprob_generations)

MODELS = [
"JackFram/llama-160m",
Expand All @@ -23,22 +25,6 @@
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
async with server.get_async_client() as client:
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
Expand All @@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int,
is_async: bool):
async def test_multi_step(
example_prompts,
model: str,
tp_size: int,
pp_size: int,
eager_mode: int,
num_scheduler_steps: int,
num_prompts: int,
is_async: bool,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.

Set up an engine with single-step scheduling as a ground-truth reference.

Send a completions API request to both engines with the same prompts.

Validate:
* Generated tokens match
* Generated logprobs are all very close

Args:
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
tp_size: degree of tensor-parallelism
pp_size: degree of pipeline-parallelism
eager_mode
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
str(pp_size),
]

# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to
# observed timeouts in GHA CI
ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
prompts,
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
prompts,
model,
ms_server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations = get_client_text_generations(ref_completions)
test_generations = get_client_text_generations(test_completions)
assert ref_generations == test_generations

# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling.
ref_text_logprobs = get_client_text_logprob_generations(ref_completions)
test_text_logprobs = get_client_text_logprob_generations(test_completions)
check_logprobs_close(
outputs_0_lst=ref_text_logprobs,
outputs_1_lst=test_text_logprobs,
name_0="hf",
name_1="vllm",
)
95 changes: 74 additions & 21 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Test the LLMEngine with multi-step-decoding

from typing import Optional

import pytest

from ..models.utils import check_outputs_equal
from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
"JackFram/llama-160m",
Expand All @@ -18,32 +20,83 @@
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, tp_size: int, max_tokens: int,
enforce_eager: int, num_scheduler_steps: int,
num_prompts: int) -> None:
@pytest.mark.parametrize("num_logprobs", [None, 5])
def test_multi_step_llm(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.

Set up a HuggingFace (HF) transformers model as a ground-truth reference.

Prompt them with the same example prompts.

Validate:
* Generated tokens match
* Generated logprobs are all very close

Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs))

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
if num_logprobs is None else
hf_model.generate_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs))

if num_logprobs is None:
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
else:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
3 changes: 2 additions & 1 deletion tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids)
get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Expand Down
4 changes: 2 additions & 2 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
Expand Down
5 changes: 3 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)

from .core.utils import create_dummy_prompt

Expand Down
Loading