Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,75 @@ trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
torch_dtype=torch.bfloat16,
peft_config=peft_config,
)

trainer.train()
```

### Using Flash attention

You can easliy make the model use flash attention for more memory efficient training. Install the latest `optimum` library and simply add `use_flash_attn=True` when initializing `SFTTrainer`.
The implementation is based on the `BetterTransformer` API of optimum that you can read more about it [here](https://huggingface.co/docs/optimum/bettertransformer/overview). That API will convert the supported models to dispatch the attention operation to make them use [`torch.nn.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) method.
The [`SFTTrainer`] will force-dispatch this operation to use the Flash attention backend. Note that you can combine this feature with PEFT, as well as quantized models (bitsandbytes 4bit and 8bit):

```python
...

peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)

trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
peft_config=peft_config,
use_flash_attn=True,
packing=True,
)

trainer.train()
```

This feature is currently available on `torch` nightlies, we have tested our experiments on the version `'2.1.0.dev20230802+cu118'` to install the nightlies, follow the instructions on the [official PyTorch website](https://pytorch.org/get-started/locally/), below is the command to install it for CUDA 11.8:

```bash
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```

Make sure that your model is loaded in half-precision (in case it is not quantized) for this feature to work correctly. Also make sure that you are training under `packing` mode with the default data_collator (just set it to `None) as the current flash attention kernels for SDPA does not support attention masks. Note that Flash attention enables users to train faster, but also with longer context length.

Below is a table that summarizes our experiments with and without flash attention:

```md
| use_flash_attn | model_name | max_seq_len | batch_size | time per training step |
|----------------|-------------------|-------------|------------|------------------------|
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
```

As you can see you can train with longer sequence and faster using the flash attention integration of `SFTTrainer`.
The command we used to benchmark is the following (in a NVIDIA T4 GPU on GCP):

```bash
python sft_trainer.py --packing --seq_length 2048 --batch_size 4 --use_peft --load_in_4bit --bnb_4b
it_compute_dtype float16
```


## Best practices

Pay attention to the following best practices when training a model with that trainer:
Expand All @@ -342,6 +404,7 @@ Pay attention to the following best practices when training a model with that tr
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
- For more memory efficient training (faster and longer context length), you can benefit from flash attention by just adding `use_flash_attn=True` to SFTTrainer. This is also supported for quantized models.

## SFTTrainer

Expand Down
12 changes: 11 additions & 1 deletion examples/scripts/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ScriptArguments:
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
bnb_4bit_compute_dtype: Optional[str] = field(default="float32", metadata={"help": "the compute dtype for 4 bits"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"})
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
Expand All @@ -57,6 +58,8 @@ class ScriptArguments:
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
use_flash_attn: Optional[bool] = field(default=False, metadata={"help": "Use Flash Attention"})
packing: Optional[bool] = field(default=False, metadata={"help": "Use packing"})
save_steps: Optional[int] = field(
default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
)
Expand All @@ -68,12 +71,17 @@ class ScriptArguments:
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

if script_args.use_flash_attn and not script_args.packing:
raise ValueError("You can't use Flash Attention without packing")

# Step 1: Load the model
if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
load_in_8bit=script_args.load_in_8bit,
load_in_4bit=script_args.load_in_4bit,
bnb_4bit_compute_dtype=getattr(torch, script_args.bnb_4bit_compute_dtype),
)
# This means: fit the entire model on the GPU:0
device_map = {"": 0}
Expand Down Expand Up @@ -130,6 +138,8 @@ class ScriptArguments:
train_dataset=dataset,
dataset_text_field=script_args.dataset_text_field,
peft_config=peft_config,
packing=script_args.packing,
use_flash_attn=script_args.use_flash_attn,
)

trainer.train()
Expand Down
38 changes: 38 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -29,6 +30,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import ContextManagers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this 😂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hahahah nice!


from ..import_utils import is_peft_available
from .utils import ConstantLengthDataset, DataCollatorForCompletionOnlyLM, PeftSavingCallback
Expand Down Expand Up @@ -119,6 +121,7 @@ def __init__(
chars_per_token: Optional[float] = 3.6,
dataset_num_proc: Optional[int] = None,
dataset_batch_size: int = 1000,
use_flash_attn: Optional[bool] = False,
):
if isinstance(model, str):
warnings.warn(
Expand All @@ -131,6 +134,21 @@ def __init__(
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument."
)

if (not packing or data_collator is not None) and use_flash_attn:
raise ValueError(
"You passed `use_flash_attn=True` to the SFTTrainer, but you also passed `packing=False` or a custom data collator. This is not supported."
" You need to pass `packing=True` to the SFTTrainer if you want to use the `use_flash_attn` argument because flash attention"
" training does not support passing attention masks."
)

if use_flash_attn and not torch.cuda.is_available():
raise ValueError(
"You passed `use_flash_attn=True` to the SFTTrainer, but you don't have a GPU available. This is not supported."
)

if isinstance(model, PreTrainedModel) and use_flash_attn:
model = model.to_bettertransformer()

if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
Expand All @@ -148,6 +166,11 @@ def __init__(
model = prepare_model_for_int8_training(model)

model = get_peft_model(model, peft_config)
elif use_flash_attn and isinstance(model, PeftModel):
raise ValueError(
"You passed a `PeftModel` to the SFTTrainer, but you also passed `use_flash_attn=True`. This is not supported."
" You need to first create a transformers model and pass a peft_config to the SFTTrainer rather than directly passing a PeftModel."
)

if callbacks is None:
callbacks = [PeftSavingCallback]
Expand All @@ -169,6 +192,8 @@ def __init__(

self.dataset_num_proc = dataset_num_proc
self.dataset_batch_size = dataset_batch_size
self.use_flash_attn = use_flash_attn

if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError(
Expand Down Expand Up @@ -313,3 +338,16 @@ def tokenize(element):
)

return tokenized_dataset

@functools.wraps(Trainer.train)
def train(self, *args, **kwargs):
training_contexts = []

if self.use_flash_attn:
training_contexts.append(
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
)
training_contexts.append(torch.cuda.amp.autocast())

with ContextManagers(training_contexts):
super().train(*args, **kwargs)