@@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST(
377377 %7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6)
378378 return (%7))IR" ,
379379 std::vector<int64_t >({10 , 2 , 2 , 2 , 2 }));
380+
381+ TEST (Converters, GridSampleConvertsCorrectly) {
382+ const auto graph = R"IR(
383+ graph(%input : Tensor, %grid : Tensor):
384+ %5 : int = prim::Constant[value=2]()
385+ %6 : int = prim::Constant[value=2]()
386+ %7 : bool = prim::Constant[value=1]()
387+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
388+ return (%8))IR" ;
389+ auto g = std::make_shared<torch::jit::Graph>();
390+
391+ torch::jit::parseIR (graph, g.get ());
392+
393+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
394+ auto d = at::linspace (-1 , 1 , 8 );
395+ auto mesh = at::meshgrid ({d, d});
396+ auto mesh_x = mesh[0 ];
397+ auto mesh_y = mesh[1 ];
398+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
399+
400+ auto trt_input = input.clone ();
401+ auto trt_grid = grid.clone ();
402+
403+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
404+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
405+
406+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
407+
408+ for (size_t i = 0 ; i < jit_results.size (); i++) {
409+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
410+ }
411+ }
412+
413+ TEST (Converters, GridSampleOptions1ConvertsCorrectly) {
414+ const auto graph = R"IR(
415+ graph(%input : Tensor, %grid : Tensor):
416+ %5 : int = prim::Constant[value=1]()
417+ %6 : int = prim::Constant[value=1]()
418+ %7 : bool = prim::Constant[value=0]()
419+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
420+ return (%8))IR" ;
421+ auto g = std::make_shared<torch::jit::Graph>();
422+
423+ torch::jit::parseIR (graph, g.get ());
424+
425+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
426+ auto d = at::linspace (-1 , 1 , 8 );
427+ auto mesh = at::meshgrid ({d, d});
428+ auto mesh_x = mesh[0 ];
429+ auto mesh_y = mesh[1 ];
430+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
431+
432+ auto trt_input = input.clone ();
433+ auto trt_grid = grid.clone ();
434+
435+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
436+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
437+
438+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
439+
440+ for (size_t i = 0 ; i < jit_results.size (); i++) {
441+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
442+ }
443+ }
444+
445+ TEST (Converters, GridSampleOptions2ConvertsCorrectly) {
446+ const auto graph = R"IR(
447+ graph(%input : Tensor, %grid : Tensor):
448+ %5 : int = prim::Constant[value=0]()
449+ %6 : int = prim::Constant[value=0]()
450+ %7 : bool = prim::Constant[value=0]()
451+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
452+ return (%8))IR" ;
453+ auto g = std::make_shared<torch::jit::Graph>();
454+
455+ torch::jit::parseIR (graph, g.get ());
456+
457+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
458+ auto d = at::linspace (-1 , 1 , 8 );
459+ auto mesh = at::meshgrid ({d, d});
460+ auto mesh_x = mesh[0 ];
461+ auto mesh_y = mesh[1 ];
462+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
463+
464+ auto trt_input = input.clone ();
465+ auto trt_grid = grid.clone ();
466+
467+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
468+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
469+
470+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
471+
472+ for (size_t i = 0 ; i < jit_results.size (); i++) {
473+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
474+ }
475+ }
0 commit comments