Skip to content

Commit 44866a3

Browse files
committed
Guard NWOR staging from unexpected graph capture
1 parent d0ac344 commit 44866a3

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

vllm/v1/kv_cache/deferred.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ def stage_layer(
251251
if not self._window_active:
252252
return False
253253

254+
if _in_restricted_context():
255+
logger.warning_once(
256+
"NWOR: Graph capture detected during staging; skipping staged writes."
257+
)
258+
return False
259+
254260
if not (_tensor_has_storage(key) and _tensor_has_storage(value)):
255261
raise ShouldFallback("kv_slice_without_storage")
256262

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,10 +2731,10 @@ def _compute_nwor_acceptance(
27312731
draft_ids = draft_ids.to(dtype=sampled_token_ids.dtype, copy=False)
27322732

27332733
if return_mask:
2734-
mask_work = torch.zeros(total_tokens, dtype=torch.bool, device=work_device)
2735-
else:
2736-
mask_work = None
2737-
accepted_counts = []
2734+
mask_work = torch.zeros(total_tokens, dtype=torch.bool, device=work_device)
2735+
else:
2736+
mask_work = None
2737+
accepted_counts = []
27382738

27392739
if sampled_token_ids.ndim == 0:
27402740
zero_counts = [0 for _ in num_draft_tokens]

0 commit comments

Comments
 (0)