Skip to content

Commit 11d73d2

Browse files
Update on "[executorch] Add coreml quant recipes"
Fixing tests for stack that got reverted: #13265 Adds coreml quant recipes after FP32/16 recipes added in #13121 Recipes added: PT2E_INT8_STATIC PT2E_INT8_WEIGHT_ONLY INT4_WEIGHT_ONLY_PER_CHANNEL INT4_WEIGHT_ONLY_PER_GROUP INT8_WEIGHT_ONLY_PER_CHANNEL INT8_WEIGHT_ONLY_PER_GROUP CODEBOOK_WEIGHT_ONLY Differential Revision: [D80206542](https://our.internmc.facebook.com/intern/diff/D80206542/) [ghstack-poisoned]
2 parents 2fec78b + 2165def commit 11d73d2

25 files changed

+413
-108
lines changed

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,9 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
192192
// Use a temporary allocator for the intermediate tensors of the
193193
// computation. The allocator is released in runtime/executor/method.cpp at
194194
// the end of the execution of the Ethos-U custom delegate
195-
char* ethosu_scratch =
196-
static_cast<char*>(temp_allocator->allocate(handles.scratch_data_size));
195+
// Ethos-U driver requires 16 bit alignment.
196+
char* ethosu_scratch = static_cast<char*>(
197+
temp_allocator->allocate(handles.scratch_data_size, 16UL));
197198
if (ethosu_scratch == nullptr) {
198199
ET_LOG(
199200
Error,

backends/arm/test/test_arm_baremetal.sh

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly ins
1717

1818

1919
TEST_SUITE=$1
20-
TOSA_VERSION="${2:-TOSA-1.0+INT}"
2120

2221
# Source the tools
2322
# This should be prepared by the setup.sh
@@ -157,17 +156,23 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
157156

158157
# TOSA quantized
159158
echo "${TEST_SUITE_NAME}: Test ethos-u target TOSA"
160-
examples/arm/run.sh --et_build_root=arm_test/test_run --target=${TOSA_VERSION} --model_name=add
161-
examples/arm/run.sh --et_build_root=arm_test/test_run --target=${TOSA_VERSION} --model_name=mul
159+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=TOSA-1.0+INT --model_name=add
160+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=TOSA-1.0+INT --model_name=mul
162161

163162
# Ethos-U55
164163
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U55"
165164
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=add
165+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=add --bundleio
166+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=add --bundleio --etdump
167+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=add --etdump
166168
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=mul
167169

168170
# Ethos-U85
169171
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
170172
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
173+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add --bundleio
174+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add --bundleio --etdump
175+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add --etdump
171176
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul
172177

173178
# Cortex-M op tests
@@ -187,17 +192,17 @@ test_models_tosa() { # End to End model tests using model_test.py
187192

188193
# TOSA quantized
189194
echo "${TEST_SUITE_NAME}: Test ethos-u target TOSA"
190-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=mv2
191-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=mv3
192-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=lstm
193-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=edsr
194-
# python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=emformer_transcribe # Takes long time to run
195-
# python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=emformer_join # Takes long time to run
196-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=w2l
197-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=ic3
198-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=ic4
199-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=resnet18
200-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=${TOSA_VERSION} --model=resnet50
195+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=mv2
196+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=mv3
197+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=lstm
198+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=edsr
199+
# python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=emformer_transcribe # Takes long time to run
200+
# python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=emformer_join # Takes long time to run
201+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=w2l
202+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=ic3
203+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=ic4
204+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=resnet18
205+
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=TOSA-1.0+INT --model=resnet50
201206

202207
echo "${TEST_SUITE_NAME}: PASS"
203208
}

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@
219219
- arg_meta: null
220220
kernel_name: impl::reference::quantized_relu_per_tensor_out
221221

222+
- func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: impl::reference::quantized_relu_asym8s_asym8s_per_tensor_out
226+
227+
- func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
228+
kernels:
229+
- arg_meta: null
230+
kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out
231+
222232
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
223233
kernels:
224234
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,16 @@
339339
- arg_meta: null
340340
kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out
341341

342+
- func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
343+
kernels:
344+
- arg_meta: null
345+
kernel_name: cadence::impl::HiFi::quantized_relu_asym8s_asym8s_per_tensor_out
346+
347+
- func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
348+
kernels:
349+
- arg_meta: null
350+
kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out
351+
342352
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
343353
kernels:
344354
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@
232232
"quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
233233
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
234234
)
235+
lib.define(
236+
"quantized_relu_asym8s_asym8s.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
237+
)
238+
lib.define(
239+
"quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
240+
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
241+
)
242+
lib.define(
243+
"quantized_relu_asym8u_asym8u.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
244+
)
245+
lib.define(
246+
"quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
247+
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
248+
)
235249
lib.define(
236250
"quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
237251
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
@@ -770,6 +784,28 @@ def quantized_relu_per_tensor_meta(
770784
return input.new_empty(input.size(), dtype=input.dtype)
771785

772786

787+
@register_fake("cadence::quantized_relu_asym8s_asym8s.per_tensor")
788+
def quantized_relu_asym8s_asym8s_per_tensor_meta(
789+
input: torch.Tensor,
790+
in_zero_point: int,
791+
out_zero_point: int,
792+
out_multiplier: int,
793+
out_shift: int,
794+
) -> torch.Tensor:
795+
return input.new_empty(input.size(), dtype=input.dtype)
796+
797+
798+
@register_fake("cadence::quantized_relu_asym8u_asym8u.per_tensor")
799+
def quantized_relu_asym8u_asym8u_per_tensor_meta(
800+
input: torch.Tensor,
801+
in_zero_point: int,
802+
out_zero_point: int,
803+
out_multiplier: int,
804+
out_shift: int,
805+
) -> torch.Tensor:
806+
return input.new_empty(input.size(), dtype=input.dtype)
807+
808+
773809
@register_fake("cadence::fully_connected")
774810
def fully_connected_meta(
775811
src: torch.Tensor,

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2327,10 +2327,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23272327
# Cast the const_arg to the dtype of the x_arg
23282328
full_arg = self.resolve_full_arg(x_arg, const_arg)
23292329

2330+
full_output_dtype = (
2331+
torch.int32 if isinstance(full_arg, int) else torch.float32
2332+
)
2333+
23302334
# Extract an argument to a separate full op.
23312335
with graph_module.graph.inserting_before(mul_node):
23322336
full_node = graph_module.graph.call_function(
2333-
torch.ops.aten.full.default, args=([1], full_arg)
2337+
torch.ops.aten.full.default,
2338+
args=([1], full_arg),
2339+
kwargs={"dtype": full_output_dtype},
23342340
)
23352341
full_node.meta = mul_node.meta
23362342
full_node.meta["val"] = [1]

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,51 @@ def test_mixed_types_error(self) -> None:
137137
with self.assertRaises(RuntimeError) as context:
138138
cast(PassResult, p(gm)).graph_module
139139
self.assertIn("Unsupported input types", str(context.exception))
140+
141+
def test_int8_dispatch_quantized_relu(self) -> None:
142+
"""Test int8 input should dispatch to asym8s_asym8s variant for quantized_relu"""
143+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
144+
gm = single_op_builder(
145+
placeholders=(x,),
146+
op=exir_ops.edge.cadence.quantized_relu.per_tensor,
147+
args=(x, 0, 0, 1, 0),
148+
)
149+
p = CompileTimeTypeDispatchPass()
150+
gm = cast(PassResult, p(gm)).graph_module
151+
# Original op should be replaced
152+
self.assertEqual(
153+
count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor),
154+
0,
155+
)
156+
# Should be replaced with int8 specific variant
157+
self.assertEqual(
158+
count_node(
159+
gm,
160+
exir_ops.edge.cadence.quantized_relu_asym8s_asym8s.per_tensor,
161+
),
162+
1,
163+
)
164+
165+
def test_uint8_dispatch_quantized_relu(self) -> None:
166+
"""Test uint8 input should dispatch to asym8u_asym8u variant for quantized_relu"""
167+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
168+
gm = single_op_builder(
169+
placeholders=(x,),
170+
op=exir_ops.edge.cadence.quantized_relu.per_tensor,
171+
args=(x, 0, 0, 1, 0),
172+
)
173+
p = CompileTimeTypeDispatchPass()
174+
gm = cast(PassResult, p(gm)).graph_module
175+
# Original op should be replaced
176+
self.assertEqual(
177+
count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor),
178+
0,
179+
)
180+
# Should be replaced with uint8 specific variant
181+
self.assertEqual(
182+
count_node(
183+
gm,
184+
exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor,
185+
),
186+
1,
187+
)

backends/cadence/aot/type_dispatch.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,63 @@ class CompileTimeTypeDispatchPass(ExportPass):
2323
Replaces generic ops with ops that have explicit types.
2424
"""
2525

26-
_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
26+
_BINARY_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
2727
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
2828
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
2929
}
3030

31-
_SUPPORTED_OPS: dict[OpOverload, str] = {
31+
_UNARY_TYPE_DISPATCH_MAP: dict[torch.dtype, str] = {
32+
torch.int8: "asym8s_asym8s",
33+
torch.uint8: "asym8u_asym8u",
34+
}
35+
36+
_BINARY_SUPPORTED_OPS: dict[OpOverload, str] = {
3237
exir_ops.edge.cadence.quantized_fully_connected.per_tensor: "quantized_fully_connected",
3338
exir_ops.edge.cadence.quantized_linear.per_tensor: "quantized_linear",
3439
}
3540

41+
_SUPPORTED_UNARY_OPS: dict[OpOverload, str] = {
42+
exir_ops.edge.cadence.quantized_relu.per_tensor: "quantized_relu",
43+
}
44+
3645
def call_operator(
3746
self,
3847
op: OpOverload,
3948
args: tuple[Argument, ...],
4049
kwargs: dict[str, Argument],
4150
meta: NodeMetadata,
4251
) -> ProxyValue:
43-
if op not in self._SUPPORTED_OPS:
44-
return super().call_operator(op, args, kwargs, meta)
52+
if op in self._BINARY_SUPPORTED_OPS:
53+
# pyre-ignore[16]: None has no attribute `to_tensor`.
54+
input_dtype = args[0].to_tensor().dtype
55+
weight_dtype = args[1].to_tensor().dtype
56+
dtype_pair = (input_dtype, weight_dtype)
57+
58+
if dtype_pair not in self._BINARY_TYPE_DISPATCH_MAP:
59+
raise RuntimeError(
60+
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
61+
)
62+
63+
base_op_name = self._BINARY_SUPPORTED_OPS[op]
64+
type_suffix = self._BINARY_TYPE_DISPATCH_MAP[dtype_pair]
65+
66+
typed_op_name = f"{base_op_name}_{type_suffix}"
67+
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
68+
69+
return super().call_operator(typed_op, args, kwargs, meta)
70+
71+
elif op in self._SUPPORTED_UNARY_OPS:
72+
input_dtype = args[0].to_tensor().dtype
4573

46-
# pyre-ignore[16]: None has no attribute `to_tensor`.
47-
input_dtype = args[0].to_tensor().dtype
48-
weight_dtype = args[1].to_tensor().dtype
49-
dtype_pair = (input_dtype, weight_dtype)
74+
if input_dtype not in self._UNARY_TYPE_DISPATCH_MAP:
75+
raise RuntimeError(f"Unsupported input type for {op}: {input_dtype}")
5076

51-
if dtype_pair not in self._TYPE_DISPATCH_MAP:
52-
raise RuntimeError(
53-
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
54-
)
77+
base_op_name = self._SUPPORTED_UNARY_OPS[op]
78+
type_suffix = self._UNARY_TYPE_DISPATCH_MAP[input_dtype]
5579

56-
base_op_name = self._SUPPORTED_OPS[op]
57-
type_suffix = self._TYPE_DISPATCH_MAP[dtype_pair]
80+
typed_op_name = f"{base_op_name}_{type_suffix}"
81+
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
5882

59-
typed_op_name = f"{base_op_name}_{type_suffix}"
60-
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
83+
return super().call_operator(typed_op, args, kwargs, meta)
6184

62-
return super().call_operator(typed_op, args, kwargs, meta)
85+
return super().call_operator(op, args, kwargs, meta)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <xa_nnlib_kernels_api.h>
12+
13+
namespace cadence {
14+
namespace impl {
15+
namespace HiFi {
16+
namespace native {
17+
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::KernelRuntimeContext;
20+
21+
void quantized_relu_asym8s_asym8s_per_tensor_out(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& input,
24+
const int64_t in_zero_point,
25+
const int64_t out_zero_point,
26+
const int64_t out_multiplier,
27+
const int64_t out_shift,
28+
Tensor& output) {
29+
const int8_t* __restrict__ input_data = input.const_data_ptr<int8_t>();
30+
int8_t* __restrict__ output_data = output.mutable_data_ptr<int8_t>();
31+
32+
const int32_t out_multipler_int32 = static_cast<int32_t>(out_multiplier);
33+
const int32_t out_shift_int32 = static_cast<int32_t>(out_shift);
34+
35+
const int32_t ret = xa_nn_vec_relu_asym8s_asym8s(
36+
output_data,
37+
input_data,
38+
in_zero_point,
39+
out_multipler_int32,
40+
out_shift_int32,
41+
out_zero_point,
42+
-128,
43+
127,
44+
input.numel());
45+
ET_DCHECK_MSG(
46+
ret == 0, "HiFi quantized_relu_asym8s_asym8s_per_tensor failed");
47+
}
48+
49+
} // namespace native
50+
} // namespace HiFi
51+
} // namespace impl
52+
} // namespace cadence

0 commit comments

Comments
 (0)