Skip to content

Commit 1c2288e

Browse files
committed
provide overloads for compute_cell_shifts instead
1 parent 9ac309a commit 1c2288e

File tree

2 files changed

+15
-28
lines changed

2 files changed

+15
-28
lines changed

torch_sim/neighbors.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,7 @@ def strict_nl(
692692
if cell is None:
693693
d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1)
694694
else:
695-
cell_shifts = transforms.compute_cell_shifts_strict(
696-
cell, shifts_idx, system_mapping
697-
)
695+
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping)
698696
d2 = (
699697
(positions[mapping[0]] - positions[mapping[1]] - cell_shifts)
700698
.square()

torch_sim/transforms.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from collections.abc import Callable
99
from functools import wraps
10+
from typing import overload, override
1011

1112
import torch
1213
from torch.types import _dtype
@@ -535,6 +536,18 @@ def compute_distances_with_cell_shifts(
535536
return dr.norm(p=2, dim=1)
536537

537538

539+
@overload
540+
def compute_cell_shifts(
541+
cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor
542+
) -> torch.Tensor: ...
543+
544+
545+
@overload
546+
def compute_cell_shifts(
547+
cell: None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor
548+
) -> None: ...
549+
550+
538551
def compute_cell_shifts(
539552
cell: torch.Tensor | None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor
540553
) -> torch.Tensor | None:
@@ -564,30 +577,6 @@ def compute_cell_shifts(
564577
return cell_shifts
565578

566579

567-
def compute_cell_shifts_strict(
568-
cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor
569-
) -> torch.Tensor:
570-
"""Compute the cell shifts based on the provided indices and cell matrix.
571-
572-
This function calculates the shifts to apply to positions based on the specified
573-
indices and the unit cell matrix. It is the same as compute_cell_shifts, but
574-
cell cannot be None.
575-
576-
Args:
577-
cell (torch.Tensor): A tensor of shape (n_cells, 3, 3)
578-
representing the unit cell matrices.
579-
shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3)
580-
representing the indices for shifts.
581-
system_mapping (torch.Tensor): A tensor of shape (n_systems,)
582-
that maps the shifts to the corresponding cells.
583-
584-
Returns:
585-
torch.Tensor: A tensor of shape (n_systems, 3) containing
586-
the computed cell shifts.
587-
"""
588-
return torch.einsum("jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping])
589-
590-
591580
def get_fully_connected_mapping(
592581
*,
593582
i_ids: torch.Tensor,
@@ -873,7 +862,7 @@ def linked_cell( # noqa: PLR0915
873862
shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image
874863
)
875864
batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long)
876-
cell_shifts = compute_cell_shifts_strict(cell.view(-1, 3, 3), shifts_idx, batch_image)
865+
cell_shifts = compute_cell_shifts(cell.view(-1, 3, 3), shifts_idx, batch_image)
877866

878867
i_ids = torch.arange(n_atom, device=device, dtype=torch.long)
879868
i_ids = i_ids.repeat(n_cell_image)

0 commit comments

Comments
 (0)