Skip to content

Commit 15367de

Browse files
Fixing model loading logic (names and cache dir) for fairchem models (#278)
Signed-off-by: Niklas Hölter <[email protected]>
1 parent a8a4bbb commit 15367de

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

tests/models/test_fairchem.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from collections.abc import Callable
1313

1414
from ase.build import bulk, fcc100, molecule
15+
from fairchem.core.calculate.pretrained_mlip import (
16+
pretrained_checkpoint_path_from_name,
17+
)
1518
from huggingface_hub.utils._auth import get_token
1619

1720
import torch_sim as ts
@@ -213,6 +216,28 @@ def test_empty_batch_error() -> None:
213216
model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32))
214217

215218

219+
@pytest.mark.skipif(
220+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
221+
)
222+
def test_load_from_checkpoint_path(device: torch.device, dtype: torch.dtype) -> None:
223+
"""Test loading model from a saved checkpoint file path."""
224+
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1")
225+
loaded_model = FairChemModel(
226+
model=str(checkpoint_path), task_name="omat", cpu=device.type == "cpu"
227+
)
228+
229+
# Verify the loaded model works
230+
system = bulk("Si", "diamond", a=5.43)
231+
state = ts.io.atoms_to_state([system], device=device, dtype=dtype)
232+
results = loaded_model(state)
233+
234+
assert "energy" in results
235+
assert "forces" in results
236+
assert results["energy"].shape == (1,)
237+
assert torch.isfinite(results["energy"]).all()
238+
assert torch.isfinite(results["forces"]).all()
239+
240+
216241
test_fairchem_uma_model_outputs = pytest.mark.skipif(
217242
get_token() is None,
218243
reason="Requires HuggingFace authentication for UMA model access",

torch_sim/models/fairchem.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import os
1112
import traceback
1213
import typing
1314
import warnings
@@ -74,6 +75,7 @@ def __init__(
7475
neighbor_list_fn: Callable | None = None,
7576
*, # force remaining arguments to be keyword-only
7677
model_name: str | None = None,
78+
model_cache_dir: str | Path | None = None,
7779
cpu: bool = False,
7880
dtype: torch.dtype | None = None,
7981
compute_stress: bool = False,
@@ -86,6 +88,7 @@ def __init__(
8688
neighbor_list_fn (Callable | None): Function to compute neighbor lists
8789
(not currently supported)
8890
model_name (str | None): Name of pretrained model to load
91+
model_cache_dir (str | Path | None): Path where to save the model
8992
cpu (bool): Whether to use CPU instead of GPU for computation
9093
dtype (torch.dtype | None): Data type to use for computation
9194
compute_stress (bool): Whether to compute stress tensor
@@ -132,7 +135,22 @@ def __init__(
132135
self.task_name = task_name
133136

134137
# Create efficient batch predictor for fast inference
135-
self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str)
138+
if model in pretrained_mlip.available_models:
139+
if model_cache_dir and model_cache_dir.exists():
140+
self.predictor = pretrained_mlip.get_predict_unit(
141+
model, device=device_str, cache_dir=model_cache_dir
142+
)
143+
else:
144+
self.predictor = pretrained_mlip.get_predict_unit(
145+
model, device=device_str
146+
)
147+
elif os.path.isfile(model):
148+
self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str)
149+
else:
150+
raise ValueError(
151+
f"Invalid model name or checkpoint path: {model}. "
152+
f"Available pretrained models are: {pretrained_mlip.available_models}"
153+
)
136154

137155
# Determine implemented properties
138156
# This is a simplified approach - in practice you might want to

0 commit comments

Comments
 (0)