Skip to content

Commit b75b0ba

Browse files
End2End Lowering (apache#23)
* call_dps lowering. * Improve shape lowering. * Support alloc_storage for dynamic shape. * implementt ToNonDF to transform program to non-dataflow format. * Fix the mutator issue. * Update build api, an issue occurred. * vm tests can pass. * Support shape tuple in executable seriablization. * Fix for test. * Minor fixes. * Address comments. * Add mutate binding var back. * Visit binding var and fix tests. Co-authored-by: YuchenJin <[email protected]>
1 parent 6acf69e commit b75b0ba

File tree

14 files changed

+568
-185
lines changed

14 files changed

+568
-185
lines changed

include/tvm/relax/expr_functor.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
193193
* \return expr.
194194
*/
195195
Expr Mutate(const Expr& expr) {
196-
if (memo_.count(expr) == 0) {
197-
memo_[expr] = this->VisitExpr(expr);
198-
}
199-
return Downcast<Expr>(memo_[expr]);
196+
return this->VisitExpr(expr);
200197
}
201198

202199
Expr VisitExpr(const Expr& expr) override;
@@ -226,6 +223,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
226223
virtual void VisitBinding(const Binding& binding);
227224
virtual Var VisitVarBinding(const VarBinding& binding);
228225
virtual void VisitMatchShape(const MatchShape& binding);
226+
229227
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
230228
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
231229

python/tvm/relax/exec_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tvm
2121
from tvm._ffi._ctypes.packed_func import TVMRetValueHandle
2222
from tvm.runtime import Object
23+
from tvm.runtime.container import ShapeTuple
2324
from tvm._ffi.base import _LIB, check_call
2425
from . vm import Executable
2526
from . import _ffi_api
@@ -89,7 +90,11 @@ def emit_call(
8990
dst = SpecialReg.VOID_ARG
9091
args_ = []
9192
for arg in args:
92-
if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType):
93+
if isinstance(arg, tuple):
94+
shape_tuple = ShapeTuple(arg)
95+
new_arg = self.emit_constant(shape_tuple)
96+
args_.append(new_arg)
97+
elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)):
9398
new_arg = self.emit_constant(arg)
9499
args_.append(new_arg)
95100
else:

python/tvm/relax/transform/transform.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
# under the License.
1717
# pylint: disable=no-else-return
1818
# pylint: disable=unidiomatic-typecheck
19-
from tvm import IRModule
19+
from tvm import IRModule
2020
from . import _ffi_api
2121

22+
2223
def fma_rewrite(expr):
2324
"""Perform fused multiply add rewriting in dataflow blocks.
2425
@@ -29,22 +30,45 @@ def fma_rewrite(expr):
2930
"""
3031
return _ffi_api.fma_rewrite(expr)
3132

32-
def explicit_memory_rewrite(expr):
33-
"""Perform explicit memory allocation for call_dps in dataflow blocks.
33+
def to_non_dataflow(mod: IRModule) -> IRModule:
34+
"""Transform all dataflow structure to non-dataflow version.
3435
3536
Parameters
3637
----------
37-
expr : tvm.relay.Expr
38-
The input expression.
38+
mod : tvm.IRModule
39+
The input module.
3940
"""
40-
return _ffi_api.explicit_memory_rewrite(expr)
41+
return _ffi_api.to_non_dataflow(mod)
42+
43+
44+
def call_dps_rewrite(mod: IRModule) -> IRModule:
45+
"""Perform explicit memory allocation for call_dps.
46+
47+
Parameters
48+
----------
49+
mod : tvm.IRModule
50+
The input module.
51+
"""
52+
return _ffi_api.call_dps_rewrite(mod)
53+
54+
55+
def memory_lower(mod: IRModule) -> IRModule:
56+
"""Perform memory lowering. Lower the relax.builtin.alloc_tensor op to VM builtin functions.
57+
58+
Parameters
59+
----------
60+
mod : tvm.IRModule
61+
The input module.
62+
"""
63+
return _ffi_api.memory_lower(mod)
64+
4165

4266
def shape_lower(mod: IRModule) -> IRModule:
4367
"""Lower the shape expression in relax to shape heap and TIR functions.
4468
4569
Parameters
4670
----------
47-
expr : tvm.IRModule
71+
mod : tvm.IRModule
4872
The input module.
4973
"""
5074
return _ffi_api.shape_lower(mod)

python/tvm/relax/vm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tvm.runtime import Object, Device, Module, PackedFunc
2121
from tvm._ffi.base import _LIB, check_call
2222
from . import _ffi_api
23+
from . import transform
2324
from ..rpc.base import RPC_SESS_MASK
2425

2526

@@ -164,5 +165,9 @@ def build(mod: tvm.IRModule,
164165
lib: tvm.runtime.Module
165166
A runtime module that contains generated code.
166167
"""
167-
ex, lib = _ffi_api.VMBuild(mod, target, target_host)
168+
new_mod = transform.to_non_dataflow(mod)
169+
new_mod = transform.call_dps_rewrite(new_mod)
170+
new_mod = transform.memory_lower(new_mod)
171+
new_mod = transform.shape_lower(new_mod)
172+
ex, lib = _ffi_api.VMBuild(new_mod, target, target_host)
168173
return ex, lib

