Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,14 +621,14 @@ def validation_step(self, *args, **kwargs):
for val_batch in val_data:
out = validation_step(val_batch)
val_outs.append(out)
validation_epoch_end(val_outs)
validation_epoch_end(val_outs)

Args:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val datasets used)
(only if multiple val dataloaders used)

Return:
Any of.
Expand Down Expand Up @@ -677,11 +677,11 @@ def validation_step(self, batch, batch_idx):
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val datasets, validation_step will have an additional argument.
If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument.

.. code-block:: python

# CASE 2: multiple validation datasets
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx):
# dataloader_idx tells you which dataset this is.

Expand Down Expand Up @@ -813,7 +813,7 @@ def test_step(self, *args, **kwargs):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test datasets used).
(only if multiple test dataloaders used).

Return:
Any of.
Expand Down Expand Up @@ -853,17 +853,17 @@ def test_step(self, batch, batch_idx):
# log the outputs!
self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple validation datasets, :meth:`test_step` will have an additional
If you pass in multiple test dataloaders, :meth:`test_step` will have an additional
argument.

.. code-block:: python

# CASE 2: multiple test datasets
# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx):
# dataloader_idx tells you which dataset this is.

Note:
If you don't need to validate you don't need to implement this method.
If you don't need to test you don't need to implement this method.

Note:
When the :meth:`test_step` is called, the model has been put in eval mode and
Expand Down