A professional implementation of experience replay for continual learning with support for multiple datasets, models, and comprehensive evaluation metrics.
- Professional Architecture: Clean, modular codebase with proper separation of concerns
- Configuration Management: YAML-based configuration with inheritance support
- Multiple Models: Support for ResNet backbones with multi-head architecture
- Comprehensive Evaluation: Detailed metrics including forgetting and forward transfer
- Unlearning Experiments: Advanced unlearning capabilities for selective knowledge removal
- CLI Interface: User-friendly command-line interface with rich output formatting
- Extensive Testing: Comprehensive test suite with proper fixtures
- Type Safety: Full type annotations with mypy support
- Logging: Professional logging with file and console output
# Clone the repository
git clone https://github.com/Fatemerjn/Replay-Based-Continual-Learning-PyTorch.git
cd Replay-Based-Continual-Learning-PyTorch
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install package
pip install -e .
# For development
make install-dev# Run continual learning experiment
continual-learning train --config configs/default.yaml
# Run unlearning experiment
continual-learning unlearn --config configs/unlearning.yaml --task 2
# Validate configuration
continual-learning validate-config configs/default.yamlfrom experience_replay import Config, MultiHeadResNet, ExperienceReplay
from experience_replay.training import ContinualTrainer
from experience_replay.utils import setup_device
# Load configuration
config = Config.from_yaml("configs/default.yaml")
# Setup device and model
device = setup_device(config.system.device)
model = MultiHeadResNet(
num_tasks=config.dataset.num_tasks,
num_classes_per_task=config.dataset.classes_per_task
).to(device)
# Initialize strategy and trainer
strategy = ExperienceReplay(buffer_size_per_task=config.strategy.buffer_size_per_task)
trainer = ContinualTrainer(model, strategy, config, device)
# Run experiment
results = trainer.train()├── src/experience_replay/ # Main package
│ ├── config/ # Configuration management
│ ├── data/ # Dataset utilities
│ ├── models/ # Neural network models
│ ├── strategies/ # Continual learning strategies
│ ├── training/ # Training and evaluation
│ ├── cli.py # Command-line interface
│ └── utils.py # Utility functions
├── configs/ # Configuration files
├── tests/ # Test suite
├── requirements.txt # Core dependencies
├── requirements-dev.txt # Development dependencies
├── setup.py # Package setup
├── pyproject.toml # Project configuration
└── Makefile # Development commands
The project uses YAML configuration files with inheritance support:
# configs/default.yaml
model:
name: "multi_head_resnet"
backbone: "resnet18"
pretrained: true
dataset:
name: "cifar100"
num_tasks: 10
classes_per_task: 10
training:
batch_size: 64
epochs_per_task: 10
learning_rate: 0.01
strategy:
name: "experience_replay"
buffer_size_per_task: 50continual-learning train \
--config configs/default.yaml \
--output-dir ./results/experiment_1 \
--seed 42continual-learning unlearn \
--config configs/unlearning.yaml \
--task 2 \
--output-dir ./results/unlearning_experimentThe framework provides comprehensive evaluation metrics:
- Average Accuracy: Mean accuracy across all tasks
- Backward Transfer (Forgetting): Performance degradation on previous tasks
- Forward Transfer: Performance improvement on future tasks from previous learning
- Task-wise Analysis: Detailed per-task performance breakdown
make install-dev# Format code
make format
# Run linting
make lint
# Run tests
make test# Standard training
make train
# Unlearning experiment
make unlearn
# Validate configuration
make validate-config# Run all tests
pytest
# Run with coverage
pytest --cov=experience_replay --cov-report=html
# Run specific test categories
pytest -m "not slow" # Skip slow tests
pytest -m integration # Only integration testsResults are automatically saved to the specified output directory:
results/
├── experiment_logs.log
├── final_results.json
├── task_accuracies.csv
└── model_checkpoints/