@@ -30,17 +30,87 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {
3030
3131// TODO: IR Parser doesnt work well with neg numbers
3232TEST (Converters, ATenFlattenOtherDimsConvertsCorrectly) {
33- const auto graph = R"IR(
34- graph(%0 : Tensor):
35- %1 : int = prim::Constant[value=1]()
36- %2 : int = prim::Constant[value=2]()
37- %3 : Tensor = aten::flatten(%0, %1, %2)
38- return (%3))IR" ;
33+ const auto graph = R"IR(
34+ graph(%0 : Tensor):
35+ %1 : int = prim::Constant[value=1]()
36+ %2 : int = prim::Constant[value=2]()
37+ %3 : Tensor = aten::flatten(%0, %1, %2)
38+ return (%3))IR" ;
39+
40+ auto g = std::make_shared<torch::jit::Graph>();
41+ torch::jit::parseIR (graph, &*g);
42+
43+ auto in = at::randint (0 , 5 , {2 , 3 , 3 }, {at::kCUDA });
44+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
45+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
46+
47+ in = at::clone (in);
48+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
49+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
50+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
51+
52+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
53+ }
3954
40- auto g = std::make_shared<torch::jit::Graph>();
55+ TEST (Converters, ATenReshapeConvertsCorrectly) {
56+ const auto graph = R"IR(
57+ graph(%0 : Tensor):
58+ %1 : int = prim::Constant[value=3]()
59+ %2 : int = prim::Constant[value=2]()
60+ %3 : int[] = prim::ListConstruct(%1, %2)
61+ %4 : Tensor = aten::reshape(%0, %3)
62+ return (%4))IR" ;
63+
64+ auto g = std::make_shared<torch::jit::Graph>();
65+ torch::jit::parseIR (graph, &*g);
66+
67+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
68+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
69+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
70+
71+ in = at::clone (in);
72+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
73+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
74+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
75+
76+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
77+ }
78+
79+ TEST (Converters, ATenViewConvertsCorrectly) {
80+ const auto graph = R"IR(
81+ graph(%0 : Tensor):
82+ %1 : int = prim::Constant[value=3]()
83+ %2 : int = prim::Constant[value=2]()
84+ %3 : int[] = prim::ListConstruct(%1, %2)
85+ %4 : Tensor = aten::view(%0, %3)
86+ return (%4))IR" ;
87+
88+ auto g = std::make_shared<torch::jit::Graph>();
89+ torch::jit::parseIR (graph, &*g);
90+
91+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
92+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
93+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
94+
95+ in = at::clone (in);
96+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
97+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
98+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
99+
100+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
101+ }
102+
103+ TEST (Converters, ATenPermuteConvertsCorrectly) {
104+ const auto graph = R"IR(
105+ graph(%x.1 : Tensor):
106+ %2 : int[] = prim::Constant[value=[3, 0, 1, 2]]()
107+ %3 : Tensor = aten::permute(%x.1, %2)
108+ return (%3))IR" ;
109+
110+ auto g = std::make_shared<torch::jit::Graph>();
41111 torch::jit::parseIR (graph, &*g);
42112
43- auto in = at::randint (0 , 5 , {2 , 3 , 3 }, {at::kCUDA });
113+ auto in = at::randint (0 , 5 , {2 , 3 , 2 , 3 }, {at::kCUDA });
44114 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
45115 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
46116
@@ -52,19 +122,17 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
52122 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
53123}
54124
55- TEST (Converters, ATenReshapeConvertsCorrectly) {
56- const auto graph = R"IR(
57- graph(%0 : Tensor):
58- %1 : int = prim::Constant[value=3]()
59- %2 : int = prim::Constant[value=2]()
60- %3 : int[] = prim::ListConstruct(%1, %2)
61- %4 : Tensor = aten::reshape(%0, %3)
62- return (%4))IR" ;
125+ TEST (Converters, ATenPermute3DConvertsCorrectly) {
126+ const auto graph = R"IR(
127+ graph(%x.1 : Tensor):
128+ %2 : int[] = prim::Constant[value=[0, 2, 1]]()
129+ %3 : Tensor = aten::permute(%x.1, %2)
130+ return (%3))IR" ;
63131
64- auto g = std::make_shared<torch::jit::Graph>();
132+ auto g = std::make_shared<torch::jit::Graph>();
65133 torch::jit::parseIR (graph, &*g);
66134
67- auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
135+ auto in = at::randint (0 , 5 , {2 , 2 , 3 }, {at::kCUDA });
68136 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
69137 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
70138
@@ -76,19 +144,17 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
76144 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
77145}
78146
79- TEST (Converters, ATenViewConvertsCorrectly) {
80- const auto graph = R"IR(
81- graph(%0 : Tensor):
82- %1 : int = prim::Constant[value=3]()
83- %2 : int = prim::Constant[value=2]()
84- %3 : int[] = prim::ListConstruct(%1, %2)
85- %4 : Tensor = aten::view(%0, %3)
86- return (%4))IR" ;
147+ TEST (Converters, ATenPermute5DConvertsCorrectly) {
148+ const auto graph = R"IR(
149+ graph(%x.1 : Tensor):
150+ %2 : int[] = prim::Constant[value=[3, 4, 0, 2, 1]]()
151+ %3 : Tensor = aten::permute(%x.1, %2)
152+ return (%3))IR" ;
87153
88- auto g = std::make_shared<torch::jit::Graph>();
154+ auto g = std::make_shared<torch::jit::Graph>();
89155 torch::jit::parseIR (graph, &*g);
90156
91- auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
157+ auto in = at::randint (0 , 5 , {2 , 2 , 1 , 2 , 3 }, {at::kCUDA });
92158 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
93159 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
94160
0 commit comments