@@ -49,5 +49,52 @@ TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
4949 params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
5050 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
5151
52+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53+ }
54+ TEST (Converters, ATenCatPureTensorNegDimConvertsCorrectly) {
55+ const auto graph = R"IR(
56+ graph(%0 : Tensor,
57+ %1 : Tensor):
58+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
59+ %3 : int = prim::Constant[value=-1]()
60+ %4 : Tensor = aten::cat(%2, %3)
61+ return (%4))IR" ;
62+
63+ auto g = std::make_shared<torch::jit::Graph>();
64+ torch::jit::parseIR (graph, g.get ());
65+
66+ auto in1 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
67+ auto in2 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
68+
69+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
70+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1, in2});
71+
72+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
73+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1, in2});
74+
75+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
76+ }
77+
78+ TEST (Converters, ATenCatDiffTensorNegDimConvertsCorrectly) {
79+ const auto graph = R"IR(
80+ graph(%0 : Tensor,
81+ %1 : Float(5)):
82+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
83+ %3 : int = prim::Constant[value=-1]()
84+ %4 : Tensor = aten::cat(%2, %3)
85+ return (%4))IR" ;
86+
87+ auto g = std::make_shared<torch::jit::Graph>();
88+ torch::jit::parseIR (graph, g.get ());
89+
90+ auto in1 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
91+ auto in2 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
92+
93+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
94+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
95+
96+ params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
97+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
98+
5299 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53100}
0 commit comments