2828#include < tvm/tir/op.h>
2929
3030#include " ../../../op/call/call.h"
31+ #include " tvm/tir/function.h"
3132
3233namespace tvm {
3334namespace relay {
3435namespace contrib {
3536namespace example_target_hooks {
3637
38+ namespace {
39+
40+ /* !
41+ * \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay
42+ * Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a
43+ * TIR PrimFunc implementing subtraction.
44+ *
45+ * Illustrates six aspects a custom 'lowering' style pass may need to account for:
46+ * - Lowerable functions can appear inline as calls ops, bound to let-bound variables, or as
47+ * global functions.
48+ * - Let-bound lowerable functions should be inlined on-the-fly since after processing the
49+ * let-binding is no longer required.
50+ * - There may be multiple calls to the same lowerable function. All calls need to be
51+ * rewritten, even though the function itself need be rewritten only once.
52+ * - GlobalVars must be shared between all calls and the new definition itself.
53+ * - Calls to lowered functions must use the "call_lowered" calling convention.
54+ * - The Target::Current() may hold an instance of the TargetKind from which the custom Pass
55+ * was extracted.
56+ *
57+ * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add
58+ * runtime::Modules to the output IRModule's "external_mods" attribute.
59+ */
3760class ConvertAddToSubtract : public MixedModeMutator {
3861 public:
3962 explicit ConvertAddToSubtract (IRModule ir_module, Target host_target)
@@ -56,51 +79,105 @@ class ConvertAddToSubtract : public MixedModeMutator {
5679 return tir::BufferLoad (buffer, {index});
5780 }
5881
59- void ReplaceAddWithSubtractPrimFunc (const GlobalVar& new_global_var, const Function& func) {
60- tir::Buffer x_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " x" );
61- tir::Buffer y_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " y" );
62- tir::Buffer out_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ));
82+ GlobalVar ReplaceAddWithSubtractPrimFunc (const Function& func) {
83+ auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
84+ ICHECK (func_name.defined ());
85+
86+ // --------------------------------------------------------------------------------------------
87+ // Cases:
88+ // - Inline function:
89+ // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call.
90+ // - Thereafter (via object sharing): discover global var already in module, replace call
91+ // - Global function:
92+ // - func_name == global_var->name_hint
93+ // - First encounter: rewrite to PrimFunc and update binding, replace call
94+ // - Thereafter (via global var): Just replace call
95+ // - func_name != global_var->name_hint
96+ // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call
97+ // (The original Relay function should also be tagged as 'extern', ie given attribute
98+ // "ExternalSymbol".)
99+ // - Thereafter (via global var): discover global var already in module, replace call
100+ // --------------------------------------------------------------------------------------------
101+
102+ // If necessary, introduce a new global var to map the function to and copy the source type
103+ // over for InferType.
104+ GlobalVar global_var;
105+ bool need_rewriting;
106+ if (ir_module_->ContainGlobalVar (func_name.value ())) {
107+ global_var = ir_module_->GetGlobalVar (func_name.value ());
108+ // Only rewrite to a PrimFunc if the global definition is still a Relay function.
109+ need_rewriting = ir_module_->Lookup (global_var)->IsInstance <FunctionNode>();
110+ } else {
111+ global_var = GlobalVar (func_name.value ());
112+ global_var->checked_type_ = func->checked_type ();
113+ need_rewriting = true ;
114+ }
115+
116+ // For illustration only, check if the current target matches the example_target_hook kind,
117+ // and if so extract the example attribute value.
118+ int64_t example_attribute_value = 0 ;
119+ Optional<Target> opt_current_target = Target::Current ();
120+ if (opt_current_target.defined () &&
121+ opt_current_target.value ()->kind ->name == " example_target_hook" ) {
122+ example_attribute_value =
123+ opt_current_target.value ()->GetAttr <Integer>(" example_attribute" ).value ()->value ;
124+ }
63125
64- tir::Var x_var (" x" , DataType::Handle ());
65- tir::Var y_var (" y" , DataType::Handle ());
66- tir::Var out_var (" out" , DataType::Handle ());
126+ if (need_rewriting) {
127+ // The called function is still in Relay form. Convert to TIR.
128+ tir::Buffer x_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " x" );
129+ tir::Buffer y_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " y" );
130+ tir::Buffer out_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ));
67131
68- Map<String, ObjectRef> dict_attrs ;
69- dict_attrs. Set ( " global_symbol " , new_global_var-> name_hint );
70- dict_attrs. Set ( " tir.noalias " , Bool ( true ));
132+ tir::Var x_var ( " x " , DataType::Handle ()) ;
133+ tir::Var y_var ( " y " , DataType::Handle () );
134+ tir::Var out_var ( " out " , DataType::Handle ( ));
71135
72- te::Var index (" index" , DataType::Int (32 ));
73- tir::Sub indexed_sub = tir::Sub (LoadIndex (x_buffer, index), LoadIndex (y_buffer, index));
74- tir::Stmt math_body = tir::BufferStore (out_buffer, indexed_sub, {index});
75- tir::Stmt math_loop = tir::For (index, 0 , 8 , tir::ForKind::kSerial , math_body);
136+ Map<String, ObjectRef> dict_attrs;
137+ dict_attrs.Set (" global_symbol" , global_var->name_hint );
138+ dict_attrs.Set (" tir.noalias" , Bool (true ));
76139
77- Map<tir::Var, tir::Buffer> buffer_map = {
78- {x_var, x_buffer},
79- {y_var, y_buffer},
80- {out_var, out_buffer},
81- };
140+ te::Var index (" index" , DataType::Int (32 ));
141+ tir::Sub indexed_sub = tir::Sub (LoadIndex (x_buffer, index), LoadIndex (y_buffer, index));
142+ if (example_attribute_value > 0 ) {
143+ // For illustration only, fold the example attribute into the result.
144+ indexed_sub = tir::Sub (indexed_sub, FloatImm (DataType::Float (32 ),
145+ static_cast <double >(example_attribute_value)));
146+ }
82147
83- tir::PrimFunc replacement_func = tir::PrimFunc ({x_var, y_var, out_var}, math_loop, VoidType (),
84- buffer_map, {}, DictAttrs (dict_attrs) );
148+ tir::Stmt math_body = tir::BufferStore (out_buffer, indexed_sub, {index});
149+ tir::Stmt math_loop = tir::For (index, 0 , 8 , tir::ForKind:: kSerial , math_body );
85150
86- // Switch to TIRToRuntime hook for testing
87- Bool tir_to_runtime = func->GetAttr <Bool>(" tir_to_runtime" ).value_or (Bool (false ));
88- if (tir_to_runtime) {
89- replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , custom_target_);
90- } else {
91- replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , host_target_);
151+ Map<tir::Var, tir::Buffer> buffer_map = {
152+ {x_var, x_buffer},
153+ {y_var, y_buffer},
154+ {out_var, out_buffer},
155+ };
156+
157+ tir::PrimFunc replacement_func = tir::PrimFunc ({x_var, y_var, out_var}, math_loop, VoidType (),
158+ buffer_map, {}, DictAttrs (dict_attrs));
159+
160+ // Switch to TIRToRuntime hook for testing
161+ Bool tir_to_runtime = func->GetAttr <Bool>(" tir_to_runtime" ).value_or (Bool (false ));
162+ if (tir_to_runtime) {
163+ replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , custom_target_);
164+ } else {
165+ replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , host_target_);
166+ }
167+
168+ ir_module_->Update (global_var, replacement_func); // Will Add if global_var is new.
92169 }
93170
94- ir_module_-> Add (new_global_var, replacement_func) ;
171+ return global_var ;
95172 }
96173
97174 Expr VisitExpr_ (const LetNode* op) final {
98175 auto pre_visit = [this ](const LetNode* op) {
99176 Expr var = this ->VisitExpr (op->var );
100177 Expr value = this ->VisitExpr (op->value );
101178
102- // Outlineable function no longer needs let binding
103- if ( this -> CanLowerExpr ( value)) {
179+ if ( AsLowerableFunction (value)) {
180+ // Inline on-the-fly if the let-bound value is lowerable.
104181 this ->memo_ [var] = value;
105182 }
106183 };
@@ -110,8 +187,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
110187 Expr body = this ->VisitExpr (op->body );
111188 auto expr = GetRef<Expr>(op);
112189
113- // Drop the let binding
114- if ( this -> CanLowerExpr (value)) {
190+ if ( AsLowerableFunction (value)) {
191+ // The let binding is no longer needed since inlined on-the-fly above.
115192 this ->memo_ [expr] = this ->VisitExpr (op->body );
116193 } else {
117194 Var var = Downcast<Var>(this ->VisitExpr (op->var ));
@@ -126,39 +203,49 @@ class ConvertAddToSubtract : public MixedModeMutator {
126203 return memo_[GetRef<Expr>(op)];
127204 }
128205
129- bool CanLowerExpr (const Expr& expr) {
130- const auto * func = expr.as <FunctionNode>();
131- if (func == nullptr ) {
132- return false ;
133- }
134- auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
135- if (!func_name.defined ()) {
136- return false ;
206+ const FunctionNode* AsLowerableFunction (const Expr& expr) {
207+ if (const auto * function_node = expr.as <FunctionNode>()) {
208+ auto func_name = function_node->GetAttr <String>(::tvm::attr::kGlobalSymbol );
209+ if (!func_name.defined ()) {
210+ return nullptr ;
211+ }
212+ if (func_name != " replace_add_with_subtract" ) {
213+ return nullptr ;
214+ }
215+ return function_node;
216+ } else if (const auto * global_var_node = expr.as <GlobalVarNode>()) {
217+ return AsLowerableFunction (ir_module_->Lookup (GetRef<GlobalVar>(global_var_node)));
218+ } else {
219+ return nullptr ;
137220 }
138- if (func_name != " replace_add_with_subtract" ) {
139- return false ;
221+ }
222+
223+ const GlobalVarNode* AsAlreadyLoweredFunction (const Expr& expr) {
224+ if (const auto * global_var_node = expr.as <GlobalVarNode>()) {
225+ if (ir_module_->Lookup (GetRef<GlobalVar>(global_var_node)).as <tir::PrimFuncNode>()) {
226+ return global_var_node;
227+ }
140228 }
141- return true ;
229+ return nullptr ;
142230 }
143231
144232 Expr Rewrite_ (const CallNode* pre , const Expr& post ) override {
145- if (const CallNode* call = post .as <CallNode>()) {
146- if (CanLowerExpr (call->op )) {
147- auto * func = call->op .as <FunctionNode>();
148- auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
149-
150- // Introduce a new global var to map the function to and copy the source type
151- // over for InferType
152- GlobalVar new_global_var (func_name.value ());
153- new_global_var->checked_type_ = func->checked_type ();
154- ReplaceAddWithSubtractPrimFunc (new_global_var, GetRef<Function>(func));
155-
233+ if (const auto * call = post .as <CallNode>()) {
234+ GlobalVar new_op;
235+ if (const auto * function_node = AsLowerableFunction (call->op )) {
236+ // Add or replace the function with a PrimFunc.
237+ new_op = ReplaceAddWithSubtractPrimFunc (GetRef<Function>(function_node));
238+ } else if (const auto * global_var_node = AsAlreadyLoweredFunction (call->op )) {
239+ // The function has already been rewritten, so we just need to update the call.
240+ new_op = GetRef<GlobalVar>(global_var_node);
241+ }
242+ if (new_op.defined ()) {
156243 // Since we are replacing the Relay function with a call to a TIR function, we must use
157244 // the call_lowered op.
158245 CallLoweredAttrs attrs;
159246 attrs.metadata .Set (" relay_attrs" , call->attrs );
160247 ICHECK (call->type_args .empty ()) << " lowered functions cannot be polymorphic" ;
161- return CallLowered (std::move (new_global_var ), call->args , std::move (attrs), call->span );
248+ return CallLowered (std::move (new_op ), call->args , std::move (attrs), call->span );
162249 }
163250 }
164251
@@ -171,10 +258,12 @@ class ConvertAddToSubtract : public MixedModeMutator {
171258 Target custom_target_;
172259};
173260
261+ } // namespace
262+
174263transform::Pass RelayToTIR () {
175264 runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
176265 [=](IRModule ir_module, transform::PassContext pass_context) {
177- auto relay_to_tir = ConvertAddToSubtract ( ir_module, Target (" c" ));
266+ ConvertAddToSubtract relay_to_tir ( std::move ( ir_module) , Target (" c" ));
178267 return relay_to_tir.Mutate ();
179268 };
180269 return tvm::transform::CreateModulePass (pass_func, 0 , " RelayToTIR" , {});
0 commit comments