Skip to content

A replay-based continual learning, where models preserve past knowledge through stored exemplars or pseudo-samples. Implements Experience Replay (ER), Gradient Episodic Memory (GEM), and iCaRL. Provides modular dataset buffering, memory selection policies, and evaluation utilities for reproducible experiments on vision and NLP tasks.

Notifications You must be signed in to change notification settings

Fatemerjn/Replay-Based-Continual-Learning-PyTorch

Repository files navigation

Replay-Based Continual Learning in PyTorch

Python 3.8+ PyTorch

A professional implementation of experience replay for continual learning with support for multiple datasets, models, and comprehensive evaluation metrics.

Features

  • Professional Architecture: Clean, modular codebase with proper separation of concerns
  • Configuration Management: YAML-based configuration with inheritance support
  • Multiple Models: Support for ResNet backbones with multi-head architecture
  • Comprehensive Evaluation: Detailed metrics including forgetting and forward transfer
  • Unlearning Experiments: Advanced unlearning capabilities for selective knowledge removal
  • CLI Interface: User-friendly command-line interface with rich output formatting
  • Extensive Testing: Comprehensive test suite with proper fixtures
  • Type Safety: Full type annotations with mypy support
  • Logging: Professional logging with file and console output

Quick Start

Installation

# Clone the repository
git clone https://github.com/Fatemerjn/Replay-Based-Continual-Learning-PyTorch.git
cd Replay-Based-Continual-Learning-PyTorch

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install package
pip install -e .

# For development
make install-dev

Basic Usage

# Run continual learning experiment
continual-learning train --config configs/default.yaml

# Run unlearning experiment
continual-learning unlearn --config configs/unlearning.yaml --task 2

# Validate configuration
continual-learning validate-config configs/default.yaml

Python API

from experience_replay import Config, MultiHeadResNet, ExperienceReplay
from experience_replay.training import ContinualTrainer
from experience_replay.utils import setup_device

# Load configuration
config = Config.from_yaml("configs/default.yaml")

# Setup device and model
device = setup_device(config.system.device)
model = MultiHeadResNet(
    num_tasks=config.dataset.num_tasks,
    num_classes_per_task=config.dataset.classes_per_task
).to(device)

# Initialize strategy and trainer
strategy = ExperienceReplay(buffer_size_per_task=config.strategy.buffer_size_per_task)
trainer = ContinualTrainer(model, strategy, config, device)

# Run experiment
results = trainer.train()

Project Structure

├── src/experience_replay/          # Main package
│   ├── config/                     # Configuration management
│   ├── data/                       # Dataset utilities
│   ├── models/                     # Neural network models
│   ├── strategies/                 # Continual learning strategies
│   ├── training/                   # Training and evaluation
│   ├── cli.py                      # Command-line interface
│   └── utils.py                    # Utility functions
├── configs/                        # Configuration files
├── tests/                          # Test suite
├── requirements.txt               # Core dependencies
├── requirements-dev.txt           # Development dependencies
├── setup.py                       # Package setup
├── pyproject.toml                 # Project configuration
└── Makefile                       # Development commands

Configuration

The project uses YAML configuration files with inheritance support:

# configs/default.yaml
model:
  name: "multi_head_resnet"
  backbone: "resnet18"
  pretrained: true

dataset:
  name: "cifar100"
  num_tasks: 10
  classes_per_task: 10

training:
  batch_size: 64
  epochs_per_task: 10
  learning_rate: 0.01

strategy:
  name: "experience_replay"
  buffer_size_per_task: 50

Experiments

Standard Continual Learning

continual-learning train \
    --config configs/default.yaml \
    --output-dir ./results/experiment_1 \
    --seed 42

Unlearning Experiment

continual-learning unlearn \
    --config configs/unlearning.yaml \
    --task 2 \
    --output-dir ./results/unlearning_experiment

Evaluation Metrics

The framework provides comprehensive evaluation metrics:

  • Average Accuracy: Mean accuracy across all tasks
  • Backward Transfer (Forgetting): Performance degradation on previous tasks
  • Forward Transfer: Performance improvement on future tasks from previous learning
  • Task-wise Analysis: Detailed per-task performance breakdown

Development

Setup Development Environment

make install-dev

Code Quality

# Format code
make format

# Run linting
make lint

# Run tests
make test

Running Experiments

# Standard training
make train

# Unlearning experiment
make unlearn

# Validate configuration
make validate-config

Testing

# Run all tests
pytest

# Run with coverage
pytest --cov=experience_replay --cov-report=html

# Run specific test categories
pytest -m "not slow"  # Skip slow tests
pytest -m integration  # Only integration tests

Results

Results are automatically saved to the specified output directory:

results/
├── experiment_logs.log
├── final_results.json
├── task_accuracies.csv
└── model_checkpoints/

About

A replay-based continual learning, where models preserve past knowledge through stored exemplars or pseudo-samples. Implements Experience Replay (ER), Gradient Episodic Memory (GEM), and iCaRL. Provides modular dataset buffering, memory selection policies, and evaluation utilities for reproducible experiments on vision and NLP tasks.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published