diff --git a/kernel-coder/utils.py b/kernel-coder/utils.py index 26e71e6..274e172 100644 --- a/kernel-coder/utils.py +++ b/kernel-coder/utils.py @@ -120,7 +120,7 @@ def log_softmax_and_gather(logits: torch.Tensor, index: torch.Tensor) -> torch.T torch compiled version of the common `log_softmax -> gather` operation. - The compiled version of this opration avoids the (significant) memory overhead of + The compiled version of this operation avoids the (significant) memory overhead of allocating a new (batch_size, seq_len, vocab_size) tensor to store the logprobs. Args: