Skip to content

Commit 6a04876

Browse files
dario-cosciandem0
authored andcommitted
Update label_tensor.py cpu/gpu (#292)
* Update label_tensor.py cpu/gpu * Update test_adaptive_refinment_callbacks.py * Update test_optimizer_callbacks.py
1 parent 6c909e9 commit 6a04876

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

pina/label_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def cuda(self, *args, **kwargs):
176176
tmp = super().cuda(*args, **kwargs)
177177
new = self.__class__.clone(self)
178178
new.data = tmp.data
179-
return tmp
179+
return new
180180

181181
def cpu(self, *args, **kwargs):
182182
"""
@@ -185,7 +185,7 @@ def cpu(self, *args, **kwargs):
185185
tmp = super().cpu(*args, **kwargs)
186186
new = self.__class__.clone(self)
187187
new.data = tmp.data
188-
return tmp
188+
return new
189189

190190
def extract(self, label_to_extract):
191191
"""

tests/test_callbacks/test_adaptive_refinment_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_r3refinment_routine():
7171
# make the trainer
7272
trainer = Trainer(solver=solver,
7373
callbacks=[R3Refinement(sample_every=1)],
74+
accelerator='cpu',
7475
max_epochs=5)
7576
trainer.train()
7677

tests/test_callbacks/test_optimizer_callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,6 @@ def test_switch_optimizer_routine():
8484
new_optimizers_kwargs={'lr': 0.01},
8585
epoch_switch=3)
8686
],
87+
accelerator='cpu',
8788
max_epochs=5)
8889
trainer.train()

0 commit comments

Comments
 (0)