Skip to content

Commit 57143c2

Browse files
committed
test(/tests/core/converters/): added 3 tests for linear, bilinear, trilinear ops. removed redundant tests
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent bb46e70 commit 57143c2

File tree

1 file changed

+50
-18
lines changed

1 file changed

+50
-18
lines changed

tests/core/converters/test_interpolate.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(Converters, ATenUpsampleNearest1dConvertsCorrectly) {
3333
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
3434
}
3535

36-
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly1dOutputSize) {
36+
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
3737
const auto graph = R"IR(
3838
graph(%0 : Tensor):
3939
%2 : int = prim::Constant[value=10]()
@@ -62,21 +62,21 @@ TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly1dOutputSize) {
6262
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
6363
}
6464

65-
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
65+
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
6666
const auto graph = R"IR(
6767
graph(%0 : Tensor):
6868
%2 : int = prim::Constant[value=10]()
69-
%3 : int[] = prim::ListConstruct(%2, %2)
69+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
7070
%4 : None = prim::Constant()
71-
%5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
71+
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
7272
return (%5))IR";
7373

7474
auto g = std::make_shared<torch::jit::Graph>();
7575

7676
torch::jit::parseIR(graph, &*g);
7777

78-
// Input Tensor needs to be 4D for TensorRT upsample_nearest2d
79-
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
78+
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
79+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
8080

8181
auto jit_in = at::clone(in);
8282
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
@@ -91,21 +91,22 @@ TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
9191
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
9292
}
9393

94-
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly1dOutputSize) {
94+
TEST(Converters, ATenUpsampleLinear1dConvertsCorrectly) {
9595
const auto graph = R"IR(
9696
graph(%0 : Tensor):
9797
%2 : int = prim::Constant[value=10]()
98-
%3 : int[] = prim::ListConstruct(%2, %2, %2)
99-
%4 : None = prim::Constant()
100-
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
101-
return (%5))IR";
98+
%3 : int[] = prim::ListConstruct(%2)
99+
%4 : bool = prim::Constant[value=1]()
100+
%5 : None = prim::Constant()
101+
%6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5)
102+
return (%6))IR";
102103

103104
auto g = std::make_shared<torch::jit::Graph>();
104105

105106
torch::jit::parseIR(graph, &*g);
106107

107-
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
108-
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
108+
// Input Tensor needs to be 3D for TensorRT upsample_linear1d
109+
auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA});
109110

110111
auto jit_in = at::clone(in);
111112
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
@@ -120,20 +121,51 @@ TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly1dOutputSize) {
120121
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
121122
}
122123

123-
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
124+
TEST(Converters, ATenUpsampleBilinear2dConvertsCorrectly2dOutputSize) {
125+
const auto graph = R"IR(
126+
graph(%0 : Tensor):
127+
%2 : int = prim::Constant[value=10]()
128+
%3 : int[] = prim::ListConstruct(%2, %2)
129+
%4 : bool = prim::Constant[value=1]()
130+
%5 : None = prim::Constant()
131+
%6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5)
132+
return (%6))IR";
133+
134+
auto g = std::make_shared<torch::jit::Graph>();
135+
136+
torch::jit::parseIR(graph, &*g);
137+
138+
// Input Tensor needs to be 4D for TensorRT upsample_bilinear2d
139+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
140+
141+
auto jit_in = at::clone(in);
142+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
143+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
144+
145+
auto trt_in = at::clone(in);
146+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
147+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
148+
149+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
150+
151+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
152+
}
153+
154+
TEST(Converters, ATenUpsampleTrilinear3dConvertsCorrectly3dOutputSize) {
124155
const auto graph = R"IR(
125156
graph(%0 : Tensor):
126157
%2 : int = prim::Constant[value=10]()
127158
%3 : int[] = prim::ListConstruct(%2, %2, %2)
128-
%4 : None = prim::Constant()
129-
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
130-
return (%5))IR";
159+
%4 : bool = prim::Constant[value=1]()
160+
%5 : None = prim::Constant()
161+
%6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5)
162+
return (%6))IR";
131163

132164
auto g = std::make_shared<torch::jit::Graph>();
133165

134166
torch::jit::parseIR(graph, &*g);
135167

136-
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
168+
// Input Tensor needs to be 5D for TensorRT upsample_trilinear3d
137169
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
138170

139171
auto jit_in = at::clone(in);

0 commit comments

Comments
 (0)