-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.4.x
Description
Bug description
When running Lightning with a multi-device training strategy (e.g. with DDP), using the OnExceptionCheckpoint callback:
- silently swallows exceptions, which makes it challenging to identify the cause of errors
- results in a NCCL timeout
This is due to the following:
- When we catch an exception, it gets handled by
_call_and_handle_interrupt, which calls into_interrupt:_interrupt(trainer, exception) - We are supposed to re-raise the original exception at the end of this function, but we never get there because...
- In
_interrupt, we call_call_callback_hooks, which calls theon_exceptioncallbacks:_call_callback_hooks(trainer, "on_exception", exception) - If the
OnExceptionCheckpointis enabled, we then call that callback. However, we never finish executing this callback, because in that callback, we calltrainer.save_checkpoint: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/callbacks/on_exception_checkpoint.py#L67 - The
trainer.save_checkpointmethod saves the checkpoint, and then callsself.strategy.barrier("Trainer.save_checkpoint"), which waits for the other processes to get reach that barrier. However, if those processes haven't had an exception, they will never hit this codepath, which means we never advance beyond that barrier (until it times out).
As described in the docstring for Trainer.save_checkpoint:
This method needs to be called on all processes in case the selected strategy is handling distributed checkpointing.
In practice, this means that our jobs eventually time out with a NCCL error, and don't print the original exception.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingstrategy: ddpDistributedDataParallelDistributedDataParallelver: 2.4.x