Skip to content

Commit b16b541

Browse files
committed
fixing errors after merging main
1 parent 4028273 commit b16b541

File tree

8 files changed

+41
-44
lines changed

8 files changed

+41
-44
lines changed

vllm/executor/gpu_executor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def execute_model(
8989
self, execute_model_req: ExecuteModelRequest
9090
) -> List[Union[SamplerOutput, PoolerOutput]]:
9191
output = self.driver_worker.execute_model(execute_model_req)
92-
93-
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
94-
output = [sampler_output for sampler_output, _ in output]
95-
9692
return output
9793

9894
def add_lora(self, lora_request: LoRARequest) -> bool:
@@ -120,8 +116,4 @@ async def execute_model_async(
120116
) -> List[Union[SamplerOutput, PoolerOutput]]:
121117
output = await make_async(self.driver_worker.execute_model
122118
)(execute_model_req=execute_model_req, )
123-
124-
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
125-
output = [sampler_output for sampler_output, _ in output]
126-
127119
return output

vllm/spec_decode/batch_expansion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def score_proposals(
4343
self,
4444
execute_model_req: ExecuteModelRequest,
4545
proposals: SpeculativeProposals,
46-
) -> Tuple[SpeculativeScores, Optional[ExtraTensorData]]:
46+
) -> SpeculativeScores:
4747
"""Score the proposed tokens via the scorer model.
4848
4949
This converts each input sequence to a set of k+1 target sequences. The
@@ -82,7 +82,7 @@ def score_proposals(
8282
execute_model_req=execute_model_req.clone(
8383
seq_group_metadata_list=target_seq_group_metadata_list, ))
8484
assert len(target_sampler_output) == 1, "expected single-step output"
85-
target_sampler_output, _ = target_sampler_output[0]
85+
target_sampler_output = target_sampler_output[0]
8686

8787
(all_tokens, all_probs, spec_logprobs,
8888
all_extra_output_data) = self._contract_batch(
@@ -99,7 +99,8 @@ def score_proposals(
9999
probs=all_probs,
100100
token_ids=all_tokens,
101101
logprobs=spec_logprobs,
102-
), all_extra_output_data
102+
extra_tensor_data=all_extra_output_data,
103+
)
103104

104105
def _expand_batch(
105106
self,

vllm/spec_decode/interfaces.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Optional, Tuple
3+
from typing import Optional
44

55
import torch
66

@@ -47,10 +47,14 @@ class SpeculativeScores:
4747
# tokens and also non-speculative normal decoding.
4848
token_ids: torch.Tensor
4949

50+
# Extra data output by the model
51+
extra_tensor_data: Optional[ExtraTensorData]
52+
5053
def __repr__(self):
5154
return (f"SpeculativeScores("
5255
f"probs={self.probs.shape}, "
53-
f"token_ids={self.token_ids.shape})")
56+
f"token_ids={self.token_ids.shape}, "
57+
f"extra_tensor_data={self.extra_tensor_data})")
5458

5559

5660
class SpeculativeProposer(ABC):
@@ -70,5 +74,5 @@ def score_proposals(
7074
self,
7175
execute_model_req: ExecuteModelRequest,
7276
proposals: SpeculativeProposals,
73-
) -> Tuple[SpeculativeScores, Optional[ExtraTensorData]]:
77+
) -> SpeculativeScores:
7478
raise NotImplementedError

vllm/spec_decode/multi_head_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def sampler_output(
4848
execute_model_req: ExecuteModelRequest,
4949
sample_len: int,
5050
) -> Tuple[List[SamplerOutput], bool]:
51-
model_outputs, _ = super().execute_model(
51+
model_outputs = super().execute_model(
5252
execute_model_req=execute_model_req)[0]
5353
return model_outputs, False
5454

vllm/spec_decode/spec_decode_worker.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def execute_model(
278278
self._maybe_disable_speculative_tokens(
279279
disable_all_speculation, execute_model_req.seq_group_metadata_list)
280280

281+
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
282+
.copy()
283+
281284
# If no spec tokens, call the proposer and scorer workers normally.
282285
# Used for prefill.
283286
if num_lookahead_slots == 0 or len(
@@ -327,15 +330,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
327330
not called, meaning that the kv-cache in proposer for requests is not
328331
updated, so they cannot enable spec decode in the rest decoding.
329332
"""
330-
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
331-
.copy()
332-
model_outputs = self.scorer_worker.execute_model(execute_model_req)
333-
assert len(model_outputs) == 1
334-
335-
sampler_output, prefill_extra_tensor_data = model_outputs[0]
333+
sampler_output = self.scorer_worker.execute_model(execute_model_req)
334+
assert len(sampler_output) == 1
335+
sampler_output = sampler_output[0]
336336

