Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions configs/regional_forecast_climax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@ seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
default_root_dir: ${oc.env:OUTPUT_DIR,/home/t-tungnguyen/ClimaX/exps/regional_forecast_climax}
default_root_dir: ${oc.env:OUTPUT_DIR,"/Users/<your_username>/ClimaX/experiments/regional_forecast_climax"}

precision: 16

gpus: null
# MPS on Apple Silicon: keep full precision for stability
precision: 32-true
num_nodes: 1
accelerator: gpu
strategy: ddp
accelerator: mps
devices: 1
strategy: auto

min_epochs: 1
max_epochs: 100
enable_progress_bar: true

sync_batchnorm: True
enable_checkpointing: True
resume_from_checkpoint: null
sync_batchnorm: false
enable_checkpointing: true
# resume_from_checkpoint: null # (deprecated; pass ckpt_path to fit/test instead)

# debugging
fast_dev_run: false
Expand All @@ -28,8 +27,8 @@ trainer:
save_dir: ${trainer.default_root_dir}/logs
name: null
version: null
log_graph: False
default_hp_metric: True
log_graph: false
default_hp_metric: true
prefix: ""

callbacks:
Expand All @@ -40,26 +39,12 @@ trainer:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: "${trainer.default_root_dir}/checkpoints"
monitor: "val/w_rmse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
verbose: False
monitor: "val/w_rmse"
mode: "min"
save_top_k: 1
save_last: true
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False

- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: "val/w_rmse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
patience: 5 # how many validation epochs of not improving until training stops
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

- class_path: pytorch_lightning.callbacks.RichModelSummary
init_args:
max_depth: -1

- class_path: pytorch_lightning.callbacks.RichProgressBar
auto_insert_metric_name: false

# ---------------------------- MODEL -------------------------------------------
model:
Expand All @@ -71,7 +56,7 @@ model:
max_epochs: 100000
warmup_start_lr: 1e-8
eta_min: 1e-8
pretrained_path: ""
pretrained_path: "" # override via CLI if needed

net:
class_path: climax.regional_forecast.arch.RegionalClimaX
Expand Down Expand Up @@ -138,7 +123,8 @@ model:

# ---------------------------- DATA -------------------------------------------
data:
root_dir: /datadrive/datasets/5.625deg_equally_np/
# Use your mac path (update if different)
root_dir: /mnt/data/5.625deg_npz
variables: [
"land_sea_mask",
"orography",
Expand Down Expand Up @@ -194,6 +180,7 @@ data:
predict_range: 72
hrs_each_step: 1
buffer_size: 10000
batch_size: 128
num_workers: 1
pin_memory: False
# Apple Silicon memory is tight; start small and scale up
batch_size: 1
num_workers: 0 # macOS: avoid multiprocessing hiccups
pin_memory: false
159 changes: 159 additions & 0 deletions demo_compare_resolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python3
import argparse, sys, numpy as np, torch
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict

# Regional architecture (nn.Module, not Lightning)
from climax.regional_forecast.arch import RegionalClimaX

# ---- Variables list (order must match training!) ----
DEFAULT_VARS = [
"land_sea_mask","orography","lattitude",
"2m_temperature","10m_u_component_of_wind","10m_v_component_of_wind",
"geopotential_50","geopotential_250","geopotential_500",
"geopotential_600","geopotential_700","geopotential_850","geopotential_925",
"u_component_of_wind_50","u_component_of_wind_250","u_component_of_wind_500",
"u_component_of_wind_600","u_component_of_wind_700","u_component_of_wind_850","u_component_of_wind_925",
"v_component_of_wind_50","v_component_of_wind_250","v_component_of_wind_500",
"v_component_of_wind_600","v_component_of_wind_700","v_component_of_wind_850","v_component_of_wind_925",
"temperature_50","temperature_250","temperature_500","temperature_600",
"temperature_700","temperature_850","temperature_925",
"relative_humidity_50","relative_humidity_250","relative_humidity_500","relative_humidity_600",
"relative_humidity_700","relative_humidity_850","relative_humidity_925",
"specific_humidity_50","specific_humidity_250","specific_humidity_500","specific_humidity_600",
"specific_humidity_700","specific_humidity_850","specific_humidity_925",
]
DEFAULT_OUT_VARS = ["geopotential_500","temperature_850","2m_temperature"]

def parse_args():
p = argparse.ArgumentParser(description="ClimaX predict-only: compare resolutions side-by-side")
p.add_argument("--ckpt_5625", required=True, help="Path to 5.625° checkpoint")
p.add_argument("--ckpt_1406", required=True, help="Path to 1.40625° checkpoint")
p.add_argument("--device", choices=["mps","cpu"], default="mps")
p.add_argument("--out_vars", nargs="*", default=DEFAULT_OUT_VARS)
p.add_argument("--batch", type=int, default=1)
# Arch hyperparams (should match ckpts; defaults work for official ones)
p.add_argument("--embed_dim", type=int, default=1024)
p.add_argument("--depth", type=int, default=8)
p.add_argument("--decoder_depth", type=int, default=2)
p.add_argument("--num_heads", type=int, default=16)
p.add_argument("--mlp_ratio", type=float, default=4.0)
p.add_argument("--drop_path", type=float, default=0.1)
p.add_argument("--drop_rate", type=float, default=0.1)
return p.parse_args()

def geo_for(res: str) -> Tuple[Tuple[int,int], int]:
return ((32,64), 2) if res == "5.625" else ((128,256), 4)

def build_region_info(H: int, W: int, patch: int) -> Dict[str, np.ndarray]:
lat_vec = np.linspace(-90.0, 90.0, num=H, dtype=np.float32)
lon_vec = np.linspace(0.0, 360.0, num=W, endpoint=False, dtype=np.float32)
hp, wp = H // patch, W // patch
L = hp * wp
patch_ids = np.arange(L, dtype=np.int64)
return {
# full-frame ranges
"min_h": 0, "max_h": H, "min_w": 0, "max_w": W,
# per-pixel indices
"h_inds": np.arange(H, dtype=np.int64),
"w_inds": np.arange(W, dtype=np.int64),
# patch grid
"patch": patch, "hp": hp, "wp": wp, "patch_ids": patch_ids,
# metadata
"lat": lat_vec, "lon": lon_vec,
"north": 90.0, "south": -90.0, "west": 0.0, "east": 360.0,
"name": "full",
}

def strip_prefixes(sd: dict) -> dict:
for p in ("model.","net.","module.","climax."):
if any(k.startswith(p) for k in sd):
return {k[len(p):]: v for k,v in sd.items()}
return sd

def load_checkpoint(path: str) -> dict:
ckpt = torch.load(path, map_location="cpu")
sd = ckpt.get("state_dict", ckpt)
return strip_prefixes(sd)

class CaptureMetric:
"""Metric that captures predictions and returns zero loss."""
def __init__(self):
self.last_pred = None
def __call__(self, preds: torch.Tensor, y_true: torch.Tensor, out_vars, lat_vec: torch.Tensor):
self.last_pred = preds
return torch.tensor(0.0, device=preds.device, dtype=preds.dtype), {}

def predict_one(ckpt_path: str, res: str, device: torch.device, out_vars: List[str],
embed_dim: int, depth: int, decoder_depth: int, num_heads: int,
mlp_ratio: float, drop_path: float, drop_rate: float, batch: int):
(H, W), patch = geo_for(res)
model = RegionalClimaX(DEFAULT_VARS, (H, W), patch,
embed_dim, depth, decoder_depth,
num_heads, mlp_ratio, drop_path, drop_rate)
sd = load_checkpoint(ckpt_path)
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"[{res}] loaded | missing: {len(missing)} unexpected: {len(unexpected)}")
x = torch.randn(batch, len(DEFAULT_VARS), H, W, dtype=torch.float32, device=device)
y = torch.zeros(batch, len(out_vars), H, W, dtype=torch.float32, device=device)
lat = torch.linspace(-90.0, 90.0, steps=H, dtype=torch.float32, device=device)
lead_times = torch.tensor([1], dtype=torch.float32, device=device)
region_info = build_region_info(H, W, patch)
metric = CaptureMetric()
model.to(device).eval()
with torch.no_grad():
_ = model(x, y, lead_times, DEFAULT_VARS, out_vars, [metric], lat, region_info)
preds = metric.last_pred # (B, V, H, W)
if preds is None:
raise RuntimeError("Metric didn’t receive predictions; forward signature mismatch.")
return preds.detach().cpu().numpy()[0], (H, W) # (V,H,W), (H,W)

def plot_side_by_side(var_name: str, arr_lo: np.ndarray, arr_hi: np.ndarray,
shape_lo: Tuple[int,int], shape_hi: Tuple[int,int], out_path: str):
plt.figure(figsize=(10,4))
plt.suptitle(f"{var_name} — ClimaX predictions (demo)")

plt.subplot(1,2,1)
plt.title("5.625° (32×64)")
plt.imshow(arr_lo, origin="lower") # no explicit colormap/colors per your environment rules
plt.colorbar(shrink=0.8)

plt.subplot(1,2,2)
plt.title("1.40625° (128×256)")
plt.imshow(arr_hi, origin="lower")
plt.colorbar(shrink=0.8)

plt.tight_layout(rect=[0,0,1,0.95])
plt.savefig(out_path, dpi=150)
plt.close()

def main():
args = parse_args()
device = torch.device("mps" if (args.device=="mps" and torch.backends.mps.is_available()) else "cpu")

# Predict on both resolutions
preds_lo, shape_lo = predict_one(
args.ckpt_5625, "5.625", device, args.out_vars,
args.embed_dim, args.depth, args.decoder_depth, args.num_heads,
args.mlp_ratio, args.drop_path, args.drop_rate, args.batch
)
preds_hi, shape_hi = predict_one(
args.ckpt_1406, "1.40625", device, args.out_vars,
args.embed_dim, args.depth, args.decoder_depth, args.num_heads,
args.mlp_ratio, args.drop_path, args.drop_rate, args.batch
)

# Save raw tensors (optional)
np.save("preds_5p625.npy", preds_lo) # (V, 32, 64)
np.save("preds_1p406.npy", preds_hi) # (V, 128, 256)

# Side-by-side plots per variable
for i, v in enumerate(args.out_vars):
out_png = f"compare_{v.replace('/', '_')}.png"
plot_side_by_side(v, preds_lo[i], preds_hi[i], shape_lo, shape_hi, out_png)
print(f"saved {out_png}")

print("Done.")

if __name__ == "__main__":
main()
129 changes: 129 additions & 0 deletions output/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@

Steps:

```
brew install cmake protobuf rust [email protected] git wget
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install diffusers transformers accelerate safetensors
```


https://dataserv.ub.tum.de/s/?dir=/5.625deg/2m_temperature


```
python predict_only_climax.py \
--ckpt ./data-ch/5.625deg.ckpt \
--res 5.625 \
--device mps
```


```
python demo_compare_resolutions.py \
--ckpt_5625 ./data-ch/5.625deg.ckpt \
--ckpt_1406 ./data-ch/1.40625deg.ckpt \
--device mps \
--out_vars geopotential_500 temperature_850 2m_temperature
```

Train:
```
python src/climax/regional_forecast/train.py \
--config configs/regional_forecast_climax.yaml \
--trainer.callbacks=null \
--trainer.accelerator=mps --trainer.devices=1 --trainer.precision=32-true \
--data.root_dir=/mnt/data/5.625deg_npz \
--data.region="NorthAmerica" \
--data.predict_range=72 \
--data.out_variables="['z_500','t_850','t2m']" \
--data.batch_size=1 --data.num_workers=0 \
--trainer.max_epochs=1 \
--model.pretrained_path="./data-ch/1.40625deg.ckpt" \
--model.lr=5e-7 --model.beta_1=0.9 --model.beta_2=0.99 \
--model.weight_decay=1e-5
```

Deps:
```
absl-py==2.3.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
antlr4-python3-runtime==4.9.3
asciitree==0.3.3
async-timeout==5.0.1
attrs==25.3.0
cdsapi==0.7.6
certifi==2025.8.3
cftime==1.6.4.post1
charset-normalizer==3.4.3
-e git+https://github.com/microsoft/ClimaX@6d5d354ffb4b91bb684f430b98e8f6f8af7c7f7c#egg=ClimaX
contourpy==1.3.2
cycler==0.12.1
docstring_parser==0.17.0
ecmwf-datastores-client==0.4.0
einops==0.8.1
fasteners==0.20
filelock==3.19.1
fonttools==4.59.2
frozenlist==1.7.0
fsspec==2024.12.0
grpcio==1.74.0
hf-xet==1.1.9
huggingface-hub==0.34.4
hydra-core==1.3.2
idna==3.10
importlib_resources==6.5.2
Jinja2==3.1.6
jsonargparse==4.41.0
kiwisolver==1.4.9
lightning-utilities==0.15.2
Markdown==3.9
markdown-it-py==4.0.0
MarkupSafe==3.0.2
matplotlib==3.10.6
mdurl==0.1.2
mpmath==1.3.0
multidict==6.6.4
multiurl==0.3.7
netCDF4==1.7.2
networkx==3.4.2
numcodecs==0.13.1
numpy==1.26.4
omegaconf==2.3.0
packaging==24.2
pillow==11.3.0
propcache==0.3.2
protobuf==6.32.0
Pygments==2.19.2
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytorch-lightning==2.1.4
pytz==2025.2
PyYAML==6.0.2
requests==2.32.5
rich==14.1.0
safetensors==0.6.2
scipy==1.15.3
six==1.17.0
sympy==1.14.0
tensorboard==2.20.0
tensorboard-data-server==0.7.2
timm==0.6.13
torch==2.2.2
torchaudio==2.2.2
torchdata==0.7.1
torchmetrics==1.8.2
torchvision==0.17.2
tqdm==4.67.1
typeshed_client==2.8.2
typing_extensions==4.15.0
tzdata==2025.2
urllib3==2.5.0
Werkzeug==3.1.3
xarray==2025.6.1
yarl==1.20.1
zarr==2.18.3

```
Binary file added output/compare_2m_temperature.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output/compare_geopotential_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output/compare_temperature_850.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading