Skip to content

Commit e2a9c61

Browse files
ZihengJiangyongwww
authored andcommitted
[PASS] Shape lowering (apache#16)
* [PASS] Shape lowering. * Update to IRModule based. * TIR function generation. * Improve. * Improve. * Improve test. * Improve. * Address comment.
1 parent b97d398 commit e2a9c61

File tree

9 files changed

+304
-34
lines changed

9 files changed

+304
-34
lines changed

include/tvm/relax/expr_functor.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,22 +208,22 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
208208
* visitor for types which transform them appropriately.
209209
*/
210210
virtual Type VisitType(const Type& t);
211-
virtual void VisitBinding(const Binding& binding);
212-
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
213-
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder);
211+
virtual void VisitBinding(const Binding& binding, IRBuilder& builder);
212+
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
213+
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder);
214214
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
215215
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
216216

217217
protected:
218-
LazyIRBuilder irbuilder_;
218+
IRBuilder builder_;
219219
};
220220

221221
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
222222
*/
223223
class DataflowMutator : public ExprMutator {
224224
public:
225225
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
226-
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder);
226+
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
227227

228228
protected:
229229
/*! \brief Look up the value binded to a var. */

python/tvm/relax/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=no-else-return
1818
# pylint: disable=unidiomatic-typecheck
19+
from tvm import IRModule
1920
from . import _ffi_api
2021

