Skip to content

Commit 0094d0e

Browse files
committed
Use explicit named arg for clamp min
1 parent 1b68dd3 commit 0094d0e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/generation/logits_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
493493

494494
# Remove tokens with cumulative mass above the threshold
495495
last_ind = (cumulative_probs < self.mass).sum(dim=1) - 1
496-
last_ind.clamp_(0)
496+
last_ind.clamp_(min=0)
497497
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
498498
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
499499
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

0 commit comments

Comments
 (0)