44import os
55import re
66from pathlib import Path
7- from typing import Any , Callable , List , Optional , Union , cast
7+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , cast
88
99import torch # pyright: ignore[reportMissingImports]
10+ from accelerate import Accelerator # pyright: ignore[reportMissingImports]
11+ from accelerate .utils .memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
1012from datasets .arrow_dataset import Dataset
1113from sacremoses import MosesPunctNormalizer
1214from torch import Tensor # pyright: ignore[reportMissingImports]
15+ from torch .nn import Module # pyright: ignore[reportMissingImports]
16+ from torch .optim .lr_scheduler import LambdaLR # pyright: ignore[reportMissingImports]
17+ from torch .optim .optimizer import Optimizer # pyright: ignore[reportMissingImports]
1318from torch .utils .checkpoint import checkpoint # pyright: ignore[reportMissingImports] # noqa: F401
1419from transformers import (
1520 AutoConfig ,
1621 AutoModelForSeq2SeqLM ,
1722 AutoTokenizer ,
1823 DataCollatorForSeq2Seq ,
24+ EvalPrediction ,
1925 M2M100ForConditionalGeneration ,
2026 M2M100Tokenizer ,
27+ MBart50Tokenizer ,
2128 MBart50TokenizerFast ,
2229 MBartTokenizer ,
2330 MBartTokenizerFast ,
2431 NllbTokenizer ,
2532 NllbTokenizerFast ,
2633 PreTrainedModel ,
2734 PreTrainedTokenizer ,
35+ PreTrainedTokenizerBase ,
2836 PreTrainedTokenizerFast ,
2937 Seq2SeqTrainer ,
3038 Seq2SeqTrainingArguments ,
3139 TrainerCallback ,
3240 set_seed ,
3341)
34- from transformers .models .mbart50 import MBart50Tokenizer
3542from transformers .trainer_callback import TrainerControl , TrainerState
3643from transformers .trainer_utils import get_last_checkpoint
3744from transformers .training_args import TrainingArguments
@@ -315,7 +322,7 @@ def preprocess_function(examples):
315322 pad_to_multiple_of = 8 if self ._training_args .fp16 else None ,
316323 )
317324
318- self ._trainer = Seq2SeqTrainer (
325+ self ._trainer = AutoGradientAccumulationStepsSeq2SeqTrainer (
319326 model = model ,
320327 args = self ._training_args ,
321328 train_dataset = cast (Any , train_dataset ),
@@ -372,10 +379,12 @@ def __init__(
372379 max_steps : Optional [int ],
373380 progress : Optional [Callable [[ProgressStatus ], None ]],
374381 check_canceled : Optional [Callable [[], None ]],
382+ update_frequency : Optional [int ] = None ,
375383 ) -> None :
376- self ._max_steps = max_steps
384+ self ._max_steps = max_steps if max_steps is not None else 0
377385 self ._progress = progress
378386 self ._check_canceled = check_canceled
387+ self ._update_frequency = update_frequency if update_frequency is not None else max ((self ._max_steps // 100 ), 1 )
379388
380389 def on_train_begin (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ) -> None :
381390 if self ._check_canceled is not None :
@@ -387,6 +396,9 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
387396 )
388397
389398 def on_step_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ) -> None :
399+ if (state .global_step % self ._update_frequency ) != 0 :
400+ return
401+
390402 if self ._check_canceled is not None :
391403 self ._check_canceled ()
392404
@@ -398,6 +410,73 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
398410 )
399411
400412
413+ class AutoGradientAccumulationStepsSeq2SeqTrainer (Seq2SeqTrainer ):
414+ def __init__ (
415+ self ,
416+ model : Union [PreTrainedModel , Module ],
417+ args : Seq2SeqTrainingArguments ,
418+ data_collator : Any ,
419+ train_dataset : Optional [Dataset ] = None ,
420+ eval_dataset : Optional [Union [Dataset , Dict [str , Dataset ]]] = None ,
421+ tokenizer : Optional [PreTrainedTokenizerBase ] = None ,
422+ model_init : Optional [Callable [[], PreTrainedModel ]] = None ,
423+ compute_metrics : Optional [Callable [[EvalPrediction ], Dict ]] = None ,
424+ callbacks : Optional [List [TrainerCallback ]] = None ,
425+ optimizers : Tuple [Optional [Optimizer ], Optional [LambdaLR ]] = (None , None ),
426+ preprocess_logits_for_metrics : Optional [Callable [[Tensor , Tensor ], Tensor ]] = None ,
427+ ):
428+ super ().__init__ (
429+ model ,
430+ args ,
431+ data_collator ,
432+ train_dataset , # type: ignore
433+ eval_dataset , # type: ignore
434+ tokenizer ,
435+ model_init ,
436+ compute_metrics ,
437+ callbacks ,
438+ optimizers , # type: ignore
439+ preprocess_logits_for_metrics ,
440+ )
441+
442+ def _inner_training_loop (
443+ self , batch_size = None , args = None , resume_from_checkpoint = None , trial = None , ignore_keys_for_eval = None
444+ ):
445+ inner_training_loop = find_executable_batch_size (super ()._inner_training_loop , batch_size , self .accelerator )
446+ return inner_training_loop (
447+ args = args ,
448+ resume_from_checkpoint = resume_from_checkpoint ,
449+ trial = trial ,
450+ ignore_keys_for_eval = ignore_keys_for_eval ,
451+ )
452+
453+
454+ def find_executable_batch_size (function : Callable , starting_batch_size , accelerator : Accelerator ):
455+ batch_size = starting_batch_size
456+
457+ def decorator (* args , ** kwargs ):
458+ nonlocal batch_size
459+ gc .collect ()
460+ torch .cuda .empty_cache ()
461+
462+ while True :
463+ if batch_size == 0 :
464+ raise RuntimeError ("No executable batch size found, reached zero." )
465+ try :
466+ return function (batch_size , * args , ** kwargs )
467+ except Exception as e :
468+ if should_reduce_batch_size (e ):
469+ gc .collect ()
470+ torch .cuda .empty_cache ()
471+ batch_size //= 2
472+ accelerator .gradient_accumulation_steps = accelerator .gradient_accumulation_steps * 2
473+ kwargs ["args" ].gradient_accumulation_steps = accelerator .gradient_accumulation_steps
474+ else :
475+ raise
476+
477+ return decorator
478+
479+
401480def add_lang_code_to_tokenizer (tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ], lang_code : str ):
402481 if isinstance (tokenizer , M2M100Tokenizer ):
403482 lang_token = "__" + lang_code + "__"
0 commit comments