@@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
3737 torch::jit::EliminateCommonSubexpression (g);
3838 }
3939 torch::jit::EliminateDeadCode (g);
40- passes::MarkNodesForFallback (g, true );
40+ if (lower_info.forced_fallback_modules .size () > 0 ) {
41+ passes::MarkNodesForFallback (g, true );
42+ }
4143 passes::UnpackHardSwish (g);
4244 passes::EliminateExceptionOrPassPattern (g);
4345 passes::ReduceToOperation (g);
@@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6062 LOG_GRAPH (*g);
6163}
6264
63- torch::jit::Module LowerModule (
64- const torch::jit::Module& mod,
65- std::string method_name,
66- std::unordered_set<std::string> forced_fallback_modules) {
67- passes::NotateModuleForFallback (mod, " " , method_name, forced_fallback_modules);
68- LOG_GRAPH (" After MLF notation pass: " << *mod.get_method (method_name).graph ());
65+ torch::jit::Module LowerModule (const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
66+ std::unordered_set<std::string> forced_fallback_modules (
67+ lower_info.forced_fallback_modules .begin (), lower_info.forced_fallback_modules .end ());
68+ if (forced_fallback_modules.size () > 0 ) {
69+ passes::NotateModuleForFallback (mod, " " , method_name, forced_fallback_modules);
70+ LOG_GRAPH (" After MLF notation pass: " << *mod.get_method (method_name).graph ());
71+ }
6972 auto mod_ = torch::jit::freeze_module (mod);
7073 LOG_GRAPH (" After freeze: " << *mod_.get_method (method_name).graph ());
7174 return mod_;
@@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
7780 const LowerInfo& lower_info) {
7881 LOG_DEBUG (lower_info);
7982 LOG_GRAPH (" Before lowering: " << *mod.get_method (method_name).graph ());
80- std::unordered_set<std::string> forced_fallback_modules (
81- lower_info.forced_fallback_modules .begin (), lower_info.forced_fallback_modules .end ());
82- auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod, method_name, forced_fallback_modules);
83+ auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod, method_name, lower_info);
8384 auto g = lowered_mod.get_method (method_name).graph ();
8485
8586 LOG_GRAPH (" LibTorch Lowering" );
0 commit comments