Skip to content

Commit f6ddd52

Browse files
authored
[microNPU] Expose compute cycle annotations to TIR lowering (#11288)
* [microNPU] Expose compute cycle annotations to TIR lowering Adds an AttrSttmt "compute_cycles_hint" to each NPU operation for later passes to consume. Change-Id: I09779bdab6de6ef2094db610bb20d6e052e68ee3 * compute_cycles->compute_cycles_hint Change-Id: Iebd71e699522e92a28fd321ffdb41ed7924db4e0 * add test to check annotations in compilation flow Change-Id: Idcdcc8c8b5536c4732f297246b71aa8378a2732c * add compute cycles hints for copy operations Change-Id: I007ba19732e16081fa2ea9baca40c64a653c93cf * fixing annotations for copies and improving test coverage Change-Id: Ib812c4151fab03f4c1adcc016b4e798003a22e5e * rebase Change-Id: I653101908706096ae25ad1ebf08e7b6c4f1196c7
1 parent 8135860 commit f6ddd52

File tree

7 files changed

+301
-51
lines changed

7 files changed

+301
-51
lines changed

python/tvm/contrib/ethosu/cascader/plan_generator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Algorithms to generate Plans for a CascaderGraph."""
18-
from typing import List, Dict
18+
from typing import List, Dict, Tuple
1919

20-
from tvm.contrib.ethosu.cascader.tensor_config import MemoryRegion
20+
from tvm.contrib.ethosu.cascader.tensor_config import MemoryRegion, TensorConfig
2121

2222
from . import _ffi_api
2323
from .cascader_options import CascaderOptions
@@ -55,3 +55,23 @@ def _generate_graph_plans(
5555
home_map,
5656
options,
5757
)
58+
59+
60+
def get_copy_cycles_hint(tensor_config: TensorConfig) -> Tuple[int, int]:
61+
"""
62+
Returns a hint estimating the number of cycles for the copy
63+
specified by tensor_config.
64+
65+
Parameters
66+
----------
67+
tensor_config : TensorConfig
68+
The tensor configuration to estimate.
69+
70+
Returns
71+
-------
72+
mem2mem_cycles : int
73+
Total estimated cycles.
74+
initial_mem2mem_cycles : int
75+
Estimated cycles for the first block.
76+
"""
77+
return _ffi_api.GetCopyCyclesHint(tensor_config)

python/tvm/contrib/ethosu/cascader/scheduler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .tensor_config import MemoryRegion
3232
from .proposal import Proposal
3333
from .proposal_generator import generate_proposals
34+
from .plan_generator import get_copy_cycles_hint
3435
from .graph import create_cascader_graph
3536
from .device_config import EthosuDeviceConfig
3637
from .logging import Logging
@@ -134,7 +135,11 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
134135
if isinstance(part, EthosuPart):
135136
tensor_config = plan.tensor_configs[part.output_tensor]
136137
stripe_config = tensor_config.stripe_configs[0]
138+
buffer_mode = tensor_config.buffer_mode
137139
block_config = part.get_block_config(stripe_config)
140+
compute_cycles = part.get_performance_info(
141+
stripe_config, buffer_mode
142+
).compute_cycles
138143
iv = part.subgraph.output_tensor.op.axis[0]
139144
block_shape = block_config.output_shape
140145
if len(block_shape) == 4:
@@ -147,6 +152,10 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
147152
sch[part.subgraph.output_tensor].pragma(iv, "block_config_width", width)
148153
sch[part.subgraph.output_tensor].pragma(iv, "block_config_depth", depth)
149154

155+
# Attach AttrStmt directly to npu op so it isn't removed by ReplaceOperators
156+
npu_op = part.subgraph.output_tensor.op.input_tensors[0].op.input_tensors[0]
157+
sch[npu_op].pragma(npu_op.op.axis[0], "compute_cycles_hint", compute_cycles)
158+
150159
output_tensor_config = plan.output_config
151160
output_tensor = output_tensor_config.tensor
152161
output_part = output_tensor.producers[0]
@@ -156,6 +165,7 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
156165
stripe_shape = [int(x) for x in stripe_config.shape]
157166
stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch)
158167
copy_te_tensors = []
168+
compute_cycles_hints = []
159169
readers = defaultdict(list)
160170
for part in plan.part_group:
161171
if part != output_part:
@@ -167,8 +177,14 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
167177
if tensor_config.home_region != tensor_config.copy_region:
168178
copy_te_tensors.append(part.subgraph.input_tensors[i])
169179

170-
for te_tensor in copy_te_tensors:
180+
compute_cycles_hint, _ = get_copy_cycles_hint(tensor_config)
181+
compute_cycles_hints.append(compute_cycles_hint)
182+
183+
for te_tensor, compute_cycles_hint in zip(copy_te_tensors, compute_cycles_hints):
171184
copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor])
185+
sch[copy_stage].pragma(
186+
copy_stage.op.axis[0], "compute_cycles_hint", compute_cycles_hint
187+
)
172188
sch[copy_stage].compute_at(stripe_stage, stripe_axis)
173189

174190

python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ def _detect_cache_read(stage):
263263
if stage.attach_type != 2: # Not inlined
264264
if _detect_cache_read(stage):
265265
fax = stage.fuse(*stage.op.axis)
266+
267+
# propagate pragmas placed on the outer loop
268+
if len(stage.op.axis) > 0 and stage.op.axis[0] in stage.iter_var_attrs:
269+
attrs = stage.iter_var_attrs[stage.op.axis[0]]
270+
for k, v in zip(attrs.pragma_keys, attrs.pragma_values):
271+
stage.pragma(fax, k.value, v)
272+
266273
stage.pragma(fax, "op", "ethosu_copy")
267274

268275

src/contrib/ethosu/cascader/plan_generator.cc

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,42 @@ int GetInteriorMemoryUsage(const std::vector<TensorConfig>& input_configs,
301301
return memory_usage;
302302
}
303303

304+
/**
305+
* \brief Returns a hint estimating the number of cycles required for
306+
* the copy specified by tensor_config.
307+
*
308+
* \param tensor_config The tensor configuration to estimate.
309+
* \return mem2mem_cycles Total estimated cycles.
310+
* \return initial_mem2mem_cycles Estimated cycles for the first block.
311+
*/
312+
std::pair<int, int> GetCopyCyclesHint(const TensorConfig& tensor_config) {
313+
Tensor tensor = tensor_config->GetTensor();
314+
MemoryRegion home_region = tensor_config->GetHomeRegion();
315+
MemoryRegion copy_region = tensor_config->GetCopyRegion();
316+
int initial_mem2mem_cycles = 0;
317+
int mem2mem_cycles = 0;
318+
319+
// This Tensor needs to be copied - Count stripes for this config
320+
for (const auto& stripe_config : tensor_config->GetStripeConfigs()) {
321+
std::map<std::vector<int>, int> input_blocks = CountStripes(stripe_config, true);
322+
bool first_block = true;
323+
for (const auto& block : input_blocks) {
324+
int bytes_transferred = mul_reduce(block.first) * tensor->GetDataType().bytes() *
325+
tensor->GetCompressionRatio() * block.second;
326+
int read_cycles = bytes_transferred * home_region->read_bandwidth + home_region->read_latency;
327+
int write_cycles = bytes_transferred * copy_region->write_bandwidth;
328+
329+
if (first_block) {
330+
first_block = false;
331+
initial_mem2mem_cycles += std::max(read_cycles, write_cycles);
332+
}
333+
mem2mem_cycles += std::max(read_cycles, write_cycles);
334+
}
335+
}
336+
337+
return {mem2mem_cycles, initial_mem2mem_cycles};
338+
}
339+
304340
std::vector<Plan> GenerateSinglePlans(
305341
const Part& part, const std::vector<StripeConfig>& output_stripe_configs,
306342
const std::unordered_map<Tensor, std::vector<MemoryRegion>, ObjectPtrHash, ObjectPtrEqual>&
@@ -372,28 +408,12 @@ std::vector<Plan> GenerateSinglePlans(
372408
BlockConfig block_config = perf_info->block_config;
373409
for (size_t i = 0; i < input_configs.size(); i++) {
374410
Tensor tensor = input_configs[i]->GetTensor();
375-
MemoryRegion home_region = input_configs[i]->GetHomeRegion();
376411
MemoryRegion copy_region = input_configs[i]->GetCopyRegion();
377412

378413
if (input_configs[i]->DoCopy()) {
379-
// This Tensor needs to be copied - Count stripes for this config
380-
for (const auto& stripe_config : input_configs[i]->GetStripeConfigs()) {
381-
std::map<std::vector<int>, int> input_blocks = CountStripes(stripe_config, true);
382-
bool first_block = true;
383-
for (const auto& block : input_blocks) {
384-
int bytes_transferred = mul_reduce(block.first) * tensor->GetDataType().bytes() *
385-
tensor->GetCompressionRatio() * block.second;
386-
int read_cycles = bytes_transferred * home_region->read_bandwidth +
387-
input_configs[i]->GetHomeRegion()->read_latency;
388-
int write_cycles = bytes_transferred * copy_region->write_bandwidth;
389-
390-
if (first_block) {
391-
first_block = false;
392-
initial_mem2mem_cycles += std::max(read_cycles, write_cycles);
393-
}
394-
mem2mem_cycles += std::max(read_cycles, write_cycles);
395-
}
396-
}
414+
std::pair<int, int> ret = GetCopyCyclesHint(input_configs[i]);
415+
mem2mem_cycles += ret.first;
416+
initial_mem2mem_cycles += ret.second;
397417
}
398418
float read_efficiency =
399419
GetTransferEfficiency(tensor, block_config->GetInputBlockShape(), copy_region);
@@ -585,6 +605,12 @@ TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GenerateGraphPlans")
585605
return tclosed_plans;
586606
});
587607

608+
TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GetCopyCyclesHint")
609+
.set_body_typed([](TensorConfig tensor_config) {
610+
std::pair<int, int> ret = GetCopyCyclesHint(tensor_config);
611+
return Array<Integer>({ret.first, ret.second});
612+
});
613+
588614
} // namespace cascader
589615
} // namespace ethosu
590616
} // namespace contrib

src/tir/contrib/ethosu/passes.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,14 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
168168
}
169169

170170
tvm::runtime::Array<tvm::PrimExpr> get_stmt_args(const Stmt& stmt) {
171-
auto eval_node{stmt.as<EvaluateNode>()};
171+
Stmt eval_stmt = stmt;
172+
if (const auto* attr_stmt = eval_stmt.as<AttrStmtNode>()) {
173+
eval_stmt = attr_stmt->body;
174+
}
175+
176+
auto eval_node{eval_stmt.as<EvaluateNode>()};
172177
ICHECK(eval_node) << "Expected statement to be an evaluate node, but was "
173-
<< stmt->GetTypeKey();
178+
<< eval_stmt->GetTypeKey();
174179
auto call_node{eval_node->value.as<CallNode>()};
175180
ICHECK(call_node) << "Expected expression to be a call node, but was "
176181
<< eval_node->value->GetTypeKey();
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wrong-import-position,invalid-name
18+
19+
"""
20+
Test the cascader in the compilation flow.
21+
"""
22+
23+
import pytest
24+
25+
pytest.importorskip("ethosu.vela")
26+
27+
import numpy as np
28+
29+
import tvm
30+
from tvm import relay
31+
from tvm.relay.backend.contrib.ethosu.codegen import _create_cascader
32+
from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
33+
from tvm.contrib.ethosu.cascader import MemoryRegion, EthosuDeviceConfig
34+
35+
from .. import infra as test_infra
36+
from . import infra as cascader_test_infra
37+
38+
39+
def _ethos_u55_cascader():
40+
sram = MemoryRegion(
41+
name="SRAM",
42+
size=10**6,
43+
read_bandwidth=16,
44+
write_bandwidth=16,
45+
read_latency=0,
46+
write_latency=0,
47+
burst_length=1,
48+
)
49+
flash = MemoryRegion(name="FLASH", size=10**7, read_bandwidth=4, write_bandwidth=4)
50+
51+
device_config = EthosuDeviceConfig("ethos-u55-256")
52+
cascader_options = cascader_test_infra.make_options(
53+
cascade_region=sram,
54+
max_proposals=64,
55+
stripe_factors=4,
56+
max_plan_size=10,
57+
max_open_plans=8,
58+
max_closed_plans=32,
59+
always_copy_size=1024,
60+
disable_pareto_plans=False,
61+
disable_pareto_proposals=False,
62+
enable_striping=False,
63+
)
64+
return _create_cascader(
65+
options=cascader_options,
66+
io_region=sram,
67+
constant_region=flash,
68+
working_regions=[sram],
69+
device_config=device_config,
70+
)
71+
72+
73+
def _compile_model(relay_function):
74+
mod = tvm.IRModule()
75+
mod["main"] = relay_function
76+
mod = relay.transform.InferType()(mod)
77+
tir_mod = _lower_to_tir(mod["main"], _ethos_u55_cascader())[0]
78+
return tir_mod["main"]
79+
80+
81+
def _create_single_conv2d():
82+
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
83+
conv1 = test_infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
84+
func = relay.Function(relay.analysis.free_vars(conv1), conv1)
85+
return func
86+
87+
88+
def _create_double_conv2d():
89+
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
90+
conv1 = test_infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
91+
conv2 = test_infra.make_ethosu_conv2d(conv1, 4, 4, (1, 3), (1, 1), (1, 1), (1, 1))
92+
func = relay.Function(relay.analysis.free_vars(conv2), conv2)
93+
return func
94+
95+
96+
def _create_scalar_add():
97+
ifm = relay.var("x", shape=(1, 5, 4, 3), dtype="int8")
98+
ifm2 = relay.const(np.ones((1, 1, 1, 1)), dtype="int8")
99+
add = test_infra.make_ethosu_binary_elementwise(
100+
ifm, ifm2, ifm_channels=3, ifm2_channels=1, operator_type="ADD", ofm_dtype="int8"
101+
)
102+
func = relay.Function(relay.analysis.free_vars(add), add)
103+
return func
104+
105+
106+
def test_single_conv_compute_cycles_hint():
107+
"""
108+
Check the "compute_cycles_hint" annotation remains in the lowering flow
109+
for single convolution.
110+
"""
111+
primfunc = _compile_model(_create_single_conv2d())
112+
ops = primfunc.body.body.body.seq
113+
114+
compute_cycles_hints = [2304, 640, 320]
115+
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
116+
assert op.attr_key == "pragma_compute_cycles_hint"
117+
assert op.value == compute_cycle_hint
118+
119+
120+
def test_double_conv_compute_cycles_hint():
121+
"""
122+
Check the "compute_cycles_hint" annotation remains in the lowering flow
123+
for double convolution.
124+
"""
125+
primfunc = _compile_model(_create_double_conv2d())
126+
ops = primfunc.body.body.body.body.body.body.seq
127+
128+
compute_cycles_hints = [2304, 640, 768, 640, 320, 240]
129+
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
130+
assert op.attr_key == "pragma_compute_cycles_hint"
131+
assert op.value == compute_cycle_hint
132+
133+
134+
def test_scalar_add_compute_cycles_hint():
135+
"""
136+
Check the "compute_cycles_hint" annotation remains in the lowering flow
137+
for add with scalar values.
138+
"""
139+
primfunc = _compile_model(_create_scalar_add())
140+
ops = primfunc.body.body.seq
141+
142+
compute_cycles_hints = [16, 24]
143+
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
144+
assert op.attr_key == "pragma_compute_cycles_hint"
145+
assert op.value == compute_cycle_hint

0 commit comments

Comments
 (0)