@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
4646 Expr new_e = ExprMutator::VisitExpr_ (call_node);
4747 Call new_call = Downcast<Call>(new_e);
4848
49+ // Check if the string is registered in the OpRegistry.
50+ if (!Op::HasAttr (legalize_map_attr_name_)) {
51+ return new_e;
52+ }
53+
4954 // Collect the registered legalize function.
5055 auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
51- Op op = Downcast<Op>(call_node->op );
52-
53- if (fop_legalize.count (op)) {
54- // Collect the new_args.
55- tvm::Array<Expr> call_args = new_call->args ;
56-
57- // Collect input and output dtypes to pass on to Legalize API.
58- tvm::Array<tvm::relay::Type> types;
59- for (auto arg : call_node->args ) {
60- types.push_back (arg->checked_type ());
61- }
62- types.push_back (call_node->checked_type ());
63-
64- // Transform the op by calling the registered legalize function.
65- Expr legalized_value = fop_legalize[op](call_node->attrs , call_args, types);
66-
67- // Reassign new_e if the transformation succeeded.
68- if (legalized_value.defined ()) {
69- // Check that the returned Expr from legalize is CallNode.
70- const CallNode* legalized_call_node = legalized_value.as <CallNode>();
71- CHECK (legalized_call_node)
72- << " Can only replace the original operator with another call node" ;
73-
74- new_e = legalized_value;
56+ auto call_op = call_node->op ;
57+ if (call_op.as <OpNode>()) {
58+ Op op = Downcast<Op>(call_node->op );
59+
60+ if (fop_legalize.count (op)) {
61+ // Collect the new_args.
62+ tvm::Array<Expr> call_args = new_call->args ;
63+
64+ // Collect input and output dtypes to pass on to Legalize API.
65+ tvm::Array<tvm::relay::Type> types;
66+ for (auto arg : call_node->args ) {
67+ types.push_back (arg->checked_type ());
68+ }
69+ types.push_back (call_node->checked_type ());
70+
71+ // Transform the op by calling the registered legalize function.
72+ Expr legalized_value = fop_legalize[op](call_node->attrs , call_args, types);
73+
74+ // Reassign new_e if the transformation succeeded.
75+ if (legalized_value.defined ()) {
76+ // Check that the returned Expr from legalize is CallNode.
77+ const CallNode* legalized_call_node = legalized_value.as <CallNode>();
78+ CHECK (legalized_call_node)
79+ << " Can only replace the original operator with another call node" ;
80+
81+ new_e = legalized_value;
82+ }
7583 }
7684 }
7785
@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
95103 [=](Function f, Module m, PassContext pc) {
96104 return Downcast<Function>(relay::legalize::Legalize (f, legalize_map_attr_name));
97105 };
98- return CreateFunctionPass (pass_func, 3 , " Legalize" , {ir::StringImm::make (" InferType" )});
106+ return CreateFunctionPass (pass_func, 0 , " Legalize" , {ir::StringImm::make (" InferType" )});
99107}
100108
101109TVM_REGISTER_API (" relay._transform.Legalize" ).set_body_typed(Legalize);
0 commit comments