src/relax/ir/expr_functor.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,19 @@ void ExprVisitor::VisitBinding(const Binding& binding) {
120120
}
121121
}
122122

123-
void ExprVisitor::VisitVarBinding(const VarBinding& binding) { this->VisitExpr(binding->value); }
123+
void ExprVisitor::VisitVarBinding(const VarBinding& binding) {
124+
this->VisitExpr(binding->value);
125+
this->VisitExpr(binding->var);
126+
}
124127

125128
void ExprVisitor::VisitMatchShape(const MatchShape& binding) {
126129
this->VisitExpr(binding->value);
127130
// TODO(ziheng): should we change pattern from
128131
// Array<PrimExpr> to ShapeExpr?
129132
this->VisitExpr(ShapeExpr(binding->pattern));
133+
if (binding->var.defined()) {
134+
this->VisitExpr(binding->var);
135+
}
130136
}
131137

132138
void ExprVisitor::VisitBindingBlock(const BindingBlock& block) {
@@ -214,6 +220,10 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
214220
}
215221

216222
Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) {
223+
auto it = var_remap_.find(GetRef<Var>(op));
224+
if (it != var_remap_.end()) {
225+
return it->second;
226+
}
217227
if (op->type_annotation.defined()) {
218228
Type type = this->VisitType(op->type_annotation.value());
219229
if (!op->type_annotation.same_as(type)) {
@@ -339,7 +349,7 @@ void ExprMutator::VisitBinding(const Binding& binding) {
339349

340350
Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
341351
Expr new_value = builder_->Normalize(this->Mutate(binding->value));
342-
Var new_var = Downcast<Var>(this->Mutate(binding->var));
352+
343353
// TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it
344354
// in this method...
345355
// if (new_value->shape_.defined()) {
@@ -356,6 +366,7 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
356366
// new_var->checked_type_ = new_value->checked_type_;
357367
// }
358368

369+
Var new_var = Downcast<Var>(this->Mutate(binding->var));
359370
if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) ||
360371
!StructuralEqual()(new_var->checked_type(), new_value->checked_type())) {
361372
new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span);
@@ -380,7 +391,14 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
380391
void ExprMutator::VisitMatchShape(const MatchShape& binding) {
381392
Expr new_value = this->Mutate(binding->value);
382393
Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern));
383-
Var new_var = Downcast<Var>(this->Mutate(binding->var));
394+
Var new_var;
395+
if (binding->var.defined()){
396+
new_var = Downcast<Var>(this->Mutate(binding->var));
397+
} else {
398+
new_var = binding->var;
399+
}
400+
401+
// TODO: when value's shape/type changed, create new var
384402
builder_->EmitMatchShape(
385403
MatchShape(new_value, Downcast<ShapeExpr>(new_pattern)->values, new_var));
386404
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file src/relax/transform/call_dps_rewrite.cc
21+
* \brief
22+
*/
23+
#include <tvm/relax/attrs/memory.h>
24+
#include <tvm/relax/expr_functor.h>
25+
#include <tvm/relax/type.h>
26+
#include <tvm/tir/op.h>
27+
28+
#include "../../relay/transforms/pattern_utils.h"
29+
30+
namespace tvm {
31+
namespace relax {
32+
33+
// ==================
34+
// CallDPSMutator
35+
// Example:
36+
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
37+
// -->
38+
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
39+
// rx.call_packed(op.identity, x, lv0)
40+
41+
class CallDPSMutator : public ExprMutator {
42+
public:
43+
explicit CallDPSMutator(IRModule mod) { mod_ = mod; }
44+
45+
IRModule Lower() {
46+
IRModule ret_mod = IRModule();
47+
for (auto& p : mod_->functions) {
48+
Expr func = p.second;
49+
if (p.second->IsInstance<FunctionNode>()) {
50+
func = this->Mutate(p.second);
51+
}
52+
ret_mod->Add(p.first, Downcast<BaseFunc>(func));
53+
}
54+
return ret_mod;
55+
}
56+
57+
Expr VisitExpr_(const CallNode* call) override {
58+
// post-order mutation
59+
Expr expr = ExprMutator::VisitExpr_(call);
60+
call = expr.as<CallNode>();
61+
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
62+
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));
63+
64+
static const Op& call_dps_op = Op::Get("relax.call_dps");
65+
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
66+
67+
if (call->op == call_dps_op) {
68+
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
69+
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
70+
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
71+
return tensor;
72+
}
73+
74+
return GetRef<Expr>(call);
75+
}
76+
77+
private:
78+
IRModule mod_;
79+
};
80+
81+
TVM_REGISTER_GLOBAL("relax.transform.call_dps_rewrite").set_body_typed([](IRModule mod) {
82+
return CallDPSMutator(mod).Lower();
83+
});
84+
85+
} // namespace relax
86+
} // namespace tvm

