From a139bca6dd027853cec3c312b49b46218a353516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20H=C3=B6lter?= <83964137+niklashoelter@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:33:50 +0200 Subject: [PATCH 1/4] Fixing model loading logic (names and cache dir) for fairchem models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit previously: model (str | Path | None) was passed as to fairchem's get_predict_unit - although this is expecting a model name. When loading a cached model, the cache_dir parameter needs to be set in addition to a valid model_name (one of the available ones). This is fixed here. Signed-off-by: Niklas Hölter <83964137+niklashoelter@users.noreply.github.com> --- torch_sim/models/fairchem.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 3aae3f3a..03260c52 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -70,10 +70,10 @@ class FairChemModel(ModelInterface): def __init__( self, - model: str | Path | None, + model_name: str, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only - model_name: str | None = None, + model_cache_dir: str | Path | None = None, cpu: bool = False, dtype: torch.dtype | None = None, compute_stress: bool = False, @@ -82,10 +82,10 @@ def __init__( """Initialize the FairChem model. Args: - model (str | Path | None): Path to model checkpoint file + model_name (str): Name of pretrained model to load neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) - model_name (str | None): Name of pretrained model to load + model_cache_dir (str | Path | None): Path where to save the model cpu (bool): Whether to use CPU instead of GPU for computation dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor @@ -111,16 +111,8 @@ def __init__( "Custom neighbor list is not supported for FairChemModel." ) - if model_name is not None: - if model is not None: - raise RuntimeError( - "model_name and checkpoint_path were both specified, " - "please use only one at a time" - ) - model = model_name - - if model is None: - raise ValueError("Either model or model_name must be provided") + if model_name is None: + raise ValueError("Valid fairchem model name needs to be specified") # Convert task_name to UMATask if it's a string (only for UMA models) if isinstance(task_name, str): @@ -132,8 +124,14 @@ def __init__( self.task_name = task_name # Create efficient batch predictor for fast inference - self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str) - + if model_cache_dir: + if model_cache_dir.exists(): + self.predictor = pretrained_mlip.get_predict_unit(str(model_name), device=device_str, cache_dir=model_cache_dir) + else: + raise ValueError("Specified cache dir does not exist!") + else: + self.predictor = pretrained_mlip.get_predict_unit(str(model_name), device=device_str) + # Determine implemented properties # This is a simplified approach - in practice you might want to # inspect the model configuration more carefully From 9ff6812c2d4e5f2f325bb1bcbded55b4805bfacc Mon Sep 17 00:00:00 2001 From: niklashoelter Date: Wed, 8 Oct 2025 10:53:14 +0200 Subject: [PATCH 2/4] adapted model loading procedure accoring to fairchem --- torch_sim/models/fairchem.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 03260c52..df69b9bd 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -70,7 +70,7 @@ class FairChemModel(ModelInterface): def __init__( self, - model_name: str, + model: str, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only model_cache_dir: str | Path | None = None, @@ -111,8 +111,8 @@ def __init__( "Custom neighbor list is not supported for FairChemModel." ) - if model_name is None: - raise ValueError("Valid fairchem model name needs to be specified") + if model is None: + raise ValueError("Valid fairchem model name or path needs to be specified") # Convert task_name to UMATask if it's a string (only for UMA models) if isinstance(task_name, str): @@ -124,13 +124,22 @@ def __init__( self.task_name = task_name # Create efficient batch predictor for fast inference - if model_cache_dir: - if model_cache_dir.exists(): - self.predictor = pretrained_mlip.get_predict_unit(str(model_name), device=device_str, cache_dir=model_cache_dir) + if model in pretrained_mlip.available_models: + if model_cache_dir and model_cache_dir.exists(): + self.predictor = pretrained_mlip.get_predict_unit( + model, + device=device_str, + cache_dir=model_cache_dir + ) else: - raise ValueError("Specified cache dir does not exist!") + self.predictor = pretrained_mlip.get_predict_unit(model, device=device_str) + elif os.path.isfile(model): + self.predictor = pretrained_mlip.load_predict_unit( + model, + device=device_str + ) else: - self.predictor = pretrained_mlip.get_predict_unit(str(model_name), device=device_str) + raise ValueError(f"Invalid model name or checkpoint path: {model}. Available pretrained models are: {pretrained_mlip.available_models}") # Determine implemented properties # This is a simplified approach - in practice you might want to From 37a10ba7a9246bdf193a9465c51061915cccd94a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 8 Oct 2025 17:18:42 -0400 Subject: [PATCH 3/4] fix the loading from a checkpoint and add extra tests --- tests/models/test_fairchem.py | 25 +++++++++++++++++++++++ torch_sim/models/fairchem.py | 37 +++++++++++++++++++++++------------ 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 3c6034c2..2518e8d9 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -10,6 +10,9 @@ from collections.abc import Callable from ase.build import bulk, fcc100, molecule + from fairchem.core.calculate.pretrained_mlip import ( + pretrained_checkpoint_path_from_name, + ) from huggingface_hub.utils._auth import get_token import torch_sim as ts @@ -222,6 +225,28 @@ def test_empty_batch_error() -> None: model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32)) +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_load_from_checkpoint_path(device: torch.device, dtype: torch.dtype) -> None: + """Test loading model from a saved checkpoint file path.""" + checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1") + loaded_model = FairChemModel( + model=str(checkpoint_path), task_name="omat", cpu=device.type == "cpu" + ) + + # Verify the loaded model works + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=device, dtype=dtype) + results = loaded_model(state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape == (1,) + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + test_fairchem_uma_model_outputs = pytest.mark.skipif( get_token() is None, reason="Requires HuggingFace authentication for UMA model access", diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index df69b9bd..5aa441b3 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -8,6 +8,7 @@ from __future__ import annotations +import os import traceback import typing import warnings @@ -70,9 +71,10 @@ class FairChemModel(ModelInterface): def __init__( self, - model: str, + model: str | Path | None, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only + model_name: str | None = None, model_cache_dir: str | Path | None = None, cpu: bool = False, dtype: torch.dtype | None = None, @@ -82,9 +84,10 @@ def __init__( """Initialize the FairChem model. Args: - model_name (str): Name of pretrained model to load + model (str | Path): Path to model checkpoint file neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) + model_name (str | None): Name of pretrained model to load model_cache_dir (str | Path | None): Path where to save the model cpu (bool): Whether to use CPU instead of GPU for computation dtype (torch.dtype | None): Data type to use for computation @@ -111,8 +114,16 @@ def __init__( "Custom neighbor list is not supported for FairChemModel." ) + if model_name is not None: + if model is not None: + raise RuntimeError( + "model_name and checkpoint_path were both specified, " + "please use only one at a time" + ) + model = model_name + if model is None: - raise ValueError("Valid fairchem model name or path needs to be specified") + raise ValueError("Either model or model_name must be provided") # Convert task_name to UMATask if it's a string (only for UMA models) if isinstance(task_name, str): @@ -127,20 +138,20 @@ def __init__( if model in pretrained_mlip.available_models: if model_cache_dir and model_cache_dir.exists(): self.predictor = pretrained_mlip.get_predict_unit( - model, - device=device_str, - cache_dir=model_cache_dir + model, device=device_str, cache_dir=model_cache_dir ) else: - self.predictor = pretrained_mlip.get_predict_unit(model, device=device_str) + self.predictor = pretrained_mlip.get_predict_unit( + model, device=device_str + ) elif os.path.isfile(model): - self.predictor = pretrained_mlip.load_predict_unit( - model, - device=device_str - ) + self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str) else: - raise ValueError(f"Invalid model name or checkpoint path: {model}. Available pretrained models are: {pretrained_mlip.available_models}") - + raise ValueError( + f"Invalid model name or checkpoint path: {model}. " + f"Available pretrained models are: {pretrained_mlip.available_models}" + ) + # Determine implemented properties # This is a simplified approach - in practice you might want to # inspect the model configuration more carefully From 45c67220c46610970063d42f98f95723d2c2bd4c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 8 Oct 2025 17:21:06 -0400 Subject: [PATCH 4/4] doc: fix docstring --- torch_sim/models/fairchem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 5aa441b3..3d7648f8 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -84,7 +84,7 @@ def __init__( """Initialize the FairChem model. Args: - model (str | Path): Path to model checkpoint file + model (str | Path | None): Path to model checkpoint file neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) model_name (str | None): Name of pretrained model to load