2727from datasets import load_dataset , load_metric
2828
2929import transformers
30+ from sparseml_utils import GLUEModuleExporter , SparseMLGLUETrainer
3031from transformers import (
3132 AutoConfig ,
3233 AutoModelForSequenceClassification ,
3536 EvalPrediction ,
3637 HfArgumentParser ,
3738 PretrainedConfig ,
38- Trainer ,
3939 TrainingArguments ,
4040 default_data_collator ,
4141 set_seed ,
4242)
43+ from transformers .sparse import export_model , load_recipe , preprocess_state_dict
4344from transformers .trainer_utils import get_last_checkpoint
4445from transformers .utils import check_min_version
4546
@@ -72,6 +73,19 @@ class DataTrainingArguments:
7273 the command line.
7374 """
7475
76+ recipe : Optional [str ] = field (
77+ default = None ,
78+ metadata = {
79+ "help" : "Path to a SparseML sparsification recipe, see https://github.com/neuralmagic/sparseml "
80+ "for more information"
81+ },
82+ )
83+ onnx_export_path : Optional [str ] = field (
84+ default = None , metadata = {"help" : "The filename and path which will be where onnx model is outputed" }
85+ )
86+ num_exported_samples : Optional [int ] = field (
87+ default = 20 , metadata = {"help" : "Number of exported samples, default to 20" }
88+ )
7589 task_name : Optional [str ] = field (
7690 default = None ,
7791 metadata = {"help" : "The name of the task to train on: " + ", " .join (task_to_keys .keys ())},
@@ -155,6 +169,9 @@ class ModelArguments:
155169 model_name_or_path : str = field (
156170 metadata = {"help" : "Path to pretrained model or model identifier from huggingface.co/models" }
157171 )
172+ distill_teacher : Optional [str ] = field (
173+ default = None , metadata = {"help" : "Teacher model which needs to be a trained text classification model" }
174+ )
158175 config_name : Optional [str ] = field (
159176 default = None , metadata = {"help" : "Pretrained config name or path if not the same as model_name" }
160177 )
@@ -305,6 +322,12 @@ def main():
305322 #
306323 # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
307324 # download model & vocab.
325+
326+ # Load and preprocess the state dict if the model existed (in this case we continue to train or
327+ # evaluate the model). The preprocessing step is to restore names of parameters changed by
328+ # QAT process
329+ state_dict = preprocess_state_dict (model_args .model_name_or_path )
330+
308331 config = AutoConfig .from_pretrained (
309332 model_args .config_name if model_args .config_name else model_args .model_name_or_path ,
310333 num_labels = num_labels ,
@@ -327,8 +350,19 @@ def main():
327350 cache_dir = model_args .cache_dir ,
328351 revision = model_args .model_revision ,
329352 use_auth_token = True if model_args .use_auth_token else None ,
353+ state_dict = state_dict ,
330354 )
331355
356+ teacher_model = None
357+ if model_args .distill_teacher is not None :
358+ teacher_model = AutoModelForSequenceClassification .from_pretrained (
359+ model_args .distill_teacher ,
360+ from_tf = bool (".ckpt" in model_args .distill_teacher ),
361+ cache_dir = model_args .cache_dir ,
362+ )
363+ teacher_model_parameters = filter (lambda p : p .requires_grad , teacher_model .parameters ())
364+ params = sum ([np .prod (p .size ()) for p in teacher_model_parameters ])
365+ logger .info ("Teacher Model has %s parameters" , params )
332366 # Preprocessing the datasets
333367 if data_args .task_name is not None :
334368 sentence1_key , sentence2_key = task_to_keys [data_args .task_name ]
@@ -445,17 +479,29 @@ def compute_metrics(p: EvalPrediction):
445479 else :
446480 data_collator = None
447481
482+ # Load possible existing recipe and new one passed in through command argument
483+ existing_recipe = load_recipe (model_args .model_name_or_path )
484+ new_recipe = data_args .recipe
485+
448486 # Initialize our Trainer
449- trainer = Trainer (
487+ trainer = SparseMLGLUETrainer (
488+ model_args .model_name_or_path ,
489+ [existing_recipe , new_recipe ],
490+ teacher = teacher_model ,
450491 model = model ,
451492 args = training_args ,
452493 train_dataset = train_dataset if training_args .do_train else None ,
453494 eval_dataset = eval_dataset if training_args .do_eval else None ,
454- compute_metrics = compute_metrics ,
455495 tokenizer = tokenizer ,
456496 data_collator = data_collator ,
497+ compute_metrics = compute_metrics ,
457498 )
458499
500+ # Apply recipes to the model. This is necessary given that
501+ # sparsification methods such as QAT modified the model graph with their own learnable
502+ # parameters. They are also restored/loaded to the model.
503+ trainer .apply_recipes ()
504+
459505 # Training
460506 if training_args .do_train :
461507 checkpoint = None
@@ -536,6 +582,12 @@ def compute_metrics(p: EvalPrediction):
536582
537583 trainer .push_to_hub (** kwargs )
538584
585+ if data_args .onnx_export_path :
586+ logger .info ("*** Export to ONNX ***" )
587+ eval_dataloader = trainer .get_eval_dataloader (eval_dataset )
588+ exporter = GLUEModuleExporter (model , output_dir = data_args .onnx_export_path )
589+ export_model (exporter , eval_dataloader , data_args .onnx_export_path , data_args .num_exported_samples )
590+
539591
540592def _mp_fn (index ):
541593 # For xla_spawn (TPUs)
0 commit comments