Skip to content

Commit 5ff99b6

Browse files
tqchenWei Chen
authored andcommitted
[NODE] Macro to define NodeRef methods, constructor style example (apache#3224)
1 parent e6b68b4 commit 5ff99b6

File tree

5 files changed

+86
-55
lines changed

5 files changed

+86
-55
lines changed

include/tvm/arithmetic.h

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,7 @@ namespace arith {
4848

4949
// Forward declare Analyzer
5050
class Analyzer;
51-
/*!
52-
* \brief reference class to ConstIntBoundNode
53-
* \sa ConstIntBoundNode
54-
*/
55-
class ConstIntBound;
51+
5652
/*!
5753
* \brief Constant integer up and lower bound(inclusive).
5854
* Useful for value bound analysis.
@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
6965
v->Visit("max_value", &max_value);
7066
}
7167

72-
TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);
73-
7468
/*! \brief Number to represent +inf */
7569
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
7670
/*!
@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
8377
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
8478
};
8579

86-
TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);
80+
/*!
81+
* \brief reference class to ConstIntBoundNode
82+
* \sa ConstIntBoundNode
83+
*/
84+
class ConstIntBound : public NodeRef {
85+
public:
86+
/*!
87+
* \brief constructor by fields.
88+
* \param min_value The mininum value.
89+
* \param max_value The maximum value.
90+
*/
91+
TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
92+
93+
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
94+
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
95+
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
96+
};
8797

8898
/*!
8999
* \brief Analyzer to get constant integer bound over expression.
@@ -133,11 +143,6 @@ class ConstIntBoundAnalyzer {
133143
Impl* impl_;
134144
};
135145

136-
/*!
137-
* \brief reference of ModularSetNode
138-
* \sa ModularSetNode
139-
*/
140-
class ModularSet;
141146
/*!
142147
* \brief Range of a linear integer function.
143148
* Use to do specify the possible index values.
@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
162167
v->Visit("base", &base);
163168
}
164169

165-
TVM_DLL static ModularSet make(int64_t coeff, int64_t base);
166-
167170
static constexpr const char* _type_key = "arith.ModularSet";
168171
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
169172
};
170173

171-
TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);
174+
/*!
175+
* \brief reference of ModularSetNode
176+
* \sa ModularSetNode
177+
*/
178+
class ModularSet : public NodeRef {
179+
public:
180+
TVM_DLL ModularSet(int64_t coeff, int64_t base);
181+
182+
TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
183+
};
172184

173185
/*!
174186
* \brief Analyzer to get modular information over expression.

include/tvm/base.h

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,24 @@ using ::tvm::Node;
3939
using ::tvm::NodeRef;
4040
using ::tvm::AttrVisitor;
4141

42-
/*! \brief Macro to make it easy to define node ref type given node */
43-
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
44-
class TypeName : public ::tvm::NodeRef { \
45-
public: \
46-
TypeName() {} \
47-
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
48-
const NodeName* operator->() const { \
49-
return static_cast<const NodeName*>(node_.get()); \
50-
} \
51-
using ContainerType = NodeName; \
52-
}; \
42+
/*!
43+
* \brief Macro to define common node ref methods.
44+
* \param TypeName The name of the NodeRef.
45+
* \param BaseTypeName The Base type.
46+
* \param NodeName The node container type.
47+
*/
48+
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
49+
TypeName() {} \
50+
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
51+
const NodeName* operator->() const { \
52+
return static_cast<const NodeName*>(node_.get()); \
53+
} \
54+
operator bool() const { return this->defined(); } \
55+
using ContainerType = NodeName;
5356

5457
/*!
55-
* \brief Macro to make it easy to define node ref type that
56-
* has a CopyOnWrite member function.
58+
* \brief Macro to define CopyOnWrite function in a NodeRef.
59+
* \param NodeName The Type of the Node.
5760
*
5861
* CopyOnWrite will generate a unique copy of the internal node.
5962
* The node will be copied if it is referenced by multiple places.
@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
7073
*
7174
* \endcode
7275
*/
73-
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
74-
class TypeName : public BaseType { \
75-
public: \
76-
TypeName() {} \
77-
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
78-
const NodeName* operator->() const { \
79-
return static_cast<const NodeName*>(node_.get()); \
80-
} \
81-
inline NodeName* CopyOnWrite() { \
76+
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
77+
NodeName* CopyOnWrite() { \
8278
CHECK(node_ != nullptr); \
8379
if (!node_.unique()) { \
8480
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
8581
NodePtr<Node>(std::move(n)).swap(node_); \
8682
} \
8783
return static_cast<NodeName*>(node_.get()); \
88-
} \
89-
using ContainerType = NodeName; \
90-
};
84+
}
9185

86+
/*! \brief Macro to make it easy to define node ref type given node */
87+
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
88+
class TypeName : public ::tvm::NodeRef { \
89+
public: \
90+
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
91+
}; \
92+
93+
/*!
94+
* \brief Macro to make it easy to define node ref type that
95+
* has a CopyOnWrite member function.
96+
*/
97+
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
98+
class TypeName : public BaseType { \
99+
public: \
100+
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
101+
TVM_DEFINE_NODE_REF_COW(NodeName); \
102+
};
92103

93104
/*!
94105
* \brief save the node as well as all the node it depends on as json.

src/api/api_arith.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
5858
TVM_REGISTER_API("arith.DomainTouched")
5959
.set_body_typed(DomainTouched);
6060

61-
6261
TVM_REGISTER_API("_IntervalSetGetMin")
6362
.set_body_method(&IntSet::min);
6463

@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
7170
TVM_REGISTER_API("_IntSetIsEverything")
7271
.set_body_method(&IntSet::is_everything);
7372

73+
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
74+
return ConstIntBound(min_value, max_value);
75+
}
76+
7477
TVM_REGISTER_API("arith._make_ConstIntBound")
75-
.set_body_typed(ConstIntBoundNode::make);
78+
.set_body_typed(MakeConstIntBound);
79+
80+
ModularSet MakeModularSet(int64_t coeff, int64_t base) {
81+
return ModularSet(coeff, base);
82+
}
7683

7784
TVM_REGISTER_API("arith._make_ModularSet")
78-
.set_body_typed(ModularSetNode::make);
85+
.set_body_typed(MakeModularSet);
7986

8087
TVM_REGISTER_API("arith._CreateAnalyzer")
8188
.set_body([](TVMArgs args, TVMRetValue* ret) {

src/arithmetic/const_int_bound.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ using namespace ir;
3434

3535
TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
3636

37-
ConstIntBound ConstIntBoundNode::make(
37+
ConstIntBound::ConstIntBound(
3838
int64_t min_value, int64_t max_value) {
3939
auto node = make_node<ConstIntBoundNode>();
4040
node->min_value = min_value;
4141
node->max_value = max_value;
42-
return ConstIntBound(node);
42+
node_ = std::move(node);
4343
}
4444

4545
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
289289
std::vector<BoundInfo> additional_info_;
290290
// constants: the limit value means umlimited
291291
// NOTE: kNegInf/kPosInf are used to represent infinity.
292-
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
293-
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
292+
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
293+
static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
294294
static_assert(-kNegInf == kPosInf, "invariant of inf");
295295
// internal helper functions
296296
/*!
@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
462462

463463
ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
464464
Entry ret = impl_->VisitExpr(expr);
465-
return ConstIntBoundNode::make(ret.min_value, ret.max_value);
465+
return ConstIntBound(ret.min_value, ret.max_value);
466466
}
467467

468468
void ConstIntBoundAnalyzer::Update(const Var& var,

src/arithmetic/modular_set.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ using namespace ir;
3535

3636
TVM_REGISTER_NODE_TYPE(ModularSetNode);
3737

38-
ModularSet ModularSetNode::make(int64_t coeff, int64_t base) {
38+
ModularSet::ModularSet(int64_t coeff, int64_t base) {
3939
auto node = make_node<ModularSetNode>();
4040
node->coeff = coeff;
4141
node->base = base;
42-
return ModularSet(node);
42+
// finish construction.
43+
node_ = std::move(node);
4344
}
4445

4546
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
366367
* \return Bound that represent everything dtype can represent.
367368
*/
368369
static Entry Nothing() {
369-
return Entry(0, 1);
370+
return Entry(0, 1);
370371
}
371372
};
372373

373374
ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
374375
Entry ret = impl_->VisitExpr(expr);
375-
return ModularSetNode::make(ret.coeff, ret.base);
376+
return ModularSet(ret.coeff, ret.base);
376377
}
377378

378379
void ModularSetAnalyzer::Update(const Var& var,

0 commit comments

Comments
 (0)