|
7 | 7 |
|
8 | 8 | from collections.abc import Callable |
9 | 9 | from functools import wraps |
| 10 | +from typing import overload, override |
10 | 11 |
|
11 | 12 | import torch |
12 | 13 | from torch.types import _dtype |
@@ -535,6 +536,18 @@ def compute_distances_with_cell_shifts( |
535 | 536 | return dr.norm(p=2, dim=1) |
536 | 537 |
|
537 | 538 |
|
| 539 | +@overload |
| 540 | +def compute_cell_shifts( |
| 541 | + cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor |
| 542 | +) -> torch.Tensor: ... |
| 543 | + |
| 544 | + |
| 545 | +@overload |
| 546 | +def compute_cell_shifts( |
| 547 | + cell: None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor |
| 548 | +) -> None: ... |
| 549 | + |
| 550 | + |
538 | 551 | def compute_cell_shifts( |
539 | 552 | cell: torch.Tensor | None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor |
540 | 553 | ) -> torch.Tensor | None: |
@@ -564,30 +577,6 @@ def compute_cell_shifts( |
564 | 577 | return cell_shifts |
565 | 578 |
|
566 | 579 |
|
567 | | -def compute_cell_shifts_strict( |
568 | | - cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor |
569 | | -) -> torch.Tensor: |
570 | | - """Compute the cell shifts based on the provided indices and cell matrix. |
571 | | -
|
572 | | - This function calculates the shifts to apply to positions based on the specified |
573 | | - indices and the unit cell matrix. It is the same as compute_cell_shifts, but |
574 | | - cell cannot be None. |
575 | | -
|
576 | | - Args: |
577 | | - cell (torch.Tensor): A tensor of shape (n_cells, 3, 3) |
578 | | - representing the unit cell matrices. |
579 | | - shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) |
580 | | - representing the indices for shifts. |
581 | | - system_mapping (torch.Tensor): A tensor of shape (n_systems,) |
582 | | - that maps the shifts to the corresponding cells. |
583 | | -
|
584 | | - Returns: |
585 | | - torch.Tensor: A tensor of shape (n_systems, 3) containing |
586 | | - the computed cell shifts. |
587 | | - """ |
588 | | - return torch.einsum("jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping]) |
589 | | - |
590 | | - |
591 | 580 | def get_fully_connected_mapping( |
592 | 581 | *, |
593 | 582 | i_ids: torch.Tensor, |
@@ -873,7 +862,7 @@ def linked_cell( # noqa: PLR0915 |
873 | 862 | shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image |
874 | 863 | ) |
875 | 864 | batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long) |
876 | | - cell_shifts = compute_cell_shifts_strict(cell.view(-1, 3, 3), shifts_idx, batch_image) |
| 865 | + cell_shifts = compute_cell_shifts(cell.view(-1, 3, 3), shifts_idx, batch_image) |
877 | 866 |
|
878 | 867 | i_ids = torch.arange(n_atom, device=device, dtype=torch.long) |
879 | 868 | i_ids = i_ids.repeat(n_cell_image) |
|
0 commit comments