Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
441 commits
Select commit Hold shift + click to select a range
a14d3a4
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jun 25, 2025
1716f07
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jun 25, 2025
b429c10
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
fbdb595
comment Re: output tokens list ref
afeldman-nm Jun 25, 2025
e3dc71e
Merge branch 'logitsprocs' into logitsprocs_merge
afeldman-nm Jun 25, 2025
aa4c519
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
3ae8a6b
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jun 25, 2025
d58bf24
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jun 25, 2025
77bba48
refactor
afeldman-nm Jun 25, 2025
890a9cd
refactor
afeldman-nm Jun 25, 2025
6b3ea9f
Update vllm/v1/sample/logits_processor.py
afeldman-nm Jun 25, 2025
8a8f9c2
wip
afeldman-nm Jun 25, 2025
070d71d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
5384732
feedback
afeldman-nm Jun 25, 2025
9aebc9f
Update vllm/v1/sample/sampler.py
afeldman-nm Jun 25, 2025
8bb6bf0
revert some changes
afeldman-nm Jun 25, 2025
0a88e16
refactor
afeldman-nm Jun 25, 2025
18721da
Merge branch 'logitsprocs' of https://github.com/neuralmagic/vllm int…
afeldman-nm Jun 25, 2025
dc0b23a
refactor
afeldman-nm Jun 25, 2025
21ad212
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
2f0de77
argmax_invariant
afeldman-nm Jun 25, 2025
8d97a7c
batch update builder impl
afeldman-nm Jun 25, 2025
2abd24d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
d1c6607
refactor
afeldman-nm Jun 25, 2025
9fe0bc3
wip dict removal
afeldman-nm Jun 25, 2025
aa18e8f
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 25, 2025
f7a162c
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 26, 2025
de81e42
updated unit tests
afeldman-nm Jun 26, 2025
20928f0
refactor
afeldman-nm Jun 26, 2025
a0e5398
iterators
afeldman-nm Jun 26, 2025
d4704d7
refactor
afeldman-nm Jun 26, 2025
729729d
reorg
afeldman-nm Jun 27, 2025
9948fd3
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 27, 2025
bc48f38
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 27, 2025
9eeea03
feedback
afeldman-nm Jun 28, 2025
1078a24
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 28, 2025
cd766a4
feedback
afeldman-nm Jun 28, 2025
2628f98
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 30, 2025
2ecb37d
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jun 30, 2025
64ac2cf
input batch tests
afeldman-nm Jul 1, 2025
4da82cc
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
bd62df4
refactor
afeldman-nm Jul 1, 2025
8455bb6
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
a6dc218
attempted fmt fix
afeldman-nm Jul 1, 2025
a870259
wip
afeldman-nm Jul 1, 2025
072ee00
wip
afeldman-nm Jul 1, 2025
55fd6e7
fixed cancellation bug
afeldman-nm Jul 1, 2025
6d4e073
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm Jul 1, 2025
ab3a985
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 1, 2025
b55f88e
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
348a100
Merge branch 'logitsprocs' into lp_ext
afeldman-nm Jul 1, 2025
c397e24
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 1, 2025
1217b74
wip
afeldman-nm Jul 1, 2025
402d012
Update vllm/v1/worker/gpu_model_runner.py
afeldman-nm Jul 1, 2025
7c15b43
CLI
afeldman-nm Jul 1, 2025
06fc926
pr feedback
afeldman-nm Jul 1, 2025
8d229ed
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
4d0b612
Merge branch 'logitsprocs' into lp_ext
afeldman-nm Jul 1, 2025
4b1884b
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 1, 2025
99c0c18
skeleton of example
afeldman-nm Jul 1, 2025
aabd1dd
fixes
afeldman-nm Jul 1, 2025
63b640c
wip
afeldman-nm Jul 1, 2025
45dade4
mem util
afeldman-nm Jul 1, 2025
d377a6b
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
6ae7574
memory util
afeldman-nm Jul 1, 2025
5203324
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
68aab25
Merge branch 'main' into logitsprocs_merge
afeldman-nm Jul 1, 2025
066736d
merge'
afeldman-nm Jul 2, 2025
31597e9
Merge branch 'logitsprocs' into lp_ext
afeldman-nm Jul 2, 2025
957bd86
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 2, 2025
3a5564d
refactor
afeldman-nm Jul 2, 2025
663bff1
refactor
afeldman-nm Jul 3, 2025
69c2a0d
merge
afeldman-nm Jul 3, 2025
538c378
Merge branch 'main' into lp_ext
afeldman-nm Jul 3, 2025
270b184
wip
afeldman-nm Jul 3, 2025
195f651
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 3, 2025
f9df850
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 3, 2025
fc9c308
py llm plumbing
afeldman-nm Jul 3, 2025
3aa383e
wip lp example
afeldman-nm Jul 3, 2025
b420aac
wip
afeldman-nm Jul 7, 2025
a475fe9
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 7, 2025
01d640c
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 7, 2025
138dc07
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 7, 2025
699768a
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm Jul 7, 2025
ee88fdf
Merge branch 'main' into lp_ext_py
afeldman-nm Jul 7, 2025
6a405ab
first pass at lp loading system
afeldman-nm Jul 7, 2025
0de1e73
wip
afeldman-nm Jul 8, 2025
ef51732
Merge branch 'main' into lp_ext_py
afeldman-nm Jul 8, 2025
c8e8671
loading logitsprocs
afeldman-nm Jul 8, 2025
52146dc
refactor
afeldman-nm Jul 8, 2025
e79f9ad
lp tests passing
afeldman-nm Jul 8, 2025
4c16135
Merge branch 'main' into lp_ext_py
afeldman-nm Jul 8, 2025
2e330e1
refactor
afeldman-nm Jul 8, 2025
7a60363
logitsprocs
afeldman-nm Jul 8, 2025
18129b4
example w/ dummy logitproc
afeldman-nm Jul 8, 2025
e73c00c
refactor
afeldman-nm Jul 8, 2025
f612fcf
Merge branch 'main' into lp_ext
afeldman-nm Jul 8, 2025
4af5159
entrypoint example
afeldman-nm Jul 8, 2025
be7177a
cli arg
afeldman-nm Jul 8, 2025
0ad8b1c
removed regex
afeldman-nm Jul 8, 2025
c21a2ec
fqn/entrypoint examples
afeldman-nm Jul 8, 2025
4730d7a
cli tests
afeldman-nm Jul 8, 2025
1617747
Merge branch 'main' into lp_ext
afeldman-nm Jul 8, 2025
1784079
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 10, 2025
f078ce7
tail end of merge
afeldman-nm Jul 10, 2025
129479a
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 10, 2025
ac1509f
refactor
afeldman-nm Jul 10, 2025
5b85255
wip
afeldman-nm Jul 10, 2025
d7499db
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 10, 2025
ee74904
all lp plugins are loaded; can pass lp types to LLM; refactor
afeldman-nm Jul 11, 2025
be9e750
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 11, 2025
4f80bee
unit test fix
afeldman-nm Jul 11, 2025
683b99f
typo
afeldman-nm Jul 11, 2025
f7ee5ee
refactor
afeldman-nm Jul 11, 2025
ad08c45
refactor
afeldman-nm Jul 11, 2025
bbabe50
abstract __init__
abf149 Jul 14, 2025
9e88f37
fqn
abf149 Jul 14, 2025
ae55b2e
merge
abf149 Jul 14, 2025
d08a89d
type checking
abf149 Jul 14, 2025
e7cb8e1
merge
afeldman-nm Jul 16, 2025
aae02a0
small fix
afeldman-nm Jul 16, 2025
2d807af
fixes
afeldman-nm Jul 16, 2025
c0cdd27
merge
afeldman-nm Jul 16, 2025
2a79d4a
merge
afeldman-nm Jul 16, 2025
7c49fe1
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 16, 2025
84112c0
Merge branch 'lp_ext_merge' into lp_ext
afeldman-nm Jul 16, 2025
bd243c9
refactor
afeldman-nm Jul 16, 2025
07d6056
fix
afeldman-nm Jul 16, 2025
83eca33
fix test bug
afeldman-nm Jul 16, 2025
6a4597b
cli test works again
afeldman-nm Jul 16, 2025
9d2156c
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 17, 2025
8c2d16c
LLM entrypoint testing
afeldman-nm Jul 17, 2025
35999c8
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 17, 2025
e6123e7
cli entrypoints test
afeldman-nm Jul 17, 2025
da8aa76
fixed example
afeldman-nm Jul 17, 2025
12d48a7
adding prompt tokens to added requests
afeldman-nm Jul 17, 2025
c4a76be
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 17, 2025
4a57631
initial feedback
afeldman-nm Jul 17, 2025
6697a30
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 17, 2025
5c350a3
wip
afeldman-nm Jul 17, 2025
4aa5c86
merge load.py into __init__.py
afeldman-nm Jul 17, 2025
d3099b4
refactor
afeldman-nm Jul 17, 2025
f6cbcad
resetting test.txt
afeldman-nm Jul 17, 2025
ffbc6f2
logitsprocs in input batch
afeldman-nm Jul 17, 2025
ea3c970
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 17, 2025
90472d8
merge
afeldman-nm Jul 22, 2025
961c863
wip feedback
afeldman-nm Jul 22, 2025
c50b51a
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 22, 2025
89618f1
prompt token ids ordering
afeldman-nm Jul 22, 2025
bb00b9a
Merge branch 'main' into lp_ext_merge
abf149 Jul 23, 2025
eb18dc1
merge
abf149 Jul 24, 2025
85bc835
refactor
abf149 Jul 24, 2025
4e740c5
merge
abf149 Jul 24, 2025
c60c39d
Merge branch 'main' into lp_ext_merge
abf149 Jul 25, 2025
beec9cb
reorg imports
abf149 Jul 25, 2025
d1eade5
reorg API interface
abf149 Jul 25, 2025
a6af577
Merge branch 'main' into lp_ext_merge
abf149 Jul 25, 2025
f004aff
Merge branch 'lp_ext' into lp_ext_fb
abf149 Jul 28, 2025
5420fff
Merge branch 'main' into lp_ext_merge
abf149 Jul 28, 2025
02a837a
merge
afeldman-nm Jul 28, 2025
c171aa7
disabled logitsprocs for pooling
afeldman-nm Jul 28, 2025
2941f7c
Merge branch 'lp_ext' into lp_ext_merge
afeldman-nm Jul 28, 2025
191e083
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 28, 2025
ad9ce9e
merge
afeldman-nm Jul 28, 2025
304d77d
merge
abf149 Jul 30, 2025
62578b6
Merge branch 'lp_ext' of https://github.com/neuralmagic/vllm into lp_ext
abf149 Jul 30, 2025
69232fd
Merge branch 'main' into lp_ext_merge
abf149 Jul 30, 2025
38134f8
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 31, 2025
d627759
fix
afeldman-nm Jul 31, 2025
86a7492
pooling compat
afeldman-nm Jul 31, 2025
998d9a1
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 31, 2025
bc2f8af
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 31, 2025
635fb6c
CLI argument Union type support
afeldman-nm Jul 31, 2025
86b7bd0
Union type support
afeldman-nm Jul 31, 2025
622c306
unit test
afeldman-nm Jul 31, 2025
fc4bce6
cleanup
afeldman-nm Jul 31, 2025
b743042
removed separate union test; added test cases
afeldman-nm Jul 31, 2025
470bcc9
ValueError
afeldman-nm Jul 31, 2025
1311f41
Merge branch 'cli_union_type' into lp_ext
afeldman-nm Jul 31, 2025
ccc7bc9
Fix `get_kwargs` for case where type hint is `list[Union[str, type]]`
hmellor Jul 31, 2025
2035018
reverting my fix
afeldman-nm Jul 31, 2025
b539c5a
Merge branch 'fix-list-of-union' into lp_ext
afeldman-nm Jul 31, 2025
8330bcf
cleanup
afeldman-nm Jul 31, 2025
e278f9e
import fixes
afeldman-nm Jul 31, 2025
dd7b316
Merge branch 'main' into lp_ext_merge
afeldman-nm Jul 31, 2025
fd3bbea
refactor
afeldman-nm Jul 31, 2025
7601057
online test refactor
afeldman-nm Jul 31, 2025
66b38df
merge
abf149 Aug 5, 2025
22fa931
feedback
abf149 Aug 5, 2025
729b960
linting fix
abf149 Aug 5, 2025
ae38cfa
linting
abf149 Aug 5, 2025
0e0a997
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 6, 2025
2b45e0f
added request structure
afeldman-nm Aug 6, 2025
ea38c26
dummy lp
afeldman-nm Aug 6, 2025
cbd33b9
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 6, 2025
ee633bf
always be refreshing
afeldman-nm Aug 6, 2025
16616c7
Exception if custom logitsprocs are provided to pooling model
afeldman-nm Aug 6, 2025
66f6254
offline pooling/logitproc incompat test
afeldman-nm Aug 6, 2025
04d91de
wip
afeldman-nm Aug 6, 2025
c65bde5
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 6, 2025
3f9444a
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 6, 2025
5f91ae9
revert online test changes
afeldman-nm Aug 6, 2025
ae00852
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 6, 2025
482c43f
refactor
afeldman-nm Aug 6, 2025
38c156e
wip
afeldman-nm Aug 6, 2025
38a8d50
wip
afeldman-nm Aug 12, 2025
c6680b6
merge
afeldman-nm Aug 12, 2025
632b6fc
Merge branch 'lp_ext' into lp_ext_entrypoint
afeldman-nm Aug 12, 2025
1a98864
fixes
afeldman-nm Aug 12, 2025
de43449
DummyLogitsProcessor early exist if there are no applicable requests
afeldman-nm Aug 12, 2025
aed6746
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 12, 2025
b6ecdc5
Merge branch 'lp_ext' into lp_ext_entrypoint
afeldman-nm Aug 12, 2025
0efab23
wip
afeldman-nm Aug 12, 2025
1cfd5c4
wip
afeldman-nm Aug 12, 2025
712bdb7
swap fix
afeldman-nm Aug 12, 2025
343f454
Merge branch 'lp_ext' into lp_ext_fork
afeldman-nm Aug 12, 2025
08c2ded
wip
afeldman-nm Aug 12, 2025
50ab032
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 12, 2025
aec7944
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 13, 2025
3a4597d
working multiprocessing.Process solution
afeldman-nm Aug 14, 2025
fa73c7b
refactor
afeldman-nm Aug 14, 2025
9ad80b7
Merge branch 'lp_ext' into lp_ext_mp
afeldman-nm Aug 14, 2025
0bee33c
refactor
afeldman-nm Aug 14, 2025
4a54ec7
refactor
afeldman-nm Aug 14, 2025
12ada08
wip
afeldman-nm Aug 14, 2025
9f97241
online/offline working; fixed entrypoints bug
afeldman-nm Aug 14, 2025
de3adaf
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 14, 2025
1c9e839
refactor
afeldman-nm Aug 14, 2025
d8e2cb5
linting
afeldman-nm Aug 14, 2025
cdf4a76
ci bugs
afeldman-nm Aug 14, 2025
6e920ee
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 14, 2025
27cd95c
simplify handling of pooling models in gpu_model_runner / gpu_input_b…
njhill Aug 14, 2025
bc92fc8
small fix
afeldman-nm Aug 15, 2025
5bd6a77
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
97a8f2b
refactor
afeldman-nm Aug 15, 2025
b283a0d
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
29c6544
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
5a2cb53
refactor
afeldman-nm Aug 15, 2025
35f10ed
online tests passing
afeldman-nm Aug 15, 2025
1457af0
offline tests pass
afeldman-nm Aug 15, 2025
d1e94f9
refactor
afeldman-nm Aug 15, 2025
5751bde
refactor
afeldman-nm Aug 15, 2025
fa76da6
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
f9cb99f
Merge branch 'njhill-lp_ext' into lp_ext_nick
afeldman-nm Aug 15, 2025
4e624bd
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
a53e3da
linting fixes
afeldman-nm Aug 15, 2025
37a340b
mypy block
afeldman-nm Aug 15, 2025
507b4f8
Merge branch 'main' into lp_ext_merge
afeldman-nm Aug 15, 2025
27de87a
revert input_batch.make_sampling_metadata back to private naming
njhill Aug 15, 2025
7cdd389
fix test isolation
njhill Aug 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ steps:
- pytest -v -s v1/engine
- pytest -v -s v1/entrypoints
- pytest -v -s v1/sample
- pytest -v -s v1/logits_processors
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode
Expand Down
147 changes: 147 additions & 0 deletions examples/offline_inference/logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""This example demonstrates instantiating vLLM with a custom logits processor
class object.

