diff --git a/pyproject.toml b/pyproject.toml index d3e6c9f9..eb463285 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ ignore = [ "FIX002", # Line contains TODO, consider resolving the issue "N803", # Variable name should be lowercase "N806", # Uppercase letters in variable names - "PLC0415", # import` should be at the top-level of a file + "PLC0415", # import should be at the top-level of a file "PLR0912", # too many branches "PLR0913", # too many function arguments "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 33a43cc2..e45e9d35 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -1,16 +1,19 @@ import traceback -import urllib.request -from enum import StrEnum from pathlib import Path import pytest -from tests.conftest import DEVICE -from tests.models.conftest import make_model_calculator_consistency_test +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import ( + consistency_test_simstate_fixtures, + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) try: from nequip.ase import NequIPCalculator + from nequip.scripts.compile import main from torch_sim.models.nequip_framework import ( NequIPFrameworkModel, @@ -22,29 +25,34 @@ ) -class NequIPUrls(StrEnum): - """Checkpoint download URLs for NequIP models.""" - - Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth" - - @pytest.fixture(scope="session") -def model_path_nequip(tmp_path_factory: pytest.TempPathFactory) -> Path: - tmp_path = tmp_path_factory.mktemp("nequip_checkpoints") - model_name = "Si.nequip.pth" - model_path = Path(tmp_path) / model_name - - if not model_path.is_file(): - urllib.request.urlretrieve(NequIPUrls.Si, model_path) # noqa: S310 +def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Compile NequIP OAM-L model from nequip.net.""" + tmp_path = tmp_path_factory.mktemp("nequip_compiled") + output_model_name = "mir-group__NequIP-OAM-L__0.1.nequip.pt2" + output_path = Path(tmp_path) / output_model_name + + main( + args=[ + "nequip.net:mir-group/NequIP-OAM-L:0.1", + str(output_path), + "--mode", + "aotinductor", + "--device", + DEVICE.type, + "--target", + "ase", + ] + ) - return model_path + return output_path -@pytest.fixture -def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel: +@pytest.fixture(scope="session") +def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel: """Create an NequIPModel wrapper for the pretrained model.""" compiled_model, (r_max, type_names) = from_compiled_model( - model_path_nequip, device=DEVICE + compiled_nequip_model_path, device=DEVICE ) return NequIPFrameworkModel( model=compiled_model, @@ -54,16 +62,18 @@ def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel: ) -@pytest.fixture -def nequip_calculator(model_path_nequip: Path) -> NequIPCalculator: +@pytest.fixture(scope="session") +def nequip_calculator(compiled_nequip_model_path: Path) -> NequIPCalculator: """Create an NequIPCalculator for the pretrained model.""" - return NequIPCalculator.from_compiled_model(str(model_path_nequip), device=DEVICE) + return NequIPCalculator.from_compiled_model( + str(compiled_nequip_model_path), device=DEVICE + ) -def test_nequip_initialization(model_path_nequip: Path) -> None: +def test_nequip_initialization(compiled_nequip_model_path: Path) -> None: """Test that the NequIP model initializes correctly.""" compiled_model, (r_max, type_names) = from_compiled_model( - model_path_nequip, device=DEVICE + compiled_nequip_model_path, device=DEVICE ) model = NequIPFrameworkModel( model=compiled_model, @@ -78,7 +88,14 @@ def test_nequip_initialization(model_path_nequip: Path) -> None: test_name="nequip", model_fixture_name="nequip_model", calculator_fixture_name="nequip_calculator", - sim_state_names=("si_sim_state", "rattled_si_sim_state"), + sim_state_names=consistency_test_simstate_fixtures, + energy_atol=5e-5, + dtype=DTYPE, + device=DEVICE, ) -# TODO (AG): Test multi element models +test_nequip_model_outputs = make_validate_model_outputs_test( + model_fixture_name="nequip_model", + dtype=DTYPE, + device=DEVICE, +)