Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ class RelayExprNode : public BaseExprNode {
* the call to the function or closure is stored (instead of where the function itself is stored).
* The VirtualDevice's Target field describes how the body of the function should be compiled.
*
* Set to VirtualDevice::FullyUnconstrained by default.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import.
*/
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ConstantNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
41 changes: 31 additions & 10 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
namespace tvm {

VirtualDevice RelayExprNode::virtual_device() const {
if (virtual_device_.defined()) {
return Downcast<VirtualDevice>(this->virtual_device_);
if (!this->virtual_device_.defined()) {
// virtual_device_ should always be defined, unless we imported this node from JSON using an old
// version of TVM, in which case we want to set it to the default, which is
// VirtualDevice::FullyUnconstrained().
return VirtualDevice::FullyUnconstrained();
}
return VirtualDevice::FullyUnconstrained();
return Downcast<VirtualDevice>(this->virtual_device_);
}

namespace relay {
Expand Down Expand Up @@ -76,6 +79,7 @@ TensorType ConstantNode::tensor_type() const {
Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -100,7 +104,8 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
all_fields_unchanged = false;
}

all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device()) &&
span.same_as(tuple->span);
if (!all_fields_unchanged) {
TupleNode* cow_tuple_node = tuple.CopyOnWrite();
cow_tuple_node->fields = fields;
Expand All @@ -120,6 +125,7 @@ Var::Var(Id vid, Type type_annotation, Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
n->vid = std::move(vid);
n->type_annotation = std::move(type_annotation);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -139,7 +145,7 @@ Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation
Span span = opt_span.value_or(var->span);

bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) &&
span.same_as(var->span);
virtual_device.same_as(var->virtual_device()) && span.same_as(var->span);

if (!unchanged) {
VarNode* cow_var_node = var.CopyOnWrite();
Expand Down Expand Up @@ -174,6 +180,7 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -188,7 +195,8 @@ Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args
VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device());
Span span = opt_span.value_or(call->span);

bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span);
bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) &&
virtual_device.same_as(call->virtual_device()) && span.same_as(call->span);

// Check that the args are unchanged
if (unchanged) {
Expand Down Expand Up @@ -248,6 +256,7 @@ Let::Let(Var var, Expr value, Expr body, Span span) {
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -261,7 +270,7 @@ Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optiona
Span span = opt_span.value_or(let->span);

bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) &&
span.same_as(let->span);
virtual_device.same_as(let->virtual_device()) && span.same_as(let->span);

if (!unchanged) {
LetNode* cow_let_node = let.CopyOnWrite();
Expand Down Expand Up @@ -291,6 +300,7 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -305,7 +315,8 @@ If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branc
Span span = opt_span.value_or(if_expr->span);

bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) &&
false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span);
false_branch.same_as(if_expr->false_branch) &&
virtual_device.same_as(if_expr->virtual_device()) && span.same_as(if_expr->span);

if (!unchanged) {
IfNode* cow_if_node = if_expr.CopyOnWrite();
Expand Down Expand Up @@ -336,6 +347,7 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -349,6 +361,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
Span span = opt_span.value_or(tuple_get_item->span);

bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) &&
virtual_device.same_as(tuple_get_item->virtual_device()) &&
span.same_as(tuple_get_item->span);
if (!unchanged) {
TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite();
Expand All @@ -375,6 +388,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
RefCreate::RefCreate(Expr value, Span span) {
ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
n->value = std::move(value);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -385,7 +399,9 @@ RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value,
VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device());
Span span = opt_span.value_or(ref_create->span);

bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span);
bool unchanged = value.same_as(ref_create->value) &&
virtual_device.same_as(ref_create->virtual_device()) &&
span.same_as(ref_create->span);
if (!unchanged) {
RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite();
cow_ref_create_node->value = value;
Expand All @@ -410,6 +426,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
RefRead::RefRead(Expr ref, Span span) {
ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
n->ref = std::move(ref);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -420,7 +437,9 @@ RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref,
VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device());
Span span = opt_span.value_or(ref_read->span);

bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span);
bool unchanged = ref.same_as(ref_read->ref) &&
virtual_device.same_as(ref_read->virtual_device()) &&
span.same_as(ref_read->span);
if (!unchanged) {
RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite();
cow_ref_read_node->ref = ref;
Expand All @@ -444,6 +463,7 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) {
ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
n->ref = std::move(ref);
n->value = std::move(value);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -456,6 +476,7 @@ RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> o
Span span = opt_span.value_or(ref_write->span);

bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) &&
virtual_device.same_as(ref_write->virtual_device()) &&
span.same_as(ref_write->span);
if (!unchanged) {
RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite();
Expand Down
5 changes: 4 additions & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand All @@ -53,7 +54,9 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional
Span span = opt_span.value_or(function->span);

bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) &&
attrs.same_as(function->attrs) && span.same_as(function->span);
attrs.same_as(function->attrs) &&
virtual_device.same_as(function->virtual_device()) &&
span.same_as(function->span);

// Check that all the type params are unchanged
if (unchanged) {
Expand Down
1 change: 1 addition & 0 deletions src/relay/transforms/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Expr DeDup(const Expr& e) {
Expr DispatchVisitExpr(const Expr& e) final {
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
ret->virtual_device_ = e->virtual_device_;
return ret;
}

Expand Down
12 changes: 10 additions & 2 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def test_var():
{"type_key": ""},
{
"type_key": "relay.Var",
"attrs": {"_checked_type_": "0", "span": "0", "type_annotation": "0", "vid": "2"},
"attrs": {
"_checked_type_": "0",
"span": "0",
"type_annotation": "0",
"vid": "2",
},
},
{"type_key": "relay.Id", "attrs": {"name_hint": "a3"}},
{"type_key": "relay.TensorType", "attrs": {"dtype": "float32", "shape": "4", "span": "0"}},
Expand Down Expand Up @@ -133,7 +138,10 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar)
nodes = [
{"type_key": ""},
{"type_key": "GlobalVar", "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}},
{
"type_key": "GlobalVar",
"attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"},
},
]
data = {
"root": 1,
Expand Down