Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from transformers.utils import check_min_version
from utils_qa import postprocess_qa_predictions

import wandb
wandb.init(project="sparse-transfer-downstream-qa-daniel")

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.7.0.dev0")
Expand Down Expand Up @@ -217,7 +219,6 @@ def __post_init__(self):
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down
38 changes: 25 additions & 13 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,44 @@ class SparseMLQATrainer(SparseMLTrainer, QuestionAnsweringTrainer):
:param distill_temperature: temperature for distillation
:param args, kwargs: arguments passed into parent class
"""

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes or self.teacher is None:
if not self.recipes and self.teachers is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
if self.teacher is None:
if self.teachers is None:
loss = outputs["loss"]
else:
input_device = inputs["input_ids"].device
self.teacher = self.teacher.to(input_device)
start_logits_student = outputs["start_logits"]
end_logits_student = outputs["end_logits"]
start_logits_label = inputs["start_positions"]
end_logits_label = inputs["end_positions"]
with torch.no_grad():
teacher_output = self.teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
start_logits_teacher = teacher_output["start_logits"]
end_logits_teacher = teacher_output["end_logits"]
if self.multi_gpu:
input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus))
start_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda')
end_logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda')
for i in range(self.num_gpus):
with torch.no_grad():
input_device = self.teachers[i].device
teacher_output = self.teachers[i](input_ids[i].to(input_device))
start_logits_teacher = torch.cat((start_logits_teacher, teacher_output["start_logits"].to('cuda')), dim=0)
end_logits_teacher = torch.cat((end_logits_teacher, teacher_output["end_logits"].to('cuda')), dim=0)
else: # CPU or single GPU
input_device = inputs["input_ids"].device
self.teachers = self.teachers.to(input_device)
with torch.no_grad():
teacher_output = self.teachers(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
start_logits_teacher = teacher_output["start_logits"]
end_logits_teacher = teacher_output["end_logits"]

loss_start = (
F.kl_div(
input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1),
Expand Down
69 changes: 63 additions & 6 deletions examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dataclasses import dataclass, field
from typing import Optional

import wandb
import numpy as np
from datasets import load_dataset, load_metric

Expand All @@ -40,13 +41,13 @@
default_data_collator,
set_seed,
)

from sparseml_utils import GLUEModuleExporter
from transformers.sparse import export_model, SparseMLTrainer, load_recipe, preprocess_state_dict
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.7.0.dev0")

task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
Expand All @@ -71,7 +72,17 @@ class DataTrainingArguments:
into argparse arguments to be able to specify them on
the command line.
"""

recipe: Optional[str] = field(
default=None,
metadata={"help": "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
"for more information"},
)
onnx_export_path: Optional[str] = field(
default=None, metadata={"help": "The filename and path which will be where onnx model is outputed"}
)
num_exported_samples: Optional[int] = field(
default=20, metadata={"help": "Number of exported samples, default to 20"}
)
task_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
Expand Down Expand Up @@ -155,6 +166,15 @@ class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained text classification model"}
)
distill_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
)
distill_hardness: Optional[float] = field(
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -305,6 +325,13 @@ def main():
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

# Load and preprocess the state dict if the model existed (in this case we continue to train or
# evaluate the model). The preprocessing step is to restore names of parameters changed by
# QAT process
state_dict = preprocess_state_dict(model_args.model_name_or_path)


config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
Expand All @@ -327,8 +354,19 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
state_dict=state_dict,
)

teacher_model = None
if model_args.distill_teacher is not None:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
model_args.distill_teacher,
from_tf=bool(".ckpt" in model_args.distill_teacher),
cache_dir=model_args.cache_dir,
)
teacher_model_parameters = filter(lambda p: p.requires_grad, teacher_model.parameters())
params = sum([np.prod(p.size()) for p in teacher_model_parameters])
logger.info("Teacher Model has %s parameters", params)
# Preprocessing the datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
Expand Down Expand Up @@ -445,17 +483,31 @@ def compute_metrics(p: EvalPrediction):
else:
data_collator = None

# Load possible existing recipe and new one passed in through command argument
existing_recipe = load_recipe(model_args.model_name_or_path)
new_recipe = data_args.recipe

