diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 5dc23fab..0e481967 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -123,7 +123,7 @@ def test_model_calculator_consistency( return test_model_calculator_consistency -def make_validate_model_outputs_test( +def make_validate_model_outputs_test( # noqa: PLR0915 model_fixture_name: str, device: torch.device = DEVICE, dtype: torch.dtype = torch.float64, @@ -135,7 +135,7 @@ def make_validate_model_outputs_test( model_fixture_name: Name of the model fixture to validate """ - def test_model_output_validation(request: pytest.FixtureRequest) -> None: + def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915 """Test that a model implementation follows the ModelInterface contract.""" # Get the model fixture dynamically model: ModelInterface = request.getfixturevalue(model_fixture_name) @@ -224,6 +224,15 @@ def test_model_output_validation(request: pytest.FixtureRequest) -> None: # atol=10e-3, # ) + # Test single system output + assert fe_model_output["energy"].shape == (1,) + # forces should be shape (n_atoms, 3) for n_atoms in the system + if force_computed: + assert fe_model_output["forces"].shape == (12, 3) + # stress should be shape (1, 3, 3) for 1 system + if stress_computed: + assert fe_model_output["stress"].shape == (1, 3, 3) + # Rename the function to include the test name test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 31ccc75b..3e43f6f1 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -295,6 +295,399 @@ def test_nvt_langevin_multi_kt( assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) +def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + + # Run dynamics for several steps + state = ts.nvt_nose_hoover_init( + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, seed=42 + ) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = ts.nvt_nose_hoover_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(ts.nvt_nose_hoover_invariant(state, kT)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + assert torch.allclose( + temperatures_tensor[-1], + torch.tensor([299.9910, 299.6800], dtype=dtype), + ) + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check invariant conservation (should be roughly constant) + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 10% relative variation + assert invariant_std / invariant_traj.mean() < 0.1 + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_nvt_nose_hoover_multi_equivalent_to_single( + mixed_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Test that nvt_nose_hoover with multiple identical kT values behaves like + running different single kT, assuming same initial state + (most importantly same momenta).""" + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + + final_temperatures = [] + initial_momenta = [] + # Run dynamics for several steps + for i in range(mixed_double_sim_state.n_systems): + state = ts.nvt_nose_hoover_init( + state=mixed_double_sim_state[i], model=lj_model, dt=dt, kT=kT, seed=42 + ) + initial_momenta.append(state.momenta.clone()) + for _step in range(n_steps): + state = ts.nvt_nose_hoover_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + final_temperatures.append(temp / MetalUnits.temperature) + + initial_momenta_tensor = torch.concat(initial_momenta) + final_temperatures = torch.concat(final_temperatures) + state = ts.nvt_nose_hoover_init( + state=mixed_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + seed=42, + momenta=initial_momenta_tensor, + ) + for _step in range(n_steps): + state = ts.nvt_nose_hoover_step(state=state, model=lj_model, dt=dt, kT=kT) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + + assert torch.allclose(final_temperatures, temp / MetalUnits.temperature) + + +def test_nvt_nose_hoover_multi_kt( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature + + # Run dynamics for several steps + state = ts.nvt_nose_hoover_init( + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, seed=42 + ) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = ts.nvt_nose_hoover_step(state=state, model=lj_model, dt=dt, kT=kT) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(ts.nvt_nose_hoover_invariant(state, kT)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + # Check invariant conservation for each system + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 10% relative variation + assert invariant_std / invariant_traj.mean() < 0.1 + + +def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Run dynamics for several steps + state = ts.npt_nose_hoover_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + seed=42, + ) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = ts.npt_nose_hoover_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(ts.npt_nose_hoover_invariant(state, kT, external_pressure)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + assert torch.allclose( + temperatures_tensor[-1], + torch.tensor([297.8602, 297.5306], dtype=dtype), + ) + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory (NPT allows energy fluctuations) + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 2.0 # Allow more fluctuation than NVT due to volume changes + + # Check invariant conservation (should be roughly constant) + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 15% relative variation (more lenient than NVT) + assert invariant_std / invariant_traj.mean() < 0.15 + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_npt_nose_hoover_multi_equivalent_to_single( + mixed_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + """Test that nvt_nose_hoover with multiple identical kT values behaves like + running different single kT, assuming same initial state + (most importantly same momenta).""" + dtype = torch.float64 + n_steps = 100 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + final_temperatures = [] + initial_momenta = [] + # Run dynamics for several steps + for i in range(mixed_double_sim_state.n_systems): + state = ts.npt_nose_hoover_init( + state=mixed_double_sim_state[i], + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + seed=42, + ) + initial_momenta.append(state.momenta.clone()) + for _step in range(n_steps): + state = ts.npt_nose_hoover_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + final_temperatures.append(temp / MetalUnits.temperature) + + initial_momenta_tensor = torch.concat(initial_momenta) + final_temperatures = torch.concat(final_temperatures) + state = ts.npt_nose_hoover_init( + state=mixed_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + seed=42, + momenta=initial_momenta_tensor, + ) + for _step in range(n_steps): + state = ts.npt_nose_hoover_step( + state=state, model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + + assert torch.allclose(final_temperatures, temp / MetalUnits.temperature) + + +def test_npt_nose_hoover_multi_kt( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Run dynamics for several steps + state = ts.npt_nose_hoover_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + seed=42, + ) + energies = [] + temperatures = [] + invariants = [] + for _step in range(n_steps): + state = ts.npt_nose_hoover_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + invariants.append(ts.npt_nose_hoover_invariant(state, kT, external_pressure)) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + invariants_tensor = torch.stack(invariants) + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + # Check invariant conservation for each system + for traj_idx in range(invariants_tensor.shape[1]): + invariant_traj = invariants_tensor[:, traj_idx] + invariant_std = invariant_traj.std() + # Allow for some drift but should be relatively stable + # Less than 15% relative variation (more lenient than NVT) + assert invariant_std / invariant_traj.mean() < 0.15 + + def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): n_steps = 100 dt = torch.tensor(0.001, dtype=DTYPE) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index b968e18f..589b0412 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -208,13 +208,13 @@ class NoseHooverChain: in the chain has its own positions, momenta, and masses. Attributes: - positions: Positions of the chain thermostats. Shape: [chain_length] - momenta: Momenta of the chain thermostats. Shape: [chain_length] - masses: Masses of the chain thermostats. Shape: [chain_length] + positions: Positions of the chain thermostats. Shape: [n_systems, chain_length] + momenta: Momenta of the chain thermostats. Shape: [n_systems, chain_length] + masses: Masses of the chain thermostats. Shape: [n_systems, chain_length] tau: Thermostat relaxation time. Longer values give better stability - but worse temperature control. Shape: scalar - kinetic_energy: Current kinetic energy of the coupled system. Shape: scalar - degrees_of_freedom: Number of degrees of freedom in the coupled system + but worse temperature control. Shape: [n_systems] or scalar + kinetic_energy: Current kinetic energy of the coupled system. Shape: [n_systems] + degrees_of_freedom: Number of degrees of freedom per system. Shape: [n_systems] """ positions: torch.Tensor @@ -222,7 +222,8 @@ class NoseHooverChain: masses: torch.Tensor tau: torch.Tensor kinetic_energy: torch.Tensor - degrees_of_freedom: int + degrees_of_freedom: torch.Tensor + system_idx: torch.Tensor | None = None @dataclass @@ -285,7 +286,7 @@ class NoseHooverChainFns: } -def construct_nose_hoover_chain( +def construct_nose_hoover_chain( # noqa: C901 PLR0915 dt: torch.Tensor, chain_length: int, chain_steps: int, @@ -324,14 +325,14 @@ def construct_nose_hoover_chain( """ def init_fn( - degrees_of_freedom: int, KE: torch.Tensor, kT: torch.Tensor + degrees_of_freedom: torch.Tensor, KE: torch.Tensor, kT: torch.Tensor ) -> NoseHooverChain: """Initialize a Nose-Hoover chain state. Args: - degrees_of_freedom: Number of degrees of freedom in coupled system - KE: Initial kinetic energy of the system - kT: Target temperature in energy units + degrees_of_freedom: Number of degrees of freedom per system, shape [n_systems] + KE: Initial kinetic energy per system, shape [n_systems] + kT: Target temperature in energy units, shape [n_systems] or scalar Returns: Initial NoseHooverChain state @@ -339,16 +340,40 @@ def init_fn( device = KE.device dtype = KE.dtype - xi = torch.zeros(chain_length, dtype=dtype, device=device) - p_xi = torch.zeros(chain_length, dtype=dtype, device=device) + # Ensure n_systems is determined from KE shape + n_systems = KE.shape[0] if KE.dim() > 0 else 1 - Q = kT * torch.square(tau) * torch.ones(chain_length, dtype=dtype, device=device) - Q[0] *= degrees_of_freedom + # Initialize chain variables with proper batch dimensions + xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device) + p_xi = torch.zeros((n_systems, chain_length), dtype=dtype, device=device) - return NoseHooverChain(xi, p_xi, Q, tau, KE, degrees_of_freedom) + # Broadcast tau to match n_systems + if isinstance(tau, torch.Tensor): + tau_batched = tau.expand(n_systems) if tau.dim() == 0 else tau + else: + tau_batched = torch.full((n_systems,), tau, dtype=dtype, device=device) + + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT + else: + kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + + Q = ( + kT_batched.unsqueeze(-1) + * torch.square(tau_batched).unsqueeze(-1) ** 2 + * torch.ones((n_systems, chain_length), dtype=dtype, device=device) + ) + Q[:, 0] *= degrees_of_freedom + + return NoseHooverChain(xi, p_xi, Q, tau_batched, KE, degrees_of_freedom) def substep_fn( - delta: torch.Tensor, P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor + delta: torch.Tensor, + P: torch.Tensor, + state: NoseHooverChain, + kT: torch.Tensor, + system_idx: torch.Tensor, ) -> tuple[torch.Tensor, NoseHooverChain, torch.Tensor]: """Perform single update of chain parameters and rescale velocities. @@ -357,6 +382,7 @@ def substep_fn( P: System momenta to be rescaled state: Current chain state kT: Target temperature + system_idx: Index of the system being evolved Returns: Tuple of (rescaled momenta, updated chain state, temperature) @@ -376,40 +402,53 @@ def substep_fn( M = chain_length - 1 + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(KE.shape[0]) if kT.dim() == 0 else kT + else: + kT_batched = torch.full_like(KE, kT) + # Update chain momenta backwards - G = torch.square(p_xi[M - 1]) / Q[M - 1] - kT - p_xi[M] += delta_4 * G + if M > 0: + G = torch.square(p_xi[:, M - 1]) / Q[:, M - 1] - kT_batched + p_xi[:, M] += delta_4 * G for m in range(M - 1, 0, -1): - G = torch.square(p_xi[m - 1]) / Q[m - 1] - kT - scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) - p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) + G = torch.square(p_xi[:, m - 1]) / Q[:, m - 1] - kT_batched + scale = torch.exp(-delta_8 * p_xi[:, m + 1] / Q[:, m + 1]) + p_xi[:, m] = scale * (scale * p_xi[:, m] + delta_4 * G) # Update system coupling - G = 2.0 * KE - DOF * kT - scale = torch.exp(-delta_8 * p_xi[1] / Q[1]) - p_xi[0] = scale * (scale * p_xi[0] + delta_4 * G) + G = 2.0 * KE - DOF * kT_batched + scale = torch.exp(-delta_8 * p_xi[:, 1] / Q[:, 1]) if M > 0 else 1.0 + p_xi[:, 0] = scale * (scale * p_xi[:, 0] + delta_4 * G) # Rescale system momenta - scale = torch.exp(-delta_2 * p_xi[0] / Q[0]) + scale = torch.exp(-delta_2 * p_xi[:, 0] / Q[:, 0]) KE = KE * torch.square(scale) - P = P * scale + + # Apply scale to momenta - need to map from system to atom indices + atom_scale = scale[system_idx].unsqueeze(-1) + P = P * atom_scale # Update positions xi = xi + delta_2 * p_xi / Q # Update chain momenta forwards - G = 2.0 * KE - DOF * kT + G = 2.0 * KE - DOF * kT_batched for m in range(M): - scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) - p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) - G = torch.square(p_xi[m]) / Q[m] - kT - p_xi[M] += delta_4 * G + scale = torch.exp(-delta_8 * p_xi[:, m + 1] / Q[:, m + 1]) + p_xi[:, m] = scale * (scale * p_xi[:, m] + delta_4 * G) + G = torch.square(p_xi[:, m]) / Q[:, m] - kT_batched + p_xi[:, M] += delta_4 * G - return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT + return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT_batched def half_step_chain_fn( - P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor + P: torch.Tensor, + state: NoseHooverChain, + kT: torch.Tensor, + system_idx: torch.Tensor, ) -> tuple[torch.Tensor, NoseHooverChain]: """Evolve chain for half timestep using multi-timestep integration. @@ -417,12 +456,13 @@ def half_step_chain_fn( P: System momenta to be rescaled state: Current chain state kT: Target temperature + system_idx: Index of the system being evolved Returns: Tuple of (rescaled momenta, updated chain state) """ if chain_steps == 1 and sy_steps == 1: - P, state, _ = substep_fn(dt, P, state, kT) + P, state, _ = substep_fn(dt, P, state, kT, system_idx) return P, state delta = dt / chain_steps @@ -430,7 +470,7 @@ def half_step_chain_fn( for step in range(chain_steps * sy_steps): d = delta * weights[step % sy_steps] - P, state, _ = substep_fn(d, P, state, kT) + P, state, _ = substep_fn(d, P, state, kT, system_idx) return P, state @@ -449,12 +489,21 @@ def update_chain_mass_fn( device = chain_state.positions.device dtype = chain_state.positions.dtype + # Get number of systems + n_systems = chain_state.kinetic_energy.shape[0] + + # Ensure kT has proper batch dimension + if isinstance(kT, torch.Tensor): + kT_batched = kT.expand(n_systems) if kT.dim() == 0 else kT + else: + kT_batched = torch.full((n_systems,), kT, dtype=dtype, device=device) + Q = ( - kT - * torch.square(chain_state.tau) - * torch.ones(chain_length, dtype=dtype, device=device) + kT_batched.unsqueeze(-1) + * torch.square(chain_state.tau).unsqueeze(-1) + * torch.ones((n_systems, chain_length), dtype=dtype, device=device) ) - Q[0] *= chain_state.degrees_of_freedom + Q[:, 0] *= chain_state.degrees_of_freedom return NoseHooverChain( chain_state.positions, diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1ab4e7c3..7fb07a1f 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1073,7 +1073,7 @@ def _npt_nose_hoover_exp_iL2( # noqa: N802 Args: state (NPTNoseHooverState): Current simulation state for batch mapping - alpha (torch.Tensor): Cell scaling parameter + alpha (torch.Tensor): Cell scaling parameter with shape [n_systems] momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] @@ -1242,13 +1242,14 @@ def _npt_nose_hoover_inner_step( ) # Update cell momentum and particle momenta - cell_momentum = cell_momentum + dt_2 * cell_force_val + cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) + cell_velocities = cell_momentum.squeeze(-1) / cell_mass momenta = _npt_nose_hoover_exp_iL2( - state, alpha, momenta, forces, cell_momentum / cell_mass, dt_2 + state, alpha, momenta, forces, cell_velocities, dt_2 ) # Full step: Update positions - cell_position = cell_position + cell_momentum / cell_mass * dt + cell_position = cell_position + cell_velocities * dt # Update state with new cell_position before calling functions that depend on it state.cell_position = cell_position @@ -1258,16 +1259,14 @@ def _npt_nose_hoover_inner_step( cell = volume_to_cell(volume) # Update particle positions and forces - positions = _npt_nose_hoover_exp_iL1( - state, state.velocities, cell_momentum / cell_mass, dt - ) + positions = _npt_nose_hoover_exp_iL1(state, state.velocities, cell_velocities, dt) state.positions = positions state.cell = cell model_output = model(state) # Second half step: Update momenta momenta = _npt_nose_hoover_exp_iL2( - state, alpha, momenta, model_output["forces"], cell_momentum / cell_mass, dt_2 + state, alpha, momenta, model_output["forces"], cell_velocities, dt_2 ) cell_force_val = _npt_nose_hoover_compute_cell_force( alpha=alpha, @@ -1279,7 +1278,7 @@ def _npt_nose_hoover_inner_step( external_pressure=external_pressure, system_idx=state.system_idx, ) - cell_momentum = cell_momentum + dt_2 * cell_force_val + cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state state.positions = positions @@ -1375,8 +1374,9 @@ def npt_nose_hoover_init( atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) # Initialize cell variables with proper system dimensions + # cell_momentum: [n_systems, 1] for compatibility with half_step cell_position = torch.zeros(n_systems, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_systems, device=device, dtype=dtype) + cell_momentum = torch.zeros(n_systems, 1, device=device, dtype=dtype) # Convert kT to tensor if it's not already one if not isinstance(kT, torch.Tensor): @@ -1391,12 +1391,20 @@ def npt_nose_hoover_init( cell_mass = cell_mass.to(device=device, dtype=dtype) # Calculate cell kinetic energy (using first system for initialization) - KE_cell = ts.calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) + dof_barostat = torch.ones(n_systems, device=device, dtype=dtype) + KE_cell = (cell_momentum.squeeze(-1) ** 2) / (2 * cell_mass) + + # Initialize momenta + momenta = kwargs.get( + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) # Compute total DOF for thermostat initialization and a zero KE placeholder dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim - total_dof = int(dof_per_system.sum().item()) - KE_zero = torch.tensor(0.0, device=device, dtype=dtype) + KE_thermostat = ts.calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: @@ -1420,9 +1428,9 @@ def npt_nose_hoover_init( energy = model_output["energy"] # Create initial state - npt_state = NPTNoseHooverState( + return NPTNoseHooverState( positions=state.positions, - momenta=torch.zeros_like(state.positions), + momenta=momenta, energy=energy, forces=forces, masses=state.masses, @@ -1434,25 +1442,12 @@ def npt_nose_hoover_init( cell_position=cell_position, cell_momentum=cell_momentum, cell_mass=cell_mass, - barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=thermostat_fns.initialize(total_dof, KE_zero, kT), + barostat=barostat_fns.initialize(dof_barostat, KE_cell, kT), + thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) - # Initialize momenta - momenta = kwargs.get( - "momenta", - calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed - ), - ) - - # Initialize thermostat - npt_state.momenta = momenta - - return npt_state - def npt_nose_hoover_step( state: NPTNoseHooverState, @@ -1493,11 +1488,12 @@ def npt_nose_hoover_step( state = _npt_nose_hoover_update_cell_mass(state, kT, device, dtype) # First half step of thermostat chains + cell_system_idx = torch.arange(state.n_systems, device=device) state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT + state.cell_momentum, state.barostat, kT, cell_system_idx ) state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT + state.momenta, state.thermostat, kT, state.system_idx ) # Perform inner NPT step @@ -1509,19 +1505,50 @@ def npt_nose_hoover_step( ) state.thermostat.kinetic_energy = KE - KE_cell = ts.calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) + KE_cell = (torch.square(state.cell_momentum.squeeze(-1))) / (2 * state.cell_mass) state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT + state.momenta, state.thermostat, kT, state.system_idx ) state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT + state.cell_momentum, state.barostat, kT, cell_system_idx ) return state +def _compute_chain_energy( + chain: NoseHooverChain, kT: torch.Tensor, e_tot: torch.Tensor, dof: torch.Tensor +) -> torch.Tensor: + """Compute energy contribution from a Nose-Hoover chain. + + Args: + chain: The Nose-Hoover chain state + kT: Target temperature + e_tot: Current total energy for broadcasting + dof: Degrees of freedom (only used for first chain element) + + Returns: + Total chain energy contribution + """ + chain_energy = torch.zeros_like(e_tot) + + # First chain element with DOF weighting + ke_0 = torch.square(chain.momenta[:, 0]) / (2 * chain.masses[:, 0]) + pe_0 = dof * kT * chain.positions[:, 0] + + chain_energy += ke_0 + pe_0 + + # Remaining chain elements + for i in range(1, chain.positions.shape[1]): + ke = torch.square(chain.momenta[:, i]) / (2 * chain.masses[:, i]) + pe = kT * chain.positions[:, i] + chain_energy += ke + pe + + return chain_energy + + def npt_nose_hoover_invariant( state: NPTNoseHooverState, kT: torch.Tensor, @@ -1534,17 +1561,17 @@ def npt_nose_hoover_invariant( NPT simulations. The conserved quantity includes: - - Potential energy of the system + - Potential energy of the systems - Kinetic energy of the particles - - Energy contributions from thermostat chains - - Energy contributions from barostat chains + - Energy contributions from thermostat chains (per system) + - Energy contributions from barostat chains (per system) - PV work term - Cell kinetic energy Args: state: Current state of the NPT simulation system. Must contain position, momentum, cell, cell_momentum, cell_mass, thermostat, - and barostat. + and barostat with proper batching for multiple systems. external_pressure: Target external pressure of the system. kT: Target thermal energy (Boltzmann constant x temperature). @@ -1563,72 +1590,25 @@ def npt_nose_hoover_invariant( ) # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - DOF_per_system = ( - n_atoms_per_system * state.positions.shape[-1] - ) # n_atoms * n_dimensions + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + dof_per_system = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dim # Initialize total energy with PE + KE - if isinstance(e_pot, torch.Tensor) and e_pot.ndim > 0: - e_tot = e_pot + e_kin_per_system # [n_systems] - else: - e_tot = e_pot + e_kin_per_system # [n_systems] - - # Add thermostat chain contributions - # Note: These are global thermostat variables, so we add them to each system - # Start thermostat_energy as a tensor with the right shape - thermostat_energy = torch.zeros_like(e_tot) - thermostat_energy += torch.square(state.thermostat.momenta[0]) / ( - 2 * state.thermostat.masses[0] - ) - - # Ensure kT can broadcast properly with DOF_per_system - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT - expand to match DOF_per_system shape - kT_expanded = kT.expand_as(DOF_per_system) - else: - kT_expanded = kT + e_tot = e_pot + e_kin_per_system - thermostat_energy += DOF_per_system * kT_expanded * state.thermostat.positions[0] + # Add thermostat chain contributions (batched per system, DOF = n_atoms * 3) + e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system) - # Add remaining thermostat terms - for pos, momentum, mass in zip( - state.thermostat.positions[1:], - state.thermostat.momenta[1:], - state.thermostat.masses[1:], - strict=True, - ): - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT case - thermostat_energy += torch.square(momentum) / (2 * mass) + kT * pos - else: - # Batched kT case - thermostat_energy += torch.square(momentum) / (2 * mass) + kT_expanded * pos - - e_tot = e_tot + thermostat_energy - - # Add barostat chain contributions - barostat_energy = torch.zeros_like(e_tot) - for pos, momentum, mass in zip( - state.barostat.positions, - state.barostat.momenta, - state.barostat.masses, - strict=True, - ): - if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT case - barostat_energy += torch.square(momentum) / (2 * mass) + kT * pos - else: - # Batched kT case - barostat_energy += torch.square(momentum) / (2 * mass) + kT_expanded * pos - - e_tot = e_tot + barostat_energy + # Add barostat chain contributions (batched per system, DOF = 1) + barostat_dof = torch.ones_like(dof_per_system) # 1 DOF per system for barostat + e_tot += _compute_chain_energy(state.barostat, kT, e_tot, barostat_dof) # Add PV term and cell kinetic energy (both are per system) e_tot += external_pressure * volume - e_tot += torch.square(state.cell_momentum) / (2 * state.cell_mass) - # Return scalar if single system, otherwise return per-system values - if state.n_systems == 1: - return e_tot.squeeze() + # Ensure cell_momentum has the right shape [n_systems] + cell_momentum = state.cell_momentum.squeeze() + + e_tot += torch.square(cell_momentum) / (2 * state.cell_mass) + return e_tot diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 4021ce2b..478abfa9 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -310,10 +310,6 @@ def nvt_nose_hoover_init( n_atoms_per_system * state.positions.shape[-1] ) # n_atoms * n_dimensions - # For now, sum the per-system DOF as chain expects a single int - # This is a limitation that should be addressed in the chain implementation - total_dof = int(dof_per_system.sum().item()) - # Initialize state return NVTNoseHooverState( positions=state.positions, @@ -325,7 +321,7 @@ def nvt_nose_hoover_init( pbc=state.pbc, atomic_numbers=atomic_numbers, system_idx=state.system_idx, - chain=chain_fns.initialize(total_dof, KE, kT), + chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, # Store the chain functions ) @@ -368,7 +364,7 @@ def nvt_nose_hoover_step( chain = chain_fns.update_mass(chain, kT) # First half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) state.momenta = momenta # Full velocity Verlet step @@ -381,7 +377,7 @@ def nvt_nose_hoover_step( chain.kinetic_energy = KE # Second half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) state.momenta = momenta state.chain = chain @@ -433,8 +429,8 @@ def nvt_nose_hoover_invariant( # Add first thermostat term c = state.chain # Ensure chain momenta and masses broadcast correctly with batch dimensions - chain_ke_0 = torch.square(c.momenta[0]) / (2 * c.masses[0]) - chain_pe_0 = dof * kT * c.positions[0] + chain_ke_0 = torch.square(c.momenta[:, 0]) / (2 * c.masses[:, 0]) + chain_pe_0 = dof * kT * c.positions[:, 0] # If chain variables are scalars but we have batches, broadcast them if chain_ke_0.numel() == 1 and e_tot.numel() > 1: @@ -445,9 +441,11 @@ def nvt_nose_hoover_invariant( e_tot = e_tot + chain_ke_0 + chain_pe_0 # Add remaining chain terms - for pos, momentum, mass in zip( - c.positions[1:], c.momenta[1:], c.masses[1:], strict=True - ): + for i in range(1, c.positions.shape[1]): + pos = c.positions[:, i] + momentum = c.momenta[:, i] + mass = c.masses[:, i] + chain_ke = momentum**2 / (2 * mass) chain_pe = kT * pos