Skip to content

Commit 4ccee76

Browse files
committed
fixing types for creating cell shifts
1 parent 295f476 commit 4ccee76

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

torch_sim/neighbors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def vesin_nl(
638638
def strict_nl(
639639
cutoff: float,
640640
positions: torch.Tensor,
641-
cell: torch.Tensor,
641+
cell: torch.Tensor | None,
642642
mapping: torch.Tensor,
643643
system_mapping: torch.Tensor,
644644
shifts_idx: torch.Tensor,
@@ -656,7 +656,7 @@ def strict_nl(
656656
is used to filter the neighbor pairs based on their distances.
657657
positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing
658658
the positions of the atoms.
659-
cell (torch.Tensor): Unit cell vectors according to the row vector convention,
659+
cell (torch.Tensor) | None: Unit cell vectors according to the row vector convention,
660660
i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`.
661661
mapping (torch.Tensor):
662662
A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions`
@@ -687,10 +687,10 @@ def strict_nl(
687687
References:
688688
- https://github.com/felixmusil/torch_nl
689689
"""
690-
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping)
691-
if cell_shifts is None:
690+
if cell is None:
692691
d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1)
693692
else:
693+
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping)
694694
d2 = (
695695
(positions[mapping[0]] - positions[mapping[1]] - cell_shifts)
696696
.square()

torch_sim/transforms.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,9 @@ def compute_cell_shifts(
554554
torch.Tensor: A tensor of shape (n_systems, 3) containing
555555
the computed cell shifts.
556556
"""
557-
if cell is None:
558-
cell_shifts = None
559-
else:
560-
cell_shifts = torch.einsum(
561-
"jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping]
562-
)
563-
return cell_shifts
557+
return torch.einsum(
558+
"jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping]
559+
)
564560

565561

566562
def get_fully_connected_mapping(

0 commit comments

Comments
 (0)