@@ -219,6 +219,11 @@ def __init__(
219219 self .arange_np = np .arange (self .max_num_tokens , dtype = np .int32 )
220220 self .num_reqs_paddings = _get_req_paddings (
221221 min_req_size = MIN_NUM_SEQS , max_req_size = self .max_num_reqs )
222+
223+ max_token_idx = self .max_num_reqs * self .max_model_len - 1
224+ # if max token idx exceeds int32 max, use int64 to avoid overflow
225+ self .token_indices_dtype = np .int32 \
226+ if max_token_idx <= np .iinfo (np .int32 ).max else np .int64
222227
223228 def _update_num_xla_graphs (self , case_str ):
224229 check_comp = self .check_recompilation and not self .enforce_eager
@@ -457,8 +462,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
457462 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
458463 # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
459464 # where M is the max_model_len.
465+ # For long context, may need to cast to int64 to avoid overflow
460466 token_indices = (positions_np +
461- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
467+ req_indices .astype (self .token_indices_dtype ) *
468+ self .input_batch .token_ids_cpu .shape [1 ])
462469
463470 # NOTE(woosuk): We use torch.index_select instead of np.take here
464471 # because torch.index_select is much faster than np.take for large
0 commit comments