1010// const c10::optional<Tensor>& bias_opt /* optional */,
1111// const c10::optional<Tensor>& running_mean_opt /* optional */,
1212// const c10::optional<Tensor>& running_var_opt /* optional */,
13- // bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
13+ // bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
1414constexpr auto graph = R"IR(
1515 graph(%input.1 : Tensor,
1616 %weight.1 : Tensor?,
@@ -28,23 +28,23 @@ constexpr auto graph = R"IR(
2828 return (%4)
2929)IR" ;
3030
31-
3231TEST (Converters, ATenInstanceNormConvertsCorrectly) {
3332 auto g = std::make_shared<torch::jit::Graph>();
3433 torch::jit::parseIR (graph, g.get ());
3534
3635 auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
3736 torch::jit::IValue weight, bias, mean, var; // NoneType
3837 // https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
39- const bool use_input_stats = true ;
40-
38+ const bool use_input_stats = true ;
39+
4140 auto trt_in = at::clone (in);
4241 torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
4342
4443 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
4544 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
4645
47- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
46+ params = trtorch::core::conversion::get_named_params (
47+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
4848 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
4949
5050 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
@@ -58,8 +58,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
5858
5959 auto weight = at::randn ({in.size (1 )}).to (at::kCUDA );
6060 auto bias = at::randn ({in.size (1 )}).to (at::kCUDA );
61-
62- torch::jit::IValue mean, var; // NoneType
61+
62+ torch::jit::IValue mean, var; // NoneType
6363 const bool use_input_stats = true ;
6464
6565 auto trt_in = at::clone (in);
@@ -70,7 +70,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
7070 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
7171 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
7272
73- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
73+ params = trtorch::core::conversion::get_named_params (
74+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
7475 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
7576
7677 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
@@ -81,12 +82,12 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
8182 torch::jit::parseIR (graph, g.get ());
8283
8384 auto in = at::randn ({1 , 5 , 5 , 5 }, {at::kCUDA });
84-
85+
8586 torch::jit::IValue weight, bias;
8687 auto mean = at::zeros ({in.size (1 )}, {at::kCUDA });
8788 auto var = at::ones ({in.size (1 )}, {at::kCUDA });
8889 const bool use_input_stats = false ;
89-
90+
9091 auto trt_in = at::clone (in);
9192 torch::jit::IValue trt_weight, trt_bias;
9293 auto trt_mean = at::clone (mean);
@@ -95,7 +96,8 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
9596 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
9697 auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
9798
98- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99+ params = trtorch::core::conversion::get_named_params (
100+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99101 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
100102 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
101103}
0 commit comments