Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/scripts/1_Introduction/1.1_Lennard_Jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel


Expand Down Expand Up @@ -69,6 +70,8 @@
dtype=dtype,
compute_forces=True,
compute_stress=True,
per_atom_energies=True,
per_atom_stresses=True,
)

# Print system information
Expand All @@ -88,3 +91,37 @@
print(f"Energy: {results['energy']}")
print(f"Forces: {results['forces']}")
print(f"Stress: {results['stress']}")
print(f"Energies: {results['energies']}")
print(f"Stresses: {results['stresses']}")

# Batched model
batched_model = LennardJonesModel(
use_neighbor_list=True,
cutoff=2.5 * 3.405,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
per_atom_energies=True,
per_atom_stresses=True,
)

# Batched state
state = dict(
positions=positions,
cell=cell.unsqueeze(0),
atomic_numbers=atomic_numbers,
pbc=True,
)

# Run the simulation and get results
results = batched_model(state)

# Print the results
print(f"Energy: {results['energy']}")
print(f"Forces: {results['forces']}")
print(f"Stress: {results['stress']}")
print(f"Energies: {results['energies']}")
print(f"Stresses: {results['stresses']}")
19 changes: 19 additions & 0 deletions tests/models/test_lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def models(
"dtype": torch.float64,
"compute_forces": True,
"compute_stress": True,
"per_atom_energies": True,
"per_atom_stresses": True,
}

cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma
Expand All @@ -178,6 +180,14 @@ def test_energy_match(
assert torch.allclose(results_nl["energy"], results_direct["energy"], rtol=1e-10)


def test_per_atom_energy_match(
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
) -> None:
"""Test that per-atom energy matches between neighbor list and direct calculations."""
results_nl, results_direct = models
assert torch.allclose(results_nl["energies"], results_direct["energies"], rtol=1e-10)


def test_forces_match(
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
) -> None:
Expand All @@ -194,6 +204,15 @@ def test_stress_match(
assert torch.allclose(results_nl["stress"], results_direct["stress"], rtol=1e-10)


def test_per_atom_stress_match(
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
) -> None:
"""Test that per-atom stress tensors match between neighbor list
and direct calculations."""
results_nl, results_direct = models
assert torch.allclose(results_nl["stresses"], results_direct["stresses"], rtol=1e-10)


def test_force_conservation(
models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
) -> None:
Expand Down
9 changes: 7 additions & 2 deletions torch_sim/models/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@
compute_forces=True)
- "stress": Stress tensor with shape [n_batches, 3, 3] (if
compute_stress=True)
- May include additional outputs based on configuration
- "energies": Per-atom energies with shape [n_atoms] (if
per_atom_energies=True)
- "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
per_atom_stresses=True)

Raises:
ValueError: If batch cannot be inferred for multi-cell systems.
Expand All @@ -307,6 +310,8 @@
energy = results["energy"] # Shape: [n_batches]
forces = results["forces"] # Shape: [n_atoms, 3]
stress = results["stress"] # Shape: [n_batches, 3, 3]
energies = results["energies"] # Shape: [n_atoms]
stresses = results["stresses"] # Shape: [n_atoms, 3, 3]
"""
if isinstance(state, dict):
state = SimState(**state, masses=torch.ones_like(state["positions"]))
Expand All @@ -324,7 +329,7 @@
for key in ("stress", "energy"):
if key in properties:
results[key] = torch.stack([out[key] for out in outputs])
for key in ("forces",):
for key in ("forces", "energies", "stresses"):

Check warning on line 332 in torch_sim/models/lennard_jones.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/lennard_jones.py#L332

Added line #L332 was not covered by tests
if key in properties:
results[key] = torch.cat([out[key] for out in outputs], dim=0)

Expand Down