Skip to content

Commit ff814b7

Browse files
lhutton1ylc
authored andcommitted
[microNPU] Allow constants to be given as input to an operator (apache#9515)
* [microNPU] Allow constants to be given as input to an operator Currently the expectation is that all constants need to be encoded, however, this is not always the case for scalar inputs. This PR ensures that constants that don't need encoding are not treated like encoded constants by the EncodeConstants pass. Change-Id: I79cf4aa10d01c4ae9ce9cdafb6f21ebb2d028126 * address comments Change-Id: I67b61a2d2f67de25c47d2ace0e3a22c59ba8ea15
1 parent 8e8791b commit ff814b7

File tree

4 files changed

+150
-3
lines changed

4 files changed

+150
-3
lines changed

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,19 @@ def _visit_rewrite(stmt):
392392
# For extern calls, we need to rewrite pairs of arguments corresponding to
393393
# base address load and the length of the load.
394394
new_args = [stmt.args[0]]
395+
new_buffers = rewrite_buffer.values()
395396
for i in range(1, len(stmt.args)):
396397
# If the previous argument was a load, the current should be a length
397398
if isinstance(stmt.args[i - 1], tvm.tir.Load):
398399
load = stmt.args[i - 1]
399400
pointer = load.buffer_var
400401
if pointer in pointer_to_buffer:
401-
new_args.append(np.prod(list(pointer_to_buffer[pointer].shape)))
402-
continue
402+
buffer = pointer_to_buffer[pointer]
403+
# Only rewrite the arguments of buffers that have been encoded
404+
if buffer in new_buffers:
405+
new_arg = np.prod(list(pointer_to_buffer[pointer].shape))
406+
new_args.append(new_arg)
407+
continue
403408
new_args.append(stmt.args[i])
404409

405410
return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span)

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,56 @@ def representative_dataset():
435435
infra.verify_source(compiled_models, accel_type)
436436

437437

438+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
439+
def test_binary_add_from_constant_scalar(accel_type):
440+
dtype = "uint8"
441+
ifm_shape = (1, 4, 4, 8)
442+
443+
def create_relay_graph():
444+
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
445+
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
446+
add = relay.qnn.op.add(
447+
inp,
448+
scalar,
449+
relay.const(1.0, dtype="float32"),
450+
relay.const(0, dtype="int32"),
451+
relay.const(1.0, dtype="float32"),
452+
relay.const(0, dtype="int32"),
453+
relay.const(1.0, dtype="float32"),
454+
relay.const(0, dtype="int32"),
455+
)
456+
func = relay.Function(relay.analysis.free_vars(add), add)
457+
return tvm.IRModule.from_expr(func)
458+
459+
mod = create_relay_graph()
460+
partitioned_mod = partition_for_ethosu(mod)
461+
462+
# Generate reference data
463+
input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
464+
output_data = generate_ref_data(mod, input_data)
465+
466+
compiled_models = infra.build_source(
467+
partitioned_mod,
468+
input_data,
469+
output_data,
470+
accel_type,
471+
output_tolerance=0,
472+
)
473+
474+
# Assumes only two runtime.Modules are created -- i.e. single offload module
475+
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
476+
assert len(imported_modules) == 2
477+
ethosu_module = imported_modules[0]
478+
479+
# Verify generated C source
480+
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
481+
cmms = get_cs(ethosu_module)
482+
cmms = bytes.fromhex(cmms)
483+
484+
infra.print_payload(cmms)
485+
infra.verify_source(compiled_models, accel_type)
486+
487+
438488
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
439489
@pytest.mark.parametrize(
440490
"ifm_shape, ifm2_shape",

tests/python/contrib/test_ethosu/test_encode_constants.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import pytest
18+
import numpy as np
1819

1920
pytest.importorskip("ethosu.vela")
2021
import tvm
@@ -23,8 +24,10 @@
2324
from tvm.relay.testing import run_opt_pass
2425
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
2526
from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute
27+
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
28+
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
2629

27-
from .infra import make_ethosu_conv2d
30+
from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise
2831

2932

3033
# fmt: off
@@ -270,5 +273,47 @@ def _get_func():
270273
assert reference_const_sizes == test_const_sizes
271274

272275

276+
def test_constant_as_input():
277+
"""Test to check that constants specified as inputs aren't
278+
interpreted as an encoded constant."""
279+
280+
def get_graph():
281+
dtype = "uint8"
282+
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype)
283+
conv1 = make_ethosu_conv2d(
284+
ifm,
285+
32,
286+
16,
287+
(1, 1),
288+
(0, 0),
289+
(1, 1),
290+
(1, 1),
291+
)
292+
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
293+
add1 = make_ethosu_binary_elementwise(
294+
conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype
295+
)
296+
func = relay.Function(relay.analysis.free_vars(add1), add1)
297+
func = run_opt_pass(func, relay.transform.InferType())
298+
return func
299+
300+
tir_mod, params = lower_to_tir(get_graph(), copy_constants())
301+
302+
# Check tile address for the scalar constant input hasn't been
303+
# overwritten.
304+
extern_calls = tir_mod["main"].body.body.body.body.body
305+
binary_elementwise = extern_calls[-1].value
306+
args = binary_elementwise.args
307+
308+
reason = "Tile address overwritten"
309+
assert args[26] == 0, reason
310+
assert args[27] == 0, reason
311+
assert args[28] == 0, reason
312+
313+
# More generally, check compiles successfully to make sure
314+
# nothing else was overrwritten.
315+
tir_to_cs_translator.translate(tir_mod, params)
316+
317+
273318
if __name__ == "__main__":
274319
pytest.main([__file__])

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,53 @@ def verify(ext_func):
693693
verify(mod["tvmgen_default_ethos_u_main_0"])
694694

695695

696+
def test_binary_add_from_constant_scalar():
697+
dtype = "uint8"
698+
ifm_shape = (1, 4, 4, 8)
699+
700+
def create_graph():
701+
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
702+
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
703+
add = relay.qnn.op.add(
704+
inp,
705+
scalar,
706+
relay.const(1.0, dtype="float32"),
707+
relay.const(0, dtype="int32"),
708+
relay.const(1.0, dtype="float32"),
709+
relay.const(0, dtype="int32"),
710+
relay.const(1.0, dtype="float32"),
711+
relay.const(0, dtype="int32"),
712+
)
713+
func = relay.Function(relay.analysis.free_vars(add), add)
714+
return tvm.IRModule.from_expr(func)
715+
716+
def verify(ext_func):
717+
op = ext_func.body
718+
assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8]
719+
assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1]
720+
assert op.args[0].checked_type.dtype == "uint8"
721+
assert list(op.checked_type.shape) == [1, 4, 4, 8]
722+
assert op.checked_type.dtype == "uint8"
723+
assert op.attrs.operator_type == "ADD"
724+
725+
rewriter = legalize.AddRewriter()
726+
pattern_table = [
727+
(
728+
ethosu.AddParams.composite_name,
729+
ethosu.qnn_add_pattern(),
730+
lambda pat: ethosu.AddParams(pat).is_valid(),
731+
),
732+
]
733+
734+
mod = create_graph()
735+
mod = partition_ethosu_by_table(mod, pattern_table)
736+
737+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
738+
rewriter, mod["tvmgen_default_ethos_u_main_0"]
739+
)
740+
verify(mod["tvmgen_default_ethos_u_main_0"])
741+
742+
696743
@pytest.mark.parametrize(
697744
"ifm_shape, ifm2_shape, reversed_operands",
698745
[

0 commit comments

Comments
 (0)