Skip to content

Commit 1a62ae0

Browse files
committed
bug fix for pruner
1 parent 7119790 commit 1a62ae0

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

python/tvm/contrib/msc/core/tools/prune/pruner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]):
340340
def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None):
341341
shape = tensor.get_shape()
342342
if channel_axis is None:
343-
channel_axis = tensor.layout_of("C")
343+
if self.has_w_node(tensor.name):
344+
w_node = self.find_w_node(tensor.name)
345+
_, channel_axis = self._get_io_axes(w_node)
346+
else:
347+
channel_axis = tensor.layout_of("C")
348+
assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor)
344349
shape[channel_axis] = dim
345350
return _prune_by_shape(tensor, shape)
346351

python/tvm/contrib/msc/core/tools/tool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
16201620
in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
16211621
if in_axis >= 0 and out_axis >= 0:
16221622
return in_axis, out_axis
1623+
if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0:
1624+
io_axis = 1 - w_node.weight.layout_of("N")
1625+
return io_axis, io_axis
16231626
if w_node.weight.layout_of("C") >= 0:
16241627
return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
16251628
raise Exception("Can not infer in_axis/out_axis from " + str(w_node))

src/contrib/msc/core/ir/graph.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,9 +1116,8 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map<String, Array<Strin
11161116
const auto& tensor = node->OutputAt(0);
11171117
Map<String, String> attrs;
11181118
attrs.Set("producer_type", node->optype);
1119-
if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 &&
1120-
node->OutputAt(0)->LayoutOf("C") >= 0 &&
1121-
node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) {
1119+
if (node->optype == "reshape") {
1120+
// TODO(archermmt): check non-passby reshape
11221121
attrs.Set("weight_strategy", "passby");
11231122
} else {
11241123
attrs.Set("weight_strategy", relation_wtypes[node->optype]);

tests/python/contrib/test_msc/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_config(
4747

4848
path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools])
4949
return {
50-
"workspace": msc_utils.msc_dir(path),
50+
"workspace": msc_utils.msc_dir(path, keep_history=False),
5151
"verbose": "critical",
5252
"model_type": model_type,
5353
"inputs": inputs,

0 commit comments

Comments
 (0)