1- import inspect
21import collections
2+ import inspect
33import math
44import os
5- from typing import Any
5+ from typing import Any , Optional
66
77import numpy
88import torch
1313from sparseml .pytorch .optim .optimizer import ScheduledOptimizer
1414from sparseml .pytorch .utils import ModuleExporter , logger
1515from trainer_qa import QuestionAnsweringTrainer
16+ from transformers .file_utils import RECIPE_NAME , WEIGHTS_NAME
1617from transformers .modeling_outputs import QuestionAnsweringModelOutput
1718from transformers .models .bert .modeling_bert import BertForQuestionAnswering
1819
@@ -28,36 +29,74 @@ class SparseMLQATrainer(QuestionAnsweringTrainer):
2829 :param args, kwargs: arguments passed into parent class
2930 """
3031
31- def __init__ (self , recipe , teacher = None , distill_hardness = 0.5 , distill_temperature = 2.0 , * args , ** kwargs ):
32+ def __init__ (
33+ self , model_name_or_path , recipes , teacher = None , distill_hardness = 0.5 , distill_temperature = 2.0 , * args , ** kwargs
34+ ):
3235 super ().__init__ (* args , ** kwargs )
33- self .recipe = recipe
36+ self .model_name_or_path = str (model_name_or_path )
37+ self .recipes = [recipe for recipe in recipes if recipe ]
3438 self .teacher = teacher
3539 self .distill_hardness = distill_hardness
3640 self .distill_temperature = distill_temperature
3741 self .criterion = torch .nn .CrossEntropyLoss ()
3842
39- self .manager = None
43+ manager = None
44+ modifiers = []
45+ for recipe in self .recipes :
46+ manager = ScheduledModifierManager .from_yaml (recipe , modifiers )
47+ modifiers = manager .modifiers
48+ self .manager = manager
49+
4050 self .loggers = None
41- if self .recipe is not None :
51+ if self .recipes is not None :
4252 loggers = []
4353 if "wandb" in self .args .report_to :
4454 loggers .append (logger .WANDBLogger ())
4555 self .loggers = loggers
4656
57+ def apply_recipes (self , epoch = 0.0 ):
58+ """
59+ Apply recipes and sparsification related parameters to the model
60+ """
61+ if self .manager is not None :
62+ org_state_dict = self .model .state_dict ()
63+ self .manager .initialize (self .model , epoch = epoch , loggers = self .loggers )
64+ new_state_dict = self .model .state_dict ()
65+ new_params = [p for p in new_state_dict .keys () if p not in org_state_dict ]
66+
67+ if os .path .isdir (self .model_name_or_path ):
68+ if os .path .isfile (os .path .join (self .model_name_or_path , WEIGHTS_NAME )):
69+ archive_file = os .path .join (self .model_name_or_path , WEIGHTS_NAME )
70+ state_dict = torch .load (archive_file , map_location = "cpu" )
71+ new_params_to_init = [p for p in new_params if p in state_dict .keys ()]
72+ if new_params_to_init :
73+ # If we're here, the assumption is that all the new parameters introduced
74+ # by the recipes are available to be restore from the checkpoint---this is
75+ # case of evaluating pruned or pruned quantized models
76+ # Otherwise, we're in use cases such as quantizing a block pruned model in which
77+ # new parameters need to be initialized and trained during the QAT process
78+ _ , missing_keys , unexpected_keys , _ = BertForQuestionAnswering ._load_state_dict_into_model (
79+ self .model , state_dict , self .model_name_or_path , _fast_init = False
80+ )
81+ if missing_keys or unexpected_keys :
82+ raise RuntimeError (
83+ "Unexpected or missing keys detected when applying recipes to models\n "
84+ f"Missing keys: { missing_keys } \n "
85+ f"Unexpected keys: { unexpected_keys } \n "
86+ )
87+
4788 def create_optimizer (self ):
4889 """
4990 Create optimizer customized using SparseML
5091 """
5192 super ().create_optimizer ()
52- if self .recipe is None :
93+ if not self .recipes :
5394 return
5495 steps_per_epoch = math .ceil (
5596 len (self .train_dataset ) / (self .args .per_device_train_batch_size * self .args ._n_gpu )
5697 )
57- self .manager = ScheduledModifierManager .from_yaml (self .recipe )
5898 self .args .num_train_epochs = float (self .manager .max_epochs )
5999 if hasattr (self , "scaler" ):
60- self .manager .initialize (self .model , epoch = 0.0 , loggers = self .loggers )
61100 self .scaler = self .manager .modify (
62101 self .model , self .optimizer , steps_per_epoch = steps_per_epoch , wrap_optim = self .scaler
63102 )
@@ -70,7 +109,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
70109 """
71110 Computing loss using teacher/student distillation
72111 """
73- if self .recipe is None or self .teacher is None :
112+ if not self .recipes or self .teacher is None :
74113 return super ().compute_loss (model , inputs , return_outputs = return_outputs )
75114
76115 outputs = model (** inputs )
@@ -114,11 +153,25 @@ def compute_loss(self, model, inputs, return_outputs=False):
114153 loss = ((1 - self .distill_hardness ) * label_loss ) + (self .distill_hardness * teacher_loss )
115154 return (loss , outputs ) if return_outputs else loss
116155
156+ def save_model (self , output_dir : Optional [str ] = None ):
157+ """
158+ Save model during or after training. The sparsification recipe will also be saved.
159+ """
160+ super ().save_model (output_dir = output_dir )
161+ if self .manager is not None :
162+ self ._save_recipe (output_dir = output_dir )
163+
164+ def _save_recipe (self , output_dir : Optional [str ] = None ):
165+ output_dir = output_dir if output_dir is not None else self .args .output_dir
166+ output_recipe_file = os .path .join (output_dir , RECIPE_NAME )
167+ self .manager .save (output_recipe_file )
168+
117169
118170class QuestionAnsweringModuleExporter (ModuleExporter ):
119171 """
120172 Module exporter class for Question Answering
121173 """
174+
122175 @classmethod
123176 def get_output_names (self , out : Any ):
124177 if not isinstance (out , QuestionAnsweringModelOutput ):
@@ -173,3 +226,44 @@ def export_model(model, dataloader, output_dir, num_exported_samples):
173226 num_samples += 1
174227 if num_samples >= num_exported_samples :
175228 return
229+
230+
231+ def preprocess_state_dict (pretrained_model_name_or_path ):
232+ """
233+ Restore original parameter names that were changed by QAT process
234+ """
235+ state_dict = None
236+ if pretrained_model_name_or_path is not None :
237+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
238+ if os .path .isdir (pretrained_model_name_or_path ):
239+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , RECIPE_NAME )):
240+ recipe = os .path .join (pretrained_model_name_or_path , RECIPE_NAME )
241+ manager = ScheduledModifierManager .from_yaml (recipe )
242+ modifiers = [m .__class__ .__name__ for m in manager .modifiers ]
243+ is_qat_recipe = "QuantizationModifier" in modifiers
244+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
245+ archive_file = os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )
246+ state_dict = torch .load (archive_file , map_location = "cpu" )
247+ removed_keys = (
248+ [key for key in state_dict if (key .endswith (".module.weight" ) or key .endswith (".module.bias" ))]
249+ if is_qat_recipe
250+ else []
251+ )
252+ for key in removed_keys :
253+ new_key = key .replace (".module" , "" )
254+ state_dict [new_key ] = state_dict [key ]
255+ state_dict .pop (key )
256+ return state_dict
257+
258+
259+ def load_recipe (pretrained_model_name_or_path ):
260+ """
261+ Load recipe from the model directory
262+ """
263+ recipe = None
264+ if pretrained_model_name_or_path is not None :
265+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
266+ if os .path .isdir (pretrained_model_name_or_path ):
267+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , RECIPE_NAME )):
268+ recipe = os .path .join (pretrained_model_name_or_path , RECIPE_NAME )
269+ return recipe
0 commit comments