88
99from __future__ import annotations
1010
11+ import os
1112import traceback
1213import typing
1314import warnings
@@ -74,6 +75,7 @@ def __init__(
7475 neighbor_list_fn : Callable | None = None ,
7576 * , # force remaining arguments to be keyword-only
7677 model_name : str | None = None ,
78+ model_cache_dir : str | Path | None = None ,
7779 cpu : bool = False ,
7880 dtype : torch .dtype | None = None ,
7981 compute_stress : bool = False ,
@@ -86,6 +88,7 @@ def __init__(
8688 neighbor_list_fn (Callable | None): Function to compute neighbor lists
8789 (not currently supported)
8890 model_name (str | None): Name of pretrained model to load
91+ model_cache_dir (str | Path | None): Path where to save the model
8992 cpu (bool): Whether to use CPU instead of GPU for computation
9093 dtype (torch.dtype | None): Data type to use for computation
9194 compute_stress (bool): Whether to compute stress tensor
@@ -132,7 +135,22 @@ def __init__(
132135 self .task_name = task_name
133136
134137 # Create efficient batch predictor for fast inference
135- self .predictor = pretrained_mlip .get_predict_unit (str (model ), device = device_str )
138+ if model in pretrained_mlip .available_models :
139+ if model_cache_dir and model_cache_dir .exists ():
140+ self .predictor = pretrained_mlip .get_predict_unit (
141+ model , device = device_str , cache_dir = model_cache_dir
142+ )
143+ else :
144+ self .predictor = pretrained_mlip .get_predict_unit (
145+ model , device = device_str
146+ )
147+ elif os .path .isfile (model ):
148+ self .predictor = pretrained_mlip .load_predict_unit (model , device = device_str )
149+ else :
150+ raise ValueError (
151+ f"Invalid model name or checkpoint path: { model } . "
152+ f"Available pretrained models are: { pretrained_mlip .available_models } "
153+ )
136154
137155 # Determine implemented properties
138156 # This is a simplified approach - in practice you might want to
0 commit comments