Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/relay/pass/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
}

Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) {
if (IsOnDeviceNode(call_node)) {
Copy link
Member

@jroesch jroesch May 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only meaningful change right? this will just search deeper in the tree to rewrite until we hit another device annotation? just want to make sure I 100% understand this time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for example, on_device connects a and b expression, a -> on_device -> b, we have inserted device_copy when we visit b. Anytime when we visit on_device, if means we either have inserted device_copy op, or it is not necessary, therefore, we can safely delete on_device by returning 'a.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay cool, LGTM

return this->VisitExpr(call_node->args[0]);
}

if (IsDeviceCopyNode(call_node)) {
return ExprMutator::VisitExpr_(call_node);
}

Expand Down Expand Up @@ -358,6 +362,9 @@ class DeviceInfo {
public:
void Visit(const Expr& expr) {
if (const auto* fn = expr.as<FunctionNode>()) {
for (const auto& param : fn->params) {
this->VisitExpr(param);
}
this->VisitExpr(fn->body);
} else {
this->VisitExpr(expr);
Expand Down Expand Up @@ -402,7 +409,7 @@ class DeviceInfo {
}

void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}

void VisitExpr_(const LetNode* ln) final {
Expand Down
126 changes: 85 additions & 41 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.expr_functor import ExprMutator


def test_redundant_annotation():
Expand All @@ -34,21 +35,22 @@ def annotated():
add = relay.add(x, y)
_add1 = relay.annotation.on_device(add, ctx2)
_add2 = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
sub1 = relay.subtract(_add1, z)
sub2 = relay.subtract(_add2, z)

func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add1, _add2,
sub])))
func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
return func

def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx2, ctx1)
sub = relay.subtract(copy_add_sub, z)
func = relay.Function([x, y, z], sub)
copy_add_sub1 = relay.device_copy(add, ctx2, ctx1)
sub1 = relay.subtract(copy_add_sub1, z)
copy_add_sub2 = relay.device_copy(add, ctx2, ctx1)
sub2 = relay.subtract(copy_add_sub2, z)
func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
return func

annotated_func = relay.ir_pass.infer_type(annotated())
Expand All @@ -66,10 +68,9 @@ def test_annotate_expr():
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx1)
sub = relay.subtract(add, z)
sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2)
expr = relay.Tuple([sub, _add, _sub])
expr = relay.ir_pass.infer_type(expr)
expr = relay.ir_pass.infer_type(_sub)
expr = relay.ir_pass.rewrite_annotated_ops(expr,
ctx1.device_type)
return expr
Expand All @@ -95,12 +96,10 @@ def test_annotate_all():
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2)

func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add, _sub,
sub])))
func = relay.Function([x, y, z], _sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
Expand Down Expand Up @@ -168,6 +167,34 @@ def test_conv_network():
dev1 = tvm.context(1)
dev2 = tvm.context(2)

def original():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
add = relay.add(conv2d_1, conv2d_2)
conv2d_3 = relay.nn.conv2d(
add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))

func = relay.Function([data1, data2, weight], conv2d_3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
return func


def annotated():
conv2d_1 = relay.nn.conv2d(
data1,
Expand All @@ -183,25 +210,40 @@ def annotated():
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
add = relay.add(conv2d_1, conv2d_2)
add = relay.add(_conv2d_1, _conv2d_2)
_add = relay.annotation.on_device(add, dev1)
conv2d_3 = relay.nn.conv2d(
add,
_add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)

func = relay.Function([data1, data2, weight],
relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2,
_conv2d_3, _add,
conv2d_3])))
func = relay.Function([data1, data2, weight], _conv2d_3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
return func

class ScheduleConv2d(ExprMutator):
def __init__(self, device):
self.device = device
super().__init__()

def visit_call(self, expr):
visit = super().visit_call(expr)
if expr.op == tvm.relay.op.get("nn.conv2d"):
return relay.annotation.on_device(visit, self.device)
else:
return visit

def annotate_with_visitor(func):
sched = ScheduleConv2d(dev2)
func = sched.visit(func)
func = relay.ir_pass.rewrite_annotated_ops(func, dev1.device_type)
return func

def expected():
conv2d_1 = relay.nn.conv2d(
data1,
Expand Down Expand Up @@ -249,10 +291,19 @@ def check_storage_and_device_types():
assert len(set(device_types)) == 2
assert set(device_types) == {1, 2}

annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_manual_annotation():
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()

def test_visitor_annotation():
annotated_func = annotate_with_visitor(original())
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)

test_manual_annotation()
test_visitor_annotation()


def run_fusible_network(dev, tgt):
Expand Down Expand Up @@ -321,12 +372,11 @@ def annotated():
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
subtract = relay.subtract(_sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)

func = relay.Function([x, y],
relay.Tuple(tvm.convert([_sqrt, _exp, exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
Expand Down Expand Up @@ -364,19 +414,16 @@ def test_fuse_all(device, tgt):
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, dev_ctx)
sqrt = relay.sqrt(add)
sqrt = relay.sqrt(_add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
log = relay.log(_add)
_log = relay.annotation.on_device(log, dev_ctx)
subtract = relay.subtract(sqrt, log)
subtract = relay.subtract(_sqrt, _log)
_subtract = relay.annotation.on_device(subtract, dev_ctx)
exp = relay.exp(subtract)
exp = relay.exp(_subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)

func = relay.Function([x, y],
relay.Tuple(tvm.convert([_add, _sqrt, _log,
_subtract, _exp,
exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
Expand All @@ -401,8 +448,7 @@ def annotated():
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, cpu_ctx)

func = relay.Function([x, y],
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
Expand Down Expand Up @@ -472,11 +518,9 @@ def annotated():
_add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul)
sub = relay.subtract(_add, _mul)
_sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d],
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.Function([a, b, c, d], _sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
Expand Down