Skip to content

Commit 2fc33eb

Browse files
Track the number of tokens seen to metrics (#27274)
* Add tokens seen * Address comments, add to TrainingArgs * Update log * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * Use self.args * Fix docstring Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
1 parent 303c1d6 commit 2fc33eb

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

src/transformers/trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,6 +1838,17 @@ def _inner_training_loop(
18381838
step = -1
18391839
for step, inputs in enumerate(epoch_iterator):
18401840
total_batched_samples += 1
1841+
1842+
if self.args.include_num_input_tokens_seen:
1843+
main_input_name = getattr(self.model, "main_input_name", "input_ids")
1844+
if main_input_name not in inputs:
1845+
logger.warning(
1846+
"Tried to track the number of tokens seen, however the current model is "
1847+
"not configured properly to know what item is the input. To fix this, add "
1848+
"a `main_input_name` attribute to the model class you are using."
1849+
)
1850+
else:
1851+
self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel()
18411852
if rng_to_sync:
18421853
self._load_rng_state(resume_from_checkpoint)
18431854
rng_to_sync = False
@@ -2640,6 +2651,8 @@ def log(self, logs: Dict[str, float]) -> None:
26402651
"""
26412652
if self.state.epoch is not None:
26422653
logs["epoch"] = round(self.state.epoch, 2)
2654+
if self.args.include_num_input_tokens_seen:
2655+
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
26432656

26442657
output = {**logs, **{"step": self.state.global_step}}
26452658
self.state.log_history.append(output)

src/transformers/trainer_callback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class TrainerState:
5959
Run an evaluation every X steps.
6060
save_steps (`int`, *optional*, defaults to 500):
6161
Save checkpoint every X updates steps.
62+
num_input_tokens_seen (`int`, *optional*, defaults to 0):
63+
The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
6264
total_flos (`float`, *optional*, defaults to 0):
6365
The total number of floating operations done by the model since the beginning of training (stored as floats
6466
to avoid overflow).
@@ -87,6 +89,7 @@ class TrainerState:
8789
eval_steps: int = 500
8890
save_steps: int = 500
8991
num_train_epochs: int = 0
92+
num_input_tokens_seen: int = 0
9093
total_flos: float = 0
9194
log_history: List[Dict[str, float]] = None
9295
best_metric: Optional[float] = None

src/transformers/training_args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,12 @@ class TrainingArguments:
637637
This will iterate over the entire training dataloader once beforehand,
638638
639639
and will slow down the entire process.
640+
641+
include_num_input_tokens_seen (`bool`, *optional*):
642+
Whether or not to track the number of input tokens seen throughout training.
643+
644+
May be slower in distributed training as gather operations must be called.
645+
640646
neftune_noise_alpha (`Optional[float]`):
641647
If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
642648
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
@@ -1258,6 +1264,13 @@ class TrainingArguments:
12581264
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
12591265
)
12601266

1267+
include_num_input_tokens_seen: Optional[bool] = field(
1268+
default=False,
1269+
metadata={
1270+
"help": "If set to `True`, will track the number of input tokens seen throughout training. (May be slower in distributed training)"
1271+
},
1272+
)
1273+
12611274
neftune_noise_alpha: float = field(
12621275
default=None,
12631276
metadata={

0 commit comments

Comments
 (0)