Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
Expand Down Expand Up @@ -639,6 +640,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),

# Use delayed sampling for HPU to reduce host cpu overhead
# between each step.
"VLLM_HPU_USE_DELAYED_SAMPLING":
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
("1", "true"),

# Rank of the process in the data parallel setting
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
Expand Down
140 changes: 133 additions & 7 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@

LORA_WARMUP_RANK = 8

DUMMY_TOKEN_ID = -1


class Singleton(type):
_instances: Dict[type, object] = {}
Expand Down Expand Up @@ -668,6 +670,9 @@ def __init__(

# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
# For delayed sampling
self.cached_step_inputs: List[
ModelInputForHPUWithSamplingMetadata] = []

def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
Expand Down Expand Up @@ -771,6 +776,12 @@ def load_model(self) -> None:
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)

def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)

def get_model(self) -> nn.Module:
return self.model

Expand Down Expand Up @@ -2020,6 +2031,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def _get_seq_ids(self, model_input):
return ([
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
])

def _pad_to_max_num_seqs(self, tensor, value):
padding_needed = self.max_num_seqs - tensor.size(0)
if padding_needed:
padding = torch.full((padding_needed, *tensor.shape[1:]),
value,
device=tensor.device,
dtype=tensor.dtype)
tensor = torch.cat([tensor, padding])
return tensor

@torch.inference_mode()
def execute_model(
self,
Expand All @@ -2030,6 +2056,37 @@ def execute_model(
warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
assert model_input.input_tokens is not None
if use_delayed_sampling and not model_input.is_prompt and \
self.is_driver_worker:
num_cached = len(self.cached_step_outputs)
assert num_cached > 0
cur_seq_ids = self._get_seq_ids(model_input)
cur_seq_id_pos = {
sid: idx
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
}
htorch.core.mark_step()
for i in range(num_cached):
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
target_indices = [
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
]
padding = self.cached_step_outputs[i].size(0) - len(
target_indices)
target_indices.extend([-1] * padding)
target_indices = torch.tensor(
target_indices,
device=model_input.input_tokens.device,
dtype=model_input.input_tokens.dtype)
model_input.input_tokens.index_copy_(
0, target_indices, self.cached_step_outputs[i])
htorch.core.mark_step()

if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
Expand All @@ -2045,7 +2102,21 @@ def execute_model(
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
# Rank!=0 workers has is_prompt==None
if use_delayed_sampling and not model_input.is_prompt and \
model_input.input_tokens.size(1) == 1:
if self.is_driver_worker:
model_kwargs_broadcast_data = {
"input_tokens": model_input.input_tokens
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
input_tokens = model_input.input_tokens

else:
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
input_tokens = model_kwargs_broadcast_data["input_tokens"]
else:
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
Expand Down Expand Up @@ -2092,7 +2163,7 @@ def execute_model(
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1:
if num_steps > 1 or use_delayed_sampling:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
Expand Down Expand Up @@ -2152,9 +2223,9 @@ def try_revert_dummy_output_tokens():
if not self.is_driver_worker:
continue

if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
if use_delayed_sampling:
fake_output = self._delayed_sampler_outputs(model_input)

with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
Expand All @@ -2166,9 +2237,16 @@ def try_revert_dummy_output_tokens():
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(
output.detach().clone())
self.cached_step_outputs.append(output)
if use_delayed_sampling and self.is_driver_worker:
self._patch_prev_output()
output = self._pad_to_max_num_seqs(
output.sampled_token_ids, DUMMY_TOKEN_ID)
self.cached_step_outputs.append(output)
self.cached_step_inputs.append(model_input)
htorch.core.mark_step()
if model_input.async_callback is not None:
model_input.async_callback()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
Expand Down Expand Up @@ -2241,11 +2319,30 @@ def try_revert_dummy_output_tokens():
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
if use_delayed_sampling:
if self.is_driver_worker:
return [fake_output]
else:
return []

return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]

def _delayed_sampler_outputs(self, model_input):
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
model_input.sampling_metadata.seq_groups)
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
return sampler_output

def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
Expand Down Expand Up @@ -2312,3 +2409,32 @@ def shutdown_inc(self):

def __del__(self):
self.shutdown_inc()

def _patch_prev_output(self):
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
f'''Inputs and outputs are out of sync!
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
if len(self.cached_step_inputs) == 0:
return
model_input = self.cached_step_inputs.pop(0)
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
-1).tolist()
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
# If there's no output to patch with, which is usually the case when
# we're starting a new request after all requests are completed.
if len(ctx.output_queue) == 0:
return
assert len(
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
output_data = ctx.output_queue[0]
assert len(output_data.outputs) == 1
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
fake_out.samples[0].output_token = real_out
for sg, real_out in zip(output_data.seq_group_metadata_list,
delayed_output):
assert len(sg.seq_data) == 1
seq_data = list(sg.seq_data.values())[0]
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out