Skip to content

Commit c4eb59f

Browse files
committed
Manifest memory allocations
1 parent 1853ea2 commit c4eb59f

File tree

26 files changed

+392
-310
lines changed

26 files changed

+392
-310
lines changed

cmake/util/FindANTLR.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,5 @@ macro(find_antlr use_antlr)
6161
elseif(NOT ${use_antlr} STREQUAL "OFF")
6262
set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr})
6363
endif()
64-
message(STATUS "ANTLR4="${ANTLR4})
64+
message(STATUS "ANTLR4=${ANTLR4}")
6565
endmacro(find_antlr)

include/tvm/expr_operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
#ifndef TVM_EXPR_OPERATOR_H_
2929
#define TVM_EXPR_OPERATOR_H_
3030

31+
#include <tvm/expr.h>
32+
#include <tvm/ir.h>
3133
#include <algorithm>
3234
#include <type_traits>
33-
#include "expr.h"
34-
#include "ir.h"
3535

3636
namespace tvm {
3737

include/tvm/relay/attrs/annotation.h

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#define TVM_RELAY_ATTRS_ANNOTATION_H_
2626

2727
#include <tvm/attrs.h>
28-
#include <tvm/relay/expr.h>
2928
#include <string>
3029

3130
namespace tvm {
@@ -58,42 +57,6 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
5857
}
5958
};
6059

61-
/*!
62-
* \brief Options for the device annotation operators.
63-
*/
64-
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
65-
tvm::relay::Constant const_shape;
66-
Array<IndexExpr> assert_shape;
67-
DataType dtype;
68-
69-
TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
70-
TVM_ATTR_FIELD(dtype)
71-
.describe(
72-
"The virutal device/context type that an expression is annotated with.")
73-
.set_default(Float(32, 1));
74-
TVM_ATTR_FIELD(const_shape)
75-
.describe(
76-
"The virutal device/context type that an expression is annotated with.");
77-
TVM_ATTR_FIELD(assert_shape)
78-
.describe(
79-
"The virutal device/context type that an expression is annotated with.");
80-
}
81-
};
82-
83-
/*!
84-
* \brief Options for the device annotation operators.
85-
*/
86-
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
87-
bool dependent{false};
88-
89-
TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
90-
TVM_ATTR_FIELD(dependent)
91-
.describe(
92-
"Wheather the shape function is input dependent.")
93-
.set_default(false);
94-
}
95-
};
96-
9760
} // namespace relay
9861
} // namespace tvm
9962
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_

include/tvm/relay/attrs/memory.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace tvm {
3232
namespace relay {
3333

3434
/*!
35-
* \brief Options for the device annotation operators.
35+
* \brief Options for allocating tensors.
3636
*/
3737
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
3838
tvm::relay::Constant const_shape;
@@ -46,24 +46,25 @@ struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
4646
.set_default(Float(32, 1));
4747
TVM_ATTR_FIELD(const_shape)
4848
.describe(
49-
"The shape if constant used to aid in type inference.");
49+
"The shape of constant used to aid in type inference.");
5050
TVM_ATTR_FIELD(assert_shape)
5151
.describe(
52-
"The shape to cast the return type of the allocation to, used to specify the shape obtained via further analysis.");
52+
"The shape to cast the return type of the allocation to, "\
53+
"used to specify the shape obtained via further analysis.");
5354
}
5455
};
5556

5657
/*!
57-
* \brief Options for the device annotation operators.
58+
* \brief Options for the shape function operator.
5859
*/
5960
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
60-
bool dependent{false};
61+
Array<tvm::Integer> is_input;
6162

6263
TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
63-
TVM_ATTR_FIELD(dependent)
64+
TVM_ATTR_FIELD(is_input)
6465
.describe(
65-
"Wheather the shape function is input dependent.")
66-
.set_default(false);
66+
"A bool indicating whether the shape function should"\
67+
"expect shape or input in each position.");
6768
}
6869
};
6970

include/tvm/relay/expr_functor.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,29 @@ class ExprMutator
230230
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
231231
};
232232

233+
/*! \brief A helper class for matching and rewriting operators. */
233234
template<typename R>
234235
class OpMatch {
235236
public:
236237
using MatchFunc =
237238
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
238239

240+
/*! \brief Match an operator with the given name.
241+
* \param op_name The name of the operator to match.
242+
* \param func The function to execute when it matches.
243+
* \return A self-reference for builder style API.
244+
*/
239245
inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
240246
auto op = Op::Get(op_name);
241247
match_map_.insert({op, func});
242248
return *this;
243249
}
244250

