From f7b7c0730ba2cba83a9604ca09951bafcbbafa03 Mon Sep 17 00:00:00 2001 From: jinc7461 Date: Tue, 18 Mar 2025 10:18:09 +0800 Subject: [PATCH] Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP --- src/diffusers/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 00b55cd9c9d6..260b4b8929b0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,7 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor.contiguous()) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor