diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 3f5bcb696affe..ac83b5539f249 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor: # TODO: Could make this a `nn.Parameter` with `requires_grad=False` self.pe = self._init_pos_encoding(device=x.device) - x + self.pe[: x.size(0), :] + x = x + self.pe[: x.size(0), :] return self.dropout(x) def _init_pos_encoding(self, device: torch.device) -> Tensor: