Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions core/lowering/passes/unpack_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ namespace passes {
// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
std::string sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale)
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa)
return (%out))IR";

std::string unpacked_sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
Expand All @@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
return(%out))IR";

std::string unpacked_sdpa_attn_biased_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=-1]()
Expand Down Expand Up @@ -69,6 +69,16 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant) {
LOG_WARNING(
"Could not unpack scaled_dot_product_attention with non constant enable_gqa: " << *enable_gqa_node);
return false;
}
if (enable_gqa_node->i(at::attr::value) == 1) {
LOG_WARNING("Could not unpack scaled_dot_product_attention with enable_gqa = True: " << *enable_gqa_node);
return false;
}
return true;
});

Expand All @@ -83,6 +93,11 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
// messages already written in first pass, do not write again
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant || enable_gqa_node->i(at::attr::value) == 1) {
// messages already written in first pass, do not write again
return false;
}
return true;
});
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) {
%none : NoneType = prim::Constant()
%0 : float = prim::Constant[value=0.]()
%scale : NoneType = prim::Constant()
%enable_gqa : bool = prim::Constant[value=0]()
%false : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale)
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -38,7 +39,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%scale : NoneType = prim::Constant()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
%enable_gqa : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -59,13 +61,14 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
TEST(Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%scale : NoneType = prim::Constant()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
%enable_gqa : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -74,7 +77,7 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, at::kCUDA).to(at::kBool);
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});

Expand Down