For a basic example of implementing a custom logits processor, see
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.

For testing purposes, a dummy logits processor is employed which, if
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
will mask out all tokens except `target_token`.

A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect the `target_token` to be decoded in each step, yielding an output
similar to that shown below:

Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""

from typing import Optional

import torch

from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)


# Hypothetical custom logits processor
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""

def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
self.req_info: dict[int, SamplingParams] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we assing self.req_info[index] = target_token, should that type hint be int instead of SamplingParams?

Suggested change
self.req_info: dict[int, SamplingParams] = {}
self.req_info: dict[int, int] = {}


def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False

def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return

# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (
target_token := params.extra_args.get("target_token")
):
self.req_info[index] = target_token

if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)

# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val

def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits

# Save target values before modification
rows_list = list(self.req_info.keys())
cols = torch.tensor(
[self.req_info[i] for i in rows_list],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should work too, right?

Suggested change
[self.req_info[i] for i in rows_list],
list(self.req_info.values()),

dtype=torch.long,
device=logits.device,
)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
values_to_keep = logits[rows, cols].clone()

# Mask all but target tokens
logits[rows] = float("-inf")
logits[rows, cols] = values_to_keep

return logits


# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]


def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[DummyLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)


if __name__ == "__main__":
main()
79 changes: 66 additions & 13 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import time
import warnings
from contextlib import contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union

