diff --git a/tests/test_quantities.py b/tests/test_quantities.py index 7513b6bd..ee17a6fd 100644 --- a/tests/test_quantities.py +++ b/tests/test_quantities.py @@ -1,13 +1,158 @@ +"""Tests for quantities module functions.""" + import pytest import torch +from numpy.testing import assert_allclose from torch._tensor import Tensor -from torch_sim import quantities +from torch_sim.quantities import ( + calc_heat_flux, + calc_kinetic_energy, + calc_kT, + calc_temperature, +) from torch_sim.units import MetalUnits -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -DTYPE = torch.double +DTYPE = torch.float64 +DEVICE = torch.device("cpu") + + +class TestHeatFlux: + """Test suite for heat flux calculations.""" + + @pytest.fixture + def mock_simple_system(self) -> dict[str, torch.Tensor]: + """Simple system with known values.""" + return { + "velocities": torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=DEVICE, + ), + "energies": torch.tensor([1.0, 2.0, 3.0], device=DEVICE), + "stress": torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=DEVICE, + ), + "masses": torch.ones(3, device=DEVICE), + } + + def test_unbatched_total_flux( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test total heat flux calculation for unbatched case.""" + flux = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stresses=mock_simple_system["stress"], + is_virial_only=False, + ) + + # Heat flux parts should cancel out + expected = torch.zeros(3, device=flux.device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_unbatched_virial_only( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test virial-only heat flux calculation for unbatched case.""" + virial = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stresses=mock_simple_system["stress"], + is_virial_only=True, + ) + + expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device) + assert_allclose(virial.cpu().numpy(), expected.cpu().numpy()) + + def test_batched_calculation(self) -> None: + """Test heat flux calculation with batched data.""" + velocities = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=DEVICE, + ) + energies = torch.tensor([1.0, 2.0, 3.0], device=DEVICE) + stress = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=DEVICE, + ) + batch = torch.tensor([0, 0, 1], device=DEVICE) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(3, device=DEVICE), + velocities=velocities, + energies=energies, + stresses=stress, + batch=batch, + ) + + # Each batch should cancel heat flux parts + expected = torch.zeros((2, 3), device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_centroid_stress(self) -> None: + """Test heat flux with centroid stress formulation.""" + velocities = torch.tensor([[1.0, 1.0, 1.0]], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) + + # Symmetric cross-terms + stress = torch.tensor( + [[1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], device=DEVICE + ) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(1, device=DEVICE), + velocities=velocities, + energies=energies, + stresses=stress, + is_centroid_stress=True, + ) + + # Heatflux should be [-1,-1,-1] + expected = torch.full((3,), -1.0, device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_momenta_input(self) -> None: + """Test heat flux calculation using momenta instead.""" + momenta = torch.tensor([[1.0, 0.0, 0.0]], device=DEVICE) + masses = torch.tensor([2.0], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) + stress = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], device=DEVICE) + + flux = calc_heat_flux( + momenta=momenta, + masses=masses, + velocities=None, + energies=energies, + stresses=stress, + ) + + # Heat flux terms should cancel out + expected = torch.zeros(3, device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) @pytest.fixture @@ -46,14 +191,14 @@ def batched_system_data() -> dict[str, Tensor]: def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: # With velocities - ke_vel = quantities.calc_kinetic_energy( + ke_vel = calc_kinetic_energy( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) assert torch.allclose(ke_vel, single_system_data["ke"]) # With momenta - ke_mom = quantities.calc_kinetic_energy( + ke_mom = calc_kinetic_energy( masses=single_system_data["masses"], momenta=single_system_data["momenta"] ) assert torch.allclose(ke_mom, single_system_data["ke"]) @@ -63,7 +208,7 @@ def test_calc_kinetic_energy_batched_system( batched_system_data: dict[str, Tensor], ) -> None: # With velocities - ke_vel = quantities.calc_kinetic_energy( + ke_vel = calc_kinetic_energy( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], @@ -71,7 +216,7 @@ def test_calc_kinetic_energy_batched_system( assert torch.allclose(ke_vel, batched_system_data["ke"]) # With momenta - ke_mom = quantities.calc_kinetic_energy( + ke_mom = calc_kinetic_energy( masses=batched_system_data["masses"], momenta=batched_system_data["momenta"], system_idx=batched_system_data["system_idx"], @@ -81,26 +226,26 @@ def test_calc_kinetic_energy_batched_system( def test_calc_kinetic_energy_errors(single_system_data: dict[str, Tensor]) -> None: with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy( + calc_kinetic_energy( masses=single_system_data["masses"], momenta=single_system_data["momenta"], velocities=single_system_data["velocities"], ) with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy(masses=single_system_data["masses"]) + calc_kinetic_energy(masses=single_system_data["masses"]) def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: # With velocities - kt_vel = quantities.calc_kT( + kt_vel = calc_kT( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) assert torch.allclose(kt_vel, single_system_data["kt"]) # With momenta - kt_mom = quantities.calc_kT( + kt_mom = calc_kT( masses=single_system_data["masses"], momenta=single_system_data["momenta"] ) assert torch.allclose(kt_mom, single_system_data["kt"]) @@ -108,7 +253,7 @@ def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: # With velocities - kt_vel = quantities.calc_kT( + kt_vel = calc_kT( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], @@ -116,7 +261,7 @@ def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: assert torch.allclose(kt_vel, batched_system_data["kt"]) # With momenta - kt_mom = quantities.calc_kT( + kt_mom = calc_kT( masses=batched_system_data["masses"], momenta=batched_system_data["momenta"], system_idx=batched_system_data["system_idx"], @@ -125,11 +270,11 @@ def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: def test_calc_temperature(single_system_data: dict[str, Tensor]) -> None: - temp = quantities.calc_temperature( + temp = calc_temperature( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) - kt = quantities.calc_kT( + kt = calc_kT( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index a1ac0811..404d07f3 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -1,26 +1,11 @@ """Functions for computing physical quantities.""" -from typing import cast - import torch from torch_sim.state import SimState from torch_sim.units import MetalUnits -# @torch.jit.script -def count_dof(tensor: torch.Tensor) -> int: - """Count the degrees of freedom in the system. - - Args: - tensor: Tensor to count the degrees of freedom in - - Returns: - Number of degrees of freedom - """ - return tensor.numel() - - # @torch.jit.script def calc_kT( # noqa: N802 *, @@ -44,17 +29,18 @@ def calc_kT( # noqa: N802 if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: + if momenta is None and velocities is not None: # If velocity provided, calculate mv^2 - velocities = cast("torch.Tensor", velocities) - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # If momentum provided, calculate v^2 = p^2/m^2 - squared_term = (momenta**2) / masses.unsqueeze(-1) + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: # Count total degrees of freedom - dof = count_dof(squared_term) + dof = squared_term.numel() return torch.sum(squared_term) / dof # Sum squared terms for each system flattened_squared = torch.sum(squared_term, dim=-1) @@ -121,10 +107,12 @@ def calc_kinetic_energy( if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: # Using velocities - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: # Using momenta - squared_term = (momenta**2) / masses.unsqueeze(-1) + if momenta is None and velocities is not None: # Using velocities + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # Using momenta + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: return 0.5 * torch.sum(squared_term) @@ -135,7 +123,10 @@ def calc_kinetic_energy( def get_pressure( - stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 + stress: torch.Tensor, + kinetic_energy: float | torch.Tensor, + volume: torch.Tensor, + dim: int = 3, ) -> torch.Tensor: """Compute the pressure from the stress tensor. @@ -145,6 +136,132 @@ def get_pressure( return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) +def calc_heat_flux( + momenta: torch.Tensor | None, + masses: torch.Tensor, + velocities: torch.Tensor | None, + energies: torch.Tensor, + stresses: torch.Tensor, + batch: torch.Tensor | None = None, + *, # Force keyword arguments for booleans + is_centroid_stress: bool = False, + is_virial_only: bool = False, +) -> torch.Tensor: + r"""Calculate the heat flux vector. + + Computes the microscopic heat flux, :math:`\mathbf{J}` + defined as: + + .. math:: + \mathbf{J} = \mathbf{J}^c + \mathbf{J}^v + + where the convective part :math:`\mathbf{J}^c` and virial part + :math:`\mathbf{J}^v` are: + + .. math:: + \mathbf{J}^c &= \sum_i \epsilon_i \mathbf{v}_i \\ + \mathbf{J}^v &= \sum_i \sum_j \mathbf{S}_{ij} \cdot \mathbf{v}_j + + where :math:`\epsilon_i` is the per-atom energy (p.e. + k.e.), + :math:`\mathbf{v}_i` is velocity, and :math:`\mathbf{S}_{ij}` is the + per-atom stress tensor. + + Args: + momenta: Particle momenta, shape (n_particles, n_dim) + masses: Particle masses, shape (n_particles,) + velocities: Particle velocities, shape (n_particles, n_dim) + energies: Per-atom energies (p.e. + k.e.), shape (n_particles,) + stresses: Per-atom stress tensor components: + - If is_centroid_stress=False: shape (n_particles, 6) for + :math:`[\sigma_{xx}, \sigma_{yy}, \sigma_{zz}, + \sigma_{xy}, \sigma_{xz}, \sigma_{yz}]` + - If is_centroid_stress=True: shape (n_particles, 9) for + :math:`[\mathbf{r}_{ix}f_{ix}, \mathbf{r}_{iy}f_{iy}, + \mathbf{r}_{iz}f_{iz}, \mathbf{r}_{ix}f_{iy}, + \mathbf{r}_{ix}f_{iz}, \mathbf{r}_{iy}f_{iz}, + \mathbf{r}_{iy}f_{ix}, \mathbf{r}_{iz}f_{ix}, + \mathbf{r}_{iz}f_{iy}]` + batch: Optional tensor indicating system membership + is_centroid_stress: Whether stress uses centroid formulation + is_virial_only: If True, returns only virial part :math:`\mathbf{J}^v` + + Returns: + Heat flux vector of shape (3,) or (n_systems, 3) + """ + if momenta is not None and velocities is not None: + raise ValueError("Must pass either momenta or velocities, not both") + if momenta is None and velocities is None: + raise ValueError("Must pass either momenta or velocities") + + # Deduce velocities + if velocities is None: + velocities = momenta / masses.unsqueeze(-1) + + convective_flux = energies.unsqueeze(-1) * velocities + + # Calculate virial flux + if is_centroid_stress: + # Centroid formulation: r_i[x,y,z] . f_i[x,y,z] + virial_x = -( + stresses[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x + + stresses[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y + + stresses[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z + ) + virial_y = -( + stresses[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x + + stresses[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y + + stresses[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z + ) + virial_z = -( + stresses[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x + + stresses[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y + + stresses[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z + ) + else: + # Standard stress tensor components + virial_x = -( + stresses[:, 0] * velocities[:, 0] # s_xx.v_x + + stresses[:, 3] * velocities[:, 1] # s_xy.v_y + + stresses[:, 4] * velocities[:, 2] # s_xz.v_z + ) + virial_y = -( + stresses[:, 3] * velocities[:, 0] # s_xy.v_x + + stresses[:, 1] * velocities[:, 1] # s_yy.v_y + + stresses[:, 5] * velocities[:, 2] # s_yz.v_z + ) + virial_z = -( + stresses[:, 4] * velocities[:, 0] # s_xz.v_x + + stresses[:, 5] * velocities[:, 1] # s_yz.v_y + + stresses[:, 2] * velocities[:, 2] # s_zz.v_z + ) + + virial_flux = torch.stack([virial_x, virial_y, virial_z], dim=-1) + + if batch is None: + # All atoms + virial_sum = torch.sum(virial_flux, dim=0) + if is_virial_only: + return virial_sum + conv_sum = torch.sum(convective_flux, dim=0) + return conv_sum + virial_sum + + # All atoms in each system + n_systems = int(torch.max(batch) + 1) + virial_sum = torch.zeros( + (n_systems, 3), device=velocities.device, dtype=velocities.dtype + ) + virial_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), virial_flux) + + if is_virial_only: + return virial_sum + + conv_sum = torch.zeros( + (n_systems, 3), device=velocities.device, dtype=velocities.dtype + ) + conv_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), convective_flux) + return conv_sum + virial_sum + + def systemwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per system.