Skip to content

Commit 0e19a91

Browse files
author
Ashutosh Parkhi
committed
[CMSIS-NN] code generator for softmax
Change-Id: Ie248f82c67240885be880ddbb2b547bbe61b97fb
1 parent ee03a12 commit 0e19a91

File tree

12 files changed

+542
-6
lines changed

12 files changed

+542
-6
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
4747
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
4848
tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF)
4949
tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF)
50+
tvm_option(USE_CMSISNN "Build with Arm CMSIS-NN" OFF)
5051
tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON)
5152
tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" AUTO)
5253
tvm_option(BUILD_STATIC_RUNTIME "Build static version of libtvm_runtime" OFF)
@@ -390,6 +391,7 @@ include(cmake/modules/ROCM.cmake)
390391
include(cmake/modules/LLVM.cmake)
391392
include(cmake/modules/Micro.cmake)
392393
include(cmake/modules/contrib/EthosN.cmake)
394+
include(cmake/modules/contrib/CMSISNN.cmake)
393395
include(cmake/modules/contrib/BLAS.cmake)
394396
include(cmake/modules/contrib/CODEGENC.cmake)
395397
include(cmake/modules/contrib/DNNL.cmake)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
if(USE_CMSISNN)
19+
message(STATUS "Build with CMSIS-NN support")
20+
file(GLOB RELAY_CONTRIB_CMSISNN_SRCS src/relay/backend/contrib/cmsisnn/*.cc)
21+
list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS})
22+
endif(USE_CMSISNN)

python/tvm/driver/tvmc/composite_target.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
2626
from tvm.relay.op.contrib.ethosn import partition_for_ethosn
27+
from tvm.relay.op.contrib.cmsisnn import partition_for_cmsisnn
2728
from tvm.relay.op.contrib.bnns import partition_for_bnns
2829
from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai
2930

@@ -49,6 +50,10 @@
4950
"config_key": None,
5051
"pass_pipeline": partition_for_arm_compute_lib,
5152
},
53+
"cmsis-nn": {
54+
"config_key": None,
55+
"pass_pipeline": partition_for_cmsisnn,
56+
},
5257
"ethos-n77": {
5358
"config_key": "relay.ext.ethos-n.options",
5459
"pass_pipeline": partition_for_ethosn,

python/tvm/relay/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# under the License.
1717
"""Backend codegen modules for relay."""
1818
from . import compile_engine
19+
from .contrib import cmsisnn
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""external backend codegen modules for relay."""
18+
from . import cmsisnn
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""CMSIS-NN codegen modules for relay."""
18+
from . import codegen
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""FFI APIs for CMSIS-NN relay transformation passes."""
18+
import tvm._ffi
19+
20+
tvm._ffi._init_api("relay.ext.cmsisnn", __name__)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Codegen for CMSIS-NN"""
18+
import tvm
19+
from tvm import relay
20+
from tvm.relay.expr_functor import ExprVisitor
21+
22+
23+
def generate_tir(name, func):
24+
"""Generates TIR"""
25+
26+
class GenerateTIR(ExprVisitor):
27+
"""Generates TIR module containing TIR primfuncs corresponding to the Relay operators.
28+
Note: Relay operator to primfunc mapping may not be 1:1.
29+
"""
30+
31+
def __init__(self, name):
32+
super().__init__()
33+
self.name = name
34+
self.tir_mod = None
35+
self.scale = 1.0 / 256
36+
37+
def call_contains_op(self, call, op_name):
38+
if not isinstance(call.op, tvm.ir.op.Op):
39+
return False
40+
if call.op.name != op_name:
41+
return False
42+
return True
43+
44+
def is_quantized_softmax(self, call):
45+
"""Checks for the following relay sequence
46+
a = qnn.dequantize(in, scale, zero_point)
47+
b = nn.softmax(a)
48+
c = qnn.quantize(c, scale, zero_point)
49+
"""
50+
if not self.call_contains_op(call, "qnn.quantize"):
51+
return False
52+
softmax_call = call.args[0]
53+
if not self.call_contains_op(softmax_call, "nn.softmax"):
54+
return False
55+
dequantize_call = softmax_call.args[0]
56+
if not self.call_contains_op(dequantize_call, "qnn.dequantize"):
57+
return False
58+
if not call.attrs.out_dtype == "int8":
59+
return False
60+
self.scale = dequantize_call.args[1].data.numpy().item(0)
61+
return True
62+
63+
def emit_softmax_tir(self, call):
64+
"""Generates TIR extern_call for softmax"""
65+
shape = call.checked_type.shape # NHWC
66+
dtype = call.checked_type.dtype
67+
ir_builder = tvm.tir.ir_builder.create()
68+
in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
69+
out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
70+
num_rows = shape[0] * shape[1] * shape[2]
71+
row_size = shape[3]
72+
ir_builder.emit(
73+
tvm.tir.call_extern(
74+
dtype,
75+
"arm_softmax_s8",
76+
in_buf.data,
77+
num_rows,
78+
row_size,
79+
self.scale,
80+
out_buf.data,
81+
)
82+
)
83+
prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get())
84+
prim_func = prim_func.with_attr("global_symbol", self.name)
85+
prim_func = prim_func.with_attr("tir.noalias", True)
86+
self.tir_mod = tvm.IRModule({self.name: prim_func})
87+
88+
def visit_call(self, call):
89+
"""Iterates over the relay operators within relay external function"""
90+
super().visit_call(call)
91+
if self.is_quantized_softmax(call):
92+
self.emit_softmax_tir(call)
93+
94+
def generate_tir(self, func):
95+
self.visit(func)
96+
return self.tir_mod
97+
98+
tir_mod = GenerateTIR(name).generate_tir(func)
99+
return tir_mod
100+
101+
102+
def relay_to_tir(name, func):
103+
"""Lower a Relay function to TIR for the CMSIS-NN target.
104+
105+
The Relay function should only contain operations supported
106+
by the CMSIS-NN target. This is enforced by the graph partitioner
107+
for CMSIS-NN.
108+
109+
Parameters
110+
----------
111+
name: str
112+
Name of the external relay function
113+
func : tvm.relay.Function
114+
The Relay function to lower.
115+
116+
Returns
117+
-------
118+
mod : tvm.IRModule
119+
The lowered TIR module.
120+
121+
"""
122+
tir_mod = generate_tir(name, func)
123+
return tir_mod
124+
125+
126+
@tvm._ffi.register_func("relay.ext.cmsisnn")
127+
def cmsisnn_compiler(relay_func):
128+
"""It compiles Relay's external function into equivalent TIR
129+
and subsequently converts that into 'c' code. During the 'c'
130+
code generation, it embeds CMSIS-NN APIs for the corresponding
131+
operators.
132+
"""
133+
assert isinstance(relay_func, tvm.ir.function.BaseFunc)
134+
mod = tvm.IRModule()
135+
mod["main"] = relay_func
136+
mod = relay.transform.InferType()(mod)
137+
func_name = relay_func.attrs["global_symbol"]
138+
tir_mod = relay_to_tir(func_name, mod["main"])
139+
cmsisnn_runtime = tvm._ffi.get_global_func("runtime.module.cmsisnn.create")
140+
return cmsisnn_runtime(tir_mod)

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ def softmax_pattern():
6767

6868
def check_quantized_softmax(extract):
6969
"""Check if softmax is supported by CMSIS-NN."""
70+
dequantize_call = extract.args[0].args[0]
71+
scale = dequantize_call.args[1].data.numpy().item(0)
72+
zero_point = dequantize_call.args[2].data.numpy().item(0)
7073

7174
# check for dtypes of quantize and dequantize
7275
return (
73-
extract.attrs.out_dtype == "int8"
74-
and extract.args[0].args[0].args[0].checked_type.dtype == "int8"
76+
(scale == 1.0 / 256 and zero_point == -128)
77+
and extract.attrs.out_dtype == "int8"
78+
and dequantize_call.args[0].checked_type.dtype == "int8"
7579
)
7680

7781
return [

0 commit comments

Comments
 (0)