2020 * \file src/relax/transform/memory_rewrite.cc
2121 * \brief
2222 */
23+ #include < tvm/relax/attrs/memory.h>
2324#include < tvm/relax/expr_functor.h>
2425#include < tvm/relax/type.h>
2526#include < tvm/tir/op.h>
@@ -30,14 +31,31 @@ namespace tvm {
3031namespace relax {
3132
3233// ==================
33- // ExplicitMemMutator
34+ // MemLowerMutator
35+ // Lower the relax.builtin.alloc_tensor op to VM builtin functions.
3436// Example:
35- // y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x ))
37+ // x = relax.builtin.alloc_tensor((m, n ))
3638// -->
37- // lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
38- // rx.call_packed(op.identity, x, lv0)
39+ // gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type,
40+ // relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset,
41+ // (m, n), relax.attrs.AllocTensorAttrs)
42+
43+ class MemLowerMutator : public ExprMutator {
44+ public:
45+ explicit MemLowerMutator (IRModule mod) { mod_ = mod; }
46+
47+ IRModule Lower () {
48+ IRModule ret_mod = IRModule ();
49+ for (auto & p : mod_->functions ) {
50+ Expr func = p.second ;
51+ if (p.second ->IsInstance <FunctionNode>()) {
52+ func = this ->Mutate (p.second );
53+ }
54+ ret_mod->Add (p.first , Downcast<BaseFunc>(func));
55+ }
56+ return ret_mod;
57+ }
3958
40- class ExplicitMemMutator : public ExprMutator {
4159 Expr ComputeStorageSize (const Expr& shape, const Type& type) const {
4260 DynTensorType tensor_type = Downcast<DynTensorType>(type);
4361 DataType dtype = DataType (tensor_type->dtype );
@@ -63,44 +81,47 @@ class ExplicitMemMutator : public ExprMutator {
6381 return ret;
6482 }
6583
66- BindingBlock VisitBindingBlock (const BindingBlock& block) {
67- builder_->BeginBindingBlock ();
68- for (Binding binding : block->bindings ) {
69- this ->VisitBinding (binding);
70- }
71- return builder_->EndBlock ();
72- }
73-
7484 Expr VisitExpr_ (const CallNode* call) override {
7585 // post-order mutation
7686 Expr expr = ExprMutator::VisitExpr_ (call);
7787 call = expr.as <CallNode>();
78- // TODO(@yuchen, @altanh): using mutate cause infinite recursion
79- // Expr expr = ExprMutator::Mutate(GetRef<Call>(call));
8088
81- static const Op& call_dps_op = Op::Get (" relax.call_dps" );
8289 static const Op& alloc_tensor_op = Op::Get (" relax.builtin.alloc_tensor" );
8390
84- if (call->op == call_dps_op) {
85- ShapeExpr output_shape = Downcast<ShapeExpr>(call->args [0 ]);
86- Type arg_type = Downcast<Tuple>(call->args [2 ])->fields [0 ]->checked_type ();
87- Expr output_size = ComputeStorageSize (output_shape, arg_type);
88- Var tensor = builder_->Emit (Call (alloc_tensor_op, {call->args [0 ]}), " alloc" );
89- builder_->Emit (Call (call->args [1 ], {call->args [2 ], tensor}), " _" );
90- return tensor;
91+ if (call->op == alloc_tensor_op) {
92+ ShapeExpr tensor_shape = Downcast<ShapeExpr>(call->args [0 ]);
93+ // TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor
94+ Type tensor_type = DynTensorType (tensor_shape->values .size (), DataType::Float (32 ));
95+ Expr storage_size = ComputeStorageSize (tensor_shape, tensor_type);
96+ ShapeExpr alignment = ShapeExpr ({IntImm (DataType::Int (64 ), 64 )});
97+ ShapeExpr device_type = ShapeExpr ({IntImm (DataType::Int (64 ), 1 )});
98+ auto storage_attr = make_object<AllocStorageAttrs>();
99+ storage_attr->dtype = DataType::Float (32 );
100+ storage_attr->device_type = 1 ;
101+
102+ Var storage =
103+ builder_->Emit (Call (ExternFunc (" vm.builtin.alloc_storage" ),
104+ {storage_size, alignment}, Attrs (storage_attr)),
105+ " storage" );
106+
107+ ShapeExpr offset = ShapeExpr ({IntImm (DataType::Int (64 ), 0 )});
108+ auto tensor_attr = make_object<AllocTensorAttrs>();
109+ tensor_attr->dtype = DataType::Float (32 );
110+ Expr shape = call->args [0 ];
111+ return builder_->Emit (
112+ Call (ExternFunc (" vm.builtin.alloc_tensor" ), {storage, offset, shape}, Attrs (tensor_attr)),
113+ " tensor" );
91114 }
92115
93116 return GetRef<Expr>(call);
94117 }
95- };
96118
97- Expr ExplicitMemRewrite ( const Expr& e) {
98- return ExplicitMemMutator (). Mutate (e);
99- }
119+ private:
120+ IRModule mod_;
121+ };
100122
101- TVM_REGISTER_GLOBAL (" relax.transform.explicit_memory_rewrite" )
102- .set_body_typed([](Expr expr) {
103- return ExplicitMemRewrite (expr);
123+ TVM_REGISTER_GLOBAL (" relax.transform.memory_lower" ).set_body_typed([](IRModule mod) {
124+ return MemLowerMutator (mod).Lower ();
104125});
105126
106127} // namespace relax
0 commit comments