@@ -43,6 +43,43 @@ TEST(LoweringPasses, AtenPadConstantCorrectly) {
4343 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
4444}
4545
46+ TEST (LoweringPasses, AtenPadConstantNoneValueCorrectly) {
47+ const auto source_graph = R"IR(
48+ graph(%0 : Tensor):
49+ %2 : str = prim::Constant[value="constant"]()
50+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
51+ %3 : NoneType = prim::Constant()
52+ %4 : Tensor = aten::pad(%0, %1, %2, %3)
53+ return (%4))IR" ;
54+
55+ const auto target_graph = R"IR(
56+ graph(%0 : Tensor):
57+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
58+ %2 : Scalar = prim::Constant[value=0.0]()
59+ %3 : Tensor = aten::constant_pad_nd(%0, %1, %2)
60+ return (%3))IR" ;
61+
62+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
63+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
64+ auto sg = std::make_shared<torch::jit::Graph>();
65+ torch::jit::parseIR (source_graph, &*sg);
66+ torch_tensorrt::core::lowering::passes::ReplaceAtenPad (sg);
67+
68+ auto tg = std::make_shared<torch::jit::Graph>();
69+ torch::jit::parseIR (target_graph, &*tg);
70+
71+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
72+
73+ auto trt_in = at::clone (in);
74+ auto params = torch_tensorrt::core::ir::get_static_params (sg->inputs (), {});
75+ auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine (sg, params, {trt_in});
76+
77+ params = torch_tensorrt::core::ir::get_static_params (tg->inputs (), {});
78+ auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine (tg, params, {trt_in});
79+
80+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
81+ }
82+
4683TEST (LoweringPasses, AtenPadReflect1dCorrectly) {
4784 const auto source_graph = R"IR(
4885 graph(%0 : Tensor):
@@ -221,4 +258,4 @@ TEST(LoweringPasses, AtenPadReplicate3dCorrectly) {
221258 auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine (tg, params, {trt_in});
222259
223260 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
224- }
261+ }
0 commit comments