-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Labels
Description
I'm looking at the nvt_nose_hoover method and I see this intriguing comment:
# For now, sum the per-system DOF as chain expects a single int
# This is a limitation that should be addressed in the chain implementation
total_dof = int(dof_per_system.sum().item())
This seems quite weird if one wants to run NVT MD simulation with a batched system. Actually looking at the NoseHooverChain, nothing indicates that it handles a batch system.
Running script 3.5 and defining a batch system, the code actually fails.
"""NVT simulation with MACE and Nose-Hoover thermostat."""
# /// script
# dependencies = [
# "mace-torch>=0.3.12",
# ]
# ///
import os
import torch
from ase.build import bulk
import torch_sim as ts
from torch_sim.integrators.nvt import nvt_nose_hoover, nvt_nose_hoover_invariant
from orb_models.forcefield import pretrained
from torch_sim.models.orb import OrbModel
from torch_sim.quantities import calc_kT
from torch_sim.units import MetalUnits as Units
from torch_sim.state import concatenate_states
# Set device and data type
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
# Number of steps to run
N_steps = 20
# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
orbff = pretrained.orb_v3_conservative_inf_omat(
device=device,
precision="float32-highest",
compile=False
)
orb_model = OrbModel(model=orbff, dtype=torch.get_default_dtype())
state = ts.io.atoms_to_state(si_dc, device=device, dtype=dtype)
state = concatenate_states([state, state]) # Create a batch of 2 systems
# Run initial inference
results = orb_model(state)
dt = 0.002 * Units.time # Timestep (ps)
kT = (
torch.tensor(1000, device=device, dtype=dtype) * Units.temperature
) # Initial temperature (K)
nvt_init, nvt_update = nvt_nose_hoover(model=orb_model, kT=kT, dt=dt)
state = nvt_init(state=state, kT=kT, seed=1)
for step in range(N_steps):
if step % 10 == 0:
temp = (
calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
/ Units.temperature
)
invariant = nvt_nose_hoover_invariant(state, kT=kT)
print(f"{step}: Temperature: {temp}: {invariant}")
state = nvt_update(state=state, kT=kT)
final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f}")