diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 98b22ca2f..d4c76916e 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -114,7 +114,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( - matrix.reshape(*batch_dim, 9), dim=-1 + matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( @@ -142,14 +142,15 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. - quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1))) + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] - ].reshape(*batch_dim, 4) + ].reshape(batch_dim + (4,)) def _axis_angle_rotation(axis: str, angle): @@ -238,13 +239,14 @@ def _angle_from_tan( return torch.atan2(data[..., i2], -data[..., i1]) -def _index_from_letter(letter: str): +def _index_from_letter(letter: str) -> int: if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 + raise ValueError("letter must be either X, Y or Z.") def matrix_to_euler_angles(matrix, convention: str): @@ -573,4 +575,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ - return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 85888a455..3a992eb06 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -8,6 +8,7 @@ import itertools import math import unittest +from distutils.version import LooseVersion from typing import Optional, Union import numpy as np @@ -264,6 +265,13 @@ def test_6d(self): torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6 ) + @unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only") + def test_scriptable(self): + torch.jit.script(matrix_to_axis_angle), + torch.jit.script(matrix_to_euler_angles), + torch.jit.script(matrix_to_quaternion), + torch.jit.script(matrix_to_rotation_6d), + def _assert_quaternions_close( self, input: Union[torch.Tensor, np.ndarray],