Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from collections.abc import Callable

from ase.build import bulk, fcc100, molecule
from fairchem.core.calculate.pretrained_mlip import (
pretrained_checkpoint_path_from_name,
)
from huggingface_hub.utils._auth import get_token

import torch_sim as ts
Expand Down Expand Up @@ -222,6 +225,28 @@ def test_empty_batch_error() -> None:
model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32))


@pytest.mark.skipif(
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
)
def test_load_from_checkpoint_path(device: torch.device, dtype: torch.dtype) -> None:
"""Test loading model from a saved checkpoint file path."""
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1")
loaded_model = FairChemModel(
model=str(checkpoint_path), task_name="omat", cpu=device.type == "cpu"
)

# Verify the loaded model works
system = bulk("Si", "diamond", a=5.43)
state = ts.io.atoms_to_state([system], device=device, dtype=dtype)
results = loaded_model(state)

assert "energy" in results
assert "forces" in results
assert results["energy"].shape == (1,)
assert torch.isfinite(results["energy"]).all()
assert torch.isfinite(results["forces"]).all()


test_fairchem_uma_model_outputs = pytest.mark.skipif(
get_token() is None,
reason="Requires HuggingFace authentication for UMA model access",
Expand Down
20 changes: 19 additions & 1 deletion torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import os
import traceback
import typing
import warnings
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
neighbor_list_fn: Callable | None = None,
*, # force remaining arguments to be keyword-only
model_name: str | None = None,
model_cache_dir: str | Path | None = None,
cpu: bool = False,
dtype: torch.dtype | None = None,
compute_stress: bool = False,
Expand All @@ -86,6 +88,7 @@ def __init__(
neighbor_list_fn (Callable | None): Function to compute neighbor lists
(not currently supported)
model_name (str | None): Name of pretrained model to load
model_cache_dir (str | Path | None): Path where to save the model
cpu (bool): Whether to use CPU instead of GPU for computation
dtype (torch.dtype | None): Data type to use for computation
compute_stress (bool): Whether to compute stress tensor
Expand Down Expand Up @@ -132,7 +135,22 @@ def __init__(
self.task_name = task_name

# Create efficient batch predictor for fast inference
self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str)
if model in pretrained_mlip.available_models:
if model_cache_dir and model_cache_dir.exists():
self.predictor = pretrained_mlip.get_predict_unit(
model, device=device_str, cache_dir=model_cache_dir
)
else:
self.predictor = pretrained_mlip.get_predict_unit(
model, device=device_str
)
elif os.path.isfile(model):
self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str)
else:
raise ValueError(
f"Invalid model name or checkpoint path: {model}. "
f"Available pretrained models are: {pretrained_mlip.available_models}"
)

# Determine implemented properties
# This is a simplified approach - in practice you might want to
Expand Down
Loading