Skip to content

Commit 7b9fd1e

Browse files
authored
[microNPU] Removing constant args from PrimFunc (#9951)
Before this commit, microNPU creates PrimFunc as if it accepts constants from the callee. This commit changes the PrimFunc to remove the constants as an argument to PrimFunc as they are not provided from the main function.
1 parent e6af874 commit 7b9fd1e

File tree

11 files changed

+460
-410
lines changed

11 files changed

+460
-410
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
309309
# scratch memory size.
310310
tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())
311311

312-
for idx in const_dict.keys():
313-
const_dict[idx] = tvm.nd.array(const_dict[idx])
312+
for param in const_dict.keys():
313+
const_dict[param] = tvm.nd.array(const_dict[param])
314314

315315
primfunc = tir_mod["main"]
316316
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
@@ -341,11 +341,9 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
341341
tir_mod = tvm.IRModule()
342342
tir_mod[symbol] = primfunc
343343

344-
const_dict_with_int_keys = dict()
345-
for idx in const_dict.keys():
346-
const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy()
344+
const_dict_np = dict()
345+
for buffer_var in const_dict.keys():
346+
const_dict_np[buffer_var] = const_dict[buffer_var].numpy()
347347

348-
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(
349-
tir_mod, const_dict_with_int_keys
350-
)
348+
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np)
351349
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
9090
mod = tvm.tir.transform.StorageRewrite()(mod)
9191
mod = tvm.tir.transform.RemoveNoOp()(mod)
9292
mod = ethosu_passes.AnnotateAllocates()(mod)
93+
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
9394
return mod, const_dict
9495

9596

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,3 +687,40 @@ def _ftransform(f, mod, ctx):
687687
return tvm.tir.transform.prim_func_pass(
688688
_ftransform, opt_level=0, name="tir.ethosu.remove_concatenates"
689689
)
690+
691+
692+
def CreatePrimFuncWithoutConstants(const_dict):
693+
"""
694+
This pass will remove arguments that are constants
695+
from PrimFunc Args. These should be replaced properly
696+
with tir.allocate_const when it becomes available.
697+
698+
It also modifies the constant dictionary to
699+
rewrite the keys as the actual tir.Vars that are params
700+
rather than the index because this pass removes PrimFunc
701+
arguments that represent constants.
702+
"""
703+
704+
new_const_dict = dict()
705+
706+
def _ftransform(f, mod, ctx):
707+
new_params = list()
708+
new_buffer_map = dict()
709+
for param_idx in const_dict.keys():
710+
# We are using buffer_var to key the constants as
711+
# PrimFunc params of constants will be removed.
712+
new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]
713+
for i in range(len(f.params)):
714+
if i not in const_dict.keys():
715+
new_params.append(f.params[i])
716+
new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]]
717+
return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span)
718+
719+
def _create_primfunc_without_constants(mod):
720+
transform_func = tvm.tir.transform.prim_func_pass(
721+
_ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants"
722+
)
723+
mod = transform_func(mod)
724+
return mod, new_const_dict
725+
726+
return _create_primfunc_without_constants

python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def extract_buffer_info(
208208
----------
209209
mod : tvm.IRModule
210210
The NPU TIR IRModule.
211-
param_dict : Dict[int, np.ndarray]
211+
param_dict : Dict[tvm.tir.Var, np.ndarray]
212212
A dictionary containing param idx --> const numpy.NDArray
213213
214214
Returns
@@ -222,8 +222,7 @@ def extract_buffer_info(
222222
assert len(mod.functions.items()) == 1
223223
primfunc = mod.functions.items()[0][1]
224224

225-
for idx, const_data in param_dict.items():
226-
param = primfunc.params[idx]
225+
for param, const_data in param_dict.items():
227226
buffer_info[param] = BufferInfo(
228227
const_data, const_data.shape, const_data.dtype, BufferType.constant
229228
)
@@ -257,7 +256,6 @@ def populate_allocate_buffer_info(stmt):
257256
)
258257

259258
tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info)
260-
261259
return buffer_info
262260

263261

tests/python/contrib/test_ethosu/test_compiler.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,46 @@
2020
import tvm
2121
from tvm import relay
2222
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
23+
from . import infra
2324

2425

25-
def test_lower_to_tir():
26-
data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8")
27-
weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8")
28-
p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32")
29-
conv = relay.nn.conv2d(
30-
data,
31-
weight,
32-
kernel_size=(1, 1),
33-
data_layout="NHWC",
34-
kernel_layout="HWIO",
35-
out_dtype="int32",
36-
)
37-
tile = relay.tile(p2, reps=(1, 1, 1, 1001))
38-
subtract = relay.subtract(conv, tile)
39-
func = subtract
40-
expr = relay.Function(relay.analysis.free_vars(func), func)
41-
mod = tvm.IRModule.from_expr(expr)
26+
def _create_single_conv2d():
27+
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
28+
conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
29+
func = relay.Function(relay.analysis.free_vars(conv1), conv1)
30+
return func
31+
32+
33+
def _create_double_conv2d():
34+
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
35+
conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
36+
conv2 = infra.make_ethosu_conv2d(conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1))
37+
func = relay.Function(relay.analysis.free_vars(conv2), conv2)
38+
return func
39+
40+
41+
def _create_non_linear_conv2d():
42+
shape = (1, 8, 8, 4)
43+
ifm1 = relay.var("x", shape=shape, dtype="int8")
44+
ifm2 = relay.var("y", shape=shape, dtype="int8")
45+
conv1 = infra.make_ethosu_conv2d(ifm1, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
46+
conv2 = infra.make_ethosu_conv2d(ifm2, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
47+
add = infra.make_ethosu_binary_elementwise(conv1, conv2, shape[3], shape[3], "ADD", "int8")
48+
func = relay.Function(relay.analysis.free_vars(add), add)
49+
return func
50+
51+
52+
@pytest.mark.parametrize(
53+
"relay_function, arg_count",
54+
[(_create_single_conv2d, 2), (_create_double_conv2d, 2), (_create_non_linear_conv2d, 3)],
55+
)
56+
def test_lower_to_tir_arg_count(relay_function, arg_count):
57+
mod = tvm.IRModule()
58+
mod["main"] = relay_function()
4259
mod = relay.transform.InferType()(mod)
43-
lower_to_tir(mod["main"])
60+
tir_mod = lower_to_tir(mod["main"])[0]
61+
primfunc = tir_mod["main"]
62+
assert len(primfunc.params) == arg_count
4463

4564

4665
if __name__ == "__main__":

0 commit comments

Comments
 (0)