Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_model_calculator_consistency(
return test_model_calculator_consistency


def make_validate_model_outputs_test(
def make_validate_model_outputs_test( # noqa: PLR0915
model_fixture_name: str,
device: torch.device = DEVICE,
dtype: torch.dtype = torch.float64,
Expand All @@ -135,7 +135,7 @@ def make_validate_model_outputs_test(
model_fixture_name: Name of the model fixture to validate
"""

def test_model_output_validation(request: pytest.FixtureRequest) -> None:
def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915
"""Test that a model implementation follows the ModelInterface contract."""
# Get the model fixture dynamically
model: ModelInterface = request.getfixturevalue(model_fixture_name)
Expand Down Expand Up @@ -224,6 +224,15 @@ def test_model_output_validation(request: pytest.FixtureRequest) -> None:
# atol=10e-3,
# )

# Test single system output
assert fe_model_output["energy"].shape == (1,)
# forces should be shape (n_atoms, 3) for n_atoms in the system
if force_computed:
assert fe_model_output["forces"].shape == (12, 3)
# stress should be shape (1, 3, 3) for 1 system
if stress_computed:
assert fe_model_output["stress"].shape == (1, 3, 3)

# Rename the function to include the test name
test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
Loading
Loading