Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 543a34e

Browse files
authored
Load and save SparseML QAT recipes (#2)
* Load and save SparseML QAT recipes * Conditionally load state_dict after applying recipes; code clean up
1 parent b7df172 commit 543a34e

File tree

3 files changed

+129
-12
lines changed

3 files changed

+129
-12
lines changed

examples/pytorch/question-answering/run_qa.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from datasets import load_dataset, load_metric
2929

3030
import transformers
31-
from sparseml_utils import SparseMLQATrainer, export_model
31+
from sparseml_utils import (
32+
SparseMLQATrainer,
33+
export_model,
34+
preprocess_state_dict,
35+
load_recipe
36+
)
3237
from transformers import (
3338
AutoConfig,
3439
AutoModelForQuestionAnswering,
@@ -311,13 +316,20 @@ def main():
311316
revision=model_args.model_revision,
312317
use_auth_token=True if model_args.use_auth_token else None,
313318
)
319+
320+
# Load and preprocess the state dict if the model existed (in this case we continue to train or
321+
# evaluate the model). The preprocessing step is to restore names of parameters changed by
322+
# QAT process.
323+
state_dict = preprocess_state_dict(model_args.model_name_or_path)
324+
314325
model = AutoModelForQuestionAnswering.from_pretrained(
315326
model_args.model_name_or_path,
316327
from_tf=bool(".ckpt" in model_args.model_name_or_path),
317328
config=config,
318329
cache_dir=model_args.cache_dir,
319330
revision=model_args.model_revision,
320331
use_auth_token=True if model_args.use_auth_token else None,
332+
state_dict=state_dict
321333
)
322334

323335
teacher_model = None
@@ -573,9 +585,14 @@ def post_processing_function(examples, features, predictions, stage="eval"):
573585
def compute_metrics(p: EvalPrediction):
574586
return metric.compute(predictions=p.predictions, references=p.label_ids)
575587

588+
# Load possible existing recipe and new one passed in through command argument
589+
existing_recipe = load_recipe(model_args.model_name_or_path)
590+
new_recipe = data_args.recipe
591+
576592
# Initialize our Trainer
577593
trainer = SparseMLQATrainer(
578-
data_args.recipe,
594+
model_args.model_name_or_path,
595+
[existing_recipe, new_recipe],
579596
teacher=teacher_model,
580597
distill_hardness=model_args.distill_hardness,
581598
distill_temperature=model_args.distill_temperature,
@@ -590,6 +607,11 @@ def compute_metrics(p: EvalPrediction):
590607
compute_metrics=compute_metrics,
591608
)
592609

610+
# Apply recipes to the model. This is necessary given that
611+
# sparsification methods such as QAT modified the model graph with their own learnable
612+
# parameters. They are also restored/loaded to the model.
613+
trainer.apply_recipes()
614+
593615
# Training
594616
if training_args.do_train:
595617
checkpoint = None

examples/pytorch/question-answering/sparseml_utils.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import inspect
21
import collections
2+
import inspect
33
import math
44
import os
5-
from typing import Any
5+
from typing import Any, Optional
66

77
import numpy
88
import torch
@@ -13,6 +13,7 @@
1313
from sparseml.pytorch.optim.optimizer import ScheduledOptimizer
1414
from sparseml.pytorch.utils import ModuleExporter, logger
1515
from trainer_qa import QuestionAnsweringTrainer
16+
from transformers.file_utils import RECIPE_NAME, WEIGHTS_NAME
1617
from transformers.modeling_outputs import QuestionAnsweringModelOutput
1718
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
1819

@@ -28,36 +29,74 @@ class SparseMLQATrainer(QuestionAnsweringTrainer):
2829
:param args, kwargs: arguments passed into parent class
2930
"""
3031

31-
def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs):
32+
def __init__(
33+
self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs
34+
):
3235
super().__init__(*args, **kwargs)
33-
self.recipe = recipe
36+
self.model_name_or_path = str(model_name_or_path)
37+
self.recipes = [recipe for recipe in recipes if recipe]
3438
self.teacher = teacher
3539
self.distill_hardness = distill_hardness
3640
self.distill_temperature = distill_temperature
3741
self.criterion = torch.nn.CrossEntropyLoss()
3842

39-
self.manager = None
43+
manager = None
44+
modifiers = []
45+
for recipe in self.recipes:
46+
manager = ScheduledModifierManager.from_yaml(recipe, modifiers)
47+
modifiers = manager.modifiers
48+
self.manager = manager
49+
4050
self.loggers = None
41-
if self.recipe is not None:
51+
if self.recipes is not None:
4252
loggers = []
4353
if "wandb" in self.args.report_to:
4454
loggers.append(logger.WANDBLogger())
4555
self.loggers = loggers
4656

57+
def apply_recipes(self, epoch=0.0):
58+
"""
59+
Apply recipes and sparsification related parameters to the model
60+
"""
61+
if self.manager is not None:
62+
org_state_dict = self.model.state_dict()
63+
self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers)
64+
new_state_dict = self.model.state_dict()
65+
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]
66+
67+
if os.path.isdir(self.model_name_or_path):
68+
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
69+
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
70+
state_dict = torch.load(archive_file, map_location="cpu")
71+
new_params_to_init = [p for p in new_params if p in state_dict.keys()]
72+
if new_params_to_init:
73+
# If we're here, the assumption is that all the new parameters introduced
74+
# by the recipes are available to be restore from the checkpoint---this is
75+
# case of evaluating pruned or pruned quantized models
76+
# Otherwise, we're in use cases such as quantizing a block pruned model in which
77+
# new parameters need to be initialized and trained during the QAT process
78+
_, missing_keys, unexpected_keys, _ = BertForQuestionAnswering._load_state_dict_into_model(
79+
self.model, state_dict, self.model_name_or_path, _fast_init=False
80+
)
81+
if missing_keys or unexpected_keys:
82+
raise RuntimeError(
83+
"Unexpected or missing keys detected when applying recipes to models\n"
84+
f"Missing keys: {missing_keys}\n"
85+
f"Unexpected keys: {unexpected_keys}\n"
86+
)
87+
4788
def create_optimizer(self):
4889
"""
4990
Create optimizer customized using SparseML
5091
"""
5192
super().create_optimizer()
52-
if self.recipe is None:
93+
if not self.recipes:
5394
return
5495
steps_per_epoch = math.ceil(
5596
len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu)
5697
)
57-
self.manager = ScheduledModifierManager.from_yaml(self.recipe)
5898
self.args.num_train_epochs = float(self.manager.max_epochs)
5999
if hasattr(self, "scaler"):
60-
self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers)
61100
self.scaler = self.manager.modify(
62101
self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler
63102
)
@@ -70,7 +109,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
70109
"""
71110
Computing loss using teacher/student distillation
72111
"""
73-
if self.recipe is None or self.teacher is None:
112+
if not self.recipes or self.teacher is None:
74113
return super().compute_loss(model, inputs, return_outputs=return_outputs)
75114

76115
outputs = model(**inputs)
@@ -114,11 +153,25 @@ def compute_loss(self, model, inputs, return_outputs=False):
114153
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
115154
return (loss, outputs) if return_outputs else loss
116155

156+
def save_model(self, output_dir: Optional[str] = None):
157+
"""
158+
Save model during or after training. The sparsification recipe will also be saved.
159+
"""
160+
super().save_model(output_dir=output_dir)
161+
if self.manager is not None:
162+
self._save_recipe(output_dir=output_dir)
163+
164+
def _save_recipe(self, output_dir: Optional[str] = None):
165+
output_dir = output_dir if output_dir is not None else self.args.output_dir
166+
output_recipe_file = os.path.join(output_dir, RECIPE_NAME)
167+
self.manager.save(output_recipe_file)
168+
117169

118170
class QuestionAnsweringModuleExporter(ModuleExporter):
119171
"""
120172
Module exporter class for Question Answering
121173
"""
174+
122175
@classmethod
123176
def get_output_names(self, out: Any):
124177
if not isinstance(out, QuestionAnsweringModelOutput):
@@ -173,3 +226,44 @@ def export_model(model, dataloader, output_dir, num_exported_samples):
173226
num_samples += 1
174227
if num_samples >= num_exported_samples:
175228
return
229+
230+
231+
def preprocess_state_dict(pretrained_model_name_or_path):
232+
"""
233+
Restore original parameter names that were changed by QAT process
234+
"""
235+
state_dict = None
236+
if pretrained_model_name_or_path is not None:
237+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
238+
if os.path.isdir(pretrained_model_name_or_path):
239+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)):
240+
recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME)
241+
manager = ScheduledModifierManager.from_yaml(recipe)
242+
modifiers = [m.__class__.__name__ for m in manager.modifiers]
243+
is_qat_recipe = "QuantizationModifier" in modifiers
244+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
245+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
246+
state_dict = torch.load(archive_file, map_location="cpu")
247+
removed_keys = (
248+
[key for key in state_dict if (key.endswith(".module.weight") or key.endswith(".module.bias"))]
249+
if is_qat_recipe
250+
else []
251+
)
252+
for key in removed_keys:
253+
new_key = key.replace(".module", "")
254+
state_dict[new_key] = state_dict[key]
255+
state_dict.pop(key)
256+
return state_dict
257+
258+
259+
def load_recipe(pretrained_model_name_or_path):
260+
"""
261+
Load recipe from the model directory
262+
"""
263+
recipe = None
264+
if pretrained_model_name_or_path is not None:
265+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
266+
if os.path.isdir(pretrained_model_name_or_path):
267+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, RECIPE_NAME)):
268+
recipe = os.path.join(pretrained_model_name_or_path, RECIPE_NAME)
269+
return recipe

src/transformers/file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@
220220
CONFIG_NAME = "config.json"
221221
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
222222
MODEL_CARD_NAME = "modelcard.json"
223+
RECIPE_NAME = "recipe.yaml"
223224

224225
SENTENCEPIECE_UNDERLINE = "▁"
225226
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility

0 commit comments

Comments
 (0)