Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ce29d83
[TIR] Introduce tir.allocate_const to TIR
Jun 30, 2021
998fb0e
[TIR] Integrate tir constant nodes in compilation pipeline
Jun 30, 2021
2e144d0
[TIR] tir.const extraction
d-smirnov Nov 11, 2021
f539d53
unit test fixed
d-smirnov Nov 17, 2021
bf3e6f2
cmsis-nn codegen fix
d-smirnov Nov 26, 2021
083f17b
Fixes for unittests
d-smirnov Nov 30, 2021
4bd2946
Rebasing tests fixes
d-smirnov Dec 8, 2021
bae53d7
Linter: added method param description
d-smirnov Dec 8, 2021
5ae2012
Printing removal fix
d-smirnov Dec 9, 2021
bcd42e2
Bugfix
d-smirnov Dec 23, 2021
1ddc762
Reworked logic for not to introduce empty constant list to modue attrs
d-smirnov Jan 6, 2022
7926fe4
Added support for tir builtin::tvm_access_ptr
d-smirnov Jan 11, 2022
2699375
Unit test fix
d-smirnov Jan 13, 2022
c73b9b7
Addressed requested changes
d-smirnov Jan 18, 2022
0447323
Namespace usage changed to conform earlier C++ standard
d-smirnov Jan 18, 2022
aea4331
Bugfix
d-smirnov Jan 19, 2022
ef759f9
updated IRModuleNode::ExtractPrimFuncConstants
d-smirnov Jan 27, 2022
c5239da
Minor changes
d-smirnov Jan 28, 2022
e2c3df6
Moved LinkedParam/LinkedParamNode
d-smirnov Jan 31, 2022
3818327
Addressed upstream comments
d-smirnov Feb 11, 2022
66bdcb7
Removed unnecessary forward declaration
d-smirnov Feb 16, 2022
d19ba1b
Constant extractor now is a separate pass
d-smirnov Feb 16, 2022
d452f19
Added forgotten file + unit test fix
d-smirnov Feb 16, 2022
4d8ed0d
Changed to IRModule pass
d-smirnov Feb 17, 2022
d03a5a2
bugfix after rebasing
d-smirnov Feb 21, 2022
b00ead9
-v -> -vv to have more debug information
d-smirnov Feb 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,44 @@
#include <vector>

namespace tvm {
/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

class IRModule;

/*!
* \brief IRModule that holds functions and type definitions.
*
Expand Down Expand Up @@ -504,6 +541,11 @@ constexpr const char* kRuntime = "runtime";
*/
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";

/*
* \brief Module attribute for tir constants
*/
constexpr const char* kConstantsArray = "Constants";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_MODULE_H_
9 changes: 9 additions & 0 deletions include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tvm/node/functor.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/ndarray.h>

#include <functional>
#include <string>
Expand Down Expand Up @@ -199,5 +200,13 @@ class SHashReducer {
bool map_free_vars_;
};

class SEqualReducer;
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce);
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs, SEqualReducer equal);
};

} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_HASH_H_
7 changes: 4 additions & 3 deletions include/tvm/relay/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class ExecutorNode : public Object {
}

static constexpr const char* _type_key = "Executor";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object);
};

Expand All @@ -122,8 +124,6 @@ class ExecutorNode : public Object {
*/
class Executor : public ObjectRef {
public:
Executor() = default;

/*!
* \brief Create a new Executor object using the registry
* \throws Error if name is not registered
Expand All @@ -147,7 +147,8 @@ class Executor : public ObjectRef {
TVM_DLL static Map<String, String> ListExecutorOptions(const String& name);

/*! \brief specify container node */
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
TVM_DEFINE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorNode)

private:
/*!
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \param attrs Attributes for the expression to be evaluated with
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target);
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs = {});

} // namespace relay
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/relay/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class RuntimeNode : public Object {
}

static constexpr const char* _type_key = "Runtime";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object);
};

Expand Down
38 changes: 1 addition & 37 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};

/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Tensor intrinsics for tensorization
*/
Expand Down Expand Up @@ -239,7 +203,7 @@ class TensorIntrin : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
};

/*!
/*
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
Expand Down
96 changes: 94 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,16 +559,18 @@ class AllocateNode : public StmtNode {
* Otherwise return 0.
* \return The result.
*/
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);

static constexpr const char* _type_key = "tir.Allocate";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};

Expand All @@ -585,6 +587,96 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
class AllocateConstNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The optional data associated to the constant.
*/
Optional<runtime::NDArray> data;
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
this is an optional index to indicate the index within
"Constants" attribute, that is a Array<NDArray> of IRModule.
*/
Optional<Integer> irmod_storage_idx;
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
Array<PrimExpr> extents;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("data", &data);
v->Visit("irmod_storage_idx", &irmod_storage_idx);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(buffer_var);
hash_reduce(dtype);
hash_reduce(extents);
hash_reduce(body);
hash_reduce(annotations);
hash_reduce(data);
}

/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \return The result.
*/
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);

static constexpr const char* _type_key = "tir.AllocateConst";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
};

/*!
* \brief Managed reference to AllocateConstNode.
* \sa AllocateConstNode
*/
class AllocateConst : public Stmt {
public:
/* The constructor to create a IRNode with constant data
* depending on the type of ObjectRef, it will either
* create AllocateConstNode with irmod_storage_idx or data
*/
TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
ObjectRef data_or_idx, Stmt body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
Expand Down Expand Up @@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
Expand Down Expand Up @@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/function.h>

#include <string>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -601,6 +602,15 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*/
TVM_DLL Pass InjectSoftwarePipeline();

TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs? Also shall we expose this pass to python like we did for all other passes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrushao1994 @d-smirnov -- should we address this nit in a follow up PR if this CI runs good ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way works for me given it’s just a nit


/*!
* \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute.
*
* \return The pass.
*/
TVM_DLL Pass ExtractPrimFuncConstants();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
Loading