Skip to content

Commit 86a89af

Browse files
Add per atom energies and stresses for batched LJ (#144)
* Add per atom energies and stresses for batched LJ * update changelog for v0.2.0 (#147) * update changelog for v0.2.0 * minor modification for PR template * formatting fixes * formatting and typos * remove contributors bc they aren't linked * Add tests * simplify results --------- Signed-off-by: Abhijeet Gangan <[email protected]> Co-authored-by: Orion Cohen <[email protected]>
1 parent 6cf3ea8 commit 86a89af

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

examples/scripts/1_Introduction/1.1_Lennard_Jones.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torch_sim.models.lennard_jones import LennardJonesModel
1314
from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel
1415

1516

@@ -69,6 +70,8 @@
6970
dtype=dtype,
7071
compute_forces=True,
7172
compute_stress=True,
73+
per_atom_energies=True,
74+
per_atom_stresses=True,
7275
)
7376

7477
# Print system information
@@ -88,3 +91,37 @@
8891
print(f"Energy: {results['energy']}")
8992
print(f"Forces: {results['forces']}")
9093
print(f"Stress: {results['stress']}")
94+
print(f"Energies: {results['energies']}")
95+
print(f"Stresses: {results['stresses']}")
96+
97+
# Batched model
98+
batched_model = LennardJonesModel(
99+
use_neighbor_list=True,
100+
cutoff=2.5 * 3.405,
101+
sigma=3.405,
102+
epsilon=0.0104,
103+
device=device,
104+
dtype=dtype,
105+
compute_forces=True,
106+
compute_stress=True,
107+
per_atom_energies=True,
108+
per_atom_stresses=True,
109+
)
110+
111+
# Batched state
112+
state = dict(
113+
positions=positions,
114+
cell=cell.unsqueeze(0),
115+
atomic_numbers=atomic_numbers,
116+
pbc=True,
117+
)
118+
119+
# Run the simulation and get results
120+
results = batched_model(state)
121+
122+
# Print the results
123+
print(f"Energy: {results['energy']}")
124+
print(f"Forces: {results['forces']}")
125+
print(f"Stress: {results['stress']}")
126+
print(f"Energies: {results['energies']}")
127+
print(f"Stresses: {results['stresses']}")

tests/models/test_lennard_jones.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def models(
155155
"dtype": torch.float64,
156156
"compute_forces": True,
157157
"compute_stress": True,
158+
"per_atom_energies": True,
159+
"per_atom_stresses": True,
158160
}
159161

160162
cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma
@@ -178,6 +180,14 @@ def test_energy_match(
178180
assert torch.allclose(results_nl["energy"], results_direct["energy"], rtol=1e-10)
179181

180182

183+
def test_per_atom_energy_match(
184+
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
185+
) -> None:
186+
"""Test that per-atom energy matches between neighbor list and direct calculations."""
187+
results_nl, results_direct = models
188+
assert torch.allclose(results_nl["energies"], results_direct["energies"], rtol=1e-10)
189+
190+
181191
def test_forces_match(
182192
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
183193
) -> None:
@@ -194,6 +204,15 @@ def test_stress_match(
194204
assert torch.allclose(results_nl["stress"], results_direct["stress"], rtol=1e-10)
195205

196206

207+
def test_per_atom_stress_match(
208+
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
209+
) -> None:
210+
"""Test that per-atom stress tensors match between neighbor list
211+
and direct calculations."""
212+
results_nl, results_direct = models
213+
assert torch.allclose(results_nl["stresses"], results_direct["stresses"], rtol=1e-10)
214+
215+
197216
def test_force_conservation(
198217
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
199218
) -> None:

torch_sim/models/lennard_jones.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,10 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
293293
compute_forces=True)
294294
- "stress": Stress tensor with shape [n_batches, 3, 3] (if
295295
compute_stress=True)
296-
- May include additional outputs based on configuration
296+
- "energies": Per-atom energies with shape [n_atoms] (if
297+
per_atom_energies=True)
298+
- "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
299+
per_atom_stresses=True)
297300
298301
Raises:
299302
ValueError: If batch cannot be inferred for multi-cell systems.
@@ -307,6 +310,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
307310
energy = results["energy"] # Shape: [n_batches]
308311
forces = results["forces"] # Shape: [n_atoms, 3]
309312
stress = results["stress"] # Shape: [n_batches, 3, 3]
313+
energies = results["energies"] # Shape: [n_atoms]
314+
stresses = results["stresses"] # Shape: [n_atoms, 3, 3]
310315
"""
311316
if isinstance(state, dict):
312317
state = SimState(**state, masses=torch.ones_like(state["positions"]))
@@ -324,7 +329,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
324329
for key in ("stress", "energy"):
325330
if key in properties:
326331
results[key] = torch.stack([out[key] for out in outputs])
327-
for key in ("forces",):
332+
for key in ("forces", "energies", "stresses"):
328333
if key in properties:
329334
results[key] = torch.cat([out[key] for out in outputs], dim=0)
330335

0 commit comments

Comments
 (0)