Skip to content

Commit d27e4c1

Browse files
authored
Move rescale dtype recasting to match torchvision ToTensor (#25229)
Move dtype recasting to match torchvision ToTensor
1 parent 3170af7 commit d27e4c1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/transformers/image_transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,12 @@ def rescale(
110110
if not isinstance(image, np.ndarray):
111111
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
112112

113-
image = image.astype(dtype)
114-
115113
rescaled_image = image * scale
116114
if data_format is not None:
117115
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
116+
117+
rescaled_image = rescaled_image.astype(dtype)
118+
118119
return rescaled_image
119120

120121

0 commit comments

Comments
 (0)