File tree Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -251,6 +251,12 @@ def stage_layer(
251251 if not self ._window_active :
252252 return False
253253
254+ if _in_restricted_context ():
255+ logger .warning_once (
256+ "NWOR: Graph capture detected during staging; skipping staged writes."
257+ )
258+ return False
259+
254260 if not (_tensor_has_storage (key ) and _tensor_has_storage (value )):
255261 raise ShouldFallback ("kv_slice_without_storage" )
256262
Original file line number Diff line number Diff line change @@ -2731,10 +2731,10 @@ def _compute_nwor_acceptance(
27312731 draft_ids = draft_ids .to (dtype = sampled_token_ids .dtype , copy = False )
27322732
27332733 if return_mask :
2734- mask_work = torch .zeros (total_tokens , dtype = torch .bool , device = work_device )
2735- else :
2736- mask_work = None
2737- accepted_counts = []
2734+ mask_work = torch .zeros (total_tokens , dtype = torch .bool , device = work_device )
2735+ else :
2736+ mask_work = None
2737+ accepted_counts = []
27382738
27392739 if sampled_token_ids .ndim == 0 :
27402740 zero_counts = [0 for _ in num_draft_tokens ]
You can’t perform that action at this time.
0 commit comments