1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenReshapeConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor):
10
+ %1 : int = prim::Constant[value=3]()
11
+ %2 : int = prim::Constant[value=2]()
12
+ %3 : int[] = prim::ListConstruct(%1, %2)
13
+ %4 : Tensor = aten::reshape(%0, %3)
14
+ return (%4))IR" ;
15
+
16
+ auto g = std::make_shared<torch::jit::Graph>();
17
+ torch::jit::parseIR (graph, &*g);
18
+
19
+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
20
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
21
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
22
+
23
+ in = at::clone (in);
24
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
25
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
26
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
27
+
28
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
29
+ }
0 commit comments