@@ -162,13 +162,15 @@ def __init__(
162162 raise ValueError (
163163 '"matrix" has to be a tensor of shape (minibatch, 4, 4)'
164164 )
165- # set the device from matrix
165+ # set dtype and device from matrix
166+ dtype = matrix .dtype
166167 device = matrix .device
167168 self ._matrix = matrix .view (- 1 , 4 , 4 )
168169
169170 self ._transforms = [] # store transforms to compose
170171 self ._lu = None
171172 self .device = make_device (device )
173+ self .dtype = dtype
172174
173175 def __len__ (self ):
174176 return self .get_matrix ().shape [0 ]
@@ -200,7 +202,7 @@ def compose(self, *others):
200202 Returns:
201203 A new Transform3d with the stored transforms
202204 """
203- out = Transform3d (device = self .device )
205+ out = Transform3d (dtype = self . dtype , device = self .device )
204206 out ._matrix = self ._matrix .clone ()
205207 for other in others :
206208 if not isinstance (other , Transform3d ):
@@ -259,7 +261,7 @@ def inverse(self, invert_composed: bool = False):
259261 transformation.
260262 """
261263
262- tinv = Transform3d (device = self .device )
264+ tinv = Transform3d (dtype = self . dtype , device = self .device )
263265
264266 if invert_composed :
265267 # first compose then invert
@@ -278,7 +280,7 @@ def inverse(self, invert_composed: bool = False):
278280 # right-multiplies by the inverse of self._matrix
279281 # at the end of the composition.
280282 tinv ._transforms = [t .inverse () for t in reversed (self ._transforms )]
281- last = Transform3d (device = self .device )
283+ last = Transform3d (dtype = self . dtype , device = self .device )
282284 last ._matrix = i_matrix
283285 tinv ._transforms .append (last )
284286 else :
@@ -291,7 +293,7 @@ def inverse(self, invert_composed: bool = False):
291293 def stack (self , * others ):
292294 transforms = [self ] + list (others )
293295 matrix = torch .cat ([t ._matrix for t in transforms ], dim = 0 )
294- out = Transform3d ()
296+ out = Transform3d (dtype = self . dtype , device = self . device )
295297 out ._matrix = matrix
296298 return out
297299
@@ -392,7 +394,7 @@ def clone(self):
392394 Returns:
393395 new Transforms object.
394396 """
395- other = Transform3d (device = self .device )
397+ other = Transform3d (dtype = self . dtype , device = self .device )
396398 if self ._lu is not None :
397399 other ._lu = [elem .clone () for elem in self ._lu ]
398400 other ._matrix = self ._matrix .clone ()
@@ -422,17 +424,22 @@ def to(
422424 Transform3d object.
423425 """
424426 device_ = make_device (device )
425- if not copy and self .device == device_ :
427+ dtype_ = self .dtype if dtype is None else dtype
428+ skip_to = self .device == device_ and self .dtype == dtype_
429+
430+ if not copy and skip_to :
426431 return self
427432
428433 other = self .clone ()
429- if self .device == device_ :
434+
435+ if skip_to :
430436 return other
431437
432438 other .device = device_
433- other ._matrix = self ._matrix .to (device = device_ , dtype = dtype )
439+ other .dtype = dtype_
440+ other ._matrix = other ._matrix .to (device = device_ , dtype = dtype_ )
434441 other ._transforms = [
435- t .to (device_ , copy = copy , dtype = dtype ) for t in other ._transforms
442+ t .to (device_ , copy = copy , dtype = dtype_ ) for t in other ._transforms
436443 ]
437444 return other
438445
0 commit comments