Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
Before a pull request can be merged, the following items must be checked:

* [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google).
Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
* [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
* [ ] Run `uvx ty check` on the repo.
* [ ] Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit.
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)

# Run initial simulation and get results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)

# Initialize the Soft Sphere model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)
# Run initial simulation and get results
results = model(state)
Expand Down Expand Up @@ -148,11 +148,11 @@


stress = model(state)["stress"]
calc_kinetic_energy = calc_kinetic_energy(
kinetic_energy = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
volume = torch.linalg.det(state.cell)
pressure = get_pressure(stress, calc_kinetic_energy, volume)
pressure = get_pressure(stress, kinetic_energy, volume)
pressure = pressure.item() / Units.pressure
print(f"Final {pressure=:.4f}")
print(stress * UnitConversion.eV_per_Ang3_to_GPa)
6 changes: 1 addition & 5 deletions examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@
masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)
# Initialize the Lennard-Jones model
# Parameters:
Expand Down
6 changes: 1 addition & 5 deletions examples/scripts/3_Dynamics/3.2_MACE_NVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@
)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)

# Run initial inference
Expand Down
6 changes: 1 addition & 5 deletions examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@
)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)

dt = 0.002 * Units.time # Timestep (ps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)
# Run initial simulation and get results
results = model(state)
Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/4_High_level_api/4.1_high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@
prop_calculators = {
10: {"potential_energy": lambda state: state.energy},
20: {
"kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses)
"kinetic_energy": lambda state: calc_kinetic_energy(
momenta=state.momenta, masses=state.masses
)
},
}

Expand Down
7 changes: 4 additions & 3 deletions examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from phonopy.structure.atoms import PhonopyAtoms

import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.models.mace import MaceModel, MaceUrls


def get_relaxed_structure(
struct: Atoms,
model: torch.nn.Module | None,
model: ModelInterface,
Nrelax: int = 300,
fmax: float = 1e-3,
*,
Expand Down Expand Up @@ -80,7 +81,7 @@ def get_relaxed_structure(
def get_qha_structures(
state: ts.state.SimState,
length_factors: np.ndarray,
model: torch.nn.Module | None,
model: ModelInterface,
Nmax: int = 300,
fmax: float = 1e-3,
*,
Expand Down Expand Up @@ -129,7 +130,7 @@ def get_qha_structures(

def get_qha_phonons(
scaled_structures: list[PhonopyAtoms],
model: torch.nn.Module | None,
model: ModelInterface,
supercell_matrix: np.ndarray | None,
displ: float = 0.05,
*,
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/high_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
10: {"potential_energy": lambda state: state.energy},
20: {
"kinetic_energy": lambda state: ts.calc_kinetic_energy(
state.momenta, state.masses
momenta=state.momenta, masses=state.masses
)
},
}
Expand Down
3 changes: 2 additions & 1 deletion examples/tutorials/reporting_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@

# %%
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.interface import ModelInterface


# Define some property calculators
Expand All @@ -214,7 +215,7 @@ def calculate_com(state: ts.state.SimState) -> torch.Tensor:
return torch.mean(state.positions * state.masses.unsqueeze(1), dim=0)


def calculate_energy(state: ts.state.SimState, model: torch.nn.Module) -> torch.Tensor:
def calculate_energy(state: ts.state.SimState, model: ModelInterface) -> torch.Tensor:
"""Calculate energy - needs both state and model"""
return model(state)["energy"]

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_mattersim.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def pretrained_mattersim_model(device: torch.device, model_name: str):

@pytest.fixture
def mattersim_model(
pretrained_mattersim_model: torch.nn.Module, device: torch.device
pretrained_mattersim_model: Potential, device: torch.device
) -> MatterSimModel:
"""Create an MatterSimModel wrapper for the pretrained model."""
return MatterSimModel(
Expand All @@ -66,7 +66,7 @@ def mattersim_calculator(


def test_mattersim_initialization(
pretrained_mattersim_model: torch.nn.Module, device: torch.device
pretrained_mattersim_model: Potential, device: torch.device
) -> None:
"""Test that the MatterSim model initializes correctly."""
model = MatterSimModel(
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_sevennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
try:
import sevenn.util
from sevenn.calculator import SevenNetCalculator
from sevenn.nn.sequential import AtomGraphSequential

from torch_sim.models.sevennet import SevenNetModel

Expand Down Expand Up @@ -50,7 +51,7 @@ def pretrained_sevenn_model(device: torch.device, model_name: str):

@pytest.fixture
def sevenn_model(
pretrained_sevenn_model: torch.nn.Module, device: torch.device, modal_name: str
pretrained_sevenn_model: AtomGraphSequential, device: torch.device, modal_name: str
) -> SevenNetModel:
"""Create an SevenNetModel wrapper for the pretrained model."""
return SevenNetModel(
Expand All @@ -69,7 +70,7 @@ def sevenn_calculator(


def test_sevennet_initialization(
pretrained_sevenn_model: torch.nn.Module, device: torch.device
pretrained_sevenn_model: AtomGraphSequential, device: torch.device
) -> None:
"""Test that the SevenNet model initializes correctly."""
model = SevenNetModel(
Expand Down
16 changes: 12 additions & 4 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo
state = update_fn(state=state)

# Calculate instantaneous temperature from kinetic energy
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
temp = calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
energies.append(state.energy)
temperatures.append(temp / MetalUnits.temperature)

Expand Down Expand Up @@ -172,7 +174,9 @@ def test_npt_langevin_multi_kt(
state = update_fn(state=state)

# Calculate instantaneous temperature from kinetic energy
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
temp = calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
energies.append(state.energy)
temperatures.append(temp / MetalUnits.temperature)

Expand Down Expand Up @@ -213,7 +217,9 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo
state = update_fn(state=state)

# Calculate instantaneous temperature from kinetic energy
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
temp = calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
energies.append(state.energy)
temperatures.append(temp / MetalUnits.temperature)

Expand Down Expand Up @@ -273,7 +279,9 @@ def test_nvt_langevin_multi_kt(
state = update_fn(state=state)

# Calculate instantaneous temperature from kinetic energy
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
temp = calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
energies.append(state.energy)
temperatures.append(temp / MetalUnits.temperature)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pymatgen.core import Structure

import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.monte_carlo import (
SwapMCState,
generate_swaps,
Expand Down Expand Up @@ -112,7 +113,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState):

def test_monte_carlo(
batched_diverse_state: ts.SimState,
lj_model: torch.nn.Module,
lj_model: ModelInterface,
):
"""Test the monte_carlo function that returns a step function and initial state."""
# Call monte_carlo to get the initial state and step function
Expand Down
Loading