|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from sparseml.pytorch.optim.manager import ScheduledModifierManager |
| 7 | +from sparseml.pytorch.optim.optimizer import ScheduledOptimizer |
| 8 | +from sparseml.pytorch.utils import ModuleExporter, logger |
| 9 | +from trainer_qa import QuestionAnsweringTrainer |
| 10 | + |
| 11 | + |
| 12 | +class SparseMLQATrainer(QuestionAnsweringTrainer): |
| 13 | + """ |
| 14 | + Question Answering trainer with SparseML integration |
| 15 | +
|
| 16 | + :param recipe: recipe for model sparsification |
| 17 | + :param teacher: teacher model for distillation |
| 18 | + :param distill_hardness: ratio of loss by teacher targets (between 0 and 1) |
| 19 | + :param distill_temperature: temperature for distillation |
| 20 | + :param args, kwargs: arguments passed into parent class |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__(self, recipe, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs): |
| 24 | + super().__init__(*args, **kwargs) |
| 25 | + self.recipe = recipe |
| 26 | + self.teacher = teacher |
| 27 | + self.distill_hardness = distill_hardness |
| 28 | + self.distill_temperature = distill_temperature |
| 29 | + self.criterion = torch.nn.CrossEntropyLoss() |
| 30 | + |
| 31 | + self.manager = None |
| 32 | + self.loggers = None |
| 33 | + if self.recipe is not None: |
| 34 | + loggers = [] |
| 35 | + if "wandb" in self.args.report_to: |
| 36 | + loggers.append(logger.WANDBLogger()) |
| 37 | + self.loggers = loggers |
| 38 | + |
| 39 | + def create_optimizer(self): |
| 40 | + """ |
| 41 | + Create optimizer customized using SparseML |
| 42 | + """ |
| 43 | + super().create_optimizer() |
| 44 | + if self.recipe is None: |
| 45 | + return |
| 46 | + steps_per_epoch = math.ceil( |
| 47 | + len(self.train_dataset) / (self.args.per_device_train_batch_size * self.args._n_gpu) |
| 48 | + ) |
| 49 | + self.manager = ScheduledModifierManager.from_yaml(self.recipe) |
| 50 | + self.args.num_train_epochs = float(self.manager.max_epochs) |
| 51 | + if hasattr(self, "scaler"): |
| 52 | + self.manager.initialize(self.model, epoch=0.0, loggers=self.loggers) |
| 53 | + self.scaler = self.manager.modify( |
| 54 | + self.model, self.optimizer, steps_per_epoch=steps_per_epoch, wrap_optim=self.scaler |
| 55 | + ) |
| 56 | + else: |
| 57 | + self.optimizer = ScheduledOptimizer( |
| 58 | + self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers |
| 59 | + ) |
| 60 | + |
| 61 | + def compute_loss(self, model, inputs, return_outputs=False): |
| 62 | + """ |
| 63 | + Computing loss using teacher/student distillation |
| 64 | + """ |
| 65 | + if self.recipe is None or self.teacher is None: |
| 66 | + return super().compute_loss(model, inputs, return_outputs=return_outputs) |
| 67 | + |
| 68 | + outputs = model(**inputs) |
| 69 | + if self.teacher is None: |
| 70 | + loss = outputs["loss"] |
| 71 | + else: |
| 72 | + input_device = inputs["input_ids"].device |
| 73 | + self.teacher = self.teacher.to(input_device) |
| 74 | + start_logits_student = outputs["start_logits"] |
| 75 | + end_logits_student = outputs["end_logits"] |
| 76 | + start_logits_label = inputs["start_positions"] |
| 77 | + end_logits_label = inputs["end_positions"] |
| 78 | + with torch.no_grad(): |
| 79 | + teacher_output = self.teacher( |
| 80 | + input_ids=inputs["input_ids"], |
| 81 | + token_type_ids=inputs["token_type_ids"], |
| 82 | + attention_mask=inputs["attention_mask"], |
| 83 | + ) |
| 84 | + start_logits_teacher = teacher_output["start_logits"] |
| 85 | + end_logits_teacher = teacher_output["end_logits"] |
| 86 | + loss_start = ( |
| 87 | + F.kl_div( |
| 88 | + input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1), |
| 89 | + target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1), |
| 90 | + reduction="batchmean", |
| 91 | + ) |
| 92 | + * (self.distill_temperature ** 2) |
| 93 | + ) |
| 94 | + loss_end = ( |
| 95 | + F.kl_div( |
| 96 | + input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1), |
| 97 | + target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1), |
| 98 | + reduction="batchmean", |
| 99 | + ) |
| 100 | + * (self.distill_temperature ** 2) |
| 101 | + ) |
| 102 | + teacher_loss = (loss_start + loss_end) / 2.0 |
| 103 | + loss_start = self.criterion(start_logits_student, start_logits_label) |
| 104 | + loss_end = self.criterion(end_logits_student, end_logits_label) |
| 105 | + label_loss = (loss_start + loss_end) / 2.0 |
| 106 | + loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss) |
| 107 | + return (loss, outputs) if return_outputs else loss |
| 108 | + |
| 109 | + |
| 110 | +def export_model(model, dataloader, output_dir): |
| 111 | + """ |
| 112 | + Export a trained model to ONNX |
| 113 | + :param model: trained model |
| 114 | + :param dataloader: dataloader to get sample batch |
| 115 | + :param output_dir: output directory for ONNX model |
| 116 | + """ |
| 117 | + exporter = ModuleExporter(model, output_dir=output_dir) |
| 118 | + for _, sample_batch in enumerate(dataloader): |
| 119 | + sample_input = (sample_batch["input_ids"], sample_batch["attention_mask"], sample_batch["token_type_ids"]) |
| 120 | + exporter.export_onnx(sample_batch=sample_input, convert_qat=True) |
| 121 | + break |
0 commit comments