Skip to content

Commit 6feb673

Browse files
authored
Merge pull request #1712 from pytorch/constant_pad_fix
fix: Handle nonetype pad value for Constant pad
2 parents fb42d42 + 1980786 commit 6feb673

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

core/lowering/passes/replace_aten_pad.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph) {
9999
} else if (mode_str == "constant") {
100100
// aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
101101
torch::jit::Node* new_node;
102+
auto pad_value = it->inputs()[3];
103+
auto is_pad_none = torch::jit::toIValue(it->inputs()[3])->isNone();
104+
if (is_pad_none) {
105+
pad_value = graph->insertConstant(0.0);
106+
}
107+
102108
new_node = graph->create(
103109
c10::Symbol::fromQualString("aten::constant_pad_nd"),
104-
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1], it->inputs()[3]}),
110+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1], pad_value}),
105111
1);
112+
106113
new_node->insertAfter(*it);
107114
new_node->outputs()[0]->setType(c10::TensorType::get());
108115
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);

tests/core/lowering/test_replace_aten_pad_pass.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,43 @@ TEST(LoweringPasses, AtenPadConstantCorrectly) {
4343
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
4444
}
4545

46+
TEST(LoweringPasses, AtenPadConstantNoneValueCorrectly) {
47+
const auto source_graph = R"IR(
48+
graph(%0 : Tensor):
49+
%2 : str = prim::Constant[value="constant"]()
50+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
51+
%3 : NoneType = prim::Constant()
52+
%4 : Tensor = aten::pad(%0, %1, %2, %3)
53+
return (%4))IR";
54+
55+
const auto target_graph = R"IR(
56+
graph(%0 : Tensor):
57+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
58+
%2 : Scalar = prim::Constant[value=0.0]()
59+
%3 : Tensor = aten::constant_pad_nd(%0, %1, %2)
60+
return (%3))IR";
61+
62+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
63+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
64+
auto sg = std::make_shared<torch::jit::Graph>();
65+
torch::jit::parseIR(source_graph, &*sg);
66+
torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg);
67+
68+
auto tg = std::make_shared<torch::jit::Graph>();
69+
torch::jit::parseIR(target_graph, &*tg);
70+
71+
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
72+
73+
auto trt_in = at::clone(in);
74+
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {});
75+
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});
76+
77+
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {});
78+
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
79+
80+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
81+
}
82+
4683
TEST(LoweringPasses, AtenPadReflect1dCorrectly) {
4784
const auto source_graph = R"IR(
4885
graph(%0 : Tensor):
@@ -221,4 +258,4 @@ TEST(LoweringPasses, AtenPadReplicate3dCorrectly) {
221258
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
222259

223260
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
224-
}
261+
}

0 commit comments

Comments
 (0)