@@ -11,8 +11,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) {
11
11
%none : NoneType = prim::Constant()
12
12
%0 : float = prim::Constant[value=0.]()
13
13
%scale : NoneType = prim::Constant()
14
+ %enable_gqa : bool = prim::Constant[value=0]()
14
15
%false : bool = prim::Constant[value=0]()
15
- %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale)
16
+ %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale, %enable_gqa )
16
17
return (%3))IR" ;
17
18
18
19
auto g = std::make_shared<torch::jit::Graph>();
@@ -38,7 +39,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
38
39
%0 : float = prim::Constant[value=0.]()
39
40
%false : bool = prim::Constant[value=0]()
40
41
%scale : NoneType = prim::Constant()
41
- %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
42
+ %enable_gqa : bool = prim::Constant[value=0]()
43
+ %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
42
44
return (%3))IR" ;
43
45
44
46
auto g = std::make_shared<torch::jit::Graph>();
@@ -59,13 +61,14 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
59
61
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
60
62
}
61
63
62
- TEST (Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly ) {
64
+ TEST (Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly ) {
63
65
const auto graph = R"IR(
64
66
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
65
67
%0 : float = prim::Constant[value=0.]()
66
68
%false : bool = prim::Constant[value=0]()
67
69
%scale : NoneType = prim::Constant()
68
- %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
70
+ %enable_gqa : bool = prim::Constant[value=0]()
71
+ %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
69
72
return (%3))IR" ;
70
73
71
74
auto g = std::make_shared<torch::jit::Graph>();
@@ -74,7 +77,7 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
74
77
auto query = at::rand ({32 , 8 , 128 , 64 }, {at::kCUDA });
75
78
auto key = at::rand ({32 , 8 , 128 , 64 }, {at::kCUDA });
76
79
auto value = at::rand ({32 , 8 , 128 , 64 }, {at::kCUDA });
77
- auto attn_mask = at::randint (0 , 2 , {32 , 8 , 128 , 128 }, at::kCUDA ). to (at:: kBool );
80
+ auto attn_mask = at::randint (0 , 2 , {32 , 8 , 128 , 128 }, { at::kCUDA } );
78
81
auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
79
82
auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {query, key, value, attn_mask});
80
83
0 commit comments