@@ -390,16 +390,16 @@ def transform_normals(self, normals) -> torch.Tensor:
390390 return normals_out
391391
392392 def translate (self , * args , ** kwargs ) -> "Transform3d" :
393- return self .compose (Translate (device = self .device , * args , ** kwargs ))
393+ return self .compose (Translate (device = self .device , dtype = self . dtype , * args , ** kwargs ))
394394
395395 def scale (self , * args , ** kwargs ) -> "Transform3d" :
396- return self .compose (Scale (device = self .device , * args , ** kwargs ))
396+ return self .compose (Scale (device = self .device , dtype = self . dtype , * args , ** kwargs ))
397397
398398 def rotate (self , * args , ** kwargs ) -> "Transform3d" :
399- return self .compose (Rotate (device = self .device , * args , ** kwargs ))
399+ return self .compose (Rotate (device = self .device , dtype = self . dtype , * args , ** kwargs ))
400400
401401 def rotate_axis_angle (self , * args , ** kwargs ) -> "Transform3d" :
402- return self .compose (RotateAxisAngle (device = self .device , * args , ** kwargs ))
402+ return self .compose (RotateAxisAngle (device = self .device , dtype = self . dtype , * args , ** kwargs ))
403403
404404 def clone (self ) -> "Transform3d" :
405405 """
@@ -488,7 +488,7 @@ def __init__(
488488 - A 1D torch tensor
489489 """
490490 xyz = _handle_input (x , y , z , dtype , device , "Translate" )
491- super ().__init__ (device = xyz .device )
491+ super ().__init__ (device = xyz .device , dtype = dtype )
492492 N = xyz .shape [0 ]
493493
494494 mat = torch .eye (4 , dtype = dtype , device = self .device )
@@ -532,7 +532,7 @@ def __init__(
532532 - 1D torch tensor
533533 """
534534 xyz = _handle_input (x , y , z , dtype , device , "scale" , allow_singleton = True )
535- super ().__init__ (device = xyz .device )
535+ super ().__init__ (device = xyz .device , dtype = dtype )
536536 N = xyz .shape [0 ]
537537
538538 # TODO: Can we do this all in one go somehow?
@@ -571,7 +571,7 @@ def __init__(
571571
572572 """
573573 device_ = get_device (R , device )
574- super ().__init__ (device = device_ )
574+ super ().__init__ (device = device_ , dtype = dtype )
575575 if R .dim () == 2 :
576576 R = R [None ]
577577 if R .shape [- 2 :] != (3 , 3 ):
@@ -629,7 +629,7 @@ def __init__(
629629 # is for transforming column vectors. Therefore we transpose this matrix.
630630 # R will always be of shape (N, 3, 3)
631631 R = _axis_angle_rotation (axis , angle ).transpose (1 , 2 )
632- super ().__init__ (device = angle .device , R = R )
632+ super ().__init__ (device = angle .device , R = R , dtype = dtype )
633633
634634
635635def _handle_coord (c , dtype : torch .dtype , device : torch .device ) -> torch .Tensor :
@@ -646,8 +646,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
646646 c = torch .tensor (c , dtype = dtype , device = device )
647647 if c .dim () == 0 :
648648 c = c .view (1 )
649- if c .device != device :
650- c = c .to (device = device )
649+ if c .device != device or c . dtype != dtype :
650+ c = c .to (device = device , dtype = dtype )
651651 return c
652652
653653
@@ -696,7 +696,7 @@ def _handle_input(
696696 if y is not None or z is not None :
697697 msg = "Expected y and z to be None (in %s)" % name
698698 raise ValueError (msg )
699- return x .to (device = device_ )
699+ return x .to (device = device_ , dtype = dtype )
700700
701701 if allow_singleton and y is None and z is None :
702702 y = x
0 commit comments