337337
execute_model_req.extra_outputs.clear()
338-
execute_model_req.extra_inputs = prefill_extra_tensor_data
338+
execute_model_req.extra_inputs = sampler_output.extra_tensor_data
339339

340340
if not skip_proposer:
341341
self.proposer_worker.execute_model(execute_model_req)
@@ -389,9 +389,7 @@ def _run_speculative_decoding_step(
389389
# Generate proposals using draft worker.
390390
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
391391

392-
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
393-
.copy()
394-
proposal_scores, extra_tensor_data = self.scorer.score_proposals(
392+
proposal_scores = self.scorer.score_proposals(
395393
execute_model_req,
396394
proposals,
397395
)
@@ -405,7 +403,7 @@ def _run_speculative_decoding_step(
405403
accepted_token_ids,
406404
target_logprobs=target_logprobs,
407405
k=execute_model_req.num_lookahead_slots,
408-
extra_tensor_data=extra_tensor_data)
406+
extra_tensor_data=proposal_scores.extra_tensor_data)
409407

410408
@nvtx_range("spec_decode_worker._verify_tokens")
411409
def _verify_tokens(

vllm/transformers_utils/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
"medusa": MedusaConfig,
2323
}
2424

25-
with contextlib.suppress(ValueError):
26-
for name, cls in _CONFIG_REGISTRY.items():
25+
for name, cls in _CONFIG_REGISTRY.items():
26+
with contextlib.suppress(ValueError):
2727
AutoConfig.register(name, cls)
2828

2929

vllm/worker/model_runner.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def execute_model(
722722
kv_caches: List[torch.Tensor],
723723
extra_inputs: ExtraTensorData = None,
724724
extra_outputs: Optional[Set[str]] = None,
725-
) -> Tuple[Optional[SamplerOutput], ExtraTensorData]:
725+
) -> Optional[SamplerOutput]:
726726
(input_tokens, input_positions, attn_metadata, sampling_metadata,
727727
lora_requests, lora_mapping, multi_modal_input, prepared_extra_inputs
728728
) = self.prepare_input_tensors(seq_group_metadata_list)
@@ -771,19 +771,19 @@ def execute_model(
771771
logits = self.model.compute_logits(hidden_states, sampling_metadata)
772772

773773
# Only perform sampling in the driver worker.
774-
if self.is_driver_worker:
775-
# Sample the next token.
776-
output = self.model.sample(
777-
logits=logits,
778-
sampling_metadata=sampling_metadata,
779-
)
774+
if not self.is_driver_worker:
775+
return None
776+
777+
# Sample the next token.
778+
output = self.model.sample(
779+
logits=logits,
780+
sampling_metadata=sampling_metadata,
781+
)
780782

783+
if extra_outputs:
781784
sampled_extra_tensor_data = extra_tensor_data.index_select(
782785
0, sampling_metadata.selected_token_indices)
783-
else:
784-
output = None
785786

786-
if extra_outputs:
787787
if prefill_meta is not None:
788788
for k in extra_tensor_data:
789789
extra_tensor_data[k] = extra_tensor_data[k].roll(shifts=1,
@@ -794,12 +794,13 @@ def execute_model(
794794
if output is not None:
795795
_move_extra_tensor_data_to_seq_outputs(
796796
output, sampled_extra_tensor_data, sampling_metadata)
797+
798+
output.extra_tensor_data = extra_tensor_data
797799
else:
798-
extra_tensor_data.clear()
799800
if output is not None:
800801
output.extra_tensor_data = sampled_extra_tensor_data
801802

802-
return output, extra_tensor_data
803+
return output
803804

804805
@torch.inference_mode()
805806
def profile_run(self) -> None:

vllm/worker/worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
set_custom_all_reduce)
1616
from vllm.lora.request import LoRARequest
1717
from vllm.model_executor import set_random_seed
18-
from vllm.sequence import (ExecuteModelRequest, ExtraTensorData, PoolerOutput,
19-
SamplerOutput)
18+
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
2019
from vllm.worker.cache_engine import CacheEngine
2120
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
2221
from vllm.worker.model_runner import ModelRunner
@@ -273,8 +272,10 @@ def execute_model(
273272
output = self.model_runner.execute_model(
274273
seq_group_metadata_list,
275274
self.gpu_cache,
276-
extra_inputs=execute_model_req.extra_inputs,
277-
extra_outputs=execute_model_req.extra_outputs)
275+
extra_inputs=None
276+
if execute_model_req is None else execute_model_req.extra_inputs,
277+
extra_outputs=None
278+
if execute_model_req is None else execute_model_req.extra_outputs)
278279

279280
# Worker only supports single-step execution. Wrap the output in a list
280281
# to conform to interface.

0 commit comments

Comments
 (0)