2626#include < tvm/arith/analyzer.h>
2727#include < tvm/arith/int_solver.h>
2828#include < tvm/tir/stmt_functor.h>
29+ #include < tvm/tir/transform.h>
2930
3031#include < unordered_map>
3132#include < unordered_set>
@@ -90,13 +91,107 @@ Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body) {
9091
9192class IRConvertSSA final : public StmtExprMutator {
9293 public:
93- PrimExpr VisitExpr_ (const VarNode* op) final {
94- if (scope_.count (op) && !scope_[op].empty ()) {
95- return scope_[op].back ();
96- } else {
97- return GetRef<PrimExpr>(op);
94+ PrimFunc VisitPrimFunc (PrimFunc func) {
95+ std::vector<ScopedRedefine> redefines;
96+
97+ // Remap parameters, if they were used in another function
98+ auto params = func->params .Map ([&](const tir::Var& var) -> tir::Var {
99+ if (defined_.count (var.get ())) {
100+ const ScopedRedefine& redefine = redefines.emplace_back (this , var);
101+ return redefine.new_var ;
102+ } else {
103+ defined_.insert (var.get ());
104+ return var;
105+ }
106+ });
107+
108+ // Remap implicitly defined buffer parameters
109+ {
110+ std::unordered_set<const VarNode*> defined_params;
111+ for (const auto & var : func->params ) {
112+ defined_params.insert (var.get ());
113+ }
114+ for (const auto & [var, buffer] : func->buffer_map ) {
115+ static_cast <void >(var); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
116+ auto check_expr = [&](const PrimExpr& expr) {
117+ auto * var_ptr = expr.as <VarNode>();
118+ if (!var_ptr) return ;
119+ if (defined_params.count (var_ptr)) return ;
120+
121+ if (defined_.count (var_ptr)) {
122+ auto var = GetRef<Var>(var_ptr);
123+ redefines.emplace_back (this , var);
124+ } else {
125+ defined_.insert (var_ptr);
126+ }
127+ };
128+ for (const auto & dim : buffer->shape ) {
129+ check_expr (dim);
130+ }
131+ for (const auto & stride : buffer->strides ) {
132+ check_expr (stride);
133+ }
134+ check_expr (buffer->elem_offset );
135+ }
136+ }
137+
138+ // Update the buffer map, based on the redefined parameters
139+ auto buffer_map = [&]() {
140+ Map<Var, Buffer> buffer_map;
141+ bool made_change = false ;
142+ for (const auto & [var, buffer] : func->buffer_map ) {
143+ auto new_var = GetRemappedVar (var);
144+ auto new_buf = GetRemappedBuffer (buffer);
145+
146+ made_change = made_change || !var.same_as (new_var) || !buffer.same_as (new_buf);
147+ buffer_map.Set (new_var, new_buf);
148+ }
149+ if (made_change) {
150+ return buffer_map;
151+ } else {
152+ return func->buffer_map ;
153+ }
154+ }();
155+
156+ auto attrs = [&]() -> DictAttrs {
157+ Map<String, ObjectRef> dict;
158+ bool made_change = false ;
159+
160+ for (const auto & [key, old_value] : func->attrs ->dict ) {
161+ auto value = old_value;
162+ if (auto * expr = value.as <PrimExprNode>()) {
163+ value = VisitExpr (GetRef<PrimExpr>(expr));
164+ } else if (auto * stmt = value.as <StmtNode>()) {
165+ value = VisitStmt (GetRef<Stmt>(stmt));
166+ }
167+
168+ made_change = made_change || !value.same_as (old_value);
169+ dict.Set (key, value);
170+ }
171+
172+ if (made_change) {
173+ return DictAttrs (dict);
174+ } else {
175+ return func->attrs ;
176+ }
177+ }();
178+
179+ auto body = VisitStmt (func->body );
180+
181+ // If anything changed, update the returned function
182+ if (!params.same_as (func->params ) || !buffer_map.same_as (func->buffer_map ) ||
183+ !attrs.same_as (func->attrs ) || !body.same_as (func->body )) {
184+ func = PrimFunc (params, body, func->ret_type , buffer_map, attrs);
185+ }
186+
187+ // Pop the redefines in reverse order of creation
188+ while (redefines.size ()) {
189+ redefines.pop_back ();
98190 }
191+ return func;
99192 }
193+
194+ PrimExpr VisitExpr_ (const VarNode* op) final { return GetRemappedVar (GetRef<Var>(op)); }
100195 PrimExpr VisitExpr_ (const LetNode* op) final {
101196 const Var& v = op->var ;
102197 if (defined_.count (v.get ())) {
@@ -142,18 +237,27 @@ class IRConvertSSA final : public StmtExprMutator {
142237 return node;
143238 }
144239
240+ Var GetRemappedVar (Var var) {
241+ if (auto it = scope_.find (var.get ()); it != scope_.end () && it->second .size ()) {
242+ return it->second .back ();
243+ } else {
244+ return var;
245+ }
246+ }
247+
145248 Buffer GetRemappedBuffer (Buffer buf) {
146249 // Determine the buffer var that should be in the updated buffer,
147250 // given the current scope. If no redefines are present, then the
148251 // buffer var is unchanged.
149- Var new_buffer_var = buf->data ;
150- auto var_it = scope_. find (buf->data . get () );
151- if (var_it != scope_. end () && !var_it-> second . empty ()) {
152- new_buffer_var = var_it-> second . back ( );
153- }
252+ Var new_buffer_var = GetRemappedVar ( buf->data ) ;
253+ PrimExpr elem_offset = VisitExpr (buf->elem_offset );
254+ auto visit_expr = [ this ]( const PrimExpr& expr) { return VisitExpr (expr); };
255+ Array<PrimExpr> shape = buf-> shape . Map (visit_expr );
256+ Array<PrimExpr> strides = buf-> strides . Map (visit_expr);
154257
155258 // If no mapping is required, return the original buffer.
156- if (new_buffer_var.same_as (buf->data )) {
259+ if (new_buffer_var.same_as (buf->data ) && elem_offset.same_as (buf->elem_offset ) &&
260+ shape.same_as (buf->shape ) && strides.same_as (buf->strides )) {
157261 return buf;
158262 }
159263
@@ -169,9 +273,9 @@ class IRConvertSSA final : public StmtExprMutator {
169273 // new buffer, pushing it onto the scoped stack of existing
170274 // buffers. This will be popped when the new_buffer_var
171275 // redefinition is popped.
172- Buffer new_buf (new_buffer_var, buf->dtype , buf-> shape , buf-> strides , buf->elem_offset ,
173- buf->name , buf->data_alignment , buf->offset_factor , buf->buffer_type ,
174- buf->axis_separators , buf-> span );
276+ Buffer new_buf (new_buffer_var, buf->dtype , shape, strides, elem_offset, buf->name ,
277+ buf->data_alignment , buf->offset_factor , buf->buffer_type , buf->axis_separators ,
278+ buf->span );
175279 buffers.push_back (new_buf);
176280 return new_buf;
177281 }
@@ -239,16 +343,33 @@ class IRConvertSSA final : public StmtExprMutator {
239343 }
240344
241345 ~ScopedRedefine () {
242- parent->scope_ [old_var.get ()].pop_back ();
243- for (auto & kv : parent->buf_remap_ ) {
244- std::vector<Buffer>& buffers = kv.second ;
245- if (buffers.size () && (buffers.back ()->data .get () == new_var.get ())) {
246- buffers.pop_back ();
346+ if (parent) {
347+ parent->scope_ [old_var.get ()].pop_back ();
348+ for (auto & kv : parent->buf_remap_ ) {
349+ std::vector<Buffer>& buffers = kv.second ;
350+ if (buffers.size () && (buffers.back ()->data .get () == new_var.get ())) {
351+ buffers.pop_back ();
352+ }
247353 }
248354 }
249355 }
250356
251- IRConvertSSA* parent;
357+ ScopedRedefine& operator =(const ScopedRedefine&) = delete ;
358+ ScopedRedefine (const ScopedRedefine&) = delete ;
359+
360+ ScopedRedefine& operator =(ScopedRedefine&& other) {
361+ swap (other);
362+ return *this ;
363+ }
364+ ScopedRedefine (ScopedRedefine&& other) { swap (other); }
365+
366+ void swap (ScopedRedefine& other) {
367+ std::swap (parent, other.parent );
368+ std::swap (old_var, other.old_var );
369+ std::swap (new_var, other.new_var );
370+ }
371+
372+ IRConvertSSA* parent{nullptr };
252373 Var old_var;
253374 Var new_var;
254375 };
@@ -447,5 +568,30 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
447568 return std::make_pair (op->value , inner->value );
448569}
449570
571+ namespace transform {
572+ Pass ConvertSSA () {
573+ auto pass_func = [](IRModule mod, PassContext ctx) {
574+ tir::IRConvertSSA converter;
575+ Map<GlobalVar, BaseFunc> functions;
576+ bool made_change = false ;
577+ for (auto [gvar, base_func] : mod->functions ) {
578+ if (auto * ptr = base_func.as <tir::PrimFuncNode>()) {
579+ auto updated = converter.VisitPrimFunc (GetRef<tir::PrimFunc>(ptr));
580+ if (!updated.same_as (base_func)) {
581+ made_change = true ;
582+ base_func = updated;
583+ }
584+ }
585+ functions.Set (gvar, base_func);
586+ }
587+ if (made_change) {
588+ mod.CopyOnWrite ()->functions = std::move (functions);
589+ }
590+ return mod;
591+ };
592+ return tvm::transform::CreateModulePass (pass_func, 0 , " tir.ConvertSSA" , {});
593+ }
594+
595+ } // namespace transform
450596} // namespace tir
451597} // namespace tvm
0 commit comments