diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 381c356f..b71a5910 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,7 +8,8 @@ Before a pull request can be merged, the following items must be checked: * [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). - Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. +* [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. +* [ ] Run `uvx ty check` on the repo. * [ ] Tests have been added for any new functionality or bug fixes. We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit. diff --git a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py index ff11162b..9ddd8ba7 100644 --- a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py @@ -89,8 +89,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results diff --git a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py index 7571f43d..3956c956 100644 --- a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py @@ -80,8 +80,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Initialize the Soft Sphere model diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 8c87f5a6..6f81519e 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -97,8 +97,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results results = model(state) @@ -148,11 +148,11 @@ stress = model(state)["stress"] -calc_kinetic_energy = calc_kinetic_energy( +kinetic_energy = calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) volume = torch.linalg.det(state.cell) -pressure = get_pressure(stress, calc_kinetic_energy, volume) +pressure = get_pressure(stress, kinetic_energy, volume) pressure = pressure.item() / Units.pressure print(f"Final {pressure=:.4f}") print(stress * UnitConversion.eV_per_Ang3_to_GPa) diff --git a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py index 92ea04a7..9506e9d4 100644 --- a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py +++ b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py @@ -78,11 +78,7 @@ masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) # Initialize the Lennard-Jones model # Parameters: diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index 10cc1dc4..4151eb1a 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -60,11 +60,7 @@ ) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) # Run initial inference diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index d7846c7e..c998e950 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -59,11 +59,7 @@ ) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) dt = 0.002 * Units.time # Timestep (ps) diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index 6375bc61..0c1ffa58 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -96,8 +96,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results results = model(state) diff --git a/examples/scripts/4_High_level_api/4.1_high_level_api.py b/examples/scripts/4_High_level_api/4.1_high_level_api.py index f09aa039..396ca035 100644 --- a/examples/scripts/4_High_level_api/4.1_high_level_api.py +++ b/examples/scripts/4_High_level_api/4.1_high_level_api.py @@ -54,7 +54,9 @@ prop_calculators = { 10: {"potential_energy": lambda state: state.energy}, 20: { - "kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses) + "kinetic_energy": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) }, } diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index fe1c041b..4b4edea4 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -24,12 +24,13 @@ from phonopy.structure.atoms import PhonopyAtoms import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.models.mace import MaceModel, MaceUrls def get_relaxed_structure( struct: Atoms, - model: torch.nn.Module | None, + model: ModelInterface, Nrelax: int = 300, fmax: float = 1e-3, *, @@ -80,7 +81,7 @@ def get_relaxed_structure( def get_qha_structures( state: ts.state.SimState, length_factors: np.ndarray, - model: torch.nn.Module | None, + model: ModelInterface, Nmax: int = 300, fmax: float = 1e-3, *, @@ -129,7 +130,7 @@ def get_qha_structures( def get_qha_phonons( scaled_structures: list[PhonopyAtoms], - model: torch.nn.Module | None, + model: ModelInterface, supercell_matrix: np.ndarray | None, displ: float = 0.05, *, diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index c26d7479..cf0debb4 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -132,7 +132,7 @@ 10: {"potential_energy": lambda state: state.energy}, 20: { "kinetic_energy": lambda state: ts.calc_kinetic_energy( - state.momenta, state.masses + momenta=state.momenta, masses=state.masses ) }, } diff --git a/examples/tutorials/reporting_tutorial.py b/examples/tutorials/reporting_tutorial.py index c47340fe..47477613 100644 --- a/examples/tutorials/reporting_tutorial.py +++ b/examples/tutorials/reporting_tutorial.py @@ -206,6 +206,7 @@ # %% from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.models.interface import ModelInterface # Define some property calculators @@ -214,7 +215,7 @@ def calculate_com(state: ts.state.SimState) -> torch.Tensor: return torch.mean(state.positions * state.masses.unsqueeze(1), dim=0) -def calculate_energy(state: ts.state.SimState, model: torch.nn.Module) -> torch.Tensor: +def calculate_energy(state: ts.state.SimState, model: ModelInterface) -> torch.Tensor: """Calculate energy - needs both state and model""" return model(state)["energy"] diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 44ca3237..a137ed78 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -48,7 +48,7 @@ def pretrained_mattersim_model(device: torch.device, model_name: str): @pytest.fixture def mattersim_model( - pretrained_mattersim_model: torch.nn.Module, device: torch.device + pretrained_mattersim_model: Potential, device: torch.device ) -> MatterSimModel: """Create an MatterSimModel wrapper for the pretrained model.""" return MatterSimModel( @@ -66,7 +66,7 @@ def mattersim_calculator( def test_mattersim_initialization( - pretrained_mattersim_model: torch.nn.Module, device: torch.device + pretrained_mattersim_model: Potential, device: torch.device ) -> None: """Test that the MatterSim model initializes correctly.""" model = MatterSimModel( diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index a17f8558..25bd310a 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -11,6 +11,7 @@ try: import sevenn.util from sevenn.calculator import SevenNetCalculator + from sevenn.nn.sequential import AtomGraphSequential from torch_sim.models.sevennet import SevenNetModel @@ -50,7 +51,7 @@ def pretrained_sevenn_model(device: torch.device, model_name: str): @pytest.fixture def sevenn_model( - pretrained_sevenn_model: torch.nn.Module, device: torch.device, modal_name: str + pretrained_sevenn_model: AtomGraphSequential, device: torch.device, modal_name: str ) -> SevenNetModel: """Create an SevenNetModel wrapper for the pretrained model.""" return SevenNetModel( @@ -69,7 +70,7 @@ def sevenn_calculator( def test_sevennet_initialization( - pretrained_sevenn_model: torch.nn.Module, device: torch.device + pretrained_sevenn_model: AtomGraphSequential, device: torch.device ) -> None: """Test that the SevenNet model initializes correctly.""" model = SevenNetModel( diff --git a/tests/test_integrators.py b/tests/test_integrators.py index ac7bf4b8..d5b210d1 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -109,7 +109,9 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -172,7 +174,9 @@ def test_npt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -213,7 +217,9 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -273,7 +279,9 @@ def test_nvt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 3be7787d..479552c0 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -3,6 +3,7 @@ from pymatgen.core import Structure import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.monte_carlo import ( SwapMCState, generate_swaps, @@ -112,7 +113,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): def test_monte_carlo( batched_diverse_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ): """Test the monte_carlo function that returns a step function and initial state.""" # Call monte_carlo to get the initial state and step function diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index a5bfa675..141bf5ee 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -6,6 +6,7 @@ import torch import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import ( FireState, FrechetCellFIREState, @@ -23,7 +24,7 @@ def test_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -62,7 +63,7 @@ def test_gradient_descent_optimization( def test_unit_cell_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -111,7 +112,7 @@ def test_unit_cell_gradient_descent_optimization( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the FIRE optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -185,7 +186,7 @@ def test_simple_optimizer_init_with_dict( optimizer_fn: callable, expected_state_type: type, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test simple optimizer init_fn with a ts.SimState dictionary.""" state_dict = { @@ -201,7 +202,7 @@ def test_simple_optimizer_init_with_dict( @pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) def test_optimizer_invalid_md_flavor( - optimizer_func: callable, lj_model: torch.nn.Module + optimizer_func: callable, lj_model: ModelInterface ) -> None: """Test optimizer with an invalid md_flavor raises ValueError.""" with pytest.raises(ValueError, match="Unknown md_flavor"): @@ -209,7 +210,7 @@ def test_optimizer_invalid_md_flavor( def test_fire_ase_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the ASE FIRE P<0 branch behaves as expected.""" f_dec = 0.5 # Default from fire optimizer @@ -272,7 +273,7 @@ def test_fire_ase_negative_power_branch( def test_fire_vv_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Attempt to trigger and test the VV FIRE P<0 branch.""" f_dec = 0.5 @@ -325,7 +326,7 @@ def test_fire_vv_negative_power_branch( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_unit_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Unit Cell FIRE optimizer actually minimizes energy.""" @@ -414,7 +415,7 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( expected_state_type: type, cell_factor_val: float, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with dict state and explicit cell_factor.""" state_dict = { @@ -448,7 +449,7 @@ def test_cell_optimizer_init_cell_factor_none( optimizer_fn: callable, expected_state_type: type, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with cell_factor=None.""" init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) @@ -467,7 +468,7 @@ def test_cell_optimizer_init_cell_factor_none( @pytest.mark.filterwarnings("ignore:WARNING: Non-positive volume detected") def test_unit_cell_fire_ase_non_positive_volume_warning( ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, capsys: pytest.CaptureFixture, ) -> None: """Attempt to trigger non-positive volume warning in unit_cell_fire ASE.""" @@ -503,7 +504,7 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_frechet_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Frechet Cell FIRE optimizer actually minimizes energy for different md_flavors.""" @@ -592,7 +593,7 @@ def test_frechet_cell_fire_optimization( def test_optimizer_batch_consistency( optimizer_func: callable, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test batched optimizer is consistent with individual optimizations.""" generator = torch.Generator(device=ar_supercell_sim_state.device) @@ -707,7 +708,7 @@ def energy_converged(current_e: torch.Tensor, prev_e: torch.Tensor) -> bool: def test_unit_cell_fire_multi_batch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test FIRE optimization with multiple batches.""" # Create a multi-batch system by duplicating ar_fcc_state @@ -783,7 +784,7 @@ def test_unit_cell_fire_multi_batch( def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test batched Frechet Fixed cell FIRE optimization is consistent with FIRE (position only) optimizations.""" diff --git a/tests/test_quantities.py b/tests/test_quantities.py new file mode 100644 index 00000000..7513b6bd --- /dev/null +++ b/tests/test_quantities.py @@ -0,0 +1,136 @@ +import pytest +import torch +from torch._tensor import Tensor + +from torch_sim import quantities +from torch_sim.units import MetalUnits + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DTYPE = torch.double + + +@pytest.fixture +def single_system_data() -> dict[str, Tensor]: + masses = torch.tensor([1.0, 2.0], device=DEVICE, dtype=DTYPE) + velocities = torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], device=DEVICE, dtype=DTYPE + ) + momenta = velocities * masses.unsqueeze(-1) + return { + "masses": masses, + "velocities": velocities, + "momenta": momenta, + "ke": torch.tensor(13.5, device=DEVICE, dtype=DTYPE), + "kt": torch.tensor(4.5, device=DEVICE, dtype=DTYPE), + } + + +@pytest.fixture +def batched_system_data() -> dict[str, Tensor]: + masses = torch.tensor([1.0, 1.0, 2.0, 2.0], device=DEVICE, dtype=DTYPE) + velocities = torch.tensor( + [[1, 1, 1], [1, 1, 1], [2, 2, 2], [2, 2, 2]], device=DEVICE, dtype=DTYPE + ) + momenta = velocities * masses.unsqueeze(-1) + system_idx = torch.tensor([0, 0, 1, 1], device=DEVICE) + return { + "masses": masses, + "velocities": velocities, + "momenta": momenta, + "system_idx": system_idx, + "ke": torch.tensor([3.0, 24.0], device=DEVICE, dtype=DTYPE), + "kt": torch.tensor([1.0, 8.0], device=DEVICE, dtype=DTYPE), + } + + +def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: + # With velocities + ke_vel = quantities.calc_kinetic_energy( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(ke_vel, single_system_data["ke"]) + + # With momenta + ke_mom = quantities.calc_kinetic_energy( + masses=single_system_data["masses"], momenta=single_system_data["momenta"] + ) + assert torch.allclose(ke_mom, single_system_data["ke"]) + + +def test_calc_kinetic_energy_batched_system( + batched_system_data: dict[str, Tensor], +) -> None: + # With velocities + ke_vel = quantities.calc_kinetic_energy( + masses=batched_system_data["masses"], + velocities=batched_system_data["velocities"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(ke_vel, batched_system_data["ke"]) + + # With momenta + ke_mom = quantities.calc_kinetic_energy( + masses=batched_system_data["masses"], + momenta=batched_system_data["momenta"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(ke_mom, batched_system_data["ke"]) + + +def test_calc_kinetic_energy_errors(single_system_data: dict[str, Tensor]) -> None: + with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): + quantities.calc_kinetic_energy( + masses=single_system_data["masses"], + momenta=single_system_data["momenta"], + velocities=single_system_data["velocities"], + ) + + with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): + quantities.calc_kinetic_energy(masses=single_system_data["masses"]) + + +def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: + # With velocities + kt_vel = quantities.calc_kT( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(kt_vel, single_system_data["kt"]) + + # With momenta + kt_mom = quantities.calc_kT( + masses=single_system_data["masses"], momenta=single_system_data["momenta"] + ) + assert torch.allclose(kt_mom, single_system_data["kt"]) + + +def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: + # With velocities + kt_vel = quantities.calc_kT( + masses=batched_system_data["masses"], + velocities=batched_system_data["velocities"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(kt_vel, batched_system_data["kt"]) + + # With momenta + kt_mom = quantities.calc_kT( + masses=batched_system_data["masses"], + momenta=batched_system_data["momenta"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(kt_mom, batched_system_data["kt"]) + + +def test_calc_temperature(single_system_data: dict[str, Tensor]) -> None: + temp = quantities.calc_temperature( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + kt = quantities.calc_kT( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(temp, kt / MetalUnits.temperature) diff --git a/tests/test_runners.py b/tests/test_runners.py index cd1ff3db..5c9862d0 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -23,7 +23,11 @@ def test_integrate_nve( filenames=traj_file, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -56,7 +60,11 @@ def test_integrate_single_nvt( filenames=traj_file, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -108,7 +116,11 @@ def test_integrate_double_nvt_with_reporter( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -155,7 +167,11 @@ def test_integrate_many_nvt( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -346,7 +362,11 @@ def test_batched_optimize_fire( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + velocities=state.velocities, masses=state.masses + ) + } }, ) diff --git a/tests/test_state.py b/tests/test_state.py index ea57dd3a..af0bda7b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -500,8 +500,8 @@ class DeformState(SimState, DeformGradMixin): def __init__( self, *args, - velocities: torch.Tensor | None = None, - reference_cell: torch.Tensor | None = None, + velocities: torch.Tensor, + reference_cell: torch.Tensor, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -530,14 +530,6 @@ def deform_grad_state(device: torch.device) -> DeformState: ) -def test_deform_grad_momenta(deform_grad_state: DeformState) -> None: - """Test momenta calculation in DeformGradMixin.""" - expected_momenta = deform_grad_state.velocities * deform_grad_state.masses.unsqueeze( - -1 - ) - assert torch.allclose(deform_grad_state.momenta, expected_momenta) - - def test_deform_grad_reference_cell(deform_grad_state: DeformState) -> None: """Test reference cell getter/setter in DeformGradMixin.""" original_ref_cell = deform_grad_state.reference_cell.clone() diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index e003f7a7..67d9de1f 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -9,6 +9,7 @@ import torch_sim as ts from torch_sim.integrators import MDState +from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter @@ -748,7 +749,7 @@ def test_reporter_with_model( """Test TrajectoryReporter with a model argument in property calculators.""" # Create a property calculator that uses the model - def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tensor: + def energy_calculator(state: ts.SimState, model: ModelInterface) -> torch.Tensor: output = model(state) # Calculate a property that depends on the model return output["energy"] diff --git a/tests/workflows/test_a2c.py b/tests/workflows/test_a2c.py index aaed7317..a95ce0e0 100644 --- a/tests/workflows/test_a2c.py +++ b/tests/workflows/test_a2c.py @@ -1,10 +1,12 @@ +from typing import cast + import pytest import torch from pymatgen.core.composition import Composition import torch_sim as ts from torch_sim.models.soft_sphere import SoftSphereModel -from torch_sim.optimizers import UnitCellFireState +from torch_sim.optimizers import FireState, UnitCellFireState from torch_sim.workflows import a2c @@ -155,6 +157,7 @@ def test_random_packed_structure_auto_diameter(device: torch.device) -> None: max_iter=3, device=device, ) + state = cast("FireState", state) # Just check that it ran without errors assert state.positions is not None diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 55bb27a9..a067e03d 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -24,6 +24,7 @@ import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import BravaisType @@ -1105,7 +1106,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 def calculate_elastic_tensor( - model: torch.nn.Module, + model: ModelInterface, *, state: SimState, bravais_type: BravaisType = BravaisType.TRICLINIC, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index ce15877d..069f62eb 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -6,6 +6,7 @@ import torch from torch_sim import transforms +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -156,7 +157,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def velocity_verlet(state: MDState, dt: torch.Tensor, model: torch.nn.Module) -> MDState: +def velocity_verlet(state: MDState, dt: torch.Tensor, model: ModelInterface) -> MDState: """Perform one complete velocity Verlet integration step. This function implements the velocity Verlet algorithm, which provides diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 27b78c8b..e1ac1d39 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -14,6 +14,7 @@ calculate_momenta, construct_nose_hoover_chain, ) +from torch_sim.models.interface import ModelInterface from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -140,7 +141,7 @@ def _compute_cell_force( def npt_langevin( # noqa: C901, PLR0915 - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -162,7 +163,7 @@ def npt_langevin( # noqa: C901, PLR0915 maintain constant temperature. Args: - model (torch.nn.Module): Neural network model that computes energies, forces, + model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or @@ -898,7 +899,7 @@ def current_cell(self) -> torch.Tensor: def npt_nose_hoover( # noqa: C901, PLR0915 *, - model: torch.nn.Module, + model: ModelInterface, kT: torch.Tensor, external_pressure: torch.Tensor, dt: torch.Tensor, @@ -915,7 +916,7 @@ def npt_nose_hoover( # noqa: C901, PLR0915 with Nose-Hoover chain thermostats for temperature and pressure control. Args: - model (torch.nn.Module): Model to compute forces and energies + model (ModelInterface): Model to compute forces and energies kT (torch.Tensor): Target temperature in energy units external_pressure (torch.Tensor): Target external pressure dt (torch.Tensor): Integration timestep @@ -1221,7 +1222,9 @@ def compute_cell_force( if system_mask.any(): system_momenta = momenta[system_mask] system_masses = masses[system_mask] - KE_per_system[b] = calc_kinetic_energy(system_momenta, system_masses) + KE_per_system[b] = calc_kinetic_energy( + masses=system_masses, momenta=system_momenta + ) # Get stress tensor and compute trace per system # Handle stress tensor with batch dimension @@ -1430,7 +1433,7 @@ def npt_nose_hoover_init( cell_mass = cell_mass.to(device=device, dtype=dtype) # Calculate cell kinetic energy (using first system for initialization) - KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1]) + KE_cell = calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: @@ -1485,7 +1488,9 @@ def npt_nose_hoover_init( # Initialize thermostat npt_state.momenta = momenta KE = calc_kinetic_energy( - npt_state.momenta, npt_state.masses, system_idx=npt_state.system_idx + momenta=npt_state.momenta, + masses=npt_state.masses, + system_idx=npt_state.system_idx, ) npt_state.thermostat = thermostat_fns.initialize( npt_state.positions.numel(), KE, kT @@ -1542,10 +1547,12 @@ def npt_nose_hoover_update( ) # Update kinetic energies for thermostats - KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) state.thermostat.kinetic_energy = KE - KE_cell = calc_kinetic_energy(state.cell_momentum, state.cell_mass) + KE_cell = calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains @@ -1597,7 +1604,7 @@ def npt_nose_hoover_invariant( # Calculate kinetic energy of particles per system e_kin_per_system = calc_kinetic_energy( - state.momenta, state.masses, system_idx=state.system_idx + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) # Calculate degrees of freedom per system diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index f17e59f6..c7e41390 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -10,12 +10,13 @@ momentum_step, position_step, ) +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import StateDict def nve( - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 19a811c8..8309afa1 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -16,13 +16,14 @@ position_step, velocity_verlet, ) +from torch_sim.models.interface import ModelInterface from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict def nvt_langevin( # noqa: C901 - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -275,7 +276,7 @@ def velocities(self) -> torch.Tensor: def nvt_nose_hoover( *, - model: torch.nn.Module, + model: ModelInterface, dt: torch.Tensor, kT: torch.Tensor, chain_length: int = 3, @@ -367,7 +368,9 @@ def nvt_nose_hoover_init( ) # Calculate initial kinetic energy per system - KE = calc_kinetic_energy(momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) # Calculate degrees of freedom per system n_atoms_per_system = torch.bincount(state.system_idx) @@ -433,7 +436,9 @@ def nvt_nose_hoover_update( state = velocity_verlet(state=state, dt=dt, model=model) # Update chain kinetic energy per system - KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) chain.kinetic_energy = KE # Second half-step of chain evolution @@ -477,7 +482,9 @@ def nvt_nose_hoover_invariant( """ # Calculate system energy terms per system e_pot = state.energy - e_kin = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + e_kin = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) # Get system degrees of freedom per system n_atoms_per_system = torch.bincount(state.system_idx) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index b911668e..77b1b0ba 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -45,7 +45,7 @@ except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) - class FairChemModel(torch.nn.Module, ModelInterface): + class FairChemModel(ModelInterface): """FairChem model wrapper for torch_sim. This class is a placeholder for the FairChemModel class. @@ -70,7 +70,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: } -class FairChemModel(torch.nn.Module, ModelInterface): +class FairChemModel(ModelInterface): """Computes atomistic energies, forces and stresses using a FairChem model. This class wraps a FairChem model to compute energies, forces, and stresses for diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 6ce52753..a7b287e0 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -36,7 +36,7 @@ warnings.warn(f"GraphPES import failed: {traceback.format_exc()}", stacklevel=2) PropertyKey = str - class GraphPESWrapper(torch.nn.Module, ModelInterface): # type: ignore[reportRedeclaration] + class GraphPESWrapper(ModelInterface): # type: ignore[reportRedeclaration] """GraphPESModel wrapper for torch_sim. This class is a placeholder for the GraphPESWrapper class. @@ -99,7 +99,7 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra return to_batch(graphs) -class GraphPESWrapper(torch.nn.Module, ModelInterface): +class GraphPESWrapper(ModelInterface): """Wrapper for GraphPESModel in TorchSim. This class provides a TorchSim wrapper around GraphPESModel instances, diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index a70736ab..27c03277 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -27,8 +27,6 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): """ from abc import ABC, abstractmethod -from pathlib import Path -from typing import Self import torch @@ -37,7 +35,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): from torch_sim.typing import MemoryScaling, StateDict -class ModelInterface(ABC): +class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in torchsim. This interface provides a common structure for all energy and force models, @@ -71,37 +69,10 @@ class ModelInterface(ABC): ``` """ - @abstractmethod - def __init__( - self, - model: str | Path | torch.nn.Module | None = None, - device: torch.device | None = None, - dtype: torch.dtype = torch.float64, - **kwargs, - ) -> Self: - """Initialize a model implementation. - - Implementations must set device, dtype and compute capability flags - to indicate what operations the model supports. Models may optionally - load parameters from a file or existing module. - - Args: - model (str | Path | torch.nn.Module | None): Model specification, which - can be: - - Path to a model checkpoint or model file - - Pre-configured torch.nn.Module - - None for default initialization - Defaults to None. - device (torch.device | None): Device where the model will run. If None, - a default device will be selected. Defaults to None. - dtype (torch.dtype): Data type for model calculations. Defaults to - torch.float64. - **kwargs: Additional model-specific parameters. - - Notes: - All implementing classes must set self._device, self._dtype, - self._compute_stress and self._compute_forces in their __init__ method. - """ + _device: torch.device + _dtype: torch.dtype + _compute_stress: bool + _compute_forces: bool @property def device(self) -> torch.device: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 3b7a5b81..2a05e2f8 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -119,7 +119,7 @@ def lennard_jones_pair_force( return torch.where(dr > 0, force, torch.zeros_like(force)) -class LennardJonesModel(torch.nn.Module, ModelInterface): +class LennardJonesModel(ModelInterface): """Lennard-Jones potential energy and force calculator. Implements the Lennard-Jones 12-6 potential for molecular dynamics simulations. diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index cfd34142..5ca2a629 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -38,7 +38,7 @@ except ImportError as exc: warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2) - class MaceModel(torch.nn.Module, ModelInterface): + class MaceModel(ModelInterface): """MACE model wrapper for torch_sim. This class is a placeholder for the MaceModel class. @@ -77,7 +77,7 @@ def to_one_hot( return oh.view(*shape) -class MaceModel(torch.nn.Module, ModelInterface): +class MaceModel(ModelInterface): """Computes energies for multiple systems using a MACE model. This class wraps a MACE model to compute energies, forces, and stresses for diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index 40ef0cc5..9b2efb23 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -21,7 +21,7 @@ except ImportError as exc: warnings.warn(f"MatterSim import failed: {traceback.format_exc()}", stacklevel=2) - class MatterSimModel(torch.nn.Module, ModelInterface): + class MatterSimModel(ModelInterface): """MatterSim model wrapper for torch_sim. This class is a placeholder for the MatterSimModel class. @@ -39,7 +39,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -class MatterSimModel(torch.nn.Module, ModelInterface): +class MatterSimModel(ModelInterface): """Computes atomistic energies, forces and stresses using an MatterSim model. This class wraps an MatterSim model to compute energies, forces, and stresses for diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 47fda077..5655ed88 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -36,7 +36,7 @@ except ImportError as exc: warnings.warn(f"Metatomic import failed: {traceback.format_exc()}", stacklevel=2) - class MetatomicModel(torch.nn.Module, ModelInterface): + class MetatomicModel(ModelInterface): """Metatomic model wrapper for torch_sim. This class is a placeholder for the MetatomicModel class. @@ -48,7 +48,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -class MetatomicModel(torch.nn.Module, ModelInterface): +class MetatomicModel(ModelInterface): """Computes energies for a list of systems using a metatomic model. This class wraps a metatomic model to compute energies, forces, and stresses for diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 97564361..702dc41f 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -112,7 +112,7 @@ def morse_pair_force( return torch.where(dr > 0, force, torch.zeros_like(force)) -class MorseModel(torch.nn.Module, ModelInterface): +class MorseModel(ModelInterface): """Morse potential energy and force calculator. Implements the Morse potential for molecular dynamics simulations. This model diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 7b4bffd7..fd65b23f 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -39,7 +39,7 @@ except ImportError as exc: warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2) - class OrbModel(torch.nn.Module, ModelInterface): + class OrbModel(ModelInterface): """ORB model wrapper for torch_sim. This class is a placeholder for the OrbModel class. @@ -247,7 +247,7 @@ def state_to_atom_graphs( # noqa: PLR0915 ).to(device=device, dtype=output_dtype) -class OrbModel(torch.nn.Module, ModelInterface): +class OrbModel(ModelInterface): """Computes atomistic energies, forces and stresses using an ORB model. This class wraps an ORB model to compute energies, forces, and stresses for diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 39c5606c..3a13a333 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -83,7 +83,7 @@ def asymmetric_particle_pair_force_jit( return inner_forces + outer_forces -class ParticleLifeModel(torch.nn.Module, ModelInterface): +class ParticleLifeModel(ModelInterface): """Calculator for asymmetric particle interaction. This model implements an asymmetric interaction between particles based on diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 6156fc17..cda8e183 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -32,7 +32,7 @@ except ImportError as exc: warnings.warn(f"SevenNet import failed: {traceback.format_exc()}", stacklevel=2) - class SevenNetModel(torch.nn.Module, ModelInterface): + class SevenNetModel(ModelInterface): """SevenNet model wrapper for torch_sim. This class is a placeholder for the SevenNetModel class. @@ -44,7 +44,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -class SevenNetModel(torch.nn.Module, ModelInterface): +class SevenNetModel(ModelInterface): """Computes atomistic energies, forces and stresses using an SevenNet model. This class wraps an SevenNet model to compute energies, forces, and stresses for diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 4397b2eb..75598aaf 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -130,7 +130,7 @@ def fn(dr: torch.Tensor) -> torch.Tensor: return transforms.safe_mask(mask, fn, dr) -class SoftSphereModel(torch.nn.Module, ModelInterface): +class SoftSphereModel(ModelInterface): """Calculator for soft sphere potential energies and forces. Implements a model for computing properties based on the soft sphere potential, @@ -435,7 +435,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: return results -class SoftSphereMultiModel(torch.nn.Module): +class SoftSphereMultiModel(ModelInterface): """Calculator for systems with multiple particle types. Extends the basic soft sphere model to support multiple particle types with @@ -594,11 +594,11 @@ def __init__( with type 0). """ super().__init__() - self.device = device or torch.device("cpu") - self.dtype = dtype + self._device = device or torch.device("cpu") + self._dtype = dtype self.pbc = pbc - self.compute_forces = compute_forces - self.compute_stress = compute_stress + self._compute_forces = compute_forces + self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 90929068..64aad3c7 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -15,6 +15,7 @@ import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -183,7 +184,7 @@ def metropolis_criterion( def swap_monte_carlo( *, - model: torch.nn.Module, + model: ModelInterface, kT: float, seed: int | None = None, ) -> tuple[ diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..44cb498e 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -25,6 +25,7 @@ import torch import torch_sim.math as tsm +from torch_sim.models.interface import ModelInterface from torch_sim.state import DeformGradMixin, SimState from torch_sim.typing import StateDict @@ -57,7 +58,7 @@ class GDState(SimState): def gradient_descent( - model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 + model: ModelInterface, *, lr: torch.Tensor | float = 0.01 ) -> tuple[Callable[[StateDict | SimState], GDState], Callable[[GDState], GDState]]: """Initialize a batched gradient descent optimization. @@ -196,7 +197,7 @@ class UnitCellGDState(GDState, DeformGradMixin): def unit_cell_gradient_descent( # noqa: PLR0915, C901 - model: torch.nn.Module, + model: ModelInterface, *, positions_lr: float = 0.01, cell_lr: float = 0.1, @@ -483,7 +484,7 @@ class FireState(SimState): def fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -692,7 +693,7 @@ class UnitCellFireState(SimState, DeformGradMixin): def unit_cell_fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -708,7 +709,7 @@ def unit_cell_fire( max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ - UnitCellFireState, + Callable[[SimState | StateDict], UnitCellFireState], Callable[[UnitCellFireState], UnitCellFireState], ]: """Initialize a batched FIRE optimization with unit cell degrees of freedom. @@ -976,7 +977,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): def frechet_cell_fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -992,7 +993,7 @@ def frechet_cell_fire( max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ - FrechetCellFIREState, + Callable[[SimState | StateDict], FrechetCellFIREState], Callable[[FrechetCellFIREState], FrechetCellFIREState], ]: """Initialize a batched FIRE optimization with Frechet cell parameterization. @@ -1204,7 +1205,7 @@ def fire_init( def _vv_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, - model: torch.nn.Module, + model: ModelInterface, *, dt_max: torch.Tensor, n_min: torch.Tensor, @@ -1420,7 +1421,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 def _ase_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, - model: torch.nn.Module, + model: ModelInterface, *, dt_max: torch.Tensor, n_min: torch.Tensor, diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 971b1b54..a1ac0811 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -1,5 +1,7 @@ """Functions for computing physical quantities.""" +from typing import cast + import torch from torch_sim.state import SimState @@ -21,8 +23,9 @@ def count_dof(tensor: torch.Tensor) -> int: # @torch.jit.script def calc_kT( # noqa: N802 - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> torch.Tensor: @@ -38,14 +41,12 @@ def calc_kT( # noqa: N802 Returns: torch.Tensor: Scalar temperature value """ - if momenta is not None and velocities is not None: - raise ValueError("Must pass either momenta or velocities, not both") - - if momenta is None and velocities is None: - raise ValueError("Must pass either momenta or velocities") + if not ((momenta is not None) ^ (velocities is not None)): + raise ValueError("Must pass either one of momenta or velocities") if momenta is None: # If velocity provided, calculate mv^2 + velocities = cast("torch.Tensor", velocities) squared_term = (velocities**2) * masses.unsqueeze(-1) else: # If momentum provided, calculate v^2 = p^2/m^2 @@ -70,11 +71,12 @@ def calc_kT( # noqa: N802 def calc_temperature( - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, - units: object = MetalUnits.temperature, + units: MetalUnits = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. @@ -89,13 +91,17 @@ def calc_temperature( Returns: torch.Tensor: Temperature value in specified units """ - return calc_kT(momenta, masses, velocities, system_idx) / units + kT = calc_kT( + masses=masses, momenta=momenta, velocities=velocities, system_idx=system_idx + ) + return kT / units # @torch.jit.script def calc_kinetic_energy( - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> torch.Tensor: @@ -112,10 +118,8 @@ def calc_kinetic_energy( If system_idx is None: Scalar tensor containing the total kinetic energy If system_idx is provided: Tensor of kinetic energies per system """ - if momenta is not None and velocities is not None: - raise ValueError("Must pass either momenta or velocities, not both") - if momenta is None and velocities is None: - raise ValueError("Must pass either momenta or velocities") + if not ((momenta is not None) ^ (velocities is not None)): + raise ValueError("Must pass either one of momenta or velocities") if momenta is None: # Using velocities squared_term = (velocities**2) * masses.unsqueeze(-1) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3c83724c..187cdd89 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -45,8 +45,12 @@ def _configure_reporter( "potential_energy": lambda state: state.energy, "forces": lambda state: state.forces, "stress": lambda state: state.stress, - "kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses), - "temperature": lambda state: calc_kT(state.momenta, state.masses), + "kinetic_energy": lambda state: calc_kinetic_energy( + velocities=state.velocities, masses=state.masses + ), + "temperature": lambda state: calc_kT( + velocities=state.velocities, masses=state.masses + ), } prop_calculators = { diff --git a/torch_sim/state.py b/torch_sim/state.py index 5240d094..af4db6d1 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,12 +8,12 @@ import importlib import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, cast import torch import torch_sim as ts -from torch_sim.typing import StateLike +from torch_sim.typing import SimStateVar, StateLike if TYPE_CHECKING: @@ -109,6 +109,7 @@ def __init__( self.pbc = pbc self.atomic_numbers = atomic_numbers + # Validate and process the state after initialization. # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way @@ -136,13 +137,13 @@ def __init__( self.n_atoms, device=self.device, dtype=torch.int64 ) else: - self.system_idx = system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. - _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) - if not torch.all(counts == torch.bincount(self.system_idx)): + _, counts = torch.unique_consecutive(system_idx, return_counts=True) + if not torch.all(counts == torch.bincount(system_idx)): raise ValueError("System indices must be unique consecutive integers") + self.system_idx = system_idx if self.cell.ndim != 3 and system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -251,7 +252,9 @@ def n_systems(self) -> int: @property def volume(self) -> torch.Tensor: """Volume of the system.""" - return torch.det(self.cell) if self.pbc else None + if not self.pbc: + raise ValueError("Volume is only defined for periodic systems") + return torch.det(self.cell) @property def column_vector_cell(self) -> torch.Tensor: @@ -361,7 +364,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se for attr_name, attr_value in vars(modified_state).items(): setattr(self, attr_name, attr_value) - return popped_states + return cast("list[Self]", popped_states) def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None @@ -401,14 +404,8 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> class DeformGradMixin: """Mixin for states that support deformation gradients.""" - @property - def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses. - - Returns: - The momenta of the particles - """ - return self.velocities * self.masses.unsqueeze(-1) + reference_cell: torch.Tensor + row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: @@ -483,10 +480,10 @@ def _normalize_system_indices( def state_to_device( - state: SimState, + state: SimStateVar, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> Self: +) -> SimStateVar: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -692,9 +689,9 @@ def _filter_attrs_by_mask( def _split_state( - state: SimState, + state: SimStateVar, ambiguous_handling: Literal["error", "globalize"] = "error", -) -> list[SimState]: +) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving @@ -805,10 +802,10 @@ def _pop_states( def _slice_state( - state: SimState, + state: SimStateVar, system_indices: list[int] | torch.Tensor, ambiguous_handling: Literal["error", "globalize"] = "error", -) -> SimState: +) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving @@ -968,7 +965,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_systems == 1 for state in system): + if not all(cast("SimState", state).n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 3150064a..fb170754 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -38,6 +38,7 @@ import tables import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -96,9 +97,9 @@ def __init__( state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, - state_kwargs: dict | None = None, + state_kwargs: dict[str, Any] | None = None, metadata: dict[str, str] | None = None, - trajectory_kwargs: dict | None = None, + trajectory_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize a TrajectoryReporter. @@ -203,7 +204,7 @@ def report( self, state: SimState, step: int, - model: torch.nn.Module | None = None, + model: ModelInterface | None = None, ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -216,7 +217,7 @@ def report( len(filenames) step (int): Current simulation step, setting step to 0 will write the state and all properties. - model (torch.nn.Module, optional): Model used for simulation. + model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. write_to_file (bool, optional): Whether to write the state to the trajectory diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index e3abb8d5..fb5bba0b 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -19,6 +19,7 @@ import torch_sim as ts from torch_sim import transforms +from torch_sim.models.interface import ModelInterface from torch_sim.models.soft_sphere import SoftSphereModel, SoftSphereMultiModel from torch_sim.optimizers import FireState, UnitCellFireState, fire from torch_sim.optimizers import unit_cell_fire as batched_unit_cell_fire @@ -228,7 +229,7 @@ def random_packed_structure( device: torch.device | None = None, dtype: torch.dtype = torch.float32, log: Any | None = None, -) -> FireState: +) -> FireState | tuple[FireState, list[np.ndarray]]: """Generates a random packed atomic structure and minimizes atomic overlaps. This function creates a random atomic structure within a given cell and optionally @@ -326,6 +327,7 @@ def random_packed_structure( if log is not None: return state, log + return state @@ -575,6 +577,13 @@ def get_subcells_to_crystallize( # Convert species list to numpy array for easier composition handling species_array = np.array(species) + if restrict_to_compositions is not None and restrict_to_compositions: + restrict_to_compositions: set[str] = { + Composition(comp).reduced_formula for comp in restrict_to_compositions + } + else: + restrict_to_compositions: set[str] = set() + # Generate allowed stoichiometries if max_coef is specified if max_coeff: if elements is None: @@ -583,17 +592,9 @@ def get_subcells_to_crystallize( stoichs = list(itertools.product(range(max_coeff + 1), repeat=len(elements))) stoichs.pop(0) # Remove the empty composition (0,0,...) # Convert stoichiometries to composition formulas - comps = [] for stoich in stoichs: comp = dict(zip(elements, stoich, strict=True)) - comps.append(Composition.from_dict(comp).reduced_formula) - restrict_to_compositions = set(comps) - - # Ensure compositions are in reduced formula form if provided - if restrict_to_compositions: - restrict_to_compositions = [ - Composition(comp).reduced_formula for comp in restrict_to_compositions - ] + restrict_to_compositions.add(Composition.from_dict(comp).reduced_formula) # Create orthorhombic grid for systematic subcell generation bins = int(1 / d_frac) @@ -610,7 +611,7 @@ def get_subcells_to_crystallize( .T ) - candidates = [] + candidates: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] # Iterate through all possible subcell boundary combinations for lb, ub in itertools.product(l_bound, u_bound): if torch.all(ub > lb): # Ensure valid subcell dimensions @@ -705,9 +706,9 @@ def get_target_temperature( def get_unit_cell_relaxed_structure( state: ts.SimState, - model: torch.nn.Module, + model: ModelInterface, max_iter: int = 200, -) -> tuple[UnitCellFireState, dict]: +) -> tuple[UnitCellFireState, dict[str, torch.Tensor], list[float], list[float]]: """Relax both atomic positions and cell parameters using FIRE algorithm. This function performs geometry optimization of both atomic positions and unit cell @@ -751,8 +752,8 @@ def get_unit_cell_relaxed_structure( state = unit_cell_fire_init(state) def step_fn( - step: int, state: UnitCellFireState, logger: dict - ) -> tuple[UnitCellFireState, dict]: + step: int, state: UnitCellFireState, logger: dict[str, torch.Tensor] + ) -> tuple[UnitCellFireState, dict[str, torch.Tensor]]: logger["energy"][step] = state.energy logger["stress"][step] = state.stress state = unit_cell_fire_update(state) @@ -772,61 +773,3 @@ def step_fn( f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" ) return state, logger, final_energy, final_pressure - - -def get_relaxed_structure( - state: ts.SimState, - model: torch.nn.Module, - max_iter: int = 200, -) -> tuple[FireState, dict]: - """Relax atomic positions at fixed cell parameters using FIRE algorithm. - - Does geometry optimization of atomic positions while keeping the unit cell fixed. - Uses the Fast Inertial Relaxation Engine (FIRE) algorithm to minimize forces on atoms. - - Args: - state: State containing positions, cell and atomic numbers - model: Model to compute energies, forces, and stresses - max_iter: Maximum number of FIRE iterations. Defaults to 200. - - Returns: - tuple containing: - - FIREState: Final state containing relaxed positions and other quantities - - dict: Logger with energy trajectory - - float: Final energy in eV - - float: Final pressure in eV/ų - """ - # Get device and dtype from model - device, dtype = model.device, model.dtype - - logger = {"energy": torch.zeros((max_iter, 1), device=device, dtype=dtype)} - - results = model(state) - Initial_energy = results["energy"] - print(f"Initial energy: {Initial_energy.item():.4f} eV") - - state_init_fn, fire_update = fire(model=model) - state = state_init_fn(state) - - def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]: - logger["energy"][idx] = state.energy - state = fire_update(state) - return state, logger - - for idx in range(max_iter): - state, logger = step_fn(idx, state, logger) - - # Get final results - model.compute_stress = True - final_results = model( - positions=state.positions, cell=state.cell, atomic_numbers=state.atomic_numbers - ) - - final_energy = final_results["energy"].item() - final_stress = final_results["stress"] - final_pressure = (torch.trace(final_stress) / 3.0).item() - print( - f"Final energy: {final_energy:.4f} eV, " - f"Final pressure: {final_pressure:.4f} eV/A^3" - ) - return state, logger, final_energy, final_pressure