Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import torch
from torch import fx
from torch._ops import OpOverload
from torch.export import ExportedProgram
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


Expand Down Expand Up @@ -149,3 +151,34 @@ def find_sequential_partitions_aten(
if _partitions_sequential(candidate):
fused_partitions.append(candidate)
return fused_partitions


def post_training_quantize(
model: [ExportedProgram | fx.GraphModule],
calibration_inputs: list[tuple[torch.Tensor, ...]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will call from aot_neutron_compile work? The type hint changed.

quantizer=None,
) -> fx.GraphModule:
"""Quantize the provided model.

:param model: Aten model (or it's GraphModule representation) to quantize.
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
input. Or an iterator over such tuples.
:param quantizer: Optional quantizer to use, defaults to NXP default quantizer (NeutronQuantizer).

:return: Quantized GraphModule.
"""

if isinstance(model, ExportedProgram):
model = model.module()

if not quantizer:
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to conditionally import? It could be in default param.


quantizer = NeutronQuantizer()

m = prepare_pt2e(model, quantizer)
for data in calibration_inputs:
m(*data)
m = convert_pt2e(m)

return m
19 changes: 4 additions & 15 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.quantizer.utils import post_training_quantize
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
Expand All @@ -26,7 +27,6 @@
)
from executorch.extension.export_util.utils import export_to_edge
from torch import nn
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


@dataclass
Expand All @@ -35,17 +35,6 @@ class ModelInputSpec:
dtype: torch.dtype = torch.float32


def _quantize_model(
model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]]
):
m = prepare_pt2e(model, quantizer)
for data in calibration_inputs:
m(*data)
m = convert_pt2e(m)

return m


def get_random_calibration_inputs(
input_spec: tuple[ModelInputSpec, ...]
) -> list[tuple[torch.Tensor, ...]]:
Expand Down Expand Up @@ -99,10 +88,10 @@ def to_quantized_edge_program(

exir_program_aten = torch.export.export(model, example_input, strict=True)

exir_program_aten__module_quant = _quantize_model(
exir_program_aten.module(),
get_quantizer_fn(),
exir_program_aten__module_quant = post_training_quantize(
exir_program_aten,
calibration_inputs,
get_quantizer_fn(),
)

edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
Expand Down
6 changes: 3 additions & 3 deletions backends/nxp/tests/test_removing_dead_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.tests.executorch_pipeline import _quantize_model
from executorch.backends.nxp.quantizer.utils import post_training_quantize
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops


Expand Down Expand Up @@ -52,8 +52,8 @@ def test_removing_dead_code(self):

# The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method.
quantizer = NeutronQuantizer()
exir_program_aten_quant = _quantize_model(
exir_program_aten.module(), quantizer, [example_inputs]
exir_program_aten_quant = post_training_quantize(
exir_program_aten, [example_inputs], quantizer
)

# Make sure the is no `add` operation in the graph anymore.
Expand Down
6 changes: 3 additions & 3 deletions backends/nxp/tests/test_split_group_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.quantizer.utils import post_training_quantize
from executorch.backends.nxp.tests.executorch_pipeline import (
_quantize_model,
get_random_calibration_inputs,
to_model_input_spec,
)
Expand All @@ -42,8 +42,8 @@ def _quantize_and_lower_module(
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape))
quantizer = NeutronQuantizer()

exir_program_aten__module_quant = _quantize_model(
module, quantizer, calibration_inputs
exir_program_aten__module_quant = post_training_quantize(
module, calibration_inputs, quantizer
)

edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
Expand Down
14 changes: 14 additions & 0 deletions docs/source/backends-nxp.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ List of Aten operators supported by Neutron quantizer:
`reshape`, `view`, `softmax.int`, `sigmoid`, `tanh`, `tanh_`

#### Example

To quantize your model, you can either use the PT2E workflow:
```python
import torch
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
Expand All @@ -73,6 +75,18 @@ for data in calibration_inputs:
m = convert_pt2e(m)
```

Or you can use the predefined function for post training quantization from NXP backend implementation:
```python
from executorch.backends.nxp.quantizer.utils import post_training_quantize

...

quantized_graph_module = post_training_quantize(
aten_model,
calibration_inputs,
)
```

## Runtime Integration

To learn how to run the converted model on the NXP hardware, use one of our example projects on using ExecuTorch runtime from MCUXpresso IDE example projects list.
Expand Down
39 changes: 1 addition & 38 deletions examples/nxp/aot_neutron_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import io
import logging
from collections import defaultdict
from typing import Iterator

import executorch.extension.pybindings.portable_lib
import executorch.kernels.quantized # noqa F401
Expand All @@ -20,7 +19,7 @@
)
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.backends.nxp.quantizer.utils import post_training_quantize
from executorch.examples.models import MODEL_NAME_TO_MODEL
from executorch.examples.models.model_factory import EagerModelFactory
from executorch.exir import (
Expand All @@ -30,7 +29,6 @@
)
from executorch.extension.export_util import save_pte_program
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
from .models.mobilenet_v2 import MobilenetV2
Expand Down Expand Up @@ -102,41 +100,6 @@ def get_model_and_inputs_from_name(model_name: str):
}


def post_training_quantize(
model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]]
):
"""Quantize the provided model.

:param model: Aten model to quantize.
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
input. Or an iterator over such tuples.
"""
# Based on executorch.examples.arm.aot_amr_compiler.quantize
logging.info("Quantizing model")
logging.debug(f"---> Original model: {model}")
quantizer = NeutronQuantizer()

m = prepare_pt2e(model, quantizer)
# Calibration:
logging.debug("Calibrating model")

def _get_batch_size(data):
return data[0].shape[0]

if not isinstance(
calibration_inputs, tuple
): # Assumption that calibration_inputs is finite.
for i, data in enumerate(calibration_inputs):
if i % (1000 // _get_batch_size(data)) == 0:
logging.debug(f"{i * _get_batch_size(data)} calibration inputs done")
m(*data)
else:
m(*calibration_inputs)
m = convert_pt2e(m)
logging.debug(f"---> Quantized model: {m}")
return m


if __name__ == "__main__": # noqa C901
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
Loading