-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[V1] Logits processors extensibility #19912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
441 commits
Select commit
Hold shift + click to select a range
a14d3a4
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm 1716f07
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm b429c10
Merge branch 'main' into logitsprocs_merge
afeldman-nm fbdb595
comment Re: output tokens list ref
afeldman-nm e3dc71e
Merge branch 'logitsprocs' into logitsprocs_merge
afeldman-nm aa4c519
Merge branch 'main' into logitsprocs_merge
afeldman-nm 3ae8a6b
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm d58bf24
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm 77bba48
refactor
afeldman-nm 890a9cd
refactor
afeldman-nm 6b3ea9f
Update vllm/v1/sample/logits_processor.py
afeldman-nm 8a8f9c2
wip
afeldman-nm 070d71d
Merge branch 'main' into logitsprocs_merge
afeldman-nm 5384732
feedback
afeldman-nm 9aebc9f
Update vllm/v1/sample/sampler.py
afeldman-nm 8bb6bf0
revert some changes
afeldman-nm 0a88e16
refactor
afeldman-nm 18721da
Merge branch 'logitsprocs' of https://github.com/neuralmagic/vllm int…
afeldman-nm dc0b23a
refactor
afeldman-nm 21ad212
Merge branch 'main' into logitsprocs_merge
afeldman-nm 2f0de77
argmax_invariant
afeldman-nm 8d97a7c
batch update builder impl
afeldman-nm 2abd24d
Merge branch 'main' into logitsprocs_merge
afeldman-nm d1c6607
refactor
afeldman-nm 9fe0bc3
wip dict removal
afeldman-nm aa18e8f
Merge branch 'main' into logitsprocs_merge
afeldman-nm f7a162c
Merge branch 'main' into logitsprocs_merge
afeldman-nm de81e42
updated unit tests
afeldman-nm 20928f0
refactor
afeldman-nm a0e5398
iterators
afeldman-nm d4704d7
refactor
afeldman-nm 729729d
reorg
afeldman-nm 9948fd3
Merge branch 'main' into logitsprocs_merge
afeldman-nm bc48f38
Merge branch 'main' into logitsprocs_merge
afeldman-nm 9eeea03
feedback
afeldman-nm 1078a24
Merge branch 'main' into logitsprocs_merge
afeldman-nm cd766a4
feedback
afeldman-nm 2628f98
Merge branch 'main' into logitsprocs_merge
afeldman-nm 2ecb37d
Merge branch 'main' into logitsprocs_merge
afeldman-nm 64ac2cf
input batch tests
afeldman-nm 4da82cc
Merge branch 'main' into logitsprocs_merge
afeldman-nm bd62df4
refactor
afeldman-nm 8455bb6
Merge branch 'main' into logitsprocs_merge
afeldman-nm a6dc218
attempted fmt fix
afeldman-nm a870259
wip
afeldman-nm 072ee00
wip
afeldman-nm 55fd6e7
fixed cancellation bug
afeldman-nm 6d4e073
Merge branch 'logitsprocs' into lp_ext_merge
afeldman-nm ab3a985
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm b55f88e
Merge branch 'main' into logitsprocs_merge
afeldman-nm 348a100
Merge branch 'logitsprocs' into lp_ext
afeldman-nm c397e24
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm 1217b74
wip
afeldman-nm 402d012
Update vllm/v1/worker/gpu_model_runner.py
afeldman-nm 7c15b43
CLI
afeldman-nm 06fc926
pr feedback
afeldman-nm 8d229ed
Merge branch 'main' into logitsprocs_merge
afeldman-nm 4d0b612
Merge branch 'logitsprocs' into lp_ext
afeldman-nm 4b1884b
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm 99c0c18
skeleton of example
afeldman-nm aabd1dd
fixes
afeldman-nm 63b640c
wip
afeldman-nm 45dade4
mem util
afeldman-nm d377a6b
Merge branch 'main' into logitsprocs_merge
afeldman-nm 6ae7574
memory util
afeldman-nm 5203324
Merge branch 'main' into logitsprocs_merge
afeldman-nm 68aab25
Merge branch 'main' into logitsprocs_merge
afeldman-nm 066736d
merge'
afeldman-nm 31597e9
Merge branch 'logitsprocs' into lp_ext
afeldman-nm 957bd86
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm 3a5564d
refactor
afeldman-nm 663bff1
refactor
afeldman-nm 69c2a0d
merge
afeldman-nm 538c378
Merge branch 'main' into lp_ext
afeldman-nm 270b184
wip
afeldman-nm 195f651
Merge branch 'main' into lp_ext_merge
afeldman-nm f9df850
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm fc9c308
py llm plumbing
afeldman-nm 3aa383e
wip lp example
afeldman-nm b420aac
wip
afeldman-nm a475fe9
Merge branch 'main' into lp_ext_merge
afeldman-nm 01d640c
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm 138dc07
Merge branch 'main' into lp_ext_merge
afeldman-nm 699768a
Merge branch 'lp_ext' into lp_ext_py
afeldman-nm ee88fdf
Merge branch 'main' into lp_ext_py
afeldman-nm 6a405ab
first pass at lp loading system
afeldman-nm 0de1e73
wip
afeldman-nm ef51732
Merge branch 'main' into lp_ext_py
afeldman-nm c8e8671
loading logitsprocs
afeldman-nm 52146dc
refactor
afeldman-nm e79f9ad
lp tests passing
afeldman-nm 4c16135
Merge branch 'main' into lp_ext_py
afeldman-nm 2e330e1
refactor
afeldman-nm 7a60363
logitsprocs
afeldman-nm 18129b4
example w/ dummy logitproc
afeldman-nm e73c00c
refactor
afeldman-nm f612fcf
Merge branch 'main' into lp_ext
afeldman-nm 4af5159
entrypoint example
afeldman-nm be7177a
cli arg
afeldman-nm 0ad8b1c
removed regex
afeldman-nm c21a2ec
fqn/entrypoint examples
afeldman-nm 4730d7a
cli tests
afeldman-nm 1617747
Merge branch 'main' into lp_ext
afeldman-nm 1784079
Merge branch 'main' into lp_ext_merge
afeldman-nm f078ce7
tail end of merge
afeldman-nm 129479a
Merge branch 'main' into lp_ext_merge
afeldman-nm ac1509f
refactor
afeldman-nm 5b85255
wip
afeldman-nm d7499db
Merge branch 'main' into lp_ext_merge
afeldman-nm ee74904
all lp plugins are loaded; can pass lp types to LLM; refactor
afeldman-nm be9e750
Merge branch 'main' into lp_ext_merge
afeldman-nm 4f80bee
unit test fix
afeldman-nm 683b99f
typo
afeldman-nm f7ee5ee
refactor
afeldman-nm ad08c45
refactor
afeldman-nm bbabe50
abstract __init__
abf149 9e88f37
fqn
abf149 ae55b2e
merge
abf149 d08a89d
type checking
abf149 e7cb8e1
merge
afeldman-nm aae02a0
small fix
afeldman-nm 2d807af
fixes
afeldman-nm c0cdd27
merge
afeldman-nm 2a79d4a
merge
afeldman-nm 7c49fe1
Merge branch 'main' into lp_ext_merge
afeldman-nm 84112c0
Merge branch 'lp_ext_merge' into lp_ext
afeldman-nm bd243c9
refactor
afeldman-nm 07d6056
fix
afeldman-nm 83eca33
fix test bug
afeldman-nm 6a4597b
cli test works again
afeldman-nm 9d2156c
Merge branch 'main' into lp_ext_merge
afeldman-nm 8c2d16c
LLM entrypoint testing
afeldman-nm 35999c8
Merge branch 'main' into lp_ext_merge
afeldman-nm e6123e7
cli entrypoints test
afeldman-nm da8aa76
fixed example
afeldman-nm 12d48a7
adding prompt tokens to added requests
afeldman-nm c4a76be
Merge branch 'main' into lp_ext_merge
afeldman-nm 4a57631
initial feedback
afeldman-nm 6697a30
Merge branch 'main' into lp_ext_merge
afeldman-nm 5c350a3
wip
afeldman-nm 4aa5c86
merge load.py into __init__.py
afeldman-nm d3099b4
refactor
afeldman-nm f6cbcad
resetting test.txt
afeldman-nm ffbc6f2
logitsprocs in input batch
afeldman-nm ea3c970
Merge branch 'main' into lp_ext_merge
afeldman-nm 90472d8
merge
afeldman-nm 961c863
wip feedback
afeldman-nm c50b51a
Merge branch 'main' into lp_ext_merge
afeldman-nm 89618f1
prompt token ids ordering
afeldman-nm bb00b9a
Merge branch 'main' into lp_ext_merge
abf149 eb18dc1
merge
abf149 85bc835
refactor
abf149 4e740c5
merge
abf149 c60c39d
Merge branch 'main' into lp_ext_merge
abf149 beec9cb
reorg imports
abf149 d1eade5
reorg API interface
abf149 a6af577
Merge branch 'main' into lp_ext_merge
abf149 f004aff
Merge branch 'lp_ext' into lp_ext_fb
abf149 5420fff
Merge branch 'main' into lp_ext_merge
abf149 02a837a
merge
afeldman-nm c171aa7
disabled logitsprocs for pooling
afeldman-nm 2941f7c
Merge branch 'lp_ext' into lp_ext_merge
afeldman-nm 191e083
Merge branch 'main' into lp_ext_merge
afeldman-nm ad9ce9e
merge
afeldman-nm 304d77d
merge
abf149 62578b6
Merge branch 'lp_ext' of https://github.com/neuralmagic/vllm into lp_ext
abf149 69232fd
Merge branch 'main' into lp_ext_merge
abf149 38134f8
Merge branch 'main' into lp_ext_merge
afeldman-nm d627759
fix
afeldman-nm 86a7492
pooling compat
afeldman-nm 998d9a1
Merge branch 'main' into lp_ext_merge
afeldman-nm bc2f8af
Merge branch 'main' into lp_ext_merge
afeldman-nm 635fb6c
CLI argument Union type support
afeldman-nm 86b7bd0
Union type support
afeldman-nm 622c306
unit test
afeldman-nm fc4bce6
cleanup
afeldman-nm b743042
removed separate union test; added test cases
afeldman-nm 470bcc9
ValueError
afeldman-nm 1311f41
Merge branch 'cli_union_type' into lp_ext
afeldman-nm ccc7bc9
Fix `get_kwargs` for case where type hint is `list[Union[str, type]]`
hmellor 2035018
reverting my fix
afeldman-nm b539c5a
Merge branch 'fix-list-of-union' into lp_ext
afeldman-nm 8330bcf
cleanup
afeldman-nm e278f9e
import fixes
afeldman-nm dd7b316
Merge branch 'main' into lp_ext_merge
afeldman-nm fd3bbea
refactor
afeldman-nm 7601057
online test refactor
afeldman-nm 66b38df
merge
abf149 22fa931
feedback
abf149 729b960
linting fix
abf149 ae38cfa
linting
abf149 0e0a997
Merge branch 'main' into lp_ext_merge
afeldman-nm 2b45e0f
added request structure
afeldman-nm ea38c26
dummy lp
afeldman-nm cbd33b9
Merge branch 'main' into lp_ext_merge
afeldman-nm ee633bf
always be refreshing
afeldman-nm 16616c7
Exception if custom logitsprocs are provided to pooling model
afeldman-nm 66f6254
offline pooling/logitproc incompat test
afeldman-nm 04d91de
wip
afeldman-nm c65bde5
Merge branch 'main' into lp_ext_merge
afeldman-nm 3f9444a
Merge branch 'main' into lp_ext_merge
afeldman-nm 5f91ae9
revert online test changes
afeldman-nm ae00852
Merge branch 'main' into lp_ext_merge
afeldman-nm 482c43f
refactor
afeldman-nm 38c156e
wip
afeldman-nm 38a8d50
wip
afeldman-nm c6680b6
merge
afeldman-nm 632b6fc
Merge branch 'lp_ext' into lp_ext_entrypoint
afeldman-nm 1a98864
fixes
afeldman-nm de43449
DummyLogitsProcessor early exist if there are no applicable requests
afeldman-nm aed6746
Merge branch 'main' into lp_ext_merge
afeldman-nm b6ecdc5
Merge branch 'lp_ext' into lp_ext_entrypoint
afeldman-nm 0efab23
wip
afeldman-nm 1cfd5c4
wip
afeldman-nm 712bdb7
swap fix
afeldman-nm 343f454
Merge branch 'lp_ext' into lp_ext_fork
afeldman-nm 08c2ded
wip
afeldman-nm 50ab032
Merge branch 'main' into lp_ext_merge
afeldman-nm aec7944
Merge branch 'main' into lp_ext_merge
afeldman-nm 3a4597d
working multiprocessing.Process solution
afeldman-nm fa73c7b
refactor
afeldman-nm 9ad80b7
Merge branch 'lp_ext' into lp_ext_mp
afeldman-nm 0bee33c
refactor
afeldman-nm 4a54ec7
refactor
afeldman-nm 12ada08
wip
afeldman-nm 9f97241
online/offline working; fixed entrypoints bug
afeldman-nm de3adaf
Merge branch 'main' into lp_ext_merge
afeldman-nm 1c9e839
refactor
afeldman-nm d8e2cb5
linting
afeldman-nm cdf4a76
ci bugs
afeldman-nm 6e920ee
Merge branch 'main' into lp_ext_merge
afeldman-nm 27cd95c
simplify handling of pooling models in gpu_model_runner / gpu_input_b…
njhill bc92fc8
small fix
afeldman-nm 5bd6a77
Merge branch 'main' into lp_ext_merge
afeldman-nm 97a8f2b
refactor
afeldman-nm b283a0d
Merge branch 'main' into lp_ext_merge
afeldman-nm 29c6544
Merge branch 'main' into lp_ext_merge
afeldman-nm 5a2cb53
refactor
afeldman-nm 35f10ed
online tests passing
afeldman-nm 1457af0
offline tests pass
afeldman-nm d1e94f9
refactor
afeldman-nm 5751bde
refactor
afeldman-nm fa76da6
Merge branch 'main' into lp_ext_merge
afeldman-nm f9cb99f
Merge branch 'njhill-lp_ext' into lp_ext_nick
afeldman-nm 4e624bd
Merge branch 'main' into lp_ext_merge
afeldman-nm a53e3da
linting fixes
afeldman-nm 37a340b
mypy block
afeldman-nm 507b4f8
Merge branch 'main' into lp_ext_merge
afeldman-nm 27de87a
revert input_batch.make_sampling_metadata back to private naming
njhill 7cdd389
fix test isolation
njhill File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,147 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|
|
||||||
| """This example demonstrates instantiating vLLM with a custom logits processor | ||||||
| class object. | ||||||
|
|
||||||
| For a basic example of implementing a custom logits processor, see | ||||||
| the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`. | ||||||
|
|
||||||
| For testing purposes, a dummy logits processor is employed which, if | ||||||
| `target_token` is passed as a keyword argument to `SamplingParams.extra_args`, | ||||||
| will mask out all tokens except `target_token`. | ||||||
|
|
||||||
| A batch is constructed with `temperature=0.0` and 50% of requests specifying | ||||||
| `target_token`, and for these requests - and *only* these requests - we | ||||||
| expect the `target_token` to be decoded in each step, yielding an output | ||||||
| similar to that shown below: | ||||||
|
|
||||||
| Generated Outputs: | ||||||
| ------------------------------------------------------------ | ||||||
| Prompt: 'Hello, my name is' | ||||||
| Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" | ||||||
| ------------------------------------------------------------ | ||||||
| Prompt: 'The president of the United States is' | ||||||
| Output: " not a racist. He is a racist.\nHe's a racist because he" | ||||||
| ------------------------------------------------------------ | ||||||
| Prompt: 'The capital of France is' | ||||||
| Output: ' also also also also also also also also also also also also also | ||||||
| also also also' | ||||||
| ------------------------------------------------------------ | ||||||
| Prompt: 'The future of AI is' | ||||||
| Output: ' in the hands of the people.\n\nThe future of AI is in the' | ||||||
| ------------------------------------------------------------ | ||||||
| """ | ||||||
|
|
||||||
| from typing import Optional | ||||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| from vllm import LLM, SamplingParams | ||||||
| from vllm.config import VllmConfig | ||||||
| from vllm.v1.sample.logits_processor import ( | ||||||
| BatchUpdate, | ||||||
| LogitsProcessor, | ||||||
| MoveDirectionality, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # Hypothetical custom logits processor | ||||||
| class DummyLogitsProcessor(LogitsProcessor): | ||||||
| """Fake logit processor to support unit testing and examples""" | ||||||
|
|
||||||
| def __init__( | ||||||
| self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool | ||||||
| ): | ||||||
| self.req_info: dict[int, SamplingParams] = {} | ||||||
|
|
||||||
| def is_argmax_invariant(self) -> bool: | ||||||
| """Never impacts greedy sampling""" | ||||||
| return False | ||||||
|
|
||||||
| def update_state(self, batch_update: Optional[BatchUpdate]): | ||||||
| if not batch_update: | ||||||
| return | ||||||
|
|
||||||
| # Process added requests. | ||||||
| for index, params, _, _ in batch_update.added: | ||||||
| assert params is not None | ||||||
| if params.extra_args and ( | ||||||
| target_token := params.extra_args.get("target_token") | ||||||
| ): | ||||||
| self.req_info[index] = target_token | ||||||
|
|
||||||
| if self.req_info: | ||||||
| # Process removed requests. | ||||||
| for index in batch_update.removed: | ||||||
| self.req_info.pop(index, None) | ||||||
|
|
||||||
| # Process moved requests, unidirectional move (a->b) and swap | ||||||
| # (a<->b) | ||||||
| for adx, bdx, direct in batch_update.moved: | ||||||
| a_val = self.req_info.pop(adx, None) | ||||||
| b_val = self.req_info.pop(bdx, None) | ||||||
| if a_val is not None: | ||||||
| self.req_info[bdx] = a_val | ||||||
| if direct == MoveDirectionality.SWAP and b_val is not None: | ||||||
| self.req_info[adx] = b_val | ||||||
|
|
||||||
| def apply(self, logits: torch.Tensor) -> torch.Tensor: | ||||||
| if not self.req_info: | ||||||
| return logits | ||||||
|
|
||||||
| # Save target values before modification | ||||||
| rows_list = list(self.req_info.keys()) | ||||||
| cols = torch.tensor( | ||||||
| [self.req_info[i] for i in rows_list], | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should work too, right?
Suggested change
|
||||||
| dtype=torch.long, | ||||||
| device=logits.device, | ||||||
| ) | ||||||
| rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) | ||||||
| values_to_keep = logits[rows, cols].clone() | ||||||
|
|
||||||
| # Mask all but target tokens | ||||||
| logits[rows] = float("-inf") | ||||||
| logits[rows, cols] = values_to_keep | ||||||
|
|
||||||
| return logits | ||||||
|
|
||||||
|
|
||||||
| # Sample prompts. | ||||||
| prompts = [ | ||||||
| "Hello, my name is", | ||||||
| "The president of the United States is", | ||||||
| "The capital of France is", | ||||||
| "The future of AI is", | ||||||
| ] | ||||||
| # Create a mixture of requests which do and don't utilize the dummy logitproc | ||||||
| sampling_params_list = [ | ||||||
| SamplingParams(temperature=0.0, extra_args={"target_token": 128}), | ||||||
| SamplingParams(temperature=0.0), | ||||||
| SamplingParams(temperature=0.0, extra_args={"target_token": 67}), | ||||||
| SamplingParams(temperature=0.0), | ||||||
| ] | ||||||
|
|
||||||
|
|
||||||
| def main(): | ||||||
| # Create an LLM. | ||||||
| llm = LLM( | ||||||
| model="facebook/opt-125m", | ||||||
| logits_processors=[DummyLogitsProcessor], | ||||||
| ) | ||||||
| # Generate texts from the prompts. | ||||||
| # The output is a list of RequestOutput objects | ||||||
| # that contain the prompt, generated text, and other information. | ||||||
| outputs = llm.generate(prompts, sampling_params_list) | ||||||
| # Print the outputs. | ||||||
| print("\nGenerated Outputs:\n" + "-" * 60) | ||||||
| for output in outputs: | ||||||
| prompt = output.prompt | ||||||
| generated_text = output.outputs[0].text | ||||||
| print(f"Prompt: {prompt!r}") | ||||||
| print(f"Output: {generated_text!r}") | ||||||
| print("-" * 60) | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| main() | ||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we assing
self.req_info[index] = target_token, should that type hint beintinstead ofSamplingParams?