Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def execute_model(
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
for _ in range(num_layers)
]

execute_model_kwargs = {
"input_ids":
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def profile_run(self) -> None:
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
6 changes: 5 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,13 @@ def profile_run(self) -> None:
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
Loading