Skip to content

Commit d794d13

Browse files
committed
fixes for passing format.sh
1 parent 6dd8d26 commit d794d13

File tree

11 files changed

+48
-43
lines changed

11 files changed

+48
-43
lines changed

vllm/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import enum
22
import json
33
from dataclasses import dataclass, field, fields
4-
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Tuple, Union
4+
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Tuple,
5+
Union)
56

67
import torch
78
from transformers import PretrainedConfig
@@ -98,7 +99,7 @@ def __init__(
9899
max_logprobs: int = 5,
99100
skip_tokenizer_init: bool = False,
100101
served_model_name: Optional[Union[str, List[str]]] = None,
101-
extra_inputs: Set[str] = set(),
102+
extra_inputs: Optional[Set[str]] = None,
102103
) -> None:
103104
self.model = model
104105
self.tokenizer = tokenizer
@@ -132,7 +133,7 @@ def __init__(
132133

133134
self.extra_inputs: Dict[str, Tuple[Tuple[int],
134135
Optional[torch.dtype]]] = {}
135-
if "hidden_states" in extra_inputs:
136+
if extra_inputs and "hidden_states" in extra_inputs:
136137
self.extra_inputs["hidden_states"] = ((
137138
self.hf_config.hidden_size, ), None)
138139

vllm/engine/arg_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,8 @@ def add_cli_args(
488488
'--extra-inputs-for-draft-model',
489489
type=nullable_str,
490490
default=EngineArgs.extra_inputs_for_draft_model,
491-
help=
492-
'Extra model inputs used by draft model. These should come as outputs from the target model.'
493-
)
491+
help='Extra model inputs used by draft model.'
492+
'These should come as outputs from the target model.')
494493

