Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from dataclasses import dataclass, field
from typing import Optional

import numpy
from datasets import load_dataset, load_metric

import transformers
from trainer_qa import QuestionAnsweringTrainer
from sparseml_utils import SparseMLQATrainer, export_model
from transformers import (
AutoConfig,
AutoModelForQuestionAnswering,
Expand Down Expand Up @@ -56,10 +57,18 @@ class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
)
distill_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
)
distill_hardness: Optional[float] = field(
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -89,6 +98,14 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

recipe: Optional[str] = field(
default=None,
metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
Expand Down Expand Up @@ -300,6 +317,18 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
config=config,
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([numpy.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)

# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand Down Expand Up @@ -543,7 +572,11 @@ def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)

# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
trainer = SparseMLQATrainer(
data_args.recipe,
teacher=teacher_model,
distill_hardness=model_args.distill_hardness,
distill_temperature=model_args.distill_temperature,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand Down Expand Up @@ -612,6 +645,11 @@ def compute_metrics(p: EvalPrediction):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
export_model(model, eval_dataloader, data_args.onnx_export_path)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
121 changes: 121 additions & 0 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import math

import torch
import torch.nn.functional as F

from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
from sparseml.pytorch.utils import ModuleExporter, logger
from trainer_qa import QuestionAnsweringTrainer


class SparseMLQATrainer(QuestionAnsweringTrainer):
"""
Question Answering trainer with SparseML integration

:param recipe: recipe for model sparsification
:param teacher: teacher model for distillation
:param distill_hardness: ratio of loss by teacher targets (between 0 and 1)
:param distill_temperature: temperature for distillation
:param args, kwargs: arguments passed into parent class
"""

def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.recipe = recipe
self.teacher = teacher
self.distill_hardness = distill_hardness
self.distill_temperature = distill_temperature
self.criterion = torch.nn.CrossEntropyLoss()

self.manager = None
self.loggers = None
if self.recipe is not None:
loggers = []
if "wandb" in self.args.report_to:
loggers.append(logger.WANDBLogger())
self.loggers = loggers

def create_optimizer(self):
"""
Create optimizer customized using SparseML
"""
super().create_optimizer()
if self.recipe is None:
return
steps_per_epoch = math.ceil(
len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu)
)
self.manager = ScheduledModifierManager.from_yaml(self.recipe)
self.args.num_train_epochs = float(self.manager.max_epochs)
if hasattr(self, "scaler"):
self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers)
self.scaler = self.manager.modify(
self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler
)
else:
self.optimizer = ScheduledOptimizer(
self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers
)

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if self.recipe is None or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
if self.teacher is None:
loss = outputs["loss"]
else:
input_device = inputs["input_ids"].device
self.teacher = self.teacher.to(input_device)
start_logits_student = outputs["start_logits"]
end_logits_student = outputs["end_logits"]
start_logits_label = inputs["start_positions"]
end_logits_label = inputs["end_positions"]
with torch.no_grad():
teacher_output = self.teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
start_logits_teacher = teacher_output["start_logits"]
end_logits_teacher = teacher_output["end_logits"]
loss_start = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a function in SparseML for distillation, would like to update that to match this and use that

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll need to update this part later

F.kl_div(
input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
loss_end = (
F.kl_div(
input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
teacher_loss = (loss_start + loss_end) / 2.0
loss_start = self.criterion(start_logits_student, start_logits_label)
loss_end = self.criterion(end_logits_student, end_logits_label)
label_loss = (loss_start + loss_end) / 2.0
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
return (loss, outputs) if return_outputs else loss


def export_model(model, dataloader, output_dir):
"""
Export a trained model to ONNX
:param model: trained model
:param dataloader: dataloader to get sample batch
:param output_dir: output directory for ONNX model
"""
exporter = ModuleExporter(model, output_dir=output_dir)
for _, sample_batch in enumerate(dataloader):
sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"])
exporter.export_onnx(sample_batch=sample_input, convert_qat=True)
break