| 
 | 1 | +#include <string>  | 
 | 2 | +#include "core/compiler.h"  | 
 | 3 | +#include "gtest/gtest.h"  | 
 | 4 | +#include "tests/util/util.h"  | 
 | 5 | +#include "torch/csrc/jit/ir/irparser.h"  | 
 | 6 | + | 
 | 7 | +TEST(Converters, ATenMaxDimConvertsCorrectly) {  | 
 | 8 | +  const auto graph = R"IR(  | 
 | 9 | +    graph(%x.1 : Tensor):  | 
 | 10 | +      %2 : int = prim::Constant[value=0]()  | 
 | 11 | +      %3 : bool = prim::Constant[value=0]()  | 
 | 12 | +      %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)  | 
 | 13 | +      return (%4, %5))IR";  | 
 | 14 | + | 
 | 15 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 16 | +  torch::jit::parseIR(graph, g.get());  | 
 | 17 | + | 
 | 18 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 19 | + | 
 | 20 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 21 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 22 | + | 
 | 23 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 24 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 25 | + | 
 | 26 | +  ASSERT_TRUE(  | 
 | 27 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 28 | +  ASSERT_TRUE(  | 
 | 29 | +      torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));  | 
 | 30 | +}  | 
 | 31 | + | 
 | 32 | +TEST(Converters, ATenMinDimConvertsCorrectly) {  | 
 | 33 | +  const auto graph = R"IR(  | 
 | 34 | +    graph(%x.1 : Tensor):  | 
 | 35 | +      %2 : int = prim::Constant[value=0]()  | 
 | 36 | +      %3 : bool = prim::Constant[value=0]()  | 
 | 37 | +      %4 : Tensor, %5 : Tensor = aten::min(%x.1, %2, %3)  | 
 | 38 | +      return (%4, %5))IR";  | 
 | 39 | + | 
 | 40 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 41 | +  torch::jit::parseIR(graph, g.get());  | 
 | 42 | + | 
 | 43 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 44 | + | 
 | 45 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 46 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 47 | + | 
 | 48 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 49 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 50 | + | 
 | 51 | +  ASSERT_TRUE(  | 
 | 52 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 53 | +  ASSERT_TRUE(  | 
 | 54 | +      torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));  | 
 | 55 | +}  | 
 | 56 | + | 
 | 57 | +TEST(Converters, ATenArgMaxConvertsCorrectly) {  | 
 | 58 | +  const auto graph = R"IR(  | 
 | 59 | +    graph(%x.1 : Tensor):  | 
 | 60 | +      %2 : int = prim::Constant[value=0]()  | 
 | 61 | +      %3 : bool = prim::Constant[value=0]()  | 
 | 62 | +      %4 : Tensor = aten::argmax(%x.1, %2, %3)  | 
 | 63 | +      return (%4))IR";  | 
 | 64 | + | 
 | 65 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 66 | +  torch::jit::parseIR(graph, g.get());  | 
 | 67 | + | 
 | 68 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 69 | + | 
 | 70 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 71 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 72 | + | 
 | 73 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 74 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 75 | + | 
 | 76 | +  ASSERT_TRUE(  | 
 | 77 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 78 | +}  | 
 | 79 | + | 
 | 80 | +TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) {  | 
 | 81 | +  const auto graph = R"IR(  | 
 | 82 | +    graph(%x.1 : Tensor):  | 
 | 83 | +      %2 : int = prim::Constant[value=1]()  | 
 | 84 | +      %3 : bool = prim::Constant[value=1]()  | 
 | 85 | +      %4 : Tensor = aten::argmax(%x.1, %2, %3)  | 
 | 86 | +      return (%4))IR";  | 
 | 87 | + | 
 | 88 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 89 | +  torch::jit::parseIR(graph, g.get());  | 
 | 90 | + | 
 | 91 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 92 | + | 
 | 93 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 94 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 95 | + | 
 | 96 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 97 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 98 | + | 
 | 99 | +  ASSERT_TRUE(  | 
 | 100 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 101 | +}  | 
 | 102 | + | 
 | 103 | +TEST(Converters, ATenArgMinConvertsCorrectly) {  | 
 | 104 | +  const auto graph = R"IR(  | 
 | 105 | +    graph(%x.1 : Tensor):  | 
 | 106 | +      %2 : int = prim::Constant[value=0]()  | 
 | 107 | +      %3 : bool = prim::Constant[value=0]()  | 
 | 108 | +      %4 : Tensor = aten::argmin(%x.1, %2, %3)  | 
 | 109 | +      return (%4))IR";  | 
 | 110 | + | 
 | 111 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 112 | +  torch::jit::parseIR(graph, g.get());  | 
 | 113 | + | 
 | 114 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 115 | + | 
 | 116 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 117 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 118 | + | 
 | 119 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 120 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 121 | + | 
 | 122 | +  ASSERT_TRUE(  | 
 | 123 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 124 | +}  | 
 | 125 | + | 
 | 126 | +TEST(Converters, ATenArgMinKeepdimConvertsCorrectly) {  | 
 | 127 | +  const auto graph = R"IR(  | 
 | 128 | +    graph(%x.1 : Tensor):  | 
 | 129 | +      %2 : int = prim::Constant[value=1]()  | 
 | 130 | +      %3 : bool = prim::Constant[value=1]()  | 
 | 131 | +      %4 : Tensor = aten::argmin(%x.1, %2, %3)  | 
 | 132 | +      return (%4))IR";  | 
 | 133 | + | 
 | 134 | +  auto g = std::make_shared<torch::jit::Graph>();  | 
 | 135 | +  torch::jit::parseIR(graph, g.get());  | 
 | 136 | + | 
 | 137 | +  auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});  | 
 | 138 | + | 
 | 139 | +  auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 140 | +  auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});  | 
 | 141 | + | 
 | 142 | +  params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});  | 
 | 143 | +  auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});  | 
 | 144 | + | 
 | 145 | +  ASSERT_TRUE(  | 
 | 146 | +      torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));  | 
 | 147 | +}  | 
0 commit comments