From 4157a20cec17069de2da1a46d1ca6f25ef890a70 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 18 Sep 2025 13:57:07 +0000 Subject: [PATCH 01/13] fix:orb squeeze incorrect energy shape --- torch_sim/models/orb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index fd65b23f..132f6d5c 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -416,7 +416,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] From 4042130f060ccab701814e0553f00f55a37f97cc Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 29 Sep 2025 17:10:19 +0000 Subject: [PATCH 02/13] =?UTF-8?q?feature:=20add=20batching=20for=20nvt=20n?= =?UTF-8?q?os=C3=A9=20hoover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_sim/integrators/md.py | 132 ++++++++++++++++++++++++----------- torch_sim/integrators/nvt.py | 26 +++---- 2 files changed, 106 insertions(+), 52 deletions(-) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 490e3528..312a6c85 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -208,13 +208,13 @@ class NoseHooverChain: in the chain has its own positions, momenta, and masses. Attributes: - positions: Positions of the chain thermostats. Shape: [chain_length] - momenta: Momenta of the chain thermostats. Shape: [chain_length] - masses: Masses of the chain thermostats. Shape: [chain_length] + positions: Positions of the chain thermostats. Shape: [n_systems, chain_length] + momenta: Momenta of the chain thermostats. Shape: [n_systems, chain_length] + masses: Masses of the chain thermostats. Shape: [n_systems, chain_length] tau: Thermostat relaxation time. Longer values give better stability - but worse temperature control. Shape: scalar - kinetic_energy: Current kinetic energy of the coupled system. Shape: scalar - degrees_of_freedom: Number of degrees of freedom in the coupled system + but worse temperature control. Shape: [n_systems] or scalar + kinetic_energy: Current kinetic energy of the coupled system. Shape: [n_systems] + degrees_of_freedom: Number of degrees of freedom per system. Shape: [n_systems] """ positions: torch.Tensor @@ -222,7 +222,8 @@ class NoseHooverChain: masses: torch.Tensor tau: torch.Tensor kinetic_energy: torch.Tensor - degrees_of_freedom: int + degrees_of_freedom: torch.Tensor + system_idx: torch.Tensor | None = None @dataclass @@ -267,7 +268,7 @@ class NoseHooverChainFns: } -def construct_nose_hoover_chain( +def construct_nose_hoover_chain( # noqa: C901 PLR0915 dt: torch.Tensor, chain_length: int, chain_steps: int, @@ -306,14 +307,14 @@ def construct_nose_hoover_chain( """ def init_fn( - degrees_of_freedom: int, KE: torch.Tensor, kT: torch.Tensor + degrees_of_freedom: torch.Tensor, KE: torch.Tensor, kT: torch.Tensor ) -> NoseHooverChain: """Initialize a Nose-Hoover chain state. Args: - degrees_of_freedom: Number of degrees of freedom in coupled system - KE: Initial kinetic energy of the system - kT: Target temperature in energy units + degrees_of_freedom: Number of degrees of freedom per system, shape [n_systems] + KE: Initial kinetic energy per system, shape [n_systems] + kT: Target temperature in energy units, shape [n_systems] or scalar Returns: Initial NoseHooverChain state @@ -321,16 +322,40 @@ def init_fn( device = KE.device dtype = KE.dtype - xi = torch.zeros(chain_length, dtype=dtype, device=device) - p_xi = torch.zeros(chain_length, dtype=dtype, device=device) - - Q = kT * tau**2 * torch.ones(chain_length, dtype=dtype, device=device) - Q[0] *= degrees_of_freedom + # Ensure n_systems is determined from KE shape + n_systems = KE.shape[0] if KE.dim() > 0 else 1 + + # Initialize chain variables with proper batch dimensions + xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device) + p_xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device) + + # Broadcast tau to match n_systems + if isinstance(tau, torch.Tensor): + tau_batched = tau.expand(n_systems) if tau.dim() == 0 else tau + else: + tau_batched = torch.full((n_systems,), tau, dtype=dtype, device=device) + + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT + else: + kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + + Q = ( + kT_batched.unsqueeze(-1) + * tau_batched.unsqueeze(-1) ** 2 + * torch.ones((n_systems, chain_length), dtype=dtype, device=device) + ) + Q[:, 0] *= degrees_of_freedom - return NoseHooverChain(xi, p_xi, Q, tau, KE, degrees_of_freedom) + return NoseHooverChain(xi, p_xi, Q, tau_batched, KE, degrees_of_freedom) def substep_fn( - delta: torch.Tensor, P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor + delta: torch.Tensor, + P: torch.Tensor, + state: NoseHooverChain, + kT: torch.Tensor, + system_idx: torch.Tensor, ) -> tuple[torch.Tensor, NoseHooverChain, torch.Tensor]: """Perform single update of chain parameters and rescale velocities. @@ -339,6 +364,7 @@ def substep_fn( P: System momenta to be rescaled state: Current chain state kT: Target temperature + system_idx: Index of the system being evolved Returns: Tuple of (rescaled momenta, updated chain state, temperature) @@ -358,40 +384,52 @@ def substep_fn( M = chain_length - 1 + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(KE.shape[0]) if kT.dim() == 0 else kT + else: + kT_batched = torch.full_like(KE, kT) + # Update chain momenta backwards - G = p_xi[M - 1] ** 2 / Q[M - 1] - kT - p_xi[M] += delta_4 * G + G = p_xi[:, M - 1] ** 2 / Q[:, M - 1] - kT_batched + p_xi[:, M] += delta_4 * G for m in range(M - 1, 0, -1): - G = p_xi[m - 1] ** 2 / Q[m - 1] - kT - scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) - p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) + G = p_xi[:, m - 1] ** 2 / Q[:, m - 1] - kT_batched + scale = torch.exp(-delta_8 * p_xi[:, m + 1] / Q[:, m + 1]) + p_xi[:, m] = scale * (scale * p_xi[:, m] + delta_4 * G) # Update system coupling - G = 2.0 * KE - DOF * kT - scale = torch.exp(-delta_8 * p_xi[1] / Q[1]) - p_xi[0] = scale * (scale * p_xi[0] + delta_4 * G) + G = 2.0 * KE - DOF * kT_batched + scale = torch.exp(-delta_8 * p_xi[:, 1] / Q[:, 1]) + p_xi[:, 0] = scale * (scale * p_xi[:, 0] + delta_4 * G) # Rescale system momenta - scale = torch.exp(-delta_2 * p_xi[0] / Q[0]) + scale = torch.exp(-delta_2 * p_xi[:, 0] / Q[:, 0]) KE = KE * scale**2 - P = P * scale + + # Apply scale to momenta - need to map from system to atom indices + atom_scale = scale[system_idx].unsqueeze(-1) + P = P * atom_scale # Update positions xi = xi + delta_2 * p_xi / Q # Update chain momenta forwards - G = 2.0 * KE - DOF * kT + G = 2.0 * KE - DOF * kT_batched for m in range(M): - scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) - p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) - G = p_xi[m] ** 2 / Q[m] - kT - p_xi[M] += delta_4 * G + scale = torch.exp(-delta_8 * p_xi[:, m + 1] / Q[:, m + 1]) + p_xi[:, m] = scale * (scale * p_xi[:, m] + delta_4 * G) + G = p_xi[:, m] ** 2 / Q[:, m] - kT_batched + p_xi[:, M] += delta_4 * G - return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT + return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT_batched def half_step_chain_fn( - P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor + P: torch.Tensor, + state: NoseHooverChain, + kT: torch.Tensor, + system_idx: torch.Tensor, ) -> tuple[torch.Tensor, NoseHooverChain]: """Evolve chain for half timestep using multi-timestep integration. @@ -399,12 +437,13 @@ def half_step_chain_fn( P: System momenta to be rescaled state: Current chain state kT: Target temperature + system_idx: Index of the system being evolved Returns: Tuple of (rescaled momenta, updated chain state) """ if chain_steps == 1 and sy_steps == 1: - P, state, _ = substep_fn(dt, P, state, kT) + P, state, _ = substep_fn(dt, P, state, kT, system_idx) return P, state delta = dt / chain_steps @@ -412,7 +451,7 @@ def half_step_chain_fn( for step in range(chain_steps * sy_steps): d = delta * weights[step % sy_steps] - P, state, _ = substep_fn(d, P, state, kT) + P, state, _ = substep_fn(d, P, state, kT, system_idx) return P, state @@ -429,8 +468,21 @@ def update_chain_mass_fn(state: NoseHooverChain, kT: torch.Tensor) -> NoseHoover device = state.positions.device dtype = state.positions.dtype - Q = kT * state.tau**2 * torch.ones(chain_length, dtype=dtype, device=device) - Q[0] *= state.degrees_of_freedom + # Get number of systems + n_systems = state.kinetic_energy.shape[0] + + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT + else: + kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + + Q = ( + kT_batched.unsqueeze(-1) + * state.tau.unsqueeze(-1) ** 2 + * torch.ones((n_systems, chain_length), dtype=dtype, device=device) + ) + Q[:, 0] *= state.degrees_of_freedom return NoseHooverChain( state.positions, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 18f0ae15..18c6ba16 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -6,7 +6,7 @@ import torch -from torch_sim.integrators.md import ( +from torch_sim.integrators.md_batched import ( MDState, NoseHooverChain, NoseHooverChainFns, @@ -377,14 +377,14 @@ def nvt_nose_hoover_init( ) # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) + n_atoms_per_system = state.n_atoms_per_system dof_per_system = ( n_atoms_per_system * state.positions.shape[-1] ) # n_atoms * n_dimensions # For now, sum the per-system DOF as chain expects a single int # This is a limitation that should be addressed in the chain implementation - total_dof = int(dof_per_system.sum().item()) + # total_dof = int(dof_per_system.sum().item()) # Initialize state state = NVTNoseHooverState( @@ -397,7 +397,7 @@ def nvt_nose_hoover_init( pbc=state.pbc, atomic_numbers=atomic_numbers, system_idx=state.system_idx, - chain=chain_fns.initialize(total_dof, KE, kT), + chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, # Store the chain functions ) return state # noqa: RET504 @@ -430,10 +430,10 @@ def nvt_nose_hoover_update( chain = state.chain # Update chain masses based on target temperature - chain = chain_fns.update_mass(chain, kT) + # chain = chain_fns.update_mass(chain, kT) # First half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) state.momenta = momenta # Full velocity Verlet step @@ -446,7 +446,7 @@ def nvt_nose_hoover_update( chain.kinetic_energy = KE # Second half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) state.momenta = momenta state.chain = chain @@ -500,8 +500,8 @@ def nvt_nose_hoover_invariant( # Add first thermostat term c = state.chain # Ensure chain momenta and masses broadcast correctly with batch dimensions - chain_ke_0 = c.momenta[0] ** 2 / (2 * c.masses[0]) - chain_pe_0 = dof * kT * c.positions[0] + chain_ke_0 = c.momenta[:, 0] ** 2 / (2 * c.masses[:, 0]) + chain_pe_0 = dof * kT * c.positions[:, 0] # If chain variables are scalars but we have batches, broadcast them if chain_ke_0.numel() == 1 and e_tot.numel() > 1: @@ -512,9 +512,11 @@ def nvt_nose_hoover_invariant( e_tot = e_tot + chain_ke_0 + chain_pe_0 # Add remaining chain terms - for pos, momentum, mass in zip( - c.positions[1:], c.momenta[1:], c.masses[1:], strict=True - ): + for i in range(1, c.positions.shape[1]): + pos = c.positions[:, i] + momentum = c.momenta[:, i] + mass = c.masses[:, i] + chain_ke = momentum**2 / (2 * mass) chain_pe = kT * pos From ee5ff165529ab72f8586926582f876d985d47f3c Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 29 Sep 2025 17:11:29 +0000 Subject: [PATCH 03/13] =?UTF-8?q?feature:=20add=20tests=20for=20nvt=20nos?= =?UTF-8?q?=C3=A9=20hoover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_integrators.py | 177 ++++++++++++++++++++++++++++++ torch_sim/integrators/__init__.py | 2 +- 2 files changed, 178 insertions(+), 1 deletion(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index d5b210d1..43a141b4 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -8,6 +8,8 @@ npt_langevin, nve, nvt_langevin, + nvt_nose_hoover, + nvt_nose_hoover_invariant, ) from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT @@ -301,6 +303,181 @@ def test_nvt_langevin_multi_kt( assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) +def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + + # Initialize integrator + init_fn, update_fn = nvt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(nvt_nose_hoover_invariant(state, kT)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check invariant conservation (should be roughly constant) + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 10% relative variation + assert invariant_std / invariant_traj.mean() < 0.1 + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_nvt_nose_hoover_multi_equivalent_to_single( + mixed_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Test that nvt_nose_hoover with multiple identical kT values behaves like + running different single kT, assuming same initial state + (most importantly same momenta).""" + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + + # Initialize integrator + init_fn, update_fn = nvt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + ) + final_temperatures = [] + initial_momenta = [] + # Run dynamics for several steps + for i in range(mixed_double_sim_state.n_systems): + state = init_fn(state=mixed_double_sim_state[i], seed=42) + initial_momenta.append(state.momenta.clone()) + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + final_temperatures.append(temp / MetalUnits.temperature) + + initial_momenta_tensor = torch.concat(initial_momenta) + final_temperatures = torch.concat(final_temperatures) + state = init_fn(state=mixed_double_sim_state, seed=42, momenta=initial_momenta_tensor) + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + + assert torch.allclose(final_temperatures, temp / MetalUnits.temperature) + + +def test_nvt_nose_hoover_multi_kt( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature + + # Initialize integrator + init_fn, update_fn = nvt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(nvt_nose_hoover_invariant(state, kT)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + # Check invariant conservation for each system + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 10% relative variation + assert invariant_std / invariant_traj.mean() < 0.1 + + def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index e38c1925..f0e040b1 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -24,4 +24,4 @@ from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet from .npt import NPTLangevinState, npt_langevin from .nve import nve -from .nvt import nvt_langevin +from .nvt import nvt_langevin, nvt_nose_hoover, nvt_nose_hoover_invariant From 7f24199a015e3273fcba07334fde3ca5d7c662e3 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 30 Sep 2025 07:49:16 +0000 Subject: [PATCH 04/13] fix typo from development --- torch_sim/integrators/nvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 18c6ba16..1716c1dc 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -6,7 +6,7 @@ import torch -from torch_sim.integrators.md_batched import ( +from torch_sim.integrators.md import ( MDState, NoseHooverChain, NoseHooverChainFns, From 347a5474e03d9ac5c6ca829fc0ed343d4b4bbe39 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 6 Oct 2025 12:18:38 +0200 Subject: [PATCH 05/13] fix comments --- torch_sim/integrators/nvt.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 1716c1dc..81ace8e8 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -382,10 +382,6 @@ def nvt_nose_hoover_init( n_atoms_per_system * state.positions.shape[-1] ) # n_atoms * n_dimensions - # For now, sum the per-system DOF as chain expects a single int - # This is a limitation that should be addressed in the chain implementation - # total_dof = int(dof_per_system.sum().item()) - # Initialize state state = NVTNoseHooverState( positions=state.positions, @@ -430,7 +426,7 @@ def nvt_nose_hoover_update( chain = state.chain # Update chain masses based on target temperature - # chain = chain_fns.update_mass(chain, kT) + chain = chain_fns.update_mass(chain, kT) # First half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) From d7d3e78833ec6478868e5e5603265b51852109c1 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 6 Oct 2025 12:37:24 +0200 Subject: [PATCH 06/13] add test for output shape if single system as input --- tests/models/conftest.py | 82 ++++++++++++++++++++++++++++++++++ tests/models/test_fairchem.py | 6 +++ tests/models/test_graphpes.py | 4 ++ tests/models/test_mace.py | 5 +++ tests/models/test_mattersim.py | 5 +++ tests/models/test_metatomic.py | 5 +++ tests/models/test_orb.py | 13 ++++++ tests/models/test_sevennet.py | 5 +++ 8 files changed, 125 insertions(+) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 7ef84304..ff67f115 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -228,3 +228,85 @@ def test_model_output_validation( # Rename the function to include the test name test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation + + +def make_validate_single_system_model_outputs_test( + model_fixture_name: str, +): + """Factory function to create single system model output validation tests. + + Args: + model_fixture_name: Name of the model fixture to validate + """ + + def test_single_system_model_output_validation( + request: pytest.FixtureRequest, + device: torch.device, + dtype: torch.dtype, + ) -> None: + """Test that a model follows the ModelInterface contract for single systems.""" + # Get the model fixture dynamically + model: ModelInterface = request.getfixturevalue(model_fixture_name) + + from ase.build import bulk + + assert model.dtype is not None + assert model.device is not None + assert model.compute_stress is not None + assert model.compute_forces is not None + + try: + if not model.compute_stress: + model.compute_stress = True + stress_computed = True + except NotImplementedError: + stress_computed = False + + try: + if not model.compute_forces: + model.compute_forces = True + force_computed = True + except NotImplementedError: + force_computed = False + + # Use only a single Si system + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + sim_state = ts.io.atoms_to_state([si_atoms], device, dtype) + + og_positions = sim_state.positions.clone() + og_cell = sim_state.cell.clone() + og_batch = sim_state.system_idx.clone() + og_atomic_numbers = sim_state.atomic_numbers.clone() + + model_output = model.forward(sim_state) + + # assert model did not mutate the input + assert torch.allclose(og_positions, sim_state.positions) + assert torch.allclose(og_cell, sim_state.cell) + assert torch.allclose(og_batch, sim_state.system_idx) + assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) + + # assert model output has the correct keys + assert "energy" in model_output + assert "forces" in model_output if force_computed else True + assert "stress" in model_output if stress_computed else True + + # assert model output shapes are correct for single system + # energy should be shape (1,) for 1 system + assert model_output["energy"].shape == (1,) + # forces should be shape (n_atoms, 3) for n_atoms in the system + if force_computed: + assert model_output["forces"].shape == (sim_state.n_atoms, 3) + # stress should be shape (1, 3, 3) for 1 system + if stress_computed: + assert model_output["stress"].shape == (1, 3, 3) + + # Verify that energy is a scalar for this single system + energy_scalar = model_output["energy"][0] + assert energy_scalar.shape == () # Should be a scalar tensor + + # Rename the function to include the test name + test_single_system_model_output_validation.__name__ = ( + f"test_{model_fixture_name}_single_system_output_validation" + ) + return test_single_system_model_output_validation diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b04ad9c8..2744dada 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -7,6 +7,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test ) @@ -95,3 +96,8 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: os.environ.get("HF_TOKEN") is None, reason="Issues in graph construction of older models", )(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) + +test_fairchem_ocp_model_single_output = pytest.mark.skipif( + os.environ.get("HF_TOKEN") is None, + reason="Issues in graph construction of older models", +)(make_validate_single_system_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index 1470d3be..e0709ea6 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -7,6 +7,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test, ) from torch_sim.models.graphpes import GraphPESWrapper @@ -176,6 +177,9 @@ def ase_mace_calculator(device: torch.device, dtype: torch.dtype): test_graphpes_mace_model_outputs = make_validate_model_outputs_test( model_fixture_name="ts_mace_model", ) +test_graphpes_mace_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="ts_mace_model", +) _lj_model = LennardJones(sigma=0.5) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 427ef064..5cd40266 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -7,6 +7,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test ) from torch_sim.models.mace import MaceUrls @@ -123,6 +124,10 @@ def torchsim_mace_off_model(device: torch.device, dtype: torch.dtype) -> MaceMod model_fixture_name="torchsim_mace_model" ) +test_mace_off_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="torchsim_mace_model" +) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_mace_off_dtype_working( diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index a137ed78..8a292ac1 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -9,6 +9,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test, ) @@ -87,3 +88,7 @@ def test_mattersim_initialization( test_mattersim_model_outputs = make_validate_model_outputs_test( model_fixture_name="mattersim_model", ) + +test_mace_off_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="mattersim_model" +) \ No newline at end of file diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index f467e4e7..ddeeca2f 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -5,6 +5,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test, ) @@ -64,3 +65,7 @@ def test_metatomic_initialization(device: torch.device) -> None: test_metatomic_model_outputs = make_validate_model_outputs_test( model_fixture_name="metatomic_model", ) + +test_metatomic_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="metatomic_model", +) \ No newline at end of file diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 5c75d4bd..2b390ba7 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -5,6 +5,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test, ) @@ -80,3 +81,15 @@ def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: test_validate_direct_model_outputs = make_validate_model_outputs_test( model_fixture_name="orbv3_direct_20_omat_model", ) + +test_validate_conservative_model_single_output = ( + make_validate_single_system_model_outputs_test( + model_fixture_name="orbv3_conservative_inf_omat_model", + ) +) + +test_validate_direct_model_single_output = ( + make_validate_single_system_model_outputs_test( + model_fixture_name="orbv3_direct_20_omat_model", + ) +) diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 25bd310a..018f1935 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -5,6 +5,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, + make_validate_single_system_model_outputs_test, ) @@ -96,3 +97,7 @@ def test_sevennet_initialization( test_sevennet_model_outputs = make_validate_model_outputs_test( model_fixture_name="sevenn_model", ) + +test_sevennet_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="sevenn_model", +) From 9585c7cad3dc0349559922484e4f7f2c8bcc29e1 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 6 Oct 2025 12:38:47 +0200 Subject: [PATCH 07/13] make ruff happy --- tests/models/test_fairchem.py | 8 ++++++-- tests/models/test_mace.py | 2 +- tests/models/test_mattersim.py | 2 +- tests/models/test_metatomic.py | 2 +- tests/models/test_orb.py | 6 ++---- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 2744dada..2d5ac757 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -7,7 +7,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test + make_validate_single_system_model_outputs_test, ) @@ -100,4 +100,8 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: test_fairchem_ocp_model_single_output = pytest.mark.skipif( os.environ.get("HF_TOKEN") is None, reason="Issues in graph construction of older models", -)(make_validate_single_system_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) +)( + make_validate_single_system_model_outputs_test( + model_fixture_name="eqv2_omat24_model_pbc" + ) +) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 5cd40266..358cf413 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -7,7 +7,7 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test + make_validate_single_system_model_outputs_test, ) from torch_sim.models.mace import MaceUrls diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 8a292ac1..2a3e4cd9 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -91,4 +91,4 @@ def test_mattersim_initialization( test_mace_off_model_single_output = make_validate_single_system_model_outputs_test( model_fixture_name="mattersim_model" -) \ No newline at end of file +) diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index ddeeca2f..39746bb6 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -68,4 +68,4 @@ def test_metatomic_initialization(device: torch.device) -> None: test_metatomic_model_single_output = make_validate_single_system_model_outputs_test( model_fixture_name="metatomic_model", -) \ No newline at end of file +) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 2b390ba7..03d9162d 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -88,8 +88,6 @@ def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: ) ) -test_validate_direct_model_single_output = ( - make_validate_single_system_model_outputs_test( - model_fixture_name="orbv3_direct_20_omat_model", - ) +test_validate_direct_model_single_output = make_validate_single_system_model_outputs_test( + model_fixture_name="orbv3_direct_20_omat_model", ) From eba765791235fb710d539cd10857b1a11c382fa9 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 7 Oct 2025 18:10:07 +0200 Subject: [PATCH 08/13] npt nose hoover is batchable --- tests/test_integrators.py | 183 ++++++++++++++++++++++++++ torch_sim/integrators/__init__.py | 2 +- torch_sim/integrators/npt.py | 205 +++++++++++++++--------------- 3 files changed, 283 insertions(+), 107 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 43a141b4..689ccd30 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -10,6 +10,8 @@ nvt_langevin, nvt_nose_hoover, nvt_nose_hoover_invariant, + npt_nose_hoover, + npt_nose_hoover_invariant, ) from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT @@ -478,6 +480,187 @@ def test_nvt_nose_hoover_multi_kt( assert invariant_std / invariant_traj.mean() < 0.1 +def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(npt_nose_hoover_invariant(state, kT, external_pressure)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory (NPT allows energy fluctuations) + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 2.0 # Allow more fluctuation than NVT due to volume changes + + # Check invariant conservation (should be roughly constant) + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 15% relative variation (more lenient than NVT) + assert invariant_std / invariant_traj.mean() < 0.15 + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_npt_nose_hoover_multi_equivalent_to_single( + mixed_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Test that nvt_nose_hoover with multiple identical kT values behaves like + running different single kT, assuming same initial state + (most importantly same momenta).""" + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + final_temperatures = [] + initial_momenta = [] + # Run dynamics for several steps + for i in range(mixed_double_sim_state.n_systems): + state = init_fn(state=mixed_double_sim_state[i], seed=42) + initial_momenta.append(state.momenta.clone()) + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + final_temperatures.append(temp / MetalUnits.temperature) + + initial_momenta_tensor = torch.concat(initial_momenta) + final_temperatures = torch.concat(final_temperatures) + state = init_fn(state=mixed_double_sim_state, seed=42, momenta=initial_momenta_tensor) + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + + assert torch.allclose(final_temperatures, temp / MetalUnits.temperature) + + +def test_npt_nose_hoover_multi_kt( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_nose_hoover( + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(npt_nose_hoover_invariant(state, kT, external_pressure)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + # Check invariant conservation for each system + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 15% relative variation (more lenient than NVT) + assert invariant_std / invariant_traj.mean() < 0.15 + + def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index f0e040b1..feae4aa3 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -22,6 +22,6 @@ # ruff: noqa: F401 from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet -from .npt import NPTLangevinState, npt_langevin +from .npt import NPTLangevinState, npt_langevin, npt_nose_hoover, npt_nose_hoover_invariant from .nve import nve from .nvt import nvt_langevin, nvt_nose_hoover, nvt_nose_hoover_invariant diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index f5f23c5c..5fb3dd1a 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1166,7 +1166,7 @@ def exp_iL2( # noqa: N802 Args: state (NPTNoseHooverState): Current simulation state for batch mapping - alpha (torch.Tensor): Cell scaling parameter + alpha (torch.Tensor): Cell scaling parameter with shape [n_systems] momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] @@ -1182,11 +1182,12 @@ def exp_iL2( # noqa: N802 - Part of the NPT integration algorithm - Supports batched operations with proper atom-to-system mapping """ - # Map system-level cell velocities to atom level using system indices + # Map system-level values to atom level using system indices + alpha_atoms = alpha[state.system_idx] # [n_atoms] cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] # Compute scaling terms per atom - x = alpha * cell_velocity_atoms * dt_2 # [n_atoms] + x = alpha_atoms * cell_velocity_atoms * dt_2 # [n_atoms] x_2 = x / 2 # [n_atoms] # Compute sinh(x/2)/(x/2) using stable Taylor series @@ -1334,11 +1335,15 @@ def npt_inner_step( ) # Update cell momentum and particle momenta - cell_momentum = cell_momentum + dt_2 * cell_force_val - momenta = exp_iL2(state, alpha, momenta, forces, cell_momentum / cell_mass, dt_2) + # cell_force_val has shape [n_systems], need to expand to [n_systems, 1] + cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) + + # For cell velocity calculation, squeeze to [n_systems] + cell_velocity = cell_momentum.squeeze(-1) / cell_mass + momenta = exp_iL2(state, alpha, momenta, forces, cell_velocity, dt_2) # Full step: Update positions - cell_position = cell_position + cell_momentum / cell_mass * dt + cell_position = cell_position + (cell_momentum.squeeze(-1) / cell_mass) * dt # Update state with new cell_position before calling functions that depend on it state.cell_position = cell_position @@ -1348,14 +1353,15 @@ def npt_inner_step( cell = volume_to_cell(volume) # Update particle positions and forces - positions = exp_iL1(state, state.velocities, cell_momentum / cell_mass, dt) + positions = exp_iL1(state, state.velocities, cell_momentum.squeeze(-1) / cell_mass, dt) state.positions = positions state.cell = cell model_output = model(state) # Second half step: Update momenta momenta = exp_iL2( - state, alpha, momenta, model_output["forces"], cell_momentum / cell_mass, dt_2 + state, alpha, momenta, model_output["forces"], + cell_momentum.squeeze(-1) / cell_mass, dt_2 ) cell_force_val = compute_cell_force( alpha=alpha, @@ -1367,7 +1373,7 @@ def npt_inner_step( external_pressure=external_pressure, system_idx=state.system_idx, ) - cell_momentum = cell_momentum + dt_2 * cell_force_val + cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state state.positions = positions @@ -1449,7 +1455,7 @@ def npt_nose_hoover_init( # Initialize cell variables with proper system dimensions cell_position = torch.zeros(n_systems, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_systems, device=device, dtype=dtype) + cell_momentum = torch.zeros(n_systems, 1, device=device, dtype=dtype) # [n_systems, 1] for compatibility with half_step # Convert kT to tensor if it's not already one if not isinstance(kT, torch.Tensor): @@ -1463,8 +1469,9 @@ def npt_nose_hoover_init( cell_mass = dim * (n_atoms_per_system + 1) * kT_system * b_tau**2 cell_mass = cell_mass.to(device=device, dtype=dtype) - # Calculate cell kinetic energy (using first system for initialization) - KE_cell = calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) + # Calculate cell kinetic energy (per system for proper batching) + # For cell variables, each system has 1 DOF, so KE = p^2/(2m) for each system + KE_cell = (cell_momentum.squeeze(-1)**2) / (2 * cell_mass) # [n_systems] # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: @@ -1487,10 +1494,29 @@ def npt_nose_hoover_init( forces = model_output["forces"] energy = model_output["energy"] - # Create initial state - npt_state = NPTNoseHooverState( + # Initialize momenta first to calculate kinetic energies + momenta = kwargs.get( + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + # Calculate thermostat degrees of freedom and kinetic energy per system + thermostat_dof = n_atoms_per_system * dim # n_atoms * n_dimensions per system + KE_thermostat = calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) + + # Initialize thermostat (batched per system with DOF = n_atoms * 3) + thermostat = thermostat_fns.initialize(thermostat_dof, KE_thermostat, kT) + + # Initialize barostat (batched per system with DOF = 1) + barostat_dof = torch.ones(n_systems, device=device, dtype=dtype) + barostat = barostat_fns.initialize(barostat_dof, KE_cell, kT) + + # Create and return initial state + return NPTNoseHooverState( positions=state.positions, - momenta=None, + momenta=momenta, energy=energy, forces=forces, masses=state.masses, @@ -1502,33 +1528,12 @@ def npt_nose_hoover_init( cell_position=cell_position, cell_momentum=cell_momentum, cell_mass=cell_mass, - barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=None, + barostat=barostat, + thermostat=thermostat, barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) - # Initialize momenta - momenta = kwargs.get( - "momenta", - calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed - ), - ) - - # Initialize thermostat - npt_state.momenta = momenta - KE = calc_kinetic_energy( - momenta=npt_state.momenta, - masses=npt_state.masses, - system_idx=npt_state.system_idx, - ) - npt_state.thermostat = thermostat_fns.initialize( - npt_state.positions.numel(), KE, kT - ) - - return npt_state - def npt_nose_hoover_update( state: NPTNoseHooverState, dt: torch.Tensor = dt, @@ -1563,11 +1568,13 @@ def npt_nose_hoover_update( state = update_cell_mass(state, kT) # First half step of thermostat chains + # For cell momenta, create system index mapping each cell momentum to its system + cell_system_idx = torch.arange(state.n_systems, device=state.device) state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT + state.cell_momentum, state.barostat, kT, cell_system_idx ) state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT + state.momenta, state.thermostat, kT, state.system_idx ) # Perform inner NPT step @@ -1583,21 +1590,53 @@ def npt_nose_hoover_update( ) state.thermostat.kinetic_energy = KE - KE_cell = calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) + KE_cell = (state.cell_momentum.squeeze(-1)**2) / (2 * state.cell_mass) # [n_systems] state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT + state.momenta, state.thermostat, kT, state.system_idx ) + cell_system_idx = torch.arange(state.n_systems, device=state.device) state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT + state.cell_momentum, state.barostat, kT, cell_system_idx ) return state return npt_nose_hoover_init, npt_nose_hoover_update +def _compute_chain_energy( + chain: NoseHooverChain, kT: torch.Tensor, e_tot: torch.Tensor, dof: torch.Tensor +) -> torch.Tensor: + """Compute energy contribution from a Nose-Hoover chain. + + Args: + chain: The Nose-Hoover chain state + kT: Target temperature + e_tot: Current total energy for broadcasting + dof: Degrees of freedom (only used for first chain element) + + Returns: + Total chain energy contribution + """ + chain_energy = torch.zeros_like(e_tot) + + # First chain element with DOF weighting + ke_0 = chain.momenta[:, 0] ** 2 / (2 * chain.masses[:, 0]) + pe_0 = dof * kT * chain.positions[:, 0] + + chain_energy += ke_0 + pe_0 + + # Remaining chain elements + for i in range(1, chain.positions.shape[1]): + ke = chain.momenta[:, i] ** 2 / (2 * chain.masses[:, i]) + pe = kT * chain.positions[:, i] + chain_energy += ke + pe + + return chain_energy + + def npt_nose_hoover_invariant( state: NPTNoseHooverState, kT: torch.Tensor, @@ -1612,15 +1651,15 @@ def npt_nose_hoover_invariant( The conserved quantity includes: - Potential energy of the system - Kinetic energy of the particles - - Energy contributions from thermostat chains - - Energy contributions from barostat chains + - Energy contributions from thermostat chains (per system) + - Energy contributions from barostat chains (per system) - PV work term - Cell kinetic energy Args: state: Current state of the NPT simulation system. Must contain position, momentum, cell, cell_momentum, cell_mass, thermostat, - and barostat. + and barostat with proper batching for multiple systems. external_pressure: Target external pressure of the system. kT: Target thermal energy (Boltzmann constant x temperature). @@ -1639,72 +1678,26 @@ def npt_nose_hoover_invariant( ) # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - DOF_per_system = ( - n_atoms_per_system * state.positions.shape[-1] - ) # n_atoms * n_dimensions + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + dof_per_system = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dim # Initialize total energy with PE + KE - if isinstance(e_pot, torch.Tensor) and e_pot.ndim > 0: - e_tot = e_pot + e_kin_per_system # [n_systems] - else: - e_tot = e_pot + e_kin_per_system # [n_systems] - - # Add thermostat chain contributions - # Note: These are global thermostat variables, so we add them to each system - # Start thermostat_energy as a tensor with the right shape - thermostat_energy = torch.zeros_like(e_tot) - thermostat_energy += (state.thermostat.momenta[0] ** 2) / ( - 2 * state.thermostat.masses[0] - ) + e_tot = e_pot + e_kin_per_system - # Ensure kT can broadcast properly with DOF_per_system - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT - expand to match DOF_per_system shape - kT_expanded = kT.expand_as(DOF_per_system) - else: - kT_expanded = kT - - thermostat_energy += DOF_per_system * kT_expanded * state.thermostat.positions[0] - - # Add remaining thermostat terms - for pos, momentum, mass in zip( - state.thermostat.positions[1:], - state.thermostat.momenta[1:], - state.thermostat.masses[1:], - strict=True, - ): - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT case - thermostat_energy += (momentum**2) / (2 * mass) + kT * pos - else: - # Batched kT case - thermostat_energy += (momentum**2) / (2 * mass) + kT_expanded * pos - - e_tot = e_tot + thermostat_energy - - # Add barostat chain contributions - barostat_energy = torch.zeros_like(e_tot) - for pos, momentum, mass in zip( - state.barostat.positions, - state.barostat.momenta, - state.barostat.masses, - strict=True, - ): - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT case - barostat_energy += (momentum**2) / (2 * mass) + kT * pos - else: - # Batched kT case - barostat_energy += (momentum**2) / (2 * mass) + kT_expanded * pos + # Add thermostat chain contributions (batched per system, DOF = n_atoms * 3) + e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system) - e_tot = e_tot + barostat_energy + # Add barostat chain contributions (batched per system, DOF = 1) + barostat_dof = torch.ones_like(dof_per_system) # 1 DOF per system for barostat + e_tot += _compute_chain_energy(state.barostat, kT, e_tot, barostat_dof) # Add PV term and cell kinetic energy (both are per system) e_tot += external_pressure * volume - e_tot += (state.cell_momentum**2) / (2 * state.cell_mass) - # Return scalar if single system, otherwise return per-system values - if state.n_systems == 1: - return e_tot.squeeze() + # Ensure cell_momentum has the right shape [n_systems] + cell_momentum = state.cell_momentum + cell_momentum = cell_momentum.squeeze() + + e_tot += (cell_momentum**2) / (2 * state.cell_mass) + return e_tot From d91f94a0cc8d11a76abd952cf637638a309a4c62 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 7 Oct 2025 18:24:06 +0200 Subject: [PATCH 09/13] add single system output shape test in test_model_output_validation --- tests/models/conftest.py | 85 +++------------------------------- tests/models/test_fairchem.py | 10 ---- tests/models/test_graphpes.py | 4 -- tests/models/test_mace.py | 5 -- tests/models/test_mattersim.py | 5 -- tests/models/test_metatomic.py | 5 -- tests/models/test_orb.py | 11 ----- tests/models/test_sevennet.py | 5 -- 8 files changed, 6 insertions(+), 124 deletions(-) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index ff67f115..ceda039d 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -225,88 +225,15 @@ def test_model_output_validation( # atol=10e-3, # ) - # Rename the function to include the test name - test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" - return test_model_output_validation - - -def make_validate_single_system_model_outputs_test( - model_fixture_name: str, -): - """Factory function to create single system model output validation tests. - - Args: - model_fixture_name: Name of the model fixture to validate - """ - - def test_single_system_model_output_validation( - request: pytest.FixtureRequest, - device: torch.device, - dtype: torch.dtype, - ) -> None: - """Test that a model follows the ModelInterface contract for single systems.""" - # Get the model fixture dynamically - model: ModelInterface = request.getfixturevalue(model_fixture_name) - - from ase.build import bulk - - assert model.dtype is not None - assert model.device is not None - assert model.compute_stress is not None - assert model.compute_forces is not None - - try: - if not model.compute_stress: - model.compute_stress = True - stress_computed = True - except NotImplementedError: - stress_computed = False - - try: - if not model.compute_forces: - model.compute_forces = True - force_computed = True - except NotImplementedError: - force_computed = False - - # Use only a single Si system - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - sim_state = ts.io.atoms_to_state([si_atoms], device, dtype) - - og_positions = sim_state.positions.clone() - og_cell = sim_state.cell.clone() - og_batch = sim_state.system_idx.clone() - og_atomic_numbers = sim_state.atomic_numbers.clone() - - model_output = model.forward(sim_state) - - # assert model did not mutate the input - assert torch.allclose(og_positions, sim_state.positions) - assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.system_idx) - assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) - - # assert model output has the correct keys - assert "energy" in model_output - assert "forces" in model_output if force_computed else True - assert "stress" in model_output if stress_computed else True - - # assert model output shapes are correct for single system - # energy should be shape (1,) for 1 system - assert model_output["energy"].shape == (1,) + # Test single system output + assert fe_model_output["energy"].shape == (1,) # forces should be shape (n_atoms, 3) for n_atoms in the system if force_computed: - assert model_output["forces"].shape == (sim_state.n_atoms, 3) + assert fe_model_output["forces"].shape == (12, 3) # stress should be shape (1, 3, 3) for 1 system if stress_computed: - assert model_output["stress"].shape == (1, 3, 3) - - # Verify that energy is a scalar for this single system - energy_scalar = model_output["energy"][0] - assert energy_scalar.shape == () # Should be a scalar tensor + assert fe_model_output["stress"].shape == (1, 3, 3) # Rename the function to include the test name - test_single_system_model_output_validation.__name__ = ( - f"test_{model_fixture_name}_single_system_output_validation" - ) - return test_single_system_model_output_validation + test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" + return test_model_output_validation diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 2d5ac757..b04ad9c8 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -7,7 +7,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) @@ -96,12 +95,3 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: os.environ.get("HF_TOKEN") is None, reason="Issues in graph construction of older models", )(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) - -test_fairchem_ocp_model_single_output = pytest.mark.skipif( - os.environ.get("HF_TOKEN") is None, - reason="Issues in graph construction of older models", -)( - make_validate_single_system_model_outputs_test( - model_fixture_name="eqv2_omat24_model_pbc" - ) -) diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index e0709ea6..1470d3be 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -7,7 +7,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) from torch_sim.models.graphpes import GraphPESWrapper @@ -177,9 +176,6 @@ def ase_mace_calculator(device: torch.device, dtype: torch.dtype): test_graphpes_mace_model_outputs = make_validate_model_outputs_test( model_fixture_name="ts_mace_model", ) -test_graphpes_mace_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="ts_mace_model", -) _lj_model = LennardJones(sigma=0.5) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 358cf413..427ef064 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -7,7 +7,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) from torch_sim.models.mace import MaceUrls @@ -124,10 +123,6 @@ def torchsim_mace_off_model(device: torch.device, dtype: torch.dtype) -> MaceMod model_fixture_name="torchsim_mace_model" ) -test_mace_off_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="torchsim_mace_model" -) - @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_mace_off_dtype_working( diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 2a3e4cd9..a137ed78 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -9,7 +9,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) @@ -88,7 +87,3 @@ def test_mattersim_initialization( test_mattersim_model_outputs = make_validate_model_outputs_test( model_fixture_name="mattersim_model", ) - -test_mace_off_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="mattersim_model" -) diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index 39746bb6..f467e4e7 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -5,7 +5,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) @@ -65,7 +64,3 @@ def test_metatomic_initialization(device: torch.device) -> None: test_metatomic_model_outputs = make_validate_model_outputs_test( model_fixture_name="metatomic_model", ) - -test_metatomic_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="metatomic_model", -) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 03d9162d..5c75d4bd 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -5,7 +5,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) @@ -81,13 +80,3 @@ def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: test_validate_direct_model_outputs = make_validate_model_outputs_test( model_fixture_name="orbv3_direct_20_omat_model", ) - -test_validate_conservative_model_single_output = ( - make_validate_single_system_model_outputs_test( - model_fixture_name="orbv3_conservative_inf_omat_model", - ) -) - -test_validate_direct_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="orbv3_direct_20_omat_model", -) diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 018f1935..25bd310a 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -5,7 +5,6 @@ consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, - make_validate_single_system_model_outputs_test, ) @@ -97,7 +96,3 @@ def test_sevennet_initialization( test_sevennet_model_outputs = make_validate_model_outputs_test( model_fixture_name="sevenn_model", ) - -test_sevennet_model_single_output = make_validate_single_system_model_outputs_test( - model_fixture_name="sevenn_model", -) From 3af5c0456a04fceb761c5b54d38f0dc410e3adba Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 8 Oct 2025 18:07:42 +0200 Subject: [PATCH 10/13] pass lint test --- tests/models/conftest.py | 4 ++-- tests/test_integrators.py | 4 ++-- torch_sim/integrators/__init__.py | 7 ++++++- torch_sim/integrators/npt.py | 22 ++++++++++++++++------ 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index ceda039d..59bc909f 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -122,7 +122,7 @@ def test_model_calculator_consistency( return test_model_calculator_consistency -def make_validate_model_outputs_test( +def make_validate_model_outputs_test( # noqa: PLR0915 model_fixture_name: str, ): """Factory function to create model output validation tests. @@ -132,7 +132,7 @@ def make_validate_model_outputs_test( model_fixture_name: Name of the model fixture to validate """ - def test_model_output_validation( + def test_model_output_validation( # noqa: PLR0915 request: pytest.FixtureRequest, device: torch.device, dtype: torch.dtype, diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 689ccd30..bdba744d 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -6,12 +6,12 @@ NPTLangevinState, calculate_momenta, npt_langevin, + npt_nose_hoover, + npt_nose_hoover_invariant, nve, nvt_langevin, nvt_nose_hoover, nvt_nose_hoover_invariant, - npt_nose_hoover, - npt_nose_hoover_invariant, ) from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index feae4aa3..f348b916 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -22,6 +22,11 @@ # ruff: noqa: F401 from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet -from .npt import NPTLangevinState, npt_langevin, npt_nose_hoover, npt_nose_hoover_invariant +from .npt import ( + NPTLangevinState, + npt_langevin, + npt_nose_hoover, + npt_nose_hoover_invariant, +) from .nve import nve from .nvt import nvt_langevin, nvt_nose_hoover, nvt_nose_hoover_invariant diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 5fb3dd1a..958d68b5 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1353,15 +1353,21 @@ def npt_inner_step( cell = volume_to_cell(volume) # Update particle positions and forces - positions = exp_iL1(state, state.velocities, cell_momentum.squeeze(-1) / cell_mass, dt) + positions = exp_iL1( + state, state.velocities, cell_momentum.squeeze(-1) / cell_mass, dt + ) state.positions = positions state.cell = cell model_output = model(state) # Second half step: Update momenta momenta = exp_iL2( - state, alpha, momenta, model_output["forces"], - cell_momentum.squeeze(-1) / cell_mass, dt_2 + state, + alpha, + momenta, + model_output["forces"], + cell_momentum.squeeze(-1) / cell_mass, + dt_2, ) cell_force_val = compute_cell_force( alpha=alpha, @@ -1455,7 +1461,9 @@ def npt_nose_hoover_init( # Initialize cell variables with proper system dimensions cell_position = torch.zeros(n_systems, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_systems, 1, device=device, dtype=dtype) # [n_systems, 1] for compatibility with half_step + cell_momentum = torch.zeros( + n_systems, 1, device=device, dtype=dtype + ) # [n_systems, 1] for compatibility with half_step # Convert kT to tensor if it's not already one if not isinstance(kT, torch.Tensor): @@ -1471,7 +1479,7 @@ def npt_nose_hoover_init( # Calculate cell kinetic energy (per system for proper batching) # For cell variables, each system has 1 DOF, so KE = p^2/(2m) for each system - KE_cell = (cell_momentum.squeeze(-1)**2) / (2 * cell_mass) # [n_systems] + KE_cell = (cell_momentum.squeeze(-1) ** 2) / (2 * cell_mass) # [n_systems] # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: @@ -1590,7 +1598,9 @@ def npt_nose_hoover_update( ) state.thermostat.kinetic_energy = KE - KE_cell = (state.cell_momentum.squeeze(-1)**2) / (2 * state.cell_mass) # [n_systems] + KE_cell = (state.cell_momentum.squeeze(-1) ** 2) / ( + 2 * state.cell_mass + ) # [n_systems] state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains From e35160c766ac85c3a38a487bf6c1657da6669df3 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 10 Oct 2025 13:24:30 +0200 Subject: [PATCH 11/13] check final temperature (algorithm preservation) --- tests/test_integrators.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 3657614a..3e43f6f1 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -327,6 +327,10 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone # Convert temperatures list to tensor temperatures_tensor = torch.stack(temperatures) temperatures_list = [t.tolist() for t in temperatures_tensor.T] + assert torch.allclose( + temperatures_tensor[-1], + torch.tensor([299.9910, 299.6800], dtype=dtype), + ) energies_tensor = torch.stack(energies) energies_list = [t.tolist() for t in energies_tensor.T] @@ -512,6 +516,10 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone # Convert temperatures list to tensor temperatures_tensor = torch.stack(temperatures) temperatures_list = [t.tolist() for t in temperatures_tensor.T] + assert torch.allclose( + temperatures_tensor[-1], + torch.tensor([297.8602, 297.5306], dtype=dtype), + ) energies_tensor = torch.stack(energies) energies_list = [t.tolist() for t in energies_tensor.T] From 520fa5df097f08e1e78dfb441f026ea818738a7e Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 16 Oct 2025 15:09:39 +0200 Subject: [PATCH 12/13] solve nose hoover for chain_length=1 --- torch_sim/integrators/md.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index c2a9d593..f1bacfb4 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -401,7 +401,7 @@ def substep_fn( # Update system coupling G = 2.0 * KE - DOF * kT_batched - scale = torch.exp(-delta_8 * p_xi[:, 1] / Q[:, 1]) + scale = torch.exp(-delta_8 * p_xi[:, 1] / Q[:, 1]) if M > 0 else 1.0 p_xi[:, 0] = scale * (scale * p_xi[:, 0] + delta_4 * G) # Rescale system momenta From 2b87b67d268af1a8913d01831323fffae7e95673 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 16 Oct 2025 15:29:51 +0200 Subject: [PATCH 13/13] 2nd part correction chain_length --- torch_sim/integrators/md.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index f1bacfb4..aeb46678 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -391,8 +391,9 @@ def substep_fn( kT_batched = torch.full_like(KE, kT) # Update chain momenta backwards - G = torch.square(p_xi[:, M - 1]) / Q[:, M - 1] - kT_batched - p_xi[:, M] += delta_4 * G + if M > 0: + G = torch.square(p_xi[:, M - 1]) / Q[:, M - 1] - kT_batched + p_xi[:, M] += delta_4 * G for m in range(M - 1, 0, -1): G = torch.square(p_xi[:, m - 1]) / Q[:, m - 1] - kT_batched