-
Notifications
You must be signed in to change notification settings - Fork 371
[float] document e2e training -> inference flow #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2190
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 723088a with merge base cdced21 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
cc @andrewor14 |
andrewor14
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested this out locally too, works for me. Thanks!
|
(Might need to add a distributed checkpoint section but we can do that in a separate PR) |
|
|
||
| # save the model | ||
| torch.save({ | ||
| 'model': m, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice the model would be in some modeling file, and the training code and inference code would both import it separately, in order to avoid the need to deserialize the python model definition w/ torch.load(...., weights_only=False), which has some security risks.
However, I was aiming to have these be copy/paste-able runnable standalone examples, which seemed to require this bad practice. Thoughts @andrewor14 @vkuzo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a good way to do it as as follows:
- create a reproducible model definition
- create a new instance of (1), train it, save weights to checkpoint
- create a new instance of (1), load weights from checkpoint, finetune it or do inference
there is no saving of model definition with torch.save needed in the flow as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I originally tried actually, but it doesn't work because the weights in the serialized/checkpointed model from step (2) are registered under different names (prefixed with _orig_mod) than the freshly initialized model in step (3).
I solved this by saving the converted model definition directly in torch.save and loading the model state dict into that, but it's not ideal imo. I'm curious how torchtitan/torchtune do this as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok going to merge this for now, we can discuss alternatives async if you want.
* document e2e training -> inference flow * add save/load checkpoint * update to how we load checkpoint * remove debugging * add more detail * remove unused import * lower lr to prevent large optimizer step into weight territory which produces inf * use actual loss function
Summary
Document the E2E training => inference flow with examples.