diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 3c6034c2..2518e8d9 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -10,6 +10,9 @@ from collections.abc import Callable from ase.build import bulk, fcc100, molecule + from fairchem.core.calculate.pretrained_mlip import ( + pretrained_checkpoint_path_from_name, + ) from huggingface_hub.utils._auth import get_token import torch_sim as ts @@ -222,6 +225,28 @@ def test_empty_batch_error() -> None: model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32)) +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_load_from_checkpoint_path(device: torch.device, dtype: torch.dtype) -> None: + """Test loading model from a saved checkpoint file path.""" + checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1") + loaded_model = FairChemModel( + model=str(checkpoint_path), task_name="omat", cpu=device.type == "cpu" + ) + + # Verify the loaded model works + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=device, dtype=dtype) + results = loaded_model(state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape == (1,) + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + test_fairchem_uma_model_outputs = pytest.mark.skipif( get_token() is None, reason="Requires HuggingFace authentication for UMA model access", diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 3aae3f3a..3d7648f8 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -8,6 +8,7 @@ from __future__ import annotations +import os import traceback import typing import warnings @@ -74,6 +75,7 @@ def __init__( neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only model_name: str | None = None, + model_cache_dir: str | Path | None = None, cpu: bool = False, dtype: torch.dtype | None = None, compute_stress: bool = False, @@ -86,6 +88,7 @@ def __init__( neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) model_name (str | None): Name of pretrained model to load + model_cache_dir (str | Path | None): Path where to save the model cpu (bool): Whether to use CPU instead of GPU for computation dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor @@ -132,7 +135,22 @@ def __init__( self.task_name = task_name # Create efficient batch predictor for fast inference - self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str) + if model in pretrained_mlip.available_models: + if model_cache_dir and model_cache_dir.exists(): + self.predictor = pretrained_mlip.get_predict_unit( + model, device=device_str, cache_dir=model_cache_dir + ) + else: + self.predictor = pretrained_mlip.get_predict_unit( + model, device=device_str + ) + elif os.path.isfile(model): + self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str) + else: + raise ValueError( + f"Invalid model name or checkpoint path: {model}. " + f"Available pretrained models are: {pretrained_mlip.available_models}" + ) # Determine implemented properties # This is a simplified approach - in practice you might want to