Skip to content

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Sep 16, 2024

torch.nn.functional.scaled_dot_product_attention outputs in the shape of (N, ..., L, E_v) but relax.op.nn.attention does (N, L, ..., E_v) so the output should also be transposed.

Maybe we should add E2E tests in tests/python/nightly/ to check the relax torch frontend.

cc: @yongwww

@mshr-h mshr-h marked this pull request as ready for review September 16, 2024 14:25
@mshr-h mshr-h changed the title Fix torch sdpa converter [Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention Sep 16, 2024
@mshr-h mshr-h marked this pull request as draft September 17, 2024 02:38
@yongwww
Copy link
Member

yongwww commented Sep 17, 2024

we can transpose to get the expected result. Thanks for the effort!

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good to me

@mshr-h mshr-h marked this pull request as ready for review September 17, 2024 04:37
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch 2 times, most recently from 456c72e to 185d28c Compare September 17, 2024 07:33
@mshr-h
Copy link
Contributor Author

mshr-h commented Sep 17, 2024

MSC E2E test is failing. Seems like we also need to change something other than relax frontend.
@Archermmt Do you have any ideas on how to fix the error?

Link to the ci log: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/PR-17379/6/pipeline/

tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED

[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] =================================== FAILURES ===================================
[2024-09-17T10:06:53.212Z] ________________________________ test_attention ________________________________
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] def test_attention():
[2024-09-17T10:06:53.212Z] """test torch translator for attention"""
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] # pylint: disable=import-outside-toplevel
[2024-09-17T10:06:53.212Z] import torch.nn.functional as F
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] class Attention1(Module):
[2024-09-17T10:06:53.212Z] def forward(self, q_data, k_data, v_data):
[2024-09-17T10:06:53.212Z] return F.scaled_dot_product_attention(q_data, k_data, v_data)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] class Attention2(Module):
[2024-09-17T10:06:53.212Z] def forward(self, q_data, k_data, v_data):
[2024-09-17T10:06:53.212Z] return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] input_info = [
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ]
[2024-09-17T10:06:53.212Z] > verify_model(Attention1(), input_info)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py:1127:
[2024-09-17T10:06:53.212Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py:52: in verify_model
[2024-09-17T10:06:53.212Z] tvm.testing.assert_allclose(
[2024-09-17T10:06:53.212Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] actual = array([[[[0.5253085 , 0.48211107, 0.52921075, ..., 0.518867 ,
[2024-09-17T10:06:53.212Z] 0.49926636, 0.48493868],
[2024-09-17T10:06:53.212Z] [0.5294311 ...5],
[2024-09-17T10:06:53.212Z] [0.47335747, 0.48579183, 0.5360674 , ..., 0.543607 ,
[2024-09-17T10:06:53.212Z] 0.5020893 , 0.47848547]]]], dtype=float32)
[2024-09-17T10:06:53.212Z] desired = array([[[[0.5253085 , 0.48211107, 0.52921075, ..., 0.518867 ,
[2024-09-17T10:06:53.212Z] 0.49926636, 0.48493868],
[2024-09-17T10:06:53.212Z] [0.49697113... ],
[2024-09-17T10:06:53.212Z] [0.47335747, 0.48579183, 0.5360674 , ..., 0.543607 ,
[2024-09-17T10:06:53.212Z] 0.5020893 , 0.47848547]]]], dtype=float32)
[2024-09-17T10:06:53.212Z] rtol = 1e-05, atol = 1e-05
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
[2024-09-17T10:06:53.212Z] """Version of np.testing.assert_allclose with atol and rtol fields set
[2024-09-17T10:06:53.212Z] in reasonable defaults.
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] Arguments actual and desired are not interchangeable, since the function
[2024-09-17T10:06:53.212Z] compares the abs(actual-desired) with atol+rtol*abs(desired). Since we
[2024-09-17T10:06:53.212Z] often allow desired to be close to zero, we generally want non-zero atol.
[2024-09-17T10:06:53.212Z] """
[2024-09-17T10:06:53.212Z] actual = np.asanyarray(actual)
[2024-09-17T10:06:53.212Z] desired = np.asanyarray(desired)
[2024-09-17T10:06:53.212Z] > np.testing.assert_allclose(actual.shape, desired.shape)
[2024-09-17T10:06:53.212Z] E AssertionError:
[2024-09-17T10:06:53.212Z] E Not equal to tolerance rtol=1e-07, atol=0
[2024-09-17T10:06:53.212Z] E
[2024-09-17T10:06:53.212Z] E Mismatched elements: 2 / 4 (50%)
[2024-09-17T10:06:53.212Z] E Max absolute difference: 120
[2024-09-17T10:06:53.212Z] E Max relative difference: 15.
[2024-09-17T10:06:53.212Z] E x: array([ 32, 8, 128, 64])
[2024-09-17T10:06:53.212Z] E y: array([ 32, 128, 8, 64])
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] python/tvm/testing/utils.py:119: AssertionError

@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from 185d28c to 43268e1 Compare September 17, 2024 14:50
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from 43268e1 to a783823 Compare September 19, 2024 05:25
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from a783823 to a2b29c0 Compare September 19, 2024 09:06
@yongwww yongwww merged commit 85f2cc3 into apache:main Sep 20, 2024
@mshr-h mshr-h deleted the fix-torch-sdpa-converter branch September 20, 2024 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants