-
-
Couldn't load subscription status.
- Fork 10.8k
[XPU] Fix xpu model runner call torch.cuda APIs #25011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU] Fix xpu model runner call torch.cuda APIs #25011
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an issue for XPU devices by patching torch.cuda APIs to forward to torch.xpu. While the intention is correct, the implementation has a critical flaw: it performs monkey-patching on the global torch.cuda module without properly restoring the original attributes. This can lead to global state corruption and unpredictable behavior in other parts of the code. My review includes a suggestion to implement the patching in a safe, contained manner using a try...finally block that saves and restores the original attributes.
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]> Signed-off-by: charlifu <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Purpose
#23693 introduce torch.cuda.Stream in GPUModelRunner, this breaks xpu behavior. this PR fix this issue by forward
torch.cudaAPI totorch.xpu.Test Plan
CI
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.