# Initialize our Trainer
trainer = Trainer(
trainer = SparseMLTrainer(
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
distill_hardness=model_args.distill_hardness,
distill_temperature=model_args.distill_temperature,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)

# Apply recipes to the model. This is necessary given that
# sparsification methods such as QAT modified the model graph with their own learnable
# parameters. They are also restored/loaded to the model.
trainer.apply_recipes()

# Training
if training_args.do_train:
checkpoint = None
Expand Down Expand Up @@ -536,6 +588,11 @@ def compute_metrics(p: EvalPrediction):

trainer.push_to_hub(**kwargs)

if data_args.onnx_export_path:
logger.info("*** Export to ONNX ***")
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
export_model(model, eval_dataloader, data_args.onnx_export_path, data_args.num_exported_samples)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
23 changes: 23 additions & 0 deletions examples/pytorch/text-classification/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any

import numpy
import torch
import torch.nn.functional as F

from sparseml.pytorch.utils import ModuleExporter

from transformers.modeling_outputs import SequenceClassifierOutput

class GLUEModuleExporter(ModuleExporter):
"""
Module exporter class for Sequence Classification
"""

@classmethod
def get_output_names(self, out: Any):
if not isinstance(out, SequenceClassifierOutput):
raise ValueError("Expected SequenceClassifierOutput, got {type(out)}")
expected = ["logits"]
if numpy.any([name for name in expected if name not in out]):
raise ValueError("Expected output names not found in model output")
return expected
59 changes: 58 additions & 1 deletion src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy
import torch
import torch.nn.functional as F

import onnxruntime
from sparseml.pytorch.optim.manager import ScheduledModifierManager
Expand All @@ -32,7 +33,14 @@ def __init__(
super().__init__(*args, **kwargs)
self.model_name_or_path = str(model_name_or_path)
self.recipes = [recipe for recipe in recipes if recipe]
self.teacher = teacher
self.teachers = teacher
self.multi_gpu = False
if torch.cuda.device_count() and teacher != None:
self.multi_gpu = True
self.num_gpus = torch.cuda.device_count()
self.teachers = [teacher for i in range(self.num_gpus)]
for i in range(self.num_gpus):
self.teachers[i] = self.teachers[i].to(i)
self.distill_hardness = distill_hardness
self.distill_temperature = distill_temperature
self.criterion = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -82,6 +90,54 @@ def apply_recipes(self, epoch=0.0):
f"Unexpected keys: {unexpected_keys}\n"
)

def compute_loss(self, model, inputs, return_outputs=False):
"""
Computing loss using teacher/student distillation
"""
if not self.recipes and self.teachers is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
if self.teachers is None:
loss = outputs["loss"]
else:
logits_student = outputs["logits"]
if self.multi_gpu:
input_ids = torch.split(inputs['input_ids'], int(inputs['input_ids'].shape[0]/self.num_gpus))
token_type_ids = torch.split(inputs['token_type_ids'], int(inputs['token_type_ids'].shape[0]/self.num_gpus))
attention_mask = torch.split(inputs['attention_mask'], int(inputs['attention_mask'].shape[0]/self.num_gpus))
logits_teacher = torch.empty((0,inputs['input_ids'].shape[1]), dtype=torch.int32, device='cuda')
for i in range(self.num_gpus):
with torch.no_grad():
input_device = self.teachers[i].device
teacher_output = self.teachers[i](
input_ids=input_ids[i].to(input_device),
token_type_ids=token_type_ids[i].to(input_device),
attention_mask=attention_mask[i].to(input_device)
)
logits_teacher = torch.cat((logits_teacher, teacher_output["logits"].to('cuda')), dim=0)
else: # CPU or single GPU
input_device = inputs["input_ids"].device
self.teachers = self.teachers.to(input_device)
with torch.no_grad():
teacher_output = self.teachers(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
logits_teacher = teacher_output["start_logits"]

teacher_loss = (F.kl_div(
input=F.log_softmax(logits_student / self.distill_temperature, dim=-1),
target=F.softmax(logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)

loss = ((1 - self.distill_hardness) * outputs["loss"]) + (self.distill_hardness * teacher_loss)
return (loss, outputs) if return_outputs else loss

def create_optimizer(self):
"""
Create optimizer customized using SparseML
Expand All @@ -102,6 +158,7 @@ def create_optimizer(self):
self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers
)


def create_scheduler(self, num_training_steps: int):
"""
Override LR scheduler if the SparseML manager has LR modifiers, otherwise
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def set_seed(seed: int):
if is_torch_available():
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# ^^ safe to call this function even if cuda is not available
if is_tf_available():
tf.random.set_seed(seed)
Expand Down