@@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
437437 ASSERT_TRUE (
438438 torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
439439}
440+
441+ TEST (LoweringPasses, RemoveAtenIntTensorValuesAgree) {
442+ std::string source_graph_no_inputs = R"IR(
443+ graph():
444+ %0: int = prim::Constant[value=2]()
445+ %11: int = prim::Constant[value=7]()
446+ %3: Tensor = prim::NumToTensor(%0)
447+ %1: Tensor = prim::NumToTensor(%11)
448+ %4: Tensor = aten::floor_divide(%1, %3)
449+ %7: Tensor = aten::mul(%3, %4)
450+ %8: Tensor = aten::mul(%7, %1)
451+ %50: int = aten::Int(%8)
452+ %5: Tensor = prim::NumToTensor(%50)
453+ return (%5))IR" ;
454+ std::string target_graph_no_inputs = R"IR(
455+ graph():
456+ %0: int = prim::Constant[value=2]()
457+ %1: int = prim::Constant[value=7]()
458+ %4: int = aten::floordiv(%1, %0)
459+ %7: int = aten::mul(%0, %4)
460+ %40: int = aten::mul(%7, %1)
461+ %4: Tensor = prim::NumToTensor(%40)
462+ return (%4))IR" ;
463+
464+ auto g_in = std::make_shared<torch::jit::Graph>();
465+ auto g_out = std::make_shared<torch::jit::Graph>();
466+
467+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
468+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
469+
470+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {});
471+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {});
472+
473+ ASSERT_TRUE (
474+ torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
475+
476+ // Ensure the lowering pass transforms the first graph into the second
477+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
478+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
479+ auto sg = std::make_shared<torch::jit::Graph>();
480+ torch::jit::parseIR (source_graph_no_inputs, sg.get ());
481+
482+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
483+
484+ auto tg = std::make_shared<torch::jit::Graph>();
485+ torch::jit::parseIR (target_graph_no_inputs, tg.get ());
486+
487+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
488+ }
489+
490+ TEST (LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) {
491+ std::string source_graph_no_inputs = R"IR(
492+ graph(%x.0: Tensor):
493+ %10: int = prim::Constant[value=0]()
494+ %100: int = aten::size(%x.0, %10)
495+ %0: Tensor = prim::NumToTensor(%100)
496+ %11: int = prim::Constant[value=9]()
497+ %1: Tensor = prim::NumToTensor(%11)
498+ %4: Tensor = aten::floor_divide(%1, %0)
499+ %7: Tensor = aten::mul(%0, %4)
500+ %8: Tensor = aten::mul(%7, %1)
501+ %50: int = aten::Int(%8)
502+ %5: Tensor = prim::NumToTensor(%50)
503+ return (%5))IR" ;
504+ std::string target_graph_no_inputs = R"IR(
505+ graph(%x.0: Tensor):
506+ %10: int = prim::Constant[value=0]()
507+ %0: int = aten::size(%x.0, %10)
508+ %1: int = prim::Constant[value=9]()
509+ %4: int = aten::floordiv(%1, %0)
510+ %7: int = aten::mul(%0, %4)
511+ %40: int = aten::mul(%7, %1)
512+ %4: Tensor = prim::NumToTensor(%40)
513+ return (%4))IR" ;
514+
515+ auto g_in = std::make_shared<torch::jit::Graph>();
516+ auto g_out = std::make_shared<torch::jit::Graph>();
517+
518+ auto in_0 = at::rand ({2 , 3 , 5 , 5 }, {at::kCUDA });
519+
520+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
521+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
522+
523+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {in_0});
524+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {in_0});
525+
526+ ASSERT_TRUE (
527+ torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
528+
529+ // Ensure the lowering pass transforms the first graph into the second
530+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
531+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
532+ auto sg = std::make_shared<torch::jit::Graph>();
533+ torch::jit::parseIR (source_graph_no_inputs, sg.get ());
534+
535+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
536+
537+ auto tg = std::make_shared<torch::jit::Graph>();
538+ torch::jit::parseIR (target_graph_no_inputs, tg.get ());
539+
540+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
541+ }
542+
543+ TEST (LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
544+ // Ensure the lowering pass transforms the first graph into the second
545+ std::string source_graph = R"IR(
546+ graph(%0: int):
547+ %1: Tensor = prim::Constant[value=[8]]()
548+ %3: Tensor = prim::NumToTensor(%0)
549+ %4: Tensor = aten::floor_divide(%3, %1)
550+ %5: int = aten::Int(%4)
551+ return (%5))IR" ;
552+
553+ std::string target_graph = R"IR(
554+ graph(%0 : int):
555+ %1 : Tensor = prim::Constant[value=[8]]()
556+ %2 : int = prim::Constant[value=8]()
557+ %3 : int = aten::floordiv(%0, %2)
558+ return (%3))IR" ;
559+
560+ auto sg = std::make_shared<torch::jit::Graph>();
561+ torch::jit::parseIR (source_graph, &*sg);
562+
563+ // Manually enter 0d tensor const for source
564+ auto first_op_sg = *(sg->block ()->nodes ().begin ());
565+ torch::jit::Value* r_sg = sg->insertConstant (c10::scalar_to_tensor (8 ), c10::nullopt , first_op_sg->scope ());
566+ r_sg->copyMetadata (first_op_sg->output ());
567+ r_sg->setType (c10::TensorType::get ());
568+ first_op_sg->output ()->replaceAllUsesWith (r_sg);
569+ first_op_sg->destroy ();
570+
571+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
572+ torch::jit::ConstantPooling (sg);
573+ sg = torch::jit::Canonicalize (sg, false );
574+
575+ auto tg = std::make_shared<torch::jit::Graph>();
576+ torch::jit::parseIR (target_graph, &*tg);
577+
578+ // Manually enter 0d tensor const for target
579+ auto first_op_tg = *(tg->block ()->nodes ().begin ());
580+ torch::jit::Value* r_tg = tg->insertConstant (c10::scalar_to_tensor (8 ), c10::nullopt , first_op_tg->scope ());
581+ r_tg->copyMetadata (first_op_tg->output ());
582+ r_tg->setType (c10::TensorType::get ());
583+ first_op_tg->output ()->replaceAllUsesWith (r_tg);
584+ first_op_tg->destroy ();
585+
586+ torch::jit::ConstantPooling (tg);
587+ tg = torch::jit::Canonicalize (tg, false );
588+
589+ // Validate identical graphs after pooling constants and canonicalizing
590+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
591+ }
0 commit comments