Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 32ed5bd

Browse files
authored
Question answering trainer with SparseML integration (#1)
* Initial commit: QA and Distill QA with SparseML integ * Overwrite scaler's step if it exists (for amp mode) * Add wandb logger * Remove distill script (to be unified with run_qa), recipes (to be moved to sparseml) * Include distillation into run_qa, code clean up * Unify distill/non-distill trainer, simplify onnx export * Simplify variable names for distillation
1 parent 80d712f commit 32ed5bd

File tree

2 files changed

+162
-3
lines changed

2 files changed

+162
-3
lines changed

examples/pytorch/question-answering/run_qa.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from dataclasses import dataclass, field
2525
from typing import Optional
2626

27+
import numpy
2728
from datasets import load_dataset, load_metric
2829

2930
import transformers
30-
from trainer_qa import QuestionAnsweringTrainer
31+
from sparseml_utils import SparseMLQATrainer, export_model
3132
from transformers import (
3233
AutoConfig,
3334
AutoModelForQuestionAnswering,
@@ -56,10 +57,18 @@ class ModelArguments:
5657
"""
5758
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
5859
"""
59-
6060
model_name_or_path: str = field(
6161
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
6262
)
63+
distill_teacher: Optional[str] = field(
64+
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
65+
)
66+
distill_temperature: Optional[float] = field(
67+
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
68+
)
69+
distill_hardness: Optional[float] = field(
70+
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
71+
)
6372
config_name: Optional[str] = field(
6473
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
6574
)
@@ -89,6 +98,14 @@ class DataTrainingArguments:
8998
Arguments pertaining to what data we are going to input our model for training and eval.
9099
"""
91100

101+
recipe: Optional[str] = field(
102+
default=None,
103+
metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
104+
"for more information"},
105+
)
106+
onnx_export_path: Optional[str] = field(
107+
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
108+
)
92109
dataset_name: Optional[str] = field(
93110
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
94111
)
@@ -300,6 +317,18 @@ def main():
300317
use_auth_token=True if model_args.use_auth_token else None,
301318
)
302319

320+
teacher_model = None
321+
if model_args.distill_teacher is not None:
322+
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
323+
model_args.distill_teacher,
324+
from_tf=bool(".ckpt" in model_args.distill_teacher),
325+
config=config,
326+
cache_dir=model_args.cache_dir,
327+
)
328+
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
329+
params = sum([numpy.prod(p.size()) for p in teacher_model_parameters])
330+
logger.info("Teacher Model has %s parameters", params)
331+
303332
# Tokenizer check: this script requires a fast tokenizer.
304333
if not isinstance(tokenizer, PreTrainedTokenizerFast):
305334
raise ValueError(
@@ -543,7 +572,11 @@ def compute_metrics(p: EvalPrediction):
543572
return metric.compute(predictions=p.predictions, references=p.label_ids)
544573

545574
# Initialize our Trainer
546-
trainer = QuestionAnsweringTrainer(
575+
trainer = SparseMLQATrainer(
576+
data_args.recipe,
577+
teacher=teacher_model,
578+
distill_hardness=model_args.distill_hardness,
579+
distill_temperature=model_args.distill_temperature,
547580
model=model,
548581
args=training_args,
549582
train_dataset=train_dataset if training_args.do_train else None,
@@ -612,6 +645,11 @@ def compute_metrics(p: EvalPrediction):
612645

613646
trainer.push_to_hub(**kwargs)
614647

648+
if data_args.onnx_export_path:
649+
logger.info("*** Export to ONNX ***")
650+
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
651+
export_model(model, eval_dataloader, data_args.onnx_export_path)
652+
615653

616654
def _mp_fn(index):
617655
# For xla_spawn (TPUs)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import math
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from sparseml.pytorch.optim.manager import ScheduledModifierManager
7+
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
8+
from sparseml.pytorch.utils import ModuleExporter, logger
9+
from trainer_qa import QuestionAnsweringTrainer
10+
11+
12+
class SparseMLQATrainer(QuestionAnsweringTrainer):
13+
"""
14+
Question Answering trainer with SparseML integration
15+
16+
:param recipe: recipe for model sparsification
17+
:param teacher: teacher model for distillation
18+
:param distill_hardness: ratio of loss by teacher targets (between 0 and 1)
19+
:param distill_temperature: temperature for distillation
20+
:param args, kwargs: arguments passed into parent class
21+
"""
22+
23+
def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs):
24+
super().__init__(*args, **kwargs)
25+
self.recipe = recipe
26+
self.teacher = teacher
27+
self.distill_hardness = distill_hardness
28+
self.distill_temperature = distill_temperature
29+
self.criterion = torch.nn.CrossEntropyLoss()
30+
31+
self.manager = None
32+
self.loggers = None
33+
if self.recipe is not None:
34+
loggers = []
35+
if "wandb" in self.args.report_to:
36+
loggers.append(logger.WANDBLogger())
37+
self.loggers = loggers
38+
39+
def create_optimizer(self):
40+
"""
41+
Create optimizer customized using SparseML
42+
"""
43+
super().create_optimizer()
44+
if self.recipe is None:
45+
return
46+
steps_per_epoch = math.ceil(
47+
len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu)
48+
)
49+
self.manager = ScheduledModifierManager.from_yaml(self.recipe)
50+
self.args.num_train_epochs = float(self.manager.max_epochs)
51+
if hasattr(self, "scaler"):
52+
self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers)
53+
self.scaler = self.manager.modify(
54+
self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler
55+
)
56+
else:
57+
self.optimizer = ScheduledOptimizer(
58+
self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers
59+
)
60+
61+
def compute_loss(self, model, inputs, return_outputs=False):
62+
"""
63+
Computing loss using teacher/student distillation
64+
"""
65+
if self.recipe is None or self.teacher is None:
66+
return super().compute_loss(model, inputs, return_outputs=return_outputs)
67+
68+
outputs = model(**inputs)
69+
if self.teacher is None:
70+
loss = outputs["loss"]
71+
else:
72+
input_device = inputs["input_ids"].device
73+
self.teacher = self.teacher.to(input_device)
74+
start_logits_student = outputs["start_logits"]
75+
end_logits_student = outputs["end_logits"]
76+
start_logits_label = inputs["start_positions"]
77+
end_logits_label = inputs["end_positions"]
78+
with torch.no_grad():
79+
teacher_output = self.teacher(
80+
input_ids=inputs["input_ids"],
81+
token_type_ids=inputs["token_type_ids"],
82+
attention_mask=inputs["attention_mask"],
83+
)
84+
start_logits_teacher = teacher_output["start_logits"]
85+
end_logits_teacher = teacher_output["end_logits"]
86+
loss_start = (
87+
F.kl_div(
88+
input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1),
89+
target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1),
90+
reduction="batchmean",
91+
)
92+
* (self.distill_temperature ** 2)
93+
)
94+
loss_end = (
95+
F.kl_div(
96+
input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1),
97+
target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1),
98+
reduction="batchmean",
99+
)
100+
* (self.distill_temperature ** 2)
101+
)
102+
teacher_loss = (loss_start + loss_end) / 2.0
103+
loss_start = self.criterion(start_logits_student, start_logits_label)
104+
loss_end = self.criterion(end_logits_student, end_logits_label)
105+
label_loss = (loss_start + loss_end) / 2.0
106+
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
107+
return (loss, outputs) if return_outputs else loss
108+
109+
110+
def export_model(model, dataloader, output_dir):
111+
"""
112+
Export a trained model to ONNX
113+
:param model: trained model
114+
:param dataloader: dataloader to get sample batch
115+
:param output_dir: output directory for ONNX model
116+
"""
117+
exporter = ModuleExporter(model, output_dir=output_dir)
118+
for _, sample_batch in enumerate(dataloader):
119+
sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"])
120+
exporter.export_onnx(sample_batch=sample_input, convert_qat=True)
121+
break

0 commit comments

Comments
 (0)