diff --git a/experiments/mnist_backprop_cmp.py b/experiments/mnist_backprop_cmp.py index ba93513..3ad0538 100644 --- a/experiments/mnist_backprop_cmp.py +++ b/experiments/mnist_backprop_cmp.py @@ -1,6 +1,7 @@ # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 import logging import os +import pathlib import time import typing from typing import Callable @@ -75,12 +76,13 @@ def train_mnist( lr: float = 1e-3, batch_size: int = 512, step_count: int = 1_000, + norm_cls: typing.Callable = nn.BatchNorm, + model_weights_filepath: pathlib.Path | None = None, ): X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION")) - model = Model(norm_cls=nn.BatchNorm) + model = Model(norm_cls=norm_cls) - model_weights_filepath = os.environ.get("MODEL_WEIGHTS") if model_weights_filepath is not None: logger.info("Loading model weights from %s", model_weights_filepath) model_state = safe_load(model_weights_filepath) @@ -154,27 +156,19 @@ def get_test_acc() -> Tensor: if __name__ == "__main__": - step_count = 1_000 - exp_id = ensure_experiment("Backprop Comparison V4") - for optimizer_type in ["adam", "muon", "sgd"]: - if optimizer_type == "adam": - with mlflow.start_run( - run_name=f"backprop-adam", - experiment_id=exp_id, - log_system_metrics=True, - ): - train_mnist(optimizer_type=optimizer_type, step_count=step_count) - else: - for lr_base in [1e-2, 1e-3, 1e-4]: - for lr in list(map(lambda x: x * 1e-3, range(1, 10))): - with mlflow.start_run( - run_name=f"backprop-{optimizer_type}-lr-{lr}", - experiment_id=exp_id, - log_system_metrics=True, - ): - train_mnist( - optimizer_type=optimizer_type, lr=lr, step_count=step_count - ) + step_count = 2_000 + exp_id = ensure_experiment("Switch Training Method") + + with mlflow.start_run( + run_name="all-adam", + experiment_id=exp_id, + log_system_metrics=True, + ): + train_mnist( + optimizer_type="adam", step_count=step_count, norm_cls=nn.InstanceNorm + ) + + checkpoint_filepath = pathlib.Path("trained-with-marketplace.safetensors") with mlflow.start_run( run_name="marketplace-v2", experiment_id=exp_id, @@ -189,4 +183,16 @@ def get_test_acc() -> Tensor: probe_scale=1e-1, marketplace=marketplace, manual_seed=42, + checkpoint_filepath=pathlib.Path("trained-with-marketplace.safetensors"), + ) + with mlflow.start_run( + run_name=f"marketplace-then-adam", + experiment_id=exp_id, + log_system_metrics=True, + ): + train_mnist( + optimizer_type="adam", + step_count=step_count, + norm_cls=nn.InstanceNorm, + model_weights_filepath=checkpoint_filepath, )