251+
/*! \brief Rewrite a call operation based on the operator and the registered
252+
* match functions.
253+
* \param call The call to rewrite.
254+
* \return The result of rewriting.
255+
*/
245256
inline R operator()(const Call& call) {
246257
auto it = match_map_.find(Downcast<Op>(call->op));
247258
if (it != match_map_.end()) {
@@ -256,7 +267,9 @@ class OpMatch {
256267
}
257268

258269
private:
270+
/*! \brief The match function map. */
259271
std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
272+
/*! \brief An optional default case. */
260273
MatchFunc default_;
261274
};
262275

include/tvm/runtime/object.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ class Object {
283283
* \note The deleter will be called when ref_counter_ becomes zero.
284284
*/
285285
inline void DecRef();
286+
287+
private:
286288
/*!
287289
* \return The usage count of the cell.
288290
* \note We use stl style naming to be consistent with known API in shared_ptr.
@@ -675,6 +677,16 @@ struct ObjectEqual {
675677
operator bool() const { return data_ != nullptr; } \
676678
using ContainerType = ObjectName;
677679

680+
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
681+
TypeName() {} \
682+
explicit TypeName( \
683+
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
684+
: ParentType(n) {} \
685+
ObjectName* operator->() { \
686+
return static_cast<ObjectName*>(data_.get()); \
687+
} \
688+
operator bool() const { return data_ != nullptr; } \
689+
using ContainerType = ObjectName;
678690

679691
// Implementations details below
680692
// Object reference counting.

include/tvm/runtime/vm.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ struct Instruction {
159159

160160
union {
161161
struct /* AllocTensor Operands */ {
162+
/*! \brief The storage to allocate from. */
162163
RegName storage;
163164
/*! \brief The number of dimensions. */
164165
uint32_t ndim;
@@ -168,6 +169,7 @@ struct Instruction {
168169
DLDataType dtype;
169170
} alloc_tensor;
170171
struct /* AllocTensorReg Operands */ {
172+
/*! \brief The storage to allocate from. */
171173
RegName storage;
172174
/*! \brief The register to read the shape out of. */
173175
RegName shape_register;
@@ -257,8 +259,11 @@ struct Instruction {
257259
RegName* free_vars;
258260
};
259261
struct /* AllocStorage Operands */ {
262+
/*! \brief The size of the allocation. */
260263
RegName allocation_size;
264+
/*! \brief The alignment of the allocation. */
261265
RegName alignment;
266+
/*! \brief The hint of the dtype. */
262267
DLDataType dtype_hint;
263268
} alloc_storage;
264269
};
@@ -282,30 +287,32 @@ struct Instruction {
282287
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
283288
const std::vector<RegName>& args);
284289
/*! \brief Construct an allocate tensor instruction with constant shape.
290+
* \param storage The storage to allocate out of.
285291
* \param shape The shape of the tensor.
286292
* \param dtype The dtype of the tensor.
287293
* \param dst The destination register.
288294
* \return The allocate tensor instruction.
289295
*/
290-
static Instruction AllocTensor(RegName storage, const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
296+
static Instruction AllocTensor(RegName storage,
297+
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
291298
/*! \brief Construct an allocate tensor instruction with register.
292-
* \param The storage to allocate out of.
299+
* \param storage The storage to allocate out of.
293300
* \param shape_register The register containing the shape.
294301
* \param dtype The dtype of the tensor.
295302
* \param dst The destination register.
296303
* \return The allocate tensor instruction.
297304
*/
298-
static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, RegName dst);
305+
static Instruction AllocTensorReg(RegName storage,
306+
RegName shape_register, DLDataType dtype, RegName dst);
299307
/*! \brief Construct an allocate datatype instruction.
300-
* \param The storage to allocate out of.
301308
* \param tag The datatype tag.
302309
* \param num_fields The number of fields for the datatype.
303310
* \param fields The registers containing the fields.
304311
* \param dst The register name of the destination.
305312
* \return The allocate instruction tensor.
306313
*/
307314
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
308-
RegName dst);
315+
RegName dst);
309316
/*! \brief Construct an allocate closure instruction.
310317
* \param func_index The index of the function table.
311318
* \param num_freevar The number of free variables.
@@ -381,7 +388,8 @@ struct Instruction {
381388
* \param dst The destination to place the storage.
382389
* \return The alloc storage instruction.
383390
*/
384-
static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint, RegName dst);
391+
static Instruction AllocStorage(RegName size, RegName alignment,
392+
DLDataType dtype_hint, RegName dst);
385393

386394
Instruction();
387395
Instruction(const Instruction& instr);

python/tvm/relay/debug.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def _debugger_init(expr, stack):
2727

2828
@register_func("relay.debug")
2929
def _debug(*args):
30-
import pdb; pdb.set_trace()
30+
import pdb
31+
pdb.set_trace()
3132

3233
# pylint: disable=unused-argument
3334
@register_func("relay.debug_interp")

python/tvm/relay/expr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ def set_params(self, params):
315315
if isinstance(value, NDArray):
316316
params[key] = Constant(value)
317317

318+
return _expr.FunctionSetParams(self, params)
319+
318320
def set_attribute(self, name, ref):
319321
return _expr.FunctionSetAttr(self, name, ref)
320322

0 commit comments

Comments
 (0)