Skip to content

Commit 3ad55a2

Browse files
authored
Fix fine-tuning callback test (#5643)
* fix * batch size
1 parent 2008d77 commit 3ad55a2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/helpers/boring_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def step(self, x):
9696
return out
9797

9898
def training_step(self, batch, batch_idx):
99-
output = self.layer(batch)
99+
output = self(batch)
100100
loss = self.loss(batch, output)
101101
return {"loss": loss}
102102

@@ -107,15 +107,15 @@ def training_epoch_end(self, outputs) -> None:
107107
torch.stack([x["loss"] for x in outputs]).mean()
108108

109109
def validation_step(self, batch, batch_idx):
110-
output = self.layer(batch)
110+
output = self(batch)
111111
loss = self.loss(batch, output)
112112
return {"x": loss}
113113

114114
def validation_epoch_end(self, outputs) -> None:
115115
torch.stack([x['x'] for x in outputs]).mean()
116116

117117
def test_step(self, batch, batch_idx):
118-
output = self.layer(batch)
118+
output = self(batch)
119119
loss = self.loss(batch, output)
120120
return {"y": loss}
121121

0 commit comments

Comments
 (0)