Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 1803e48

Browse files
spacemanidolnatuan
authored andcommitted
Setting up Text Classification for SparseML
1 parent 0e31979 commit 1803e48

File tree

3 files changed

+123
-5
lines changed

3 files changed

+123
-5
lines changed

examples/pytorch/question-answering/run_qa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ModelArguments:
6969
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
7070
)
7171
distill_hardness: Optional[float] = field(
72-
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
72+
default=0.5, metadata={"help": "Proportion of loss coming from teacher model."}
7373
)
7474
config_name: Optional[str] = field(
7575
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
@@ -217,7 +217,6 @@ def __post_init__(self):
217217
extension = self.test_file.split(".")[-1]
218218
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
219219

220-
221220
def main():
222221
# See all possible arguments in src/transformers/training_args.py
223222
# or by passing the --help flag to this script.

examples/pytorch/text-classification/run_glue.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from datasets import load_dataset, load_metric
2828

2929
import transformers
30+
from sparseml_utils import GLUEModuleExporter, SparseMLGLUETrainer
3031
from transformers import (
3132
AutoConfig,
3233
AutoModelForSequenceClassification,
@@ -35,11 +36,11 @@
3536
EvalPrediction,
3637
HfArgumentParser,
3738
PretrainedConfig,
38-
Trainer,
3939
TrainingArguments,
4040
default_data_collator,
4141
set_seed,
4242
)
43+
from transformers.sparse import export_model, load_recipe, preprocess_state_dict
4344
from transformers.trainer_utils import get_last_checkpoint
4445
from transformers.utils import check_min_version
4546

@@ -72,6 +73,19 @@ class DataTrainingArguments:
7273
the command line.
7374
"""
7475

76+
recipe: Optional[str] = field(
77+
default=None,
78+
metadata={
79+
"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
80+
"for more information"
81+
},
82+
)
83+
onnx_export_path: Optional[str] = field(
84+
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
85+
)
86+
num_exported_samples: Optional[int] = field(
87+
default=20, metadata={"help": "Number of exported samples, default to 20"}
88+
)
7589
task_name: Optional[str] = field(
7690
default=None,
7791
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
@@ -155,6 +169,9 @@ class ModelArguments:
155169
model_name_or_path: str = field(
156170
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
157171
)
172+
distill_teacher: Optional[str] = field(
173+
default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"}
174+
)
158175
config_name: Optional[str] = field(
159176
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
160177
)
@@ -305,6 +322,12 @@ def main():
305322
#
306323
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
307324
# download model & vocab.
325+
326+
# Load and preprocess the state dict if the model existed (in this case we continue to train or
327+
# evaluate the model). The preprocessing step is to restore names of parameters changed by
328+
# QAT process
329+
state_dict = preprocess_state_dict(model_args.model_name_or_path)
330+
308331
config = AutoConfig.from_pretrained(
309332
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
310333
num_labels=num_labels,
@@ -327,8 +350,19 @@ def main():
327350
cache_dir=model_args.cache_dir,
328351
revision=model_args.model_revision,
329352
use_auth_token=True if model_args.use_auth_token else None,
353+
state_dict=state_dict,
330354
)
331355

356+
teacher_model = None
357+
if model_args.distill_teacher is not None:
358+
teacher_model = AutoModelForSequenceClassification.from_pretrained(
359+
model_args.distill_teacher,
360+
from_tf=bool(".ckpt" in model_args.distill_teacher),
361+
cache_dir=model_args.cache_dir,
362+
)
363+
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
364+
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
365+
logger.info("Teacher Model has %s parameters", params)
332366
# Preprocessing the datasets
333367
if data_args.task_name is not None:
334368
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
@@ -445,17 +479,29 @@ def compute_metrics(p: EvalPrediction):
445479
else:
446480
data_collator = None
447481

482+
# Load possible existing recipe and new one passed in through command argument
483+
existing_recipe = load_recipe(model_args.model_name_or_path)
484+
new_recipe = data_args.recipe
485+
448486
# Initialize our Trainer
449-
trainer = Trainer(
487+
trainer = SparseMLGLUETrainer(
488+
model_args.model_name_or_path,
489+
[existing_recipe, new_recipe],
490+
teacher=teacher_model,
450491
model=model,
451492
args=training_args,
452493
train_dataset=train_dataset if training_args.do_train else None,
453494
eval_dataset=eval_dataset if training_args.do_eval else None,
454-
compute_metrics=compute_metrics,
455495
tokenizer=tokenizer,
456496
data_collator=data_collator,
497+
compute_metrics=compute_metrics,
457498
)
458499

500+
# Apply recipes to the model. This is necessary given that
501+
# sparsification methods such as QAT modified the model graph with their own learnable
502+
# parameters. They are also restored/loaded to the model.
503+
trainer.apply_recipes()
504+
459505
# Training
460506
if training_args.do_train:
461507
checkpoint = None
@@ -536,6 +582,12 @@ def compute_metrics(p: EvalPrediction):
536582

537583
trainer.push_to_hub(**kwargs)
538584

585+
if data_args.onnx_export_path:
586+
logger.info("*** Export to ONNX ***")
587+
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
588+
exporter = GLUEModuleExporter(model, output_dir=data_args.onnx_export_path)
589+
export_model(exporter, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)
590+
539591

540592
def _mp_fn(index):
541593
# For xla_spawn (TPUs)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Any
2+
3+
import numpy
4+
import torch
5+
6+
from sparseml.pytorch.utils import ModuleExporter, device_of
7+
8+
from transformers.modeling_outputs import SequenceClassifierOutput
9+
from transformers.sparse import SparseMLTrainer
10+
11+
12+
class SparseMLGLUETrainer(SparseMLTrainer):
13+
"""
14+
GLUE trainer with SparseML integration
15+
16+
:param recipe: recipe for model sparsification
17+
:param teacher: teacher model for distillation
18+
:param distill_hardness: ratio of loss by teacher targets (between 0 and 1)
19+
:param distill_temperature: temperature for distillation
20+
:param args, kwargs: arguments passed into parent class
21+
"""
22+
23+
def compute_loss(self, model, inputs, return_outputs=False):
24+
"""
25+
Computing loss using teacher/student distillation
26+
"""
27+
if not self.recipes or self.teacher is None:
28+
return super().compute_loss(model, inputs, return_outputs=return_outputs)
29+
30+
student_outputs = model(**inputs)
31+
loss = student_outputs["loss"]
32+
33+
target_device = device_of(inputs)
34+
self.teacher.to(target_device)
35+
with torch.no_grad():
36+
teacher_outputs = self.teacher(
37+
input_ids=inputs["input_ids"],
38+
token_type_ids=inputs["token_type_ids"],
39+
attention_mask=inputs["attention_mask"],
40+
)
41+
steps_in_epoch = -1 # Unused
42+
loss = self.manager.loss_update(
43+
loss,
44+
model,
45+
self.optimizer,
46+
self.state.epoch,
47+
steps_in_epoch,
48+
global_step=self.state.global_step,
49+
student_outputs=student_outputs,
50+
teacher_outputs=teacher_outputs,
51+
)
52+
return (loss, student_outputs) if return_outputs else loss
53+
54+
55+
class GLUEModuleExporter(ModuleExporter):
56+
"""
57+
Module exporter class for Sequence Classification
58+
"""
59+
60+
@classmethod
61+
def get_output_names(self, out: Any):
62+
if not isinstance(out, SequenceClassifierOutput):
63+
raise ValueError("Expected SequenceClassifierOutput, got {type(out)}")
64+
expected = ["logits"]
65+
if numpy.any([name for name in expected if name not in out]):
66+
raise ValueError("Expected output names not found in model output")
67+
return expected

0 commit comments

Comments
 (0)