495494
parser.add_argument(
496495
'--num-speculative-tokens',
@@ -595,7 +594,7 @@ def create_engine_config(self, ) -> EngineConfig:
595594

596595
try:
597596
extra_inputs = set(self.extra_inputs_for_draft_model.split(","))
598-
except:
597+
except Exception:
599598
extra_inputs = set()
600599

601600
speculative_config = SpeculativeConfig.maybe_create_spec_config(

vllm/executor/gpu_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def execute_model(
9090
) -> List[Union[SamplerOutput, PoolerOutput]]:
9191
output = self.driver_worker.execute_model(execute_model_req)
9292

93-
if not (isinstance(output[0], SamplerOutput)
94-
or isinstance(output[0], PoolerOutput)):
93+
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
9594
output = [sampler_output for sampler_output, _ in output]
9695

9796
return output
@@ -122,8 +121,7 @@ async def execute_model_async(
122121
output = await make_async(self.driver_worker.execute_model
123122
)(execute_model_req=execute_model_req, )
124123

125-
if not (isinstance(output[0], SamplerOutput)
126-
or isinstance(output[0], PoolerOutput)):
124+
if not isinstance(output[0], (SamplerOutput, PoolerOutput)):
127125
output = [sampler_output for sampler_output, _ in output]
128126

129127
return output

vllm/sequence.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
77

88
import torch
9+
910
from vllm.block import LogicalTokenBlock
1011
from vllm.lora.request import LoRARequest
1112
from vllm.pooling_params import PoolingParams
@@ -96,10 +97,12 @@ def stack(
9697
data: List[Optional["ExtraTensorData"]],
9798
dim: int = 0,
9899
) -> Optional["ExtraTensorData"]:
99-
if len(data) == 0: return None
100+
if len(data) == 0:
101+
return None
100102

101103
for d in data:
102-
if d is None: return None
104+
if d is None:
105+
return None
103106

104107
assert isinstance(data[0], ExtraTensorData)
105108

vllm/spec_decode/batch_expansion.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,16 @@ def score_proposals(
8484
assert len(target_sampler_output) == 1, "expected single-step output"
8585
target_sampler_output, _ = target_sampler_output[0]
8686

87-
all_tokens, all_probs, spec_logprobs, all_extra_output_data = self._contract_batch(
88-
contracted_bs=len(execute_model_req.seq_group_metadata_list),
89-
target_sampler_output=target_sampler_output,
90-
proposals=proposals,
91-
num_scoring_tokens=num_scoring_tokens,
92-
non_spec_indices=non_spec_indices,
93-
spec_indices=spec_indices,
94-
k=execute_model_req.num_lookahead_slots,
95-
)
87+
(all_tokens, all_probs, spec_logprobs,
88+
all_extra_output_data) = self._contract_batch(
89+
contracted_bs=len(execute_model_req.seq_group_metadata_list),
90+
target_sampler_output=target_sampler_output,
91+
proposals=proposals,
92+
num_scoring_tokens=num_scoring_tokens,
93+
non_spec_indices=non_spec_indices,
94+
spec_indices=spec_indices,
95+
k=execute_model_req.num_lookahead_slots,
96+
)
9697

9798
return SpeculativeScores(
9899
probs=all_probs,
@@ -217,7 +218,8 @@ def _contract_batch(
217218
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
218219
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
219220

220-
if all_extra_output_data and non_spec_target_extra_output_data is not None:
221+
if all_extra_output_data and \
222+
non_spec_target_extra_output_data is not None:
221223
for k in all_extra_output_data:
222224
all_extra_output_data[k][
223225
non_spec_indices, :
@@ -382,8 +384,9 @@ def _split_scoring_output(
382384
if sampler_output.extra_tensor_data is None:
383385
spec_extra_output_data, no_spec_extra_output_data = (None, None)
384386
else:
385-
spec_extra_output_data, no_spec_extra_output_data = sampler_output.extra_tensor_data.split(
386-
split_sizes)
387+
spec_extra_output_data, no_spec_extra_output_data = sampler_output\
388+
.extra_tensor_data\
389+
.split(split_sizes)
387390

388391
# Convert scores to tensors.
389392
sampler_output.sampled_token_probs = spec_probs

vllm/spec_decode/spec_decode_worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from vllm.spec_decode.interfaces import (SpeculativeProposals,
1313
SpeculativeScorer, SpeculativeScores)
1414
from vllm.spec_decode.metrics import AsyncMetricsCollector
15-
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1615
from vllm.spec_decode.multi_head_worker import MultiHeadWorker
16+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1717
from vllm.spec_decode.ngram_worker import NGramWorker
1818
from vllm.spec_decode.util import (create_sequence_group_output,
1919
get_all_num_logprobs, get_all_seq_ids,
@@ -337,8 +337,8 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
337337
not called, meaning that the kv-cache in proposer for requests is not
338338
updated, so they cannot enable spec decode in the rest decoding.
339339
"""
340-
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs.copy(
341-
)
340+
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
341+
.copy()
342342
model_outputs = self.scorer_worker.execute_model(execute_model_req)
343343
assert len(model_outputs) == 1
344344

@@ -392,8 +392,8 @@ def _run_speculative_decoding_step(
392392
# Generate proposals using draft worker.
393393
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
394394

395-
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs.copy(
396-
)
395+
execute_model_req.extra_outputs = self.proposer_worker.extra_inputs\
396+
.copy()
397397
proposal_scores, extra_tensor_data = self.scorer.score_proposals(
398398
execute_model_req,
399399
proposals,

vllm/spec_decode/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def sampler_output_to_torch(
193193
sampled_extra_output_data[k] = sampled_extra_output_data[
194194
k].transpose(0, 1)
195195

196-
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs, sampled_extra_output_data
196+
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
197+
sampled_extra_output_data)
197198

198199

199200
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,

vllm/transformers_utils/config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import contextlib
12
from typing import Dict, Optional
23

34
from transformers import AutoConfig, PretrainedConfig
45

56
from vllm.logger import init_logger
67
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
7-
JAISConfig, MPTConfig, RWConfig,
8-
MedusaConfig)
8+
JAISConfig, MedusaConfig,
9+
MPTConfig, RWConfig)
910

1011
logger = init_logger(__name__)
1112

@@ -21,11 +22,9 @@
2122
"medusa": MedusaConfig,
2223
}
2324

24-
for name, cls in _CONFIG_REGISTRY.items():
25-
try:
25+
with contextlib.suppress(ValueError):
26+
for name, cls in _CONFIG_REGISTRY.items():
2627
AutoConfig.register(name, cls)
27-
except:
28-
pass
2928

3029

3130
def get_config(model: str,

vllm/transformers_utils/configs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# `FalconConfig` class from the official HuggingFace transformers library.
66
from vllm.transformers_utils.configs.falcon import RWConfig
77
from vllm.transformers_utils.configs.jais import JAISConfig
8-
from vllm.transformers_utils.configs.mpt import MPTConfig
98
from vllm.transformers_utils.configs.medusa import MedusaConfig
9+
from vllm.transformers_utils.configs.mpt import MPTConfig
1010

1111
__all__ = [
1212
"ChatGLMConfig",

vllm/worker/embedding_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from vllm.lora.request import LoRARequest
1313
from vllm.model_executor.pooling_metadata import PoolingMetadata
1414
from vllm.pooling_params import PoolingParams
15-
from vllm.sequence import ExtraTensorData, PoolerOutput, SequenceData, SequenceGroupMetadata
15+
from vllm.sequence import (ExtraTensorData, PoolerOutput, SequenceData,
16+
SequenceGroupMetadata)
1617
from vllm.worker.model_runner import ModelRunner
1718

1819
logger = init_logger(__name__)

0 commit comments

Comments
 (0)