2122
def fma_rewrite(expr):
@@ -37,3 +38,13 @@ def explicit_memory_rewrite(expr):
3738
The input expression.
3839
"""
3940
return _ffi_api.explicit_memory_rewrite(expr)
41+
42+
def shape_lower(mod: IRModule) -> IRModule:
43+
"""Lower the shape expression in relax to shape heap and TIR functions.
44+
45+
Parameters
46+
----------
47+
expr : tvm.IRModule
48+
The input module.
49+
"""
50+
return _ffi_api.shape_lower(mod)

src/printer/relax_script_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,4 +550,4 @@ String AsRelaxScript(const ObjectRef& mod) {
550550
TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript);
551551

552552
} // namespace relax
553-
} // namespace tvm
553+
} // namespace tvm

src/relax/ir/expr_functor.cc

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ void ExprVisitor::VisitVarBinding(const VarBinding& binding) {
140140

141141
void ExprVisitor::VisitMatchShape(const MatchShape& binding) {
142142
this->VisitExpr(binding->value);
143+
// TODO(ziheng): should we change pattern from
144+
// Array<PrimExpr> to ShapeExpr?
145+
this->VisitExpr(ShapeExpr(binding->pattern));
143146
}
144147

145148
void ExprVisitor::VisitBindingBlock(const BindingBlock& block) {
@@ -321,50 +324,53 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) {
321324

322325
Type ExprMutator::VisitType(const Type& t) { return t; }
323326

324-
void ExprMutator::VisitBinding(const Binding& binding) {
327+
void ExprMutator::VisitBinding(const Binding& binding, IRBuilder& builder) {
325328
Binding new_binding;
326329
if (binding.as<VarBindingNode>()) {
327-
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
330+
this->VisitVarBinding(Downcast<VarBinding>(binding), builder);
328331
} else if (binding.as<MatchShapeNode>()) {
329-
this->VisitMatchShape(Downcast<MatchShape>(binding), this->irbuilder_);
332+
this->VisitMatchShape(Downcast<MatchShape>(binding), builder);
330333
} else {
331334
LOG(FATAL) << "Wrong type.";
332335
}
333336
}
334337

335-
Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
338+
Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) {
336339
Expr new_value = this->Mutate(binding->value);
337340
if (!binding->var.as<DataflowVarNode>()) {
338-
return ir_builder->EmitOutput(new_value);
341+
return builder->EmitOutput(new_value);
339342
} else {
340-
return ir_builder->Emit(Downcast<Call>(new_value));
343+
return builder->Emit(Downcast<Call>(new_value));
341344
}
342345
}
343346

344-
void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder) {
347+
void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& builder) {
345348
this->Mutate(binding->value);
349+
this->Mutate(ShapeExpr(binding->pattern));
346350
}
347351

348352
BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) {
349353
if (block.as<DataflowBlockNode>()) {
350354
return this->VisitDataflowBlock(Downcast<DataflowBlock>(block));
351355
} else{
352-
// TODO
353-
return block;
356+
this->builder_ = IRBuilderNode::Create();
357+
for (auto binding : block->bindings) {
358+
this->VisitBinding(binding, this->builder_);
359+
}
360+
auto blocks = this->builder_->GetBlocks();
361+
return blocks.back();
354362
}
355363
}
356364

357365
BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) {
358-
this->irbuilder_ = LazyIRBuilderNode::Create(block);
366+
this->builder_ = LazyIRBuilderNode::Create(block);
359367
{
360-
With<DataflowScope> scope(this->irbuilder_);
368+
With<DataflowScope> scope(this->builder_);
361369
for (auto binding : block->bindings) {
362-
if (binding.as<VarBindingNode>()) {
363-
this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
364-
}
370+
this->VisitBinding(binding, this->builder_);
365371
}
366372
}
367-
return this->irbuilder_->GetBlocks().back();
373+
return this->builder_->GetBlocks().back();
368374
}
369375

370376
Expr ExprMutator::VisitExpr(const Expr& expr) {
@@ -377,27 +383,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
377383
// DataflowMutator
378384

379385
BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) {
380-
this->irbuilder_ = LazyIRBuilderNode::Create(block);
386+
this->builder_ = LazyIRBuilderNode::Create(block);
381387
{
382-
With<DataflowScope> scope(this->irbuilder_);
388+
With<DataflowScope> scope(this->builder_);
383389
for (auto binding : block->bindings) {
384390
if (auto* var_binding = binding.as<VarBindingNode>()) {
385-
Var var = this->VisitVarBinding(Downcast<VarBinding>(binding), this->irbuilder_);
391+
Var var = this->VisitVarBinding(Downcast<VarBinding>(binding), this->builder_);
386392
this->pre_post_var_map_[var_binding->var] = var;
387393
}
388394
}
389395
}
390-
return this->irbuilder_->GetBlocks().back();
396+
return this->builder_->GetBlocks().back();
391397
}
392398

393-
Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) {
399+
Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) {
394400
Expr new_value = this->Mutate(binding->value);
395401
Var new_var;
396402
if (new_value.as<CallNode>()) {
397-
new_var = ir_builder->Emit(Downcast<Call>(new_value));
403+
new_var = builder->Emit(Downcast<Call>(new_value));
398404
}
399405
if (!binding->var.as<DataflowVarNode>()) {
400-
new_var = ir_builder->EmitOutput(new_value);
406+
new_var = builder->EmitOutput(new_value);
401407
}
402408
pre_post_var_map_[binding->var] = new_var;
403409
return new_var;
@@ -406,9 +412,9 @@ Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_bu
406412
Expr DataflowMutator::LookupVar(Var var) {
407413
auto it = pre_post_var_map_.find(var);
408414
if (it != pre_post_var_map_.end()) {
409-
return irbuilder_->LookupVar(it->first);
415+
return builder_->LookupVar(it->first);
410416
} else {
411-
return irbuilder_->LookupVar(var);
417+
return builder_->LookupVar(var);
412418
}
413419
}
414420
} // namespace relax

src/relax/ir/ir_builder.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& patt
142142
}
143143

144144
Var IRBuilderNode::Emit(const VarBinding& binding) {
145+
// FIXME(yuchen or ziheng): consider binding in normal block)
145146
if (!binding->var.as<DataflowVarNode>()) {
146147
return EmitOutput(binding->value);
147148
} else {
@@ -192,9 +193,14 @@ Expr IRBuilderNode::LookupVar(const Var& var) {
192193
return it->second;
193194
}
194195

195-
Function IRBuilderNode::Get() { return this->func_.func; }
196+
Function IRBuilderNode::Get() {
197+
return this->func_.func;
198+
}
196199

197-
std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; }
200+
std::vector<BindingBlock> IRBuilderNode::GetBlocks() {
201+
this->BuildBlock();
202+
return this->func_.binding_blocks;
203+
}
198204

199205
bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) {
200206
if (lhs == rhs) {

src/relax/op/op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ Expr MakeCallDPS(Expr shape, Expr func, Tuple args) {
5757
return Call(op, {shape, func, args}, {}, {});
5858
}
5959

60-
TVM_REGISTER_GLOBAL("relax.op.call_dps").set_body_typed(MakeCallDPS);
60+
TVM_REGISTER_GLOBAL("relax.op.call_dps")
61+
.set_body_typed(MakeCallDPS);
6162

6263
// shape_of
6364

0 commit comments

Comments
 (0)