Skip to content

Commit 5a04839

Browse files
authored
fix: TS test_scaled_dot_product_attention (#3117)
1 parent ad1ae8a commit 5a04839

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

core/lowering/passes/unpack_scaled_dot_product_attention.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ namespace passes {
1212
// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
1313
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
1414
std::string sdpa_pattern = R"IR(
15-
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
16-
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale)
15+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
16+
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa)
1717
return (%out))IR";
1818

1919
std::string unpacked_sdpa_pattern = R"IR(
20-
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
20+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
2121
%none : NoneType = prim::Constant()
2222
%1 : int = prim::Constant[value=-1]()
2323
%2 : int = prim::Constant[value=-2]()
@@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
3333
return(%out))IR";
3434

3535
std::string unpacked_sdpa_attn_biased_pattern = R"IR(
36-
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
36+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
3737
%none : NoneType = prim::Constant()
3838
%0 : int = prim::Constant[value=1]()
3939
%1 : int = prim::Constant[value=-1]()
@@ -69,6 +69,16 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
6969
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
7070
return false;
7171
}
72+
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
73+
if (enable_gqa_node->kind() != at::prim::Constant) {
74+
LOG_WARNING(
75+
"Could not unpack scaled_dot_product_attention with non constant enable_gqa: " << *enable_gqa_node);
76+
return false;
77+
}
78+
if (enable_gqa_node->i(at::attr::value) == 1) {
79+
LOG_WARNING("Could not unpack scaled_dot_product_attention with enable_gqa = True: " << *enable_gqa_node);
80+
return false;
81+
}
7282
return true;
7383
});
7484

@@ -83,6 +93,11 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
8393
// messages already written in first pass, do not write again
8494
return false;
8595
}
96+
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
97+
if (enable_gqa_node->kind() != at::prim::Constant || enable_gqa_node->i(at::attr::value) == 1) {
98+
// messages already written in first pass, do not write again
99+
return false;
100+
}
86101
return true;
87102
});
88103
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);

tests/core/conversion/converters/test_scaled_dot_product_attention.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) {
1111
%none : NoneType = prim::Constant()
1212
%0 : float = prim::Constant[value=0.]()
1313
%scale : NoneType = prim::Constant()
14+
%enable_gqa : bool = prim::Constant[value=0]()
1415
%false : bool = prim::Constant[value=0]()
15-
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale)
16+
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale, %enable_gqa)
1617
return (%3))IR";
1718

1819
auto g = std::make_shared<torch::jit::Graph>();
@@ -38,7 +39,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
3839
%0 : float = prim::Constant[value=0.]()
3940
%false : bool = prim::Constant[value=0]()
4041
%scale : NoneType = prim::Constant()
41-
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
42+
%enable_gqa : bool = prim::Constant[value=0]()
43+
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
4244
return (%3))IR";
4345

4446
auto g = std::make_shared<torch::jit::Graph>();
@@ -59,13 +61,14 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
5961
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
6062
}
6163

62-
TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
64+
TEST(Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly) {
6365
const auto graph = R"IR(
6466
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
6567
%0 : float = prim::Constant[value=0.]()
6668
%false : bool = prim::Constant[value=0]()
6769
%scale : NoneType = prim::Constant()
68-
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
70+
%enable_gqa : bool = prim::Constant[value=0]()
71+
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
6972
return (%3))IR";
7073

7174
auto g = std::make_shared<torch::jit::Graph>();
@@ -74,7 +77,7 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
7477
auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
7578
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
7679
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
77-
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, at::kCUDA).to(at::kBool);
80+
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, {at::kCUDA});
7881
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
7982
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});
8083

0 commit comments

Comments
 (0)