diff --git a/tests/models/test_einstein.py b/tests/models/test_einstein.py new file mode 100644 index 00000000..68a9b5f6 --- /dev/null +++ b/tests/models/test_einstein.py @@ -0,0 +1,333 @@ +"""Tests for the Einstein model implementation.""" + +import pytest +import torch +from ase.build import bulk + +import torch_sim as ts +from torch_sim.models.einstein import EinsteinModel + + +class TestEinsteinModel: + """Test Einstein model implementation.""" + + @pytest.fixture + def simple_system(self): + """Create a simple test system.""" + device = torch.device("cpu") + dtype = torch.float64 + + # Create a simple 2-atom system + positions = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=dtype, device=device + ) + masses = torch.tensor([1.0, 1.0], dtype=dtype, device=device) + cell = torch.eye(3, dtype=dtype, device=device) * 10.0 # Large cell + atomic_numbers = torch.tensor([1, 1], dtype=torch.int64, device=device) + + state = ts.SimState( + positions=positions, + masses=masses, + cell=cell.unsqueeze(0), + pbc=True, + atomic_numbers=atomic_numbers, + ) + + return state, device, dtype + + @pytest.fixture + def batched_system(self): + """Create a batched test system.""" + device = torch.device("cpu") + dtype = torch.float64 + + # Create two different 2-atom systems + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + fe_atoms = bulk("Fe", "bcc", a=2.87, cubic=True) + + state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype) + return state, device, dtype + + def test_einstein_model_creation( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test basic Einstein model creation.""" + state, device, dtype = simple_system + + # Create equilibrium positions and frequencies + equilibrium_pos = state.positions + frequencies = torch.ones(2, dtype=dtype, device=device) * 0.1 # [N] + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + assert model.device == device + assert model.dtype == dtype + assert model.compute_forces is True + + def test_einstein_model_forward_single_system( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test forward pass with single system.""" + state, device, dtype = simple_system + + equilibrium_pos = state.positions + frequencies = torch.ones(2, dtype=dtype, device=device) * 0.1 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + # Displace atoms slightly from equilibrium + displaced_state = state.clone() + displaced_state.positions += 0.1 + + results = model(displaced_state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape == (1,) + assert results["forces"].shape == (2, 3) # [N_atoms, 3] + + # Forces should point back toward equilibrium + expected_force_direction = -(displaced_state.positions - equilibrium_pos) + force_directions = results["forces"] + # Check that forces point in the right direction (dot product > 0) + for i in range(2): + dot_product = torch.dot(force_directions[i], expected_force_direction[i]) + assert dot_product > 0 + + def test_einstein_model_forward_batched_system( + self, batched_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test forward pass with batched system.""" + state, device, dtype = batched_system + + # Create equilibrium positions for the batched system + n_atoms = state.n_atoms + equilibrium_pos = state.positions + frequencies = torch.ones(n_atoms, dtype=dtype, device=device) * 0.05 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + # Displace atoms slightly + displaced_state = state.clone() + displaced_state.positions += 0.05 + + results = model(displaced_state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape == (2,) # [n_systems] + assert results["forces"].shape == (n_atoms, 3) # [total_atoms, 3] + + def test_einstein_model_from_frequencies(self): + """Test creation from frequencies class method.""" + device = torch.device("cpu") + dtype = torch.float64 + + # Create ASE atoms + atoms = bulk("Si", "diamond", a=5.43, cubic=True) + state = ts.io.atoms_to_state([atoms], device, dtype) + + frequencies = torch.ones(len(atoms), dtype=dtype, device=device) * 0.05 + + model = EinsteinModel.from_atom_and_frequencies( + atom=state, + frequencies=frequencies, + reference_energy=1.0, + device=device, + dtype=dtype, + ) + + assert torch.allclose(model.reference_energy, torch.tensor(1.0, dtype=dtype)) + assert model.frequencies.shape[0] == len(atoms) + + def test_periodic_boundary_conditions( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test that PBC are handled correctly.""" + state, device, dtype = simple_system + + # Create model with equilibrium at origin + equilibrium_pos = torch.zeros((2, 3), dtype=dtype, device=device) + frequencies = torch.ones(2, dtype=dtype, device=device) * 1 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + device=device, + dtype=dtype, + ) + + # Place atoms near opposite faces of the cell + test_state = state.clone() + test_state.positions = torch.tensor( + [ + [0.1, 0.0, 0.0], # Near one face + [9.9, 0.0, 0.0], # Near opposite face + ], + dtype=dtype, + device=device, + ) + + results = model(test_state) + + # Should handle PBC correctly - both atoms far from origin + # but should compute minimum image distances + assert torch.isfinite(results["energy"]) + assert torch.isfinite(results["forces"]).all() + + spring = frequencies**2 # since mass=1 + target_energies = 0.5 * spring * (torch.tensor([0.1, -0.1]) ** 2) + target_forces = -spring[:, None] * torch.tensor( + [[0.1, 0.0, 0.0], [-0.1, 0.0, 0.0]], dtype=dtype, device=device + ) + assert torch.allclose(results["energy"], target_energies.sum(), atol=1e-6) + assert torch.allclose(results["forces"], target_forces, atol=1e-6) + + def test_energy_force_consistency( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test that forces are consistent with energy gradients.""" + state, device, dtype = simple_system + + equilibrium_pos = state.positions.clone() + frequencies = torch.ones(2, dtype=dtype, device=device) * 0.1 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + # Create a displaced state with gradients enabled + test_positions = state.positions.clone() + 0.1 + test_positions.requires_grad_(requires_grad=True) + + test_state = state.clone() + test_state.positions = test_positions + + results = model(test_state) + energy = results["energy"] + + # Compute forces from gradients + forces_from_grad = -torch.autograd.grad( + energy, test_positions, create_graph=False + )[0] + forces_direct = results["forces"] + + # Forces should match (within numerical precision) + torch.testing.assert_close(forces_direct, forces_from_grad, atol=1e-6, rtol=1e-6) + + def test_get_free_energy( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test free energy calculation.""" + state, device, dtype = simple_system + + equilibrium_pos = state.positions + frequencies = torch.ones(2, dtype=dtype, device=device) * 0.1 # THz + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + temperature = 300.0 # K + results = model.get_free_energy(temperature) + + # Check that result is a dictionary with free energy + assert isinstance(results, dict) + assert "free_energy" in results + free_energy = results["free_energy"] + assert isinstance(free_energy, torch.Tensor) + assert free_energy.shape == (1,) # Single system + + # Free energy should be finite + assert torch.isfinite(free_energy).all() + + # At higher temperature, free energy should be lower (more negative) + results_high = model.get_free_energy(600.0) + free_energy_high = results_high["free_energy"] + assert free_energy_high < free_energy + + def test_get_free_energy_batched( + self, batched_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test free energy calculation for batched systems.""" + state, device, dtype = batched_system + + n_atoms = state.n_atoms + equilibrium_pos = state.positions + frequencies = torch.ones(n_atoms, dtype=dtype, device=device) * 0.05 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + system_idx=state.system_idx, + device=device, + dtype=dtype, + ) + + temperature = 300.0 + results = model.get_free_energy(temperature) + + # Check result format and shape + assert isinstance(results, dict) + assert "free_energy" in results + free_energy = results["free_energy"] + + # Should have one free energy per system + assert free_energy.shape == (2,) # Two systems + assert torch.isfinite(free_energy).all() + + def test_sample_method( + self, simple_system: tuple[ts.SimState, torch.device, torch.dtype] + ): + """Test sampling from Einstein model.""" + state, device, dtype = simple_system + + equilibrium_pos = state.positions + frequencies = torch.ones(2, dtype=dtype, device=device) * 0.1 + + model = EinsteinModel( + equilibrium_position=equilibrium_pos, + frequencies=frequencies, + masses=state.masses, + device=device, + dtype=dtype, + ) + + temperature = 300.0 + sampled_state = model.sample(state, temperature) + + # Check that sampled state has correct shape and type + assert isinstance(sampled_state, ts.SimState) + assert sampled_state.positions.shape == state.positions.shape + assert sampled_state.positions.dtype == dtype + assert sampled_state.positions.device == device + + # Sampled positions should have finite values + assert torch.isfinite(sampled_state.positions).all() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c9317cdf..b9563e4f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -911,6 +911,162 @@ def test_minimum_image_displacement( torch.testing.assert_close(result, torch.tensor(expected)) +@pytest.mark.parametrize( + ("dr", "cell", "system_idx", "pbc", "expected"), + [ + # Single system case - should match minimum_image_displacement results + ( + [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], + torch.stack([torch.eye(3) * 3.0]), + torch.tensor([0, 0]), + False, + [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], + ), + ( + [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], + torch.stack([torch.eye(3) * 3.0]), + torch.tensor([0, 0]), + True, + [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], + ), + ( + [[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]], + torch.stack([torch.eye(3) * 2.0]), + torch.tensor([0, 0, 0]), + True, + [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], + ), + # Multi-system case - different cells for different systems + ( + [[2.2, 0.0, 0.0], [0.0, 3.2, 0.0]], # Different displacements + torch.stack([torch.eye(3) * 2.0, torch.eye(3) * 3.0]), # Different cells + torch.tensor([0, 1]), # Different systems + True, + # Both wrapped with their respective cells + [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0]], + ), + # Multi-system with mixed cases + ( + [[1.5, 0.0, 0.0], [3.2, 0.0, 0.0], [0.0, 2.5, 0.0]], + torch.stack([torch.eye(3) * 3.0, torch.eye(3) * 4.0]), + torch.tensor([0, 1, 1]), # First atom in system 0, others in system 1 + True, + # Different wrapping per system + [[1.5, 0.0, 0.0], [-0.8, 0.0, 0.0], [0.0, -1.5, 0.0]], + ), + # No PBC case with multiple systems + ( + [[1.5, 1.5, 1.5], [2.2, 0.0, 0.0]], + torch.stack([torch.eye(3) * 3.0, torch.eye(3) * 2.0]), + torch.tensor([0, 1]), + False, + [[1.5, 1.5, 1.5], [2.2, 0.0, 0.0]], # No wrapping when pbc=False + ), + ], +) +def test_minimum_image_displacement_batched( + *, + dr: list[list[float]], + cell: torch.Tensor, + system_idx: torch.Tensor, + pbc: bool, + expected: list[list[float]], +) -> None: + """Test minimum_image_displacement_batched with various inputs. + + Tests function with single and multiple systems, with and without PBC. + Reuses test cases from minimum_image_displacement to ensure consistency. + """ + result = tst.minimum_image_displacement_batched( + dr=torch.tensor(dr), cell=cell, system_idx=system_idx, pbc=pbc + ) + torch.testing.assert_close(result, torch.tensor(expected)) + + +def test_minimum_image_displacement_batched_consistency() -> None: + """Test that batched version matches single system calls for single systems.""" + # Test parameters from the single system test + dr = torch.tensor([[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]]) + cell_single = torch.eye(3) * 2.0 + cell_batched = torch.stack([cell_single]) + system_idx = torch.tensor([0, 0, 0]) + + # Single system result + result_single = tst.minimum_image_displacement(dr=dr, cell=cell_single, pbc=True) + + # Batched result with single system + result_batched = tst.minimum_image_displacement_batched( + dr=dr, cell=cell_batched, system_idx=system_idx, pbc=True + ) + + torch.testing.assert_close(result_single, result_batched) + + +def test_minimum_image_displacement_batched_triclinic() -> None: + """Test minimum_image_displacement_batched with triclinic cells.""" + # Define different triclinic cells for different systems + cell1 = torch.tensor([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]) + cell2 = torch.tensor([[3.0, 0.0, 0.5], [0.3, 3.0, 0.0], [0.0, 0.0, 3.0]]) + + cell_batched = torch.stack([cell1, cell2]) + + # Create displacement vectors that need wrapping + dr = torch.tensor( + [ + [2.5, 2.5, 2.5], # System 0 + [3.5, 3.5, 3.5], # System 1 + ] + ) + system_idx = torch.tensor([0, 1]) + + result = tst.minimum_image_displacement_batched( + dr=dr, cell=cell_batched, system_idx=system_idx, pbc=True + ) + + # Verify results by computing expected values manually + # For system 0 with cell1 + expected_0 = tst.minimum_image_displacement(dr=dr[0:1], cell=cell1, pbc=True) + # For system 1 with cell2 + expected_1 = tst.minimum_image_displacement(dr=dr[1:2], cell=cell2, pbc=True) + + torch.testing.assert_close(result[0:1], expected_0) + torch.testing.assert_close(result[1:2], expected_1) + + +def test_minimum_image_displacement_batched_invalid_inputs() -> None: + """Test error handling for invalid inputs in batched minimum image displacement.""" + dr = torch.ones(4, 3) + cell = torch.stack([torch.eye(3)] * 2) + system_idx = torch.tensor([0, 0, 1, 1]) + + # Test integer tensors + with pytest.raises(TypeError): + tst.minimum_image_displacement_batched( + dr=torch.ones(4, 3, dtype=torch.int64), + cell=cell, + system_idx=system_idx, + pbc=True, + ) + + # Test dimension mismatch - displacement vectors + with pytest.raises(ValueError): + tst.minimum_image_displacement_batched( + dr=torch.ones(4, 2), # Wrong dimension (2 instead of 3) + cell=cell, + system_idx=system_idx, + pbc=True, + ) + + # Test mismatch between system indices and cell + with pytest.raises(ValueError): + tst.minimum_image_displacement_batched( + dr=dr, + cell=torch.stack([torch.eye(3)] * 3), # 3 cells but only 2 systems + system_idx=system_idx, + pbc=True, + ) + + @pytest.mark.parametrize( ("positions", "cell", "pbc", "pairs", "shifts", "expected_dr", "expected_distance"), [ diff --git a/torch_sim/models/einstein.py b/torch_sim/models/einstein.py new file mode 100644 index 00000000..072a0e75 --- /dev/null +++ b/torch_sim/models/einstein.py @@ -0,0 +1,276 @@ +"""Einstein model where each atom is treated as an independent 3D harmonic oscillator. + +Contrary to other models, the model energies depend on an absolute reference position, +so the model can only be used on systems that the model was initialized with. +As a analytical model, it can provide its Helmholtz free energy and can also generate +samples from the Boltzmann distribution at a given temperature. +""" + +import torch + +import torch_sim as ts +from torch_sim import SimState, units +from torch_sim.models.interface import ModelInterface + + +class EinsteinModel(ModelInterface): + """Einstein model where each atom is treated as an independent 3D harmonic oscillator. + Each atom has its own frequency. + + For this model: + E = sum_i 0.5 * k_i * (x_i - x0_i)^2 + F = -k_i * (x_i - x0_i) + k_i = m_i * omega_i^2 + + For best results, frequencies should be in the range of typical phonon frequencies. + They can be set for each atom type individually following energy balance from + a NVT simulation. From equipartition theorem: + = 3/2 k_B T + => omega = sqrt(3 k_B T / m ) + """ + + def __init__( + self, + equilibrium_position: torch.Tensor, # shape [N, 3] + frequencies: torch.Tensor, # shape [N] + system_idx: torch.Tensor | None = None, # shape [N] or None + masses: torch.Tensor | None = None, # shape [N] or None + reference_energy: float = 0.0, # reference energy value + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + compute_forces: bool = True, + compute_stress: bool = False, + ) -> None: + """Initialize the Einstein model. + + Args: + equilibrium_position: Tensor of shape [N, 3] with equilibrium positions. + frequencies: Tensor of shape [N] with frequencies for each atom + (same frequency in all 3 directions). + system_idx: Optional tensor of shape [N] with system indices for each atom. + If None, all atoms are assumed to belong to the same system. + masses: Optional tensor of shape [N] with masses for each atom. + If None, all masses are set to 1. + reference_energy: Reference energy value to add to the computed energy. + device: Device to use for the model (default: CPU). + dtype: Data type for the model (default: torch.float32). + compute_forces: Whether to compute forces in the model. + compute_stress: Whether to compute stress in the model. + + """ + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + + equilibrium_position = torch.as_tensor( + equilibrium_position, device=self._device, dtype=self._dtype + ) + frequencies = torch.as_tensor( + frequencies, device=self._device, dtype=self._dtype + ) # [N, 3] + + if frequencies.shape[0] != equilibrium_position.shape[0]: + raise ValueError("frequencies shape must match equilibrium_position shape") + if frequencies.min() < 0: + raise ValueError("frequencies must be non-negative") + if frequencies.ndim == 0: + frequencies = frequencies.unsqueeze(0) + if frequencies.ndim != 1: + raise ValueError("frequencies must be a 1D tensor") + + if masses is None: + masses = torch.ones( + equilibrium_position.shape[0], dtype=self._dtype, device=self._device + ) + else: + masses = masses.to(self._device, self._dtype) + + if system_idx is not None: + system_idx = system_idx.to(self._device) + else: + system_idx = torch.zeros( + equilibrium_position.shape[0], dtype=torch.long, device=self._device + ) + + self.register_buffer("system_idx", system_idx.to(self._device)) + self.register_buffer("masses", masses) # [N] + self.register_buffer("x0", equilibrium_position) # [N, 3] + self.register_buffer("frequencies", frequencies) # [N] + self.register_buffer( + "reference_energy", + torch.tensor(reference_energy, dtype=self._dtype, device=self._device), + ) + + @classmethod + def from_atom_and_frequencies( + cls, + atom: SimState, + frequencies: torch.Tensor | float, + *, + reference_energy: float = 0.0, + compute_forces: bool = True, + compute_stress: bool = False, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + ) -> "EinsteinModel": + """Create an EinsteinModel from an ASE Atoms object and frequencies. + + Args: + atom: ASE Atoms object containing the reference structure. + frequencies: Tensor of shape [N] with frequencies for each atom + (same frequency in all 3 directions) or a scalar. + reference_energy: Reference energy value. + compute_forces: Whether to compute forces in the model. + compute_stress: Whether to compute stress in the model. + device: Device to use for the model (default: CPU). + dtype: Data type for the model (default: torch.float32). + + Returns: + EinsteinModel: An instance of the EinsteinModel. + """ + # Get equilibrium positions from the atoms object + equilibrium_position = atom.positions.clone().to(dtype=dtype, device=device) + + frequencies = torch.as_tensor(frequencies, dtype=dtype, device=device) + if frequencies.ndim == 0: + frequencies = frequencies.repeat(atom.positions.shape[0]) + if frequencies.shape[0] != atom.positions.shape[0]: + raise ValueError( + "frequencies must be a scalar or a tensor of shape [N] " + "where N is the number of atoms" + ) + + # Create and return an instance of EinsteinModel + return cls( + equilibrium_position=equilibrium_position, + frequencies=frequencies, + masses=atom.masses, + system_idx=atom.system_idx, + reference_energy=reference_energy, + compute_forces=compute_forces, + compute_stress=compute_stress, + device=device, + dtype=dtype, + ) + + def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: + """Calculate energies and forces for the Einstein model. + + Args: + state: SimState or StateDict containing positions, cell, etc. + + Returns: + Dictionary containing energy, forces + """ + pos = state.positions.to(self._dtype) # [N, 3] + cell = state.cell.to(self._dtype) + + if cell.ndim == 2: + cell = cell.unsqueeze(0) # [1, 3, 3] + + # Get model parameters + x0 = torch.as_tensor(self.x0, dtype=self._dtype, device=self._device) + frequencies = torch.as_tensor( + self.frequencies, dtype=self._dtype, device=self._device + ) + masses = torch.as_tensor(self.masses, dtype=self._dtype, device=self._device) + + # Calculate displacements using periodic boundary conditions + if cell.shape[0] == 1: + disp = ts.transforms.minimum_image_displacement( + dr=pos - x0, cell=cell[0], pbc=state.pbc + ) + else: + disp = ts.transforms.minimum_image_displacement_batched( + pos - x0, cell, system_idx=state.system_idx, pbc=state.pbc + ) + + # Spring constants: k = m * omega^2 + spring_constants = masses * (frequencies**2) # [N] + + # Energy: E = 0.5 * k * x^2 + energies_per_mode = 0.5 * spring_constants * ((disp**2).sum(dim=1)) # [N] + total_energy = torch.zeros( + state.n_systems, dtype=self._dtype, device=self._device + ) + total_energy.scatter_add_(0, state.system_idx, energies_per_mode) + total_energy += self.reference_energy + + # Forces: F = -k * x + forces = -spring_constants.unsqueeze(-1) * disp # [N, 3] + + results = { + "energy": total_energy, + "forces": forces, + } + # Stress is not implemented for this model + if self._compute_stress: + results["stress"] = torch.zeros( + (state.n_systems, 3, 3), dtype=self._dtype, device=self._device + ) + + return results + + def get_free_energy(self, temperature: float) -> dict[str, torch.Tensor]: + """Compute free energy at a given temperature using Einstein model. + + Args: + temperature: Temperature in Kelvin. + + Returns: + Dictionary containing heat capacity, entropy, and free energy. + """ + # Boltzmann constant in eV/K + kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J + T = temperature + # Reduced Planck constant in eV*s + hbar = units.BaseConstant.h_planck / (2 * units.pi * units.UnitConversion.eV_to_J) + + frequencies_tensor = ( + torch.as_tensor(self.frequencies).clone() + * torch.as_tensor( + units.UnitConversion.eV_to_J / units.BaseConstant.amu + ).sqrt() + / units.UnitConversion.Ang_to_met + ) # Convert to rad/s + free_energy_per_atom = ( + -3 * kB * T * torch.log(kB * T / (hbar * frequencies_tensor)) + ) + + n_systems = self.system_idx.max().item() + 1 + free_energy_per_system = torch.zeros( + n_systems, dtype=self._dtype, device=self._device + ) + free_energy_per_system.scatter_add_(0, self.system_idx, free_energy_per_atom) + + return {"free_energy": free_energy_per_system} + + def sample(self, state: SimState, temperature: float) -> SimState: + """Generate samples from the Einstein model at a given temperature. + + Args: + state: Initial simulation state to sample from. + temperature: Temperature in Kelvin. + + Returns: + SimState containing sampled positions and velocities. + + The Boltzmann distribution for a harmonic oscillator leads to Gaussian + distributions + for both positions and velocities. + """ + N = self.x0.shape[0] + kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J + beta = 1.0 / (kB * temperature) # Inverse temperature in 1/eV + + # Sample positions from a normal distribution around equilibrium positions + stddev = torch.sqrt(1.0 / (self.masses * (self.frequencies**2) * beta)).unsqueeze( + -1 + ) + sampled_positions = self.x0 + torch.randn(N, 3, device=self._device) * stddev + state = state.clone() + state.positions = sampled_positions + return state diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index fd65b23f..132f6d5c 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -416,7 +416,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] diff --git a/torch_sim/steered_md.py b/torch_sim/steered_md.py new file mode 100644 index 00000000..0636d008 --- /dev/null +++ b/torch_sim/steered_md.py @@ -0,0 +1,732 @@ +"""Thermodynamic integration module for molecular dynamics simulations.""" + +import os +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +from tqdm import tqdm + +import torch_sim as ts +from torch_sim.autobatching import BinningAutoBatcher +from torch_sim.integrators.md import ( + MDState, + calculate_momenta, + momentum_step, + position_step, +) +from torch_sim.models.interface import ModelInterface +from torch_sim.runners import _configure_batches_iterator, _configure_reporter +from torch_sim.state import SimState, concatenate_states, initialize_state +from torch_sim.typing import StateDict +from torch_sim.units import UnitSystem + + +def linear_lambda_schedule(step: int, n_steps: int) -> float: + """Linear lambda schedule: λ(t) = t/T.""" + return step / n_steps + + +def quadratic_lambda_schedule(step: int, n_steps: int) -> float: + """Quadratic lambda schedule: λ(t) = -1*(1-t/T)² + 1.""" + t_normalized = step / n_steps + return -1 * (1 - t_normalized) ** 2 + 1 + + +def cubic_lambda_schedule(step: int, n_steps: int) -> float: + """Cubic lambda schedule: λ(t) = -1*(1-t/T)³ + 1.""" + t_normalized = step / n_steps + return -1 * (1 - t_normalized) ** 3 + 1 + + +def lammps_lambda_schedule(step: int, n_steps: int) -> float: + """Lambda schedule used in LAMMPS paper: λ(t) = 0.5*(1 - cos(π*t/T)).""" + t = step / n_steps + return t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) + + +LAMBDA_SCHEDULES = { + "linear": linear_lambda_schedule, + "quadratic": quadratic_lambda_schedule, + "lammps": lammps_lambda_schedule, + "cubic": cubic_lambda_schedule, +} + + +@dataclass +class ThermodynamicIntegrationMDState(MDState): + """Custom state for thermodynamic integration in MD simulations. + + This state can hold additional properties like lambda_ for TI. + """ + + lambda_: torch.Tensor + energy_difference: torch.Tensor + energy1: torch.Tensor + energy2: torch.Tensor + + _system_attributes = MDState._system_attributes | { # noqa: SLF001 + "lambda_", + "energy_difference", + "energy1", + "energy2", + } + + +class MixedModel(ModelInterface): + """A model that mixes two models for thermodynamic integration. + + This class implements a linear combination of two models based on a lambda + parameter, which is used for thermodynamic integration calculations to + compute free energy differences. + """ + + def __init__( + self, + model1: ModelInterface, + model2: ModelInterface, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + *, + compute_stress: bool = False, + compute_forces: bool = True, + ) -> None: + """Initialize the mixed model. + + Args: + model1: First model in the mixture + model2: Second model in the mixture + device: Device to run computations on + dtype: Data type for computations + compute_stress: Whether to compute stress + compute_forces: Whether to compute forces + """ + super().__init__() + self.model1 = model1 + self.model2 = model2 + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + if isinstance(self._device, str): + self._device = torch.device(self._device) + + self._dtype = dtype + self._compute_stress = compute_stress + self._compute_forces = compute_forces + + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + """Forward pass through the mixed model. + + Args: + state: Simulation state containing positions, masses, etc. + + Returns: + Dictionary with mixed energies and forces + """ + if "lambda_" not in state.__dict__: + lambda_ = 1 + lambda_per_atom = torch.tensor(1.0, device=self._device, dtype=self._dtype) + else: + lambda_ = state.lambda_ + lambda_per_atom = lambda_[state.system_idx] + out1 = self.model1(state) + out2 = self.model2(state) + + # Combine matching keys + output = {} + output["energy"] = (1 - lambda_) * out1["energy"] + lambda_ * out2["energy"] + output["forces"] = (1 - lambda_per_atom).view(-1, 1) * out1["forces"] + ( + lambda_per_atom + ).view(-1, 1) * out2["forces"] + output["energy_difference"] = out2["energy"] - out1["energy"] + output["energy1"] = out1["energy"] + output["energy2"] = out2["energy"] + return output + + +def nvt_langevin_thermodynamic_integration( # noqa: C901 + model: torch.nn.Module, + *, + dt: torch.Tensor, + kT: torch.Tensor, + gamma: torch.Tensor | None = None, + seed: int | None = None, +) -> tuple[ + Callable[[SimState | StateDict, torch.Tensor], MDState], + Callable[[MDState, torch.Tensor], MDState], +]: + """Initialize and return an NVT (canonical) integrator using Langevin dynamics. + + This function sets up integration in the NVT ensemble, where particle number (N), + volume (V), and temperature (T) are conserved. It returns both an initial state + and an update function for time evolution. + + It uses Langevin dynamics with stochastic noise and friction to maintain constant + temperature. The integration scheme combines deterministic velocity Verlet steps with + stochastic Ornstein-Uhlenbeck processes following the BAOAB splitting scheme. + + Args: + model (torch.nn.Module): Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_batches] + gamma (torch.Tensor, optional): Friction coefficient for Langevin thermostat, + either scalar or with shape [n_batches]. Defaults to 1/(100*dt). + seed (int, optional): Random seed for reproducibility. Defaults to None. + + Returns: + tuple: + - callable: Function to initialize the MDState from input data + with signature: init_fn(state, kT=kT, seed=seed) -> MDState + - callable: Update function that evolves system by one timestep + with signature: update_fn(state, dt=dt, kT=kT, gamma=gamma) -> MDState + + Notes: + - Uses BAOAB splitting scheme for Langevin dynamics + - Preserves detailed balance for correct NVT sampling + - Handles periodic boundary conditions if enabled in state + - Friction coefficient gamma controls the thermostat coupling strength + - Weak coupling (small gamma) preserves dynamics but with slower thermalization + - Strong coupling (large gamma) faster thermalization but may distort dynamics + """ + device, dtype = model.device, model.dtype + + if gamma is None: + gamma = 1 / (100 * dt) + + if isinstance(gamma, float): + gamma = torch.tensor(gamma, device=device, dtype=dtype) + + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + + def ou_step( + state: ThermodynamicIntegrationMDState, + dt: torch.Tensor, + kT: torch.Tensor, + gamma: torch.Tensor, + ) -> ThermodynamicIntegrationMDState: + """Apply stochastic noise and friction for Langevin dynamics. + + This function implements the Ornstein-Uhlenbeck process for Langevin dynamics, + applying random noise and friction forces to particle momenta. The noise amplitude + is chosen to satisfy the fluctuation-dissipation theorem, ensuring proper + sampling of the canonical ensemble at temperature kT. + + Args: + state (ThermodynamicIntegrationMDState): Current system state containing + positions, momenta, etc. + dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_batches] + gamma (torch.Tensor): Friction coefficient controlling noise strength, + either scalar or with shape [n_batches] + + Returns: + ThermodynamicIntegrationMDState: Updated state with new momenta + after stochastic step + + Notes: + - Implements the "O" step in the BAOAB Langevin integration scheme + - Uses Ornstein-Uhlenbeck process for correct thermal sampling + - Noise amplitude scales with sqrt(mass) for equipartition + - Preserves detailed balance through fluctuation-dissipation relation + - The equation implemented is: + p(t+dt) = c1*p(t) + c2*sqrt(m)*N(0,1) + where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1²)) + """ + c1 = torch.exp(-gamma * dt) + + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: + # kT is a tensor with shape (n_batches,) + kT = kT[state.system_idx] + + # Index c1 and c2 with state.system_idx to align shapes with state.momenta + if isinstance(c1, torch.Tensor) and len(c1.shape) > 0: + c1 = c1[state.system_idx] + + c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) + + # Generate random noise from normal distribution + noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) + new_momenta = ( + c1.unsqueeze(-1) * state.momenta + + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise + ) + state.momenta = new_momenta + return state + + def langevin_init( + state: SimState | StateDict, + lambda_: torch.Tensor, + kT: torch.Tensor = kT, + seed: int | None = seed, + ) -> ThermodynamicIntegrationMDState: + """Initialize an NVT state from input data for Langevin dynamics. + + Creates an initial state for NVT molecular dynamics by computing initial + energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution + at the specified temperature. + + Args: + state (SimState | StateDict): Either a SimState object or a dictionary + containing positions, masses, cell, pbc, and other required state vars + lambda_ (torch.Tensor): Initial lambda values for each system in the batch + kT (torch.Tensor): Temperature in energy units for initializing momenta, + either scalar or with shape [n_batches] + seed (int, optional): Random seed for reproducibility + + Returns: + MDState: Initialized state for NVT integration containing positions, + momenta, forces, energy, and other required attributes + + Notes: + The initial momenta are sampled from a Maxwell-Boltzmann distribution + at the specified temperature. This provides a proper thermal initial + state for the subsequent Langevin dynamics. + """ + if not isinstance(state, SimState): + state = SimState(**state) + model_output = model(state) + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + initial_state = ThermodynamicIntegrationMDState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + energy_difference=model_output["energy_difference"], + energy1=model_output["energy1"], + energy2=model_output["energy2"], + lambda_=lambda_, + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + ) + + return initial_state # noqa: RET504 + + def langevin_update( + state: ThermodynamicIntegrationMDState, + dt: torch.Tensor = dt, + kT: torch.Tensor = kT, + gamma: torch.Tensor = gamma, + ) -> ThermodynamicIntegrationMDState: + """Perform one complete Langevin dynamics integration step. + + This function implements the BAOAB splitting scheme for Langevin dynamics, + which provides accurate sampling of the canonical ensemble. The integration + sequence is: + 1. Half momentum update using forces (B step) + 2. Half position update using updated momenta (A step) + 3. Full stochastic update with noise and friction (O step) + 4. Half position update using updated momenta (A step) + 5. Half momentum update using new forces (B step) + + Args: + state (ThermodynamicIntegrationMDState): Current system state + containing positions, momenta, forces + dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_batches] + gamma (torch.Tensor): Friction coefficient for Langevin thermostat, + either scalar or with shape [n_batches] + + Returns: + ThermodynamicIntegrationMDState: Updated state after one complete Langevin + step with new positions, momenta, forces, and energy + """ + # if isinstance(gamma, float): + # gamma = torch.tensor(gamma, device=device, dtype=dtype) + + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + state = momentum_step(state, dt / 2) + state = position_step(state, dt / 2) + state = ou_step(state, dt, kT, gamma) + state = position_step(state, dt / 2) + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] + state.energy_difference = model_output["energy_difference"] + state.energy1 = model_output["energy1"] + state.energy2 = model_output["energy2"] + + return momentum_step(state, dt / 2) + + return langevin_init, langevin_update + + +def run_non_equilibrium_md( # noqa: C901 PLR0915 + system: Any, + model_a: torch.nn.Module, + model_b: torch.nn.Module, + save_dir: str, + integrator: Callable, + *, + n_steps: int = 1000, + lambda_schedule: str | Callable = "linear", + reverse: bool = False, + temperature: float = 300.0, + timestep: float = 0.002, + pbar: bool | dict[str, Any] = False, + trajectory_reporter: ts.TrajectoryReporter | None = None, + step_frequency: int = 1, + autobatcher: bool = False, + state_frequency: int = 50, + **integrator_kwargs, +) -> ts.SimState: + """Run non-equilibrium molecular dynamics simulation. + + Args: + system: Initial system state, possibly batched + model_a: First model for thermodynamic integration + model_b: Second model for thermodynamic integration + save_dir: Directory to save trajectory files + integrator: Integration function + n_steps: Number of simulation steps + lambda_schedule: Lambda schedule type ("linear", "quadratic", "paper") + reverse: Reverse the Lambda schedule for backward TI + for non symmetric lambda paths + temperature: Temperature for simulation + timestep: Integration timestep + pbar (bool | dict[str, Any], optional): Show a progress bar. + Only works with an autobatcher in interactive shell. If a dict is given, + it's passed to `tqdm` as kwargs. + trajectory_reporter: Reporter for trajectory data + step_frequency: Frequency for reporting steps + autobatcher: Whether to use automatic batching + state_frequency: Frequency for state reporting + **integrator_kwargs: Additional integrator arguments + + Returns: + Final simulation state + """ + unit_system = UnitSystem.metal + + # Validate lambda schedule + if isinstance(lambda_schedule, str): + if lambda_schedule not in LAMBDA_SCHEDULES: + raise ValueError( + f"Unknown lambda schedule: {lambda_schedule}. " + f"Available: {list(LAMBDA_SCHEDULES.keys())}" + ) + schedule_fn = LAMBDA_SCHEDULES[lambda_schedule] + + if isinstance(lambda_schedule, Callable): + schedule_fn = lambda_schedule + + def lambda_schedule(step: int) -> float: + if reverse: + return schedule_fn(n_steps - 1 - step, n_steps - 1) + return schedule_fn(step, n_steps - 1) + + # Ensure system is a single system (not batched) + if isinstance(system, list): + raise TypeError("system should be a single system, not a list. ") + + model = MixedModel( + model1=model_a, + model2=model_b, + device=model_b.device, + dtype=model_b.dtype, + ) + state: SimState = initialize_state(system, model.device, model.dtype) + dtype, device = state.dtype, state.device + kT = ( + torch.as_tensor(temperature, dtype=dtype, device=device) * unit_system.temperature + ) + + # Create filenames for trajectory files + filenames = [ + os.path.join(save_dir, f"trajectory_steered_{replica_idx}.h5") + for replica_idx in range(state.n_systems) + ] + + trajectory_reporter = ts.TrajectoryReporter( + filenames=filenames, + state_frequency=state_frequency, + prop_calculators={ + step_frequency: { + "energy_diff": lambda state: state.energy_difference, + "energy": lambda state: state.energy, + "energy1": lambda state: state.energy1, + "energy2": lambda state: state.energy2, + "lambda_": lambda state: state.lambda_, + }, + 10: { + "temperature": lambda state: ts.quantities.calc_temperature( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + ) + }, + }, + ) + + if not kT.ndim == 0: + raise TypeError("temperature must be a single float value.") + + init_fn, update_fn = integrator( + model=model, + kT=kT, + dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), + **integrator_kwargs, + ) + + # batch_iterator will be a list if autobatcher is False + batch_iterator = _configure_batches_iterator(model, state, autobatcher) + trajectory_reporter = _configure_reporter( + trajectory_reporter, + properties=["kinetic_energy", "potential_energy", "temperature"], + ) + + final_states: list[SimState] = [] + log_filenames = trajectory_reporter.filenames if trajectory_reporter else None + + tqdm_pbar = None + if pbar and autobatcher: + pbar_kwargs = pbar if isinstance(pbar, dict) else {} + pbar_kwargs.setdefault("desc", "Integrate") + pbar_kwargs.setdefault("disable", None) + tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) + + for state, batch_indices in batch_iterator: + # Initialize lambda values based on batch indices + lambda_values = torch.ones( + state.n_systems, dtype=dtype, device=device + ) * lambda_schedule(0) + state = init_fn(state, lambda_=lambda_values, kT=kT) + + # set up trajectory reporters + if autobatcher and trajectory_reporter: + # we must remake the trajectory reporter for each batch + trajectory_reporter.load_new_trajectories( + filenames=[log_filenames[i] for i in batch_indices] + ) + + # Thermodynamic integration phase + ti_bar = tqdm( + range(1, n_steps + 1), + desc="TI Integration", + disable=not pbar, + mininterval=0.5, + ) + + for step in ti_bar: + # Calculate lambda using the selected schedule + lambda_value = lambda_schedule(step - 1) + + # Update lambda values + if len(batch_indices) > 0: + new_lambdas = torch.full_like( + batch_indices, lambda_value, dtype=dtype, device=device + ) + else: + new_lambdas = torch.full( + (state.n_systems,), lambda_value, dtype=dtype, device=device + ) + + state.lambda_ = new_lambdas + + # Update state + state = update_fn(state, kT=kT) + + if trajectory_reporter: + trajectory_reporter.report(state, step, model=model) + + # finish the trajectory reporter + final_states.append(state) + if tqdm_pbar: + tqdm_pbar.update(state.n_batches) + + if trajectory_reporter: + trajectory_reporter.finish() + + if isinstance(batch_iterator, BinningAutoBatcher): + reordered_states = batch_iterator.restore_original_order(final_states) + return concatenate_states(reordered_states) + + return state + + +def run_equilibrium_md( # noqa: C901 + system: Any, + model_a: torch.nn.Module, + model_b: torch.nn.Module, + lambdas: torch.Tensor, + save_dir: str, + integrator: Callable, + *, + n_steps: int = 1000, + temperature: float = 300.0, + timestep: float = 0.002, + pbar: bool | dict[str, Any] = False, + trajectory_reporter: ts.TrajectoryReporter | None = None, + step_frequency: int = 1, + filenames: str | None = None, + autobatcher: bool = False, + state_frequency: int = 50, + **integrator_kwargs, +) -> ts.SimState: + """Run equilibrium molecular dynamics simulation. + + Args: + system: Initial system state, possibly batched + model_a: First model for thermodynamic integration + model_b: Second model for thermodynamic integration + lambdas: Tensor of lambda values for each system in the batch + save_dir: Directory to save trajectory files + integrator: Integration function + n_steps: Number of simulation steps + reverse: Reverse the Lambda schedule for backward TI for + non symmetric lambda paths + temperature: Temperature for simulation + timestep: Integration timestep + pbar (bool | dict[str, Any], optional): Show a progress bar. + Only works with an autobatcher in interactive shell. If a dict is given, + it's passed to `tqdm` as kwargs. + trajectory_reporter: Reporter for trajectory data + step_frequency: Frequency for reporting steps + filenames: List of filenames for trajectory files. If None, defaults will be used. + Useful when running sequential thermodynamic integration + autobatcher: Whether to use automatic batching + state_frequency: Frequency for state reporting + **integrator_kwargs: Additional integrator arguments + + Returns: + Final simulation state + """ + unit_system = UnitSystem.metal + + if lambdas.ndim == 0: + lambdas = lambdas.unsqueeze(0) + + # Ensure system is a single system (not batched) + if isinstance(system, list): + raise TypeError("system should be a single system, not a list. ") + if len(lambdas) != len(lambdas.unique()): + raise ValueError( + "Lambda list must be unique.Batch of different systems is not supported yet." + ) + + model = MixedModel( + model1=model_a, + model2=model_b, + device=model_b.device, + dtype=model_b.dtype, + ) + state: SimState = initialize_state(system, model.device, model.dtype) + dtype, device = state.dtype, state.device + kT = ( + torch.as_tensor(temperature, dtype=dtype, device=device) * unit_system.temperature + ) + if state.n_systems != len(lambdas): + raise ValueError( + f"Number of systems in state ({state.n_systems}) must match " + f"number of lambda values ({len(lambdas)})." + ) + + # Create filenames for trajectory files + if filenames is None: + filenames = [ + os.path.join(save_dir, f"trajectory_lambda_{replica_idx}.h5") + for replica_idx in range(len(lambdas)) + ] + else: + filenames = [os.path.join(save_dir, filename) for filename in filenames] + + trajectory_reporter = ts.TrajectoryReporter( + filenames=filenames, + state_frequency=state_frequency, + prop_calculators={ + step_frequency: { + "energy_diff": lambda state: state.energy_difference, + "energy": lambda state: state.energy, + # "energy1": lambda state: state.energy1, + # "energy2": lambda state: state.energy2, + "lambda_": lambda state: state.lambda_, + }, + 10: { + "temperature": lambda state: ts.quantities.calc_temperature( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + ) + }, + }, + ) + + if not kT.ndim == 0: + raise TypeError("temperature must be a single float value.") + + init_fn, update_fn = integrator( + model=model, + kT=kT, + dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), + **integrator_kwargs, + ) + + # batch_iterator will be a list if autobatcher is False + batch_iterator = _configure_batches_iterator(model, state, autobatcher) + trajectory_reporter = _configure_reporter( + trajectory_reporter, + properties=["kinetic_energy", "potential_energy", "temperature"], + ) + + final_states: list[SimState] = [] + log_filenames = trajectory_reporter.filenames if trajectory_reporter else None + + tqdm_pbar = None + if pbar and autobatcher: + pbar_kwargs = pbar if isinstance(pbar, dict) else {} + pbar_kwargs.setdefault("desc", "Integrate") + pbar_kwargs.setdefault("disable", None) + tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) + + for state, batch_indices in batch_iterator: + state = init_fn(state, lambda_=lambdas, kT=kT) + + # set up trajectory reporters + if autobatcher and trajectory_reporter: + # we must remake the trajectory reporter for each batch + trajectory_reporter.load_new_trajectories( + filenames=[log_filenames[i] for i in batch_indices] + ) + + # Thermodynamic integration phase + ti_bar = tqdm( + range(1, n_steps + 1), + desc="TI Integration", + disable=not pbar, + mininterval=0.5, + ) + + for step in ti_bar: + # Update state + state = update_fn(state, kT=kT) + + if trajectory_reporter: + trajectory_reporter.report(state, step, model=model) + + # finish the trajectory reporter + final_states.append(state) + if tqdm_pbar: + tqdm_pbar.update(state.n_batches) + + if trajectory_reporter: + trajectory_reporter.finish() + + if isinstance(batch_iterator, BinningAutoBatcher): + reordered_states = batch_iterator.restore_original_order(final_states) + return concatenate_states(reordered_states) + + return state diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 65947e8d..6e6e6b83 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -242,6 +242,64 @@ def minimum_image_displacement( return torch.einsum("ij,...j->...i", cell, dr_frac) +def minimum_image_displacement_batched( + dr: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor, *, pbc: bool = True +) -> torch.Tensor: + """Apply minimum image convention to displacement vectors with batched systems. + + Args: + dr (torch.Tensor): Displacement vectors [n_atoms, 3] or [n_atoms, N, 3]. + cell (torch.Tensor): Unit cell matrix [n_systems, 3, 3]. + system_idx (torch.Tensor): Tensor of shape (n_atoms,) + containing system indices for each atom. + pbc (bool): Whether to apply periodic boundary conditions. + + Returns: + torch.Tensor: Minimum image displacement vectors with same shape as input. + """ + if not pbc: + return dr + + # Validate inputs + if not torch.is_floating_point(dr) or not torch.is_floating_point(cell): + raise TypeError( + "Displacement vectors \ + and lattice vectors must be floating point tensors." + ) + + if dr.shape[-1] != cell.shape[-1]: + raise ValueError("Displacement dimensionality must match lattice vectors.") + + # Get unique system indices and counts + unique_systems = torch.unique(system_idx) + n_systems = len(unique_systems) + + if n_systems != cell.shape[0]: + raise ValueError( + f"Number of unique systems ({n_systems}) doesn't " + f"match number of cells ({cell.shape[0]})" + ) + + # Efficient approach without explicit loops + # Get the cell for each atom based on its system index + B = torch.linalg.inv(cell) # Shape: (n_systems, 3, 3) + B_per_atom = B[system_idx] # Shape: (n_atoms, 3, 3) + + # Transform to fractional coordinates: f = B·r + # For each atom, multiply its position by its system's inverse cell matrix + frac_coords = torch.bmm(B_per_atom, dr.unsqueeze(2)).squeeze(2) + + # Round to nearest integer to apply minimum image convention + wrapped_frac = frac_coords - torch.round(frac_coords) + + # Transform back to real space: r = A·f + # Get the cell for each atom based on its system index + cell_per_atom = cell[system_idx] # Shape: (n_atoms, 3, 3) + + # For each atom, multiply its wrapped fractional coords by its system's cell matrix + return torch.bmm(cell_per_atom, wrapped_frac.unsqueeze(2)).squeeze(2) + + def get_pair_displacements( *, positions: torch.Tensor, @@ -1172,3 +1230,46 @@ def safe_mask( """ masked = torch.where(mask, operand, torch.zeros_like(operand)) return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + + +def unwrap_positions(pos: torch.Tensor, box: torch.Tensor) -> torch.Tensor: + """Unwrap wrapped positions into continuous coordinates. + + Parameters + ---------- + pos : (T, N, 3) tensor + Wrapped cartesian positions + box : (3,3) or (T,3,3) tensor + Box matrix (orthorhombic: diagonal; triclinic: full 3x3). + If constant, pass shape (3,3). If time-dependent, pass (T,3,3). + + Returns: + ------- + unwrapped_pos : (T, N, 3) tensor + Unwrapped cartesian positions + """ + box = box.squeeze() + if box.ndim == 2: # constant box + inv_box = torch.inverse(box) # (3,3) + frac = torch.matmul(pos, inv_box.T) # (T,N,3) + dfrac = frac[1:] - frac[:-1] # (T-1,N,3) + dfrac_corrected = dfrac - torch.round(dfrac) + dcart = torch.matmul(dfrac_corrected, box.T) # (T-1,N,3) + + elif box.ndim == 3: # time-dependent box + inv_box = torch.inverse(box) # (T,3,3) + # fractional coords: (T,N,3) = (T,3,3) @ (T,N,3) + frac = torch.einsum("tij,tnj->tni", inv_box, pos) + dfrac = frac[1:] - frac[:-1] # (T-1,N,3) + dfrac_corrected = dfrac - torch.round(dfrac) + dcart = torch.einsum("tij,tnj->tni", box[:-1], dfrac_corrected) + + else: + raise ValueError("box must be shape (3,3) or (T,3,3)") + + # cumulative reconstruction + unwrapped = torch.empty_like(pos) + unwrapped[0] = pos[0] + unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] + + return unwrapped diff --git a/torch_sim/workflows/free_energy_prediction.py b/torch_sim/workflows/free_energy_prediction.py new file mode 100644 index 00000000..1e5a033b --- /dev/null +++ b/torch_sim/workflows/free_energy_prediction.py @@ -0,0 +1,639 @@ +"""Free energy prediction workflows using thermodynamic integration. + +This module provides complete workflows for free energy prediction using +thermodynamic integration methods with Einstein model references. +""" + +import logging +import os +from collections.abc import Callable + +import torch + +import torch_sim as ts +from torch_sim.integrators.nvt import nvt_langevin +from torch_sim.models.einstein import EinsteinModel +from torch_sim.models.interface import ModelInterface +from torch_sim.state import SimState +from torch_sim.steered_md import ( + nvt_langevin_thermodynamic_integration, + run_equilibrium_md, + run_non_equilibrium_md, +) +from torch_sim.units import BaseConstant, UnitConversion + + +logger = logging.getLogger(__name__) + + +def compute_free_energy_jarzynski(work: torch.Tensor, temperature: float) -> torch.Tensor: + r"""Compute free energy difference using Jarzynski equality. + + \Delta_F = -kT ln < exp(-W/kT) > = -kT ln (1/N sum_i exp(-W_i/kT)) + + Uses logsumexp for numerical stability. + + Args: + work: Tensor of shape [n_trajectories] with work values. + temperature: Temperature in K. + + Returns: + free_energy: Tensor with free energy difference at each step. + """ + kB = BaseConstant.k_B / UnitConversion.eV_to_J # Boltzmann constant in eV/K + beta = 1 / (kB * temperature) + n_traj = torch.tensor(work.shape[0], device=work.device) + return -torch.logsumexp(-beta * work, dim=0) / beta + torch.log(n_traj) / beta + + +def compute_work_steered_md( + energy_difference: torch.Tensor, lamdbas: torch.Tensor +) -> torch.Tensor: + """Compute work done during steered MD. + + Args: + energy_difference: Tensor of shape [n_steps] with energy differences. + lamdbas: Tensor of shape [n_steps] with lambda values. + + Returns: + work: Tensor of shape [n_trajectories] with work values. + """ + delta_lambda = lamdbas[1:] - lamdbas[:-1] + return torch.sum(energy_difference[:-1] * delta_lambda, dim=-1) + + +# TODO: modify to output one frequency per atom type instead of per atom +def compute_einstein_frequencies_from_nvt( + system: SimState, + model: ModelInterface, + temperature: float, + save_dir: str, + n_steps: int = 1000, + timestep: float = 0.001, + filename: str = "nvt_frequency_calc.h5", +) -> torch.Tensor: + """Compute Einstein model frequencies from NVT simulation square deviations. + + Uses equipartition theorem: omega = sqrt(3 k_B T / (m * )) + + Args: + system: Simulation state. + model: Model to use for forces and energy. + temperature: Temperature in K. + save_dir: Directory to save results. + n_steps: Number of NVT steps for frequency calculation. + timestep: Timestep for NVT simulation. + filename: Name of the file to save results. + + Returns: + frequencies: Tensor with Einstein frequencies for each atom. + """ + filename = os.path.join(save_dir, filename) + reporter = ts.TrajectoryReporter( + filenames=[filename], + state_frequency=10, + ) + + # Run NVT simulation + _ = ts.integrate( + system, + model, + integrator=nvt_langevin, + n_steps=n_steps, + timestep=timestep, + temperature=temperature, + trajectory_reporter=reporter, + ) + + # Load trajectory and compute square deviations + with ts.TorchSimTrajectory(filename, mode="r") as traj: + positions = traj.get_array("positions") # [n_frames, n_atoms, 3] + positions = torch.from_numpy(positions).to( + device=system.device, dtype=system.dtype + ) + + # Compute average square deviation from mean position + unwrapped_positions = ts.transforms.unwrap_positions(positions, system.cell[0]) + square_deviations = unwrapped_positions.var(dim=0).sum(dim=-1) # (n_particles) + + # Compute frequencies using equipartition theorem + # omega = sqrt(3/2 k_B T / (1/2 * m * )) + kB = BaseConstant.k_B / UnitConversion.eV_to_J # Boltzmann constant in eV/K + + return torch.sqrt(3 * kB * temperature / (system.masses * square_deviations)) + + +# Complete workflow implementations for free energy prediction using +# thermodynamic integration +# +# Two main approaches are implemented: +# +# 1. Forward TI with Einstein reference: +# - Compute Einstein model frequencies from NVT simulation of reference model +# - Run multiple forward TI trajectories from Einstein model to target model +# - Compute free energy difference using Jarzynski equality +# Reference: +# - Jarzynski equality: +# 2. Forward-backward TI: +# - Run forward TI from model_a to model_b +# - Equilibrate at model_b +# - Run backward TI from model_b to model_a (with reverse=True) +# - Compute free energy using Jarzynski equality and adiabatic switching method +# Reference: +# - de Koning, Maurice, and A. Antonelli. +# "Adiabatic switching applied to realistic crystalline solids: Vacancy-formation +# free energy in copper." Physical Review B 55.2 (1997): 735. +# +# Inspiration and sources: +# - https://calorine.materialsmodeling.org/get_started/free_energy_tutorial.html +# - Freitas, Rodrigo, Mark Asta, and Maurice De Koning. +# "Nonequilibrium free-energy calculation of solids using LAMMPS." +# Computational Materials Science 112 (2016): 333-341. +# +# Example usage: +# +# # Forward TI with Einstein reference +# from torch_sim.workflows.free_energy_workflows import ( +# run_forward_ti_with_einstein_workflow +# ) +# results = run_forward_ti_with_einstein_workflow( +# system=my_system, +# model_a=reference_model, # Used to compute Einstein frequencies +# model_b=target_model, +# temperature=300.0, +# save_dir="./ti_results", +# n_trajectories=10, +# n_ti_steps=1000 +# ) +# +# # Forward-backward TI +# from torch_sim.workflows.free_energy_workflows import run_forward_backward_ti_workflow +# results = run_forward_backward_ti_workflow( +# system=my_system, +# model_a=model_a, +# model_b=model_b, +# temperature=300.0, +# save_dir="./ti_results", +# n_trajectories=10, +# n_ti_steps=1000, +# n_equil_steps=500 +# ) +# +# Both workflows return dictionaries with: +# - free energy differences +# - work distributions +# - trajectory data +# - statistical analysis (mean, std dev) +# + + +def run_forward_ti_workflow_from_einstein( + system: SimState, + model: ModelInterface, + temperature: float, + save_dir: str, + *, + n_trajectories: int = 10, + n_steps_frequency: int = 1000, + n_ti_steps: int = 1000, + timestep: float = 0.002, + lambda_schedule: str | Callable = "linear", +) -> dict[str, torch.Tensor]: + """Run standard forward thermodynamic integration workflow. + + Args: + system: Initial system state. + model: Target model. + temperature: Temperature in K. + save_dir: Directory to save trajectory files. + n_trajectories: Number of TI trajectories. + n_steps_frequency: Number of NVT steps to compute frequencies for Einstein model. + n_ti_steps: Number of TI steps per trajectory. + timestep: Integration timestep for TI. + lambda_schedule: Lambda schedule ("linear" or "quadratic") or a custom function. + + Returns: + Dictionary with free energy results and trajectory data. + """ + os.makedirs(save_dir, exist_ok=True) + + if system.n_systems != 1: + raise NotImplementedError("Only single system input is supported.") + + # Find frequencies for Einstein reference model + frequencies = compute_einstein_frequencies_from_nvt( + system=system, + model=model, + temperature=temperature, + save_dir=save_dir, + n_steps=n_steps_frequency, + timestep=timestep, + ) + + logger.info( + "Einstein frequencies (eV^(0.5)/A/amu^(0.5)): %s", frequencies.cpu().numpy() + ) + + # Prepare batched systems for multiple trajectories + systems = [system.clone() for _ in range(n_trajectories)] + batched_system = ts.state.concatenate_states(systems) + + # Define reference Einstein model + einstein_model = EinsteinModel.from_atom_and_frequencies( + atom=batched_system, + frequencies=frequencies.repeat(n_trajectories), + device=batched_system.device, + dtype=batched_system.dtype, + ) + einstein_model.compile() + + batched_system = einstein_model.sample(batched_system, temperature) + + # Run forward TI + _ = run_non_equilibrium_md( + system=batched_system, + model_a=einstein_model, + model_b=model, + save_dir=save_dir, + integrator=nvt_langevin_thermodynamic_integration, + n_steps=n_ti_steps, + lambda_schedule=lambda_schedule, + reverse=False, + temperature=temperature, + timestep=timestep, + pbar=True, + ) + + # Load trajectory data and compute work values + work_values = [] + energy_differences = [] + lambdas = [] + + for i in range(n_trajectories): + filename = os.path.join(save_dir, f"trajectory_steered_{i}.h5") + with ts.TorchSimTrajectory(filename, mode="r") as traj: + energy_diff = ( + torch.from_numpy(traj.get_array("energy_diff")) + .to(batched_system.device) + .squeeze() + ) + lambda_vals = ( + torch.from_numpy(traj.get_array("lambda_")) + .to(batched_system.device) + .squeeze() + ) + + energy_diff_per_atom = energy_diff / batched_system.n_atoms_per_system[i] + + # Compute work for this trajectory + work = compute_work_steered_md(energy_diff_per_atom, lambda_vals).item() + + work_values.append(work) + energy_differences.append(energy_diff_per_atom) + lambdas.append(lambda_vals) + + work_tensor = torch.tensor(work_values).to(batched_system.device) + + # Compute free energy using Jarzynski equality + free_energy = compute_free_energy_jarzynski(work_tensor, temperature) + free_energy_einstein = ( + einstein_model.get_free_energy(temperature)["free_energy"][0] + / batched_system.n_atoms_per_system[i] + ) + + return { + "free_energy": free_energy + free_energy_einstein, + "free_energy_einstein": free_energy_einstein, + "free_energy_difference": free_energy, + "work_values": work_tensor, + "mean_work": work_tensor.mean(), + "std_work": work_tensor.std(), + "energy_differences": torch.stack(energy_differences), + "lambda_values": torch.stack(lambdas), + } + + +def run_forward_backward_ti_workflow_from_einstein( + system: SimState, + model: ModelInterface, + temperature: float, + save_dir: str, + *, + n_trajectories: int = 10, + n_steps_frequency: int = 1000, + n_ti_steps: int = 1000, + n_equil_steps: int = 500, + timestep: float = 0.002, + lambda_schedule: str = "linear", +) -> dict[str, torch.Tensor]: + """Run forward-backward thermodynamic integration workflow. + + Workflow: + 1. Run forward TI from Einstein to model + 2. Equilibrate at model for n_equil_steps + 3. Run backward TI from model back to Einstein + 4. Compute free energy using linear response theory + + Args: + system: Initial system state. + model: Target model. + temperature: Temperature in K. + save_dir: Directory to save trajectory files. + n_trajectories: Number of TI trajectories. + n_steps_frequency: Number of NVT steps to compute frequencies for Einstein model. + n_ti_steps: Number of TI steps per trajectory. + n_equil_steps: Number of equilibration steps at model_b. + timestep: Integration timestep. + lambda_schedule: Lambda schedule: + "linear", "quadratic", "cubic", "lammps" or a custom function. + + Returns: + Dictionary with free energy results and trajectory data. + """ + os.makedirs(save_dir, exist_ok=True) + + if system.n_systems != 1: + raise NotImplementedError("Only single system input is supported.") + + # Find frequencies for Einstein reference model + frequencies = compute_einstein_frequencies_from_nvt( + system=system, + model=model, + temperature=temperature, + save_dir=save_dir, + n_steps=n_steps_frequency, + timestep=timestep, + ) + + # Prepare batched systems for multiple trajectories + systems = [system.clone() for _ in range(n_trajectories)] + batched_system = ts.state.concatenate_states(systems) + + # Define reference Einstein model + einstein_model = EinsteinModel.from_atom_and_frequencies( + atom=batched_system, + frequencies=frequencies.repeat(n_trajectories), + device=batched_system.device, + dtype=batched_system.dtype, + ) + + batched_system = einstein_model.sample(batched_system, temperature) + + logger.info("Running %d forward TI trajectories...", n_trajectories) + + # Create separate save directories for forward and backward + forward_dir = os.path.join(save_dir, "forward") + backward_dir = os.path.join(save_dir, "backward") + os.makedirs(forward_dir, exist_ok=True) + os.makedirs(backward_dir, exist_ok=True) + + # Run forward TI (A -> B) + forward_final_state = run_non_equilibrium_md( + system=batched_system, + model_a=einstein_model, + model_b=model, + save_dir=forward_dir, + integrator=nvt_langevin_thermodynamic_integration, + n_steps=n_ti_steps, + lambda_schedule=lambda_schedule, + reverse=False, + temperature=temperature, + timestep=timestep, + pbar=True, + ) + + # Equilibrate at model_b + logger.info("Equilibrating at target model for %d steps...", n_equil_steps) + equilibrated_state = ts.integrate( + system=forward_final_state, + model=model, + integrator=nvt_langevin, + n_steps=n_equil_steps, + temperature=temperature, + timestep=timestep, + ) + + logger.info("Running %d backward TI trajectories...", n_trajectories) + + # Run backward TI (B -> A) + _ = run_non_equilibrium_md( + system=equilibrated_state, + model_a=einstein_model, + model_b=model, + save_dir=backward_dir, + integrator=nvt_langevin_thermodynamic_integration, + n_steps=n_ti_steps, + lambda_schedule=lambda_schedule, + reverse=True, # This is the key difference + temperature=temperature, + timestep=timestep, + pbar=True, + ) + + # Compute work values for both directions + forward_work_values = [] + backward_work_values = [] + + for i in range(n_trajectories): + # Forward work + forward_filename = os.path.join(forward_dir, f"trajectory_steered_{i}.h5") + with ts.TorchSimTrajectory(forward_filename, mode="r") as traj: + energy_diff = ( + torch.from_numpy(traj.get_array("energy_diff")) + .to(batched_system.device) + .squeeze() + ) + lambda_vals = ( + torch.from_numpy(traj.get_array("lambda_")) + .to(batched_system.device) + .squeeze() + ) + + energy_diff_per_atom = energy_diff / batched_system.n_atoms_per_system[i] + + forward_work = compute_work_steered_md( + energy_diff_per_atom, lambda_vals + ).item() + forward_work_values.append(forward_work) + + # Backward work + backward_filename = os.path.join(backward_dir, f"trajectory_steered_{i}.h5") + with ts.TorchSimTrajectory(backward_filename, mode="r") as traj: + energy_diff = ( + torch.from_numpy(traj.get_array("energy_diff")) + .to(batched_system.device) + .squeeze() + ) + lambda_vals = ( + torch.from_numpy(traj.get_array("lambda_")) + .to(batched_system.device) + .squeeze() + ) + + energy_diff_per_atom = energy_diff / batched_system.n_atoms_per_system[i] + + backward_work = compute_work_steered_md( + energy_diff_per_atom, lambda_vals + ).item() + backward_work_values.append(backward_work) + + forward_work = torch.tensor(forward_work_values).to(batched_system.device) + backward_work = torch.tensor(backward_work_values).to(batched_system.device) + + forward_free_energy = compute_free_energy_jarzynski(forward_work, temperature) + backward_free_energy = -compute_free_energy_jarzynski(backward_work, temperature) + + free_energy_difference = (forward_work.mean() - backward_work.mean()) / 2 + free_energy_einstein = ( + einstein_model.get_free_energy(temperature)["free_energy"][0] + / batched_system.n_atoms_per_system[i] + ) + + return { + "forward_free_energy": forward_free_energy + free_energy_einstein, + "backward_free_energy": backward_free_energy + free_energy_einstein, + "free_energy": free_energy_difference + free_energy_einstein, + "free_energy_difference": free_energy_difference, + "free_energy_einstein": free_energy_einstein, + "forward_work": forward_work, + "backward_work": backward_work, + "forward_mean_work": forward_work.mean(), + "forward_std_work": forward_work.std(), + "backward_mean_work": backward_work.mean(), + "backward_std_work": backward_work.std(), + } + + +def run_thermodynamic_integration_from_einstein( + system: SimState, + model: ModelInterface, + temperature: float, + save_dir: str, + *, + run_parallel: bool = False, + lambdas: torch.Tensor, + n_steps_frequency: int = 1000, + n_ti_steps: int = 1000, + timestep: float = 0.002, +) -> dict[str, torch.Tensor]: + """Run standard forward thermodynamic integration workflow. + + Args: + system: Initial system state. + model: Target model. + temperature: Temperature in K. + save_dir: Directory to save trajectory files. + run_parallel: Whether to run all trajectories in parallel from same initial state + or sequentially, using final state of previous trajectory as initial state + for next trajectory. + lambdas: tensor with lambda values to use for TI. + n_steps_frequency: Number of NVT steps to compute frequencies for Einstein model. + n_ti_steps: Number of TI steps per trajectory. + timestep: Integration timestep for TI. + + Returns: + Dictionary with free energy results and trajectory data. + """ + os.makedirs(save_dir, exist_ok=True) + + if system.n_systems != 1: + raise NotImplementedError("Only single system input is supported.") + + # Find frequencies for Einstein reference model + frequencies = compute_einstein_frequencies_from_nvt( + system=system, + model=model, + temperature=temperature, + save_dir=save_dir, + n_steps=n_steps_frequency, + timestep=timestep, + ) + + if run_parallel: + # Prepare batched systems for multiple trajectories + systems = [system.clone() for _ in range(len(lambdas))] + batched_system = ts.state.concatenate_states(systems) + + # Define reference Einstein model + einstein_model = EinsteinModel.from_atom_and_frequencies( + atom=batched_system, + frequencies=frequencies.repeat(len(lambdas)), + device=batched_system.device, + dtype=batched_system.dtype, + ) + + # Run forward TI + _ = run_equilibrium_md( + system=batched_system, + model_a=einstein_model, + model_b=model, + lambdas=lambdas, + save_dir=save_dir, + integrator=nvt_langevin_thermodynamic_integration, + n_steps=n_ti_steps, + temperature=temperature, + timestep=timestep, + pbar=True, + ) + else: + # Define reference Einstein model + einstein_model = EinsteinModel.from_atom_and_frequencies( + atom=system, + frequencies=frequencies, + device=system.device, + dtype=system.dtype, + ) + batched_system = system.clone() + for i, lambda_ in enumerate(lambdas): + # Run forward TI + batched_system = run_equilibrium_md( + system=batched_system, + model_a=einstein_model, + model_b=model, + lambdas=lambda_, + save_dir=save_dir, + integrator=nvt_langevin_thermodynamic_integration, + n_steps=n_ti_steps, + filenames=[f"trajectory_lambda_{i}.h5"], + temperature=temperature, + timestep=timestep, + pbar=True, + ) + + # Load trajectory data and compute work values + work_values = [] + + for i in range(len(lambdas)): + filename = os.path.join(save_dir, f"trajectory_lambda_{i}.h5") + with ts.TorchSimTrajectory(filename, mode="r") as traj: + energy_diff = ( + torch.from_numpy(traj.get_array("energy_diff")) + .to(batched_system.device) + .squeeze() + ) + + energy_diff_per_atom = energy_diff / batched_system.n_atoms_per_system[0] + + # Compute work for this trajectory + work = energy_diff_per_atom.mean().item() + + work_values.append(work) + + work_tensor = torch.tensor(work_values).to(batched_system.device) + + # integrate work values over lambda, e.g. using trapezoidal rule + free_energy_difference = torch.trapezoid(work_tensor, lambdas).item() + + # Compute free energy using Jarzynski equality + free_energy_einstein = ( + einstein_model.get_free_energy(temperature)["free_energy"][0] + / batched_system.n_atoms_per_system[0] + ) + + return { + "free_energy": free_energy_difference + free_energy_einstein, + "free_energy_einstein": free_energy_einstein, + "free_energy_difference": free_energy_difference, + "lambda_values": lambdas, + "work_values": work_tensor, + }