diff --git a/kernel-coder/nano_r1_script.py b/kernel-coder/nano_r1_script.py index 698b59a..febadc2 100644 --- a/kernel-coder/nano_r1_script.py +++ b/kernel-coder/nano_r1_script.py @@ -229,7 +229,7 @@ def extract_first_code(output_string: str, code_language_types: list[str]) -> st return None -def format_reward_func(completion: str, EOS_TOKEN: str) -> float: +def format_reward_func(completion: str, EOS_TOKEN: str) -> tuple[float, str | None]: """ Format: ...anything @@ -238,7 +238,8 @@ def format_reward_func(completion: str, EOS_TOKEN: str) -> float: EOS_TOKEN (str): End of sequence token Returns: - float: Reward score + Tuple containing the reward score and the first extracted code block if + available. """ code = None diff --git a/kernel-coder/utils.py b/kernel-coder/utils.py index 26e71e6..782dee3 100644 --- a/kernel-coder/utils.py +++ b/kernel-coder/utils.py @@ -187,7 +187,8 @@ def compute_token_log_probs( logits = outputs.logits / temperature # Shape: [batch_size, seq_len, vocab_size] shift_logits = logits[..., :-1, :] # Shape: [batch_size, seq_len-1, vocab_size] - shift_labels = inputs["labels"][..., 1:] # Shape: [batch_size, seq_len-1] + # Clone to avoid in-place modification of `inputs["labels"]` + shift_labels = inputs["labels"][..., 1:].clone() # Shape: [batch_size, seq_len-1] shift_labels_mask = inputs["labels_mask"][..., 1:] # Shape: [batch_size, seq_len-1] # Create mask for valid labels