Skip to content

Commit 2f3a421

Browse files
committed
Fix other PyTorch models
1 parent d531979 commit 2f3a421

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

templates/adding_a_new_model/modeling_xxx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,12 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi
309309
else:
310310
raise ValueError("You have to specify either input_ids or inputs_embeds")
311311

312+
device = input_ids.device if input_ids is not None else inputs_embeds.device
313+
312314
if attention_mask is None:
313-
attention_mask = torch.ones(input_shape)
315+
attention_mask = torch.ones(input_shape, device=device)
314316
if token_type_ids is None:
315-
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
317+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
316318

317319
# We create a 3D attention mask from a 2D tensor mask.
318320
# Sizes are [batch_size, 1, 1, to_seq_length]

transformers/modeling_distilbert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,10 @@ def forward(self,
450450
else:
451451
raise ValueError("You have to specify either input_ids or inputs_embeds")
452452

453+
device = input_ids.device if input_ids is not None else inputs_embeds.device
454+
453455
if attention_mask is None:
454-
attention_mask = torch.ones(input_shape) # (bs, seq_length)
456+
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
455457

456458
# Prepare head mask if needed
457459
# 1.0 in head_mask indicate we keep the head

0 commit comments

Comments
 (0)