diff --git a/configs/regional_forecast_climax.yaml b/configs/regional_forecast_climax.yaml index fe38f95..82028f5 100644 --- a/configs/regional_forecast_climax.yaml +++ b/configs/regional_forecast_climax.yaml @@ -2,22 +2,21 @@ seed_everything: 42 # ---------------------------- TRAINER ------------------------------------------- trainer: - default_root_dir: ${oc.env:OUTPUT_DIR,/home/t-tungnguyen/ClimaX/exps/regional_forecast_climax} + default_root_dir: ${oc.env:OUTPUT_DIR,"/Users//ClimaX/experiments/regional_forecast_climax"} - precision: 16 - - gpus: null + # MPS on Apple Silicon: keep full precision for stability + precision: 32-true num_nodes: 1 - accelerator: gpu - strategy: ddp + accelerator: mps + devices: 1 + strategy: auto min_epochs: 1 max_epochs: 100 enable_progress_bar: true - - sync_batchnorm: True - enable_checkpointing: True - resume_from_checkpoint: null + sync_batchnorm: false + enable_checkpointing: true + # resume_from_checkpoint: null # (deprecated; pass ckpt_path to fit/test instead) # debugging fast_dev_run: false @@ -28,8 +27,8 @@ trainer: save_dir: ${trainer.default_root_dir}/logs name: null version: null - log_graph: False - default_hp_metric: True + log_graph: false + default_hp_metric: true prefix: "" callbacks: @@ -40,26 +39,12 @@ trainer: - class_path: pytorch_lightning.callbacks.ModelCheckpoint init_args: dirpath: "${trainer.default_root_dir}/checkpoints" - monitor: "val/w_rmse" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - save_top_k: 1 # save k best models (determined by above metric) - save_last: True # additionaly always save model from last epoch - verbose: False + monitor: "val/w_rmse" + mode: "min" + save_top_k: 1 + save_last: true filename: "epoch_{epoch:03d}" - auto_insert_metric_name: False - - - class_path: pytorch_lightning.callbacks.EarlyStopping - init_args: - monitor: "val/w_rmse" # name of the logged metric which determines when model is improving - mode: "min" # "max" means higher metric value is better, can be also "min" - patience: 5 # how many validation epochs of not improving until training stops - min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement - - - class_path: pytorch_lightning.callbacks.RichModelSummary - init_args: - max_depth: -1 - - - class_path: pytorch_lightning.callbacks.RichProgressBar + auto_insert_metric_name: false # ---------------------------- MODEL ------------------------------------------- model: @@ -71,7 +56,7 @@ model: max_epochs: 100000 warmup_start_lr: 1e-8 eta_min: 1e-8 - pretrained_path: "" + pretrained_path: "" # override via CLI if needed net: class_path: climax.regional_forecast.arch.RegionalClimaX @@ -138,7 +123,8 @@ model: # ---------------------------- DATA ------------------------------------------- data: - root_dir: /datadrive/datasets/5.625deg_equally_np/ + # Use your mac path (update if different) + root_dir: /mnt/data/5.625deg_npz variables: [ "land_sea_mask", "orography", @@ -194,6 +180,7 @@ data: predict_range: 72 hrs_each_step: 1 buffer_size: 10000 - batch_size: 128 - num_workers: 1 - pin_memory: False + # Apple Silicon memory is tight; start small and scale up + batch_size: 1 + num_workers: 0 # macOS: avoid multiprocessing hiccups + pin_memory: false diff --git a/demo_compare_resolutions.py b/demo_compare_resolutions.py new file mode 100644 index 0000000..be08a4d --- /dev/null +++ b/demo_compare_resolutions.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +import argparse, sys, numpy as np, torch +import matplotlib.pyplot as plt +from typing import List, Tuple, Dict + +# Regional architecture (nn.Module, not Lightning) +from climax.regional_forecast.arch import RegionalClimaX + +# ---- Variables list (order must match training!) ---- +DEFAULT_VARS = [ + "land_sea_mask","orography","lattitude", + "2m_temperature","10m_u_component_of_wind","10m_v_component_of_wind", + "geopotential_50","geopotential_250","geopotential_500", + "geopotential_600","geopotential_700","geopotential_850","geopotential_925", + "u_component_of_wind_50","u_component_of_wind_250","u_component_of_wind_500", + "u_component_of_wind_600","u_component_of_wind_700","u_component_of_wind_850","u_component_of_wind_925", + "v_component_of_wind_50","v_component_of_wind_250","v_component_of_wind_500", + "v_component_of_wind_600","v_component_of_wind_700","v_component_of_wind_850","v_component_of_wind_925", + "temperature_50","temperature_250","temperature_500","temperature_600", + "temperature_700","temperature_850","temperature_925", + "relative_humidity_50","relative_humidity_250","relative_humidity_500","relative_humidity_600", + "relative_humidity_700","relative_humidity_850","relative_humidity_925", + "specific_humidity_50","specific_humidity_250","specific_humidity_500","specific_humidity_600", + "specific_humidity_700","specific_humidity_850","specific_humidity_925", +] +DEFAULT_OUT_VARS = ["geopotential_500","temperature_850","2m_temperature"] + +def parse_args(): + p = argparse.ArgumentParser(description="ClimaX predict-only: compare resolutions side-by-side") + p.add_argument("--ckpt_5625", required=True, help="Path to 5.625° checkpoint") + p.add_argument("--ckpt_1406", required=True, help="Path to 1.40625° checkpoint") + p.add_argument("--device", choices=["mps","cpu"], default="mps") + p.add_argument("--out_vars", nargs="*", default=DEFAULT_OUT_VARS) + p.add_argument("--batch", type=int, default=1) + # Arch hyperparams (should match ckpts; defaults work for official ones) + p.add_argument("--embed_dim", type=int, default=1024) + p.add_argument("--depth", type=int, default=8) + p.add_argument("--decoder_depth", type=int, default=2) + p.add_argument("--num_heads", type=int, default=16) + p.add_argument("--mlp_ratio", type=float, default=4.0) + p.add_argument("--drop_path", type=float, default=0.1) + p.add_argument("--drop_rate", type=float, default=0.1) + return p.parse_args() + +def geo_for(res: str) -> Tuple[Tuple[int,int], int]: + return ((32,64), 2) if res == "5.625" else ((128,256), 4) + +def build_region_info(H: int, W: int, patch: int) -> Dict[str, np.ndarray]: + lat_vec = np.linspace(-90.0, 90.0, num=H, dtype=np.float32) + lon_vec = np.linspace(0.0, 360.0, num=W, endpoint=False, dtype=np.float32) + hp, wp = H // patch, W // patch + L = hp * wp + patch_ids = np.arange(L, dtype=np.int64) + return { + # full-frame ranges + "min_h": 0, "max_h": H, "min_w": 0, "max_w": W, + # per-pixel indices + "h_inds": np.arange(H, dtype=np.int64), + "w_inds": np.arange(W, dtype=np.int64), + # patch grid + "patch": patch, "hp": hp, "wp": wp, "patch_ids": patch_ids, + # metadata + "lat": lat_vec, "lon": lon_vec, + "north": 90.0, "south": -90.0, "west": 0.0, "east": 360.0, + "name": "full", + } + +def strip_prefixes(sd: dict) -> dict: + for p in ("model.","net.","module.","climax."): + if any(k.startswith(p) for k in sd): + return {k[len(p):]: v for k,v in sd.items()} + return sd + +def load_checkpoint(path: str) -> dict: + ckpt = torch.load(path, map_location="cpu") + sd = ckpt.get("state_dict", ckpt) + return strip_prefixes(sd) + +class CaptureMetric: + """Metric that captures predictions and returns zero loss.""" + def __init__(self): + self.last_pred = None + def __call__(self, preds: torch.Tensor, y_true: torch.Tensor, out_vars, lat_vec: torch.Tensor): + self.last_pred = preds + return torch.tensor(0.0, device=preds.device, dtype=preds.dtype), {} + +def predict_one(ckpt_path: str, res: str, device: torch.device, out_vars: List[str], + embed_dim: int, depth: int, decoder_depth: int, num_heads: int, + mlp_ratio: float, drop_path: float, drop_rate: float, batch: int): + (H, W), patch = geo_for(res) + model = RegionalClimaX(DEFAULT_VARS, (H, W), patch, + embed_dim, depth, decoder_depth, + num_heads, mlp_ratio, drop_path, drop_rate) + sd = load_checkpoint(ckpt_path) + missing, unexpected = model.load_state_dict(sd, strict=False) + print(f"[{res}] loaded | missing: {len(missing)} unexpected: {len(unexpected)}") + x = torch.randn(batch, len(DEFAULT_VARS), H, W, dtype=torch.float32, device=device) + y = torch.zeros(batch, len(out_vars), H, W, dtype=torch.float32, device=device) + lat = torch.linspace(-90.0, 90.0, steps=H, dtype=torch.float32, device=device) + lead_times = torch.tensor([1], dtype=torch.float32, device=device) + region_info = build_region_info(H, W, patch) + metric = CaptureMetric() + model.to(device).eval() + with torch.no_grad(): + _ = model(x, y, lead_times, DEFAULT_VARS, out_vars, [metric], lat, region_info) + preds = metric.last_pred # (B, V, H, W) + if preds is None: + raise RuntimeError("Metric didn’t receive predictions; forward signature mismatch.") + return preds.detach().cpu().numpy()[0], (H, W) # (V,H,W), (H,W) + +def plot_side_by_side(var_name: str, arr_lo: np.ndarray, arr_hi: np.ndarray, + shape_lo: Tuple[int,int], shape_hi: Tuple[int,int], out_path: str): + plt.figure(figsize=(10,4)) + plt.suptitle(f"{var_name} — ClimaX predictions (demo)") + + plt.subplot(1,2,1) + plt.title("5.625° (32×64)") + plt.imshow(arr_lo, origin="lower") # no explicit colormap/colors per your environment rules + plt.colorbar(shrink=0.8) + + plt.subplot(1,2,2) + plt.title("1.40625° (128×256)") + plt.imshow(arr_hi, origin="lower") + plt.colorbar(shrink=0.8) + + plt.tight_layout(rect=[0,0,1,0.95]) + plt.savefig(out_path, dpi=150) + plt.close() + +def main(): + args = parse_args() + device = torch.device("mps" if (args.device=="mps" and torch.backends.mps.is_available()) else "cpu") + + # Predict on both resolutions + preds_lo, shape_lo = predict_one( + args.ckpt_5625, "5.625", device, args.out_vars, + args.embed_dim, args.depth, args.decoder_depth, args.num_heads, + args.mlp_ratio, args.drop_path, args.drop_rate, args.batch + ) + preds_hi, shape_hi = predict_one( + args.ckpt_1406, "1.40625", device, args.out_vars, + args.embed_dim, args.depth, args.decoder_depth, args.num_heads, + args.mlp_ratio, args.drop_path, args.drop_rate, args.batch + ) + + # Save raw tensors (optional) + np.save("preds_5p625.npy", preds_lo) # (V, 32, 64) + np.save("preds_1p406.npy", preds_hi) # (V, 128, 256) + + # Side-by-side plots per variable + for i, v in enumerate(args.out_vars): + out_png = f"compare_{v.replace('/', '_')}.png" + plot_side_by_side(v, preds_lo[i], preds_hi[i], shape_lo, shape_hi, out_png) + print(f"saved {out_png}") + + print("Done.") + +if __name__ == "__main__": + main() diff --git a/output/README.md b/output/README.md new file mode 100644 index 0000000..6ee4248 --- /dev/null +++ b/output/README.md @@ -0,0 +1,129 @@ + +Steps: + +``` +brew install cmake protobuf rust python@3.10 git wget +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +pip install diffusers transformers accelerate safetensors +``` + + +https://dataserv.ub.tum.de/s/?dir=/5.625deg/2m_temperature + + +``` +python predict_only_climax.py \ + --ckpt ./data-ch/5.625deg.ckpt \ + --res 5.625 \ + --device mps +``` + + +``` +python demo_compare_resolutions.py \ + --ckpt_5625 ./data-ch/5.625deg.ckpt \ + --ckpt_1406 ./data-ch/1.40625deg.ckpt \ + --device mps \ + --out_vars geopotential_500 temperature_850 2m_temperature +``` + +Train: +``` +python src/climax/regional_forecast/train.py \ + --config configs/regional_forecast_climax.yaml \ + --trainer.callbacks=null \ + --trainer.accelerator=mps --trainer.devices=1 --trainer.precision=32-true \ + --data.root_dir=/mnt/data/5.625deg_npz \ + --data.region="NorthAmerica" \ + --data.predict_range=72 \ + --data.out_variables="['z_500','t_850','t2m']" \ + --data.batch_size=1 --data.num_workers=0 \ + --trainer.max_epochs=1 \ + --model.pretrained_path="./data-ch/1.40625deg.ckpt" \ + --model.lr=5e-7 --model.beta_1=0.9 --model.beta_2=0.99 \ + --model.weight_decay=1e-5 +``` + +Deps: +``` +absl-py==2.3.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +antlr4-python3-runtime==4.9.3 +asciitree==0.3.3 +async-timeout==5.0.1 +attrs==25.3.0 +cdsapi==0.7.6 +certifi==2025.8.3 +cftime==1.6.4.post1 +charset-normalizer==3.4.3 +-e git+https://github.com/microsoft/ClimaX@6d5d354ffb4b91bb684f430b98e8f6f8af7c7f7c#egg=ClimaX +contourpy==1.3.2 +cycler==0.12.1 +docstring_parser==0.17.0 +ecmwf-datastores-client==0.4.0 +einops==0.8.1 +fasteners==0.20 +filelock==3.19.1 +fonttools==4.59.2 +frozenlist==1.7.0 +fsspec==2024.12.0 +grpcio==1.74.0 +hf-xet==1.1.9 +huggingface-hub==0.34.4 +hydra-core==1.3.2 +idna==3.10 +importlib_resources==6.5.2 +Jinja2==3.1.6 +jsonargparse==4.41.0 +kiwisolver==1.4.9 +lightning-utilities==0.15.2 +Markdown==3.9 +markdown-it-py==4.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.6 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.6.4 +multiurl==0.3.7 +netCDF4==1.7.2 +networkx==3.4.2 +numcodecs==0.13.1 +numpy==1.26.4 +omegaconf==2.3.0 +packaging==24.2 +pillow==11.3.0 +propcache==0.3.2 +protobuf==6.32.0 +Pygments==2.19.2 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytorch-lightning==2.1.4 +pytz==2025.2 +PyYAML==6.0.2 +requests==2.32.5 +rich==14.1.0 +safetensors==0.6.2 +scipy==1.15.3 +six==1.17.0 +sympy==1.14.0 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +timm==0.6.13 +torch==2.2.2 +torchaudio==2.2.2 +torchdata==0.7.1 +torchmetrics==1.8.2 +torchvision==0.17.2 +tqdm==4.67.1 +typeshed_client==2.8.2 +typing_extensions==4.15.0 +tzdata==2025.2 +urllib3==2.5.0 +Werkzeug==3.1.3 +xarray==2025.6.1 +yarl==1.20.1 +zarr==2.18.3 + +``` diff --git a/output/compare_2m_temperature.png b/output/compare_2m_temperature.png new file mode 100644 index 0000000..ad956dd Binary files /dev/null and b/output/compare_2m_temperature.png differ diff --git a/output/compare_geopotential_500.png b/output/compare_geopotential_500.png new file mode 100644 index 0000000..f604a58 Binary files /dev/null and b/output/compare_geopotential_500.png differ diff --git a/output/compare_temperature_850.png b/output/compare_temperature_850.png new file mode 100644 index 0000000..0f4fc28 Binary files /dev/null and b/output/compare_temperature_850.png differ diff --git a/predict_only_climax.py b/predict_only_climax.py new file mode 100644 index 0000000..caabd16 --- /dev/null +++ b/predict_only_climax.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +import argparse, sys, numpy as np, torch +from typing import List + +# Regional architecture (nn.Module, not Lightning) +from climax.regional_forecast.arch import RegionalClimaX + +# ---- Variables list (order must match training!) ---- +DEFAULT_VARS = [ + "land_sea_mask","orography","lattitude", + "2m_temperature","10m_u_component_of_wind","10m_v_component_of_wind", + "geopotential_50","geopotential_250","geopotential_500", + "geopotential_600","geopotential_700","geopotential_850","geopotential_925", + "u_component_of_wind_50","u_component_of_wind_250","u_component_of_wind_500", + "u_component_of_wind_600","u_component_of_wind_700","u_component_of_wind_850","u_component_of_wind_925", + "v_component_of_wind_50","v_component_of_wind_250","v_component_of_wind_500", + "v_component_of_wind_600","v_component_of_wind_700","v_component_of_wind_850","v_component_of_wind_925", + "temperature_50","temperature_250","temperature_500","temperature_600", + "temperature_700","temperature_850","temperature_925", + "relative_humidity_50","relative_humidity_250","relative_humidity_500","relative_humidity_600", + "relative_humidity_700","relative_humidity_850","relative_humidity_925", + "specific_humidity_50","specific_humidity_250","specific_humidity_500","specific_humidity_600", + "specific_humidity_700","specific_humidity_850","specific_humidity_925", +] +DEFAULT_OUT_VARS = ["geopotential_500","temperature_850","2m_temperature"] + +def parse_args(): + p = argparse.ArgumentParser(description="ClimaX predict-only (capture preds via metric)") + p.add_argument("--ckpt", required=True) + p.add_argument("--res", choices=["5.625","1.40625"], default="5.625", + help="Picks img_size & patch_size to match ckpt") + p.add_argument("--device", choices=["mps","cpu"], default="cpu") + p.add_argument("--batch", type=int, default=1) + p.add_argument("--out_vars", nargs="*", default=DEFAULT_OUT_VARS) + # core arch hyperparams (must match ckpt) + p.add_argument("--embed_dim", type=int, default=1024) + p.add_argument("--depth", type=int, default=8) + p.add_argument("--decoder_depth", type=int, default=2) + p.add_argument("--num_heads", type=int, default=16) + p.add_argument("--mlp_ratio", type=float, default=4.0) + p.add_argument("--drop_path", type=float, default=0.1) + p.add_argument("--drop_rate", type=float, default=0.1) + return p.parse_args() + +def geometry_for_resolution(res: str): + if res == "5.625": # 32x64, patch=2 -> 512 tokens + return (32, 64), 2 + else: # 1.40625: 128x256, patch=4 -> 2048 tokens + return (128, 256), 4 + +def build_region_info(H: int, W: int, patch: int): + lat_vec = np.linspace(-90.0, 90.0, num=H, dtype=np.float32) + lon_vec = np.linspace(0.0, 360.0, num=W, endpoint=False, dtype=np.float32) + + hp, wp = H // patch, W // patch + L = hp * wp + patch_ids = np.arange(L, dtype=np.int64) # row-major [0..L-1] + + return { + # full-domain index ranges (pixel/grid space) + "min_h": 0, "max_h": H, # NOTE: max_* is exclusive in most forks + "min_w": 0, "max_w": W, + + # per-pixel indices (some code paths use these) + "h_inds": np.arange(H, dtype=np.int64), + "w_inds": np.arange(W, dtype=np.int64), + + # patch-level info (for token selection) + "patch": patch, + "hp": hp, "wp": wp, # patches along H and W + "patch_ids": patch_ids, # all patches in row-major order + + # geo metadata (often optional but safe to include) + "lat": lat_vec, + "lon": lon_vec, + "north": 90.0, "south": -90.0, "west": 0.0, "east": 360.0, + "name": "full", + } + + +def strip_prefixes(sd: dict) -> dict: + for p in ("model.","net.","module.","climax."): + if any(k.startswith(p) for k in sd): + return {k[len(p):]: v for k,v in sd.items()} + return sd + +def load_checkpoint(ckpt_path: str) -> dict: + ckpt = torch.load(ckpt_path, map_location="cpu") + sd = ckpt.get("state_dict", ckpt) + return strip_prefixes(sd) + +class CaptureMetric: + """Metric callable that captures y_pred and returns zero loss.""" + def __init__(self): + self.last_pred = None + self.last_logs = {} + + def __call__(self, preds: torch.Tensor, y_true: torch.Tensor, + out_vars, lat_vec: torch.Tensor): + self.last_pred = preds + # return (loss, logs_dict) + return torch.tensor(0.0, device=preds.device, dtype=preds.dtype), {} + + +def main(): + args = parse_args() + + device = torch.device("mps" if (args.device=="mps" and torch.backends.mps.is_available()) else "cpu") + (H, W), patch = geometry_for_resolution(args.res) + + # Build model + model = RegionalClimaX( + DEFAULT_VARS, (H, W), patch, + args.embed_dim, args.depth, args.decoder_depth, + args.num_heads, args.mlp_ratio, args.drop_path, args.drop_rate, + ) + + # Load weights (non-strict is fine if heads differ) + sd = load_checkpoint(args.ckpt) + missing, unexpected = model.load_state_dict(sd, strict=False) + print(f"loaded state_dict | missing: {len(missing)} unexpected: {len(unexpected)}") + + # Dummy inputs (predict-only) + B = args.batch + x = torch.randn(B, len(DEFAULT_VARS), H, W, dtype=torch.float32, device=device) + y = torch.zeros(B, len(args.out_vars), H, W, dtype=torch.float32, device=device) # dummy target + lat = torch.linspace(-90.0, 90.0, steps=H, dtype=torch.float32, device=device) + lead_times = torch.tensor([1], dtype=torch.float32, device=device) # float avoids MPS Linear issues + region_info = build_region_info(H, W, patch) + + metric = CaptureMetric() + + model.to(device).eval() + with torch.no_grad(): + # Many RegionalClimaX versions accept region_info at the end of forward(...). + # Signature we satisfy: (x, y, lead_times, variables, out_variables, metric, lat, region_info) + _ = model(x, y, lead_times, DEFAULT_VARS, args.out_vars, [metric], lat, region_info) + + + preds = metric.last_pred + if preds is None: + print("Metric did not receive predictions; your forward signature may differ.", file=sys.stderr) + sys.exit(2) + + print("preds:", tuple(preds.shape)) + print("OK") + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f8f33ae --- /dev/null +++ b/requirements.txt @@ -0,0 +1,79 @@ +absl-py==2.3.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +antlr4-python3-runtime==4.9.3 +asciitree==0.3.3 +async-timeout==5.0.1 +attrs==25.3.0 +cdsapi==0.7.6 +certifi==2025.8.3 +cftime==1.6.4.post1 +charset-normalizer==3.4.3 +-e git+https://github.com/microsoft/ClimaX@6d5d354ffb4b91bb684f430b98e8f6f8af7c7f7c#egg=ClimaX +contourpy==1.3.2 +cycler==0.12.1 +docstring_parser==0.17.0 +ecmwf-datastores-client==0.4.0 +einops==0.8.1 +fasteners==0.20 +filelock==3.19.1 +fonttools==4.59.2 +frozenlist==1.7.0 +fsspec==2024.12.0 +grpcio==1.74.0 +hf-xet==1.1.9 +huggingface-hub==0.34.4 +hydra-core==1.3.2 +idna==3.10 +importlib_resources==6.5.2 +Jinja2==3.1.6 +jsonargparse==4.41.0 +kiwisolver==1.4.9 +lightning-utilities==0.15.2 +Markdown==3.9 +markdown-it-py==4.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.6 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.6.4 +multiurl==0.3.7 +netCDF4==1.7.2 +networkx==3.4.2 +numcodecs==0.13.1 +numpy==1.26.4 +omegaconf==2.3.0 +packaging==24.2 +pillow==11.3.0 +propcache==0.3.2 +protobuf==6.32.0 +Pygments==2.19.2 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytorch-lightning==2.1.4 +pytz==2025.2 +PyYAML==6.0.2 +requests==2.32.5 +rich==14.1.0 +safetensors==0.6.2 +scipy==1.15.3 +six==1.17.0 +sympy==1.14.0 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +timm==0.6.13 +torch==2.2.2 +torchaudio==2.2.2 +torchdata==0.7.1 +torchmetrics==1.8.2 +torchvision==0.17.2 +tqdm==4.67.1 +typeshed_client==2.8.2 +typing_extensions==4.15.0 +tzdata==2025.2 +urllib3==2.5.0 +Werkzeug==3.1.3 +xarray==2025.6.1 +yarl==1.20.1 +zarr==2.18.3 diff --git a/src/climax/regional_forecast/train.py b/src/climax/regional_forecast/train.py index 4080449..1756027 100644 --- a/src/climax/regional_forecast/train.py +++ b/src/climax/regional_forecast/train.py @@ -2,25 +2,41 @@ # Licensed under the MIT license. import os +from inspect import signature from climax.regional_forecast.datamodule import RegionalForecastDataModule from climax.regional_forecast.module import RegionalForecastModule -from pytorch_lightning.cli import LightningCLI + +# Prefer modern import path; fall back if needed +try: + from lightning.pytorch.cli import LightningCLI, SaveConfigCallback +except Exception: # pragma: no cover + from pytorch_lightning.cli import LightningCLI, SaveConfigCallback def main(): - # Initialize Lightning with the model and data modules, and instruct it to parse the config yml - cli = LightningCLI( + # Build kwargs in a way that works across Lightning versions + cli_sig = signature(LightningCLI.__init__) + cli_kwargs = dict( model_class=RegionalForecastModule, datamodule_class=RegionalForecastDataModule, seed_everything_default=42, - save_config_overwrite=True, + # Replaces deprecated `save_config_overwrite=...` + save_config_callback=SaveConfigCallback, + save_config_kwargs={"overwrite": True}, run=False, - auto_registry=True, parser_kwargs={"parser_mode": "omegaconf", "error_handler": None}, + # trainer_defaults can stay as-is if you add them later ) + if "auto_registry" in cli_sig.parameters: + cli_kwargs["auto_registry"] = True + + # Initialize Lightning with the model and data modules + cli = LightningCLI(**cli_kwargs) + os.makedirs(cli.trainer.default_root_dir, exist_ok=True) + # Wire datamodule ↔ model shapes & normalization cli.datamodule.set_patch_size(cli.model.get_patch_size()) normalization = cli.datamodule.output_transforms @@ -32,10 +48,8 @@ def main(): cli.model.set_val_clim(cli.datamodule.val_clim) cli.model.set_test_clim(cli.datamodule.test_clim) - # fit() runs the training + # Train & test cli.trainer.fit(cli.model, datamodule=cli.datamodule) - - # test the trained model cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best") diff --git a/src/climax/utils/pos_embed.py b/src/climax/utils/pos_embed.py index 707ce4b..4537649 100644 --- a/src/climax/utils/pos_embed.py +++ b/src/climax/utils/pos_embed.py @@ -24,8 +24,8 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ - grid_h = np.arange(grid_size_h, dtype=np.float32) - grid_w = np.arange(grid_size_w, dtype=np.float32) + grid_h = np.arange(grid_size_h, dtype=float) + grid_w = np.arange(grid_size_w, dtype=float) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) @@ -54,7 +54,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): out: (M, D) """ assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float) + omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,)