From 71a92c3004f243dd1c5e6cad73bcc36d66d597b6 Mon Sep 17 00:00:00 2001 From: Ruoqian Guo Date: Thu, 12 Jan 2023 03:38:31 +0000 Subject: [PATCH 1/3] feat: lower aten::pad to aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd Signed-off-by: Ruoqian Guo --- core/lowering/lowering.cpp | 1 + core/lowering/passes/BUILD | 1 + core/lowering/passes/passes.h | 1 + core/lowering/passes/replace_aten_pad.cpp | 139 +++++++++++ tests/core/lowering/BUILD | 5 + .../lowering/test_replace_aten_pad_pass.cpp | 224 ++++++++++++++++++ 6 files changed, 371 insertions(+) create mode 100644 core/lowering/passes/replace_aten_pad.cpp create mode 100644 tests/core/lowering/test_replace_aten_pad_pass.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cf57e7c83c..b1406446f1 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -144,6 +144,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph, std::str void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::string target_device_name); void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); +void ReplaceAtenPad(std::shared_ptr& graph); // utility functions exposed for testing std::string unmangle_cls_name(const std::string& name); diff --git a/core/lowering/passes/replace_aten_pad.cpp b/core/lowering/passes/replace_aten_pad.cpp new file mode 100644 index 0000000000..d43e0d8613 --- /dev/null +++ b/core/lowering/passes/replace_aten_pad.cpp @@ -0,0 +1,139 @@ +#include + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void ReplaceAtenPad(std::shared_ptr& graph) { + for (auto it = graph->block()->nodes().begin(), end = graph->block()->nodes().end(); it != end; ++it) { + if (it->kind() == c10::Symbol::fromQualString("aten::pad")) { + // aten::pad(Tensor self, int[] pad, str mode='constant', float? value=None) -> (Tensor) + auto mode = it->inputs()[2]; + if(mode->type()->isSubtypeOf(c10::StringType::get())){ + std::string mode_str = torch::jit::toIValue(mode)->to(); + if(mode_str == "reflect") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if(pad_list.size() == 2) + { + // aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + else if(pad_list.size() == 4) + { + //aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + else if(pad_list.size() == 6) + { + LOG_ERROR("Torch-TRT doesn't support aten::reflection_pad3d currently."); + } + + } + else if(mode_str == "replicate") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if(pad_list.size() == 2) + { + // aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + else if(pad_list.size() == 4) + { + // aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + else if(pad_list.size() == 6) + { + // aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad3d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } + + } + else if(mode_str == "constant") { + // aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::constant_pad_nd"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], it->inputs()[3]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if(mode_str == "circular"){ + LOG_ERROR("Torch-TRT doesn't support circular padding currently."); + } + } + + } + + } + LOG_GRAPH("Post map aten::pad -> aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 7f4e53d8a6..4fdcb27e1d 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -95,6 +95,10 @@ lowering_test( name = "test_rewrite_inputs_with_params", ) +lowering_test( + name = "test_replace_aten_pad_pass", +) + test_suite( name = "lowering_tests", tests = [ @@ -116,5 +120,6 @@ test_suite( ":test_unpack_hardswish", ":test_unpack_reduce_ops", ":test_view_to_reshape_pass", + ":test_replace_aten_pad_pass", ], ) diff --git a/tests/core/lowering/test_replace_aten_pad_pass.cpp b/tests/core/lowering/test_replace_aten_pad_pass.cpp new file mode 100644 index 0000000000..ddad5efb16 --- /dev/null +++ b/tests/core/lowering/test_replace_aten_pad_pass.cpp @@ -0,0 +1,224 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, AtenPadConstantCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="constant"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %2 : Scalar = prim::Constant[value=0.0]() + %3 : Tensor = aten::constant_pad_nd(%0, %1, %2) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReflect1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="reflect"]() + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : Tensor = aten::reflection_pad1d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReflect2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="reflect"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : Tensor = aten::reflection_pad2d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate1dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3]]() + %3 : Tensor = aten::replication_pad1d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate2dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : Tensor = aten::replication_pad2d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + +TEST(LoweringPasses, AtenPadReplicate3dCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="replicate"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]() + %3 : float = prim::Constant[value=0.0]() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]() + %3 : Tensor = aten::replication_pad3d(%0, %1) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5, 3}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} \ No newline at end of file From 3d5c8b61ae1d25084916f3bcd16bb05c300a9a6a Mon Sep 17 00:00:00 2001 From: Ruoqian Guo Date: Mon, 30 Jan 2023 06:58:21 +0000 Subject: [PATCH 2/3] style: run pre-commit Signed-off-by: Ruoqian Guo --- core/lowering/passes/BUILD | 2 +- core/lowering/passes/replace_aten_pad.cpp | 216 ++++++++++------------ tests/core/lowering/BUILD | 2 +- 3 files changed, 103 insertions(+), 117 deletions(-) diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 73cd8bba72..498ce2ed55 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -28,6 +28,7 @@ cc_library( "remove_dropout.cpp", "remove_nops.cpp", "remove_unnecessary_casts.cpp", + "replace_aten_pad.cpp", "rewrite_inputs_with_params.cpp", "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", @@ -39,7 +40,6 @@ cc_library( "unpack_std.cpp", "unpack_var.cpp", "view_to_reshape.cpp", - "replace_aten_pad.cpp", ], hdrs = [ "passes.h", diff --git a/core/lowering/passes/replace_aten_pad.cpp b/core/lowering/passes/replace_aten_pad.cpp index d43e0d8613..f99a0349c1 100644 --- a/core/lowering/passes/replace_aten_pad.cpp +++ b/core/lowering/passes/replace_aten_pad.cpp @@ -10,125 +10,111 @@ namespace passes { void ReplaceAtenPad(std::shared_ptr& graph) { for (auto it = graph->block()->nodes().begin(), end = graph->block()->nodes().end(); it != end; ++it) { if (it->kind() == c10::Symbol::fromQualString("aten::pad")) { - // aten::pad(Tensor self, int[] pad, str mode='constant', float? value=None) -> (Tensor) - auto mode = it->inputs()[2]; - if(mode->type()->isSubtypeOf(c10::StringType::get())){ - std::string mode_str = torch::jit::toIValue(mode)->to(); - if(mode_str == "reflect") { - auto pad = it->inputs()[1]; - c10::List pad_list = torch::jit::toIValue(pad)->to>(); - if(pad_list.size() == 2) - { - // aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::reflection_pad1d"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } - else if(pad_list.size() == 4) - { - //aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::reflection_pad2d"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } - else if(pad_list.size() == 6) - { - LOG_ERROR("Torch-TRT doesn't support aten::reflection_pad3d currently."); - } + // aten::pad(Tensor self, int[] pad, str mode='constant', float? value=None) -> (Tensor) + auto mode = it->inputs()[2]; + if (mode->type()->isSubtypeOf(c10::StringType::get())) { + std::string mode_str = torch::jit::toIValue(mode)->to(); + if (mode_str == "reflect") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if (pad_list.size() == 2) { + // aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 4) { + // aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::reflection_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 6) { + LOG_ERROR("Torch-TRT doesn't support aten::reflection_pad3d currently."); + } - } - else if(mode_str == "replicate") { - auto pad = it->inputs()[1]; - c10::List pad_list = torch::jit::toIValue(pad)->to>(); - if(pad_list.size() == 2) - { - // aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::replication_pad1d"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } - else if(pad_list.size() == 4) - { - // aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::replication_pad2d"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } - else if(pad_list.size() == 6) - { - // aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::replication_pad3d"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } + } else if (mode_str == "replicate") { + auto pad = it->inputs()[1]; + c10::List pad_list = torch::jit::toIValue(pad)->to>(); + if (pad_list.size() == 2) { + // aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad1d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 4) { + // aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad2d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (pad_list.size() == 6) { + // aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::replication_pad3d"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } - } - else if(mode_str == "constant") { - // aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor) - torch::jit::Node* new_node; - new_node = graph->create( - c10::Symbol::fromQualString("aten::constant_pad_nd"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], it->inputs()[3]}), - 1); - new_node->insertAfter(*it); - new_node->outputs()[0]->setType(c10::TensorType::get()); - it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - auto pre = --it; - ++it; - it->destroy(); - it = pre; - } else if(mode_str == "circular"){ - LOG_ERROR("Torch-TRT doesn't support circular padding currently."); - } + } else if (mode_str == "constant") { + // aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor) + torch::jit::Node* new_node; + new_node = graph->create( + c10::Symbol::fromQualString("aten::constant_pad_nd"), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], it->inputs()[3]}), + 1); + new_node->insertAfter(*it); + new_node->outputs()[0]->setType(c10::TensorType::get()); + it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + auto pre = --it; + ++it; + it->destroy(); + it = pre; + } else if (mode_str == "circular") { + LOG_ERROR("Torch-TRT doesn't support circular padding currently."); } - + } } - } LOG_GRAPH("Post map aten::pad -> aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd: " << *graph); } diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 4fdcb27e1d..8cc2c3a1e9 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -115,11 +115,11 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_unnecessary_casts", + ":test_replace_aten_pad_pass", ":test_rewrite_inputs_with_params", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", ":test_view_to_reshape_pass", - ":test_replace_aten_pad_pass", ], ) From b117d697337b110f612f4b2418d14e07a2fa5a8c Mon Sep 17 00:00:00 2001 From: Ruoqian Guo Date: Tue, 31 Jan 2023 03:30:45 +0000 Subject: [PATCH 3/3] chore: add the new file to CMake system Signed-off-by: Ruoqian Guo --- core/lowering/passes/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index 4c9ebc7efa..bdd299867d 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/remove_nops.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/remove_set_attrs.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/remove_unnecessary_casts.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/replace_aten_pad.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"