Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 86 additions & 43 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import json
import math
import os
import pathlib
import sys
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional

import numpy as np
import torch
import torch.distributed
import transformers
from aenum import extend_enum
from torch.optim.lr_scheduler import LambdaLR
from training_utils import (
compute_metrics,
create_data_loader,
create_distributed_data_loader,
dynamic_batch_size,
initialize_tokenizer,
preprocess_logits_for_metrics,
print_rank0,
print_some_examples,
tokenize_and_cache,
)
from transformers import Trainer

from functionary.prompt_template import get_prompt_template_by_version
from functionary.train import training_utils
from functionary.train.custom_datasets import read_dataset

extend_enum(
transformers.trainer_utils.SchedulerType,
Expand Down Expand Up @@ -47,17 +61,8 @@ def lr_lambda(current_step):
get_scheduler
)

from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, Trainer

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from typing import Union

from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version
from functionary.train.custom_datasets import read_dataset
from functionary.train import training_utils
from training_utils import print_rank0

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))

Expand Down Expand Up @@ -139,6 +144,23 @@ def trainer_save_model_safe(trainer: transformers.Trainer):
trainer.save_model()


"""
Below is the updated train() function from LEVENT OZBEK.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Authorship tracking is a responsibility of git. We should remove all authorship info from the code.

Most of the changes are identical to those in train_lora.py. I simply applied the changes to the utility code in training_utils.py
I commented out the original train() function

- training_utils.tokenize_and_cache() is used for both training and evaluation datasets to avoid repetition.
- dynamic_batch_size() function auto adjusts batch sizes based on token counts. I did not implement this in train_lora.py since loras are trained on a smaller data so I felt that it wasn't too necessary there.
- DataLoaders are constructed using BatchSampler to dynamically adjust the batch size per epoch.
- distributed DataLoader is used if local_rank != -1.
- updated to use the optimized preprocess_logits_for_metrics dynamically compute_metrics from training_utils.py.

Advantages of These Changes:
- handles datasets with varying sequence lengths dynamically
- supports both single-GPU and distributed setups.
"""


Comment on lines +149 to +163
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates would rather preferred to go into PR description, not the code

def train():
"""Training loop"""

Expand Down Expand Up @@ -186,10 +208,9 @@ def train():
torch_dtype=compute_dtype,
config=config,
cache_dir=training_args.cache_dir,
attn_implementation="flash_attention_2", # use_flash_attention_2 is replaced by this from version: 4.36.0
attn_implementation="flash_attention_2",
)
model.config.use_cache = False
# Activate computing load balancing loss iin MixtralForCausalLM
if hasattr(model.config, "output_router_logits"):
setattr(model.config, "output_router_logits", True)
print_rank0("Activate computing load balancing loss")
Expand All @@ -199,7 +220,7 @@ def train():
training_args.prompt_template_version
)

tokenizer = training_utils.initialize_tokenizer(
tokenizer = initialize_tokenizer(
model=model,
model_name_or_path=model_args.model_name_or_path,
prompt_template=prompt_template,
Expand All @@ -213,7 +234,6 @@ def train():

tokenizer.save_pretrained(training_args.output_dir)

# get id of added tokens to compute the accuracy of predicing the token
id2token = {
tokenizer.encode(token)[-1]: token
for token in prompt_template.get_additional_tokens()
Expand All @@ -222,57 +242,81 @@ def train():

assert data_args.train_data_path is not None, "Please provide a training data file."

train_dataset = read_dataset(
# Cache and tokenize training data
raw_train_dataset = read_dataset(
model_args.model_name_or_path, data_args, training_args, tokenizer, "train"
)
train_dataset = tokenize_and_cache(
raw_train_dataset, tokenizer, training_args.cache_dir
)

if torch.distributed.get_rank() == 0:
print(f"Training Data Loaded: #{len(train_dataset)}")

if training_args.do_eval:
eval_dataset = read_dataset(
# Cache and tokenize evaluation data
raw_eval_dataset = read_dataset(
model_args.model_name_or_path,
data_args,
training_args,
tokenizer,
"validation",
)

eval_dataset = tokenize_and_cache(
raw_eval_dataset, tokenizer, training_args.cache_dir
)
if torch.distributed.get_rank() == 0:
print(f"Eval Data Loaded: #{len(eval_dataset)}")

print_rank0("***** HERE ARE SOME EXAMPLES FROM TRAINING ****")
training_utils.print_some_examples(train_dataset, tokenizer)
print_some_examples(train_dataset, tokenizer)

if training_args.do_eval:
print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***")
training_utils.print_some_examples(eval_dataset, tokenizer)

def preprocess_logits_for_metrics(logits, labels):
return training_utils.preprocess_logits_for_metrics(
logits, labels, len(tokenizer)
)
print_some_examples(eval_dataset, tokenizer)

def compute_metrics(eval_preds):
return training_utils.compute_metrics(eval_preds, id2token, tokenizer)
# Dynamic batch size based on max tokens per batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While dynamic batch size might seem a good idea for dealing with memory issues, it would cause instabilities in training. Due to the difference in gradient updates per conversation. I would rather prefer stability over memory efficiency. Stable updates with higher cost of GPUs should be preferred over a cheaper&faster training

max_tokens_per_batch = 2048 # You can adjust this as needed
train_batch_sizes = dynamic_batch_size(
train_dataset, max_tokens_per_batch, tokenizer
)
print_rank0(f"Dynamic train batch sizes: {train_batch_sizes}")

if training_args.do_eval:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
eval_batch_sizes = dynamic_batch_size(
eval_dataset, max_tokens_per_batch, tokenizer
)
else:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
print_rank0(f"Dynamic eval batch sizes: {eval_batch_sizes}")

# DataLoaders with dynamic batch sizes
if training_args.local_rank == -1: # Single-GPU
train_loader = create_data_loader(
train_dataset, batch_size=max(train_batch_sizes), num_workers=4
)
if training_args.do_eval:
eval_loader = create_data_loader(
eval_dataset, batch_size=max(eval_batch_sizes), num_workers=4
)
else: # Multi-GPU
train_loader = create_distributed_data_loader(
train_dataset, batch_size=max(train_batch_sizes)
)
if training_args.do_eval:
eval_loader = create_distributed_data_loader(
eval_dataset, batch_size=max(eval_batch_sizes)
)

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_loader.dataset,
eval_dataset=eval_loader.dataset if training_args.do_eval else None,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval else None
),
)

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
Expand All @@ -281,7 +325,6 @@ def compute_metrics(eval_preds):

trainer.save_state()

# FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format
if trainer.is_fsdp_enabled:
trainer_save_model_safe(trainer=trainer)
else:
Expand Down
Loading