@@ -638,7 +638,7 @@ def vesin_nl(
638638def  strict_nl (
639639    cutoff : float ,
640640    positions : torch .Tensor ,
641-     cell : torch .Tensor ,
641+     cell : torch .Tensor   |   None ,
642642    mapping : torch .Tensor ,
643643    system_mapping : torch .Tensor ,
644644    shifts_idx : torch .Tensor ,
@@ -656,7 +656,7 @@ def strict_nl(
656656            is used to filter the neighbor pairs based on their distances. 
657657        positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing 
658658            the positions of the atoms. 
659-         cell (torch.Tensor): Unit cell vectors according to the row vector convention, 
659+         cell (torch.Tensor) | None : Unit cell vectors according to the row vector convention, 
660660            i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. 
661661        mapping (torch.Tensor): 
662662            A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` 
@@ -687,10 +687,10 @@ def strict_nl(
687687    References: 
688688        - https://github.com/felixmusil/torch_nl 
689689    """ 
690-     cell_shifts  =  transforms .compute_cell_shifts (cell , shifts_idx , system_mapping )
691-     if  cell_shifts  is  None :
690+     if  cell  is  None :
692691        d2  =  (positions [mapping [0 ]] -  positions [mapping [1 ]]).square ().sum (dim = 1 )
693692    else :
693+         cell_shifts  =  transforms .compute_cell_shifts (cell , shifts_idx , system_mapping )
694694        d2  =  (
695695            (positions [mapping [0 ]] -  positions [mapping [1 ]] -  cell_shifts )
696696            .square ()
0 commit comments