diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index b8ed60c6bc1..498a03732ba 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -327,13 +327,75 @@ trainer = SFTTrainer( model, train_dataset=dataset, dataset_text_field="text", - torch_dtype=torch.bfloat16, peft_config=peft_config, ) trainer.train() ``` +### Using Flash attention + +You can easliy make the model use flash attention for more memory efficient training. Install the latest `optimum` library and simply add `use_flash_attn=True` when initializing `SFTTrainer`. +The implementation is based on the `BetterTransformer` API of optimum that you can read more about it [here](https://huggingface.co/docs/optimum/bettertransformer/overview). That API will convert the supported models to dispatch the attention operation to make them use [`torch.nn.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) method. +The [`SFTTrainer`] will force-dispatch this operation to use the Flash attention backend. Note that you can combine this feature with PEFT, as well as quantized models (bitsandbytes 4bit and 8bit): + +```python +... + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLM.from_pretrained( + "EleutherAI/gpt-neo-125m", + load_in_8bit=True, + device_map="auto", +) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + dataset_text_field="text", + peft_config=peft_config, + use_flash_attn=True, + packing=True, +) + +trainer.train() +``` + +This feature is currently available on `torch` nightlies, we have tested our experiments on the version `'2.1.0.dev20230802+cu118'` to install the nightlies, follow the instructions on the [official PyTorch website](https://pytorch.org/get-started/locally/), below is the command to install it for CUDA 11.8: + +```bash +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 +``` + +Make sure that your model is loaded in half-precision (in case it is not quantized) for this feature to work correctly. Also make sure that you are training under `packing` mode with the default data_collator (just set it to `None) as the current flash attention kernels for SDPA does not support attention masks. Note that Flash attention enables users to train faster, but also with longer context length. + +Below is a table that summarizes our experiments with and without flash attention: + +```md +| use_flash_attn | model_name | max_seq_len | batch_size | time per training step | +|----------------|-------------------|-------------|------------|------------------------| +| x | facebook/opt-350m | 2048 | 8 | ~59.1s | +| | facebook/opt-350m | 2048 | 8 | **OOM** | +| x | facebook/opt-350m | 2048 | 4 | ~30.3s | +| | facebook/opt-350m | 2048 | 4 | ~148.9s | +``` + +As you can see you can train with longer sequence and faster using the flash attention integration of `SFTTrainer`. +The command we used to benchmark is the following (in a NVIDIA T4 GPU on GCP): + +```bash +python sft_trainer.py --packing --seq_length 2048 --batch_size 4 --use_peft --load_in_4bit --bnb_4b +it_compute_dtype float16 +``` + + ## Best practices Pay attention to the following best practices when training a model with that trainer: @@ -342,6 +404,7 @@ Pay attention to the following best practices when training a model with that tr - For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. +- For more memory efficient training (faster and longer context length), you can benefit from flash attention by just adding `use_flash_attn=True` to SFTTrainer. This is also supported for quantized models. ## SFTTrainer diff --git a/examples/scripts/sft_trainer.py b/examples/scripts/sft_trainer.py index 79a59b60056..37b276747c5 100644 --- a/examples/scripts/sft_trainer.py +++ b/examples/scripts/sft_trainer.py @@ -48,6 +48,7 @@ class ScriptArguments: ) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) + bnb_4bit_compute_dtype: Optional[str] = field(default="float32", metadata={"help": "the compute dtype for 4 bits"}) use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"}) output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) @@ -57,6 +58,8 @@ class ScriptArguments: use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) + use_flash_attn: Optional[bool] = field(default=False, metadata={"help": "Use Flash Attention"}) + packing: Optional[bool] = field(default=False, metadata={"help": "Use packing"}) save_steps: Optional[int] = field( default=100, metadata={"help": "Number of updates steps before two checkpoint saves"} ) @@ -68,12 +71,17 @@ class ScriptArguments: parser = HfArgumentParser(ScriptArguments) script_args = parser.parse_args_into_dataclasses()[0] +if script_args.use_flash_attn and not script_args.packing: + raise ValueError("You can't use Flash Attention without packing") + # Step 1: Load the model if script_args.load_in_8bit and script_args.load_in_4bit: raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") elif script_args.load_in_8bit or script_args.load_in_4bit: quantization_config = BitsAndBytesConfig( - load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit + load_in_8bit=script_args.load_in_8bit, + load_in_4bit=script_args.load_in_4bit, + bnb_4bit_compute_dtype=getattr(torch, script_args.bnb_4bit_compute_dtype), ) # This means: fit the entire model on the GPU:0 device_map = {"": 0} @@ -130,6 +138,8 @@ class ScriptArguments: train_dataset=dataset, dataset_text_field=script_args.dataset_text_field, peft_config=peft_config, + packing=script_args.packing, + use_flash_attn=script_args.use_flash_attn, ) trainer.train() diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index cd21a8df835..e96888a0bb2 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import warnings from typing import Callable, Dict, List, Optional, Tuple, Union @@ -29,6 +30,7 @@ ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction +from transformers.utils import ContextManagers from ..import_utils import is_peft_available from .utils import ConstantLengthDataset, DataCollatorForCompletionOnlyLM, PeftSavingCallback @@ -119,6 +121,7 @@ def __init__( chars_per_token: Optional[float] = 3.6, dataset_num_proc: Optional[int] = None, dataset_batch_size: int = 1000, + use_flash_attn: Optional[bool] = False, ): if isinstance(model, str): warnings.warn( @@ -131,6 +134,21 @@ def __init__( "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." ) + if (not packing or data_collator is not None) and use_flash_attn: + raise ValueError( + "You passed `use_flash_attn=True` to the SFTTrainer, but you also passed `packing=False` or a custom data collator. This is not supported." + " You need to pass `packing=True` to the SFTTrainer if you want to use the `use_flash_attn` argument because flash attention" + " training does not support passing attention masks." + ) + + if use_flash_attn and not torch.cuda.is_available(): + raise ValueError( + "You passed `use_flash_attn=True` to the SFTTrainer, but you don't have a GPU available. This is not supported." + ) + + if isinstance(model, PreTrainedModel) and use_flash_attn: + model = model.to_bettertransformer() + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -148,6 +166,11 @@ def __init__( model = prepare_model_for_int8_training(model) model = get_peft_model(model, peft_config) + elif use_flash_attn and isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` to the SFTTrainer, but you also passed `use_flash_attn=True`. This is not supported." + " You need to first create a transformers model and pass a peft_config to the SFTTrainer rather than directly passing a PeftModel." + ) if callbacks is None: callbacks = [PeftSavingCallback] @@ -169,6 +192,8 @@ def __init__( self.dataset_num_proc = dataset_num_proc self.dataset_batch_size = dataset_batch_size + self.use_flash_attn = use_flash_attn + if not packing: if dataset_text_field is None and formatting_func is None: raise ValueError( @@ -313,3 +338,16 @@ def tokenize(element): ) return tokenized_dataset + + @functools.wraps(Trainer.train) + def train(self, *args, **kwargs): + training_contexts = [] + + if self.use_flash_attn: + training_contexts.append( + torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) + ) + training_contexts.append(torch.cuda.amp.autocast()) + + with ContextManagers(training_contexts): + super().train(*args, **kwargs)