-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention
#17379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch.nn.functional.scaled_dot_product_attention
|
we can transpose to get the expected result. Thanks for the effort! |
yongwww
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good to me
456c72e to
185d28c
Compare
|
MSC E2E test is failing. Seems like we also need to change something other than relax frontend. Link to the ci log: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/PR-17379/6/pipeline/ tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED
|
185d28c to
43268e1
Compare
43268e1 to
a783823
Compare
a783823 to
a2b29c0
Compare
torch.nn.functional.scaled_dot_product_attention outputs in the shape of
(N, ..., L, E_v)butrelax.op.nn.attentiondoes(N, L, ..., E_v)so the output should also be transposed.Maybe we should add E2E tests in
tests/python/nightly/to check the relax torch frontend.cc: @yongwww