-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Add DataStates-LLM: Asynchronous Checkpointing Engine Support #7166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Hi @mauryaavinash95 - could you please run the pre-commit formatter? That should fix the formatting errors at least. |
Thanks @mauryaavinash95 - The formatting checks look good, DCO shows failing. You can rebase to fix with the command here or if that might cause issues given the complex git history here, we can manually approve the DCO check if you let us know. |
I tried using the DCO instructions and this is how I see it on my git log.
And I think it would be very helpful if you can manually approve the DCO using my email as [email protected]. |
Based on the checks now it looks like only the DCO part is pending @loadams. Please let me know if there's anything I can do to fix this quicker than the DeepSpeed team manually approving the DCO. |
@mauryaavinash95, thanks for this great contribution to DeepSpeed. Do you intend to add a tutorial to help users benefit from this feature? @saforem2, FYI |
@tjruwase @saforem2 : yes, we'd like to set up a tutorial for this. Currently, there is just a short snippet to enable it in deepspeed/runtime/checkpoint_engine/README.md. Could you please point us to a reference and repository that we can use for the tutorial? |
@mauryaavinash95, DeepSpeed tutorials appear on the deepspeed.ai:
|
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish | ||
pass | ||
|
||
def preserves_storage_sharing(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this API. But I think the meaning is inverted here in the sense that preserves_storage_sharing
is what leads to checkpoint blot and requires cloning to fix. Please see the following torch docs. I think it also helpful to add the doc link here.
https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-tensors-preserves-views
Further reading of the doc on my part makes me feel that preserves_tensor_views()
might be a descriptive name. I am curious to know your thoughts. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase: Thanks for pointing this out and for the helpful doc reference!
You're right that preserve_tensor_views
aligns closely with PyTorch's terminology and the default serialization behavior. That said, I was considering whether the API name should emphasize the intent to avoid storage sharing (i.e., debloating checkpoints) rather than reflect the PyTorch mechanism directly.
If the broader goal is to clearly signal "avoid capturing shared storage/views," maybe a name like shared_storage_capture
or avoid_tensor_storage_bloat
might better convey user intent. Alternatively, we could stick to preserve_tensor_views
and clarify the expected effects in the docstring.
Curious to hear your thoughts-- especially if you foresee other use cases for this API beyond just debloating for checkpoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mauryaavinash95, thanks, I see your points. Based on this, I wonder if it is better to let the individual checkpoint engines handle the decision of whether/how to debloat. Although, this will require more code changes, (i.e., moving existing clone_tensor_....
calls to the torch
and nebula
engines), I think it would be a win in the long run. It simplifies DS code, avoids a new API, and restricts this torch-specific semantics into the torch-compatible checkpoint engines.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase That sounds like a great direction-- it would definitely make the codebase more modular and maintainable in the long run. Currently, the clone_tensors_for_torch_save
also handles blocking GPU-to-CPU data movement during cloning.
So I was thinking: would it make sense to abstract this entire logic under the checkpoint_engine.save()
method? That way, each engine could manage both debloating and device transfer optimizations internally, giving more control to engine-specific implementations. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. I think we are aligned. Do you mind updating this PR accordingly? Thanks
@tjruwase: I have one more question about the way Currently, both assume that checkpoints are synchronously flushed to stable storage by the time the function returns, and they immediately update the tracking files for the Do you have thoughts on how best to handle this? One idea could be to move this responsibility into the checkpointing engine itself, allowing it to manage the timing and semantics of when the |
@mauryaavinash95, good question. We handle this in our upcoming code release of FastPersist. The idea is to add a |
deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py
Outdated
Show resolved
Hide resolved
@tjruwase Thanks for the feedback. I've updated the PR as per our discussion and moved the logic to debloat the tensors inside checkpointing engines. |
@mauryaavinash95, can you please look into the CI failures? Also, it seems we are unable to update the branch. |
@tjruwase Thanks for letting me know. |
3a82071
to
84f067b
Compare
Signed-off-by: amaurya <[email protected]>
Signed-off-by: amaurya <[email protected]>
Signed-off-by: amaurya <[email protected]>
Signed-off-by: amaurya <[email protected]>
Signed-off-by: amaurya <[email protected]>
Signed-off-by: amaurya <[email protected]>
@mauryaavinash95 - is this ready to be merged? |
@loadams: I think it is ready to be merged. The one pending thing we have is |
@mauryaavinash95 apologies for the delay on this. Since the FastPersist PR has been merged, do you want to resume this integration? Thanks! |
We are a team at Argonne National Laboratory working on low-overhead asynchronous checkpointing approaches for LLMs and transformers. As part of these efforts, we have developed DataStates-LLM, a library that we would like to contribute to the DeepSpeed community:
https://github.com/datastates/datastates-llm
The key idea we leverage is to allow non-blocking tensor copies during the forward and backward pass from the GPU to the host. Only if these copies do not finish until the update phase, then we block. Meanwhile, from the host memory, the tensors are flushed asynchronously to durable storage (parallel file systems, local SSDs, etc).
To enable this capability, our initial implementation makes the scheduler aware of checkpointing, calling a ckpt.wait() primitive before starting the update phase. We illustrated this with the pipeline scheduler. We are also considering a scheduler-independent solution that integrates with DeepSpeed/Megatron and provides a hook for the start of the update phase, which we can leverage to run ckpt.wait().
We appreciate your feedback and look forward to a collaboration in this space.