44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import functools
87from typing import Optional
98
109import torch
3938"""
4039
4140
42- def quaternion_to_matrix (quaternions ) :
41+ def quaternion_to_matrix (quaternions : torch . Tensor ) -> torch . Tensor :
4342 """
4443 Convert rotations given as quaternions to rotation matrices.
4544
@@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
7069 return o .reshape (quaternions .shape [:- 1 ] + (3 , 3 ))
7170
7271
73- def _copysign (a , b ) :
72+ def _copysign (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
7473 """
7574 Return a tensor where each element has the absolute value taken from the,
7675 corresponding element of a, with sign taken from the corresponding
@@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
114113
115114 batch_dim = matrix .shape [:- 2 ]
116115 m00 , m01 , m02 , m10 , m11 , m12 , m20 , m21 , m22 = torch .unbind (
117- matrix .reshape (* batch_dim , 9 ), dim = - 1
116+ matrix .reshape (batch_dim + ( 9 ,) ), dim = - 1
118117 )
119118
120119 q_abs = _sqrt_positive_part (
@@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
142141
143142 # We floor here at 0.1 but the exact level is not important; if q_abs is small,
144143 # the candidate won't be picked.
145- quat_candidates = quat_by_rijk / (2.0 * q_abs [..., None ].max (q_abs .new_tensor (0.1 )))
144+ flr = torch .tensor (0.1 ).to (dtype = q_abs .dtype , device = q_abs .device )
145+ quat_candidates = quat_by_rijk / (2.0 * q_abs [..., None ].max (flr ))
146146
147147 # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
148148 # forall i; we pick the best-conditioned one (with the largest denominator)
149149
150150 return quat_candidates [
151151 F .one_hot (q_abs .argmax (dim = - 1 ), num_classes = 4 ) > 0.5 , : # pyre-ignore[16]
152- ].reshape (* batch_dim , 4 )
152+ ].reshape (batch_dim + ( 4 ,) )
153153
154154
155- def _axis_angle_rotation (axis : str , angle ) :
155+ def _axis_angle_rotation (axis : str , angle : torch . Tensor ) -> torch . Tensor :
156156 """
157157 Return the rotation matrices for one of the rotations about an axis
158158 of which Euler angles describe, for each value of the angle given.
@@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
172172
173173 if axis == "X" :
174174 R_flat = (one , zero , zero , zero , cos , - sin , zero , sin , cos )
175- if axis == "Y" :
175+ elif axis == "Y" :
176176 R_flat = (cos , zero , sin , zero , one , zero , - sin , zero , cos )
177- if axis == "Z" :
177+ elif axis == "Z" :
178178 R_flat = (cos , - sin , zero , sin , cos , zero , zero , zero , one )
179+ else :
180+ raise ValueError ("letter must be either X, Y or Z." )
179181
180182 return torch .stack (R_flat , - 1 ).reshape (angle .shape + (3 , 3 ))
181183
182184
183- def euler_angles_to_matrix (euler_angles , convention : str ):
185+ def euler_angles_to_matrix (euler_angles : torch . Tensor , convention : str ) -> torch . Tensor :
184186 """
185187 Convert rotations given as Euler angles in radians to rotation matrices.
186188
@@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
201203 for letter in convention :
202204 if letter not in ("X" , "Y" , "Z" ):
203205 raise ValueError (f"Invalid letter { letter } in convention string." )
204- matrices = map (_axis_angle_rotation , convention , torch .unbind (euler_angles , - 1 ))
205- return functools .reduce (torch .matmul , matrices )
206+ matrices = [
207+ _axis_angle_rotation (c , e )
208+ for c , e in zip (convention , torch .unbind (euler_angles , - 1 ))
209+ ]
210+ # return functools.reduce(torch.matmul, matrices)
211+ return torch .matmul (torch .matmul (matrices [0 ], matrices [1 ]), matrices [2 ])
206212
207213
208214def _angle_from_tan (
209215 axis : str , other_axis : str , data , horizontal : bool , tait_bryan : bool
210- ):
216+ ) -> torch . Tensor :
211217 """
212218 Extract the first or third Euler angle from the two members of
213219 the matrix which are positive constant times its sine and cosine.
@@ -238,16 +244,17 @@ def _angle_from_tan(
238244 return torch .atan2 (data [..., i2 ], - data [..., i1 ])
239245
240246
241- def _index_from_letter (letter : str ):
247+ def _index_from_letter (letter : str ) -> int :
242248 if letter == "X" :
243249 return 0
244250 if letter == "Y" :
245251 return 1
246252 if letter == "Z" :
247253 return 2
254+ raise ValueError ("letter must be either X, Y or Z." )
248255
249256
250- def matrix_to_euler_angles (matrix , convention : str ):
257+ def matrix_to_euler_angles (matrix : torch . Tensor , convention : str ) -> torch . Tensor :
251258 """
252259 Convert rotations given as rotation matrices to Euler angles in radians.
253260
@@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
291298
292299def random_quaternions (
293300 n : int , dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
294- ):
301+ ) -> torch . Tensor :
295302 """
296303 Generate random quaternions representing rotations,
297304 i.e. versors with nonnegative real part.
@@ -305,6 +312,8 @@ def random_quaternions(
305312 Returns:
306313 Quaternions as tensor of shape (N, 4).
307314 """
315+ if isinstance (device , str ):
316+ device = torch .device (device )
308317 o = torch .randn ((n , 4 ), dtype = dtype , device = device )
309318 s = (o * o ).sum (1 )
310319 o = o / _copysign (torch .sqrt (s ), o [:, 0 ])[:, None ]
@@ -313,7 +322,7 @@ def random_quaternions(
313322
314323def random_rotations (
315324 n : int , dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
316- ):
325+ ) -> torch . Tensor :
317326 """
318327 Generate random rotations as 3x3 rotation matrices.
319328
@@ -332,7 +341,7 @@ def random_rotations(
332341
333342def random_rotation (
334343 dtype : Optional [torch .dtype ] = None , device : Optional [Device ] = None
335- ):
344+ ) -> torch . Tensor :
336345 """
337346 Generate a single random 3x3 rotation matrix.
338347
@@ -347,7 +356,7 @@ def random_rotation(
347356 return random_rotations (1 , dtype , device )[0 ]
348357
349358
350- def standardize_quaternion (quaternions ) :
359+ def standardize_quaternion (quaternions : torch . Tensor ) -> torch . Tensor :
351360 """
352361 Convert a unit quaternion to a standard form: one in which the real
353362 part is non negative.
@@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
362371 return torch .where (quaternions [..., 0 :1 ] < 0 , - quaternions , quaternions )
363372
364373
365- def quaternion_raw_multiply (a , b ) :
374+ def quaternion_raw_multiply (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
366375 """
367376 Multiply two quaternions.
368377 Usual torch rules for broadcasting apply.
@@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
383392 return torch .stack ((ow , ox , oy , oz ), - 1 )
384393
385394
386- def quaternion_multiply (a , b ) :
395+ def quaternion_multiply (a : torch . Tensor , b : torch . Tensor ) -> torch . Tensor :
387396 """
388397 Multiply two quaternions representing rotations, returning the quaternion
389398 representing their composition, i.e. the versor with nonnegative real part.
@@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
400409 return standardize_quaternion (ab )
401410
402411
403- def quaternion_invert (quaternion ) :
412+ def quaternion_invert (quaternion : torch . Tensor ) -> torch . Tensor :
404413 """
405414 Given a quaternion representing rotation, get the quaternion representing
406415 its inverse.
@@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
413422 The inverse, a tensor of quaternions of shape (..., 4).
414423 """
415424
416- return quaternion * quaternion .new_tensor ([1 , - 1 , - 1 , - 1 ])
425+ scaling = torch .tensor ([1 , - 1 , - 1 , - 1 ], device = quaternion .device )
426+ return quaternion * scaling
417427
418428
419- def quaternion_apply (quaternion , point ) :
429+ def quaternion_apply (quaternion : torch . Tensor , point : torch . Tensor ) -> torch . Tensor :
420430 """
421431 Apply the rotation given by a quaternion to a 3D point.
422432 Usual torch rules for broadcasting apply.
@@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
439449 return out [..., 1 :]
440450
441451
442- def axis_angle_to_matrix (axis_angle ) :
452+ def axis_angle_to_matrix (axis_angle : torch . Tensor ) -> torch . Tensor :
443453 """
444454 Convert rotations given as axis/angle to rotation matrices.
445455
@@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
455465 return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
456466
457467
458- def matrix_to_axis_angle (matrix ) :
468+ def matrix_to_axis_angle (matrix : torch . Tensor ) -> torch . Tensor :
459469 """
460470 Convert rotations given as rotation matrices to axis/angle.
461471
@@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
471481 return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
472482
473483
474- def axis_angle_to_quaternion (axis_angle ) :
484+ def axis_angle_to_quaternion (axis_angle : torch . Tensor ) -> torch . Tensor :
475485 """
476486 Convert rotations given as axis/angle to quaternions.
477487
@@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
485495 quaternions with real part first, as tensor of shape (..., 4).
486496 """
487497 angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True )
488- half_angles = 0.5 * angles
498+ half_angles = angles * 0.5
489499 eps = 1e-6
490500 small_angles = angles .abs () < eps
491501 sin_half_angles_over_angles = torch .empty_like (angles )
@@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
503513 return quaternions
504514
505515
506- def quaternion_to_axis_angle (quaternions ) :
516+ def quaternion_to_axis_angle (quaternions : torch . Tensor ) -> torch . Tensor :
507517 """
508518 Convert rotations given as quaternions to axis/angle.
509519
@@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
573583 IEEE Conference on Computer Vision and Pattern Recognition, 2019.
574584 Retrieved from http://arxiv.org/abs/1812.07035
575585 """
576- return matrix [..., :2 , :].clone ().reshape (* matrix .size ()[:- 2 ], 6 )
586+ batch_dim = matrix .size ()[:- 2 ]
587+ return matrix [..., :2 , :].clone ().reshape (batch_dim + (6 ,))
0 commit comments