Skip to content

Commit f6d0f72

Browse files
committed
only keep opencl related code
1 parent 335f99c commit f6d0f72

File tree

3 files changed

+49
-175
lines changed

3 files changed

+49
-175
lines changed

src/relay/transforms/device_annotation.cc

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -386,38 +386,22 @@ class DeviceInfo {
386386
}
387387

388388
void VisitExpr_(const ConstantNode* cn) final {
389-
device_tag_[cn] = dev_type_;
389+
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
390390
}
391391

392392
void VisitExpr_(const CallNode* call) final {
393393
// Skip annotation nodes.
394394
if (!IsOnDeviceNode(call)) {
395-
if (const auto* node = GetDeviceCopyNode(call)) {
396-
CHECK(node->IsInstance<CallNode>());
397-
const auto* call_node = static_cast<const CallNode*>(node);
398-
auto attrs = call_node->attrs.as<DeviceCopyAttrs>();
399-
395+
if (GetDeviceCopyNode(call)) {
400396
num_device_copy_ops_++;
401397
bool has_copy_prev = has_copy_;
402398
has_copy_ = true;
403-
dev_type_ = attrs->src_dev_type;
404-
for (auto& arg : call->args) {
405-
Visit(arg);
406-
// restore the type for remaining arguments
407-
dev_type_ = attrs->src_dev_type;
408-
}
409-
device_tag_[call] = attrs->dst_dev_type;
410-
// update the out_dev_type_, which should be the dst_dev_type of last copy
411-
out_dev_type_ = attrs->dst_dev_type;
399+
ExprVisitor::VisitExpr_(call);
400+
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
412401
has_copy_ = has_copy_prev;
413402
} else {
414-
for (auto& arg : call->args) {
415-
int cur_dev_type = dev_type_;
416-
Visit(arg);
417-
// restore the type for remaining arguments
418-
dev_type_ = cur_dev_type;
419-
}
420-
device_tag_[call] = dev_type_;
403+
ExprVisitor::VisitExpr_(call);
404+
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
421405
}
422406
}
423407
}
@@ -430,24 +414,22 @@ class DeviceInfo {
430414
void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }
431415

432416
void VisitExpr_(const VarNode* vn) final {
433-
device_tag_[vn] = dev_type_;
417+
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
434418
}
435419

436420
void VisitExpr_(const LetNode* ln) final {
437421
ExprVisitor::VisitExpr_(ln);
438-
device_tag_[ln] = dev_type_;
422+
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
439423
}
440424

441425
void VisitExpr_(const IfNode* in) final {
442426
ExprVisitor::VisitExpr_(in);
443-
device_tag_[in] = dev_type_;
427+
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
444428
}
445429

446430
int num_device_copy_ops_{0};
447431
bool has_copy_ = false;
448-
int dev_type_ = -1;
449-
int out_dev_type_ = -1;
450-
std::unordered_map<const ExprNode*, int> device_tag_;
432+
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
451433
friend DeviceInfo;
452434
};
453435

@@ -473,14 +455,39 @@ class DeviceInfo {
473455
}
474456

475457
void PropagateDeviceId() {
476-
int out_dev_type = post_visitor_.out_dev_type_;
477-
for (auto& it : post_visitor_.device_tag_) {
478-
if (it.second != -1) {
479-
device_map_.Set(GetRef<Expr>(it.first), it.second);
480-
} else {
481-
device_map_.Set(GetRef<Expr>(it.first), out_dev_type);
458+
// Bottom-up propagation.
459+
int out_dev_type = BottomUpPropagation();
460+
// propagation for remained nodes.
461+
FillPropagation(out_dev_type);
462+
}
463+
464+
int BottomUpPropagation() {
465+
const CallNode* last_copy_node = nullptr;
466+
int cur_dev_type = -1;
467+
int out_dev_type = -1;
468+
for (auto it = post_visitor_.post_dfs_order_.crbegin();
469+
it != post_visitor_.post_dfs_order_.crend(); ++it) {
470+
if (const auto* node = GetDeviceCopyNode(it->first)) {
471+
CHECK(node->IsInstance<CallNode>());
472+
last_copy_node = static_cast<const CallNode*>(node);
473+
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
474+
cur_dev_type = attrs->src_dev_type;
475+
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
476+
if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
477+
} else if (last_copy_node) {
478+
Expr expr = GetRef<Expr>(it->first);
479+
CHECK_EQ(device_map_.count(expr), 0U);
480+
if (it->second) device_map_.Set(expr, cur_dev_type);
482481
}
483482
}
483+
return out_dev_type;
484+
}
485+
486+
void FillPropagation(int out_dev_type) {
487+
for (const auto& it : post_visitor_.post_dfs_order_) {
488+
Expr expr = GetRef<Expr>(it.first);
489+
if (!it.second) device_map_.Set(expr, out_dev_type);
490+
}
484491
}
485492

486493
PostDfsOrderVisitor post_visitor_;
@@ -534,9 +541,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
534541
}
535542
}
536543

