Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions experiments/mnist_backprop_cmp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
import logging
import os
import pathlib
import time
import typing
from typing import Callable
Expand Down Expand Up @@ -75,12 +76,13 @@ def train_mnist(
lr: float = 1e-3,
batch_size: int = 512,
step_count: int = 1_000,
norm_cls: typing.Callable = nn.BatchNorm,
model_weights_filepath: pathlib.Path | None = None,
):
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))

model = Model(norm_cls=nn.BatchNorm)
model = Model(norm_cls=norm_cls)

model_weights_filepath = os.environ.get("MODEL_WEIGHTS")
if model_weights_filepath is not None:
logger.info("Loading model weights from %s", model_weights_filepath)
model_state = safe_load(model_weights_filepath)
Expand Down Expand Up @@ -154,27 +156,19 @@ def get_test_acc() -> Tensor:


if __name__ == "__main__":
step_count = 1_000
exp_id = ensure_experiment("Backprop Comparison V4")
for optimizer_type in ["adam", "muon", "sgd"]:
if optimizer_type == "adam":
with mlflow.start_run(
run_name=f"backprop-adam",
experiment_id=exp_id,
log_system_metrics=True,
):
train_mnist(optimizer_type=optimizer_type, step_count=step_count)
else:
for lr_base in [1e-2, 1e-3, 1e-4]:
for lr in list(map(lambda x: x * 1e-3, range(1, 10))):
with mlflow.start_run(
run_name=f"backprop-{optimizer_type}-lr-{lr}",
experiment_id=exp_id,
log_system_metrics=True,
):
train_mnist(
optimizer_type=optimizer_type, lr=lr, step_count=step_count
)
step_count = 2_000
exp_id = ensure_experiment("Switch Training Method")

with mlflow.start_run(
run_name="all-adam",
experiment_id=exp_id,
log_system_metrics=True,
):
train_mnist(
optimizer_type="adam", step_count=step_count, norm_cls=nn.InstanceNorm
)

checkpoint_filepath = pathlib.Path("trained-with-marketplace.safetensors")
with mlflow.start_run(
run_name="marketplace-v2",
experiment_id=exp_id,
Expand All @@ -189,4 +183,16 @@ def get_test_acc() -> Tensor:
probe_scale=1e-1,
marketplace=marketplace,
manual_seed=42,
checkpoint_filepath=pathlib.Path("trained-with-marketplace.safetensors"),
)
with mlflow.start_run(
run_name=f"marketplace-then-adam",
experiment_id=exp_id,
log_system_metrics=True,
):
train_mnist(
optimizer_type="adam",
step_count=step_count,
norm_cls=nn.InstanceNorm,
model_weights_filepath=checkpoint_filepath,
)