Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2bb3bd5
Fix memory scaling in `determine_max_batch_size` (#212)
t-reents Jul 4, 2025
317985c
Rename batch to system (#217)
curtischong Jul 16, 2025
6c79893
update metatomic checkpoint to fix tests (#223)
curtischong Jul 25, 2025
c0e3137
Update citation.cff (#225)
CompRhys Jul 25, 2025
5371cb6
add new states when the max_memory_scaler is updated (#222)
kianpu34593 Jul 25, 2025
16bf8f8
fix broken code block in low level tutorial (#226)
CompRhys Jul 25, 2025
926e043
Rename more batch to system (#233)
curtischong Aug 6, 2025
88abcff
Make system_idx non-optional in `SimState` [1/2] (#231)
curtischong Aug 7, 2025
f6cd006
Improve Typing of ModelInterface (#215)
curtischong Aug 8, 2025
e90b272
Initial fix for concatenation of states in `InFlightAutoBatcher` (#219)
t-reents Aug 8, 2025
71e1d41
Fix simstate concatenation [2/2] (#232)
curtischong Aug 8, 2025
fd865cc
Update readme plot (#236)
orionarcher Aug 11, 2025
a563523
Fix typo in README (#237)
orionarcher Aug 11, 2025
cde91e9
Define attribute scopes in SimStates (#228)
curtischong Aug 13, 2025
af3a269
Merge remote-tracking branch 'origin/main' into fairchem-v2
CompRhys Aug 13, 2025
cd0b0a0
MAINT: update pins in MACE phonons example. Remove misleading ty from…
CompRhys Aug 14, 2025
1552978
Merge remote-tracking branch 'origin/main' into fairchem-v2-patch
CompRhys Aug 14, 2025
de295df
Merge remote-tracking branch 'origin/main' into fairchem-v2-patch
CompRhys Aug 14, 2025
9c4b5ca
fix: replace batch with system_idx
CompRhys Aug 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Before a pull request can be merged, the following items must be checked:

* [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google).
Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
* [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code.
* [ ] Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ coverage.xml

# env
uv.lock

# IDE
.vscode/
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ print(relaxed_state.energy)

TorchSim achieves up to 100x speedup compared to ASE with popular MLIPs.

![Speedup comparison](/docs/_static/speedup_plot.svg)
<img src="/docs/_static/speedup_plot.svg" alt="Speedup comparison" width="100%">

This figure compares the time per atom of ASE and `torch_sim`. Time per atom is defined
as the number of atoms / total time. While ASE can only run a single system of `n_atoms`
(on the $x$ axis), `torch_sim` can run as many systems as will fit in memory. On an H100 80 GB card,
the max atoms that could fit in memory was ~8,000 for [GemNet](https://github.com/FAIR-Chem/fairchem), ~10,000 for [MACE](https://github.com/ACEsuit/mace), and ~2,500
for [SevenNet](https://github.com/MDIL-SNU/SevenNet). This metric describes model performance by capturing speed and memory
usage simultaneously.
the max atoms that could fit in memory was ~8,000 for [EGIP](https://github.com/FAIR-Chem/fairchem),
~10,000 for [MACE-MPA-0](https://github.com/ACEsuit/mace), ~22,000 for [Mattersim V1 1M](https://github.com/microsoft/mattersim),
~2,500 for [SevenNet](https://github.com/MDIL-SNU/SevenNet), and ~9000 for [PET-MAD](https://github.com/lab-cosmo/pet-mad).
This metric describes model performance by capturing speed and memory usage simultaneously.

## Installation

Expand Down
6 changes: 2 additions & 4 deletions citation.cff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cff-version: 1.2.0
title: Torch-Sim
title: TorchSim
message: If you use this software, please cite it as below.
authors:
- family-names: Gangan
Expand All @@ -17,8 +17,6 @@ authors:
license: MIT
license-url: https://github.com/Radical-AI/torch-sim/blob/main/LICENSE
repository-code: https://github.com/Radical-AI/torch-sim
type: software
url: https://github.com/Radical-AI/torch-sim
doi: 10.5281/zenodo.7486816
version: 0.1.0
type: software
date-released: 2025-04-02
3 changes: 2 additions & 1 deletion docs/_static/draw_pkg_treemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# /// script
# dependencies = [
# "pymatviz @ git+https://github.com/janosh/pymatviz",
# "pymatviz>=0.17.1",
# "plotly>=6.3.0",
# ]
# ///

Expand Down
181 changes: 180 additions & 1 deletion docs/_static/speedup_plot.svg
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions examples/scripts/1_Introduction/1.2_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,38 @@
cell = torch.tensor(cell_numpy, device=device, dtype=dtype)
atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int)

# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
atoms_per_batch = torch.tensor(
# create system idx array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
atoms_per_system = torch.tensor(
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
)
batch = torch.repeat_interleave(
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
system_idx = torch.repeat_interleave(
torch.arange(len(atoms_per_system), device=device), atoms_per_system
)

# You can see their shapes are as expected
print(f"Positions: {positions.shape}")
print(f"Cell: {cell.shape}")
print(f"Atomic numbers: {atomic_numbers.shape}")
print(f"Batch: {batch.shape}")
print(f"System indices: {system_idx.shape}")

# Now we can pass them to the model
results = batched_model(
dict(
positions=positions,
cell=cell,
atomic_numbers=atomic_numbers,
batch=batch,
system_idx=system_idx,
pbc=True,
)
)

# The energy has shape (n_batches,) as the structures in a batch
# The energy has shape (n_systems,) as the structures in a batch
print(f"Energy: {results['energy'].shape}")

# The forces have shape (n_atoms, 3) same as positions
print(f"Forces: {results['forces'].shape}")

# The stress has shape (n_batches, 3, 3) same as cell
# The stress has shape (n_systems, 3, 3) same as cell
print(f"Stress: {results['stress'].shape}")

# Check if the energy, forces, and stress are the same for the Si system across the batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)

# Run initial simulation and get results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)

# Initialize the Soft Sphere model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,20 @@
masses_numpy = np.concatenate([atoms.get_masses() for atoms in atoms_list])
masses = torch.tensor(masses_numpy, device=device, dtype=dtype)

# Create batch indices tensor for scatter operations
atoms_per_batch = torch.tensor(
# Create system indices tensor for scatter operations
atoms_per_system = torch.tensor(
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
)
batch_indices = torch.repeat_interleave(
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
system_indices = torch.repeat_interleave(
torch.arange(len(atoms_per_system), device=device), atoms_per_system
)
"""

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)

print(f"Positions shape: {state.positions.shape}")
print(f"Cell shape: {state.cell.shape}")
print(f"Batch indices shape: {state.batch.shape}")
print(f"System indices shape: {state.system_idx.shape}")

# Run initial inference
results = batched_model(state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
# Initialize unit cell gradient descent optimizer
gd_init, gd_update = unit_cell_gradient_descent(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per system
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# Initialize unit cell gradient descent optimizer
fire_init, fire_update = unit_cell_fire(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per system
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# Initialize unit cell gradient descent optimizer
fire_init, fire_update = ts.optimizers.frechet_cell_fire(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per system
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
5 changes: 4 additions & 1 deletion examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class HybridSwapMCState(MDState):
"""

last_permutation: torch.Tensor
_atom_attributes = (
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
)


nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT)
Expand All @@ -86,7 +89,7 @@ class HybridSwapMCState(MDState):
hybrid_state = HybridSwapMCState(
**vars(md_state),
last_permutation=torch.zeros(
md_state.n_batches, device=md_state.device, dtype=torch.bool
md_state.n_systems, device=md_state.device, dtype=torch.bool
),
)

Expand Down
16 changes: 9 additions & 7 deletions examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
pbc=True,
atomic_numbers=atomic_numbers,
pbc=True,
)
# Run initial simulation and get results
results = model(state)
Expand All @@ -119,13 +119,15 @@
for step in range(N_steps):
if step % 50 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
/ Units.temperature
)
pressure = get_pressure(
model(state)["stress"],
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
),
torch.linalg.det(state.cell),
)
Expand All @@ -139,18 +141,18 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
/ Units.temperature
)
print(f"Final temperature: {temp.item():.4f}")


stress = model(state)["stress"]
calc_kinetic_energy = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
kinetic_energy = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
volume = torch.linalg.det(state.cell)
pressure = get_pressure(stress, calc_kinetic_energy, volume)
pressure = get_pressure(stress, kinetic_energy, volume)
pressure = pressure.item() / Units.pressure
print(f"Final {pressure=:.4f}")
print(stress * UnitConversion.eV_per_Ang3_to_GPa)
16 changes: 11 additions & 5 deletions examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@
for step in range(N_steps_nvt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
/ Units.temperature
)
invariant = float(nvt_nose_hoover_invariant(state, kT=kT))
Expand All @@ -83,7 +85,9 @@
for step in range(N_steps_npt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
/ Units.temperature
)
stress = model(state)["stress"]
Expand All @@ -92,7 +96,9 @@
get_pressure(
stress,
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses,
momenta=state.momenta,
system_idx=state.system_idx,
),
volume,
).item()
Expand All @@ -107,7 +113,7 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f} K")
Expand All @@ -117,7 +123,7 @@
get_pressure(
final_stress,
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
),
final_volume,
).item()
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
6 changes: 1 addition & 5 deletions examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@
masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)
# Initialize the Lennard-Jones model
# Parameters:
Expand Down
8 changes: 2 additions & 6 deletions examples/scripts/3_Dynamics/3.2_MACE_NVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@
)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)

# Run initial inference
Expand All @@ -88,7 +84,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
12 changes: 5 additions & 7 deletions examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@
)

state = ts.SimState(
positions=positions,
masses=masses,
cell=cell,
pbc=True,
atomic_numbers=atomic_numbers,
positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True
)

dt = 0.002 * Units.time # Timestep (ps)
Expand All @@ -83,14 +79,16 @@
for step in range(N_steps):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
/ Units.temperature
)
print(f"{step=}: Temperature: {temp.item():.4f}")
state = langevin_update(state=state, kT=kT)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f}")
Loading