Expand Down Expand Up @@ -76,6 +77,23 @@ def _nvml():
class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key

def _start_server(self, model: str, vllm_serve_args: list[str],
env_dict: Optional[dict[str, str]]) -> None:
"""Subclasses override this method to customize server process launch
"""
env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc: subprocess.Popen = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)

def __init__(self,
model: str,
vllm_serve_args: list[str],
Expand Down Expand Up @@ -128,18 +146,7 @@ def __init__(self,
model_loader = get_model_loader(load_config)
model_loader.download_model(model_config)

env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._start_server(model, vllm_serve_args, env_dict)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds)
Expand All @@ -155,6 +162,10 @@ def __exit__(self, exc_type, exc_value, traceback):
# force kill if needed
self.proc.kill()

def _poll(self) -> Optional[int]:
"""Subclasses override this method to customize process polling"""
return self.proc.poll()

def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
Expand All @@ -169,7 +180,7 @@ def _wait_for_server(self, *, url: str, timeout: float):
# which means the server is not ready yet.
# the stack trace is not useful, so we suppress it
# by using `raise from None`.
result = self.proc.poll()
result = self._poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None

Expand Down Expand Up @@ -205,6 +216,48 @@ def get_async_client(self, **kwargs):
**kwargs)


class RemoteOpenAIServerCustom(RemoteOpenAIServer):
"""Launch test server with custom child process"""

