Pytorch-lightning framework for knowledge distillation experiments with CNN.
Данный репозиторий содержит в себе решение вступительного испытания в VK lab.
Задание заключается в том, чтобы придумать и поставить эксперименты показывающие
работоспособность метода Knowledge Distillation.
Данную работу я решил сделать в виде цельного, расширяемого репозитория.
Мне это кажется важным, так как на мой взгляд многие
разработки VK lab имеют шанс быть использованными в продакшене. Для достижения этих целей я
использовал популярный фреймворки pytorch-lightning
и W&B. Также я настроил базовый CI пайплайн
с помощью Github Actions и сохранил большие файлы с checkpoint моделей
при помощи github-lfs.
Цель данной работы заключается в том, чтобы изучить некоторые существующие подходы
Knowledge Distillation, а также реализовать фреймворк, который можно
было бы в последствии расширить. Я решил использовать датасет cifar10
из-за того, что он есть в torchvision и он простой, все эксперименты
были сделаны над моделями Resnet. В качестве студента вступала модель Resnet18, а
в качестве учителя модель Resnet50
Работу можно разделить на несколько этапов:
Для начала я решил реализовать классы
BaseCifarModelиSingleCifarModel, а также моделиResnet18иResnet50и обучить их наcifar10. Каждую модель я обучил какfrom scratch, так и c заморозкой нескольких слоев. Первый мой эксперимент был в том, что я вResnetзаморозил все слои кроме классификатора, но по итогу я получил низкое значениеaccuracy(порядка 85%). Далее я наткнулся на статью на kaggle и воспользовался оттуда двумя советами:
- Увеличить размер входной картинки (поэкспериментировав, я остановился на
128x128). Так как исходная модель была обучена наImageNet(размер картинок в котором224x224), то такой размер будет для модели более приятным- Разморозить
BatchNormalization. Данное улучшение так же кажется логичным, так как свертки обучены наImageNetдостаточно хорошо, а вот масштабы и локализация объектов вcifar10может немного отличать отImageNet.
Далее я реализовал
DistillationCifarModelиlossфункцию из следующей статьи. Результаты получились хорошими не сразу, достаточно много времени пришлось потратить на подбор гиперпараметеров.
Я имею опыт работы с
GANи поэтому мне показалось хорошей идея использовать своего рода дискриминатор дляlogitsмоделей. Дискриминатор принимает в себяlogitsстудента и выдает вероятность того, что данныеlogitsполучены от учителя. При этом дискриминатор и студент учатся по очереди, я тестировал разные стратегии, к примеру следующую: 30 шагов учится дискриминатор, а следующие 170 студент. В качествеlossя использовал популярныйWasserstein lossРеализацию можно найти в классеLogitsDiscriminatorCifarModel.
Наконец, не получив ожидаемого результата с
GAN, я решил воспроизвести результат статьи RKD, к счастью тут я получил более позитивные результаты.
Затем мне стало интересно сравнить куда смотрят те или иные сети, когда делают свои предсказания и насколько сильно "взгляд" студента похож на взгляд учителя при использовании
Knowledge Distillation. Для этого я воспользовался подходом grad-cam, реализованным в библиотекеgradcam.
Запускать следующие команды необходимо из директории cnn-distillation
Перед запуском экспериментов необходимо установить все необходимые библиотеки, это можно сделать при помощи команды.
sh scripts/build.sh
Все эксперименты можно запустить, используя команду
sh scripts/train.sh
Если необходимо запустить обучения какого-то конкретного эксперимента необходимо выполнить команду
python train.py <experiment name>Также можно добавить флаг
--unfrozen, чтобы выбрать модель со всеми размороженными слоями
Названия эксперимнтов следующие:
- Обучение студента без учителя:
student- Обучение учителя:
teacher- Обучение студента c учителем с использованием KD Loss из статьи:
kd_distillation- Обучение студента c учителем с использованием RKD Distance Loss из статьи:
rkdd_distillation- Обучение студента c учителем с использованием Logits Discriminator Loss:
ld_distillation
Обученные модели сохранены в папке
models/checkpoints. Чтобы получить оценки качества всех моделей необходимо выполнить командуsh scripts/eval.shЕсли необходимы метрики какого-то конкретного эксперимента нужно выполнить команду
python eval.py <path to .ckpt file> <experiment name>В данном случае видов экспериментов всего три: [
teacher,student,distillation,discriminator], а соответствующие экспериментам.ckptназвания файлов представлены в таблице с результатами
Здесь представлена выжимка результатов, более подробный обзор доступен по ссылке
| Student | Teacher | Method | Pretrained | Freeze Encoder | Accuracy | .ckpt file |
|---|---|---|---|---|---|---|
| ResNet18 | ❌ | Cross Entropy | ✅ | ✅ | 93.05 | student.ckpt |
| ResNet18 | ❌ | Cross Entropy | ✅ | ❌ | 93.65 | student_unfrozen.ckpt |
| ResNet50 | ❌ | Cross Entropy | ✅ | ✅ | 95.71 | teacher.ckpt |
| ResNet50 | ❌ | Cross Entropy | ✅ | ❌ | 93.83 | teacher_unfrozen.ckpt |
| ResNet18 | ResNet50 | Default KD loss | ✅ | ✅ | 93.29 | distillation_kd.ckpt |
| ResNet18 | ResNet50 | Default KD loss | ✅ | ❌ | 94.26 | distillation_kd_unfrozen.ckpt |
| ResNet18 | ResNet50 | RKD Distance loss | ✅ | ✅ | 93.05 | distillation_rkdd.ckpt |
| ResNet18 | ResNet50 | RKD Distance loss | ✅ | ❌ | 94.43 | distillation_rkdd_unfrozen.ckpt |
| ResNet18 | ResNet50 | Logits Discriminator | ✅ | ✅ | 92.21 | distillation_ld.ckpt |
| ResNet18 | ResNet50 | Logits Discriminator | ✅ | ❌ | 93.46 | distillation_ld_unfrozen.ckpt |
Теперь сравним результаты работы grad-cam. Результаты представлены в следующем
порядке:
-
Исходное изображение
-
Студент
Resnet18с учителемResnet50 -
Учитель
Resnet50 -
Студент
Resnet18без учителя -
Все слои кроме
BatchNormи классификатора заморожены
- Все слои разморожены
По результатам можно видеть, что подход Knowledge Distillation действительно
работает, в связи с чем имеет смысл продолжать изучение данного метода. Гипотеза
о том, что введение своеобразного дискриминатора для logits поможет обучению
не подтвердилась во время экспериментов, однако возможно дело было в недостаточно
аккуратном подборе гиперпараметров сети