src/relax/transform/memory_rewrite.cc

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* \file src/relax/transform/memory_rewrite.cc
2121
* \brief
2222
*/
23+
#include <tvm/relax/attrs/memory.h>
2324
#include <tvm/relax/expr_functor.h>
2425
#include <tvm/relax/type.h>
2526
#include <tvm/tir/op.h>
@@ -30,14 +31,31 @@ namespace tvm {
3031
namespace relax {
3132

3233
// ==================
33-
// ExplicitMemMutator
34+
// MemLowerMutator
35+
// Lower the relax.builtin.alloc_tensor op to VM builtin functions.
3436
// Example:
35-
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
37+
// x = relax.builtin.alloc_tensor((m, n))
3638
// -->
37-
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
38-
// rx.call_packed(op.identity, x, lv0)
39+
// gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type,
40+
// relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset,
41+
// (m, n), relax.attrs.AllocTensorAttrs)
42+
43+
class MemLowerMutator : public ExprMutator {
44+
public:
45+
explicit MemLowerMutator(IRModule mod) { mod_ = mod; }
46+
47+
IRModule Lower() {
48+
IRModule ret_mod = IRModule();
49+
for (auto& p : mod_->functions) {
50+
Expr func = p.second;
51+
if (p.second->IsInstance<FunctionNode>()) {
52+
func = this->Mutate(p.second);
53+
}
54+
ret_mod->Add(p.first, Downcast<BaseFunc>(func));
55+
}
56+
return ret_mod;
57+
}
3958

40-
class ExplicitMemMutator : public ExprMutator {
4159
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
4260
DynTensorType tensor_type = Downcast<DynTensorType>(type);
4361
DataType dtype = DataType(tensor_type->dtype);
@@ -63,44 +81,47 @@ class ExplicitMemMutator : public ExprMutator {
6381
return ret;
6482
}
6583

66-
BindingBlock VisitBindingBlock(const BindingBlock& block) {
67-
builder_->BeginBindingBlock();
68-
for (Binding binding : block->bindings) {
69-
this->VisitBinding(binding);
70-
}
71-
return builder_->EndBlock();
72-
}
73-
7484
Expr VisitExpr_(const CallNode* call) override {
7585
// post-order mutation
7686
Expr expr = ExprMutator::VisitExpr_(call);
7787
call = expr.as<CallNode>();
78-
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
79-
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));
8088

81-
static const Op& call_dps_op = Op::Get("relax.call_dps");
8289
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
8390

84-
if (call->op == call_dps_op) {
85-
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
86-
Type arg_type = Downcast<Tuple>(call->args[2])->fields[0]->checked_type();
87-
Expr output_size = ComputeStorageSize(output_shape, arg_type);
88-
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
89-
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
90-
return tensor;
91+
if (call->op == alloc_tensor_op) {
92+
ShapeExpr tensor_shape = Downcast<ShapeExpr>(call->args[0]);
93+
// TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor
94+
Type tensor_type = DynTensorType(tensor_shape->values.size(), DataType::Float(32));
95+
Expr storage_size = ComputeStorageSize(tensor_shape, tensor_type);
96+
ShapeExpr alignment = ShapeExpr({IntImm(DataType::Int(64), 64)});
97+
ShapeExpr device_type = ShapeExpr({IntImm(DataType::Int(64), 1)});
98+
auto storage_attr = make_object<AllocStorageAttrs>();
99+
storage_attr->dtype = DataType::Float(32);
100+
storage_attr->device_type = 1;
101+
102+
Var storage =
103+
builder_->Emit(Call(ExternFunc("vm.builtin.alloc_storage"),
104+
{storage_size, alignment}, Attrs(storage_attr)),
105+
"storage");
106+
107+
ShapeExpr offset = ShapeExpr({IntImm(DataType::Int(64), 0)});
108+
auto tensor_attr = make_object<AllocTensorAttrs>();
109+
tensor_attr->dtype = DataType::Float(32);
110+
Expr shape = call->args[0];
111+
return builder_->Emit(
112+
Call(ExternFunc("vm.builtin.alloc_tensor"), {storage, offset, shape}, Attrs(tensor_attr)),
113+
"tensor");
91114
}
92115

93116
return GetRef<Expr>(call);
94117
}
95-
};
96118

97-
Expr ExplicitMemRewrite(const Expr& e) {
98-
return ExplicitMemMutator().Mutate(e);
99-
}
119+
private:
120+
IRModule mod_;
121+
};
100122

101-
TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite")
102-
.set_body_typed([](Expr expr) {
103-
return ExplicitMemRewrite(expr);
123+
TVM_REGISTER_GLOBAL("relax.transform.memory_lower").set_body_typed([](IRModule mod) {
124+
return MemLowerMutator(mod).Lower();
104125
});
105126

106127
} // namespace relax

0 commit comments

Comments
 (0)