From 2bb3bd5f27b3c5621ae546e6d6207d59600fae35 Mon Sep 17 00:00:00 2001 From: Timo Reents <77727843+t-reents@users.noreply.github.com> Date: Fri, 4 Jul 2025 20:56:15 +0200 Subject: [PATCH 01/16] Fix memory scaling in `determine_max_batch_size` (#212) * Fix memory scaling in `determine_max_batch_size` The current version results in an infinite loop when `scale_factor < 1.5` due to the rounding. This is fixed by increasing the batch size by at least `+1`. * add `test_autobatching.py` check to ensure `determine_max_batch_size` does regress to infinite loop * remove outdated pymatviz extras 'export-figs' in `6.1_Phonons_MACE.py` and `6.2_QuasiHarmonic_MACE.py` * pin plotly!=6.2.0 --------- Co-authored-by: Janosh Riebesell --- docs/_static/draw_pkg_treemap.py | 1 + .../scripts/6_Phonons/6.1_Phonons_MACE.py | 3 ++- .../6_Phonons/6.2_QuasiHarmonic_MACE.py | 3 ++- tests/test_autobatching.py | 26 +++++++++++++++++++ torch_sim/autobatching.py | 4 +-- 5 files changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 3e2775a3..f339a604 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -6,6 +6,7 @@ # /// script # dependencies = [ # "pymatviz @ git+https://github.com/janosh/pymatviz", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 785c67d5..f88fd351 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,9 +4,10 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.16", # "seekpath", # "ase", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 3abf8323..0fdea6b4 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -6,7 +6,8 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.16", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 2c974ef6..5d15d4a1 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -376,6 +376,32 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: assert max_size == 8 +@pytest.mark.parametrize("scale_factor", [1.1, 1.4]) +def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( + si_sim_state: ts.SimState, + lj_model: LennardJonesModel, + monkeypatch: pytest.MonkeyPatch, + scale_factor: float, +) -> None: + """Test determine_max_batch_size doesn't infinite loop with small scale factors.""" + monkeypatch.setattr( + "torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1 + ) + + max_size = determine_max_batch_size( + si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor + ) + assert 0 < max_size <= 20 + + # Verify sequence is strictly increasing (prevents infinite loop) + sizes = [1] + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + sizes.append(next_size) + + assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) + assert max_size == sizes[-1] + + def test_in_flight_auto_batcher_restore_order( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index f2eb32c7..d436076a 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -268,7 +268,7 @@ def determine_max_batch_size( Defaults to 500,000. start_size (int): Initial batch size to test. Defaults to 1. scale_factor (float): Factor to multiply batch size by in each iteration. - Defaults to 1.3. + Defaults to 1.6. Returns: int: Maximum number of batches that fit in GPU memory. @@ -289,7 +289,7 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := round(sizes[-1] * scale_factor)) < max_atoms: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: sizes.append(next_size) for i in range(len(sizes)): From 317985c731170aad578673ebe69a9334f5abe5be Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Wed, 16 Jul 2025 12:44:02 -0400 Subject: [PATCH 02/16] Rename batch to system (#217) --- .gitignore | 3 + examples/scripts/1_Introduction/1.2_MACE.py | 16 +- .../2.3_MACE_Gradient_Descent.py | 10 +- ....5_MACE_UnitCellFilter_Gradient_Descent.py | 2 +- .../2.6_MACE_UnitCellFilter_FIRE.py | 2 +- .../2.7_MACE_FrechetCellFilter_FIRE.py | 2 +- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- .../3.11_Lennard_Jones_NPT_Langevin.py | 10 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 16 +- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 2 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 2 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 2 +- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 6 +- .../3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py | 6 +- .../3.6_MACE_NVT_Nose_Hoover_temp_profile.py | 2 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 12 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 16 +- .../3.9_MACE_NVT_staggered_stress.py | 2 +- .../4_High_level_api/4.2_auto_batching_api.py | 2 +- .../scripts/5_Workflow/5.2_In_Flight_WBM.py | 2 +- .../7_Others/7.3_Batched_neighbor_list.py | 8 +- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 26 +- examples/tutorials/autobatching_tutorial.py | 6 +- examples/tutorials/high_level_tutorial.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- examples/tutorials/low_level_tutorial.py | 2 +- examples/tutorials/state_tutorial.py | 22 +- pyproject.toml | 2 +- tests/models/conftest.py | 4 +- tests/test_autobatching.py | 28 +- tests/test_correlations.py | 10 +- tests/test_integrators.py | 14 +- tests/test_io.py | 25 +- tests/test_monte_carlo.py | 30 +- tests/test_neighbors.py | 12 +- tests/test_optimizers.py | 28 +- tests/test_runners.py | 46 +- tests/test_state.py | 154 +++--- tests/test_trajectory.py | 50 +- tests/test_transforms.py | 92 ++-- torch_sim/autobatching.py | 36 +- torch_sim/integrators/md.py | 42 +- torch_sim/integrators/npt.py | 434 +++++++++-------- torch_sim/integrators/nve.py | 12 +- torch_sim/integrators/nvt.py | 64 +-- torch_sim/io.py | 102 ++-- torch_sim/math.py | 2 +- torch_sim/models/fairchem.py | 8 +- torch_sim/models/graphpes.py | 4 +- torch_sim/models/interface.py | 18 +- torch_sim/models/lennard_jones.py | 18 +- torch_sim/models/mace.py | 60 +-- torch_sim/models/metatomic.py | 4 +- torch_sim/models/morse.py | 14 +- torch_sim/models/orb.py | 6 +- torch_sim/models/particle_life.py | 12 +- torch_sim/models/sevennet.py | 6 +- torch_sim/models/soft_sphere.py | 34 +- torch_sim/monte_carlo.py | 68 +-- torch_sim/neighbors.py | 6 +- torch_sim/optimizers.py | 460 +++++++++--------- torch_sim/quantities.py | 54 +- torch_sim/runners.py | 16 +- torch_sim/state.py | 422 +++++++++------- torch_sim/trajectory.py | 43 +- torch_sim/transforms.py | 82 ++-- torch_sim/typing.py | 2 +- torch_sim/workflows/a2c.py | 4 +- 68 files changed, 1432 insertions(+), 1281 deletions(-) diff --git a/.gitignore b/.gitignore index 9c028c81..2a0bbdf2 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ coverage.xml # env uv.lock + +# IDE +.vscode/ diff --git a/examples/scripts/1_Introduction/1.2_MACE.py b/examples/scripts/1_Introduction/1.2_MACE.py index e417ed52..f627bb5d 100644 --- a/examples/scripts/1_Introduction/1.2_MACE.py +++ b/examples/scripts/1_Introduction/1.2_MACE.py @@ -63,19 +63,19 @@ cell = torch.tensor(cell_numpy, device=device, dtype=dtype) atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int) -# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms -atoms_per_batch = torch.tensor( +# create system idx array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms +atoms_per_system = torch.tensor( [len(atoms) for atoms in atoms_list], device=device, dtype=torch.int ) -batch = torch.repeat_interleave( - torch.arange(len(atoms_per_batch), device=device), atoms_per_batch +system_idx = torch.repeat_interleave( + torch.arange(len(atoms_per_system), device=device), atoms_per_system ) # You can see their shapes are as expected print(f"Positions: {positions.shape}") print(f"Cell: {cell.shape}") print(f"Atomic numbers: {atomic_numbers.shape}") -print(f"Batch: {batch.shape}") +print(f"System indices: {system_idx.shape}") # Now we can pass them to the model results = batched_model( @@ -83,18 +83,18 @@ positions=positions, cell=cell, atomic_numbers=atomic_numbers, - batch=batch, + system_idx=system_idx, pbc=True, ) ) -# The energy has shape (n_batches,) as the structures in a batch +# The energy has shape (n_systems,) as the structures in a batch print(f"Energy: {results['energy'].shape}") # The forces have shape (n_atoms, 3) same as positions print(f"Forces: {results['forces'].shape}") -# The stress has shape (n_batches, 3, 3) same as cell +# The stress has shape (n_systems, 3, 3) same as cell print(f"Stress: {results['stress'].shape}") # Check if the energy, forces, and stress are the same for the Si system across the batch diff --git a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py index 0ae9f9e0..819ec6a4 100644 --- a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py @@ -93,12 +93,12 @@ masses_numpy = np.concatenate([atoms.get_masses() for atoms in atoms_list]) masses = torch.tensor(masses_numpy, device=device, dtype=dtype) -# Create batch indices tensor for scatter operations -atoms_per_batch = torch.tensor( +# Create system indices tensor for scatter operations +atoms_per_system = torch.tensor( [len(atoms) for atoms in atoms_list], device=device, dtype=torch.int ) -batch_indices = torch.repeat_interleave( - torch.arange(len(atoms_per_batch), device=device), atoms_per_batch +system_indices = torch.repeat_interleave( + torch.arange(len(atoms_per_system), device=device), atoms_per_system ) """ @@ -106,7 +106,7 @@ print(f"Positions shape: {state.positions.shape}") print(f"Cell shape: {state.cell.shape}") -print(f"Batch indices shape: {state.batch.shape}") +print(f"System indices shape: {state.system_idx.shape}") # Run initial inference results = batched_model(state) diff --git a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py index f92a7185..5417724c 100644 --- a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py @@ -85,7 +85,7 @@ # Initialize unit cell gradient descent optimizer gd_init, gd_update = unit_cell_gradient_descent( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py index 8125e296..85a7bd13 100644 --- a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py @@ -81,7 +81,7 @@ # Initialize unit cell gradient descent optimizer fire_init, fire_update = unit_cell_fire( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py index 1799ba13..ba06f850 100644 --- a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py @@ -80,7 +80,7 @@ # Initialize unit cell gradient descent optimizer fire_init, fire_update = ts.optimizers.frechet_cell_fire( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 38278198..6fcb50e2 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -86,7 +86,7 @@ class HybridSwapMCState(MDState): hybrid_state = HybridSwapMCState( **vars(md_state), last_permutation=torch.zeros( - md_state.n_batches, device=md_state.device, dtype=torch.bool + md_state.n_systems, device=md_state.device, dtype=torch.bool ), ) diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 4932d432..8c87f5a6 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -119,13 +119,15 @@ for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) pressure = get_pressure( model(state)["stress"], calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), torch.linalg.det(state.cell), ) @@ -139,7 +141,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") @@ -147,7 +149,7 @@ stress = model(state)["stress"] calc_kinetic_energy = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) volume = torch.linalg.det(state.cell) pressure = get_pressure(stress, calc_kinetic_energy, volume) diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index a02ecaff..a07a4d74 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -68,7 +68,9 @@ for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) @@ -83,7 +85,9 @@ for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) stress = model(state)["stress"] @@ -92,7 +96,9 @@ get_pressure( stress, calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, ), volume, ).item() @@ -107,7 +113,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f} K") @@ -117,7 +123,7 @@ get_pressure( final_stress, calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), final_volume, ).item() diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index c7ec1273..07fcb4c8 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -77,7 +77,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index f9e6bf53..10cc1dc4 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -88,7 +88,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index f09b5bc0..2dd507fe 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -71,7 +71,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index 3ef1b93f..d7846c7e 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -83,14 +83,16 @@ for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) print(f"{step=}: Temperature: {temp.item():.4f}") state = langevin_update(state=state, kT=kT) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py index d1c88893..97cb06e9 100644 --- a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py @@ -68,7 +68,9 @@ for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) @@ -76,7 +78,7 @@ state = nvt_update(state=state, kT=kT) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py index a94337f1..aafc28ac 100644 --- a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py +++ b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py @@ -163,7 +163,7 @@ def get_kT( # Calculate current temperature and save data temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) actual_temps[step] = temp diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index e85379d9..6375bc61 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -123,14 +123,16 @@ for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) invariant = float( npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) e_kin = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) pressure = get_pressure( model(state)["stress"], e_kin, torch.det(state.current_cell) @@ -145,14 +147,16 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") pressure = get_pressure( model(state)["stress"], - calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch), + calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ), torch.det(state.current_cell), ) pressure = pressure.item() / Units.pressure diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index bfa45313..9dbc402a 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -69,7 +69,9 @@ for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) invariant = float( @@ -86,7 +88,9 @@ for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) invariant = float( @@ -95,7 +99,7 @@ stress = model(state)["stress"] volume = torch.det(state.current_cell) e_kin = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) pressure = float(get_pressure(stress, e_kin, volume)) xx, yy, zz = torch.diag(state.current_cell[0]) @@ -107,7 +111,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") @@ -115,7 +119,9 @@ final_volume = torch.det(state.current_cell) final_pressure = get_pressure( final_stress, - calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch), + calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ), final_volume, ) print(f"Final pressure: {final_pressure.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index 9d7e02a6..0106f3b8 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -65,7 +65,7 @@ stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) for step in range(N_steps): temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index 7a3387ed..66463d98 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -76,7 +76,7 @@ all_completed_states, convergence_tensor, state = [], None, None while (result := batcher.next_batch(state, convergence_tensor))[0] is not None: state, completed_states = result - print(f"Starting new batch of {state.n_batches} states.") + print(f"Starting new batch of {state.n_systems} states.") all_completed_states.extend(completed_states) print("Total number of completed states", len(all_completed_states)) diff --git a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py index 78a1019f..1003f636 100644 --- a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py +++ b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py @@ -83,7 +83,7 @@ all_completed_states, convergence_tensor, state = [], None, None while (result := batcher.next_batch(state, convergence_tensor))[0] is not None: state, completed_states = result - print(f"Starting new batch of {state.n_batches} states.") + print(f"Starting new batch of {state.n_systems} states.") all_completed_states.extend(completed_states) print("Total number of completed states", len(all_completed_states)) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index a21ddeaf..91141fd3 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -18,15 +18,15 @@ atoms_list = [bulk("Si", "diamond", a=5.43), bulk("Ge", "diamond", a=5.65)] state = ts.io.atoms_to_state(atoms_list, device="cpu", dtype=torch.float32) pos, cell, pbc = state.positions, state.cell, state.pbc -batch, n_atoms = state.batch, state.n_atoms +system_idx, n_atoms = state.system_idx, state.n_atoms cutoff = 4.0 self_interaction = False -# Fix: Ensure pbc has the correct shape [n_batches, 3] +# Fix: Ensure pbc has the correct shape [n_systems, 3] pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) mapping, mapping_batch, shifts_idx = torch_nl_linked_cell( - cutoff, pos, cell, pbc_tensor, batch, self_interaction + cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) @@ -38,7 +38,7 @@ print(dds.shape) mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2( - cutoff, pos, cell, pbc_tensor, batch, self_interaction + cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2) dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 82fdf8f0..1eed7d9f 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -141,7 +141,7 @@ def run_optimization_ts( # noqa: PLR0915 start_time = time.perf_counter() print("Initial cell parameters (Torch-Sim):") - for k_idx in range(initial_state.n_batches): + for k_idx in range(initial_state.n_systems): cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) @@ -168,7 +168,7 @@ def run_optimization_ts( # noqa: PLR0915 ) batcher.load_states(opt_state) - total_structures = opt_state.n_batches + total_structures = opt_state.n_systems convergence_steps = torch.full( (total_structures,), -1, dtype=torch.long, device=device ) @@ -219,7 +219,7 @@ def run_optimization_ts( # noqa: PLR0915 if global_step % 50 == 0: total_converged_frac = converged_tensor_global.sum().item() / total_structures - active_structures = opt_state.n_batches if opt_state else 0 + active_structures = opt_state.n_systems if opt_state else 0 print( f"{global_step=}: Active structures: {active_structures}, " f"Total converged: {total_converged_frac:.2%}" @@ -230,7 +230,7 @@ def run_optimization_ts( # noqa: PLR0915 if final_state_concatenated is not None and hasattr(final_state_concatenated, "cell"): print("Final cell parameters (Torch-Sim):") - for k_idx in range(final_state_concatenated.n_batches): + for k_idx in range(final_state_concatenated.n_systems): cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) @@ -329,12 +329,12 @@ def run_optimization_ase( # noqa: C901, PLR0915 all_masses = [] all_atomic_numbers = [] all_cells = [] - all_batches_for_gd = [] + all_systems_for_gd = [] final_energies_ase = [] final_forces_ase_tensors = [] current_atom_offset = 0 - for batch_idx, ats_final in enumerate(final_ase_atoms_list): + for system_idx, ats_final in enumerate(final_ase_atoms_list): all_positions.append( torch.tensor(ats_final.get_positions(), device=device, dtype=dtype) ) @@ -350,9 +350,9 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) num_atoms_in_current = len(ats_final) - all_batches_for_gd.append( + all_systems_for_gd.append( torch.full( - (num_atoms_in_current,), batch_idx, device=device, dtype=torch.long + (num_atoms_in_current,), system_idx, device=device, dtype=torch.long ) ) current_atom_offset += num_atoms_in_current @@ -361,7 +361,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 if ats_final.calc is None: print( "Re-attaching ASE calculator for final energy/forces for " - f"structure {batch_idx}." + f"structure {system_idx}." ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, @@ -375,7 +375,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) except Exception as e: # noqa: BLE001 print( - f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}" + f"Couldn't get final energy/forces for an ASE structure {system_idx}: {e}" ) final_energies_ase.append(float("nan")) if all_positions and len(all_positions[-1]) > 0: @@ -393,8 +393,8 @@ def run_optimization_ase( # noqa: C901, PLR0915 concatenated_positions = torch.cat(all_positions, dim=0) concatenated_masses = torch.cat(all_masses, dim=0) concatenated_atomic_numbers = torch.cat(all_atomic_numbers, dim=0) - concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (N_batch, 3, 3) - concatenated_batch_indices = torch.cat(all_batches_for_gd, dim=0) + concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (n_systems, 3, 3) + concatenated_system_indices = torch.cat(all_systems_for_gd, dim=0) concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) @@ -413,7 +413,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 cell=concatenated_cells, pbc=initial_state.pbc, atomic_numbers=concatenated_atomic_numbers, - batch=concatenated_batch_indices, + system_idx=concatenated_system_indices, energy=concatenated_energies, forces=concatenated_forces, ) diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index c1647190..fec39945 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -249,7 +249,7 @@ def process_batch(batch): fire_state.positions = ( fire_state.positions + torch.randn_like(fire_state.positions) * 0.05 ) -total_states = fire_state.n_batches +total_states = fire_state.n_systems # Define a convergence function that checks the force on each atom is less than 5e-1 convergence_fn = ts.generate_force_convergence_fn(5e-1) @@ -279,11 +279,11 @@ def process_batch(batch): assert len(final_states) == total_states # Note that the fire_state has been modified in place -assert fire_state.n_batches == 0 +assert fire_state.n_systems == 0 # %% -fire_state.n_batches +fire_state.n_systems # %% [markdown] diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index 9d6c4ab4..c26d7479 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -375,7 +375,7 @@ def mock_determine_max_batch_size(*args, **kwargs): convergence function are `state` and `last_energy`. The `state` is a `SimState` object that contains the current state of the system and the `last_energy` is the energy of the previous step. The convergence function should return a boolean tensor of length -`n_batches`. +`n_systems`. This is how we'd manually define the default `convergence_fn`: """ diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 28c73fc0..d8adbbf8 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -133,7 +133,7 @@ class HybridSwapMCState(ts.integrators.MDState): hybrid_state = HybridSwapMCState( **vars(md_state), last_permutation=torch.zeros( - md_state.n_batches, device=md_state.device, dtype=torch.bool + md_state.n_systems, device=md_state.device, dtype=torch.bool ), ) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index f12893c3..d863ca09 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -228,7 +228,7 @@ state = nvt_langevin_update_fn(state=state, kT=initial_kT * (1 + step / 30)) if step % 5 == 0: temp_E_units = ts.calc_kT( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) temp = temp_E_units / MetalUnits.temperature print(f"{step=}: Temperature: {temp}") diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 1abf5d28..0bc9e341 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -31,7 +31,7 @@ * Unit cell parameters * Periodic boundary conditions * Atomic numbers (elements) -* Batch indices (for processing multiple systems simultaneously) +* System indices (for processing multiple systems simultaneously) """ @@ -58,7 +58,7 @@ # Convert to SimState si_state = ts.initialize_state(si_atoms, device=torch.device("cpu"), dtype=torch.float64) -print(f"State has {si_state.n_atoms} atoms and {si_state.n_batches} batches") +print(f"State has {si_state.n_atoms} atoms and {si_state.n_systems} systems") # here we print all the attributes of the SimState print(f"Positions shape: {si_state.positions.shape}") @@ -66,7 +66,7 @@ print(f"Atomic numbers shape: {si_state.atomic_numbers.shape}") print(f"Masses shape: {si_state.masses.shape}") print(f"PBC: {si_state.pbc}") -print(f"Batch indices shape: {si_state.batch.shape}") +print(f"System indices shape: {si_state.system_idx.shape}") # %% [markdown] @@ -75,7 +75,7 @@ * Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`, `masses`, `atomic_numbers`, and `batch`. Names are plural. -* Batchwise attributes are tensors with shape (n_batches, ...), this is just `cell` for +* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. @@ -109,14 +109,14 @@ ) print( - f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_batches} batches" + f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems" ) # we can see how the shapes of batchwise, atomwise, and global properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") -print(f"Batch indices shape: {multi_state.batch.shape}") +print(f"System indices shape: {multi_state.system_idx.shape}") # %% [markdown] @@ -148,18 +148,18 @@ # %% we can copy the state with the clone method multi_state_copy = multi_state.clone() -print(f"This state has {multi_state_copy.n_batches} batches") +print(f"This state has {multi_state_copy.n_systems} systems") # we can pop states off while modifying the original state popped_states = multi_state_copy.pop([0, 2]) print( f"We popped {len(popped_states)} states, leaving us with " - f"{multi_state_copy.n_batches} batch in the original state" + f"{multi_state_copy.n_systems} systems in the original state" ) # we can put them back together with concatenate multi_state_full = ts.concatenate_states([*popped_states, multi_state_copy]) -print(f"Again we have {multi_state_full.n_batches} batches in the full state") +print(f"Again we have {multi_state_full.n_systems} systems in the full state") # or if we don't want to modify the original state, we can instead index into it # negative indexing @@ -253,14 +253,14 @@ **asdict(si_state), # Copy all SimState properties momenta=torch.zeros_like(si_state.positions), # Initial 0 momenta forces=torch.zeros_like(si_state.positions), # Initial 0 forces - energy=torch.zeros((si_state.n_batches,), device=si_state.device), # Initial 0 energy + energy=torch.zeros((si_state.n_systems,), device=si_state.device), # Initial 0 energy ) print("MDState properties:") scope = infer_property_scope(md_state) print("Global properties:", scope["global"]) print("Per-atom properties:", scope["per_atom"]) -print("Per-batch properties:", scope["per_batch"]) +print("Per-system properties:", scope["per_system"]) # %% [markdown] diff --git a/pyproject.toml b/pyproject.toml index f71feb13..723edafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torch_sim_atomistic" -version = "0.2.2" +version = "0.3.0" description = "A pytorch toolkit for calculating material properties using MLIPs" authors = [ { name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" }, diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 923f34f7..7ef84304 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -169,7 +169,7 @@ def test_model_output_validation( og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() - og_batch = sim_state.batch.clone() + og_batch = sim_state.system_idx.clone() og_atomic_numbers = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -177,7 +177,7 @@ def test_model_output_validation( # assert model did not mutate the input assert torch.allclose(og_positions, sim_state.positions) assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.batch) + assert torch.allclose(og_batch, sim_state.system_idx) assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) # assert model output has the correct keys diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 5d15d4a1..7be28997 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -115,10 +115,10 @@ def test_split_state(si_double_sim_state: ts.SimState) -> None: # Check each state has the correct properties for state in enumerate(split_states): - assert state[1].n_batches == 1 + assert state[1].n_systems == 1 assert torch.all( - state[1].batch == 0 - ) # Each split state should have batch indices reset to 0 + state[1].system_idx == 0 + ) # Each split state should have system indices reset to 0 assert state[1].n_atoms == si_double_sim_state.n_atoms // 2 assert state[1].positions.shape[0] == si_double_sim_state.n_atoms // 2 assert state[1].cell.shape[0] == 1 @@ -472,14 +472,14 @@ def test_in_flight_with_fire( batcher.load_states(fire_states) def convergence_fn(state: ts.SimState) -> bool: - batch_wise_max_force = torch.zeros( - state.n_batches, device=state.device, dtype=torch.float64 + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 ) max_forces = state.forces.norm(dim=1) - batch_wise_max_force = batch_wise_max_force.scatter_reduce( - dim=0, index=state.batch, src=max_forces, reduce="amax" + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" ) - return batch_wise_max_force < 5e-1 + return system_wise_max_force < 5e-1 all_completed_states, convergence_tensor = [], None while True: @@ -514,7 +514,7 @@ def test_binning_auto_batcher_with_fire( batch_lengths = [state.n_atoms for state in fire_states] optimal_batches = to_constant_volume_bins(batch_lengths, 400) - optimal_n_batches = len(optimal_batches) + optimal_n_systems = len(optimal_batches) batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=400 @@ -522,9 +522,9 @@ def test_binning_auto_batcher_with_fire( batcher.load_states(fire_states) finished_states = [] - n_batches = 0 + n_systems = 0 for batch in batcher: - n_batches += 1 + n_systems += 1 for _ in range(5): batch = fire_update(batch) @@ -535,7 +535,7 @@ def test_binning_auto_batcher_with_fire( for restored, original in zip(restored_states, fire_states, strict=True): assert torch.all(restored.atomic_numbers == original.atomic_numbers) # analytically determined to be optimal - assert n_batches == optimal_n_batches + assert n_systems == optimal_n_systems def test_in_flight_max_iterations( @@ -561,7 +561,7 @@ def test_in_flight_max_iterations( state, [] = batcher.next_batch(None, None) # Create a convergence tensor that never converges - convergence_tensor = torch.zeros(state.n_batches, dtype=torch.bool) + convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool) all_completed_states = [] iteration_count = 0 @@ -574,7 +574,7 @@ def test_in_flight_max_iterations( # Update convergence tensor for next iteration (still all False) if state is not None: - convergence_tensor = torch.zeros(state.n_batches, dtype=torch.bool) + convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool) if iteration_count > max_attempts + 4: raise ValueError("Should have terminated by now") diff --git a/tests/test_correlations.py b/tests/test_correlations.py index f14612de..31624819 100644 --- a/tests/test_correlations.py +++ b/tests/test_correlations.py @@ -31,12 +31,14 @@ def __init__(self, velocities: torch.Tensor, device: torch.device) -> None: self.velocities = velocities self.device = device # Required for TrajectoryReporter - self.n_batches = 1 - self.batch = torch.zeros(velocities.shape[0], device=device, dtype=torch.int64) + self.n_systems = 1 + self.system_idx = torch.zeros( + velocities.shape[0], device=device, dtype=torch.int64 + ) def split(self) -> list["MockState"]: - """Split state into batches.""" - # Just return self since 1 batch + """Split state into multiple systems.""" + # Just return self since 1 system return [self] diff --git a/tests/test_integrators.py b/tests/test_integrators.py index bad80122..b6923aa5 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -109,7 +109,7 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -172,7 +172,7 @@ def test_npt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -213,7 +213,7 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -273,7 +273,7 @@ def test_nvt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -372,8 +372,8 @@ def test_compare_single_vs_batched_integrators( torch.testing.assert_close(single_state.energy, final_state.energy) -def test_compute_cell_force_atoms_per_batch(): - """Test that compute_cell_force correctly scales by number of atoms per batch. +def test_compute_cell_force_atoms_per_system(): + """Test that compute_cell_force correctly scales by number of atoms per system. Covers fix in https://github.com/Radical-AI/torch-sim/pull/153.""" from torch_sim.integrators.npt import _compute_cell_force @@ -389,7 +389,7 @@ def test_compute_cell_force_atoms_per_batch(): masses=torch.ones(72), cell=torch.eye(3).repeat(2, 1, 1), pbc=True, - batch=torch.cat([s1, s2]), + system_idx=torch.cat([s1, s2]), atomic_numbers=torch.ones(72, dtype=torch.long), stress=torch.zeros((2, 3, 3)), reference_cell=torch.eye(3).repeat(2, 1, 1), diff --git a/tests/test_io.py b/tests/test_io.py index 84f665d5..c5043763 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -48,9 +48,10 @@ def test_multiple_structures_to_state( assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.system_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8) + state.system_idx + == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8) ) @@ -65,8 +66,8 @@ def test_single_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert state.cell.shape == (1, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (8,) - assert state.batch.shape == (8,) - assert torch.all(state.batch == 0) + assert state.system_idx.shape == (8,) + assert torch.all(state.system_idx == 0) def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: @@ -80,9 +81,10 @@ def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.system_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + state.system_idx + == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), ) @@ -171,9 +173,10 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.system_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + state.system_idx + == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), ) @@ -235,11 +238,11 @@ def test_state_round_trip( # Get the sim_state fixture dynamically using the name sim_state: ts.SimState = request.getfixturevalue(sim_state_name) to_format_fn, from_format_fn = conversion_functions - unique_batches = torch.unique(sim_state.batch) + unique_systems = torch.unique(sim_state.system_idx) # Convert to intermediate format intermediate_format = to_format_fn(sim_state) - assert len(intermediate_format) == len(unique_batches) + assert len(intermediate_format) == len(unique_systems) # Convert back to state round_trip_state: ts.SimState = from_format_fn(intermediate_format, device, dtype) @@ -248,7 +251,7 @@ def test_state_round_trip( assert torch.allclose(sim_state.positions, round_trip_state.positions) assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) - assert torch.all(sim_state.batch == round_trip_state.batch) + assert torch.all(sim_state.system_idx == round_trip_state.system_idx) assert sim_state.pbc == round_trip_state.pbc if isinstance(intermediate_format[0], Atoms): diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 880adefb..3be7787d 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -50,7 +50,7 @@ def test_generate_permutation( ): swaps = generate_swaps(batched_diverse_state, generator=generator) permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation(permutation, batched_diverse_state.batch) + validate_permutation(permutation, batched_diverse_state.system_idx) def test_generate_swaps(batched_diverse_state: ts.SimState, generator: torch.Generator): @@ -64,9 +64,9 @@ def test_generate_swaps(batched_diverse_state: ts.SimState, generator: torch.Gen assert torch.all(swaps >= 0) assert torch.all(swaps < batched_diverse_state.n_atoms) - # Check swaps are within same batch - batch = batched_diverse_state.batch - assert torch.all(batch[swaps[:, 0]] == batch[swaps[:, 1]]) + # Check swaps are within same system + system_idx = batched_diverse_state.system_idx + assert torch.all(system_idx[swaps[:, 0]] == system_idx[swaps[:, 1]]) def test_swaps_to_permutation( @@ -95,7 +95,9 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): # Valid permutation swaps = generate_swaps(batched_diverse_state) permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation(permutation, batched_diverse_state.batch) # Should not raise + validate_permutation( + permutation, batched_diverse_state.system_idx + ) # Should not raise # Invalid permutation (swap between batches) invalid_perm = permutation.clone() @@ -105,7 +107,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): invalid_perm[batched_diverse_state.n_atoms - 1] = 0 with pytest.raises(ValueError, match="Swaps must be between"): - validate_permutation(invalid_perm, batched_diverse_state.batch) + validate_permutation(invalid_perm, batched_diverse_state.system_idx) def test_monte_carlo( @@ -147,17 +149,17 @@ def test_monte_carlo( # Verify the state has changed after multiple steps assert not torch.allclose(current_state.positions, initial_positions) - # Verify batch assignments remain unchanged - assert torch.all(current_state.batch == batched_diverse_state.batch) + # Verify system_idx assignments remain unchanged + assert torch.all(current_state.system_idx == batched_diverse_state.system_idx) - # Verify atomic numbers distribution remains the same per batch - for batch_idx in torch.unique(current_state.batch): - batch_mask_orig = batched_diverse_state.batch == batch_idx - batch_mask_result = current_state.batch == batch_idx + # Verify atomic numbers distribution remains the same per system + for idx in torch.unique(current_state.system_idx): + system_mask_orig = batched_diverse_state.system_idx == idx + system_mask_result = current_state.system_idx == idx orig_counts = torch.bincount( - batched_diverse_state.atomic_numbers[batch_mask_orig] + batched_diverse_state.atomic_numbers[system_mask_orig] ) - result_counts = torch.bincount(current_state.atomic_numbers[batch_mask_result]) + result_counts = torch.bincount(current_state.atomic_numbers[system_mask_result]) assert torch.all(orig_counts == result_counts) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index eebc7391..205b626a 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -37,7 +37,7 @@ def ase_to_torch_batch( - pos: Tensor of atomic positions. - cell: Tensor of unit cell vectors. - pbc: Tensor indicating periodic boundary conditions. - - batch: Tensor indicating the batch index for each atom. + - system_idx: Tensor indicating the system index for each atom. - n_atoms: Tensor containing the number of atoms in each structure. """ n_atoms = torch.tensor([len(atoms) for atoms in atoms_list], dtype=torch.long) @@ -49,17 +49,17 @@ def ase_to_torch_batch( pbc = torch.cat([torch.from_numpy(atoms.get_pbc()) for atoms in atoms_list]) stride = torch.cat((torch.tensor([0]), n_atoms.cumsum(0))) - batch = torch.zeros(pos.shape[0], dtype=torch.long) + system_idx = torch.zeros(pos.shape[0], dtype=torch.long) for ii, (st, end) in enumerate( zip(stride[:-1], stride[1:], strict=True) # noqa: RUF007 ): - batch[st:end] = ii + system_idx[st:end] = ii n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long) return ( pos.to(dtype=dtype, device=device), cell.to(dtype=dtype, device=device), pbc.to(device=device), - batch.to(device=device), + system_idx.to(device=device), n_atoms.to(device=device), ) @@ -556,11 +556,11 @@ def test_neighbor_lists_time_and_memory( start_time = time.perf_counter() if nl_fn in [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell]: - batch = torch.zeros(n_atoms, dtype=torch.long, device=device) + system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device) # Fix pbc tensor shape pbc = torch.tensor([[True, True, True]], device=device) mapping, mapping_batch, shifts_idx = nl_fn( - cutoff, pos, cell, pbc, batch, self_interaction=False + cutoff, pos, cell, pbc, system_idx, self_interaction=False ) else: mapping, shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 87ff8697..a5bfa675 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -128,7 +128,7 @@ def test_fire_optimization( cell=ar_supercell_sim_state.cell.clone(), pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -229,7 +229,7 @@ def test_fire_ase_negative_power_branch( state = init_fn(ar_supercell_sim_state) # Save parameters from initial state - initial_dt_batch = state.dt.clone() # per-batch dt + initial_dt_batch = state.dt.clone() # per-system dt # Manipulate state to ensure P < 0 for the update_fn step # Ensure forces are non-trivial @@ -262,10 +262,10 @@ def test_fire_ase_negative_power_branch( # Assertions for velocity update in ASE P < 0 case: # v_after_mixing_is_0, then v_final = dt_new * F_at_power_calc expected_final_velocities = ( - expected_dt_val * forces_at_power_calc[updated_state.batch == 0] + expected_dt_val * forces_at_power_calc[updated_state.system_idx == 0] ) assert torch.allclose( - updated_state.velocities[updated_state.batch == 0], + updated_state.velocities[updated_state.system_idx == 0], expected_final_velocities, atol=1e-6, ) @@ -317,8 +317,8 @@ def test_fire_vv_negative_power_branch( # If P<0 branch was taken, velocities should be zeroed assert torch.allclose( - updated_state.velocities[updated_state.batch == 0], - torch.zeros_like(updated_state.velocities[updated_state.batch == 0]), + updated_state.velocities[updated_state.system_idx == 0], + torch.zeros_like(updated_state.velocities[updated_state.system_idx == 0]), atol=1e-7, ) @@ -345,7 +345,7 @@ def test_unit_cell_fire_optimization( cell=current_cell, pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -429,7 +429,7 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( assert opt_state.forces is not None assert opt_state.stress is not None expected_cf_tensor = torch.full( - (opt_state.n_batches, 1, 1), + (opt_state.n_systems, 1, 1), float(cell_factor_val), # Ensure float for comparison if int is passed device=lj_model.device, dtype=lj_model.dtype, @@ -452,11 +452,11 @@ def test_cell_optimizer_init_cell_factor_none( ) -> None: """Test cell optimizer init_fn with cell_factor=None.""" init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) - # Ensure n_batches > 0 for cell_factor calculation from counts - assert ar_supercell_sim_state.n_batches > 0 + # Ensure n_systems > 0 for cell_factor calculation from counts + assert ar_supercell_sim_state.n_systems > 0 opt_state = init_fn(ar_supercell_sim_state) # Uses ts.SimState directly assert isinstance(opt_state, expected_state_type) - _, counts = torch.unique(ar_supercell_sim_state.batch, return_counts=True) + _, counts = torch.unique(ar_supercell_sim_state.system_idx, return_counts=True) expected_cf_tensor = counts.to(dtype=lj_model.dtype).view(-1, 1, 1) assert torch.allclose(opt_state.cell_factor, expected_cf_tensor) assert opt_state.energy is not None @@ -525,7 +525,7 @@ def test_frechet_cell_fire_optimization( cell=current_cell, pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -874,6 +874,6 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: # position only optimizations for step, energy_unit_cell in enumerate(individual_energies_unit_cell): assert abs(energy_unit_cell - individual_energies_fire[step]) < 1e-4, ( - f"Energy for batch {step} doesn't match position only optimization: " - f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" + f"Energy for system {step} doesn't match position only optimization: " + f"system={energy_unit_cell}, individual={individual_energies_fire[step]}" ) diff --git a/tests/test_runners.py b/tests/test_runners.py index 5d7201c9..cd1ff3db 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -149,7 +149,7 @@ def test_integrate_many_nvt( lj_model.dtype, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -231,7 +231,7 @@ def test_integrate_with_autobatcher_and_reporting( max_memory_scaler=260, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -340,7 +340,7 @@ def test_batched_optimize_fire( ) -> None: """Test batched FIRE optimization with LJ potential.""" trajectory_files = [ - tmp_path / f"nvt_{idx}.h5md" for idx in range(ar_double_sim_state.n_batches) + tmp_path / f"nvt_{idx}.h5md" for idx in range(ar_double_sim_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -414,7 +414,7 @@ def test_optimize_with_autobatcher_and_reporting( ) trajectory_files = [ - tmp_path / f"opt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"opt_{batch}.h5md" for batch in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -798,7 +798,7 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: # autobatcher=True, # disabled for CPU-based LJ model in test ) - assert relaxed_state.energy.shape == (final_state.n_batches,) + assert relaxed_state.energy.shape == (final_state.n_systems,) @pytest.fixture @@ -806,21 +806,21 @@ def mock_state() -> Callable: """Create a mock state for testing convergence functions.""" device = torch.device("cpu") dtype = torch.float64 - n_batches, n_atoms = 2, 8 + n_systems, n_atoms = 2, 8 torch.manual_seed(0) # deterministic forces class MockState: def __init__(self, *, include_cell_forces: bool = True) -> None: self.forces = torch.randn(n_atoms, 3, device=device, dtype=dtype) - self.batch = torch.repeat_interleave( - torch.arange(n_batches), n_atoms // n_batches + self.system_idx = torch.repeat_interleave( + torch.arange(n_systems), n_atoms // n_systems ) self.device = device self.dtype = dtype - self.n_batches = n_batches + self.n_systems = n_systems if include_cell_forces: self.cell_forces = torch.randn( - n_batches, 3, 3, device=device, dtype=dtype + n_systems, 3, 3, device=device, dtype=dtype ) return MockState @@ -858,7 +858,7 @@ def test_generate_force_convergence_fn( if has_cell_forces: ar_supercell_sim_state.cell_forces = torch.randn( - ar_supercell_sim_state.n_batches, + ar_supercell_sim_state.n_systems, 3, 3, device=ar_supercell_sim_state.device, @@ -877,7 +877,7 @@ def test_generate_force_convergence_fn( result = convergence_fn(state) assert isinstance(result, torch.Tensor) assert result.dtype == torch.bool - assert result.shape == (state.n_batches,) + assert result.shape == (state.n_systems,) def test_generate_force_convergence_fn_tolerance_ordering( @@ -888,7 +888,7 @@ def test_generate_force_convergence_fn_tolerance_ordering( ar_supercell_sim_state.forces = model_output["forces"] ar_supercell_sim_state.energy = model_output["energy"] ar_supercell_sim_state.cell_forces = torch.randn( - ar_supercell_sim_state.n_batches, + ar_supercell_sim_state.n_systems, 3, 3, device=ar_supercell_sim_state.device, @@ -926,26 +926,26 @@ def test_generate_force_convergence_fn_logic( ) -> None: """Test convergence logic with controlled force values.""" device, dtype = torch.device("cpu"), torch.float64 - n_batches, n_atoms = len(atomic_forces), 8 + n_systems, n_atoms = len(atomic_forces), 8 class ControlledMockState: def __init__(self) -> None: - self.n_batches = n_batches + self.n_systems = n_systems self.device, self.dtype = device, dtype - self.batch = torch.repeat_interleave( - torch.arange(n_batches), n_atoms // n_batches + self.system_idx = torch.repeat_interleave( + torch.arange(n_systems), n_atoms // n_systems ) - # Set specific force magnitudes per batch + # Set specific force magnitudes per system self.forces = torch.zeros(n_atoms, 3, device=device, dtype=dtype) - self.cell_forces = torch.zeros(n_batches, 3, 3, device=device, dtype=dtype) + self.cell_forces = torch.zeros(n_systems, 3, 3, device=device, dtype=dtype) - for batch_idx, (atomic_force, cell_force) in enumerate( + for system_idx, (atomic_force, cell_force) in enumerate( zip(atomic_forces, cell_forces, strict=False) ): - batch_mask = self.batch == batch_idx - self.forces[batch_mask, 0] = atomic_force - self.cell_forces[batch_idx, 0, 0] = cell_force + system_mask = self.system_idx == system_idx + self.forces[system_mask, 0] = atomic_force + self.cell_forces[system_idx, 0, 0] = cell_force state = ControlledMockState() convergence_fn = ts.generate_force_convergence_fn( diff --git a/tests/test_state.py b/tests/test_state.py index 1e5f325b..ea57dd3a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -9,7 +9,7 @@ from torch_sim.state import ( DeformGradMixin, SimState, - _normalize_batch_indices, + _normalize_system_indices, _pop_states, _slice_state, concatenate_states, @@ -28,8 +28,13 @@ def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: """Test inference of property scope.""" scope = infer_property_scope(si_sim_state) assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"} - assert set(scope["per_batch"]) == {"cell"} + assert set(scope["per_atom"]) == { + "positions", + "masses", + "atomic_numbers", + "system_idx", + } + assert set(scope["per_system"]) == {"cell"} def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: @@ -46,19 +51,19 @@ def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: "positions", "masses", "atomic_numbers", - "batch", + "system_idx", "forces", "momenta", } - assert set(scope["per_batch"]) == {"cell", "energy"} + assert set(scope["per_system"]) == {"cell", "energy"} def test_slice_substate( si_double_sim_state: ts.SimState, si_sim_state: ts.SimState ) -> None: """Test slicing a substate from the SimState.""" - for batch_index in range(2): - substate = _slice_state(si_double_sim_state, [batch_index]) + for system_index in range(2): + substate = _slice_state(si_double_sim_state, [system_index]) assert isinstance(substate, SimState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) @@ -67,7 +72,7 @@ def test_slice_substate( assert torch.allclose(substate.masses, si_sim_state.masses) assert torch.allclose(substate.cell, si_sim_state.cell) assert torch.allclose(substate.atomic_numbers, si_sim_state.atomic_numbers) - assert torch.allclose(substate.batch, torch.zeros_like(substate.batch)) + assert torch.allclose(substate.system_idx, torch.zeros_like(substate.system_idx)) def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: @@ -77,8 +82,8 @@ def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: energy=torch.zeros((2,), device=si_double_sim_state.device), forces=torch.randn_like(si_double_sim_state.positions), ) - for batch_index in range(2): - substate = _slice_state(state, [batch_index]) + for system_index in range(2): + substate = _slice_state(state, [system_index]) assert isinstance(substate, MDState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) @@ -101,10 +106,10 @@ def test_concatenate_two_si_states( assert concatenated.masses.shape == si_double_sim_state.masses.shape assert concatenated.cell.shape == si_double_sim_state.cell.shape assert concatenated.atomic_numbers.shape == si_double_sim_state.atomic_numbers.shape - assert concatenated.batch.shape == si_double_sim_state.batch.shape + assert concatenated.system_idx.shape == si_double_sim_state.system_idx.shape - # Check batch indices - expected_batch = torch.cat( + # Check system indices + expected_system_indices = torch.cat( [ torch.zeros( si_sim_state.n_atoms, dtype=torch.int64, device=si_sim_state.device @@ -114,12 +119,12 @@ def test_concatenate_two_si_states( ), ] ) - assert torch.all(concatenated.batch == expected_batch) + assert torch.all(concatenated.system_idx == expected_system_indices) - # Check that positions match (accounting for batch indices) - for batch_idx in range(2): - mask_concat = concatenated.batch == batch_idx - mask_double = si_double_sim_state.batch == batch_idx + # Check that positions match (accounting for system indices) + for system_idx in range(2): + mask_concat = concatenated.system_idx == system_idx + mask_double = si_double_sim_state.system_idx == system_idx assert torch.allclose( concatenated.positions[mask_concat], si_double_sim_state.positions[mask_double], @@ -143,22 +148,22 @@ def test_concatenate_si_and_fe_states( concatenated.masses.shape[0] == si_sim_state.masses.shape[0] + fe_supercell_sim_state.masses.shape[0] ) - assert concatenated.cell.shape[0] == 2 # One cell per batch + assert concatenated.cell.shape[0] == 2 # One cell per system - # Check batch indices + # Check system indices si_atoms = si_sim_state.n_atoms fe_atoms = fe_supercell_sim_state.n_atoms - expected_batch = torch.cat( + expected_system_indices = torch.cat( [ torch.zeros(si_atoms, dtype=torch.int64, device=si_sim_state.device), torch.ones(fe_atoms, dtype=torch.int64, device=fe_supercell_sim_state.device), ] ) - assert torch.all(concatenated.batch == expected_batch) + assert torch.all(concatenated.system_idx == expected_system_indices) - # check n_atoms_per_batch + # check n_atoms_per_system assert torch.all( - concatenated.n_atoms_per_batch + concatenated.n_atoms_per_system == torch.tensor( [si_sim_state.n_atoms, fe_supercell_sim_state.n_atoms], device=concatenated.device, @@ -192,22 +197,22 @@ def test_concatenate_double_si_and_fe_states( ) assert ( concatenated.cell.shape[0] == 3 - ) # One cell for each original batch (2 Si + 1 Ar) + ) # One cell for each original system (2 Si + 1 Ar) - # Check batch indices + # Check system indices fe_atoms = fe_supercell_sim_state.n_atoms - # The double Si state already has batches 0 and 1, so Ar should be batch 2 - expected_batch = torch.cat( + # The double Si state already has systems 0 and 1, so Ar should be system 2 + expected_system_indices = torch.cat( [ - si_double_sim_state.batch, + si_double_sim_state.system_idx, torch.full( (fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device ), ] ) - assert torch.all(concatenated.batch == expected_batch) - assert torch.unique(concatenated.batch).shape[0] == 3 + assert torch.all(concatenated.system_idx == expected_system_indices) + assert torch.unique(concatenated.system_idx).shape[0] == 3 # Check that we can slice back to the original states si_slice_0 = concatenated[0] @@ -223,14 +228,14 @@ def test_concatenate_double_si_and_fe_states( def test_split_state(si_double_sim_state: ts.SimState) -> None: """Test splitting a state into a list of states.""" states = si_double_sim_state.split() - assert len(states) == si_double_sim_state.n_batches + assert len(states) == si_double_sim_state.n_systems for state in states: assert isinstance(state, ts.SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) assert state.atomic_numbers.shape == (8,) - assert torch.allclose(state.batch, torch.zeros_like(state.batch)) + assert torch.allclose(state.system_idx, torch.zeros_like(state.system_idx)) def test_split_many_states( @@ -248,7 +253,7 @@ def test_split_many_states( assert torch.allclose(sub_state.masses, state.masses) assert torch.allclose(sub_state.cell, state.cell) assert torch.allclose(sub_state.atomic_numbers, state.atomic_numbers) - assert torch.allclose(sub_state.batch, state.batch) + assert torch.allclose(sub_state.system_idx, state.system_idx) assert len(states) == 3 @@ -276,7 +281,7 @@ def test_pop_states( assert kept_state.masses.shape == (len_kept,) assert kept_state.cell.shape == (2, 3, 3) assert kept_state.atomic_numbers.shape == (len_kept,) - assert kept_state.batch.shape == (len_kept,) + assert kept_state.system_idx.shape == (len_kept,) def test_initialize_state_from_structure( @@ -337,8 +342,8 @@ def test_state_pop_method( assert torch.allclose(popped_states[0].positions, ar_supercell_sim_state.positions) # Verify the original state was modified - assert concatenated.n_batches == 2 - assert torch.unique(concatenated.batch).tolist() == [0, 1] + assert concatenated.n_systems == 2 + assert torch.unique(concatenated.system_idx).tolist() == [0, 1] # Test popping multiple batches multi_state = concatenate_states(states) @@ -348,8 +353,8 @@ def test_state_pop_method( assert torch.allclose(popped_multi[1].positions, fe_supercell_sim_state.positions) # Verify the original multi-state was modified - assert multi_state.n_batches == 1 - assert torch.unique(multi_state.batch).tolist() == [0] + assert multi_state.n_systems == 1 + assert torch.unique(multi_state.system_idx).tolist() == [0] assert torch.allclose(multi_state.positions, ar_supercell_sim_state.positions) @@ -367,19 +372,19 @@ def test_state_getitem( single_state = concatenated[1] assert isinstance(single_state, SimState) assert torch.allclose(single_state.positions, ar_supercell_sim_state.positions) - assert single_state.n_batches == 1 + assert single_state.n_systems == 1 # Test list indexing multi_state = concatenated[[0, 2]] assert isinstance(multi_state, SimState) - assert multi_state.n_batches == 2 + assert multi_state.n_systems == 2 assert torch.allclose(multi_state[0].positions, si_sim_state.positions) assert torch.allclose(multi_state[1].positions, fe_supercell_sim_state.positions) # Test slice indexing slice_state = concatenated[1:3] assert isinstance(slice_state, SimState) - assert slice_state.n_batches == 2 + assert slice_state.n_systems == 2 assert torch.allclose(slice_state[0].positions, ar_supercell_sim_state.positions) assert torch.allclose(slice_state[1].positions, fe_supercell_sim_state.positions) @@ -391,67 +396,67 @@ def test_state_getitem( # Test step in slice step_state = concatenated[::2] assert isinstance(step_state, SimState) - assert step_state.n_batches == 2 + assert step_state.n_systems == 2 assert torch.allclose(step_state[0].positions, si_sim_state.positions) assert torch.allclose(step_state[1].positions, fe_supercell_sim_state.positions) full_state = concatenated[:] assert torch.allclose(full_state.positions, concatenated.positions) # Verify original state is unchanged - assert concatenated.n_batches == 3 + assert concatenated.n_systems == 3 -def test_normalize_batch_indices(si_double_sim_state: ts.SimState) -> None: - """Test the _normalize_batch_indices utility method.""" +def test_normalize_system_indices(si_double_sim_state: ts.SimState) -> None: + """Test the _normalize_system_indices utility method.""" state = si_double_sim_state # State with 2 batches - n_batches = state.n_batches + n_systems = state.n_systems device = state.device # Test integer indexing - assert _normalize_batch_indices(0, n_batches, device).tolist() == [0] - assert _normalize_batch_indices(1, n_batches, device).tolist() == [1] + assert _normalize_system_indices(0, n_systems, device).tolist() == [0] + assert _normalize_system_indices(1, n_systems, device).tolist() == [1] # Test negative integer indexing - assert _normalize_batch_indices(-1, n_batches, device).tolist() == [1] - assert _normalize_batch_indices(-2, n_batches, device).tolist() == [0] + assert _normalize_system_indices(-1, n_systems, device).tolist() == [1] + assert _normalize_system_indices(-2, n_systems, device).tolist() == [0] # Test list indexing - assert _normalize_batch_indices([0, 1], n_batches, device).tolist() == [0, 1] + assert _normalize_system_indices([0, 1], n_systems, device).tolist() == [0, 1] # Test list with negative indices - assert _normalize_batch_indices([0, -1], n_batches, device).tolist() == [0, 1] - assert _normalize_batch_indices([-2, -1], n_batches, device).tolist() == [0, 1] + assert _normalize_system_indices([0, -1], n_systems, device).tolist() == [0, 1] + assert _normalize_system_indices([-2, -1], n_systems, device).tolist() == [0, 1] # Test slice indexing - indices = _normalize_batch_indices(slice(0, 2), n_batches, device) + indices = _normalize_system_indices(slice(0, 2), n_systems, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test slice with negative indices - indices = _normalize_batch_indices(slice(-2, None), n_batches, device) + indices = _normalize_system_indices(slice(-2, None), n_systems, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test slice with step - indices = _normalize_batch_indices(slice(0, 2, 2), n_batches, device) + indices = _normalize_system_indices(slice(0, 2, 2), n_systems, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0], device=state.device)) # Test tensor indexing tensor_indices = torch.tensor([0, 1], device=state.device) - indices = _normalize_batch_indices(tensor_indices, n_batches, device) + indices = _normalize_system_indices(tensor_indices, n_systems, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == tensor_indices) # Test tensor with negative indices tensor_indices = torch.tensor([0, -1], device=state.device) - indices = _normalize_batch_indices(tensor_indices, n_batches, device) + indices = _normalize_system_indices(tensor_indices, n_systems, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test error for unsupported type try: - _normalize_batch_indices((0, 1), n_batches, device) # Tuple is not supported + _normalize_system_indices((0, 1), n_systems, device) # Tuple is not supported raise ValueError("Should have raised TypeError") except TypeError: pass @@ -601,7 +606,9 @@ def test_deform_grad_batched(device: torch.device) -> None: atomic_numbers=torch.ones(n_atoms * batch_size, device=device, dtype=torch.long), velocities=torch.randn(n_atoms * batch_size, 3, device=device), reference_cell=reference_cell, - batch=torch.repeat_interleave(torch.arange(batch_size, device=device), n_atoms), + system_idx=torch.repeat_interleave( + torch.arange(batch_size, device=device), n_atoms + ), ) deform_grad = state.deform_grad() @@ -611,3 +618,28 @@ def test_deform_grad_batched(device: torch.device) -> None: for i in range(batch_size): expected = expected_factors[i] * torch.eye(3, device=device) assert torch.allclose(deform_grad[i], expected) + + +def test_deprecated_batch_properties_equal_to_new_system_properties( + device: torch.device, +) -> None: + """Test that deprecated batch properties are equal to new system properties. + + This tests that the rename from batch to system is not breaking anything.""" + state = SimState( + positions=torch.randn(10, 3, device=device), + masses=torch.ones(10, device=device), + cell=torch.eye(3, device=device).unsqueeze(0).repeat(2, 1, 1), + pbc=True, + atomic_numbers=torch.ones(10, device=device, dtype=torch.long), + system_idx=torch.repeat_interleave(torch.arange(2, device=device), 5), + ) + assert state.batch is state.system_idx + assert state.n_batches == state.n_systems + assert torch.allclose(state.n_atoms_per_batch, state.n_atoms_per_system) + + # now test that assigning the old .batch property behaves the same + new_system_idx = torch.arange(4, device=device) + state.batch = new_system_idx + assert torch.allclose(state.system_idx, new_system_idx) + assert torch.allclose(state.batch, new_system_idx) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 3a662242..e003f7a7 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -32,7 +32,7 @@ def random_state() -> MDState: ), cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), atomic_numbers=torch.ones(10, dtype=torch.int32), - batch=torch.zeros(10, dtype=torch.int32), + system_idx=torch.zeros(10, dtype=torch.int32), pbc=True, ) @@ -678,9 +678,9 @@ def test_multi_batch_reporter( # Check that each trajectory has the correct number of atoms # (should be half of the total in the double state) - atoms_per_batch = si_double_sim_state.positions.shape[0] // 2 - assert traj0.get_array("positions").shape[1] == atoms_per_batch - assert traj1.get_array("positions").shape[1] == atoms_per_batch + atoms_per_system = si_double_sim_state.positions.shape[0] // 2 + assert traj0.get_array("positions").shape[1] == atoms_per_system + assert traj1.get_array("positions").shape[1] == atoms_per_system # Check property data assert "ones" in traj0.array_registry @@ -698,11 +698,11 @@ def test_property_model_consistency( """Test property models are consistent for single and multi-batch cases.""" # Create reporters for single and multi-batch cases single_reporters = [] - for batch_idx in range(2): + for system_idx in range(2): # Extract single batch states - single_state = si_double_sim_state[batch_idx] + single_state = si_double_sim_state[system_idx] reporter = TrajectoryReporter( - tmp_path / f"single_{batch_idx}.hdf5", + tmp_path / f"single_{system_idx}.hdf5", state_frequency=1, prop_calculators=prop_calculators, ) @@ -710,7 +710,7 @@ def test_property_model_consistency( reporter.report(single_state, 0) reporter.close() single_reporters.append( - TorchSimTrajectory(tmp_path / f"single_{batch_idx}.hdf5", mode="r") + TorchSimTrajectory(tmp_path / f"single_{system_idx}.hdf5", mode="r") ) # Create multi-batch reporter @@ -727,14 +727,14 @@ def test_property_model_consistency( TorchSimTrajectory(tmp_path / "multi_1.hdf5", mode="r"), ] - # Compare property values between single and multi-batch approaches - for batch_idx in range(2): - single_ke = single_reporters[batch_idx].get_array("ones")[0] - multi_ke = multi_trajectories[batch_idx].get_array("ones")[0] + # Compare property values between single and multi-system approaches + for system_idx in range(2): + single_ke = single_reporters[system_idx].get_array("ones")[0] + multi_ke = multi_trajectories[system_idx].get_array("ones")[0] assert torch.allclose(torch.tensor(single_ke), torch.tensor(multi_ke)) - single_com = single_reporters[batch_idx].get_array("center_of_mass")[0] - multi_com = multi_trajectories[batch_idx].get_array("center_of_mass")[0] + single_com = single_reporters[system_idx].get_array("center_of_mass")[0] + multi_com = multi_trajectories[system_idx].get_array("center_of_mass")[0] assert torch.allclose(torch.tensor(single_com), torch.tensor(multi_com)) # Close all trajectories @@ -767,12 +767,12 @@ def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tenso reporter.close() # Verify properties were returned - assert len(props) == 2 # One dict per batch - for batch_props in props: - assert set(batch_props) == {"energy"} - assert isinstance(batch_props["energy"], torch.Tensor) - assert batch_props["energy"].shape == (1,) - assert batch_props["energy"] == pytest.approx(49.4150) + assert len(props) == 2 # One dict per system + for system_props in props: + assert set(system_props) == {"energy"} + assert isinstance(system_props["energy"], torch.Tensor) + assert system_props["energy"].shape == (1,) + assert system_props["energy"] == pytest.approx(49.4150) # Verify property was calculated correctly trajectories = [ @@ -780,21 +780,21 @@ def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tenso TorchSimTrajectory(tmp_path / "model_1.hdf5", mode="r"), ] - for batch_idx, trajectory in enumerate(trajectories): + for system_idx, trajectory in enumerate(trajectories): # Get the property value from file file_energy = trajectory.get_array("energy")[0] - batch_props = props[batch_idx] + system_props = props[system_idx] # Calculate expected value - substate = si_double_sim_state[batch_idx] + substate = si_double_sim_state[system_idx] expected = lj_model(substate)["energy"] # Compare file contents with expected np.testing.assert_allclose(file_energy, expected) # Compare returned properties with expected - np.testing.assert_allclose(batch_props["energy"], expected) + np.testing.assert_allclose(system_props["energy"], expected) # Compare returned properties with file contents - np.testing.assert_allclose(batch_props["energy"], file_energy) + np.testing.assert_allclose(system_props["energy"], file_energy) trajectory.close() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 879a70a4..4c05e658 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -250,8 +250,8 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None # Modify a specific atom's position in each batch to be outside the cell # Get the first atom in each batch - batch_0_mask = state.batch == 0 - batch_1_mask = state.batch == 1 + batch_0_mask = state.system_idx == 0 + batch_1_mask = state.system_idx == 1 # Get current cell size (assume cubic for simplicity) cell_size = state.cell[0, 0, 0] @@ -268,7 +268,9 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None test_positions[idx1, 0] = -0.5 # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, system_idx=state.system_idx + ) # Check first modified atom is properly wrapped assert wrapped[idx0, 0] < cell_size @@ -317,7 +319,7 @@ def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: cell = torch.stack([cell1, cell2]) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, batch=batch) + wrapped = tst.pbc_wrap_batched(positions, cell=cell, system_idx=batch) # Calculate expected result for first atom (using original algorithm for verification) expected1 = tst.pbc_wrap_general(positions[0:1], cell1) @@ -343,11 +345,11 @@ def test_pbc_wrap_batched_edge_case(device: torch.device) -> None: device=device, ) - # Create batch indices - batch = torch.tensor([0, 1], device=device) + # Create system indices + system_idx = torch.tensor([0, 1], device=device) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, batch=batch) + wrapped = tst.pbc_wrap_batched(positions, cell=cell, system_idx=system_idx) # Expected results (wrapping to 0.0 rather than 2.0) expected = torch.tensor( @@ -367,12 +369,12 @@ def test_pbc_wrap_batched_invalid_inputs(device: torch.device) -> None: # Valid inputs for reference positions = torch.ones(4, 3, device=device) cell = torch.stack([torch.eye(3, device=device)] * 2) - batch = torch.tensor([0, 0, 1, 1], device=device) + system_idx = torch.tensor([0, 0, 1, 1], device=device) # Test integer tensors with pytest.raises(TypeError): tst.pbc_wrap_batched( - torch.ones(4, 3, dtype=torch.int64, device=device), cell, batch + torch.ones(4, 3, dtype=torch.int64, device=device), cell, system_idx ) # Test dimension mismatch - positions @@ -380,15 +382,15 @@ def test_pbc_wrap_batched_invalid_inputs(device: torch.device) -> None: tst.pbc_wrap_batched( torch.ones(4, 2, device=device), # Wrong dimension (2 instead of 3) cell, - batch, + system_idx, ) - # Test mismatch between batch indices and cell + # Test mismatch between system indices and cell with pytest.raises(ValueError): tst.pbc_wrap_batched( positions, torch.stack([torch.eye(3, device=device)] * 3), # 3 cell but only 2 batches - batch, + system_idx, ) @@ -399,40 +401,42 @@ def test_pbc_wrap_batched_multi_atom(si_double_sim_state: ts.SimState) -> None: # Get a copy of positions to modify test_positions = state.positions.clone() - # Move all atoms of the first batch outside the cell in +x - batch_0_mask = state.batch == 0 + # Move all atoms of the first system outside the cell in +x + system_0_mask = state.system_idx == 0 cell_size_x = state.cell[0, 0, 0].item() - test_positions[batch_0_mask, 0] += cell_size_x + test_positions[system_0_mask, 0] += cell_size_x - # Move all atoms of the second batch outside the cell in -y - batch_1_mask = state.batch == 1 + # Move all atoms of the second system outside the cell in -y + system_1_mask = state.system_idx == 1 cell_size_y = state.cell[0, 1, 1].item() - test_positions[batch_1_mask, 1] -= cell_size_y + test_positions[system_1_mask, 1] -= cell_size_y # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, system_idx=state.system_idx + ) # Check all positions are within the cell boundaries - for b in range(2): # For each batch - batch_mask = state.batch == b + for b in range(2): # For each system + system_mask = state.system_idx == b # Check x coordinates - assert torch.all(wrapped[batch_mask, 0] >= 0) - assert torch.all(wrapped[batch_mask, 0] < state.cell[b, 0, 0]) + assert torch.all(wrapped[system_mask, 0] >= 0) + assert torch.all(wrapped[system_mask, 0] < state.cell[b, 0, 0]) # Check y coordinates - assert torch.all(wrapped[batch_mask, 1] >= 0) - assert torch.all(wrapped[batch_mask, 1] < state.cell[b, 1, 1]) + assert torch.all(wrapped[system_mask, 1] >= 0) + assert torch.all(wrapped[system_mask, 1] < state.cell[b, 1, 1]) # Check z coordinates - assert torch.all(wrapped[batch_mask, 2] >= 0) - assert torch.all(wrapped[batch_mask, 2] < state.cell[b, 2, 2]) + assert torch.all(wrapped[system_mask, 2] >= 0) + assert torch.all(wrapped[system_mask, 2] < state.cell[b, 2, 2]) def test_pbc_wrap_batched_preserves_relative_positions( si_double_sim_state: ts.SimState, ) -> None: - """Test that relative positions within each batch are preserved after wrapping.""" + """Test that relative positions within each system are preserved after wrapping.""" state = si_double_sim_state # Get a copy of positions @@ -443,20 +447,22 @@ def test_pbc_wrap_batched_preserves_relative_positions( test_positions += torch.tensor([10.0, 15.0, 20.0], device=state.device) # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, system_idx=state.system_idx + ) - # Check that relative positions within each batch are preserved + # Check that relative positions within each system are preserved for b in range(2): # For each batch - batch_mask = state.batch == b + system_idx_mask = state.system_idx == b # Calculate pairwise distances before wrapping - atoms_in_batch = torch.sum(batch_mask).item() + atoms_in_batch = torch.sum(system_idx_mask).item() for n_atoms in range(atoms_in_batch - 1): for j in range(n_atoms + 1, atoms_in_batch): # Get the indices of atoms i and j in this batch - batch_indices = torch.where(batch_mask)[0] - idx_i = batch_indices[n_atoms] - idx_j = batch_indices[j] + system_indices = torch.where(system_idx_mask)[0] + idx_i = system_indices[n_atoms] + idx_j = system_indices[j] # Original vector from i to j orig_vec = ( @@ -839,11 +845,11 @@ def test_get_fractional_coordinates_batched() -> None: [[1.0, 1.0, 1.0], [2.0, 0.0, 0.0]], device=device, dtype=dtype ) - # Test single batch case (should work) - cell_single_batch = torch.tensor( + # Test single system case (should work) + cell_single_system = torch.tensor( [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]], device=device, dtype=dtype ) - frac_batched = tst.get_fractional_coordinates(positions, cell_single_batch) + frac_batched = tst.get_fractional_coordinates(positions, cell_single_system) # Compare with 2D case cell_2d = torch.tensor( @@ -852,11 +858,11 @@ def test_get_fractional_coordinates_batched() -> None: frac_2d = tst.get_fractional_coordinates(positions, cell_2d) assert torch.allclose(frac_batched, frac_2d), ( - "Single batch case should produce same result as 2D case" + "Single system case should produce same result as 2D case" ) - # Test multi-batch case (should raise NotImplementedError) - cell_multi_batch = torch.tensor( + # Test multi-system case (should raise NotImplementedError) + cell_multi_system = torch.tensor( [ [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], [[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]], @@ -865,8 +871,8 @@ def test_get_fractional_coordinates_batched() -> None: dtype=dtype, ) - with pytest.raises(NotImplementedError, match="Multiple batched cell tensors"): - tst.get_fractional_coordinates(positions, cell_multi_batch) + with pytest.raises(NotImplementedError, match="Multiple system cell tensors"): + tst.get_fractional_coordinates(positions, cell_multi_system) @pytest.mark.parametrize( diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index d436076a..9cd55673 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -234,7 +234,7 @@ def measure_model_memory_forward(state: SimState, model: ModelInterface) -> floa print( # noqa: T201 "Model Memory Estimation: Running forward pass on state with " - f"{state.n_atoms} atoms and {state.n_batches} batches.", + f"{state.n_atoms} atoms and {state.n_systems} systems.", ) # Clear GPU memory torch.cuda.synchronize() @@ -293,8 +293,8 @@ def determine_max_batch_size( sizes.append(next_size) for i in range(len(sizes)): - n_batches = sizes[i] - concat_state = concatenate_states([state] * n_batches) + n_systems = sizes[i] + concat_state = concatenate_states([state] * n_systems) try: measure_model_memory_forward(concat_state, model) @@ -343,7 +343,7 @@ def calculate_memory_scaler( # Calculate memory scaling factor based on atom count and density metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density") """ - if state.n_batches > 1: + if state.n_systems > 1: return sum(calculate_memory_scaler(s, memory_scales_with) for s in state.split()) if memory_scales_with == "n_atoms": return state.n_atoms @@ -405,8 +405,8 @@ def estimate_max_memory_scaler( print( # noqa: T201 "Model Memory Estimation: Estimating memory from worst case of " f"largest and smallest system. Largest system has {max_state.n_atoms} atoms " - f"and {max_state.n_batches} batches, and smallest system has " - f"{min_state.n_atoms} atoms and {min_state.n_batches} batches.", + f"and {max_state.n_systems} batches, and smallest system has " + f"{min_state.n_atoms} atoms and {min_state.n_systems} batches.", ) min_state_max_batches = determine_max_batch_size(min_state, model, **kwargs) max_state_max_batches = determine_max_batch_size(max_state, model, **kwargs) @@ -428,7 +428,7 @@ class BinningAutoBatcher: Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per batch. + max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. return_indices (bool): Whether to return original indices with batches. state_slices (list[SimState]): Individual states to be batched. @@ -477,7 +477,7 @@ def __init__( - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density Defaults to "n_atoms_x_density". - max_memory_scaler (float | None): Maximum metric value allowed per batch. If + max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. return_indices (bool): Whether to return original indices along with batches. Defaults to False. @@ -716,7 +716,7 @@ class InFlightAutoBatcher: Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per batch. + max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. return_indices (bool): Whether to return original indices with batches. max_iterations (int | None): Maximum number of iterations per state. @@ -776,7 +776,7 @@ def __init__( - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density Defaults to "n_atoms_x_density". - max_memory_scaler (float | None): Maximum metric value allowed per batch. + max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. return_indices (bool): Whether to return original indices along with batches. Defaults to False. @@ -933,13 +933,13 @@ def _get_first_batch(self) -> SimState: # if max_metric is not set, estimate it has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - n_batches = determine_max_batch_size( + n_systems = determine_max_batch_size( first_state, self.model, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, ) - self.max_memory_scaler = n_batches * first_metric * 0.8 + self.max_memory_scaler = n_systems * first_metric * 0.8 states = self._get_next_states() @@ -971,7 +971,7 @@ def next_batch( # noqa: C901 for the first call. Contains shape information specific to the SimState instance. convergence_tensor (torch.Tensor | None): Boolean tensor with shape - [n_batches] indicating which states have converged (True) or not + [n_systems] indicating which states have converged (True) or not (False). Should be None only for the first call. Returns: @@ -1019,14 +1019,14 @@ def next_batch( # noqa: C901 # assert statements helpful for debugging, should be moved to validate fn # the first two are most important - if len(convergence_tensor) != updated_state.n_batches: - raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_batches=}") + if len(convergence_tensor) != updated_state.n_systems: + raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_systems=}") if len(self.current_idx) != len(self.current_scalers): raise ValueError(f"{len(self.current_idx)=} != {len(self.current_scalers)=}") if len(convergence_tensor.shape) != 1: raise ValueError(f"{len(convergence_tensor.shape)=} != 1") - if updated_state.n_batches <= 0: - raise ValueError(f"{updated_state.n_batches=} <= 0") + if updated_state.n_systems <= 0: + raise ValueError(f"{updated_state.n_systems=} <= 0") # Increment attempt counters and check for max attempts in a single loop for cur_idx, abs_idx in enumerate(self.current_idx): @@ -1057,7 +1057,7 @@ def next_batch( # noqa: C901 ) # concatenate remaining state with next states - if updated_state.n_batches > 0: + if updated_state.n_systems > 0: next_states = [updated_state, *next_states] next_batch = concatenate_states(next_states) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 8640d965..ce15877d 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -21,17 +21,17 @@ class MDState(SimState): Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] momenta (torch.Tensor): Particle momenta [n_particles, n_dim] - energy (torch.Tensor): Total energy of the system [n_batches] + energy (torch.Tensor): Total energy of the system [n_systems] forces (torch.Tensor): Forces on particles [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - cell (torch.Tensor): Simulation cell matrix [n_batches, n_dim, n_dim] + cell (torch.Tensor): Simulation cell matrix [n_systems, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions - batch (torch.Tensor): Batch indices [n_particles] + system_idx (torch.Tensor): System indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] Properties: velocities (torch.Tensor): Particle velocities [n_particles, n_dim] - n_batches (int): Number of independent systems in the batch + n_systems (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ @@ -51,7 +51,7 @@ def velocities(self) -> torch.Tensor: def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, - batch: torch.Tensor, + system_idx: torch.Tensor, kT: torch.Tensor | float, seed: int | None = None, ) -> torch.Tensor: @@ -64,8 +64,8 @@ def calculate_momenta( Args: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - batch (torch.Tensor): Batch indices [n_particles] - kT (torch.Tensor): Temperature in energy units [n_batches] + system_idx (torch.Tensor): System indices [n_particles] + kT (torch.Tensor): Temperature in energy units [n_systems] seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -79,32 +79,32 @@ def calculate_momenta( generator.manual_seed(seed) if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: - # kT is a tensor with shape (n_batches,) - kT = kT[batch] + # kT is a tensor with shape (n_systems,) + kT = kT[system_idx] # Generate random momenta from normal distribution momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) - batchwise_momenta = torch.zeros( - (batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype + systemwise_momenta = torch.zeros( + (system_idx[-1] + 1, momenta.shape[1]), device=device, dtype=dtype ) - # create 3 copies of batch - batch_3 = batch.view(-1, 1).repeat(1, 3) - bincount = torch.bincount(batch) + # create 3 copies of system_idx + system_idx_3 = system_idx.view(-1, 1).repeat(1, 3) + bincount = torch.bincount(system_idx) mean_momenta = torch.scatter_reduce( - batchwise_momenta, + systemwise_momenta, dim=0, - index=batch_3, + index=system_idx_3, src=momenta, reduce="sum", ) / bincount.view(-1, 1) return torch.where( torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), - momenta - mean_momenta[batch], + momenta - mean_momenta[system_idx], momenta, ) @@ -118,7 +118,7 @@ def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: Args: state (MDState): Current system state containing forces and momenta - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] Returns: MDState: Updated state with new momenta after force application @@ -138,7 +138,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: Args: state (MDState): Current system state containing positions and velocities - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] Returns: MDState: Updated state with new positions after propagation @@ -147,9 +147,9 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: new_positions = state.positions + state.velocities * dt if state.pbc: - # Split positions and cells by batch + # Split positions and cells by system new_positions = transforms.pbc_wrap_batched( - new_positions, state.cell, state.batch + new_positions, state.cell, state.system_idx ) state.positions = new_positions diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index eb7b1f4d..27b78c8b 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -31,25 +31,25 @@ class NPTLangevinState(SimState): Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] velocities (torch.Tensor): Particle velocities [n_particles, n_dim] - energy (torch.Tensor): Energy of the system [n_batches] + energy (torch.Tensor): Energy of the system [n_systems] forces (torch.Tensor): Forces on particles [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - cell (torch.Tensor): Simulation cell matrix [n_batches, n_dim, n_dim] + cell (torch.Tensor): Simulation cell matrix [n_systems, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions - batch (torch.Tensor): Batch indices [n_particles] + system_idx (torch.Tensor): System indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] - stress (torch.Tensor): Stress tensor [n_batches, n_dim, n_dim] + stress (torch.Tensor): Stress tensor [n_systems, n_dim, n_dim] reference_cell (torch.Tensor): Original cell vectors used as reference for - scaling [n_batches, n_dim, n_dim] - cell_positions (torch.Tensor): Cell positions [n_batches, n_dim, n_dim] - cell_velocities (torch.Tensor): Cell velocities [n_batches, n_dim, n_dim] + scaling [n_systems, n_dim, n_dim] + cell_positions (torch.Tensor): Cell positions [n_systems, n_dim, n_dim] + cell_velocities (torch.Tensor): Cell velocities [n_systems, n_dim, n_dim] cell_masses (torch.Tensor): Masses associated with the cell degrees of freedom - shape [n_batches] + shape [n_systems] Properties: momenta (torch.Tensor): Particle momenta calculated as velocities*masses with shape [n_particles, n_dimensions] - n_batches (int): Number of independent systems in the batch + n_systems (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ @@ -88,12 +88,12 @@ def _compute_cell_force( Args: state (NPTLangevinState): Current NPT state external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dimensions, n_dimensions] + tensor with shape [n_systems, n_dimensions, n_dimensions] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_systems] Returns: - torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] + torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] """ # Convert external_pressure to tensor if it's not already one if not isinstance(external_pressure, torch.Tensor): @@ -106,10 +106,10 @@ def _compute_cell_force( kT = torch.tensor(kT, device=state.device, dtype=state.dtype) # Get current volumes for each batch - volumes = torch.linalg.det(state.cell) # shape: (n_batches,) + volumes = torch.linalg.det(state.cell) # shape: (n_systems,) # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1) + volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) # Create pressure tensor (diagonal with external pressure) if external_pressure.ndim == 0: @@ -117,9 +117,9 @@ def _compute_cell_force( pressure_tensor = external_pressure * torch.eye( 3, device=state.device, dtype=state.dtype ) - pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1) + pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_systems, -1, -1) else: - # Already a tensor with shape compatible with n_batches + # Already a tensor with shape compatible with n_systems pressure_tensor = external_pressure # Calculate virials from stress and external pressure @@ -129,14 +129,14 @@ def _compute_cell_force( # Add kinetic contribution (kT * Identity) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) + batch_kT = kT.expand(state.n_systems) e_kin_per_atom = batch_kT.view(-1, 1, 1) * torch.eye( 3, device=state.device, dtype=state.dtype ).unsqueeze(0) - # Correct implementation with scaling by n_atoms_per_batch - return virial + e_kin_per_atom * state.n_atoms_per_batch.view(-1, 1, 1) + # Correct implementation with scaling by n_atoms_per_system + return virial + e_kin_per_atom * state.n_atoms_per_system.view(-1, 1, 1) def npt_langevin( # noqa: C901, PLR0915 @@ -164,18 +164,18 @@ def npt_langevin( # noqa: C901, PLR0915 Args: model (torch.nn.Module): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] external_pressure (torch.Tensor): Target pressure to maintain, either scalar - or shape [n_batches, n_dim, n_dim] for anisotropic pressure + or shape [n_systems, n_dim, n_dim] for anisotropic pressure alpha (torch.Tensor, optional): Friction coefficient for particle Langevin - thermostat, either scalar or shape [n_batches]. Defaults to 1/(100*dt). + thermostat, either scalar or shape [n_systems]. Defaults to 1/(100*dt). cell_alpha (torch.Tensor, optional): Friction coefficient for cell Langevin - thermostat, either scalar or shape [n_batches]. Defaults to same as alpha. + thermostat, either scalar or shape [n_systems]. Defaults to same as alpha. b_tau (torch.Tensor, optional): Barostat time constant controlling how quickly the system responds to pressure differences, either scalar or shape - [n_batches]. Defaults to 1/(1000*dt). + [n_systems]. Defaults to 1/(1000*dt). seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -229,24 +229,24 @@ def beta( Args: state (NPTLangevinState): Current NPT state alpha (torch.Tensor): Friction coefficient, either scalar or - shape [n_batches] + shape [n_systems] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + shape [n_systems] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] Returns: torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ - # Generate batch-specific noise with correct shape + # Generate system-specific noise with correct shape noise = torch.randn_like(state.velocities) - # Calculate the thermal noise amplitude by batch + # Calculate the thermal noise amplitude by system batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) + batch_kT = kT.expand(state.n_systems) - # Map batch kT to atoms - atom_kT = batch_kT[state.batch] + # Map system kT to atoms + atom_kT = batch_kT[state.system_idx] # Calculate the prefactor for each atom # The standard deviation should be sqrt(2*alpha*kB*T*dt) @@ -269,29 +269,29 @@ def cell_beta( Args: state (NPTLangevinState): Current NPT state cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_batches] + with shape [n_systems] kT (torch.Tensor): System temperature in energy units, either scalar or - with shape [n_batches] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + with shape [n_systems] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] Returns: torch.Tensor: Scaled random noise for cell dynamics with shape - [n_batches, n_dimensions, n_dimensions] + [n_systems, n_dimensions, n_dimensions] """ # Generate standard normal distribution (zero mean, unit variance) noise = torch.randn_like(state.cell_positions, device=device, dtype=dtype) # Ensure cell_alpha and kT have batch dimension if they're scalars if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_systems) if kT.ndim == 0: - kT = kT.expand(state.n_batches) + kT = kT.expand(state.n_systems) # Reshape for broadcasting - cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) - kT = kT.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) + kT = kT.view(-1, 1, 1) # shape: (n_systems, 1, 1) if dt.ndim == 0: - dt = dt.expand(state.n_batches).view(-1, 1, 1) + dt = dt.expand(state.n_systems).view(-1, 1, 1) else: dt = dt.view(-1, 1, 1) @@ -316,12 +316,12 @@ def compute_cell_force( Args: state (NPTLangevinState): Current NPT state external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dimensions, n_dimensions] + tensor with shape [n_systems, n_dimensions, n_dimensions] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_systems] Returns: - torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] + torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] """ return _compute_cell_force(state, external_pressure, kT) @@ -340,25 +340,25 @@ def cell_position_step( Args: state (NPTLangevinState): Current NPT state - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] pressure_force (torch.Tensor): Pressure force for barostat - [n_batches, n_dim, n_dim] + [n_systems, n_dim, n_dim] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_batches] + with shape [n_systems] Returns: NPTLangevinState: Updated state with new cell positions """ # Calculate effective mass term - Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_batches, 1, 1) + Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) # Ensure parameters have batch dimension if dt.ndim == 0: - dt = dt.expand(state.n_batches) + dt = dt.expand(state.n_systems) if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_systems) # Reshape for broadcasting dt_expanded = dt.view(-1, 1, 1) @@ -403,34 +403,34 @@ def cell_velocity_step( Args: state (NPTLangevinState): Current NPT state F_p_n (torch.Tensor): Initial pressure force with shape - [n_batches, n_dimensions, n_dimensions] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + [n_systems, n_dimensions, n_dimensions] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] pressure_force (torch.Tensor): Final pressure force - shape [n_batches, n_dim, n_dim] + shape [n_systems, n_dim, n_dim] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_batches] + shape [n_systems] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_systems] Returns: NPTLangevinState: Updated state with new cell velocities """ # Ensure parameters have batch dimension if dt.ndim == 0: - dt = dt.expand(state.n_batches) + dt = dt.expand(state.n_systems) if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_systems) if kT.ndim == 0: - kT = kT.expand(state.n_batches) + kT = kT.expand(state.n_systems) # Reshape for broadcasting - need to maintain 3x3 dimensions - dt_expanded = dt.view(-1, 1, 1) # shape: (n_batches, 1, 1) - cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) + dt_expanded = dt.view(-1, 1, 1) # shape: (n_systems, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) - # Calculate cell masses per batch - reshape to match 3x3 cell matrices + # Calculate cell masses per system - reshape to match 3x3 cell matrices cell_masses_expanded = state.cell_masses.view( -1, 1, 1 - ) # shape: (n_batches, 1, 1) + ) # shape: (n_systems, 1, 1) # These factors come from the Langevin integration scheme a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( @@ -439,13 +439,13 @@ def cell_velocity_step( b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) # Calculate the three terms for velocity update - # a will broadcast from (n_batches, 1, 1) to (n_batches, 3, 3) + # a will broadcast from (n_systems, 1, 1) to (n_systems, 3, 3) c_1 = a * state.cell_velocities # Damped old velocity # Force contribution (average of initial and final forces) c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) - # Generate batch-specific cell noise with correct shape (n_batches, 3, 3) + # Generate system-specific cell noise with correct shape (n_systems, 3, 3) cell_noise = torch.randn_like(state.cell_velocities) # Calculate thermal noise amplitude @@ -463,7 +463,7 @@ def cell_velocity_step( def langevin_position_step( state: NPTLangevinState, - L_n: torch.Tensor, # This should be shape (n_batches,) + L_n: torch.Tensor, # This should be shape (n_systems,) dt: torch.Tensor, kT: torch.Tensor, ) -> NPTLangevinState: @@ -476,42 +476,42 @@ def langevin_position_step( Args: state (NPTLangevinState): Current NPT state - L_n (torch.Tensor): Previous cell length scale with shape [n_batches] - dt: Integration timestep, either scalar or with shape [n_batches] + L_n (torch.Tensor): Previous cell length scale with shape [n_systems] + dt: Integration timestep, either scalar or with shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] Returns: NPTLangevinState: Updated state with new positions """ - # Calculate effective mass term by batch + # Calculate effective mass term by system # Map masses to have batch dimension M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) # Calculate new cell length scale (cube root of volume for isotropic scaling) L_n_new = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 - ) # shape: (n_batches,) + state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 + ) # shape: (n_systems,) - # Map batch-specific L_n and L_n_new to atom-level using batch indices - # Make sure L_n is the right shape (n_batches,) before indexing - if L_n.ndim != 1 or L_n.shape[0] != state.n_batches: + # Map system-specific L_n and L_n_new to atom-level using system indices + # Make sure L_n is the right shape (n_systems,) before indexing + if L_n.ndim != 1 or L_n.shape[0] != state.n_systems: # If L_n has wrong shape, calculate it again to ensure correct shape L_n = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 + state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 ) - # Map batch values to atoms using batch indices - L_n_atoms = L_n[state.batch] # shape: (n_atoms,) - L_n_new_atoms = L_n_new[state.batch] # shape: (n_atoms,) + # Map system-specific values to atoms using system indices + L_n_atoms = L_n[state.system_idx] # shape: (n_atoms,) + L_n_new_atoms = L_n_new[state.system_idx] # shape: (n_atoms,) # Calculate damping factor alpha_atoms = alpha if alpha.ndim > 0: - alpha_atoms = alpha[state.batch] + alpha_atoms = alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: - dt_atoms = dt[state.batch] + dt_atoms = dt[state.system_idx] b = 1 / (1 + ((alpha_atoms * dt_atoms) / M_2)) @@ -529,8 +529,8 @@ def langevin_position_step( noise = torch.randn_like(state.velocities) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) - atom_kT = batch_kT[state.batch] + batch_kT = kT.expand(state.n_systems) + atom_kT = batch_kT[state.system_idx] # Calculate noise prefactor according to fluctuation-dissipation theorem noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) @@ -549,7 +549,7 @@ def langevin_position_step( # Apply periodic boundary conditions if needed if state.pbc: state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.batch + state.positions, state.cell, state.system_idx ) return state @@ -569,9 +569,9 @@ def langevin_velocity_step( Args: state (NPTLangevinState): Current NPT state forces: Forces on particles - dt: Integration timestep, either scalar or with shape [n_batches] + dt: Integration timestep, either scalar or with shape [n_systems] kT: Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] Returns: NPTLangevinState: Updated state with new velocities @@ -582,10 +582,10 @@ def langevin_velocity_step( # Map batch parameters to atom level alpha_atoms = alpha if alpha.ndim > 0: - alpha_atoms = alpha[state.batch] + alpha_atoms = alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: - dt_atoms = dt[state.batch] + dt_atoms = dt[state.system_idx] # Calculate damping factors for Langevin integration a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) @@ -601,8 +601,8 @@ def langevin_velocity_step( noise = torch.randn_like(state.velocities) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) - atom_kT = batch_kT[state.batch] + batch_kT = kT.expand(state.n_systems) + atom_kT = batch_kT[state.system_idx] # Calculate noise prefactor according to fluctuation-dissipation theorem noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) @@ -647,7 +647,7 @@ def npt_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) # Initialize cell parameters @@ -656,20 +656,20 @@ def npt_init( # Calculate initial cell_positions (volume) cell_positions = ( torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) - ) # shape: (n_batches, 1, 1) + ) # shape: (n_systems, 1, 1) # Initialize cell velocities to zero - cell_velocities = torch.zeros((state.n_batches, 3, 3), device=device, dtype=dtype) + cell_velocities = torch.zeros((state.n_systems, 3, 3), device=device, dtype=dtype) # Calculate cell masses based on system size and temperature # This follows standard NPT barostat mass scaling - n_atoms_per_batch = torch.bincount(state.batch) + n_atoms_per_system = torch.bincount(state.system_idx) batch_kT = ( - kT.expand(state.n_batches) + kT.expand(state.n_systems) if isinstance(kT, torch.Tensor) and kT.ndim == 0 else kT ) - cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau # Create the initial state return NPTLangevinState( @@ -681,7 +681,7 @@ def npt_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, reference_cell=reference_cell, cell_positions=cell_positions, @@ -706,15 +706,15 @@ def npt_update( Args: state (NPTLangevinState): Current NPT state with particle and cell variables - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_batches] + shape [n_systems] external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dim, n_dim] + tensor with shape [n_systems, n_dim, n_dim] alpha (torch.Tensor): Position friction coefficient, either scalar or - shape [n_batches] + shape [n_systems] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_batches] + shape [n_systems] Returns: NPTLangevinState: Updated NPT state after one timestep with new positions, @@ -731,12 +731,12 @@ def npt_update( dt = torch.tensor(dt, device=device, dtype=dtype) # Make sure parameters have batch dimension if they're scalars - batch_kT = kT.expand(state.n_batches) if kT.ndim == 0 else kT + batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT # Update barostat mass based on current temperature # This ensures proper coupling between system and barostat - n_atoms_per_batch = torch.bincount(state.batch) - state.cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + n_atoms_per_system = torch.bincount(state.system_idx) + state.cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau # Compute model output for current state model_output = model(state) @@ -749,24 +749,24 @@ def npt_update( state=state, external_pressure=external_pressure, kT=kT ) L_n = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 - ) # shape: (n_batches,) + state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 + ) # shape: (n_systems,) # Step 1: Update cell position state = cell_position_step(state=state, dt=dt, pressure_force=F_p_n, kT=kT) # Update cell (currently only isotropic fluctuations) dim = state.positions.shape[1] # Usually 3 for 3D - # V_0 and V are shape: (n_batches,) + # V_0 and V are shape: (n_systems,) V_0 = torch.linalg.det(state.reference_cell) - V = state.cell_positions.reshape(state.n_batches, -1)[:, 0] + V = state.cell_positions.reshape(state.n_systems, -1)[:, 0] # Scale cell uniformly in all dimensions - scaling = (V / V_0) ** (1.0 / dim) # shape: (n_batches,) + scaling = (V / V_0) ** (1.0 / dim) # shape: (n_systems,) # Apply scaling to reference cell to get new cell new_cell = torch.zeros_like(state.cell) - for b in range(state.n_batches): + for b in range(state.n_systems): new_cell[b] = scaling[b] * state.reference_cell[b] state.cell = new_cell @@ -822,14 +822,14 @@ class NPTNoseHooverState(MDState): forces (torch.Tensor): Forces on particles with shape [n_particles, n_dims] masses (torch.Tensor): Particle masses with shape [n_particles] reference_cell (torch.Tensor): Reference simulation cell matrix with shape - [n_batches, n_dimensions, n_dimensions]. Used to measure relative volume + [n_systems, n_dimensions, n_dimensions]. Used to measure relative volume changes. - cell_position (torch.Tensor): Logarithmic cell coordinate with shape [n_batches]. + cell_position (torch.Tensor): Logarithmic cell coordinate with shape [n_systems]. Represents (1/d)ln(V/V_0) where V is current volume and V_0 is reference volume. cell_momentum (torch.Tensor): Cell momentum (velocity) conjugate to cell_position - with shape [n_batches]. Controls volume changes. - cell_mass (torch.Tensor): Mass parameter for cell dynamics with shape [n_batches]. + with shape [n_systems]. Controls volume changes. + cell_mass (torch.Tensor): Mass parameter for cell dynamics with shape [n_systems]. Controls coupling between volume fluctuations and pressure. barostat (NoseHooverChain): Chain thermostat coupled to cell dynamics for pressure control @@ -842,7 +842,7 @@ class NPTNoseHooverState(MDState): velocities (torch.Tensor): Particle velocities computed as momenta divided by masses. Shape: [n_particles, n_dimensions] current_cell (torch.Tensor): Current simulation cell matrix derived from - cell_position. Shape: [n_batches, n_dimensions, n_dimensions] + cell_position. Shape: [n_systems, n_dimensions, n_dimensions] Notes: - The cell parameterization ensures volume positivity @@ -853,10 +853,10 @@ class NPTNoseHooverState(MDState): """ # Cell variables - now with batch dimensions - reference_cell: torch.Tensor # [n_batches, 3, 3] - cell_position: torch.Tensor # [n_batches] - cell_momentum: torch.Tensor # [n_batches] - cell_mass: torch.Tensor # [n_batches] + reference_cell: torch.Tensor # [n_systems, 3, 3] + cell_position: torch.Tensor # [n_systems] + cell_momentum: torch.Tensor # [n_systems] + cell_mass: torch.Tensor # [n_systems] # Thermostat variables thermostat: NoseHooverChain @@ -885,13 +885,13 @@ def current_cell(self) -> torch.Tensor: Returns: torch.Tensor: Current simulation cell matrix with shape - [n_batches, n_dimensions, n_dimensions] + [n_systems, n_dimensions, n_dimensions] """ dim = self.positions.shape[1] - V_0 = torch.det(self.reference_cell) # [n_batches] - V = V_0 * torch.exp(dim * self.cell_position) # [n_batches] - scale = (V / V_0) ** (1.0 / dim) # [n_batches] - # Expand scale to [n_batches, 1, 1] for broadcasting + V_0 = torch.det(self.reference_cell) # [n_systems] + V = V_0 * torch.exp(dim * self.cell_position) # [n_systems] + scale = (V / V_0) ** (1.0 / dim) # [n_systems] + # Expand scale to [n_systems, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * self.reference_cell @@ -952,9 +952,9 @@ def _npt_cell_info( Returns: tuple: - - torch.Tensor: Current system volume with shape [n_batches] - - callable: Function that takes a volume tensor [n_batches] and returns - the corresponding cell matrix [n_batches, n_dimensions, n_dimensions] + - torch.Tensor: Current system volume with shape [n_systems] + - callable: Function that takes a volume tensor [n_systems] and returns + the corresponding cell matrix [n_systems, n_dimensions, n_dimensions] Notes: - Uses logarithmic cell coordinate parameterization @@ -963,21 +963,21 @@ def _npt_cell_info( - Supports batched operations """ dim = state.positions.shape[1] - ref = state.reference_cell # [n_batches, dim, dim] - V_0 = torch.det(ref) # [n_batches] - Reference volume - V = V_0 * torch.exp(dim * state.cell_position) # [n_batches] - Current volume + ref = state.reference_cell # [n_systems, dim, dim] + V_0 = torch.det(ref) # [n_systems] - Reference volume + V = V_0 * torch.exp(dim * state.cell_position) # [n_systems] - Current volume def volume_to_cell(V: torch.Tensor) -> torch.Tensor: """Compute cell matrix for given volumes. Args: - V (torch.Tensor): Volumes with shape [n_batches] + V (torch.Tensor): Volumes with shape [n_systems] Returns: - torch.Tensor: Cell matrices with shape [n_batches, dim, dim] + torch.Tensor: Cell matrices with shape [n_systems, dim, dim] """ - scale = (V / V_0) ** (1.0 / dim) # [n_batches] - # Expand scale to [n_batches, 1, 1] for broadcasting + scale = (V / V_0) ** (1.0 / dim) # [n_systems] + # Expand scale to [n_systems, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * ref @@ -996,7 +996,7 @@ def update_cell_mass( Args: state (NPTNoseHooverState): Current state of the NPT system kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_batches] + shape [n_systems] Returns: NPTNoseHooverState: Updated state with new cell mass @@ -1014,11 +1014,11 @@ def update_cell_mass( kT = torch.tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT - kT_batch = kT.expand(state.n_batches) if kT.ndim == 0 else kT + kT_system = kT.expand(state.n_systems) if kT.ndim == 0 else kT - # Calculate cell masses for each batch - n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches) - cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * state.barostat.tau**2 + # Calculate cell masses for each system + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + cell_mass = dim * (n_atoms_per_system + 1) * kT_system * state.barostat.tau**2 # Update state with new cell masses state.cell_mass = cell_mass.to(device=device, dtype=dtype) @@ -1072,7 +1072,7 @@ def exp_iL1( # noqa: N802 Args: state (NPTNoseHooverState): Current simulation state velocities (torch.Tensor): Particle velocities [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_batches] + cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] dt (torch.Tensor): Integration timestep Returns: @@ -1083,10 +1083,10 @@ def exp_iL1( # noqa: N802 - Properly handles cell scaling through cell_velocity - Maintains time-reversibility of the integration scheme - Applies periodic boundary conditions if state.pbc is True - - Supports batched operations with proper atom-to-batch mapping + - Supports batched operations with proper atom-to-system mapping """ - # Map batch-level cell velocities to atom level using batch indices - cell_velocity_atoms = cell_velocity[state.batch] # [n_atoms] + # Map system-level cell velocities to atom level using system indices + cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] # Compute cell velocity terms per atom x = cell_velocity_atoms * dt # [n_atoms] @@ -1110,7 +1110,7 @@ def exp_iL1( # noqa: N802 # Apply periodic boundary conditions if needed if state.pbc: return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.batch + new_positions, state.current_cell, state.system_idx ) return new_positions @@ -1137,7 +1137,7 @@ def exp_iL2( # noqa: N802 alpha (torch.Tensor): Cell scaling parameter momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_batches] + cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] dt_2 (torch.Tensor): Half timestep (dt/2) Returns: @@ -1148,10 +1148,10 @@ def exp_iL2( # noqa: N802 - Properly handles cell velocity scaling effects - Maintains time-reversibility of the integration scheme - Part of the NPT integration algorithm - - Supports batched operations with proper atom-to-batch mapping + - Supports batched operations with proper atom-to-system mapping """ - # Map batch-level cell velocities to atom level using batch indices - cell_velocity_atoms = cell_velocity[state.batch] # [n_atoms] + # Map system-level cell velocities to atom level using system indices + cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] # Compute scaling terms per atom x = alpha * cell_velocity_atoms * dt_2 # [n_atoms] @@ -1178,7 +1178,7 @@ def compute_cell_force( masses: torch.Tensor, stress: torch.Tensor, external_pressure: torch.Tensor, - batch: torch.Tensor, + system_idx: torch.Tensor, ) -> torch.Tensor: """Compute the force on the cell degree of freedom in NPT dynamics. @@ -1190,16 +1190,16 @@ def compute_cell_force( Args: alpha (torch.Tensor): Cell scaling parameter - volume (torch.Tensor): Current system volume with shape [n_batches] + volume (torch.Tensor): Current system volume with shape [n_systems] positions (torch.Tensor): Particle positions [n_particles, n_dimensions] momenta (torch.Tensor): Particle momenta [n_particles, n_dimensions] masses (torch.Tensor): Particle masses [n_particles] - stress (torch.Tensor): Stress tensor [n_batches, n_dimensions, n_dimensions] + stress (torch.Tensor): Stress tensor [n_systems, n_dimensions, n_dimensions] external_pressure (torch.Tensor): Target external pressure - batch (torch.Tensor): Batch indices for atoms [n_particles] + system_idx (torch.Tensor): System indices for atoms [n_particles] Returns: - torch.Tensor: Force on the cell degree of freedom with shape [n_batches] + torch.Tensor: Force on the cell degree of freedom with shape [n_systems] Notes: - Force drives volume changes to maintain target pressure @@ -1209,34 +1209,34 @@ def compute_cell_force( - Supports batched operations """ N, dim = positions.shape - n_batches = len(volume) + n_systems = len(volume) - # Compute kinetic energy contribution per batch - # Split momenta and masses by batch - KE_per_batch = torch.zeros( - n_batches, device=positions.device, dtype=positions.dtype + # Compute kinetic energy contribution per system + # Split momenta and masses by system + KE_per_system = torch.zeros( + n_systems, device=positions.device, dtype=positions.dtype ) - for b in range(n_batches): - batch_mask = batch == b - if batch_mask.any(): - batch_momenta = momenta[batch_mask] - batch_masses = masses[batch_mask] - KE_per_batch[b] = calc_kinetic_energy(batch_momenta, batch_masses) - - # Get stress tensor and compute trace per batch + for b in range(n_systems): + system_mask = system_idx == b + if system_mask.any(): + system_momenta = momenta[system_mask] + system_masses = masses[system_mask] + KE_per_system[b] = calc_kinetic_energy(system_momenta, system_masses) + + # Get stress tensor and compute trace per system # Handle stress tensor with batch dimension if stress.ndim == 3: internal_pressure = torch.diagonal(stress, dim1=-2, dim2=-1).sum( dim=-1 - ) # [n_batches] + ) # [n_systems] else: - # Single batch case - expand to batch dimension - internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_batches) + # Single system case - expand to batch dimension + internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems) - # Compute force on cell coordinate per batch + # Compute force on cell coordinate per system # F = alpha * KE - dU/dV - P*V*d return ( - (alpha * KE_per_batch) + (alpha * KE_per_system) - (internal_pressure * volume) - (external_pressure * volume * dim) ) @@ -1270,9 +1270,9 @@ def npt_inner_step( momenta = state.momenta masses = state.masses forces = state.forces - cell_position = state.cell_position # [n_batches] - cell_momentum = state.cell_momentum # [n_batches] - cell_mass = state.cell_mass # [n_batches] + cell_position = state.cell_position # [n_systems] + cell_momentum = state.cell_momentum # [n_systems] + cell_mass = state.cell_mass # [n_systems] n_particles, dim = positions.shape @@ -1285,8 +1285,8 @@ def npt_inner_step( model_output = model(state) # First half step: Update momenta - n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches) - alpha = 1 + 1 / n_atoms_per_batch # [n_batches] + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + alpha = 1 + 1 / n_atoms_per_system # [n_systems] cell_force_val = compute_cell_force( alpha=alpha, @@ -1296,7 +1296,7 @@ def npt_inner_step( masses=masses, stress=model_output["stress"], external_pressure=external_pressure, - batch=state.batch, + system_idx=state.system_idx, ) # Update cell momentum and particle momenta @@ -1331,7 +1331,7 @@ def npt_inner_step( masses=masses, stress=model_output["stress"], external_pressure=external_pressure, - batch=state.batch, + system_idx=state.system_idx, ) cell_momentum = cell_momentum + dt_2 * cell_force_val @@ -1410,32 +1410,32 @@ def npt_nose_hoover_init( state = SimState(**state) n_particles, dim = state.positions.shape - n_batches = state.n_batches + n_systems = state.n_systems atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) - # Initialize cell variables with proper batch dimensions - cell_position = torch.zeros(n_batches, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_batches, device=device, dtype=dtype) + # Initialize cell variables with proper system dimensions + cell_position = torch.zeros(n_systems, device=device, dtype=dtype) + cell_momentum = torch.zeros(n_systems, device=device, dtype=dtype) # Convert kT to tensor if it's not already one if not isinstance(kT, torch.Tensor): kT = torch.tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT - kT_batch = kT.expand(n_batches) if kT.ndim == 0 else kT + kT_system = kT.expand(n_systems) if kT.ndim == 0 else kT - # Calculate cell masses for each batch - n_atoms_per_batch = torch.bincount(state.batch, minlength=n_batches) - cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * b_tau**2 + # Calculate cell masses for each system + n_atoms_per_system = torch.bincount(state.system_idx, minlength=n_systems) + cell_mass = dim * (n_atoms_per_system + 1) * kT_system * b_tau**2 cell_mass = cell_mass.to(device=device, dtype=dtype) - # Calculate cell kinetic energy (using first batch for initialization) + # Calculate cell kinetic energy (using first system for initialization) KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1]) - # Ensure reference_cell has proper batch dimensions + # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: # Single cell matrix - expand to batch dimension - reference_cell = state.cell.unsqueeze(0).expand(n_batches, -1, -1).clone() + reference_cell = state.cell.unsqueeze(0).expand(n_systems, -1, -1).clone() else: # Already has batch dimension reference_cell = state.cell.clone() @@ -1445,7 +1445,7 @@ def npt_nose_hoover_init( state.cell, int | float ): cell_matrix = torch.eye(dim, device=device, dtype=dtype) * state.cell - reference_cell = cell_matrix.unsqueeze(0).expand(n_batches, -1, -1).clone() + reference_cell = cell_matrix.unsqueeze(0).expand(n_systems, -1, -1).clone() state.cell = reference_cell # Get model output @@ -1463,7 +1463,7 @@ def npt_nose_hoover_init( atomic_numbers=atomic_numbers, cell=state.cell, pbc=state.pbc, - batch=state.batch, + system_idx=state.system_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, @@ -1478,14 +1478,14 @@ def npt_nose_hoover_init( momenta = kwargs.get( "momenta", calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.batch, kT, seed + npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed ), ) # Initialize thermostat npt_state.momenta = momenta KE = calc_kinetic_energy( - npt_state.momenta, npt_state.masses, batch=npt_state.batch + npt_state.momenta, npt_state.masses, system_idx=npt_state.system_idx ) npt_state.thermostat = thermostat_fns.initialize( npt_state.positions.numel(), KE, kT @@ -1542,7 +1542,7 @@ def npt_nose_hoover_update( ) # Update kinetic energies for thermostats - KE = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) state.thermostat.kinetic_energy = KE KE_cell = calc_kinetic_energy(state.cell_momentum, state.cell_mass) @@ -1588,44 +1588,46 @@ def npt_nose_hoover_invariant( Returns: torch.Tensor: The conserved quantity (extended Hamiltonian) of the NPT system. - Returns a scalar for single batch or tensor with shape [n_batches] for - multiple batches. + Returns a scalar for a single system or tensor with shape [n_systems] for + multiple systems. """ # Calculate volume and potential energy - volume = torch.det(state.current_cell) # [n_batches] - e_pot = state.energy # Should be scalar or [n_batches] + volume = torch.det(state.current_cell) # [n_systems] + e_pot = state.energy # Should be scalar or [n_systems] - # Calculate kinetic energy of particles per batch - e_kin_per_batch = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + # Calculate kinetic energy of particles per system + e_kin_per_system = calc_kinetic_energy( + state.momenta, state.masses, system_idx=state.system_idx + ) - # Calculate degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - DOF_per_batch = ( - n_atoms_per_batch * state.positions.shape[-1] + # Calculate degrees of freedom per system + n_atoms_per_system = torch.bincount(state.system_idx) + DOF_per_system = ( + n_atoms_per_system * state.positions.shape[-1] ) # n_atoms * n_dimensions # Initialize total energy with PE + KE if isinstance(e_pot, torch.Tensor) and e_pot.ndim > 0: - e_tot = e_pot + e_kin_per_batch # [n_batches] + e_tot = e_pot + e_kin_per_system # [n_systems] else: - e_tot = e_pot + e_kin_per_batch # [n_batches] + e_tot = e_pot + e_kin_per_system # [n_systems] # Add thermostat chain contributions - # Note: These are global thermostat variables, so we add them to each batch + # Note: These are global thermostat variables, so we add them to each system # Start thermostat_energy as a tensor with the right shape thermostat_energy = torch.zeros_like(e_tot) thermostat_energy += (state.thermostat.momenta[0] ** 2) / ( 2 * state.thermostat.masses[0] ) - # Ensure kT can broadcast properly with DOF_per_batch + # Ensure kT can broadcast properly with DOF_per_system if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT - expand to match DOF_per_batch shape - kT_expanded = kT.expand_as(DOF_per_batch) + # Scalar kT - expand to match DOF_per_system shape + kT_expanded = kT.expand_as(DOF_per_system) else: kT_expanded = kT - thermostat_energy += DOF_per_batch * kT_expanded * state.thermostat.positions[0] + thermostat_energy += DOF_per_system * kT_expanded * state.thermostat.positions[0] # Add remaining thermostat terms for pos, momentum, mass in zip( @@ -1660,11 +1662,11 @@ def npt_nose_hoover_invariant( e_tot = e_tot + barostat_energy - # Add PV term and cell kinetic energy (both are per batch) + # Add PV term and cell kinetic energy (both are per system) e_tot += external_pressure * volume e_tot += (state.cell_momentum**2) / (2 * state.cell_mass) - # Return scalar if single batch, otherwise return per-batch values - if state.n_batches == 1: + # Return scalar if single system, otherwise return per-system values + if state.n_systems == 1: return e_tot.squeeze() return e_tot diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index f10e2f73..f17e59f6 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -37,9 +37,9 @@ def nve( Args: model (torch.nn.Module): Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_batches] + either scalar or with shape [n_systems] seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -72,7 +72,7 @@ def nve_init( containing positions, masses, cell, pbc, and other required state variables kT (torch.Tensor): Temperature in energy units for initializing momenta, - scalar or with shape [n_batches] + scalar or with shape [n_systems] seed (int, optional): Random seed for reproducibility Returns: @@ -88,7 +88,7 @@ def nve_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) initial_state = MDState( @@ -99,7 +99,7 @@ def nve_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, ) return initial_state # noqa: RET504 @@ -116,7 +116,7 @@ def nve_update(state: MDState, dt: torch.Tensor = dt, **_) -> MDState: Args: state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] **_: Additional unused keyword arguments (for compatibility) Returns: diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index e446929d..ff9b7b4b 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -45,11 +45,11 @@ def nvt_langevin( # noqa: C901 Args: model (torch.nn.Module): Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] gamma (torch.Tensor, optional): Friction coefficient for Langevin thermostat, - either scalar or with shape [n_batches]. Defaults to 1/(100*dt). + either scalar or with shape [n_systems]. Defaults to 1/(100*dt). seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -93,11 +93,11 @@ def ou_step( Args: state (MDState): Current system state containing positions, momenta, etc. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] gamma (torch.Tensor): Friction coefficient controlling noise strength, - either scalar or with shape [n_batches] + either scalar or with shape [n_systems] Returns: MDState: Updated state with new momenta after stochastic step @@ -114,12 +114,12 @@ def ou_step( c1 = torch.exp(-gamma * dt) if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: - # kT is a tensor with shape (n_batches,) - kT = kT[state.batch] + # kT is a tensor with shape (n_systems,) + kT = kT[state.system_idx] - # Index c1 and c2 with state.batch to align shapes with state.momenta + # Index c1 and c2 with state.system_idx to align shapes with state.momenta if isinstance(c1, torch.Tensor) and len(c1.shape) > 0: - c1 = c1[state.batch] + c1 = c1[state.system_idx] c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) @@ -147,7 +147,7 @@ def langevin_init( state (SimState | StateDict): Either a SimState object or a dictionary containing positions, masses, cell, pbc, and other required state vars kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_batches] + either scalar or with shape [n_systems] seed (int, optional): Random seed for reproducibility Returns: @@ -167,7 +167,7 @@ def langevin_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) initial_state = MDState( @@ -178,7 +178,7 @@ def langevin_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, ) return initial_state # noqa: RET504 @@ -202,11 +202,11 @@ def langevin_update( Args: state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_systems] gamma (torch.Tensor): Friction coefficient for Langevin thermostat, - either scalar or with shape [n_batches] + either scalar or with shape [n_systems] Returns: MDState: Updated state after one complete Langevin step with new positions, @@ -363,21 +363,21 @@ def nvt_nose_hoover_init( model_output = model(state) momenta = kwargs.get( "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - # Calculate initial kinetic energy per batch - KE = calc_kinetic_energy(momenta, state.masses, batch=state.batch) + # Calculate initial kinetic energy per system + KE = calc_kinetic_energy(momenta, state.masses, system_idx=state.system_idx) - # Calculate degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - dof_per_batch = ( - n_atoms_per_batch * state.positions.shape[-1] + # Calculate degrees of freedom per system + n_atoms_per_system = torch.bincount(state.system_idx) + dof_per_system = ( + n_atoms_per_system * state.positions.shape[-1] ) # n_atoms * n_dimensions - # For now, sum the per-batch DOF as chain expects a single int + # For now, sum the per-system DOF as chain expects a single int # This is a limitation that should be addressed in the chain implementation - total_dof = int(dof_per_batch.sum().item()) + total_dof = int(dof_per_system.sum().item()) # Initialize state state = NVTNoseHooverState( @@ -431,8 +431,8 @@ def nvt_nose_hoover_update( # Full velocity Verlet step state = velocity_verlet(state=state, dt=dt, model=model) - # Update chain kinetic energy per batch - KE = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + # Update chain kinetic energy per system + KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) chain.kinetic_energy = KE # Second half-step of chain evolution @@ -474,13 +474,13 @@ def nvt_nose_hoover_invariant( - Includes both physical and thermostat degrees of freedom - Useful for debugging thermostat behavior """ - # Calculate system energy terms per batch + # Calculate system energy terms per system e_pot = state.energy - e_kin = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + e_kin = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) - # Get system degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - dof = n_atoms_per_batch * state.positions.shape[-1] # n_atoms * n_dimensions + # Get system degrees of freedom per system + n_atoms_per_system = torch.bincount(state.system_idx) + dof = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dimensions # Start with system energy e_tot = e_pot + e_kin diff --git a/torch_sim/io.py b/torch_sim/io.py index f9505062..8f61bd49 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -33,7 +33,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[Atoms]: ASE Atoms objects, one per batch + list[Atoms]: ASE Atoms objects, one per system Raises: ImportError: If ASE is not installed @@ -50,22 +50,22 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + system_idx = state.system_idx.detach().cpu().numpy() atoms_list = [] - for batch_idx in np.unique(batch): - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for ASE convention + for idx in np.unique(system_idx): + mask = system_idx == idx + system_positions = positions[mask] + system_numbers = atomic_numbers[mask] + system_cell = cell[idx].T # Transpose for ASE convention # Convert atomic numbers to chemical symbols - symbols = [chemical_symbols[z] for z in batch_numbers] + symbols = [chemical_symbols[z] for z in system_numbers] atoms = Atoms( - symbols=symbols, positions=batch_positions, cell=batch_cell, pbc=state.pbc + symbols=symbols, positions=system_positions, cell=system_cell, pbc=state.pbc ) atoms_list.append(atoms) @@ -79,7 +79,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[Structure]: Pymatgen Structure objects, one per batch + list[Structure]: Pymatgen Structure objects, one per system Raises: ImportError: If Pymatgen is not installed @@ -98,29 +98,29 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + system_idx = state.system_idx.detach().cpu().numpy() - # Get unique batch indices and counts - unique_batches = np.unique(batch) + # Get unique system indices and counts + unique_systems = np.unique(system_idx) structures = [] - for batch_idx in unique_batches: - # Get mask for current batch - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for conventional form + for unique_system_idx in unique_systems: + # Get mask for current system + mask = system_idx == unique_system_idx + system_positions = positions[mask] + system_numbers = atomic_numbers[mask] + system_cell = cell[unique_system_idx].T # Transpose for conventional form # Create species list from atomic numbers - species = [Element.from_Z(z) for z in batch_numbers] + species = [Element.from_Z(z) for z in system_numbers] - # Create structure for this batch + # Create structure for this system struct = Structure( - lattice=Lattice(batch_cell), + lattice=Lattice(system_cell), species=species, - coords=batch_positions, + coords=system_positions, coords_are_cartesian=True, ) structures.append(struct) @@ -135,7 +135,7 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[PhonopyAtoms]: PhonopyAtoms objects, one per batch + list[PhonopyAtoms]: PhonopyAtoms objects, one per system Raises: ImportError: If Phonopy is not installed @@ -152,24 +152,24 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + system_idx = state.system_idx.detach().cpu().numpy() phonopy_atoms_list = [] - for batch_idx in np.unique(batch): - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for Phonopy convention + for idx in np.unique(system_idx): + mask = system_idx == idx + system_positions = positions[mask] + system_numbers = atomic_numbers[mask] + system_cell = cell[idx].T # Transpose for Phonopy convention # Convert atomic numbers to chemical symbols - symbols = [chemical_symbols[z] for z in batch_numbers] + symbols = [chemical_symbols[z] for z in system_numbers] phonopy_atoms_list.append( PhonopyAtoms( symbols=symbols, - positions=batch_positions, - cell=batch_cell, + positions=system_positions, + cell=system_cell, pbc=state.pbc, ) ) @@ -225,10 +225,10 @@ def atoms_to_state( np.stack([a.cell.array.T for a in atoms_list]), dtype=dtype, device=device ) - # Create batch indices using repeat_interleave - atoms_per_batch = torch.tensor([len(a) for a in atoms_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(atoms_list), device=device), atoms_per_batch + # Create system indices using repeat_interleave + atoms_per_system = torch.tensor([len(a) for a in atoms_list], device=device) + system_idx = torch.repeat_interleave( + torch.arange(len(atoms_list), device=device), atoms_per_system ) # Verify consistent pbc @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - batch=batch, + system_idx=system_idx, ) @@ -297,10 +297,10 @@ def structures_to_state( device=device, ) - # Create batch indices - atoms_per_batch = torch.tensor([len(s) for s in struct_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(struct_list), device=device), atoms_per_batch + # Create system indices + atoms_per_system = torch.tensor([len(s) for s in struct_list], device=device) + system_idx = torch.repeat_interleave( + torch.arange(len(struct_list), device=device), atoms_per_system ) return ts.SimState( @@ -309,7 +309,7 @@ def structures_to_state( cell=cell, pbc=True, # Structures are always periodic atomic_numbers=atomic_numbers, - batch=batch, + system_idx=system_idx, ) @@ -368,10 +368,10 @@ def phonopy_to_state( np.stack([a.cell.T for a in phonopy_atoms_list]), dtype=dtype, device=device ) - # Create batch indices using repeat_interleave - atoms_per_batch = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(phonopy_atoms_list), device=device), atoms_per_batch + # Create system indices using repeat_interleave + atoms_per_system = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) + system_idx = torch.repeat_interleave( + torch.arange(len(phonopy_atoms_list), device=device), atoms_per_system ) """ @@ -387,5 +387,5 @@ def phonopy_to_state( cell=cell, pbc=True, atomic_numbers=atomic_numbers, - batch=batch, + system_idx=system_idx, ) diff --git a/torch_sim/math.py b/torch_sim/math.py index 40228ba4..c45d34c6 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -1000,7 +1000,7 @@ def batched_vdot( batch_indices: Tensor of shape [N_total_entities] indicating batch membership. Returns: - Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + Tensor: shape [n_systems] where each element is the sum(x_i * y_i) for entities belonging to that batch, summed over all components D and all entities in the batch. """ diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 3197c6aa..b911668e 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -349,8 +349,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.device != self._device: state = state.to(self._device) - if state.batch is None: - state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + if state.system_idx is None: + state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) if self.pbc != state.pbc: raise ValueError( @@ -358,8 +358,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict: "For FairChemModel PBC needs to be defined in the model class." ) - natoms = torch.bincount(state.batch) - fixed = torch.zeros((state.batch.size(0), natoms.sum()), dtype=torch.int) + natoms = torch.bincount(state.system_idx) + fixed = torch.zeros((state.system_idx.size(0), natoms.sum()), dtype=torch.int) data_list = [] for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index af88e94d..fe51fe01 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -67,8 +67,8 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra """ graphs = [] - for i in range(state.n_batches): - batch_mask = state.batch == i + for i in range(state.n_systems): + batch_mask = state.system_idx == i R = state.positions[batch_mask] Z = state.atomic_numbers[batch_mask] cell = state.row_vector_cell[i] diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index a6ba406b..a70736ab 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -65,9 +65,9 @@ class ModelInterface(ABC): output = model(sim_state) # Access computed properties - energy = output["energy"] # Shape: [n_batches] + energy = output["energy"] # Shape: [n_systems] forces = output["forces"] # Shape: [n_atoms, 3] - stress = output["stress"] # Shape: [n_batches, 3, 3] + stress = output["stress"] # Shape: [n_systems, 3, 3] ``` """ @@ -174,16 +174,16 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens dictionary is dependent on the model but typically must contain the following keys: - "positions": Atomic positions with shape [n_atoms, 3] - - "cell": Unit cell vectors with shape [n_batches, 3, 3] - - "batch": Batch indices for each atom with shape [n_atoms] + - "cell": Unit cell vectors with shape [n_systems, 3, 3] + - "system_idx": System indices for each atom with shape [n_atoms] - "atomic_numbers": Atomic numbers with shape [n_atoms] (optional) **kwargs: Additional model-specific parameters. Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - May include additional model-specific outputs @@ -256,7 +256,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() - og_batch = sim_state.batch.clone() + og_system_idx = sim_state.system_idx.clone() og_atomic_numbers = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -266,8 +266,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_positions=} != {sim_state.positions=}") if not torch.allclose(og_cell, sim_state.cell): raise ValueError(f"{og_cell=} != {sim_state.cell=}") - if not torch.allclose(og_batch, sim_state.batch): - raise ValueError(f"{og_batch=} != {sim_state.batch=}") + if not torch.allclose(og_system_idx, sim_state.system_idx): + raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_numbers, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_numbers=} != {sim_state.atomic_numbers=}") diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 2e98ad22..3b7a5b81 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -357,7 +357,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Lennard-Jones energies, forces, and stresses for a system. Main entry point for Lennard-Jones calculations that handles batched states by - dispatching each batch to the unbatched implementation and combining results. + dispatching each system to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, @@ -366,10 +366,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) @@ -377,7 +377,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: per_atom_stresses=True) Raises: - ValueError: If batch cannot be inferred for multi-cell systems. + ValueError: If system cannot be inferred for multi-cell systems. Example:: @@ -385,19 +385,19 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = LennardJonesModel(compute_stress=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_systems] forces = results["forces"] # Shape: [n_atoms, 3] - stress = results["stress"] # Shape: [n_batches, 3, 3] + stress = results["stress"] # Shape: [n_systems, 3, 3] energies = results["energies"] # Shape: [n_atoms] stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + if state.system_idx is None and state.cell.shape[0] > 1: + raise ValueError("System can only be inferred for batch size 1.") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index d9d92b9b..4ec3ec3f 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -91,7 +91,7 @@ class MaceModel(torch.nn.Module, ModelInterface): model (torch.nn.Module): The underlying MACE neural network model. neighbor_list_fn (Callable): Function used to compute neighbor lists. atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms]. - batch (torch.Tensor): Batch indices with shape [n_atoms]. + system_idx (torch.Tensor): System indices with shape [n_atoms]. n_systems (int): Number of systems in the batch. n_atoms_per_system (list[int]): Number of atoms in each system. ptr (torch.Tensor): Pointers to the start of each system in the batch with @@ -112,13 +112,13 @@ def __init__( compute_stress: bool = True, enable_cueq: bool = False, atomic_numbers: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, ) -> None: """Initialize the MACE model for energy and force calculations. Sets up the MACE model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers - and batch indices, or these can be provided during the forward pass. + and system indices, or these can be provided during the forward pass. Args: model (str | Path | torch.nn.Module | None): The MACE neural network model, @@ -129,9 +129,9 @@ def __init__( Defaults to torch.float64. atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. If provided at initialization, cannot be provided again during forward. - batch (torch.Tensor | None): Batch indices with shape [n_atoms] indicating - which system each atom belongs to. If not provided with atomic_numbers, - all atoms are assumed to be in the same system. + system_idx (torch.Tensor | None): System indices with shape [n_atoms] + indicating which system each atom belongs to. If not provided with + atomic_numbers, all atoms are assumed to be in the same system. neighbor_list_fn (Callable): Function to compute neighbor lists. Defaults to vesin_nl_ts. compute_forces (bool): Whether to compute forces. Defaults to True. @@ -186,38 +186,40 @@ def __init__( # Set up batch information if atomic numbers are provided if atomic_numbers is not None: - if batch is None: + if system_idx is None: # If batch is not provided, assume all atoms belong to same system - batch = torch.zeros( + system_idx = torch.zeros( len(atomic_numbers), dtype=torch.long, device=self.device ) - self.setup_from_batch(atomic_numbers, batch) + self.setup_from_batch(atomic_numbers, system_idx) - def setup_from_batch(self, atomic_numbers: torch.Tensor, batch: torch.Tensor) -> None: - """Set up internal state from atomic numbers and batch indices. + def setup_from_batch( + self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor + ) -> None: + """Set up internal state from atomic numbers and system indices. - Processes the atomic numbers and batch indices to prepare the model for + Processes the atomic numbers and system indices to prepare the model for forward pass calculations. Creates the necessary data structures for batched processing of multiple systems. Args: atomic_numbers (torch.Tensor): Atomic numbers tensor with shape [n_atoms]. - batch (torch.Tensor): Batch indices tensor with shape [n_atoms] indicating - which system each atom belongs to. + system_idx (torch.Tensor): System indices tensor with shape [n_atoms] + indicating which system each atom belongs to. """ self.atomic_numbers = atomic_numbers - self.batch = batch + self.system_idx = system_idx # Determine number of systems and atoms per system - self.n_systems = batch.max().item() + 1 + self.n_systems = system_idx.max().item() + 1 - # Create ptr tensor for batch boundaries + # Create ptr tensor for system boundaries self.n_atoms_per_system = [] ptr = [0] - for b in range(self.n_systems): - batch_mask = batch == b - n_atoms = batch_mask.sum().item() + for i in range(self.n_systems): + system_mask = system_idx == i + n_atoms = system_mask.sum().item() self.n_atoms_per_system.append(n_atoms) ptr.append(ptr[-1] + n_atoms) @@ -260,7 +262,7 @@ def forward( # noqa: C901 Raises: ValueError: If atomic numbers are not provided either in the constructor or in the forward pass, or if provided in both places. - ValueError: If batch indices are not provided when needed. + ValueError: If system indices are not provided when needed. """ # Extract required data from input if isinstance(state, dict): @@ -276,13 +278,13 @@ def forward( # noqa: C901 "Atomic numbers cannot be provided in both the constructor and forward." ) - # Use batch from init if not provided - if state.batch is None: - if not hasattr(self, "batch"): + # Use system_idx from init if not provided + if state.system_idx is None: + if not hasattr(self, "system_idx"): raise ValueError( - "Batch indices must be provided if not set during initialization" + "System indices must be provided if not set during initialization" ) - state.batch = self.batch + state.system_idx = self.system_idx # Update batch information if new atomic numbers are provided if ( @@ -293,7 +295,7 @@ def forward( # noqa: C901 getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), ) ): - self.setup_from_batch(state.atomic_numbers, state.batch) + self.setup_from_batch(state.atomic_numbers, state.system_idx) # Process each system's neighbor list separately edge_indices = [] @@ -303,7 +305,7 @@ def forward( # noqa: C901 # TODO (AG): Currently doesn't work for batched neighbor lists for b in range(self.n_systems): - batch_mask = state.batch == b + batch_mask = state.system_idx == b # Calculate neighbor list for this system edge_idx, shifts_idx = self.neighbor_list_fn( positions=state.positions[batch_mask], @@ -332,7 +334,7 @@ def forward( # noqa: C901 dict( ptr=self.ptr, node_attrs=self.node_attrs, - batch=state.batch, + batch=state.system_idx, pbc=state.pbc, cell=state.row_vector_cell, positions=state.positions, diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 31d26966..7aafb8d0 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -74,7 +74,7 @@ def __init__( Sets up a metatomic model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers - and batch indices, or these can be provided during the forward pass. + and system indices, or these can be provided during the forward pass. Args: model (str | Path | None): Path to the metatomic model file or a @@ -200,7 +200,7 @@ def forward( # noqa: C901, PLR0915 systems: list[System] = [] strains = [] for b in range(len(cell)): - system_mask = state.batch == b + system_mask = state.system_idx == b system_positions = positions[system_mask] system_cell = cell[b] system_pbc = torch.tensor( diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 8b446a3c..97564361 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -356,10 +356,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration @@ -372,17 +372,19 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = MorseModel(compute_forces=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_systems] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + if state.system_idx is None and state.cell.shape[0] > 1: + raise ValueError( + "system_idx can only be inferred if there is only one system." + ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 551a9b01..cf015fb2 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -102,7 +102,7 @@ def state_to_atom_graphs( # noqa: PLR0915 system_config = SystemConfig(radius=6.0, max_num_neighbors=20) # Handle batch information if present - n_node = torch.bincount(state.batch) + n_node = torch.bincount(state.system_idx) # Set default dtype if not provided output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype @@ -143,7 +143,7 @@ def state_to_atom_graphs( # noqa: PLR0915 if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)): positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node) - n_systems = state.batch.max().item() + 1 + n_systems = state.system_idx.max().item() + 1 # Prepare lists to collect data from each system all_edges = [] @@ -157,7 +157,7 @@ def state_to_atom_graphs( # noqa: PLR0915 # Process each system in a single loop offset = 0 for i in range(n_systems): - batch_mask = state.batch == i + batch_mask = state.system_idx == i positions_per_system = positions[batch_mask] atomic_numbers_per_system = atomic_numbers[batch_mask] atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 48110a86..39c5606c 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -223,10 +223,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) @@ -239,10 +239,12 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + if state.system_idx is None and state.cell.shape[0] > 1: + raise ValueError( + "system_idx can only be inferred if there is only one system." + ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 45e59f9b..c4e3d96b 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -181,8 +181,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: state = state.clone() data_list = [] - for b in range(state.batch.max().item() + 1): - batch_mask = state.batch == b + for b in range(state.system_idx.max().item() + 1): + batch_mask = state.system_idx == b pos = state.positions[batch_mask] # SevenNet uses row vector cell convention for neighbor list @@ -245,7 +245,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: results["energy"] = energy.detach() else: results["energy"] = torch.zeros( - state.batch.max().item() + 1, device=self.device + state.system_idx.max().item() + 1, device=self.device ) forces = output[key.PRED_FORCE] diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 8cbc0e7f..4397b2eb 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -381,7 +381,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Compute soft sphere potential energies, forces, and stresses for a system. Main entry point for soft sphere potential calculations that handles batched - states by dispatching each batch to the unbatched implementation and combining + states by dispatching each system to the unbatched implementation and combining results. Args: @@ -391,15 +391,15 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration Raises: - ValueError: If batch cannot be inferred for multi-cell systems. + ValueError: If system indices cannot be inferred for multi-cell systems. Examples: ```py @@ -407,18 +407,20 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = SoftSphereModel(compute_forces=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_systems] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - # Handle batch indices if not provided - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + # Handle System indices if not provided + if state.system_idx is None and state.cell.shape[0] > 1: + raise ValueError( + "system_idx can only be inferred if there is only one system" + ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] # Combine results @@ -816,10 +818,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration @@ -854,11 +856,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: elif state.pbc != self.pbc: raise ValueError("PBC mismatch between model and state") - # Handle batch indices if not provided - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + # Handle system indices if not provided + if state.system_idx is None and state.cell.shape[0] > 1: + raise ValueError( + "system_idx can only be inferred if there is only one system" + ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] # Combine results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 389ce687..90929068 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -41,7 +41,7 @@ def generate_swaps( ) -> torch.Tensor: """Generate atom swaps for a given batched system. - Generates proposed swaps between atoms of different types within the same batch. + Generates proposed swaps between atoms of different types within the same system. The function ensures that swaps only occur between atoms with different atomic numbers. @@ -51,48 +51,48 @@ def generate_swaps( reproducibility. Defaults to None. Returns: - torch.Tensor: A tensor of proposed swaps with shape [n_batches, 2], + torch.Tensor: A tensor of proposed swaps with shape [n_systems, 2], where each row contains indices of atoms to be swapped """ - batch = state.batch + system = state.system_idx atomic_numbers = state.atomic_numbers - batch_lengths = batch.bincount() + system_lengths = system.bincount() - # change batch_lengths to batch - batch = torch.repeat_interleave( - torch.arange(len(batch_lengths), device=batch.device), batch_lengths + # change system_lengths to system + system = torch.repeat_interleave( + torch.arange(len(system_lengths), device=system.device), system_lengths ) # Create ragged weights tensor without loops - max_length = torch.max(batch_lengths).item() - n_batches = len(batch_lengths) + max_length = torch.max(system_lengths).item() + n_systems = len(system_lengths) - # Create a range tensor for each batch - range_tensor = torch.arange(max_length, device=batch.device).expand( - n_batches, max_length + # Create a range tensor for each system + range_tensor = torch.arange(max_length, device=system.device).expand( + n_systems, max_length ) - # Create a mask where values are less than the batch length - batch_lengths_expanded = batch_lengths.unsqueeze(1).expand(n_batches, max_length) - weights = (range_tensor < batch_lengths_expanded).float() + # Create a mask where values are less than the max system length + system_lengths_expanded = system_lengths.unsqueeze(1).expand(n_systems, max_length) + weights = (range_tensor < system_lengths_expanded).float() first_index = torch.multinomial(weights, 1, replacement=False, generator=generator) - # Process each batch - we need this loop because of ragged batches - batch_starts = batch_lengths.cumsum(dim=0) - batch_lengths[0] + # Process each system - we need this loop because of ragged systems + system_starts = system_lengths.cumsum(dim=0) - system_lengths[0] - for b in range(n_batches): + for b in range(n_systems): # Get global index of selected atom - first_idx = first_index[b, 0].item() + batch_starts[b].item() + first_idx = first_index[b, 0].item() + system_starts[b].item() first_type = atomic_numbers[first_idx] - # Get indices of atoms in this batch - batch_start = batch_starts[b].item() - batch_end = batch_start + batch_lengths[b].item() + # Get indices of atoms in this system + system_start = system_starts[b].item() + system_end = system_start + system_lengths[b].item() # Create mask for same-type atoms - same_type = atomic_numbers[batch_start:batch_end] == first_type + same_type = atomic_numbers[system_start:system_end] == first_type # Zero out weights for same-type atoms (accounting for padding) weights[b, : len(same_type)][same_type] = 0.0 @@ -100,7 +100,7 @@ def generate_swaps( second_index = torch.multinomial(weights, 1, replacement=False, generator=generator) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) - return zeroed_swaps + (batch_lengths.cumsum(dim=0) - batch_lengths[0]).unsqueeze(1) + return zeroed_swaps + (system_lengths.cumsum(dim=0) - system_lengths[0]).unsqueeze(1) def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: @@ -124,21 +124,21 @@ def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: return permutation -def validate_permutation(permutation: torch.Tensor, batch: torch.Tensor) -> None: - """Validate that permutations only swap atoms within the same batch. +def validate_permutation(permutation: torch.Tensor, system_idx: torch.Tensor) -> None: + """Validate that permutations only swap atoms within the same system. - Confirms that no swaps are attempted between atoms in different batches, + Confirms that no swaps are attempted between atoms in different systems, which would lead to physically invalid configurations. Args: permutation (torch.Tensor): Permutation tensor of shape [n_atoms] - batch (torch.Tensor): Batch assignments for each atom of shape [n_atoms] + system_idx (torch.Tensor): system_idx for each atom of shape [n_atoms] Raises: - ValueError: If any swaps are between atoms in different batches + ValueError: If any swaps are between atoms in different systems """ - if not torch.all(batch == batch[permutation]): - raise ValueError("Swaps must be between atoms in the same batch") + if not torch.all(system_idx == system_idx[permutation]): + raise ValueError("Swaps must be between atoms in the same system") def metropolis_criterion( @@ -233,7 +233,7 @@ def init_swap_mc_state(state: SimState) -> SwapMCState: cell=state.cell, pbc=state.pbc, atomic_numbers=state.atomic_numbers, - batch=state.batch, + system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), ) @@ -260,12 +260,12 @@ def swap_monte_carlo_step( Notes: The function handles batched systems and ensures that swaps only occur - within the same batch. + within the same system. """ swaps = generate_swaps(state, generator=generator) permutation = swaps_to_permutation(swaps, state.n_atoms) - validate_permutation(permutation, state.batch) + validate_permutation(permutation, state.system_idx) energies_old = state.energy.clone() state.positions = state.positions[permutation].clone() diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index cd4c439d..9c997eac 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -766,7 +766,7 @@ def torch_nl_linked_cell( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - batch: torch.Tensor, + system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using the linked @@ -784,7 +784,7 @@ def torch_nl_linked_cell( pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. - batch (torch.Tensor [n_atom,] torch.long): + system_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. @@ -806,7 +806,7 @@ def torch_nl_linked_cell( References: - https://github.com/felixmusil/torch_nl """ - n_atoms = torch.bincount(batch) + n_atoms = torch.bincount(system_idx) mapping, batch_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 5c98dafe..98a83d64 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -44,12 +44,12 @@ class GDState(SimState): Attributes: positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + system_idx (torch.Tensor): System indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_batches] + energy (torch.Tensor): Potential energy with shape [n_systems] """ forces: torch.Tensor @@ -68,8 +68,8 @@ def gradient_descent( Args: model (torch.nn.Module): Model that computes energies and forces lr (torch.Tensor | float): Learning rate(s) for optimization. Can be a single - float applied to all batches or a tensor with shape [n_batches] for - batch-specific rates + float applied to all systems or a tensor with shape [n_systems] for + system-specific rates Returns: tuple: A pair of functions: @@ -113,7 +113,7 @@ def gd_init( cell=state.cell, pbc=state.pbc, atomic_numbers=atomic_numbers, - batch=state.batch, + system_idx=state.system_idx, ) def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: @@ -129,9 +129,9 @@ def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: """ # Get per-atom learning rates by mapping batch learning rates to atoms if isinstance(lr, float): - lr = torch.full((state.n_batches,), lr, device=device, dtype=dtype) + lr = torch.full((state.n_systems,), lr, device=device, dtype=dtype) - atom_lr = lr[state.batch].unsqueeze(-1) # shape: (total_atoms, 1) + atom_lr = lr[state.system_idx].unsqueeze(-1) # shape: (total_atoms, 1) # Update positions using forces and per-atom learning rates state.positions = state.positions + atom_lr * state.forces @@ -160,25 +160,25 @@ class UnitCellGDState(GDState, DeformGradMixin): # Inherited from GDState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + system_idx (torch.Tensor): System indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_batches] + energy (torch.Tensor): Potential energy with shape [n_systems] # Additional attributes for cell optimization - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] reference_cell (torch.Tensor): Reference unit cells with shape - [n_batches, 3, 3] + [n_systems, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_batches, 1, 1] + [n_systems, 1, 1] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] - cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] + cell_positions (torch.Tensor): Cell positions with shape [n_systems, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] """ # Required attributes not in BatchedGDState @@ -224,7 +224,7 @@ def unit_cell_gradient_descent( # noqa: PLR0915, C901 is 0.01. cell_lr (float): Learning rate for unit cell optimization. Default is 0.1. cell_factor (float | torch.Tensor | None): Scaling factor for cell - optimization. If None, defaults to number of atoms per batch + optimization. If None, defaults to number of atoms per system hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling). Default is False. constant_volume (bool): Whether to maintain constant volume during optimization @@ -270,25 +270,25 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) - n_batches = state.n_batches + n_systems = state.n_systems # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per system + _, counts = torch.unique(state.system_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all systems cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_systems,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_systems, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_systems, 1, 1) scalar_pressure = torch.full( - (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype + (state.n_systems, 1, 1), scalar_pressure, device=device, dtype=dtype ) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device) @@ -297,11 +297,11 @@ def gd_init( model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] - stress = model_output["stress"] # Already shape: (n_batches, 3, 3) + stress = model_output["stress"] # Already shape: (n_systems, 3, 3) # Create cell masses cell_masses = torch.ones( - (state.n_batches, 3), device=device, dtype=dtype + (state.n_systems, 3), device=device, dtype=dtype ) # One mass per cell DOF # Get current deformation gradient @@ -311,27 +311,27 @@ def gd_init( # Calculate cell positions cell_factor_expanded = cell_factor.expand( - state.n_batches, 3, 1 - ) # shape: (n_batches, 3, 1) + state.n_systems, 3, 1 + ) # shape: (n_systems, 3, 1) cell_positions = ( - cur_deform_grad.reshape(state.n_batches, 3, 3) * cell_factor_expanded - ) # shape: (n_batches, 3, 3) + cur_deform_grad.reshape(state.n_systems, 3, 3) * cell_factor_expanded + ) # shape: (n_systems, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(state.n_batches, -1, -1) + ).expand(state.n_systems, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(state.n_batches, -1, -1) + ).unsqueeze(0).expand(state.n_systems, -1, -1) return UnitCellGDState( positions=state.positions, @@ -347,7 +347,7 @@ def gd_init( constant_volume=constant_volume, pressure=pressure, atomic_numbers=state.atomic_numbers, - batch=state.batch, + system_idx=state.system_idx, cell_positions=cell_positions, cell_forces=virial / cell_factor, cell_masses=cell_masses, @@ -371,29 +371,29 @@ def gd_step( Updated UnitCellGDState after one optimization step """ # Get dimensions - n_batches = state.n_batches + n_systems = state.n_systems - # Get per-atom learning rates by mapping batch learning rates to atoms + # Get per-atom learning rates by mapping system learning rates to atoms if isinstance(positions_lr, float): positions_lr = torch.full( - (state.n_batches,), positions_lr, device=device, dtype=dtype + (state.n_systems,), positions_lr, device=device, dtype=dtype ) if isinstance(cell_lr, float): - cell_lr = torch.full((state.n_batches,), cell_lr, device=device, dtype=dtype) + cell_lr = torch.full((state.n_systems,), cell_lr, device=device, dtype=dtype) # Get current deformation gradient cur_deform_grad = state.deform_grad() # Calculate cell positions from deformation gradient - cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) cell_positions = ( - cur_deform_grad.reshape(n_batches, 3, 3) * cell_factor_expanded - ) # shape: (n_batches, 3, 3) + cur_deform_grad.reshape(n_systems, 3, 3) * cell_factor_expanded + ) # shape: (n_systems, 3, 3) # Get per-atom and per-cell learning rates - atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) + atom_wise_lr = positions_lr[state.system_idx].unsqueeze(-1) + cell_wise_lr = cell_lr.view(n_systems, 1, 1) # shape: (n_systems, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,18 +415,18 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_systems, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_systems, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) # Update cell forces state.cell_positions = cell_positions_new @@ -450,21 +450,21 @@ class FireState(SimState): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + system_idx (torch.Tensor): System indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] + energy (torch.Tensor): Energy per system with shape [n_systems] # FIRE optimization parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per system with shape [n_systems] + alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] + n_pos (torch.Tensor): Number of positive power steps per system with shape + [n_systems] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -558,8 +558,8 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per system + alpha_start: Initial mixing parameter per system Returns: FireState with initialized optimization tensors @@ -568,18 +568,18 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_systems = state.n_systems # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_systems] forces = model_output["forces"] # [n_total_atoms, 3] # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) return FireState( # Create initial state # Copy SimState attributes @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - batch=state.batch.clone(), + system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, @@ -630,36 +630,36 @@ class UnitCellFireState(SimState, DeformGradMixin): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + system_idx (torch.Tensor): System indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + energy (torch.Tensor): Energy per system with shape [n_systems] + stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] # Cell quantities - cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + cell_positions (torch.Tensor): Cell positions with shape [n_systems, 3, 3] + cell_velocities (torch.Tensor): Cell velocities with shape [n_systems, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] # Cell optimization parameters - reference_cell (torch.Tensor): Original unit cells with shape [n_batches, 3, 3] + reference_cell (torch.Tensor): Original unit cells with shape [n_systems, 3, 3] cell_factor (torch.Tensor): Cell optimization scaling factor with shape - [n_batches, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] + [n_systems, 1, 1] + pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # FIRE optimization parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per system with shape [n_systems] + alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] + n_pos (torch.Tensor): Number of positive power steps per system with shape + [n_systems] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -728,7 +728,7 @@ def unit_cell_fire( alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per batch + If None, defaults to number of atoms per system hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization @@ -782,11 +782,11 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. - Single value or tensor of shape [n_batches]. + cell_factor: Cell optimization scaling factor. If None, uses atoms per system. + Single value or tensor of shape [n_systems]. scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per system + alpha_start: Initial mixing parameter per system Returns: UnitCellFireState with initialized optimization tensors @@ -795,64 +795,64 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_systems = state.n_systems # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per system + _, counts = torch.unique(state.system_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all systems cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_systems,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_systems, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_systems, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) + pressure = pressure.unsqueeze(0).expand(n_systems, -1, -1) # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_systems] forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_batches, 3, 3] + stress = model_output["stress"] # [n_systems, 3, 3] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_systems, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) cell_forces = virial / cell_factor - # Sum masses per batch using segment_reduce + # Sum masses per system using segment_reduce # TODO (AG): check this - batch_counts = torch.bincount(state.batch) + system_counts = torch.bincount(state.system_idx) cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=batch_counts - ) # shape: (n_batches,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) + state.masses, reduce="sum", lengths=system_counts + ) # shape: (n_systems,) + cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_systems, 3) # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) return UnitCellFireState( # Create initial state # Copy SimState attributes @@ -860,14 +860,14 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - batch=state.batch.clone(), + system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes - cell_positions=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), + cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, @@ -913,37 +913,37 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + system_idx (torch.Tensor): System indices with shape [n_atoms] # Additional atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] + energy (torch.Tensor): Energy per system with shape [n_systems] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] # Optimization-specific attributes - reference_cell (torch.Tensor): Original unit cell with shape [n_batches, 3, 3] + reference_cell (torch.Tensor): Original unit cell with shape [n_systems, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_batches, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] + [n_systems, 1, 1] + pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # Cell attributes using log parameterization cell_positions (torch.Tensor): Cell positions using log parameterization with - shape [n_batches, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + shape [n_systems, 3, 3] + cell_velocities (torch.Tensor): Cell velocities with shape [n_systems, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] # FIRE algorithm parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per system with shape [n_systems] + alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] + n_pos (torch.Tensor): Number of positive power steps per system with shape + [n_systems] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -1013,7 +1013,7 @@ def frechet_cell_fire( alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per batch + If None, defaults to number of atoms per system hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization @@ -1067,11 +1067,11 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. - Single value or tensor of shape [n_batches]. + cell_factor: Cell optimization scaling factor. If None, uses atoms per system. + Single value or tensor of shape [n_systems]. scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per system + alpha_start: Initial mixing parameter per system Returns: FrechetCellFIREState with initialized optimization tensors @@ -1080,78 +1080,78 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_systems = state.n_systems # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per system + _, counts = torch.unique(state.system_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all systems cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_systems,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_systems, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_systems, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) + pressure = pressure.unsqueeze(0).expand(n_systems, -1, -1) # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_systems] forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_batches, 3, 3] + stress = model_output["stress"] # [n_systems, 3, 3] # Calculate initial cell positions using matrix logarithm # Calculate current deformation gradient (identity matrix at start) cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 state.row_vector_cell, state.row_vector_cell - ) # shape: (n_batches, 3, 3) + ) # shape: (n_systems, 3, 3) # For identity matrix, logm gives zero matrix # Initialize cell positions to zeros - cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) + cell_positions = torch.zeros((n_systems, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_systems, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) # Calculate UCF-style cell gradient ucf_cell_grad = torch.zeros_like(virial) - for b in range(n_batches): + for b in range(n_systems): ucf_cell_grad[b] = virial[b] @ torch.linalg.inv(cur_deform_grad[b].T) # Calculate cell forces using Frechet derivative approach (all zeros for identity) cell_forces = ucf_cell_grad / cell_factor - # Sum masses per batch - batch_counts = torch.bincount(state.batch) + # Sum masses per system + system_counts = torch.bincount(state.system_idx) cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=batch_counts - ) # shape: (n_batches,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) + state.masses, reduce="sum", lengths=system_counts + ) # shape: (n_systems,) + cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_systems, 3) # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) return FrechetCellFIREState( # Create initial state # Copy SimState attributes @@ -1159,7 +1159,7 @@ def fire_init( masses=state.masses, cell=state.cell, atomic_numbers=state.atomic_numbers, - batch=state.batch, + system_idx=state.system_idx, pbc=state.pbc, velocities=None, forces=forces, @@ -1239,7 +1239,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 Returns: Updated state after performing one VV-FIRE step. """ - n_batches = state.n_batches + n_systems = state.n_systems device = state.positions.device dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None @@ -1252,14 +1252,14 @@ def _vv_fire_step( # noqa: C901, PLR0915 f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( - (n_batches, 3, 3), device=device, dtype=dtype + (n_systems, 3, 3), device=device, dtype=dtype ) - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype + alpha_start_system = torch.full( + (n_systems,), alpha_start.item(), device=device, dtype=dtype ) - atom_wise_dt = state.dt[state.batch].unsqueeze(-1) + atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: @@ -1271,13 +1271,13 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) + cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() deform_grad_log = torch.zeros_like(cur_deform_grad) - for b in range(n_batches): + for b in range(n_systems): deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) cell_positions_log_scaled = deform_grad_log * cell_factor_reshaped @@ -1295,9 +1295,9 @@ def _vv_fire_step( # noqa: C901, PLR0915 if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() - cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) current_cell_positions_scaled = ( - cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded + cur_deform_grad.view(n_systems, 3, 3) * cell_factor_expanded ) cell_positions_scaled_new = ( @@ -1316,19 +1316,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): @@ -1341,7 +1341,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 directions[idx, mu, nu] = 1.0 new_cell_forces = torch.zeros_like(ucf_cell_grad) - for b in range(n_batches): + for b in range(n_systems): expm_derivs = torch.stack( [ tsm.expm_frechet( @@ -1366,49 +1366,51 @@ def _vv_fire_step( # noqa: C901, PLR0915 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) + system_power = tsm.batched_vdot(state.forces, state.velocities, state.system_idx) if is_cell_optimization: - batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + pos_mask_system = system_power > 0.0 + neg_mask_system = ~pos_mask_system - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.n_pos[pos_mask_system] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_system state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.dt[neg_mask_system] *= f_dec + state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] + state.n_pos[neg_mask_system] = 0 - v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) - f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + v_scaling_system = tsm.batched_vdot( + state.velocities, state.velocities, state.system_idx + ) + f_scaling_system = tsm.batched_vdot(state.forces, state.forces, state.system_idx) if is_cell_optimization: - v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) - v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + alpha_cell_bc = state.alpha.view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(n_batches, 1, 1), + pos_mask_system.view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( - pos_mask_batch[state.batch].unsqueeze(-1), + pos_mask_system[state.system_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) @@ -1455,7 +1457,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 Updated state after performing one ASE-FIRE step. """ device, dtype = state.positions.device, state.positions.dtype - n_batches = state.n_batches + n_systems = state.n_systems cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError @@ -1468,92 +1470,92 @@ def _ase_fire_step( # noqa: C901, PLR0915 f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( - (n_batches, 3, 3), device=device, dtype=dtype + (n_systems, 3, 3), device=device, dtype=dtype ) cur_deform_grad = state.deform_grad() else: - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype + alpha_start_system = torch.full( + (n_systems,), alpha_start.item(), device=device, dtype=dtype ) if is_cell_optimization: cur_deform_grad = state.deform_grad() forces = torch.bmm( - state.forces.unsqueeze(1), cur_deform_grad[state.batch] + state.forces.unsqueeze(1), cur_deform_grad[state.system_idx] ).squeeze(1) else: forces = state.forces - # 1. Current power (F·v) per batch (atoms + cell) - batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) + # 1. Current power (F·v) per system (atoms + cell) + system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) if is_cell_optimization: - batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + pos_mask_system = system_power > 0.0 + neg_mask_system = ~pos_mask_system - inc_mask = (state.n_pos > n_min) & pos_mask_batch + inc_mask = (state.n_pos > n_min) & pos_mask_system state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha - state.n_pos[pos_mask_batch] += 1 + state.n_pos[pos_mask_system] += 1 - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.dt[neg_mask_system] *= f_dec + state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] + state.n_pos[neg_mask_system] = 0 # 3. Velocity mixing BEFORE acceleration (ASE ordering) - v_scaling_batch = tsm.batched_vdot( - state.velocities, state.velocities, state.batch + v_scaling_system = tsm.batched_vdot( + state.velocities, state.velocities, state.system_idx ) - f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) + f_scaling_system = tsm.batched_vdot(forces, forces, state.system_idx) if is_cell_optimization: - v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) - v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + alpha_cell_bc = state.alpha.view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(n_batches, 1, 1), + pos_mask_system.view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( - pos_mask_batch[state.batch].unsqueeze(-1), + pos_mask_system[state.system_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += forces * state.dt[state.batch].unsqueeze(-1) - dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) - dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) + state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) + dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, state.system_idx) if is_cell_optimization: - state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) + state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) - dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) - dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) + dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) + dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) + dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom @@ -1562,7 +1564,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.positions = ( torch.linalg.solve( - cur_deform_grad[state.batch], state.positions.unsqueeze(-1) + cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) + dr_atom ) @@ -1580,7 +1582,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") F_current = state.deform_grad() - cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_exp_mult = state.cell_factor.expand(n_systems, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell @@ -1590,7 +1592,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.row_vector_cell = new_row_vector_cell state.positions = torch.bmm( - state.positions.unsqueeze(1), F_new[state.batch].mT + state.positions.unsqueeze(1), F_new[state.system_idx].mT ).squeeze(1) else: state.positions = state.positions + dr_atom @@ -1602,7 +1604,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( # noqa: T201 @@ -1616,13 +1618,13 @@ def _ase_fire_step( # noqa: C901, PLR0915 diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_systems, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): @@ -1645,7 +1647,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 directions[idx, mu, nu] = 1.0 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) - for b_idx in range(n_batches): + for b_idx in range(n_systems): expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 2dbcc52b..5d9898c2 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -24,7 +24,7 @@ def calc_kT( # noqa: N802 momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate temperature in energy units from momenta/velocities and masses. @@ -32,7 +32,7 @@ def calc_kT( # noqa: N802 momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle Returns: @@ -51,29 +51,29 @@ def calc_kT( # noqa: N802 # If momentum provided, calculate v^2 = p^2/m^2 squared_term = (momenta**2) / masses.unsqueeze(-1) - if batch is None: + if system_idx is None: # Count total degrees of freedom dof = count_dof(squared_term) return torch.sum(squared_term) / dof - # Sum squared terms for each batch + # Sum squared terms for each system flattened_squared = torch.sum(squared_term, dim=-1) - # Count degrees of freedom per batch - batch_sizes = torch.bincount(batch) - dof_per_batch = batch_sizes * squared_term.shape[-1] # multiply by n_dimensions + # Count degrees of freedom per system + system_sizes = torch.bincount(system_idx) + dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions - # Calculate temperature per batch - batch_sums = torch.segment_reduce( - flattened_squared, reduce="sum", lengths=batch_sizes + # Calculate temperature per system + system_sums = torch.segment_reduce( + flattened_squared, reduce="sum", lengths=system_sizes ) - return batch_sums / dof_per_batch + return system_sums / dof_per_system def calc_temperature( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, units: object = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. @@ -82,14 +82,14 @@ def calc_temperature( momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle units (object): Units to return the temperature in Returns: torch.Tensor: Temperature value in specified units """ - return calc_kT(momenta, masses, velocities, batch) / units + return calc_kT(momenta, masses, velocities, system_idx) / units # @torch.jit.script @@ -97,7 +97,7 @@ def calc_kinetic_energy( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Computes the kinetic energy of a system. @@ -105,12 +105,12 @@ def calc_kinetic_energy( momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle Returns: - If batch is None: Scalar tensor containing the total kinetic energy - If batch is provided: Tensor of kinetic energies per batch + If system_idx is None: Scalar tensor containing the total kinetic energy + If system_idx is provided: Tensor of kinetic energies per system """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") @@ -122,11 +122,11 @@ def calc_kinetic_energy( else: # Using momenta squared_term = (momenta**2) / masses.unsqueeze(-1) - if batch is None: + if system_idx is None: return 0.5 * torch.sum(squared_term) flattened_squared = torch.sum(squared_term, dim=-1) return 0.5 * torch.segment_reduce( - flattened_squared, reduce="sum", lengths=torch.bincount(batch) + flattened_squared, reduce="sum", lengths=torch.bincount(system_idx) ) @@ -142,18 +142,18 @@ def get_pressure( def batchwise_max_force(state: SimState) -> torch.Tensor: - """Compute the maximum force per batch. + """Compute the maximum force per system. Args: - state (SimState): State to compute the maximum force per batch for. + state (SimState): State to compute the maximum force per system for. Returns: - torch.Tensor: Maximum forces per batch + torch.Tensor: Maximum forces per system """ - batch_wise_max_force = torch.zeros( - state.n_batches, device=state.device, dtype=state.dtype + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=state.dtype ) max_forces = state.forces.norm(dim=1) - return batch_wise_max_force.scatter_reduce( - dim=0, index=state.batch, src=max_forces, reduce="amax" + return system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 291034b7..ef1dcb16 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -172,7 +172,7 @@ def integrate( pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Integrate") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) for state, batch_indices in batch_iterator: state = init_fn(state) @@ -194,7 +194,7 @@ def integrate( # finish the trajectory reporter final_states.append(state) if tqdm_pbar: - tqdm_pbar.update(state.n_batches) + tqdm_pbar.update(state.n_systems) if trajectory_reporter: trajectory_reporter.finish() @@ -307,7 +307,7 @@ def convergence_fn( """Check if the system has converged. Returns: - torch.Tensor: Boolean tensor of shape (n_batches,) indicating + torch.Tensor: Boolean tensor of shape (n_systems,) indicating convergence status for each batch. """ force_conv = batchwise_max_force(state) < force_tol @@ -343,7 +343,7 @@ def convergence_fn( """Check if the system has converged. Returns: - torch.Tensor: Boolean tensor of shape (n_batches,) indicating + torch.Tensor: Boolean tensor of shape (n_systems,) indicating convergence status for each batch. """ return torch.abs(state.energy - last_energy) < energy_tol @@ -372,7 +372,7 @@ def optimize( # noqa: C901 model (ModelInterface): Neural network model module optimizer (Callable): Optimization algorithm function convergence_fn (Callable | None): Condition for convergence, should return a - boolean tensor of length n_batches + boolean tensor of length n_systems optimizer_kwargs: Additional keyword arguments for optimizer init function trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking optimization trajectory. If a dict, will be passed to the @@ -434,7 +434,7 @@ def optimize( # noqa: C901 pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Optimize") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) while (result := autobatcher.next_batch(state, convergence_tensor))[0] is not None: state, converged_states, batch_indices = result @@ -545,7 +545,7 @@ class StaticState(type(state)): pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Static") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) for sub_state, batch_indices in batch_iterator: # set up trajectory reporters @@ -568,7 +568,7 @@ class StaticState(type(state)): all_props.extend(props) if tqdm_pbar: - tqdm_pbar.update(sub_state.n_batches) + tqdm_pbar.update(sub_state.n_systems) trajectory_reporter.finish() diff --git a/torch_sim/state.py b/torch_sim/state.py index f01c3ca5..ce21ef9b 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -29,48 +29,49 @@ class SimState: Contains the fundamental properties needed to describe an atomistic system: positions, masses, unit cell, periodic boundary conditions, and atomic numbers. Supports batched operations where multiple atomistic systems can be processed - simultaneously, managed through batch indices. + simultaneously, managed through system indices. States support slicing, cloning, splitting, popping, and movement to other data structures or devices. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state containing only the first three - batches. The other operations are available through the `pop`, `split`, `clone`, + systems. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. Attributes: positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) masses (torch.Tensor): Atomic masses with shape (n_atoms,) - cell (torch.Tensor): Unit cell vectors with shape (n_batches, 3, 3). + cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). Note that we use a column vector convention, i.e. the cell vectors are stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - batch (torch.Tensor, optional): Batch indices with shape (n_atoms,), - defaults to None, must be unique consecutive integers starting from 0 + system_idx (torch.Tensor, optional): Maps each atom index to its system index. + Has shape (n_atoms,), defaults to None, must be unique consecutive + integers starting from 0 Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary conditions device (torch.device): Device of the positions tensor dtype (torch.dtype): Data type of the positions tensor - n_atoms (int): Total number of atoms across all batches - n_batches (int): Number of unique batches in the system + n_atoms (int): Total number of atoms across all systems + n_systems (int): Number of unique systems in the system Notes: - positions, masses, and atomic_numbers must have shape (n_atoms, 3). - cell must be in the conventional matrix form. - - batch indices must be unique consecutive integers starting from 0. + - system indices must be unique consecutive integers starting from 0. Examples: >>> state = initialize_state( ... [ase_atoms_1, ase_atoms_2, ase_atoms_3], device, dtype ... ) - >>> state.n_batches + >>> state.n_systems 3 >>> new_state = state[[0, 1]] - >>> new_state.n_batches + >>> new_state.n_systems 2 >>> cloned_state = state.clone() """ @@ -80,11 +81,11 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - batch: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor | None = field(default=None, kw_only=True) def __post_init__(self) -> None: """Validate and process the state after initialization.""" - # data validation and fill batch + # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way devices = { @@ -106,23 +107,27 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.batch is None: + if self.cell.ndim != 3 and self.system_idx is None: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_batches, 3, 3)") + raise ValueError("Cell must have shape (n_systems, 3, 3)") - if self.batch is None: - self.batch = torch.zeros(self.n_atoms, device=self.device, dtype=torch.int64) + if self.system_idx is None: + self.system_idx = torch.zeros( + self.n_atoms, device=self.device, dtype=torch.int64 + ) else: - # assert that batch indices are unique consecutive integers - _, counts = torch.unique_consecutive(self.batch, return_counts=True) - if not torch.all(counts == torch.bincount(self.batch)): - raise ValueError("Batch indices must be unique consecutive integers") - - if self.cell.shape[0] != self.n_batches: + # assert that system indices are unique consecutive integers + # TODO(curtis): I feel like this logic is not reliable. + # I'll come up with something better later. + _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) + if not torch.all(counts == torch.bincount(self.system_idx)): + raise ValueError("System indices must be unique consecutive integers") + + if self.cell.shape[0] != self.n_systems: raise ValueError( - f"Cell must have shape (n_batches, 3, 3), got {self.cell.shape}" + f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" ) @property @@ -145,22 +150,78 @@ def dtype(self) -> torch.dtype: @property def n_atoms(self) -> int: - """Total number of atoms in the system across all batches.""" + """Total number of atoms in the system across all systems.""" return self.positions.shape[0] @property - def n_atoms_per_batch(self) -> torch.Tensor: - """Number of atoms per batch.""" + def n_atoms_per_system(self) -> torch.Tensor: + """Number of atoms per system.""" return ( - self.batch.bincount() - if self.batch is not None + self.system_idx.bincount() + if self.system_idx is not None else torch.tensor([self.n_atoms], device=self.device) ) + @property + def n_atoms_per_batch(self) -> torch.Tensor: + """Number of atoms per batch. + + deprecated:: + Use :attr:`n_atoms_per_system` instead. + """ + warnings.warn( + "n_atoms_per_batch is deprecated, use n_atoms_per_system instead", + DeprecationWarning, + stacklevel=2, + ) + return self.n_atoms_per_system + + @property + def batch(self) -> torch.Tensor | None: + """System indices. + + deprecated:: + Use :attr:`system_idx` instead. + """ + warnings.warn( + "batch is deprecated, use system_idx instead", + DeprecationWarning, + stacklevel=2, + ) + return self.system_idx + + @batch.setter + def batch(self, system_idx: torch.Tensor) -> None: + """Set the system indices from a batch index. + + deprecated:: + Use :attr:`system_idx` instead. + """ + warnings.warn( + "Setting batch is deprecated, use system_idx instead", + DeprecationWarning, + stacklevel=2, + ) + self.system_idx = system_idx + @property def n_batches(self) -> int: - """Number of batches in the system.""" - return torch.unique(self.batch).shape[0] + """Number of batches in the system. + + deprecated:: + Use :attr:`n_systems` instead. + """ + warnings.warn( + "n_batches is deprecated, use n_systems instead", + DeprecationWarning, + stacklevel=2, + ) + return self.n_systems + + @property + def n_systems(self) -> int: + """Number of systems in the system.""" + return torch.unique(self.system_idx).shape[0] @property def volume(self) -> torch.Tensor: @@ -217,7 +278,7 @@ def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. Returns: - list[Atoms]: A list of ASE Atoms objects, one per batch + list[Atoms]: A list of ASE Atoms objects, one per system """ return ts.io.state_to_atoms(self) @@ -225,7 +286,7 @@ def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. Returns: - list[Structure]: A list of pymatgen Structure objects, one per batch + list[Structure]: A list of pymatgen Structure objects, one per system """ return ts.io.state_to_structures(self) @@ -233,43 +294,43 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """Convert the SimState to a list of PhonopyAtoms objects. Returns: - list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per batch + list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per system """ return ts.io.state_to_phonopy(self) def split(self) -> list[Self]: - """Split the SimState into a list of single-batch SimStates. + """Split the SimState into a list of single-system SimStates. - Divides the current state into separate states, each containing a single batch, - preserving all properties appropriately for each batch. + Divides the current state into separate states, each containing a single system, + preserving all properties appropriately for each system. Returns: - list[SimState]: A list of SimState objects, one per batch + list[SimState]: A list of SimState objects, one per system """ return _split_state(self) - def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: - """Pop off states with the specified batch indices. + def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: + """Pop off states with the specified system indices. This method modifies the original state object by removing the specified - batches and returns the removed batches as separate SimState objects. + systems and returns the removed systems as separate SimState objects. Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices + system_indices (int | list[int] | slice | torch.Tensor): The system indices to pop Returns: - list[SimState]: Popped SimState objects, one per batch index + list[SimState]: Popped SimState objects, one per system index Notes: This method modifies the original SimState in-place. """ - batch_indices = _normalize_batch_indices( - batch_indices, self.n_batches, self.device + system_indices = _normalize_system_indices( + system_indices, self.n_systems, self.device ) # Get the modified state and popped states - modified_state, popped_states = _pop_states(self, batch_indices) + modified_state, popped_states = _pop_states(self, system_indices) # Update all attributes of self with the modified state's attributes for attr_name, attr_value in vars(modified_state).items(): @@ -293,23 +354,23 @@ def to( """ return state_to_device(self, device, dtype) - def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self: + def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices + system_indices (int | list[int] | slice | torch.Tensor): The system indices to include Returns: - SimState: A new SimState containing only the specified batches + SimState: A new SimState containing only the specified systems """ # TODO: need to document that slicing is supported # Reuse the existing slice method - batch_indices = _normalize_batch_indices( - batch_indices, self.n_batches, self.device + system_indices = _normalize_system_indices( + system_indices, self.n_systems, self.device ) - return _slice_state(self, batch_indices) + return _slice_state(self, system_indices) class DeformGradMixin: @@ -356,44 +417,44 @@ def deform_grad(self) -> torch.Tensor: return self._deform_grad(self.reference_row_vector_cell, self.row_vector_cell) -def _normalize_batch_indices( - batch_indices: int | list[int] | slice | torch.Tensor, - n_batches: int, +def _normalize_system_indices( + system_indices: int | list[int] | slice | torch.Tensor, + n_systems: int, device: torch.device, ) -> torch.Tensor: - """Normalize batch indices to handle negative indices and different input types. + """Normalize system indices to handle negative indices and different input types. - Converts various batch index representations to a consistent tensor format, + Converts various system index representations to a consistent tensor format, handling negative indices in the Python style (counting from the end). Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices to + system_indices (int | list[int] | slice | torch.Tensor): The system indices to normalize - n_batches (int): Total number of batches in the system + n_systems (int): Total number of systems in the system device (torch.device): Device to place the output tensor on Returns: - torch.Tensor: Normalized batch indices as a tensor + torch.Tensor: Normalized system indices as a tensor Raises: - TypeError: If batch_indices is of an unsupported type + TypeError: If system_indices is of an unsupported type """ - if isinstance(batch_indices, int): + if isinstance(system_indices, int): # Handle negative integer indexing - if batch_indices < 0: - batch_indices = n_batches + batch_indices - return torch.tensor([batch_indices], device=device) - if isinstance(batch_indices, list): + if system_indices < 0: + system_indices = n_systems + system_indices + return torch.tensor([system_indices], device=device) + if isinstance(system_indices, list): # Handle negative indices in lists - normalized = [idx if idx >= 0 else n_batches + idx for idx in batch_indices] + normalized = [idx if idx >= 0 else n_systems + idx for idx in system_indices] return torch.tensor(normalized, device=device) - if isinstance(batch_indices, slice): + if isinstance(system_indices, slice): # Let PyTorch handle the slice conversion with negative indices - return torch.arange(n_batches, device=device)[batch_indices] - if isinstance(batch_indices, torch.Tensor): + return torch.arange(n_systems, device=device)[system_indices] + if isinstance(system_indices, torch.Tensor): # Handle negative indices in tensors - return torch.where(batch_indices < 0, n_batches + batch_indices, batch_indices) - raise TypeError(f"Unsupported index type: {type(batch_indices)}") + return torch.where(system_indices < 0, n_systems + system_indices, system_indices) + raise TypeError(f"Unsupported index type: {type(system_indices)}") def state_to_device( @@ -435,8 +496,8 @@ def state_to_device( def infer_property_scope( state: SimState, ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_batch"], list[str]]: - """Infer whether a property is global, per-atom, or per-batch. +) -> dict[Literal["global", "per_atom", "per_system"], list[str]]: + """Infer whether a property is global, per-atom, or per-system. Analyzes the shapes of tensor attributes to determine their scope within the atomistic system representation. @@ -450,27 +511,27 @@ def infer_property_scope( - "globalize_warn": Treat ambiguous properties as global with a warning Returns: - dict[Literal["global", "per_atom", "per_batch"], list[str]]: Map of scope + dict[Literal["global", "per_atom", "per_system"], list[str]]: Map of scope category to list of property names Raises: - ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or + ValueError: If n_atoms equals n_systems (making scope inference ambiguous) or if ambiguous_handling="error" and an ambiguous property is encountered """ # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_batches, they will be classified incorrectly, + # length of n_atoms or n_systems, they will be classified incorrectly, # no clear fix - if state.n_atoms == state.n_batches: + if state.n_atoms == state.n_systems: raise ValueError( - f"n_atoms ({state.n_atoms}) and n_batches ({state.n_batches}) are equal, " + f"n_atoms ({state.n_atoms}) and n_systems ({state.n_systems}) are equal, " "which means shapes cannot be inferred unambiguously." ) scope = { "global": [], "per_atom": [], - "per_batch": [], + "per_system": [], } # Iterate through all attributes @@ -489,15 +550,15 @@ def infer_property_scope( # Vector/matrix with first dimension matching number of atoms elif shape[0] == state.n_atoms: scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of batches - elif shape[0] == state.n_batches: - scope["per_batch"].append(attr_name) + # Tensor with first dimension matching number of systems + elif shape[0] == state.n_systems: + scope["per_system"].append(attr_name) # Any other shape is ambiguous elif ambiguous_handling == "error": raise ValueError( f"Cannot categorize property '{attr_name}' with shape {shape}. " f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_batches} (per-batch), or a scalar (global)." + f"{state.n_systems} (per-system), or a scalar (global)." ) elif ambiguous_handling in ("globalize", "globalize_warn"): scope["global"].append(attr_name) @@ -516,10 +577,10 @@ def infer_property_scope( def _get_property_attrs( state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" ) -> dict[str, dict]: - """Get global, per-atom, and per-batch attributes from a state. + """Get global, per-atom, and per-system attributes from a state. Categorizes all attributes of the state based on their scope - (global, per-atom, or per-batch). + (global, per-atom, or per-system). Args: state (SimState): The state to extract attributes from @@ -527,12 +588,12 @@ def _get_property_attrs( properties Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_batch', each + dict[str, dict]: Keys are 'global', 'per_atom', and 'per_system', each containing a dictionary of attribute names to values """ scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - attrs = {"global": {}, "per_atom": {}, "per_batch": {}} + attrs = {"global": {}, "per_atom": {}, "per_system": {}} # Process global properties for attr_name in scope["global"]: @@ -542,9 +603,9 @@ def _get_property_attrs( for attr_name in scope["per_atom"]: attrs["per_atom"][attr_name] = getattr(state, attr_name) - # Process per-batch properties - for attr_name in scope["per_batch"]: - attrs["per_batch"][attr_name] = getattr(state, attr_name) + # Process per-system properties + for attr_name in scope["per_system"]: + attrs["per_system"][attr_name] = getattr(state, attr_name) return attrs @@ -552,19 +613,19 @@ def _get_property_attrs( def _filter_attrs_by_mask( attrs: dict[str, dict], atom_mask: torch.Tensor, - batch_mask: torch.Tensor, + system_mask: torch.Tensor, ) -> dict: - """Filter attributes by atom and batch masks. + """Filter attributes by atom and system masks. - Selects subsets of attributes based on boolean masks for atoms and batches. + Selects subsets of attributes based on boolean masks for atoms and systems. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_batch', each + attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_system', each containing a dictionary of attribute names to values atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) - batch_mask (torch.Tensor): Boolean mask for batches to include with shape - (n_batches,) + system_mask (torch.Tensor): Boolean mask for systems to include with shape + (n_systems,) Returns: dict: Filtered attributes with appropriate handling for each scope @@ -576,31 +637,31 @@ def _filter_attrs_by_mask( # Filter per-atom attributes for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "batch": - # Get the old batch indices for the selected atoms - old_batch = attr_value[atom_mask] + if attr_name == "system_idx": + # Get the old system indices for the selected atoms + old_system_idxs = attr_value[atom_mask] - # Get the batch indices that are kept + # Get the system indices that are kept kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ - batch_mask + system_mask ] - # Create a mapping from old batch indices to new consecutive indices - batch_map = {idx.item(): i for i, idx in enumerate(kept_indices)} + # Create a mapping from old system indices to new consecutive indices + system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)} - # Create new batch tensor with remapped indices - new_batch = torch.tensor( - [batch_map[b.item()] for b in old_batch], + # Create new system tensor with remapped indices + new_system_idxs = torch.tensor( + [system_idx_map[b.item()] for b in old_system_idxs], device=attr_value.device, dtype=attr_value.dtype, ) - filtered_attrs[attr_name] = new_batch + filtered_attrs[attr_name] = new_system_idxs else: filtered_attrs[attr_name] = attr_value[atom_mask] - # Filter per-batch attributes - for attr_name, attr_value in attrs["per_batch"].items(): - filtered_attrs[attr_name] = attr_value[batch_mask] + # Filter per-system attributes + for attr_name, attr_value in attrs["per_system"].items(): + filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -609,10 +670,10 @@ def _split_state( state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[SimState]: - """Split a SimState into a list of states, each containing a single batch element. + """Split a SimState into a list of states, each containing a single system. - Divides a multi-batch state into individual single-batch states, preserving - appropriate properties for each batch. + Divides a multi-system state into individual single-system states, preserving + appropriate properties for each system. Args: state (SimState): The SimState to split @@ -623,37 +684,42 @@ def _split_state( Returns: list[SimState]: A list of SimState objects, each containing a single - batch element + system """ attrs = _get_property_attrs(state, ambiguous_handling) - batch_sizes = torch.bincount(state.batch).tolist() + system_sizes = torch.bincount(state.system_idx).tolist() - # Split per-atom attributes by batch + # Split per-atom attributes by system split_per_atom = {} for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "batch": + if attr_name == "system_idx": continue - split_per_atom[attr_name] = torch.split(attr_value, batch_sizes, dim=0) + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - # Split per-batch attributes into individual elements - split_per_batch = {} - for attr_name, attr_value in attrs["per_batch"].items(): - split_per_batch[attr_name] = torch.split(attr_value, 1, dim=0) + # Split per-system attributes into individual elements + split_per_system = {} + for attr_name, attr_value in attrs["per_system"].items(): + split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) - # Create a state for each batch + # Create a state for each system states = [] - for i in range(state.n_batches): - batch_attrs = { - # Create a batch tensor with all zeros for this batch - "batch": torch.zeros(batch_sizes[i], device=state.device, dtype=torch.int64), + for i in range(state.n_systems): + system_attrs = { + # Create a system tensor with all zeros for this system + "system_idx": torch.zeros( + system_sizes[i], device=state.device, dtype=torch.int64 + ), # Add the split per-atom attributes **{attr_name: split_per_atom[attr_name][i] for attr_name in split_per_atom}, - # Add the split per-batch attributes - **{attr_name: split_per_batch[attr_name][i] for attr_name in split_per_batch}, + # Add the split per-system attributes + **{ + attr_name: split_per_system[attr_name][i] + for attr_name in split_per_system + }, # Add the global attributes **attrs["global"], } - states.append(type(state)(**batch_attrs)) + states.append(type(state)(**system_attrs)) return states @@ -665,11 +731,11 @@ def _pop_states( ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. - Extracts and removes the specified batch indices from the state. + Extracts and removes the specified system indices from the state. Args: state (SimState): The SimState to modify - pop_indices (list[int] | torch.Tensor): The batch indices to extract and remove + pop_indices (list[int] | torch.Tensor): The system indices to extract and remove ambiguous_handling ("error" | "globalize"): How to handle ambiguous properties. If "error", an error is raised if a property has ambiguous scope. If "globalize", properties with ambiguous scope are treated as @@ -677,8 +743,8 @@ def _pop_states( Returns: tuple[SimState, list[SimState]]: A tuple containing: - - The modified original state with specified batches removed - - A list of the extracted SimStates, one per popped batch + - The modified original state with specified systems removed + - A list of the extracted SimStates, one per popped system Notes: Unlike the pop method, this function does not modify the input state. @@ -691,17 +757,17 @@ def _pop_states( attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and batches to keep and pop - batch_range = torch.arange(state.n_batches, device=state.device) - pop_batch_mask = torch.isin(batch_range, pop_indices) - keep_batch_mask = ~pop_batch_mask + # Create masks for the atoms and systems to keep and pop + system_range = torch.arange(state.n_systems, device=state.device) + pop_system_mask = torch.isin(system_range, pop_indices) + keep_system_mask = ~pop_system_mask - pop_atom_mask = torch.isin(state.batch, pop_indices) + pop_atom_mask = torch.isin(state.system_idx, pop_indices) keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_batch_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_batch_mask) + keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_system_mask) + pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) @@ -715,17 +781,17 @@ def _pop_states( def _slice_state( state: SimState, - batch_indices: list[int] | torch.Tensor, + system_indices: list[int] | torch.Tensor, ambiguous_handling: Literal["error", "globalize"] = "error", ) -> SimState: - """Slice a substate from the SimState containing only the specified batch indices. + """Slice a substate from the SimState containing only the specified system indices. - Creates a new SimState containing only the specified batches, preserving + Creates a new SimState containing only the specified systems, preserving all relevant properties. Args: state (SimState): The state to slice - batch_indices (list[int] | torch.Tensor): Batch indices to include in the + system_indices (list[int] | torch.Tensor): System indices to include in the sliced state ambiguous_handling ("error" | "globalize"): How to handle ambiguous properties. If "error", an error is raised if a property has ambiguous @@ -733,28 +799,28 @@ def _slice_state( global. Returns: - SimState: A new SimState object containing only the specified batches + SimState: A new SimState object containing only the specified systems Raises: - ValueError: If batch_indices is empty + ValueError: If system_indices is empty """ - if isinstance(batch_indices, list): - batch_indices = torch.tensor( - batch_indices, device=state.device, dtype=torch.int64 + if isinstance(system_indices, list): + system_indices = torch.tensor( + system_indices, device=state.device, dtype=torch.int64 ) - if len(batch_indices) == 0: - raise ValueError("batch_indices cannot be empty") + if len(system_indices) == 0: + raise ValueError("system_indices cannot be empty") attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and batches to include - batch_range = torch.arange(state.n_batches, device=state.device) - batch_mask = torch.isin(batch_range, batch_indices) - atom_mask = torch.isin(state.batch, batch_indices) + # Create masks for the atoms and systems to include + system_range = torch.arange(state.n_systems, device=state.device) + system_mask = torch.isin(system_range, system_indices) + atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, batch_mask) + filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -765,8 +831,8 @@ def concatenate_states( ) -> SimState: """Concatenate a list of SimStates into a single SimState. - Combines multiple states into a single state with multiple batches. - Global properties are taken from the first state, and per-atom and per-batch + Combines multiple states into a single state with multiple systems. + Global properties are taken from the first state, and per-atom and per-system properties are concatenated. Args: @@ -775,7 +841,7 @@ def concatenate_states( Defaults to the device of the first state. Returns: - SimState: A new SimState containing all input states as separate batches + SimState: A new SimState containing all input states as separate systems Raises: ValueError: If states is empty @@ -796,20 +862,20 @@ def concatenate_states( target_device = device or first_state.device # Get property scopes from the first state to identify - # global/per-atom/per-batch properties + # global/per-atom/per-system properties first_scope = infer_property_scope(first_state) global_props = set(first_scope["global"]) per_atom_props = set(first_scope["per_atom"]) - per_batch_props = set(first_scope["per_batch"]) + per_system_props = set(first_scope["per_system"]) # Initialize result with global properties from first state concatenated = {prop: getattr(first_state, prop) for prop in global_props} # Pre-allocate lists for tensors to concatenate per_atom_tensors = {prop: [] for prop in per_atom_props} - per_batch_tensors = {prop: [] for prop in per_batch_props} - new_batch_indices = [] - batch_offset = 0 + per_system_tensors = {prop: [] for prop in per_system_props} + new_system_indices = [] + system_offset = 0 # Process all states in a single pass for state in states: @@ -822,28 +888,28 @@ def concatenate_states( # if hasattr(state, prop): per_atom_tensors[prop].append(getattr(state, prop)) - # Collect per-batch properties - for prop in per_batch_props: + # Collect per-system properties + for prop in per_system_props: # if hasattr(state, prop): - per_batch_tensors[prop].append(getattr(state, prop)) + per_system_tensors[prop].append(getattr(state, prop)) - # Update batch indices - num_batches = state.n_batches - new_indices = state.batch + batch_offset - new_batch_indices.append(new_indices) - batch_offset += num_batches + # Update system indices + num_systems = state.n_systems + new_indices = state.system_idx + system_offset + new_system_indices.append(new_indices) + system_offset += num_systems # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) - for prop, tensors in per_batch_tensors.items(): + for prop, tensors in per_system_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) - # Concatenate batch indices - concatenated["batch"] = torch.cat(new_batch_indices) + # Concatenate system indices + concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class return state_class(**concatenated) @@ -877,10 +943,10 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_batches == 1 for state in system): + if not all(state.n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " - "all states must have n_batches == 1. To fix this, you can split the " + "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) return concatenate_states(system) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index bc24de63..3150064a 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -208,11 +208,11 @@ def report( """Report a state and step to the trajectory files. Writes states and calculated properties to all trajectory files at the - specified frequencies. Splits multi-batch states across separate trajectory - files. The number of batches must match the number of trajectory files. + specified frequencies. Splits multi-system states across separate trajectory + files. The number of systems must match the number of trajectory files. Args: - state (SimState): Current system state with n_batches equal to + state (SimState): Current system state with n_systems equal to len(filenames) step (int): Current simulation step, setting step to 0 will write the state and all properties. @@ -224,27 +224,28 @@ def report( are being collected separately. Returns: - list[dict[str, torch.Tensor]]: Map of property names to tensors for each batch + list[dict[str, torch.Tensor]]: Map of property names to tensors for each + system. Raises: - ValueError: If number of batches doesn't match number of trajectory files + ValueError: If number of systems doesn't match number of trajectory files """ - # Get unique batch indices - batch_indices = range(state.n_batches) - # batch_indices = torch.unique(state.batch).cpu().tolist() + # Get unique system indices + system_indices = range(state.n_systems) + # system_indices = torch.unique(state.system_idx).cpu().tolist() # Ensure we have the right number of trajectories - if self.filenames is not None and len(batch_indices) != len(self.trajectories): + if self.filenames is not None and len(system_indices) != len(self.trajectories): raise ValueError( - f"Number of batches ({len(batch_indices)}) doesn't match " + f"Number of systems ({len(system_indices)}) doesn't match " f"number of trajectory files ({len(self.trajectories)})" ) split_states = state.split() all_props: list[dict[str, torch.Tensor]] = [] - # Process each batch separately + # Process each system separately for idx, substate in enumerate(split_states): - # Slice the state once to get only the data for this batch + # Slice the state once to get only the data for this system self.shape_warned = True # Write state to trajectory if it's time @@ -256,7 +257,7 @@ def report( self.trajectories[idx].write_state(substate, step, **self.state_kwargs) all_state_props = {} - # Process property calculators for this batch + # Process property calculators for this system for report_frequency, calculators in self.prop_calculators.items(): if step % report_frequency != 0 or report_frequency == 0: continue @@ -672,7 +673,7 @@ def write_state( # noqa: C901 self, state: SimState | list[SimState], steps: int | list[int], - batch_index: int | None = None, + system_index: int | None = None, *, save_velocities: bool = False, save_forces: bool = False, @@ -692,7 +693,7 @@ def write_state( # noqa: C901 Args: state (SimState | list[SimState]): SimState or list of SimStates to write steps (int | list[int]): Step number(s) for the frame(s) - batch_index (int, optional): Batch index to save. + system_index (int, optional): System index to save. save_velocities (bool, optional): Whether to save velocities. save_forces (bool, optional): Whether to save forces. variable_cell (bool, optional): Whether the cell varies between frames. @@ -712,15 +713,15 @@ def write_state( # noqa: C901 if isinstance(steps, int): steps = [steps] - if isinstance(batch_index, int): - batch_index = [batch_index] - sub_states = [state[batch_index] for state in state] - elif batch_index is None and torch.unique(state[0].batch) == 0: - batch_index = 0 + if isinstance(system_index, int): + system_index = [system_index] + sub_states = [state[system_index] for state in state] + elif system_index is None and torch.unique(state[0].system_idx) == 0: + system_index = 0 sub_states = state else: raise ValueError( - "Batch index must be specified if there are multiple batches" + "System index must be specified if there are multiple systems" ) if len(sub_states) != len(steps): diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 44d27181..29b6fa9a 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -23,9 +23,9 @@ def get_fractional_coordinates( Args: positions (torch.Tensor): Atomic positions in Cartesian coordinates. - Shape: [..., 3] where ... represents optional batch dimensions. + Shape: [..., 3] where ... represents optional system dimensions. cell (torch.Tensor): Unit cell matrix with lattice vectors as rows. - Shape: [..., 3, 3] where ... matches positions' batch dimensions. + Shape: [..., 3, 3] where ... matches positions' system dimensions. Returns: torch.Tensor: Atomic positions in fractional coordinates with same shape as input @@ -42,21 +42,21 @@ def get_fractional_coordinates( """ if cell.ndim == 3: # Handle batched cell tensors # For batched cells, we need to determine if this is: - # 1. A single batch (n_batches=1) - can be squeezed and handled normally - # 2. Multiple batches - need proper batch handling + # 1. A single system (n_systems=1) - can be squeezed and handled normally + # 2. Multiple systems - need proper system handling if cell.shape[0] == 1: - # Single batch case - squeeze and use the 2D implementation + # Single system case - squeeze and use the 2D implementation cell_2d = cell.squeeze(0) # Remove batch dimension return torch.linalg.solve(cell_2d.mT, positions.mT).mT - # Multiple batches case - this would require batch indices to know which - # atoms belong to which batch. For now, this is not implemented. + # Multiple systems case - this would require system indices to know which + # atoms belong to which system. For now, this is not implemented. raise NotImplementedError( - f"Multiple batched cell tensors with shape {cell.shape} are not yet " - "supported in get_fractional_coordinates. For multiple batch systems, " - "you need to provide batch indices to determine which atoms belong to " - "which batch. For single batch systems, consider squeezing the batch " - "dimension or using individual calls per batch." + f"Multiple system cell tensors with shape {cell.shape} are not yet " + "supported in get_fractional_coordinates. For multiple system systems, " + "you need to provide system indices to determine which atoms belong to " + "which system. For single system systems, consider squeezing the batch " + "dimension or using individual calls per system." ) # Original case for 2D cell matrix @@ -155,20 +155,20 @@ def pbc_wrap_general( def pbc_wrap_batched( - positions: torch.Tensor, cell: torch.Tensor, batch: torch.Tensor + positions: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor ) -> torch.Tensor: """Apply periodic boundary conditions to batched systems. This function handles wrapping positions for multiple atomistic systems - (batches) in one operation. It uses the batch indices to determine which + (systems) in one operation. It uses the system indices to determine which atoms belong to which system and applies the appropriate cell vectors. Args: positions (torch.Tensor): Tensor of shape (n_atoms, 3) containing particle positions in real space. - cell (torch.Tensor): Tensor of shape (n_batches, 3, 3) containing + cell (torch.Tensor): Tensor of shape (n_systems, 3, 3) containing lattice vectors as column vectors. - batch (torch.Tensor): Tensor of shape (n_atoms,) containing batch + system_idx (torch.Tensor): Tensor of shape (n_atoms,) containing system indices for each atom. Returns: @@ -182,33 +182,33 @@ def pbc_wrap_batched( if positions.shape[-1] != cell.shape[-1]: raise ValueError("Position dimensionality must match lattice vectors.") - # Get unique batch indices and counts - unique_batches = torch.unique(batch) - n_batches = len(unique_batches) + # Get unique system indices and counts + unique_systems = torch.unique(system_idx) + n_systems = len(unique_systems) - if n_batches != cell.shape[0]: + if n_systems != cell.shape[0]: raise ValueError( - f"Number of unique batches ({n_batches}) doesn't " + f"Number of unique systems ({n_systems}) doesn't " f"match number of cells ({cell.shape[0]})" ) # Efficient approach without explicit loops - # Get the cell for each atom based on its batch index - B = torch.linalg.inv(cell) # Shape: (n_batches, 3, 3) - B_per_atom = B[batch] # Shape: (n_atoms, 3, 3) + # Get the cell for each atom based on its system index + B = torch.linalg.inv(cell) # Shape: (n_systems, 3, 3) + B_per_atom = B[system_idx] # Shape: (n_atoms, 3, 3) # Transform to fractional coordinates: f = B·r - # For each atom, multiply its position by its batch's inverse cell matrix + # For each atom, multiply its position by its system's inverse cell matrix frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) # Wrap to reference cell [0,1) using modulo wrapped_frac = frac_coords % 1.0 # Transform back to real space: r = A·f - # Get the cell for each atom based on its batch index - cell_per_atom = cell[batch] # Shape: (n_atoms, 3, 3) + # Get the cell for each atom based on its system index + cell_per_atom = cell[system_idx] # Shape: (n_atoms, 3, 3) - # For each atom, multiply its wrapped fractional coords by its batch's cell matrix + # For each atom, multiply its wrapped fractional coords by its system's cell matrix return torch.bmm(cell_per_atom, wrapped_frac.unsqueeze(2)).squeeze(2) @@ -535,7 +535,7 @@ def compute_distances_with_cell_shifts( def compute_cell_shifts( - cell: torch.Tensor, shifts_idx: torch.Tensor, batch_mapping: torch.Tensor + cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor ) -> torch.Tensor: """Compute the cell shifts based on the provided indices and cell matrix. @@ -547,18 +547,18 @@ def compute_cell_shifts( representing the unit cell matrices. shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) representing the indices for shifts. - batch_mapping (torch.Tensor): A tensor of shape (n_batches,) + system_mapping (torch.Tensor): A tensor of shape (n_systems,) that maps the shifts to the corresponding cells. Returns: - torch.Tensor: A tensor of shape (n_batches, 3) containing + torch.Tensor: A tensor of shape (n_systems, 3) containing the computed cell shifts. """ if cell is None: cell_shifts = None else: cell_shifts = torch.einsum( - "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[batch_mapping] + "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping] ) return cell_shifts @@ -625,7 +625,7 @@ def build_naive_neighborhood( This function computes a neighborhood list of atoms within a specified cutoff distance, considering periodic boundary conditions defined by the unit cell. It returns the mapping of atom pairs, - the batch mapping for each structure, and the corresponding shifts. + the system mapping for each structure, and the corresponding shifts. Args: positions (torch.Tensor): A tensor of shape (n_atoms, 3) @@ -645,7 +645,7 @@ def build_naive_neighborhood( tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - mapping (torch.Tensor): A tensor of shape (n_pairs, 2) representing the pairs of indices for neighboring atoms. - - batch_mapping (torch.Tensor): A tensor of shape (n_pairs,) + - system_mapping (torch.Tensor): A tensor of shape (n_pairs,) indicating the structure index for each pair. - shifts_idx (torch.Tensor): A tensor of shape (n_pairs, 3) representing the shifts applied for periodic boundary @@ -659,7 +659,7 @@ def build_naive_neighborhood( stride = strides_of(n_atoms) ids = torch.arange(positions.shape[0], device=device, dtype=torch.long) - mapping, batch_mapping, shifts_idx_ = [], [], [] + mapping, system_mapping, shifts_idx_ = [], [], [] for i_structure in range(n_atoms.shape[0]): num_repeats = num_repeats_[i_structure] shifts_idx = get_cell_shift_idx(num_repeats, dtype) @@ -669,7 +669,7 @@ def build_naive_neighborhood( i_ids=i_ids, shifts_idx=shifts_idx, self_interaction=self_interaction ) mapping.append(s_mapping) - batch_mapping.append( + system_mapping.append( torch.full( (s_mapping.shape[0],), i_structure, @@ -680,7 +680,7 @@ def build_naive_neighborhood( shifts_idx_.append(shifts_idx) return ( torch.cat(mapping, dim=0).t(), - torch.cat(batch_mapping, dim=0), + torch.cat(system_mapping, dim=0), torch.cat(shifts_idx_, dim=0), ) @@ -998,7 +998,7 @@ def build_linked_cell_neighborhood( - mapping (torch.Tensor): A tensor containing pairs of indices where mapping[0] represents the central atom indices and mapping[1] represents their corresponding neighbor indices. - - batch_mapping (torch.Tensor): A tensor containing the structure indices + - system_mapping (torch.Tensor): A tensor containing the structure indices corresponding to each neighbor atom. - cell_shifts_idx (torch.Tensor): A tensor containing the cell shift indices for each neighbor atom, which are necessary for @@ -1014,7 +1014,7 @@ def build_linked_cell_neighborhood( stride = strides_of(n_atoms) - mapping, batch_mapping, cell_shifts_idx = [], [], [] + mapping, system_mapping, cell_shifts_idx = [], [], [] for i_structure in range(n_structure): # Compute the neighborhood with the linked cell algorithm neigh_atom, neigh_shift_idx = linked_cell( @@ -1025,7 +1025,7 @@ def build_linked_cell_neighborhood( self_interaction, ) - batch_mapping.append( + system_mapping.append( i_structure * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device) ) # Shift the mapping indices to access positions @@ -1034,7 +1034,7 @@ def build_linked_cell_neighborhood( return ( torch.cat(mapping, dim=1), - torch.cat(batch_mapping, dim=0), + torch.cat(system_mapping, dim=0), torch.cat(cell_shifts_idx, dim=0), ) diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 13f0db94..94ec44ca 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -15,7 +15,7 @@ MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] -StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "batch"] +StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"] StateDict = dict[StateKey, torch.Tensor] SimStateVar = TypeVar("SimStateVar", bound="SimState") diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 140c0031..e3abb8d5 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -730,9 +730,9 @@ def get_unit_cell_relaxed_structure( device, dtype = model.device, model.dtype logger = { - "energy": torch.zeros((max_iter, state.n_batches), device=device, dtype=dtype), + "energy": torch.zeros((max_iter, state.n_systems), device=device, dtype=dtype), "stress": torch.zeros( - (max_iter, state.n_batches, 3, 3), device=device, dtype=dtype + (max_iter, state.n_systems, 3, 3), device=device, dtype=dtype ), } From 6c79893a5d666cbd00194b6e8ddd489ecda01cce Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 24 Jul 2025 22:38:26 -0400 Subject: [PATCH 03/16] update metatomic checkpoint to fix tests (#223) --- docs/_static/draw_pkg_treemap.py | 2 +- examples/tutorials/metatomic_tutorial.py | 3 ++- examples/tutorials/using_graphpes_tutorial.py | 1 + tests/models/test_metatomic.py | 2 +- torch_sim/models/metatomic.py | 2 +- 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index f339a604..44775ead 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -5,7 +5,7 @@ # /// script # dependencies = [ -# "pymatviz @ git+https://github.com/janosh/pymatviz", +# "pymatviz>=0.16.0", # "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/tutorials/metatomic_tutorial.py b/examples/tutorials/metatomic_tutorial.py index 284ff2bf..54479439 100644 --- a/examples/tutorials/metatomic_tutorial.py +++ b/examples/tutorials/metatomic_tutorial.py @@ -4,7 +4,8 @@ # /// script # dependencies = [ # "metatrain[pet]==2025.7", -# "metatomic-torch>=0.1.1,<0.2" +# "metatomic-torch>=0.1.1,<0.2", +# "vesin-torch>=0.3.7", # ] # /// # diff --git a/examples/tutorials/using_graphpes_tutorial.py b/examples/tutorials/using_graphpes_tutorial.py index 7aea3531..7b84bf41 100644 --- a/examples/tutorials/using_graphpes_tutorial.py +++ b/examples/tutorials/using_graphpes_tutorial.py @@ -5,6 +5,7 @@ # dependencies = [ # "graph-pes>=0.0.30", # "torch==2.5", +# "vesin-torch>=0.3.7", # ] # /// # diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index 98bad1ff..f467e4e7 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -28,7 +28,7 @@ def metatomic_calculator(device: torch.device): """Load a pretrained metatomic model for testing.""" return ase_calculator.MetatomicCalculator( model=load_model( - "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" + "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" ).export(), device=device, ) diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 7aafb8d0..47fda077 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -103,7 +103,7 @@ def __init__( ) if model == "pet-mad": - path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" + path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" self._model = load_model(path).export() elif model.endswith(".ckpt"): path = model From c0e3137e3e6b7f8b20b2bb18755deae1c4539916 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 24 Jul 2025 23:16:28 -0400 Subject: [PATCH 04/16] Update citation.cff (#225) --- citation.cff | 6 ++---- docs/_static/draw_pkg_treemap.py | 2 +- examples/scripts/6_Phonons/6.1_Phonons_MACE.py | 2 +- examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/citation.cff b/citation.cff index c9898e03..79eb7956 100644 --- a/citation.cff +++ b/citation.cff @@ -1,5 +1,5 @@ cff-version: 1.2.0 -title: Torch-Sim +title: TorchSim message: If you use this software, please cite it as below. authors: - family-names: Gangan @@ -17,8 +17,6 @@ authors: license: MIT license-url: https://github.com/Radical-AI/torch-sim/blob/main/LICENSE repository-code: https://github.com/Radical-AI/torch-sim -type: software url: https://github.com/Radical-AI/torch-sim -doi: 10.5281/zenodo.7486816 -version: 0.1.0 +type: software date-released: 2025-04-02 diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 44775ead..3a983387 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -5,7 +5,7 @@ # /// script # dependencies = [ -# "pymatviz>=0.16.0", +# "pymatviz==0.16.0", # "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index f88fd351..38a50d60 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,7 +4,7 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz>=0.16", +# "pymatviz==0.16", # "seekpath", # "ase", # "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 0fdea6b4..fe1c041b 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -6,7 +6,7 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz>=0.16", +# "pymatviz==0.16", # "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// From 5371cb6a83cdc495f19ce43c17c9b01e7a83e3fd Mon Sep 17 00:00:00 2001 From: Kian Pu <59935793+kianpu34593@users.noreply.github.com> Date: Thu, 24 Jul 2025 23:17:14 -0400 Subject: [PATCH 05/16] add new states when the max_memory_scaler is updated (#222) --- torch_sim/autobatching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9cd55673..a691de36 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -952,6 +952,8 @@ def _get_first_batch(self) -> SimState: scale_factor=self.memory_scaling_factor, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding + newer_states = self._get_next_states() + states = [*states, *newer_states] return concatenate_states([first_state, *states]) def next_batch( # noqa: C901 From 16bf8f83ad6764b3f900127f11b6b411ca1ff02e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 24 Jul 2025 23:34:49 -0400 Subject: [PATCH 06/16] fix broken code block in low level tutorial (#226) --- examples/tutorials/low_level_tutorial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index d863ca09..69d22e78 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -166,6 +166,7 @@ the course of the simulation can be passed to the `update_fn`. """ +# %% fire_init_fn, fire_update_fn = ts.unit_cell_fire( model=model, dt_max=0.1, From 926e043bd1186bfcc1f78f17e08103a1aff10cbf Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Wed, 6 Aug 2025 18:45:32 -0400 Subject: [PATCH 07/16] Rename more batch to system (#233) --- .../7_Others/7.3_Batched_neighbor_list.py | 12 ++--- examples/tutorials/low_level_tutorial.py | 6 +-- examples/tutorials/state_tutorial.py | 22 ++++----- tests/test_autobatching.py | 2 +- tests/test_integrators.py | 48 +++++++++---------- tests/test_neighbors.py | 14 +++--- tests/test_transforms.py | 14 +++--- torch_sim/models/graphpes.py | 6 +-- torch_sim/models/mace.py | 20 ++++---- torch_sim/models/orb.py | 12 ++--- torch_sim/models/sevennet.py | 6 +-- torch_sim/neighbors.py | 40 ++++++++-------- torch_sim/quantities.py | 2 +- torch_sim/runners.py | 34 ++++++------- 14 files changed, 119 insertions(+), 119 deletions(-) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 91141fd3..2b845c07 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -25,26 +25,26 @@ # Fix: Ensure pbc has the correct shape [n_systems, 3] pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) -mapping, mapping_batch, shifts_idx = torch_nl_linked_cell( +mapping, mapping_system, shifts_idx = torch_nl_linked_cell( cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) -cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch) +cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_system) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) print(mapping.shape) -print(mapping_batch.shape) +print(mapping_system.shape) print(shifts_idx.shape) print(cell_shifts.shape) print(dds.shape) -mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2( +mapping_n2, mapping_system_n2, shifts_idx_n2 = torch_nl_n2( cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) -cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2) +cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_system_n2) dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2) print(mapping_n2.shape) -print(mapping_batch_n2.shape) +print(mapping_system_n2.shape) print(shifts_idx_n2.shape) print(cell_shifts_n2.shape) print(dds_n2.shape) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 69d22e78..99c8702d 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -107,7 +107,7 @@ """ `SimState` objects can be passed directly to the model and it will compute the properties of the systems in the batch. The properties will be returned -either batchwise, like the energy, or atomwise, like the forces. +either systemwise, like the energy, or atomwise, like the forces. Note that the energy here refers to the potential energy of the system. """ @@ -116,9 +116,9 @@ model_outputs = model(state) print(f"Model outputs: {', '.join(list(model_outputs))}") -print(f"Energy is a batchwise property with shape: {model_outputs['energy'].shape}") +print(f"Energy is a systemwise property with shape: {model_outputs['energy'].shape}") print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}") -print(f"Stress is a batchwise property with shape: {model_outputs['stress'].shape}") +print(f"Stress is a systemwise property with shape: {model_outputs['stress'].shape}") # %% [markdown] diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0bc9e341..0d3eca96 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -71,11 +71,11 @@ # %% [markdown] """ -SimState attributes fall into three categories: atomwise, batchwise, and global. +SimState attributes fall into three categories: atomwise, systemwise, and global. * Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`, - `masses`, `atomic_numbers`, and `batch`. Names are plural. -* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for + `masses`, `atomic_numbers`, and `system_idx`. Names are plural. +* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. @@ -112,7 +112,7 @@ f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems" ) -# we can see how the shapes of batchwise, atomwise, and global properties change +# we can see how the shapes of atomwise, systemwise, and global properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") @@ -142,7 +142,7 @@ SimState supports many convenience operations for manipulating batched states. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state -containing only the first three batches. The other operations are available through the +containing only the first three systems. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. """ @@ -182,19 +182,19 @@ # %% [markdown] """ -You can extract specific batches from a batched state using Python's slicing syntax. +You can extract specific systems from a batched state using Python's slicing syntax. This is extremely useful for analyzing specific systems or for implementing complex workflows where different systems need separate processing: The slicing interface follows Python's standard indexing conventions, making it intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the -selected batches, maintaining all the necessary properties and relationships. +selected systems, maintaining all the necessary properties and relationships. Note the difference between these operations: -- `split()` returns all batches as separate states but doesn't modify the original -- `pop()` removes specified batches from the original state and returns them as +- `split()` returns all systems as separate states but doesn't modify the original +- `pop()` removes specified systems from the original state and returns them as separate states -- `__getitem__` (slicing) creates a new state with specified batches without modifying +- `__getitem__` (slicing) creates a new state with specified systems without modifying the original This flexibility allows you to structure your simulation workflows in the most @@ -203,7 +203,7 @@ ### Splitting and Popping Batches SimState provides methods to split a batched state into separate states or to remove -specific batches: +specific systems: """ # %% [markdown] diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7be28997..30544f64 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -149,7 +149,7 @@ def test_binning_auto_batcher( # Get batches until None is returned batches = list(batcher) - # Check we got the expected number of batches + # Check we got the expected number of systems assert len(batches) == len(batcher.batched_states) # Test restore_original_order diff --git a/tests/test_integrators.py b/tests/test_integrators.py index b6923aa5..ac7bf4b8 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -20,66 +20,66 @@ def test_calculate_momenta_basic(device: torch.device): seed = 42 dtype = torch.float64 - # Create test inputs for 3 batches with 2 atoms each + # Create test inputs for 3 systems with 2 atoms each n_atoms = 8 positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 - batch = torch.tensor( + system_idx = torch.tensor( [0, 0, 1, 1, 2, 2, 3, 3], device=device - ) # 3 batches with 2 atoms each + ) # 3 systems with 2 atoms each kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) # Run the function - momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) # Basic checks assert momenta.shape == positions.shape assert momenta.dtype == dtype assert momenta.device == device - # Check that each batch has zero center of mass momentum + # Check that each system has zero center of mass momentum for b in range(4): - batch_mask = batch == b - batch_momenta = momenta[batch_mask] - com_momentum = torch.mean(batch_momenta, dim=0) + system_mask = system_idx == b + system_momenta = momenta[system_mask] + com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) def test_calculate_momenta_single_atoms(device: torch.device): - """Test that calculate_momenta preserves momentum for batches with single atoms.""" + """Test that calculate_momenta preserves momentum for systems with single atoms.""" seed = 42 dtype = torch.float64 - # Create test inputs with some batches having single atoms + # Create test inputs with some systems having single atoms positions = torch.randn(5, 3, dtype=dtype, device=device) masses = torch.rand(5, dtype=dtype, device=device) + 0.5 - batch = torch.tensor( + system_idx = torch.tensor( [0, 1, 1, 2, 3], device=device - ) # Batches 0, 2, and 3 have single atoms + ) # systems 0, 2, and 3 have single atoms kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) # Generate momenta and save the raw values before COM correction generator = torch.Generator(device=device).manual_seed(seed) raw_momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator - ) * torch.sqrt(masses * kT[batch]).unsqueeze(-1) + ) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1) # Run the function - momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) - # Check that single-atom batches have unchanged momenta - for b in [0, 2, 3]: # Single atom batches - batch_mask = batch == b + # Check that single-atom systems have unchanged momenta + for b in [0, 2, 3]: # Single atom systems + system_mask = system_idx == b # The momentum should be exactly the same as the raw value for single atoms - assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask]) + assert torch.allclose(momenta[system_mask], raw_momenta[system_mask]) - # Check that multi-atom batches have zero COM - for b in [1]: # Multi-atom batches - batch_mask = batch == b - batch_momenta = momenta[batch_mask] - com_momentum = torch.mean(batch_momenta, dim=0) + # Check that multi-atom systems have zero COM + for b in [1]: # Multi-atom systems + system_mask = system_idx == b + system_momenta = momenta[system_mask] + com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) @@ -378,7 +378,7 @@ def test_compute_cell_force_atoms_per_system(): Covers fix in https://github.com/Radical-AI/torch-sim/pull/153.""" from torch_sim.integrators.npt import _compute_cell_force - # Setup minimal state with two batches having 8:1 atom ratio + # Setup minimal state with two systems having 8:1 atom ratio s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long) state = NPTLangevinState( diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 205b626a..ac948acb 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -342,13 +342,13 @@ def test_torch_nl_implementations( ) # Get the neighbor list from the implementation being tested - mapping, mapping_batch, shifts_idx = nl_implementation( + mapping, mapping_system, shifts_idx = nl_implementation( cutoff, pos, row_vector_cell, pbc, batch, self_interaction ) # Calculate distances cell_shifts = transforms.compute_cell_shifts( - row_vector_cell, shifts_idx, mapping_batch + row_vector_cell, shifts_idx, mapping_system ) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) dds = np.sort(dds.numpy()) @@ -496,7 +496,7 @@ def test_strict_nl_edge_cases( # Test with no cell shifts mapping = torch.tensor([[0], [1]], device=device, dtype=torch.long) - batch_mapping = torch.tensor([0], device=device, dtype=torch.long) + system_mapping = torch.tensor([0], device=device, dtype=torch.long) shifts_idx = torch.zeros((1, 3), device=device, dtype=torch.long) new_mapping, new_batch, new_shifts = neighbors.strict_nl( @@ -504,14 +504,14 @@ def test_strict_nl_edge_cases( positions=pos, cell=cell, mapping=mapping, - batch_mapping=batch_mapping, + system_mapping=system_mapping, shifts_idx=shifts_idx, ) assert len(new_mapping[0]) > 0 # Should find neighbors # Test with different batch mappings mapping = torch.tensor([[0, 1], [1, 0]], device=device, dtype=torch.long) - batch_mapping = torch.tensor([0, 1], device=device, dtype=torch.long) + system_mapping = torch.tensor([0, 1], device=device, dtype=torch.long) shifts_idx = torch.zeros((2, 3), device=device, dtype=torch.long) new_mapping, new_batch, new_shifts = neighbors.strict_nl( @@ -519,7 +519,7 @@ def test_strict_nl_edge_cases( positions=pos, cell=cell, mapping=mapping, - batch_mapping=batch_mapping, + system_mapping=system_mapping, shifts_idx=shifts_idx, ) assert len(new_mapping[0]) > 0 # Should find neighbors @@ -559,7 +559,7 @@ def test_neighbor_lists_time_and_memory( system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device) # Fix pbc tensor shape pbc = torch.tensor([[True, True, True]], device=device) - mapping, mapping_batch, shifts_idx = nl_fn( + mapping, mapping_system, shifts_idx = nl_fn( cutoff, pos, cell, pbc, system_idx, self_interaction=False ) else: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4c05e658..ca965c69 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1183,9 +1183,9 @@ def test_compute_cell_shifts_basic() -> None: """Test compute_cell_shifts function.""" cell = torch.eye(3).unsqueeze(0) * 2.0 shifts_idx = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) - batch_mapping = torch.tensor([0, 0]) + system_mapping = torch.tensor([0, 0]) - cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, batch_mapping) + cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, system_mapping) expected = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) torch.testing.assert_close(cell_shifts, expected) @@ -1272,16 +1272,16 @@ def test_build_linked_cell_neighborhood_basic() -> None: cutoff = 1.5 n_atoms = torch.tensor([2, 2]) - mapping, batch_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood( + mapping, system_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction=False ) # Check that atoms in the same structure are neighbors assert mapping.shape[1] >= 2 # At least 2 neighbor pairs - # Verify batch_mapping has correct length - assert batch_mapping.shape[0] == mapping.shape[1] + # Verify system_mapping has correct length + assert system_mapping.shape[0] == mapping.shape[1] # Verify that there are neighbors from both batches - assert torch.any(batch_mapping == 0) - assert torch.any(batch_mapping == 1) + assert torch.any(system_mapping == 0) + assert torch.any(system_mapping == 1) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index fe51fe01..6ce52753 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -68,9 +68,9 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra graphs = [] for i in range(state.n_systems): - batch_mask = state.system_idx == i - R = state.positions[batch_mask] - Z = state.atomic_numbers[batch_mask] + system_mask = state.system_idx == i + R = state.positions[system_mask] + Z = state.atomic_numbers[system_mask] cell = state.row_vector_cell[i] nl, shifts = vesin_nl_ts( R, diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 4ec3ec3f..cfd34142 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -184,17 +184,17 @@ def __init__( # Store flag to track if atomic numbers were provided at init self.atomic_numbers_in_init = atomic_numbers is not None - # Set up batch information if atomic numbers are provided + # Set up system_idx information if atomic numbers are provided if atomic_numbers is not None: if system_idx is None: - # If batch is not provided, assume all atoms belong to same system + # If system_idx is not provided, assume all atoms belong to same system system_idx = torch.zeros( len(atomic_numbers), dtype=torch.long, device=self.device ) - self.setup_from_batch(atomic_numbers, system_idx) + self.setup_from_system_idx(atomic_numbers, system_idx) - def setup_from_batch( + def setup_from_system_idx( self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor ) -> None: """Set up internal state from atomic numbers and system indices. @@ -286,7 +286,7 @@ def forward( # noqa: C901 ) state.system_idx = self.system_idx - # Update batch information if new atomic numbers are provided + # Update system_idx information if new atomic numbers are provided if ( state.atomic_numbers is not None and not self.atomic_numbers_in_init @@ -295,7 +295,7 @@ def forward( # noqa: C901 getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), ) ): - self.setup_from_batch(state.atomic_numbers, state.system_idx) + self.setup_from_system_idx(state.atomic_numbers, state.system_idx) # Process each system's neighbor list separately edge_indices = [] @@ -305,16 +305,16 @@ def forward( # noqa: C901 # TODO (AG): Currently doesn't work for batched neighbor lists for b in range(self.n_systems): - batch_mask = state.system_idx == b + system_mask = state.system_idx == b # Calculate neighbor list for this system edge_idx, shifts_idx = self.neighbor_list_fn( - positions=state.positions[batch_mask], + positions=state.positions[system_mask], cell=state.row_vector_cell[b], pbc=state.pbc, cutoff=self.r_max, ) - # Adjust indices for the batch + # Adjust indices for the system edge_idx = edge_idx + offset shifts = torch.mm(shifts_idx, state.row_vector_cell[b]) @@ -322,7 +322,7 @@ def forward( # noqa: C901 unit_shifts_list.append(shifts_idx) shifts_list.append(shifts) - offset += len(state.positions[batch_mask]) + offset += len(state.positions[system_mask]) # Combine all neighbor lists edge_index = torch.cat(edge_indices, dim=1) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index cf015fb2..7b4bffd7 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -157,10 +157,10 @@ def state_to_atom_graphs( # noqa: PLR0915 # Process each system in a single loop offset = 0 for i in range(n_systems): - batch_mask = state.system_idx == i - positions_per_system = positions[batch_mask] - atomic_numbers_per_system = atomic_numbers[batch_mask] - atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] + system_mask = state.system_idx == i + positions_per_system = positions[system_mask] + atomic_numbers_per_system = atomic_numbers[system_mask] + atomic_numbers_embedding_per_system = atomic_numbers_embedding[system_mask] cell_per_system = row_vector_cell[i] pbc_per_system = pbc @@ -223,7 +223,7 @@ def state_to_atom_graphs( # noqa: PLR0915 # Concatenate all the edge data edge_index = torch.cat(all_edges, dim=1) unit_shifts = torch.cat(all_unit_shifts, dim=0) - batch_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device) + system_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device) senders, receivers = edge_index[0], edge_index[1] @@ -232,7 +232,7 @@ def state_to_atom_graphs( # noqa: PLR0915 senders=senders, receivers=receivers, n_node=n_node, - n_edge=batch_num_edges, + n_edge=system_num_edges, node_features=_map_concat(node_feats_list), edge_features=_map_concat(edge_feats_list), system_features=_map_concat(graph_feats_list), diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index c4e3d96b..6156fc17 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -182,13 +182,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: data_list = [] for b in range(state.system_idx.max().item() + 1): - batch_mask = state.system_idx == b + system_mask = state.system_idx == b - pos = state.positions[batch_mask] + pos = state.positions[system_mask] # SevenNet uses row vector cell convention for neighbor list row_vector_cell = state.row_vector_cell[b] pbc = state.pbc - atomic_numbers = state.atomic_numbers[batch_mask] + atomic_numbers = state.atomic_numbers[system_mask] edge_idx, shifts_idx = self.neighbor_list_fn( positions=pos, diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 9c997eac..40091eee 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -642,7 +642,7 @@ def strict_nl( positions: torch.Tensor, cell: torch.Tensor, mapping: torch.Tensor, - batch_mapping: torch.Tensor, + system_mapping: torch.Tensor, shifts_idx: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Apply a strict cutoff to the neighbor list defined in the mapping. @@ -663,7 +663,7 @@ def strict_nl( mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` for which to compute distances. - batch_mapping (torch.Tensor): + system_mapping (torch.Tensor): A tensor that maps the shifts to the corresponding cells, used in conjunction with `shifts_idx` to compute the correct periodic shifts. shifts_idx (torch.Tensor): @@ -675,8 +675,8 @@ def strict_nl( A tuple containing: - mapping (torch.Tensor): A filtered tensor of shape (2, n_filtered_pairs) with pairs of indices that are within the cutoff distance. - - mapping_batch (torch.Tensor): A tensor of shape (n_filtered_pairs,) - that maps the filtered pairs to their corresponding batches. + - mapping_system (torch.Tensor): A tensor of shape (n_filtered_pairs,) + that maps the filtered pairs to their corresponding systems. - shifts_idx (torch.Tensor): A tensor of shape (n_filtered_pairs, 3) containing the periodic shift indices for the filtered pairs. @@ -689,7 +689,7 @@ def strict_nl( References: - https://github.com/felixmusil/torch_nl """ - cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, batch_mapping) + cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) if cell_shifts is None: d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) else: @@ -701,9 +701,9 @@ def strict_nl( mask = d2 < cutoff * cutoff mapping = mapping[:, mask] - mapping_batch = batch_mapping[mask] + mapping_system = system_mapping[mask] shifts_idx = shifts_idx[mask] - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx @torch.jit.script @@ -712,7 +712,7 @@ def torch_nl_n2( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - batch: torch.Tensor, + system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using a @@ -729,7 +729,7 @@ def torch_nl_n2( pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. - batch (torch.Tensor [n_atom,] torch.long): + system_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. @@ -741,7 +741,7 @@ def torch_nl_n2( A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. - batch_mapping (torch.Tensor [n_neighbors]): + system_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the @@ -750,14 +750,14 @@ def torch_nl_n2( References: - https://github.com/felixmusil/torch_nl """ - n_atoms = torch.bincount(batch) - mapping, batch_mapping, shifts_idx = transforms.build_naive_neighborhood( + n_atoms = torch.bincount(system_idx) + mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) - mapping, mapping_batch, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, batch_mapping, shifts_idx + mapping, mapping_system, shifts_idx = strict_nl( + cutoff, positions, cell, mapping, system_mapping, shifts_idx ) - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx @torch.jit.script @@ -797,7 +797,7 @@ def torch_nl_linked_cell( A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. - - batch_mapping (torch.Tensor [n_neighbors]): + - system_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. - shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the @@ -807,11 +807,11 @@ def torch_nl_linked_cell( - https://github.com/felixmusil/torch_nl """ n_atoms = torch.bincount(system_idx) - mapping, batch_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( + mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) - mapping, mapping_batch, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, batch_mapping, shifts_idx + mapping, mapping_system, shifts_idx = strict_nl( + cutoff, positions, cell, mapping, system_mapping, shifts_idx ) - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 5d9898c2..971b1b54 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -141,7 +141,7 @@ def get_pressure( return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) -def batchwise_max_force(state: SimState) -> torch.Tensor: +def systemwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per system. Args: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index ef1dcb16..3c83724c 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ UnitCellFireState, UnitCellGDState, ) -from torch_sim.quantities import batchwise_max_force, calc_kinetic_energy, calc_kT +from torch_sim.quantities import calc_kinetic_energy, calc_kT, systemwise_max_force from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike @@ -174,14 +174,14 @@ def integrate( pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) - for state, batch_indices in batch_iterator: + for state, system_indices in batch_iterator: state = init_fn(state) # set up trajectory reporters if autobatcher and trajectory_reporter: - # we must remake the trajectory reporter for each batch + # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[i] for i in batch_indices] + filenames=[og_filenames[i] for i in system_indices] ) # run the simulation @@ -278,7 +278,7 @@ def _chunked_apply( autobatcher.load_states(states) initialized_states = [] - initialized_states = [fn(batch) for batch in autobatcher] + initialized_states = [fn(system) for system in autobatcher] ordered_states = autobatcher.restore_original_order(initialized_states) return concatenate_states(ordered_states) @@ -297,7 +297,7 @@ def generate_force_convergence_fn( Returns: Convergence function that takes a state and last energy and - returns a batchwise boolean function + returns a systemwise boolean function """ def convergence_fn( @@ -308,9 +308,9 @@ def convergence_fn( Returns: torch.Tensor: Boolean tensor of shape (n_systems,) indicating - convergence status for each batch. + convergence status for each system. """ - force_conv = batchwise_max_force(state) < force_tol + force_conv = systemwise_max_force(state) < force_tol if include_cell_forces: if (cell_forces := getattr(state, "cell_forces", None)) is None: @@ -333,7 +333,7 @@ def generate_energy_convergence_fn(energy_tol: float = 1e-3) -> Callable: Returns: Convergence function that takes a state and last energy and - returns a batchwise boolean function + returns a systemwise boolean function """ def convergence_fn( @@ -344,7 +344,7 @@ def convergence_fn( Returns: torch.Tensor: Boolean tensor of shape (n_systems,) indicating - convergence status for each batch. + convergence status for each system. """ return torch.abs(state.energy - last_energy) < energy_tol @@ -437,13 +437,13 @@ def optimize( # noqa: C901 tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) while (result := autobatcher.next_batch(state, convergence_tensor))[0] is not None: - state, converged_states, batch_indices = result + state, converged_states, system_indices = result all_converged_states.extend(converged_states) # need to update the trajectory reporter if any states have converged if trajectory_reporter and (step == 1 or len(converged_states) > 0): trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[i] for i in batch_indices] + filenames=[og_filenames[i] for i in system_indices] ) for _step in range(steps_between_swaps): @@ -487,8 +487,8 @@ def static( """Run single point calculations on a batch of systems. Unlike the other runners, this function does not return a state. Instead, it - returns a list of dictionaries, one for each batch in the input state. Each - dictionary contains the properties calculated for that batch. It will also + returns a list of dictionaries, one for each system in the input state. Each + dictionary contains the properties calculated for that system. It will also modify the state in place with the "energy", "forces", and "stress" properties if they are present in the model output. @@ -547,12 +547,12 @@ class StaticState(type(state)): pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) - for sub_state, batch_indices in batch_iterator: + for sub_state, system_indices in batch_iterator: # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: - # we must remake the trajectory reporter for each batch + # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[idx] for idx in batch_indices] + filenames=[og_filenames[idx] for idx in system_indices] ) model_outputs = model(sub_state) From 88abcff6ef2c417042a1a63b947e1e2dba688635 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Wed, 6 Aug 2025 21:43:31 -0400 Subject: [PATCH 08/16] Make system_idx non-optional in `SimState` [1/2] (#231) --- torch_sim/integrators/nvt.py | 1 + torch_sim/state.py | 55 ++++++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index ff9b7b4b..19a811c8 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -389,6 +389,7 @@ def nvt_nose_hoover_init( cell=state.cell, pbc=state.pbc, atomic_numbers=atomic_numbers, + system_idx=state.system_idx, chain=chain_fns.initialize(total_dof, KE, kT), _chain_fns=chain_fns, # Store the chain functions ) diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..5240d094 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,7 +7,7 @@ import copy import importlib import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self import torch @@ -22,7 +22,7 @@ from pymatgen.core import Structure -@dataclass +@dataclass(init=False) class SimState: """State representation for atomistic systems with batched operations support. @@ -47,9 +47,8 @@ class SimState: used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - system_idx (torch.Tensor, optional): Maps each atom index to its system index. - Has shape (n_atoms,), defaults to None, must be unique consecutive - integers starting from 0 + system_idx (torch.Tensor): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -81,10 +80,35 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor + + def __init__( + self, + positions: torch.Tensor, + masses: torch.Tensor, + cell: torch.Tensor, + pbc: bool, # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.) + atomic_numbers: torch.Tensor, + system_idx: torch.Tensor | None = None, + ) -> None: + """Initialize the SimState and validate the arguments. + + Args: + positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) + masses (torch.Tensor): Atomic masses with shape (n_atoms,) + cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). + pbc (bool): Boolean indicating whether to use periodic boundary conditions + atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) + system_idx (torch.Tensor | None): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. + If not provided, it is initialized to zeros. + """ + self.positions = positions + self.masses = masses + self.cell = cell + self.pbc = pbc + self.atomic_numbers = atomic_numbers - def __post_init__(self) -> None: - """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way @@ -107,17 +131,12 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") - - if self.system_idx is None: + if system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: + self.system_idx = system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -125,6 +144,12 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.cell.ndim != 3 and system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" From f6cd00659df2db971fb8a773e9d431122846a9b8 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 7 Aug 2025 22:27:11 -0400 Subject: [PATCH 09/16] Improve Typing of ModelInterface (#215) Signed-off-by: Rhys Goodall Co-authored-by: Rhys Goodall --- .github/PULL_REQUEST_TEMPLATE.md | 3 +- .../2.1_Lennard_Jones_FIRE.py | 2 +- .../2.2_Soft_Sphere_FIRE.py | 2 +- .../3.11_Lennard_Jones_NPT_Langevin.py | 6 +- .../3_Dynamics/3.1_Lennard_Jones_NVE.py | 6 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 6 +- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 6 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 2 +- .../4_High_level_api/4.1_high_level_api.py | 4 +- .../6_Phonons/6.2_QuasiHarmonic_MACE.py | 7 +- examples/tutorials/high_level_tutorial.py | 2 +- examples/tutorials/reporting_tutorial.py | 3 +- tests/models/test_mattersim.py | 4 +- tests/models/test_sevennet.py | 5 +- tests/test_integrators.py | 16 ++- tests/test_monte_carlo.py | 3 +- tests/test_optimizers.py | 31 ++-- tests/test_quantities.py | 136 ++++++++++++++++++ tests/test_runners.py | 30 +++- tests/test_state.py | 12 +- tests/test_trajectory.py | 3 +- tests/workflows/test_a2c.py | 5 +- torch_sim/elastic.py | 3 +- torch_sim/integrators/md.py | 3 +- torch_sim/integrators/npt.py | 27 ++-- torch_sim/integrators/nve.py | 3 +- torch_sim/integrators/nvt.py | 17 ++- torch_sim/models/fairchem.py | 4 +- torch_sim/models/graphpes.py | 4 +- torch_sim/models/interface.py | 39 +---- torch_sim/models/lennard_jones.py | 2 +- torch_sim/models/mace.py | 4 +- torch_sim/models/mattersim.py | 4 +- torch_sim/models/metatomic.py | 4 +- torch_sim/models/morse.py | 2 +- torch_sim/models/orb.py | 4 +- torch_sim/models/particle_life.py | 2 +- torch_sim/models/sevennet.py | 4 +- torch_sim/models/soft_sphere.py | 12 +- torch_sim/monte_carlo.py | 3 +- torch_sim/optimizers.py | 19 +-- torch_sim/quantities.py | 32 +++-- torch_sim/runners.py | 8 +- torch_sim/state.py | 41 +++--- torch_sim/trajectory.py | 9 +- torch_sim/workflows/a2c.py | 89 +++--------- 46 files changed, 364 insertions(+), 269 deletions(-) create mode 100644 tests/test_quantities.py diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 381c356f..b71a5910 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,7 +8,8 @@ Before a pull request can be merged, the following items must be checked: * [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). - Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. +* [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. +* [ ] Run `uvx ty check` on the repo. * [ ] Tests have been added for any new functionality or bug fixes. We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit. diff --git a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py index ff11162b..9ddd8ba7 100644 --- a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py @@ -89,8 +89,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results diff --git a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py index 7571f43d..3956c956 100644 --- a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py @@ -80,8 +80,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Initialize the Soft Sphere model diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 8c87f5a6..6f81519e 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -97,8 +97,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results results = model(state) @@ -148,11 +148,11 @@ stress = model(state)["stress"] -calc_kinetic_energy = calc_kinetic_energy( +kinetic_energy = calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) volume = torch.linalg.det(state.cell) -pressure = get_pressure(stress, calc_kinetic_energy, volume) +pressure = get_pressure(stress, kinetic_energy, volume) pressure = pressure.item() / Units.pressure print(f"Final {pressure=:.4f}") print(stress * UnitConversion.eV_per_Ang3_to_GPa) diff --git a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py index 92ea04a7..9506e9d4 100644 --- a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py +++ b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py @@ -78,11 +78,7 @@ masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) # Initialize the Lennard-Jones model # Parameters: diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index 10cc1dc4..4151eb1a 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -60,11 +60,7 @@ ) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) # Run initial inference diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index d7846c7e..c998e950 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -59,11 +59,7 @@ ) state = ts.SimState( - positions=positions, - masses=masses, - cell=cell, - pbc=True, - atomic_numbers=atomic_numbers, + positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) dt = 0.002 * Units.time # Timestep (ps) diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index 6375bc61..0c1ffa58 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -96,8 +96,8 @@ positions=positions, masses=masses, cell=cell.unsqueeze(0), - pbc=True, atomic_numbers=atomic_numbers, + pbc=True, ) # Run initial simulation and get results results = model(state) diff --git a/examples/scripts/4_High_level_api/4.1_high_level_api.py b/examples/scripts/4_High_level_api/4.1_high_level_api.py index f09aa039..396ca035 100644 --- a/examples/scripts/4_High_level_api/4.1_high_level_api.py +++ b/examples/scripts/4_High_level_api/4.1_high_level_api.py @@ -54,7 +54,9 @@ prop_calculators = { 10: {"potential_energy": lambda state: state.energy}, 20: { - "kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses) + "kinetic_energy": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) }, } diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index fe1c041b..4b4edea4 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -24,12 +24,13 @@ from phonopy.structure.atoms import PhonopyAtoms import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.models.mace import MaceModel, MaceUrls def get_relaxed_structure( struct: Atoms, - model: torch.nn.Module | None, + model: ModelInterface, Nrelax: int = 300, fmax: float = 1e-3, *, @@ -80,7 +81,7 @@ def get_relaxed_structure( def get_qha_structures( state: ts.state.SimState, length_factors: np.ndarray, - model: torch.nn.Module | None, + model: ModelInterface, Nmax: int = 300, fmax: float = 1e-3, *, @@ -129,7 +130,7 @@ def get_qha_structures( def get_qha_phonons( scaled_structures: list[PhonopyAtoms], - model: torch.nn.Module | None, + model: ModelInterface, supercell_matrix: np.ndarray | None, displ: float = 0.05, *, diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index c26d7479..cf0debb4 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -132,7 +132,7 @@ 10: {"potential_energy": lambda state: state.energy}, 20: { "kinetic_energy": lambda state: ts.calc_kinetic_energy( - state.momenta, state.masses + momenta=state.momenta, masses=state.masses ) }, } diff --git a/examples/tutorials/reporting_tutorial.py b/examples/tutorials/reporting_tutorial.py index c47340fe..47477613 100644 --- a/examples/tutorials/reporting_tutorial.py +++ b/examples/tutorials/reporting_tutorial.py @@ -206,6 +206,7 @@ # %% from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.models.interface import ModelInterface # Define some property calculators @@ -214,7 +215,7 @@ def calculate_com(state: ts.state.SimState) -> torch.Tensor: return torch.mean(state.positions * state.masses.unsqueeze(1), dim=0) -def calculate_energy(state: ts.state.SimState, model: torch.nn.Module) -> torch.Tensor: +def calculate_energy(state: ts.state.SimState, model: ModelInterface) -> torch.Tensor: """Calculate energy - needs both state and model""" return model(state)["energy"] diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 44ca3237..a137ed78 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -48,7 +48,7 @@ def pretrained_mattersim_model(device: torch.device, model_name: str): @pytest.fixture def mattersim_model( - pretrained_mattersim_model: torch.nn.Module, device: torch.device + pretrained_mattersim_model: Potential, device: torch.device ) -> MatterSimModel: """Create an MatterSimModel wrapper for the pretrained model.""" return MatterSimModel( @@ -66,7 +66,7 @@ def mattersim_calculator( def test_mattersim_initialization( - pretrained_mattersim_model: torch.nn.Module, device: torch.device + pretrained_mattersim_model: Potential, device: torch.device ) -> None: """Test that the MatterSim model initializes correctly.""" model = MatterSimModel( diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index a17f8558..25bd310a 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -11,6 +11,7 @@ try: import sevenn.util from sevenn.calculator import SevenNetCalculator + from sevenn.nn.sequential import AtomGraphSequential from torch_sim.models.sevennet import SevenNetModel @@ -50,7 +51,7 @@ def pretrained_sevenn_model(device: torch.device, model_name: str): @pytest.fixture def sevenn_model( - pretrained_sevenn_model: torch.nn.Module, device: torch.device, modal_name: str + pretrained_sevenn_model: AtomGraphSequential, device: torch.device, modal_name: str ) -> SevenNetModel: """Create an SevenNetModel wrapper for the pretrained model.""" return SevenNetModel( @@ -69,7 +70,7 @@ def sevenn_calculator( def test_sevennet_initialization( - pretrained_sevenn_model: torch.nn.Module, device: torch.device + pretrained_sevenn_model: AtomGraphSequential, device: torch.device ) -> None: """Test that the SevenNet model initializes correctly.""" model = SevenNetModel( diff --git a/tests/test_integrators.py b/tests/test_integrators.py index ac7bf4b8..d5b210d1 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -109,7 +109,9 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -172,7 +174,9 @@ def test_npt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -213,7 +217,9 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -273,7 +279,9 @@ def test_nvt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx) + temp = calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 3be7787d..479552c0 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -3,6 +3,7 @@ from pymatgen.core import Structure import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.monte_carlo import ( SwapMCState, generate_swaps, @@ -112,7 +113,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): def test_monte_carlo( batched_diverse_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ): """Test the monte_carlo function that returns a step function and initial state.""" # Call monte_carlo to get the initial state and step function diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index a5bfa675..141bf5ee 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -6,6 +6,7 @@ import torch import torch_sim as ts +from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import ( FireState, FrechetCellFIREState, @@ -23,7 +24,7 @@ def test_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -62,7 +63,7 @@ def test_gradient_descent_optimization( def test_unit_cell_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -111,7 +112,7 @@ def test_unit_cell_gradient_descent_optimization( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the FIRE optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -185,7 +186,7 @@ def test_simple_optimizer_init_with_dict( optimizer_fn: callable, expected_state_type: type, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test simple optimizer init_fn with a ts.SimState dictionary.""" state_dict = { @@ -201,7 +202,7 @@ def test_simple_optimizer_init_with_dict( @pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) def test_optimizer_invalid_md_flavor( - optimizer_func: callable, lj_model: torch.nn.Module + optimizer_func: callable, lj_model: ModelInterface ) -> None: """Test optimizer with an invalid md_flavor raises ValueError.""" with pytest.raises(ValueError, match="Unknown md_flavor"): @@ -209,7 +210,7 @@ def test_optimizer_invalid_md_flavor( def test_fire_ase_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test that the ASE FIRE P<0 branch behaves as expected.""" f_dec = 0.5 # Default from fire optimizer @@ -272,7 +273,7 @@ def test_fire_ase_negative_power_branch( def test_fire_vv_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Attempt to trigger and test the VV FIRE P<0 branch.""" f_dec = 0.5 @@ -325,7 +326,7 @@ def test_fire_vv_negative_power_branch( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_unit_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Unit Cell FIRE optimizer actually minimizes energy.""" @@ -414,7 +415,7 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( expected_state_type: type, cell_factor_val: float, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with dict state and explicit cell_factor.""" state_dict = { @@ -448,7 +449,7 @@ def test_cell_optimizer_init_cell_factor_none( optimizer_fn: callable, expected_state_type: type, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with cell_factor=None.""" init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) @@ -467,7 +468,7 @@ def test_cell_optimizer_init_cell_factor_none( @pytest.mark.filterwarnings("ignore:WARNING: Non-positive volume detected") def test_unit_cell_fire_ase_non_positive_volume_warning( ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, capsys: pytest.CaptureFixture, ) -> None: """Attempt to trigger non-positive volume warning in unit_cell_fire ASE.""" @@ -503,7 +504,7 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_frechet_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Frechet Cell FIRE optimizer actually minimizes energy for different md_flavors.""" @@ -592,7 +593,7 @@ def test_frechet_cell_fire_optimization( def test_optimizer_batch_consistency( optimizer_func: callable, ar_supercell_sim_state: ts.SimState, - lj_model: torch.nn.Module, + lj_model: ModelInterface, ) -> None: """Test batched optimizer is consistent with individual optimizations.""" generator = torch.Generator(device=ar_supercell_sim_state.device) @@ -707,7 +708,7 @@ def energy_converged(current_e: torch.Tensor, prev_e: torch.Tensor) -> bool: def test_unit_cell_fire_multi_batch( - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test FIRE optimization with multiple batches.""" # Create a multi-batch system by duplicating ar_fcc_state @@ -783,7 +784,7 @@ def test_unit_cell_fire_multi_batch( def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 - ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface ) -> None: """Test batched Frechet Fixed cell FIRE optimization is consistent with FIRE (position only) optimizations.""" diff --git a/tests/test_quantities.py b/tests/test_quantities.py new file mode 100644 index 00000000..7513b6bd --- /dev/null +++ b/tests/test_quantities.py @@ -0,0 +1,136 @@ +import pytest +import torch +from torch._tensor import Tensor + +from torch_sim import quantities +from torch_sim.units import MetalUnits + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DTYPE = torch.double + + +@pytest.fixture +def single_system_data() -> dict[str, Tensor]: + masses = torch.tensor([1.0, 2.0], device=DEVICE, dtype=DTYPE) + velocities = torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], device=DEVICE, dtype=DTYPE + ) + momenta = velocities * masses.unsqueeze(-1) + return { + "masses": masses, + "velocities": velocities, + "momenta": momenta, + "ke": torch.tensor(13.5, device=DEVICE, dtype=DTYPE), + "kt": torch.tensor(4.5, device=DEVICE, dtype=DTYPE), + } + + +@pytest.fixture +def batched_system_data() -> dict[str, Tensor]: + masses = torch.tensor([1.0, 1.0, 2.0, 2.0], device=DEVICE, dtype=DTYPE) + velocities = torch.tensor( + [[1, 1, 1], [1, 1, 1], [2, 2, 2], [2, 2, 2]], device=DEVICE, dtype=DTYPE + ) + momenta = velocities * masses.unsqueeze(-1) + system_idx = torch.tensor([0, 0, 1, 1], device=DEVICE) + return { + "masses": masses, + "velocities": velocities, + "momenta": momenta, + "system_idx": system_idx, + "ke": torch.tensor([3.0, 24.0], device=DEVICE, dtype=DTYPE), + "kt": torch.tensor([1.0, 8.0], device=DEVICE, dtype=DTYPE), + } + + +def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: + # With velocities + ke_vel = quantities.calc_kinetic_energy( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(ke_vel, single_system_data["ke"]) + + # With momenta + ke_mom = quantities.calc_kinetic_energy( + masses=single_system_data["masses"], momenta=single_system_data["momenta"] + ) + assert torch.allclose(ke_mom, single_system_data["ke"]) + + +def test_calc_kinetic_energy_batched_system( + batched_system_data: dict[str, Tensor], +) -> None: + # With velocities + ke_vel = quantities.calc_kinetic_energy( + masses=batched_system_data["masses"], + velocities=batched_system_data["velocities"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(ke_vel, batched_system_data["ke"]) + + # With momenta + ke_mom = quantities.calc_kinetic_energy( + masses=batched_system_data["masses"], + momenta=batched_system_data["momenta"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(ke_mom, batched_system_data["ke"]) + + +def test_calc_kinetic_energy_errors(single_system_data: dict[str, Tensor]) -> None: + with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): + quantities.calc_kinetic_energy( + masses=single_system_data["masses"], + momenta=single_system_data["momenta"], + velocities=single_system_data["velocities"], + ) + + with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): + quantities.calc_kinetic_energy(masses=single_system_data["masses"]) + + +def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: + # With velocities + kt_vel = quantities.calc_kT( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(kt_vel, single_system_data["kt"]) + + # With momenta + kt_mom = quantities.calc_kT( + masses=single_system_data["masses"], momenta=single_system_data["momenta"] + ) + assert torch.allclose(kt_mom, single_system_data["kt"]) + + +def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: + # With velocities + kt_vel = quantities.calc_kT( + masses=batched_system_data["masses"], + velocities=batched_system_data["velocities"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(kt_vel, batched_system_data["kt"]) + + # With momenta + kt_mom = quantities.calc_kT( + masses=batched_system_data["masses"], + momenta=batched_system_data["momenta"], + system_idx=batched_system_data["system_idx"], + ) + assert torch.allclose(kt_mom, batched_system_data["kt"]) + + +def test_calc_temperature(single_system_data: dict[str, Tensor]) -> None: + temp = quantities.calc_temperature( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + kt = quantities.calc_kT( + masses=single_system_data["masses"], + velocities=single_system_data["velocities"], + ) + assert torch.allclose(temp, kt / MetalUnits.temperature) diff --git a/tests/test_runners.py b/tests/test_runners.py index cd1ff3db..5c9862d0 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -23,7 +23,11 @@ def test_integrate_nve( filenames=traj_file, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -56,7 +60,11 @@ def test_integrate_single_nvt( filenames=traj_file, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -108,7 +116,11 @@ def test_integrate_double_nvt_with_reporter( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -155,7 +167,11 @@ def test_integrate_many_nvt( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + momenta=state.momenta, masses=state.masses + ) + } }, ) @@ -346,7 +362,11 @@ def test_batched_optimize_fire( filenames=trajectory_files, state_frequency=1, prop_calculators={ - 1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)} + 1: { + "ke": lambda state: calc_kinetic_energy( + velocities=state.velocities, masses=state.masses + ) + } }, ) diff --git a/tests/test_state.py b/tests/test_state.py index ea57dd3a..af0bda7b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -500,8 +500,8 @@ class DeformState(SimState, DeformGradMixin): def __init__( self, *args, - velocities: torch.Tensor | None = None, - reference_cell: torch.Tensor | None = None, + velocities: torch.Tensor, + reference_cell: torch.Tensor, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -530,14 +530,6 @@ def deform_grad_state(device: torch.device) -> DeformState: ) -def test_deform_grad_momenta(deform_grad_state: DeformState) -> None: - """Test momenta calculation in DeformGradMixin.""" - expected_momenta = deform_grad_state.velocities * deform_grad_state.masses.unsqueeze( - -1 - ) - assert torch.allclose(deform_grad_state.momenta, expected_momenta) - - def test_deform_grad_reference_cell(deform_grad_state: DeformState) -> None: """Test reference cell getter/setter in DeformGradMixin.""" original_ref_cell = deform_grad_state.reference_cell.clone() diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index e003f7a7..67d9de1f 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -9,6 +9,7 @@ import torch_sim as ts from torch_sim.integrators import MDState +from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter @@ -748,7 +749,7 @@ def test_reporter_with_model( """Test TrajectoryReporter with a model argument in property calculators.""" # Create a property calculator that uses the model - def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tensor: + def energy_calculator(state: ts.SimState, model: ModelInterface) -> torch.Tensor: output = model(state) # Calculate a property that depends on the model return output["energy"] diff --git a/tests/workflows/test_a2c.py b/tests/workflows/test_a2c.py index aaed7317..a95ce0e0 100644 --- a/tests/workflows/test_a2c.py +++ b/tests/workflows/test_a2c.py @@ -1,10 +1,12 @@ +from typing import cast + import pytest import torch from pymatgen.core.composition import Composition import torch_sim as ts from torch_sim.models.soft_sphere import SoftSphereModel -from torch_sim.optimizers import UnitCellFireState +from torch_sim.optimizers import FireState, UnitCellFireState from torch_sim.workflows import a2c @@ -155,6 +157,7 @@ def test_random_packed_structure_auto_diameter(device: torch.device) -> None: max_iter=3, device=device, ) + state = cast("FireState", state) # Just check that it ran without errors assert state.positions is not None diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 55bb27a9..a067e03d 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -24,6 +24,7 @@ import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import BravaisType @@ -1105,7 +1106,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 def calculate_elastic_tensor( - model: torch.nn.Module, + model: ModelInterface, *, state: SimState, bravais_type: BravaisType = BravaisType.TRICLINIC, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index ce15877d..069f62eb 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -6,6 +6,7 @@ import torch from torch_sim import transforms +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -156,7 +157,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def velocity_verlet(state: MDState, dt: torch.Tensor, model: torch.nn.Module) -> MDState: +def velocity_verlet(state: MDState, dt: torch.Tensor, model: ModelInterface) -> MDState: """Perform one complete velocity Verlet integration step. This function implements the velocity Verlet algorithm, which provides diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 27b78c8b..e1ac1d39 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -14,6 +14,7 @@ calculate_momenta, construct_nose_hoover_chain, ) +from torch_sim.models.interface import ModelInterface from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -140,7 +141,7 @@ def _compute_cell_force( def npt_langevin( # noqa: C901, PLR0915 - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -162,7 +163,7 @@ def npt_langevin( # noqa: C901, PLR0915 maintain constant temperature. Args: - model (torch.nn.Module): Neural network model that computes energies, forces, + model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] kT (torch.Tensor): Target temperature in energy units, either scalar or @@ -898,7 +899,7 @@ def current_cell(self) -> torch.Tensor: def npt_nose_hoover( # noqa: C901, PLR0915 *, - model: torch.nn.Module, + model: ModelInterface, kT: torch.Tensor, external_pressure: torch.Tensor, dt: torch.Tensor, @@ -915,7 +916,7 @@ def npt_nose_hoover( # noqa: C901, PLR0915 with Nose-Hoover chain thermostats for temperature and pressure control. Args: - model (torch.nn.Module): Model to compute forces and energies + model (ModelInterface): Model to compute forces and energies kT (torch.Tensor): Target temperature in energy units external_pressure (torch.Tensor): Target external pressure dt (torch.Tensor): Integration timestep @@ -1221,7 +1222,9 @@ def compute_cell_force( if system_mask.any(): system_momenta = momenta[system_mask] system_masses = masses[system_mask] - KE_per_system[b] = calc_kinetic_energy(system_momenta, system_masses) + KE_per_system[b] = calc_kinetic_energy( + masses=system_masses, momenta=system_momenta + ) # Get stress tensor and compute trace per system # Handle stress tensor with batch dimension @@ -1430,7 +1433,7 @@ def npt_nose_hoover_init( cell_mass = cell_mass.to(device=device, dtype=dtype) # Calculate cell kinetic energy (using first system for initialization) - KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1]) + KE_cell = calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: @@ -1485,7 +1488,9 @@ def npt_nose_hoover_init( # Initialize thermostat npt_state.momenta = momenta KE = calc_kinetic_energy( - npt_state.momenta, npt_state.masses, system_idx=npt_state.system_idx + momenta=npt_state.momenta, + masses=npt_state.masses, + system_idx=npt_state.system_idx, ) npt_state.thermostat = thermostat_fns.initialize( npt_state.positions.numel(), KE, kT @@ -1542,10 +1547,12 @@ def npt_nose_hoover_update( ) # Update kinetic energies for thermostats - KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) state.thermostat.kinetic_energy = KE - KE_cell = calc_kinetic_energy(state.cell_momentum, state.cell_mass) + KE_cell = calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains @@ -1597,7 +1604,7 @@ def npt_nose_hoover_invariant( # Calculate kinetic energy of particles per system e_kin_per_system = calc_kinetic_energy( - state.momenta, state.masses, system_idx=state.system_idx + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) # Calculate degrees of freedom per system diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index f17e59f6..c7e41390 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -10,12 +10,13 @@ momentum_step, position_step, ) +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import StateDict def nve( - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 19a811c8..8309afa1 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -16,13 +16,14 @@ position_step, velocity_verlet, ) +from torch_sim.models.interface import ModelInterface from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict def nvt_langevin( # noqa: C901 - model: torch.nn.Module, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -275,7 +276,7 @@ def velocities(self) -> torch.Tensor: def nvt_nose_hoover( *, - model: torch.nn.Module, + model: ModelInterface, dt: torch.Tensor, kT: torch.Tensor, chain_length: int = 3, @@ -367,7 +368,9 @@ def nvt_nose_hoover_init( ) # Calculate initial kinetic energy per system - KE = calc_kinetic_energy(momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) # Calculate degrees of freedom per system n_atoms_per_system = torch.bincount(state.system_idx) @@ -433,7 +436,9 @@ def nvt_nose_hoover_update( state = velocity_verlet(state=state, dt=dt, model=model) # Update chain kinetic energy per system - KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + KE = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) chain.kinetic_energy = KE # Second half-step of chain evolution @@ -477,7 +482,9 @@ def nvt_nose_hoover_invariant( """ # Calculate system energy terms per system e_pot = state.energy - e_kin = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx) + e_kin = calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) # Get system degrees of freedom per system n_atoms_per_system = torch.bincount(state.system_idx) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index b911668e..77b1b0ba 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -45,7 +45,7 @@ except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) - class FairChemModel(torch.nn.Module, ModelInterface): + class FairChemModel(ModelInterface): """FairChem model wrapper for torch_sim. This class is a placeholder for the FairChemModel class. @@ -70,7 +70,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: } -class FairChemModel(torch.nn.Module, ModelInterface): +class FairChemModel(ModelInterface): """Computes atomistic energies, forces and stresses using a FairChem model. This class wraps a FairChem model to compute energies, forces, and stresses for diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 6ce52753..a7b287e0 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -36,7 +36,7 @@ warnings.warn(f"GraphPES import failed: {traceback.format_exc()}", stacklevel=2) PropertyKey = str - class GraphPESWrapper(torch.nn.Module, ModelInterface): # type: ignore[reportRedeclaration] + class GraphPESWrapper(ModelInterface): # type: ignore[reportRedeclaration] """GraphPESModel wrapper for torch_sim. This class is a placeholder for the GraphPESWrapper class. @@ -99,7 +99,7 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra return to_batch(graphs) -class GraphPESWrapper(torch.nn.Module, ModelInterface): +class GraphPESWrapper(ModelInterface): """Wrapper for GraphPESModel in TorchSim. This class provides a TorchSim wrapper around GraphPESModel instances, diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index a70736ab..27c03277 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -27,8 +27,6 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): """ from abc import ABC, abstractmethod -from pathlib import Path -from typing import Self import torch @@ -37,7 +35,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): from torch_sim.typing import MemoryScaling, StateDict -class ModelInterface(ABC): +class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in torchsim. This interface provides a common structure for all energy and force models, @@ -71,37 +69,10 @@ class ModelInterface(ABC): ``` """ - @abstractmethod - def __init__( - self, - model: str | Path | torch.nn.Module | None = None, - device: torch.device | None = None, - dtype: torch.dtype = torch.float64, - **kwargs, - ) -> Self: - """Initialize a model implementation. - - Implementations must set device, dtype and compute capability flags - to indicate what operations the model supports. Models may optionally - load parameters from a file or existing module. - - Args: - model (str | Path | torch.nn.Module | None): Model specification, which - can be: - - Path to a model checkpoint or model file - - Pre-configured torch.nn.Module - - None for default initialization - Defaults to None. - device (torch.device | None): Device where the model will run. If None, - a default device will be selected. Defaults to None. - dtype (torch.dtype): Data type for model calculations. Defaults to - torch.float64. - **kwargs: Additional model-specific parameters. - - Notes: - All implementing classes must set self._device, self._dtype, - self._compute_stress and self._compute_forces in their __init__ method. - """ + _device: torch.device + _dtype: torch.dtype + _compute_stress: bool + _compute_forces: bool @property def device(self) -> torch.device: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 3b7a5b81..2a05e2f8 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -119,7 +119,7 @@ def lennard_jones_pair_force( return torch.where(dr > 0, force, torch.zeros_like(force)) -class LennardJonesModel(torch.nn.Module, ModelInterface): +class LennardJonesModel(ModelInterface): """Lennard-Jones potential energy and force calculator. Implements the Lennard-Jones 12-6 potential for molecular dynamics simulations. diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index cfd34142..5ca2a629 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -38,7 +38,7 @@ except ImportError as exc: warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2) - class MaceModel(torch.nn.Module, ModelInterface): + class MaceModel(ModelInterface): """MACE model wrapper for torch_sim. This class is a placeholder for the MaceModel class. @@ -77,7 +77,7 @@ def to_one_hot( return oh.view(*shape) -class MaceModel(torch.nn.Module, ModelInterface): +class MaceModel(ModelInterface): """Computes energies for multiple systems using a MACE model. This class wraps a MACE model to compute energies, forces, and stresses for diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index 40ef0cc5..9b2efb23 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -21,7 +21,7 @@ except ImportError as exc: warnings.warn(f"MatterSim import failed: {traceback.format_exc()}", stacklevel=2) - class MatterSimModel(torch.nn.Module, ModelInterface): + class MatterSimModel(ModelInterface): """MatterSim model wrapper for torch_sim. This class is a placeholder for the MatterSimModel class. @@ -39,7 +39,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -class MatterSimModel(torch.nn.Module, ModelInterface): +class MatterSimModel(ModelInterface): """Computes atomistic energies, forces and stresses using an MatterSim model. This class wraps an MatterSim model to compute energies, forces, and stresses for diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 47fda077..5655ed88 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -36,7 +36,7 @@ except ImportError as exc: warnings.warn(f"Metatomic import failed: {traceback.format_exc()}", stacklevel=2) - class MetatomicModel(torch.nn.Module, ModelInterface): + class MetatomicModel(ModelInterface): """Metatomic model wrapper for torch_sim. This class is a placeholder for the MetatomicModel class. @@ -48,7 +48,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -class MetatomicModel(torch.nn.Module, ModelInterface): +class MetatomicModel(ModelInterface): """Computes energies for a list of systems using a metatomic model. This class wraps a metatomic model to compute energies, forces, and stresses for diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 97564361..702dc41f 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -112,7 +112,7 @@ def morse_pair_force( return torch.where(dr > 0, force, torch.zeros_like(force)) -class MorseModel(torch.nn.Module, ModelInterface): +class MorseModel(ModelInterface): """Morse potential energy and force calculator. Implements the Morse potential for molecular dynamics simulations. This model diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 7b4bffd7..fd65b23f 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -39,7 +39,7 @@ except ImportError as exc: warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2) - class OrbModel(torch.nn.Module, ModelInterface): + class OrbModel(ModelInterface): """ORB model wrapper for torch_sim. This class is a placeholder for the OrbModel class. @@ -247,7 +247,7 @@ def state_to_atom_graphs( # noqa: PLR0915 ).to(device=device, dtype=output_dtype) -class OrbModel(torch.nn.Module, ModelInterface): +class OrbModel(ModelInterface): """Computes atomistic energies, forces and stresses using an ORB model. This class wraps an ORB model to compute energies, forces, and stresses for diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 39c5606c..3a13a333 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -83,7 +83,7 @@ def asymmetric_particle_pair_force_jit( return inner_forces + outer_forces -class ParticleLifeModel(torch.nn.Module, ModelInterface): +class ParticleLifeModel(ModelInterface): """Calculator for asymmetric particle interaction. This model implements an asymmetric interaction between particles based on diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 6156fc17..cda8e183 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -32,7 +32,7 @@ except ImportError as exc: warnings.warn(f"SevenNet import failed: {traceback.format_exc()}", stacklevel=2) - class SevenNetModel(torch.nn.Module, ModelInterface): + class SevenNetModel(ModelInterface): """SevenNet model wrapper for torch_sim. This class is a placeholder for the SevenNetModel class. @@ -44,7 +44,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -class SevenNetModel(torch.nn.Module, ModelInterface): +class SevenNetModel(ModelInterface): """Computes atomistic energies, forces and stresses using an SevenNet model. This class wraps an SevenNet model to compute energies, forces, and stresses for diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 4397b2eb..75598aaf 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -130,7 +130,7 @@ def fn(dr: torch.Tensor) -> torch.Tensor: return transforms.safe_mask(mask, fn, dr) -class SoftSphereModel(torch.nn.Module, ModelInterface): +class SoftSphereModel(ModelInterface): """Calculator for soft sphere potential energies and forces. Implements a model for computing properties based on the soft sphere potential, @@ -435,7 +435,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: return results -class SoftSphereMultiModel(torch.nn.Module): +class SoftSphereMultiModel(ModelInterface): """Calculator for systems with multiple particle types. Extends the basic soft sphere model to support multiple particle types with @@ -594,11 +594,11 @@ def __init__( with type 0). """ super().__init__() - self.device = device or torch.device("cpu") - self.dtype = dtype + self._device = device or torch.device("cpu") + self._dtype = dtype self.pbc = pbc - self.compute_forces = compute_forces - self.compute_stress = compute_stress + self._compute_forces = compute_forces + self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 90929068..64aad3c7 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -15,6 +15,7 @@ import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -183,7 +184,7 @@ def metropolis_criterion( def swap_monte_carlo( *, - model: torch.nn.Module, + model: ModelInterface, kT: float, seed: int | None = None, ) -> tuple[ diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..44cb498e 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -25,6 +25,7 @@ import torch import torch_sim.math as tsm +from torch_sim.models.interface import ModelInterface from torch_sim.state import DeformGradMixin, SimState from torch_sim.typing import StateDict @@ -57,7 +58,7 @@ class GDState(SimState): def gradient_descent( - model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 + model: ModelInterface, *, lr: torch.Tensor | float = 0.01 ) -> tuple[Callable[[StateDict | SimState], GDState], Callable[[GDState], GDState]]: """Initialize a batched gradient descent optimization. @@ -196,7 +197,7 @@ class UnitCellGDState(GDState, DeformGradMixin): def unit_cell_gradient_descent( # noqa: PLR0915, C901 - model: torch.nn.Module, + model: ModelInterface, *, positions_lr: float = 0.01, cell_lr: float = 0.1, @@ -483,7 +484,7 @@ class FireState(SimState): def fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -692,7 +693,7 @@ class UnitCellFireState(SimState, DeformGradMixin): def unit_cell_fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -708,7 +709,7 @@ def unit_cell_fire( max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ - UnitCellFireState, + Callable[[SimState | StateDict], UnitCellFireState], Callable[[UnitCellFireState], UnitCellFireState], ]: """Initialize a batched FIRE optimization with unit cell degrees of freedom. @@ -976,7 +977,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): def frechet_cell_fire( - model: torch.nn.Module, + model: ModelInterface, *, dt_max: float = 1.0, dt_start: float = 0.1, @@ -992,7 +993,7 @@ def frechet_cell_fire( max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ - FrechetCellFIREState, + Callable[[SimState | StateDict], FrechetCellFIREState], Callable[[FrechetCellFIREState], FrechetCellFIREState], ]: """Initialize a batched FIRE optimization with Frechet cell parameterization. @@ -1204,7 +1205,7 @@ def fire_init( def _vv_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, - model: torch.nn.Module, + model: ModelInterface, *, dt_max: torch.Tensor, n_min: torch.Tensor, @@ -1420,7 +1421,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 def _ase_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, - model: torch.nn.Module, + model: ModelInterface, *, dt_max: torch.Tensor, n_min: torch.Tensor, diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 971b1b54..a1ac0811 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -1,5 +1,7 @@ """Functions for computing physical quantities.""" +from typing import cast + import torch from torch_sim.state import SimState @@ -21,8 +23,9 @@ def count_dof(tensor: torch.Tensor) -> int: # @torch.jit.script def calc_kT( # noqa: N802 - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> torch.Tensor: @@ -38,14 +41,12 @@ def calc_kT( # noqa: N802 Returns: torch.Tensor: Scalar temperature value """ - if momenta is not None and velocities is not None: - raise ValueError("Must pass either momenta or velocities, not both") - - if momenta is None and velocities is None: - raise ValueError("Must pass either momenta or velocities") + if not ((momenta is not None) ^ (velocities is not None)): + raise ValueError("Must pass either one of momenta or velocities") if momenta is None: # If velocity provided, calculate mv^2 + velocities = cast("torch.Tensor", velocities) squared_term = (velocities**2) * masses.unsqueeze(-1) else: # If momentum provided, calculate v^2 = p^2/m^2 @@ -70,11 +71,12 @@ def calc_kT( # noqa: N802 def calc_temperature( - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, - units: object = MetalUnits.temperature, + units: MetalUnits = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. @@ -89,13 +91,17 @@ def calc_temperature( Returns: torch.Tensor: Temperature value in specified units """ - return calc_kT(momenta, masses, velocities, system_idx) / units + kT = calc_kT( + masses=masses, momenta=momenta, velocities=velocities, system_idx=system_idx + ) + return kT / units # @torch.jit.script def calc_kinetic_energy( - momenta: torch.Tensor, + *, masses: torch.Tensor, + momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> torch.Tensor: @@ -112,10 +118,8 @@ def calc_kinetic_energy( If system_idx is None: Scalar tensor containing the total kinetic energy If system_idx is provided: Tensor of kinetic energies per system """ - if momenta is not None and velocities is not None: - raise ValueError("Must pass either momenta or velocities, not both") - if momenta is None and velocities is None: - raise ValueError("Must pass either momenta or velocities") + if not ((momenta is not None) ^ (velocities is not None)): + raise ValueError("Must pass either one of momenta or velocities") if momenta is None: # Using velocities squared_term = (velocities**2) * masses.unsqueeze(-1) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3c83724c..187cdd89 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -45,8 +45,12 @@ def _configure_reporter( "potential_energy": lambda state: state.energy, "forces": lambda state: state.forces, "stress": lambda state: state.stress, - "kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses), - "temperature": lambda state: calc_kT(state.momenta, state.masses), + "kinetic_energy": lambda state: calc_kinetic_energy( + velocities=state.velocities, masses=state.masses + ), + "temperature": lambda state: calc_kT( + velocities=state.velocities, masses=state.masses + ), } prop_calculators = { diff --git a/torch_sim/state.py b/torch_sim/state.py index 5240d094..af4db6d1 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,12 +8,12 @@ import importlib import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, cast import torch import torch_sim as ts -from torch_sim.typing import StateLike +from torch_sim.typing import SimStateVar, StateLike if TYPE_CHECKING: @@ -109,6 +109,7 @@ def __init__( self.pbc = pbc self.atomic_numbers = atomic_numbers + # Validate and process the state after initialization. # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way @@ -136,13 +137,13 @@ def __init__( self.n_atoms, device=self.device, dtype=torch.int64 ) else: - self.system_idx = system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. - _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) - if not torch.all(counts == torch.bincount(self.system_idx)): + _, counts = torch.unique_consecutive(system_idx, return_counts=True) + if not torch.all(counts == torch.bincount(system_idx)): raise ValueError("System indices must be unique consecutive integers") + self.system_idx = system_idx if self.cell.ndim != 3 and system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -251,7 +252,9 @@ def n_systems(self) -> int: @property def volume(self) -> torch.Tensor: """Volume of the system.""" - return torch.det(self.cell) if self.pbc else None + if not self.pbc: + raise ValueError("Volume is only defined for periodic systems") + return torch.det(self.cell) @property def column_vector_cell(self) -> torch.Tensor: @@ -361,7 +364,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se for attr_name, attr_value in vars(modified_state).items(): setattr(self, attr_name, attr_value) - return popped_states + return cast("list[Self]", popped_states) def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None @@ -401,14 +404,8 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> class DeformGradMixin: """Mixin for states that support deformation gradients.""" - @property - def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses. - - Returns: - The momenta of the particles - """ - return self.velocities * self.masses.unsqueeze(-1) + reference_cell: torch.Tensor + row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: @@ -483,10 +480,10 @@ def _normalize_system_indices( def state_to_device( - state: SimState, + state: SimStateVar, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> Self: +) -> SimStateVar: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -692,9 +689,9 @@ def _filter_attrs_by_mask( def _split_state( - state: SimState, + state: SimStateVar, ambiguous_handling: Literal["error", "globalize"] = "error", -) -> list[SimState]: +) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving @@ -805,10 +802,10 @@ def _pop_states( def _slice_state( - state: SimState, + state: SimStateVar, system_indices: list[int] | torch.Tensor, ambiguous_handling: Literal["error", "globalize"] = "error", -) -> SimState: +) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving @@ -968,7 +965,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_systems == 1 for state in system): + if not all(cast("SimState", state).n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 3150064a..fb170754 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -38,6 +38,7 @@ import tables import torch +from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -96,9 +97,9 @@ def __init__( state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, - state_kwargs: dict | None = None, + state_kwargs: dict[str, Any] | None = None, metadata: dict[str, str] | None = None, - trajectory_kwargs: dict | None = None, + trajectory_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize a TrajectoryReporter. @@ -203,7 +204,7 @@ def report( self, state: SimState, step: int, - model: torch.nn.Module | None = None, + model: ModelInterface | None = None, ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -216,7 +217,7 @@ def report( len(filenames) step (int): Current simulation step, setting step to 0 will write the state and all properties. - model (torch.nn.Module, optional): Model used for simulation. + model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. write_to_file (bool, optional): Whether to write the state to the trajectory diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index e3abb8d5..fb5bba0b 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -19,6 +19,7 @@ import torch_sim as ts from torch_sim import transforms +from torch_sim.models.interface import ModelInterface from torch_sim.models.soft_sphere import SoftSphereModel, SoftSphereMultiModel from torch_sim.optimizers import FireState, UnitCellFireState, fire from torch_sim.optimizers import unit_cell_fire as batched_unit_cell_fire @@ -228,7 +229,7 @@ def random_packed_structure( device: torch.device | None = None, dtype: torch.dtype = torch.float32, log: Any | None = None, -) -> FireState: +) -> FireState | tuple[FireState, list[np.ndarray]]: """Generates a random packed atomic structure and minimizes atomic overlaps. This function creates a random atomic structure within a given cell and optionally @@ -326,6 +327,7 @@ def random_packed_structure( if log is not None: return state, log + return state @@ -575,6 +577,13 @@ def get_subcells_to_crystallize( # Convert species list to numpy array for easier composition handling species_array = np.array(species) + if restrict_to_compositions is not None and restrict_to_compositions: + restrict_to_compositions: set[str] = { + Composition(comp).reduced_formula for comp in restrict_to_compositions + } + else: + restrict_to_compositions: set[str] = set() + # Generate allowed stoichiometries if max_coef is specified if max_coeff: if elements is None: @@ -583,17 +592,9 @@ def get_subcells_to_crystallize( stoichs = list(itertools.product(range(max_coeff + 1), repeat=len(elements))) stoichs.pop(0) # Remove the empty composition (0,0,...) # Convert stoichiometries to composition formulas - comps = [] for stoich in stoichs: comp = dict(zip(elements, stoich, strict=True)) - comps.append(Composition.from_dict(comp).reduced_formula) - restrict_to_compositions = set(comps) - - # Ensure compositions are in reduced formula form if provided - if restrict_to_compositions: - restrict_to_compositions = [ - Composition(comp).reduced_formula for comp in restrict_to_compositions - ] + restrict_to_compositions.add(Composition.from_dict(comp).reduced_formula) # Create orthorhombic grid for systematic subcell generation bins = int(1 / d_frac) @@ -610,7 +611,7 @@ def get_subcells_to_crystallize( .T ) - candidates = [] + candidates: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] # Iterate through all possible subcell boundary combinations for lb, ub in itertools.product(l_bound, u_bound): if torch.all(ub > lb): # Ensure valid subcell dimensions @@ -705,9 +706,9 @@ def get_target_temperature( def get_unit_cell_relaxed_structure( state: ts.SimState, - model: torch.nn.Module, + model: ModelInterface, max_iter: int = 200, -) -> tuple[UnitCellFireState, dict]: +) -> tuple[UnitCellFireState, dict[str, torch.Tensor], list[float], list[float]]: """Relax both atomic positions and cell parameters using FIRE algorithm. This function performs geometry optimization of both atomic positions and unit cell @@ -751,8 +752,8 @@ def get_unit_cell_relaxed_structure( state = unit_cell_fire_init(state) def step_fn( - step: int, state: UnitCellFireState, logger: dict - ) -> tuple[UnitCellFireState, dict]: + step: int, state: UnitCellFireState, logger: dict[str, torch.Tensor] + ) -> tuple[UnitCellFireState, dict[str, torch.Tensor]]: logger["energy"][step] = state.energy logger["stress"][step] = state.stress state = unit_cell_fire_update(state) @@ -772,61 +773,3 @@ def step_fn( f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" ) return state, logger, final_energy, final_pressure - - -def get_relaxed_structure( - state: ts.SimState, - model: torch.nn.Module, - max_iter: int = 200, -) -> tuple[FireState, dict]: - """Relax atomic positions at fixed cell parameters using FIRE algorithm. - - Does geometry optimization of atomic positions while keeping the unit cell fixed. - Uses the Fast Inertial Relaxation Engine (FIRE) algorithm to minimize forces on atoms. - - Args: - state: State containing positions, cell and atomic numbers - model: Model to compute energies, forces, and stresses - max_iter: Maximum number of FIRE iterations. Defaults to 200. - - Returns: - tuple containing: - - FIREState: Final state containing relaxed positions and other quantities - - dict: Logger with energy trajectory - - float: Final energy in eV - - float: Final pressure in eV/ų - """ - # Get device and dtype from model - device, dtype = model.device, model.dtype - - logger = {"energy": torch.zeros((max_iter, 1), device=device, dtype=dtype)} - - results = model(state) - Initial_energy = results["energy"] - print(f"Initial energy: {Initial_energy.item():.4f} eV") - - state_init_fn, fire_update = fire(model=model) - state = state_init_fn(state) - - def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]: - logger["energy"][idx] = state.energy - state = fire_update(state) - return state, logger - - for idx in range(max_iter): - state, logger = step_fn(idx, state, logger) - - # Get final results - model.compute_stress = True - final_results = model( - positions=state.positions, cell=state.cell, atomic_numbers=state.atomic_numbers - ) - - final_energy = final_results["energy"].item() - final_stress = final_results["stress"] - final_pressure = (torch.trace(final_stress) / 3.0).item() - print( - f"Final energy: {final_energy:.4f} eV, " - f"Final pressure: {final_pressure:.4f} eV/A^3" - ) - return state, logger, final_energy, final_pressure From e90b272b3b6fd00771a774e173fbc356dbffd273 Mon Sep 17 00:00:00 2001 From: Timo Reents <77727843+t-reents@users.noreply.github.com> Date: Fri, 8 Aug 2025 15:58:42 +0200 Subject: [PATCH 10/16] Initial fix for concatenation of states in `InFlightAutoBatcher` (#219) --- tests/test_autobatching.py | 2 +- torch_sim/optimizers.py | 44 +++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 30544f64..898bead0 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -490,7 +490,7 @@ def convergence_fn(state: ts.SimState) -> bool: break # run 10 steps, arbitrary number - for _ in range(10): + for _ in range(5): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 44cb498e..8626ff03 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -590,7 +590,9 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, # Optimization attributes @@ -863,13 +865,17 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1162,13 +1168,17 @@ def fire_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1245,15 +1255,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) alpha_start_system = torch.full( @@ -1462,16 +1476,20 @@ def _ase_fire_step( # noqa: C901, PLR0915 cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) forces = state.forces if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) cur_deform_grad = state.deform_grad() else: From 71e1d41b38d06abd26a8a9f456fa66006bc8f631 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 10:25:38 -0400 Subject: [PATCH 11/16] Fix simstate concatenation [2/2] (#232) --- tests/test_autobatching.py | 13 +++++++++++-- tests/test_state.py | 13 +++++++++++++ torch_sim/optimizers.py | 4 ++-- torch_sim/runners.py | 4 ++-- torch_sim/state.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 898bead0..4bce4165 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order( # batcher.restore_original_order([si_sim_state]) +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + # This tests the merging of partially converged states with new states + # which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219 + 10, # At 10 steps, all states will converge before the next batch + ], +) def test_in_flight_with_fire( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, + num_steps_per_batch: int, ) -> None: fire_init, fire_update = unit_cell_fire(lj_model) @@ -489,8 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool: if state is None: break - # run 10 steps, arbitrary number - for _ in range(5): + for _ in range(num_steps_per_batch): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/tests/test_state.py b/tests/test_state.py index af0bda7b..81109bf3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -635,3 +635,16 @@ def test_deprecated_batch_properties_equal_to_new_system_properties( state.batch = new_system_idx assert torch.allclose(state.system_idx, new_system_idx) assert torch.allclose(state.batch, new_system_idx) + + +def test_derived_classes_trigger_init_subclass() -> None: + """Test that derived classes cannot have attributes that are "tensors | None".""" + + with pytest.raises(TypeError) as excinfo: + + class DerivedState(SimState): + invalid_attr: torch.Tensor | None = None + + assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( + excinfo.value + ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 8626ff03..9a41cd8c 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -475,7 +475,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -972,7 +972,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 187cdd89..8f2917da 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -538,8 +538,8 @@ def static( @dataclass class StaticState(type(state)): energy: torch.Tensor - forces: torch.Tensor | None - stress: torch.Tensor | None + forces: torch.Tensor + stress: torch.Tensor all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames diff --git a/torch_sim/state.py b/torch_sim/state.py index af4db6d1..d2ec8351 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,6 +6,7 @@ import copy import importlib +import typing import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self, cast @@ -400,6 +401,33 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_typehint) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None' because torch.cat " + "cannot concatenate between a tensor and a None. Please default " + "the tensor with dummy values and track the 'None' case." + ) + + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" From fd865ccb3edf8f7d93a81eca6d5ce306852e3716 Mon Sep 17 00:00:00 2001 From: Orion Cohen <27712051+orionarcher@users.noreply.github.com> Date: Mon, 11 Aug 2025 13:15:32 -0400 Subject: [PATCH 12/16] Update readme plot (#236) --- README.md | 9 +- docs/_static/speedup_plot.svg | 181 +++++++++++++++++++++++++++++++++- 2 files changed, 185 insertions(+), 5 deletions(-) mode change 100755 => 100644 docs/_static/speedup_plot.svg diff --git a/README.md b/README.md index 1fe0c16e..b633e084 100644 --- a/README.md +++ b/README.md @@ -92,14 +92,15 @@ print(relaxed_state.energy) TorchSim achieves up to 100x speedup compared to ASE with popular MLIPs. -![Speedup comparison](/docs/_static/speedup_plot.svg) +Speedup comparison This figure compares the time per atom of ASE and `torch_sim`. Time per atom is defined as the number of atoms / total time. While ASE can only run a single system of `n_atoms` (on the $x$ axis), `torch_sim` can run as many systems as will fit in memory. On an H100 80 GB card, -the max atoms that could fit in memory was ~8,000 for [GemNet](https://github.com/FAIR-Chem/fairchem), ~10,000 for [MACE](https://github.com/ACEsuit/mace), and ~2,500 -for [SevenNet](https://github.com/MDIL-SNU/SevenNet). This metric describes model performance by capturing speed and memory -usage simultaneously. +the max atoms that could fit in memory was ~8,000 for [EGIP](https://github.com/FAIR-Chem/fairchem), +~10,000 for [MACE-MPA-0](https://github.com/ACEsuit/mace), ~22,000 for [Mattersim V1 1M](https://github.com/microsoft/mattersim), +~2,500 for [SevenNet](https://github.com/MDIL-SNU/SevenNet), and ~9000 for [PET-MAD](https://github.com/lab-cosmo/pet-mad). +This metric describes model performance by capturing speed and memory usage simultaneously. ## Installation diff --git a/docs/_static/speedup_plot.svg b/docs/_static/speedup_plot.svg old mode 100755 new mode 100644 index c6e2de2e..e0eca6b4 --- a/docs/_static/speedup_plot.svg +++ b/docs/_static/speedup_plot.svg @@ -1 +1,180 @@ -23.515.85.01.91.41.2101.967.218.19.87.13.924.312.64.02.11.51.116321082565008640x10x20x30x40x50x60x70x80x90x100xGemNetMACESevenNetSpeedup: Batched Integration with TorchSim vs Serial Integration with ASENumber of Atoms in Single SystemSpeedup Factor + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From a5635236fff12dca6eb23e480d71d28e80a4ff03 Mon Sep 17 00:00:00 2001 From: Orion Cohen <27712051+orionarcher@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:10:07 -0400 Subject: [PATCH 13/16] Fix typo in README (#237) --- docs/_static/speedup_plot.svg | 74 +++++++++++++++++------------------ 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/_static/speedup_plot.svg b/docs/_static/speedup_plot.svg index e0eca6b4..7098fe5f 100644 --- a/docs/_static/speedup_plot.svg +++ b/docs/_static/speedup_plot.svg @@ -1,24 +1,24 @@ - + - + - + - + - + - + - + - + - + - + @@ -82,17 +82,17 @@ - + - + - + - + - + - + @@ -100,27 +100,27 @@ - + - + - + - + - + - + - + - + - + - + - + @@ -147,34 +147,34 @@ - + - + - + - + - + - + - + - + - + - + From cde91e98d93318a9a0a598701b3f18f03b068d88 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Wed, 13 Aug 2025 16:16:20 -0400 Subject: [PATCH 14/16] Define attribute scopes in SimStates (#228) Co-authored-by: Rhys Goodall Co-authored-by: Orion Cohen <27712051+orionarcher@users.noreply.github.com> --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 3 + examples/tutorials/hybrid_swap_tutorial.py | 5 + examples/tutorials/state_tutorial.py | 27 +- tests/test_state.py | 71 +++-- torch_sim/integrators/md.py | 7 + torch_sim/integrators/npt.py | 31 ++ torch_sim/integrators/nvt.py | 4 + torch_sim/monte_carlo.py | 3 + torch_sim/optimizers.py | 75 ++++- torch_sim/runners.py | 15 +- torch_sim/state.py | 293 ++++++++---------- 11 files changed, 319 insertions(+), 215 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 6fcb50e2..fb3f7983 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -76,6 +76,9 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor + _atom_attributes = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index d8adbbf8..08c3914b 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -34,9 +34,11 @@ """ # %% +from typing import ClassVar import torch import torch_sim as ts from mace.calculators.foundations_models import mace_mp +from torch_sim.integrators.md import MDState from torch_sim.models.mace import MaceModel # Initialize the mace model @@ -104,6 +106,9 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor + _atom_attributes = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) # %% [markdown] diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0d3eca96..5cd43b17 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -79,16 +79,28 @@ the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. -You can use the `infer_property_scope` function to analyze a state's properties. This +For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's +name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`: + +_atom_attributes = {"positions", "masses", "atomic_numbers", "system_idx"} +_system_attributes = {"cell"} +_global_attributes = {"pbc"} + +You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This is mostly used internally but can be useful for debugging. """ # %% -from torch_sim.state import infer_property_scope +from torch_sim.state import get_attrs_for_scope -scope = infer_property_scope(si_state) -print(scope) +# loop through each attribute: +for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"): + print(f"per-atom attribute: {attr_name}") + print(f"value: {attr_value}") +# or access the attributes via a dict: +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501 +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] """ @@ -257,10 +269,9 @@ ) print("MDState properties:") -scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) -print("Per-atom properties:", scope["per_atom"]) -print("Per-system properties:", scope["per_system"]) +print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom"))) +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] diff --git a/tests/test_state.py b/tests/test_state.py index 81109bf3..67a757fe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,7 +13,7 @@ _pop_states, _slice_state, concatenate_states, - infer_property_scope, + get_attrs_for_scope, initialize_state, ) @@ -24,38 +24,52 @@ from pymatgen.core import Structure -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { +def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: + """Test getting attributes for a scope.""" + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) + assert set(per_atom_attrs.keys()) == { "positions", "masses", "atomic_numbers", "system_idx", } - assert set(scope["per_system"]) == {"cell"} + per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) + assert set(per_system_attrs.keys()) == {"cell"} + global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) + assert set(global_attrs.keys()) == {"pbc"} -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - state = MDState( - **asdict(si_sim_state), - momenta=torch.randn_like(si_sim_state.positions), - forces=torch.randn_like(si_sim_state.positions), - energy=torch.zeros((1,)), - ) - scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - "forces", - "momenta", - } - assert set(scope["per_system"]) == {"cell", "energy"} +def test_all_attributes_must_be_specified_in_scopes() -> None: + """Test that an error is raised when we forget to specify the scope + for an attribute in a child SimState class.""" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + attribute_specified_in_scopes: bool + attribute_not_specified_in_scopes: bool + + _atom_attributes = ( + SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 + ) + + assert "attribute_not_specified_in_scopes" in str(excinfo.value) + assert "attribute_specified_in_scopes" not in str(excinfo.value) + + +def test_no_duplicate_attributes_in_scopes() -> None: + """Test that no attributes are specified in multiple scopes.""" + + # Capture the exception information using "as excinfo" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + duplicated_attribute: bool + + _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 + _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 + + assert "are declared multiple times" in str(excinfo.value) + assert "duplicated_attribute" in str(excinfo.value) def test_slice_substate( @@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: class DeformState(SimState, DeformGradMixin): """Test class that combines SimState with DeformGradMixin.""" + _system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + ) + def __init__( self, *args, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 069f62eb..490e3528 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -41,6 +41,13 @@ class MDState(SimState): energy: torch.Tensor forces: torch.Tensor + _atom_attributes = ( + SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001 + ) + _system_attributes = ( + SimState._system_attributes | {"energy"} # noqa: SLF001 + ) + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e1ac1d39..f5f23c5c 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -67,6 +67,18 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor + _atom_attributes = ( + SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 + ) + _system_attributes = SimState._system_attributes | { # noqa: SLF001 + "stress", + "cell_positions", + "cell_velocities", + "cell_masses", + "reference_cell", + "energy", + } + @property def momenta(self) -> torch.Tensor: """Calculate momenta from velocities and masses.""" @@ -867,6 +879,25 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns + _system_attributes = ( + MDState._system_attributes # noqa: SLF001 + | { + "reference_cell", + "cell_position", + "cell_momentum", + "cell_mass", + } + ) + _global_attributes = ( + MDState._global_attributes # noqa: SLF001 + | { + "thermostat", + "barostat", + "thermostat_fns", + "barostat_fns", + } + ) + @property def velocities(self) -> torch.Tensor: """Calculate particle velocities from momenta and masses. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8309afa1..18f0ae15 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -266,6 +266,10 @@ class NVTNoseHooverState(MDState): chain: NoseHooverChain _chain_fns: NoseHooverChainFns + _global_attributes = ( + MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001 + ) + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 64aad3c7..be2d99a8 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -36,6 +36,9 @@ class SwapMCState(SimState): energy: torch.Tensor last_permutation: torch.Tensor + _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def generate_swaps( state: SimState, generator: torch.Generator | None = None diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 9a41cd8c..a88d40d9 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -33,6 +33,29 @@ MdFlavor = Literal["vv_fire", "ase_fire"] vv_fire_key, ase_fire_key = get_args(MdFlavor) +_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 +_fire_system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "energy", + "stress", + "cell_positions", + "cell_velocities", + "cell_forces", + "cell_masses", + "cell_factor", + "pressure", + "dt", + "alpha", + "n_pos", + } +) +_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001 + "hydrostatic_strain", + "constant_volume", +} + @dataclass class GDState(SimState): @@ -56,6 +79,9 @@ class GDState(SimState): forces: torch.Tensor energy: torch.Tensor + _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def gradient_descent( model: ModelInterface, *, lr: torch.Tensor | float = 0.01 @@ -149,7 +175,7 @@ def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class UnitCellGDState(GDState, DeformGradMixin): """State class for batched gradient descent optimization with unit cell. @@ -195,6 +221,22 @@ class UnitCellGDState(GDState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor + _system_attributes = ( + GDState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "cell_forces", + "pressure", + "stress", + "cell_positions", + "cell_factor", + "cell_masses", + } + ) + _global_attributes = ( + GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001 + ) + def unit_cell_gradient_descent( # noqa: PLR0915, C901 model: ModelInterface, @@ -438,7 +480,7 @@ def gd_step( return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class FireState(SimState): """State information for batched FIRE optimization. @@ -482,6 +524,17 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | { + "energy", + "dt", + "alpha", + "n_pos", + } + ) + def fire( model: ModelInterface, @@ -619,7 +672,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class UnitCellFireState(SimState, DeformGradMixin): """State information for batched FIRE optimization with unit cell degrees of freedom. @@ -682,7 +735,6 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_masses: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -693,6 +745,10 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def unit_cell_fire( model: ModelInterface, @@ -907,7 +963,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class FrechetCellFIREState(SimState, DeformGradMixin): """State class for batched FIRE optimization with Frechet cell derivatives. @@ -964,7 +1020,6 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -981,6 +1036,10 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def frechet_cell_fire( model: ModelInterface, @@ -1261,7 +1320,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions[nan_velocities] ) if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) @@ -1483,7 +1542,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) forces = state.forces if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8f2917da..14164375 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -541,6 +541,13 @@ class StaticState(type(state)): forces: torch.Tensor stress: torch.Tensor + _atom_attributes = ( + state._atom_attributes | {"forces"} # noqa: SLF001 + ) + _system_attributes = ( + state._system_attributes | {"energy", "stress"} # noqa: SLF001 + ) + all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames @@ -564,8 +571,12 @@ class StaticState(type(state)): sub_state = StaticState( **vars(sub_state), energy=model_outputs["energy"], - forces=model_outputs["forces"] if model.compute_forces else None, - stress=model_outputs["stress"] if model.compute_stress else None, + forces=model_outputs["forces"] + if model.compute_forces + else torch.full_like(sub_state.positions, fill_value=float("nan")), + stress=model_outputs["stress"] + if model.compute_stress + else torch.full_like(sub_state.cell, fill_value=float("nan")), ) props = trajectory_reporter.report(sub_state, 0, model=model) diff --git a/torch_sim/state.py b/torch_sim/state.py index d2ec8351..fa898e1c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,8 +8,10 @@ import importlib import typing import warnings +from collections import defaultdict +from collections.abc import Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast import torch @@ -83,6 +85,15 @@ class SimState: atomic_numbers: torch.Tensor system_idx: torch.Tensor + _atom_attributes: ClassVar[set[str]] = { + "positions", + "masses", + "atomic_numbers", + "system_idx", + } + _system_attributes: ClassVar[set[str]] = {"cell"} + _global_attributes: ClassVar[set[str]] = {"pbc"} + def __init__( self, positions: torch.Tensor, @@ -204,7 +215,7 @@ def n_atoms_per_batch(self) -> torch.Tensor: return self.n_atoms_per_system @property - def batch(self) -> torch.Tensor | None: + def batch(self) -> torch.Tensor: """System indices. deprecated:: @@ -405,7 +416,16 @@ def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + + Also enforce all of child classes's attributes are specified in _atom_attributes, + _system_attributes, or _global_attributes. """ + cls._assert_no_tensor_attributes_can_be_none() + cls._assert_all_attributes_have_defined_scope() + super().__init_subclass__(**kwargs) + + @classmethod + def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) for attr_name, attr_typehint in type_hints.items(): @@ -426,14 +446,65 @@ def __init_subclass__(cls, **kwargs) -> None: "the tensor with dummy values and track the 'None' case." ) - super().__init_subclass__(**kwargs) + @classmethod + def _assert_all_attributes_have_defined_scope(cls) -> None: + all_defined_attributes = ( + cls._atom_attributes | cls._system_attributes | cls._global_attributes + ) + # 1) assert that no attribute is defined twice in all_defined_attributes + duplicates = ( + (cls._atom_attributes & cls._system_attributes) + | (cls._atom_attributes & cls._global_attributes) + | (cls._system_attributes & cls._global_attributes) + ) + if duplicates: + raise TypeError( + f"Attributes {duplicates} are declared multiple times in {cls.__name__} " + "in _atom_attributes, _system_attributes, or _global_attributes" + ) + + # 2) assert that all attributes are defined in all_defined_attributes + all_annotations = {} + for c in cls.mro(): + if hasattr(c, "__annotations__"): + all_annotations.update(c.__annotations__) + + attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) + + for attr_name in attributes_to_check: + is_special_attribute = attr_name.startswith("__") + is_property = attr_name in vars(cls) and isinstance( + vars(cls).get(attr_name), property + ) + is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) + is_class_variable = ( + # Note: _atom_attributes, _system_attributes, and _global_attributes + # are all class variables + typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar + ) + if is_special_attribute or is_property or is_method or is_class_variable: + continue + if attr_name not in all_defined_attributes: + raise TypeError( + f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " + "of _atom_attributes, _system_attributes, or _global_attributes" + ) + + +@dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor - row_vector_cell: torch.Tensor + + _system_attributes: ClassVar[set[str]] = {"reference_cell"} + + if TYPE_CHECKING: + # define this under a TYPE_CHECKING block to avoid it being included in the + # dataclass __init__ during runtime + row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: @@ -543,125 +614,34 @@ def state_to_device( return type(state)(**attrs) -def infer_property_scope( - state: SimState, - ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_system"], list[str]]: - """Infer whether a property is global, per-atom, or per-system. - - Analyzes the shapes of tensor attributes to determine their scope within - the atomistic system representation. - - Args: - state (SimState): The state to analyze - ambiguous_handling ("error" | "globalize" | "globalize_warn"): How to - handle properties with ambiguous scope. Options: - - "error": Raise an error for ambiguous properties - - "globalize": Treat ambiguous properties as global - - "globalize_warn": Treat ambiguous properties as global with a warning - - Returns: - dict[Literal["global", "per_atom", "per_system"], list[str]]: Map of scope - category to list of property names - - Raises: - ValueError: If n_atoms equals n_systems (making scope inference ambiguous) or - if ambiguous_handling="error" and an ambiguous property is encountered - """ - # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_systems, they will be classified incorrectly, - # no clear fix - - if state.n_atoms == state.n_systems: - raise ValueError( - f"n_atoms ({state.n_atoms}) and n_systems ({state.n_systems}) are equal, " - "which means shapes cannot be inferred unambiguously." - ) - - scope = { - "global": [], - "per_atom": [], - "per_system": [], - } - - # Iterate through all attributes - for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) - if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) - continue - - # Handle tensor properties based on shape - shape = attr_value.shape - - # Empty tensor case - if len(shape) == 0: - scope["global"].append(attr_name) - # Vector/matrix with first dimension matching number of atoms - elif shape[0] == state.n_atoms: - scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of systems - elif shape[0] == state.n_systems: - scope["per_system"].append(attr_name) - # Any other shape is ambiguous - elif ambiguous_handling == "error": - raise ValueError( - f"Cannot categorize property '{attr_name}' with shape {shape}. " - f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_systems} (per-system), or a scalar (global)." - ) - elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) - - if ambiguous_handling == "globalize_warn": - warnings.warn( - f"Property '{attr_name}' with shape {shape} is ambiguous, " - "treating as global. This may lead to unexpected behavior " - "and suggests the State is not being used as intended.", - stacklevel=1, - ) - - return scope - - -def _get_property_attrs( - state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" -) -> dict[str, dict]: - """Get global, per-atom, and per-system attributes from a state. - - Categorizes all attributes of the state based on their scope - (global, per-atom, or per-system). +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system", "global"] +) -> Generator[tuple[str, Any], None, None]: + """Get attributes for a given scope. Args: - state (SimState): The state to extract attributes from - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties + state (SimState): The state to get attributes for + scope (Literal["per-atom", "per-system", "global"]): The scope to get + attributes for Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ - scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - - attrs = {"global": {}, "per_atom": {}, "per_system": {}} - - # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) - - # Process per-atom properties - for attr_name in scope["per_atom"]: - attrs["per_atom"][attr_name] = getattr(state, attr_name) - - # Process per-system properties - for attr_name in scope["per_system"]: - attrs["per_system"][attr_name] = getattr(state, attr_name) - - return attrs + match scope: + case "per-atom": + attr_names = state._atom_attributes # noqa: SLF001 + case "per-system": + attr_names = state._system_attributes # noqa: SLF001 + case "global": + attr_names = state._global_attributes # noqa: SLF001 + case _: + raise ValueError(f"Unknown scope: {scope!r}") + for attr_name in attr_names: + yield attr_name, getattr(state, attr_name) def _filter_attrs_by_mask( - attrs: dict[str, dict], + state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, ) -> dict: @@ -670,8 +650,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and systems. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + state (SimState): The state to filter atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) system_mask (torch.Tensor): Boolean mask for systems to include with shape @@ -680,13 +659,11 @@ def _filter_attrs_by_mask( Returns: dict: Filtered attributes with appropriate handling for each scope """ - filtered_attrs = {} - # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Filter per-atom attributes - for attr_name, attr_value in attrs["per_atom"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] @@ -710,7 +687,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -718,7 +695,6 @@ def _filter_attrs_by_mask( def _split_state( state: SimStateVar, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. @@ -727,33 +703,28 @@ def _split_state( Args: state (SimState): The SimState to split - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: list[SimState]: A list of SimState objects, each containing a single system """ - attrs = _get_property_attrs(state, ambiguous_handling) system_sizes = torch.bincount(state.system_idx).tolist() - # Split per-atom attributes by system split_per_atom = {} - for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "system_idx": - continue - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + if attr_name != "system_idx": + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - # Split per-system attributes into individual elements split_per_system = {} - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + global_attrs = dict(get_attrs_for_scope(state, "global")) + # Create a state for each system states = [] - for i in range(state.n_systems): + n_systems = len(system_sizes) + for i in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( @@ -767,7 +738,7 @@ def _split_state( for attr_name in split_per_system }, # Add the global attributes - **attrs["global"], + **global_attrs, } states.append(type(state)(**system_attrs)) @@ -777,7 +748,6 @@ def _split_state( def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. @@ -786,10 +756,6 @@ def _pop_states( Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: tuple[SimState, list[SimState]]: A tuple containing: @@ -805,8 +771,6 @@ def _pop_states( if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to keep and pop system_range = torch.arange(state.n_systems, device=state.device) pop_system_mask = torch.isin(system_range, pop_indices) @@ -816,15 +780,15 @@ def _pop_states( keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_system_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) + keep_attrs = _filter_attrs_by_mask(state, keep_atom_mask, keep_system_mask) + pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) # Create and split the pop state pop_state = type(state)(**pop_attrs) - pop_states = _split_state(pop_state, ambiguous_handling) + pop_states = _split_state(pop_state) return keep_state, pop_states @@ -832,7 +796,6 @@ def _pop_states( def _slice_state( state: SimStateVar, system_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. @@ -843,10 +806,6 @@ def _slice_state( state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: SimState: A new SimState object containing only the specified systems @@ -862,15 +821,13 @@ def _slice_state( if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to include system_range = torch.arange(state.n_systems, device=state.device) system_mask = torch.isin(system_range, system_indices) atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) + filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -911,19 +868,12 @@ def concatenate_states( # Use the target device or default to the first state's device target_device = device or first_state.device - # Get property scopes from the first state to identify - # global/per-atom/per-system properties - first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) - per_atom_props = set(first_scope["per_atom"]) - per_system_props = set(first_scope["per_system"]) - # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = {prop: [] for prop in per_atom_props} - per_system_tensors = {prop: [] for prop in per_system_props} + per_atom_tensors = defaultdict(list) + per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 @@ -934,14 +884,15 @@ def concatenate_states( state = state_to_device(state, target_device) # Collect per-atom properties - for prop in per_atom_props: - # if hasattr(state, prop): - per_atom_tensors[prop].append(getattr(state, prop)) + for prop, val in get_attrs_for_scope(state, "per-atom"): + if prop == "system_idx": + # skip system_idx, it will be handled below + continue + per_atom_tensors[prop].append(val) # Collect per-system properties - for prop in per_system_props: - # if hasattr(state, prop): - per_system_tensors[prop].append(getattr(state, prop)) + for prop, val in get_attrs_for_scope(state, "per-system"): + per_system_tensors[prop].append(val) # Update system indices num_systems = state.n_systems From cd0b0a09857d4ba6d91de5c0e82bc349cbad0573 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 13 Aug 2025 21:24:57 -0400 Subject: [PATCH 15/16] MAINT: update pins in MACE phonons example. Remove misleading ty from PR template (#239) --- .github/PULL_REQUEST_TEMPLATE.md | 1 - docs/_static/draw_pkg_treemap.py | 4 ++-- examples/scripts/6_Phonons/6.1_Phonons_MACE.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b71a5910..0194cc5e 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -9,7 +9,6 @@ Before a pull request can be merged, the following items must be checked: * [ ] Doc strings have been added in the [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). * [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. -* [ ] Run `uvx ty check` on the repo. * [ ] Tests have been added for any new functionality or bug fixes. We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit. diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 3a983387..44e762e7 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -5,8 +5,8 @@ # /// script # dependencies = [ -# "pymatviz==0.16.0", -# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 +# "pymatviz>=0.17.1", +# "plotly>=6.3.0", # ] # /// diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 38a50d60..3968acda 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,10 +4,10 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz==0.16", +# "pymatviz>=0.17.1", +# "plotly>=6.3.0", # "seekpath", # "ase", -# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// From 9c4b5ca8e8ae331ab740ed62c1ad57b17b35afc4 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 13 Aug 2025 21:40:21 -0400 Subject: [PATCH 16/16] fix: replace batch with system_idx --- torch_sim/models/fairchem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1e36c20f..3aae3f3a 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -177,7 +177,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict: # Convert SimState to AtomicData objects for efficient batch processing from ase import Atoms - natoms = torch.bincount(state.batch) + natoms = torch.bincount(state.system_idx) atomic_data_list = [] for i, (n, c) in enumerate(