537-
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
538-
return DeviceInfo::GetDeviceMap(expr);
539-
}
544+
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); }
540545

541546
Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
542547
return AnnotatationVisitor::GetAnnotations(expr);

vta/python/vta/top/graphpack.py

Lines changed: 5 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
9393
if pad_width != 0:
9494
pad_width = cfactor_out - pad_width
9595
data = op.nn.pad(data, [[0, 0], [0, pad_width], [0, 0], [0, 0]])
96-
dshape = tuple([dshape[0]] + [dshape[1] + pad_width, dshape[2], dshape[3]])
96+
dshape = tuple(dshape[0], [dshape[1] + pad_width, dshape[2], dshape[3]])
9797

9898
if channels_pad != 0:
9999
channels = channels + (cfactor_out - channels_pad)
@@ -174,104 +174,6 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
174174
operator_current_idx = operator_current_idx + 1
175175
return operator_current_idx
176176

177-
178-
class ExprDeviceAnnot(ExprMutator):
179-
"""Visitor to perform graph annotation on an AST.
180-
181-
Parameters
182-
----------
183-
start: int
184-
the start location to mark run on vta (inclusive)
185-
end: int
186-
the end location to mark run on vta (exclusive)
187-
188-
Returns
189-
---------
190-
None
191-
"""
192-
def __init__(self, start=-1, end=-1):
193-
self.ext_ctx = tvm.context("ext_dev")
194-
self.cpu_ctx = tvm.context("cpu")
195-
self.cast = op.op.get("cast")
196-
self.counter = -1
197-
self.start = start
198-
self.end = end
199-
super().__init__()
200-
201-
def visit_call(self, call):
202-
""" Visit the children. """
203-
# First visit the children.
204-
oshape = _get_tensor_shape(call)
205-
odtype = _get_tensor_type(call)
206-
input_types = [arg.checked_type for arg in call.args]
207-
args = [self.visit(arg) for arg in call.args]
208-
209-
self.counter += 1
210-
if self.counter == self.start:
211-
ret = relay.Call(call.op, args, call.attrs)
212-
ret = relay.annotation.on_device(ret, self.ext_ctx)
213-
return ret
214-
elif self.counter == self.end:
215-
ret = relay.Call(call.op, args, call.attrs)
216-
ret = relay.annotation.on_device(ret, self.cpu_ctx)
217-
return ret
218-
elif self.counter > self.start and self.counter < self.end:
219-
ret = relay.Call(call.op, args, call.attrs)
220-
221-
# skip the float op, i.e., float->int cast
222-
if self.is_float_op(call):
223-
return ret
224-
225-
return relay.annotation.on_device(ret, self.ext_ctx)
226-
227-
return relay.Call(self.visit(call.op), args, call.attrs)
228-
229-
def is_float_op(self, call):
230-
"""check if this op belongs to a float op
231-
in general, float op's odtype is float;
232-
a special case is float->int cast, which follow this op sequence:
233-
multiply(float) -> round(float) -> clip(float) -> cast(int);
234-
"""
235-
args = call.args
236-
odtype = _get_tensor_type(call)
237-
op = call.op
238-
239-
if odtype == "float32":
240-
return True
241-
elif op == self.cast:
242-
idtype = _get_tensor_type(args[0])
243-
if idtype == "float32":
244-
return True
245-
246-
return False
247-
248-
249-
class ExprLocater(ExprMutator):
250-
"""Visitor to locate op on an AST.
251-
"""
252-
def __init__(self):
253-
self.counter = -1
254-
self.op2nodes = {}
255-
super().__init__()
256-
257-
def visit_call(self, call):
258-
""" Visit the children. """
259-
# First visit the children.
260-
args = [self.visit(arg) for arg in call.args]
261-
262-
odtype = _get_tensor_type(call)
263-
self.counter += 1
264-
if (call.op, odtype) in self.op2nodes:
265-
self.op2nodes[(call.op, odtype)].append(self.counter)
266-
else:
267-
self.op2nodes[(call.op, odtype)] = [self.counter]
268-
269-
return relay.Call(
270-
self.visit(call.op),
271-
args,
272-
call.attrs)
273-
274-
275177
class ExprPack(ExprMutator):
276178
"""Visitor to perform graph packing on an AST.
277179
"""
@@ -415,7 +317,7 @@ def visit_call(self, call):
415317
elif self.start_pack and call.op == op.op.get('cast') and \
416318
input_types[0].dtype == 'int32':
417319
cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
418-
return cast
320+
return relay.Call(op.op.get('copy'), [cast])
419321
elif call.op == self.pad:
420322
pad_width = call.attrs.pad_width
421323
if len(pad_width) == 6:
@@ -510,10 +412,7 @@ def graph_pack(expr,
510412
stop_name="nn.global_avg_pool2d",
511413
start_name_idx=None,
512414
stop_name_idx=None,
513-
count_meta=False,
514-
device_annot=False,
515-
annot_start_name="nn.conv2d",
516-
annot_end_name="annotation.stop_fusion"):
415+
count_meta=False):
517416
"""Pack the graph into batch&channel packed format.
518417
519418
Parameters
@@ -550,47 +449,18 @@ def graph_pack(expr,
550449
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
551450
logic would count the meta.
552451
553-
device_annot: boolean, optional
554-
if we want to annoate the device_type
555-
556-
annot_start_name: str, optional
557-
device annotation start node, from which we mark the nodes as `ext_dev`
558-
559-
annot_end_name: str, optional
560-
device annotation end node, after which we mark the nodes as 'cpu'
561-
562452
Returns
563453
-------
564454
expr : Expr
565455
The transformed expression.
566456
"""
567457
assert isinstance(expr, relay.Function)
568-
assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \
569-
(not (start_name_idx is None and stop_name_idx is None)) or (start_name_idx < stop_name_idx))
458+
assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
570459
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
571460
expr = run_opt_pass(expr, transform.InferType())
572461
packer = ExprPack(
573462
bfactor, cfactor,
574463
weight_bits)
575464
expr = packer.visit(expr)
576465
assert not packer.start_pack
577-
expr = run_opt_pass(expr, transform.InferType())
578-
579-
if device_annot:
580-
expr_locator = ExprLocater()
581-
expr_locator.visit(expr)
582-
583-
annot_start = op.op.get(annot_start_name)
584-
start = expr_locator.op2nodes[(annot_start, "int32")][0]
585-
586-
annot_end = op.op.get(annot_end_name)
587-
# we mark the next op to the last stop_fusion on cpu device
588-
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1
589-
590-
device_annot = ExprDeviceAnnot(start=start, end=end)
591-
expr = device_annot.visit(expr)
592-
ret = run_opt_pass(expr, transform.InferType())
593-
594-
return ret
595-
else:
596-
return expr
466+
return run_opt_pass(expr, transform.InferType())

vta/tutorials/autotvm/tune_alu_vta.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# Compile network
4343
# ---------------
4444
# Perform vta-specific compilation with Relay from a Gluon model
45-
def compile_network(env, target, model, start_pack, stop_pack, device_annot=False):
45+
def compile_network(env, target, model, start_pack, stop_pack):
4646

4747
# Populate the shape and data type dictionary
4848
dtype_dict = {"data": 'float32'}
@@ -70,8 +70,7 @@ def compile_network(env, target, model, start_pack, stop_pack, device_annot=Fals
7070
env.BLOCK_OUT,
7171
env.WGT_WIDTH,
7272
start_name=start_pack,
73-
stop_name=stop_pack,
74-
device_annot=device_annot)
73+
stop_name=stop_pack)
7574

7675
return relay_prog, params
7776

0 commit comments

Comments
 (0)