def _start_server(self, model: str, vllm_serve_args: list[str],
env_dict: Optional[dict[str, str]]) -> None:
self.proc: Process = Process(
target=self.child_process_fxn,
args=(env_dict, model,
vllm_serve_args)) # type: ignore[assignment]
self.proc.start()

def __init__(self,
model: str,
vllm_serve_args: list[str],
child_process_fxn: Callable[
[Optional[dict[str, str]], str, list[str]], None],
*,
env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
"""Store custom child process function then invoke superclass
constructor which will indirectly launch it."""
self.child_process_fxn = child_process_fxn
super().__init__(model=model,
vllm_serve_args=vllm_serve_args,
env_dict=env_dict,
seed=seed,
auto_port=auto_port,
max_wait_seconds=max_wait_seconds)

def _poll(self) -> Optional[int]:
return self.proc.exitcode

def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
self.proc.join(8)
if self.proc.is_alive():
# force kill if needed
self.proc.kill()


def _test_completion(
client: openai.OpenAI,
model: str,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import pytest
import torch

from tests.utils import create_new_process_for_each_test
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
Expand All @@ -24,7 +26,7 @@
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
build_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata

Expand Down Expand Up @@ -53,6 +55,7 @@ class LogitsProcsRequestParams:
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
prompt_tokens: list[int] # Dummy prompt tokens placeholder
params: SamplingParams # Settings customized for logitproc

def __init__(self, workload_index: int, logitproc_type: LogitprocType):
Expand All @@ -63,6 +66,7 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType):
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type)

def __str__(self):
Expand All @@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata(
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
logitsprocs = init_builtin_logitsprocs(
pin_memory_available=PIN_MEMORY_AVAILABLE,
max_num_reqs=MAX_NUM_REQS + 1,
device=device)

logitsprocs = build_logitsprocs(
vllm_config=VllmConfig(),
device=device,
is_pin_memory=PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
Expand Down Expand Up @@ -462,15 +467,17 @@ def _generate_fake_step_update(
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
(add_remove_idx, add_req_params.params,
add_req_params.prompt_tokens, add_req_params.out_tokens))
persistent_batch[add_remove_idx] = add_req_params

# Append remaining added requests to end of batch
add_reqs_append = workload_params[(wdx +
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
persistent_batch.extend(add_reqs_append)
Expand Down Expand Up @@ -561,6 +568,7 @@ def _assert_valid(
step_idx=step_idx)


@create_new_process_for_each_test()
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
Expand Down
Loading