diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index fc3e446af9c..6b545e0a7d3 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -8,7 +8,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union -import executorch.extension.pytree as ex_pytree import torch import torch.fx from executorch.exir.emit._emitter import ( @@ -18,89 +17,12 @@ _TopLevelEmitter, ) from executorch.exir.error import ExportError, ExportErrorType -from executorch.exir.schema import ( - Bool, - Chain, - ContainerMetadata, - Double, - EValue, - ExecutionPlan, - Int, - Program, - String, - SubsegmentOffsets, -) -from executorch.exir.tensor import layout_enum, scalar_type_enum +from executorch.exir.schema import Program, SubsegmentOffsets from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION from torch.export.exported_program import ExportedProgram, OutputKind from torch.utils import _pytree as pytree -def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]: - """ - Given a mapping of function names to return values, emit simple execution - plans that just return these constant values. - - Precondition: All the values are primitives (bool, float, int, str, enum) - or structures (list, dict) of them. - """ - plans = [] - # flatten any structures - for method, vals in prim_getters.items(): - # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. - flattened_output, spec = ex_pytree.tree_flatten(vals) - spec = spec.to_str() - chain = Chain( - inputs=[], - outputs=[], - instructions=[], - stacktrace=None, - ) - - # switch on type of prim - values = [] - for val in flattened_output: - if isinstance(val, float): - values.append(EValue(Double(val))) - - elif isinstance(val, bool): - values.append(EValue(Bool(val))) - - elif isinstance(val, int): - values.append(EValue(Int(val))) - - elif isinstance(val, str): - values.append(EValue(String(val))) - - elif isinstance(val, torch.dtype): - values.append(EValue(Int(scalar_type_enum(val)))) - - elif isinstance(val, torch.layout): - values.append(EValue(Int(layout_enum(val)))) - - else: - raise ExportError( - ExportErrorType.NOT_SUPPORTED, - f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive", - ) - - # add to plans - plans.append( - ExecutionPlan( - name=method, - values=values, - inputs=[], - outputs=list(range(0, len(values))), - chains=[chain], - operators=[], - delegates=[], - non_const_buffer_sizes=[0, 0], - container_meta_type=ContainerMetadata("", spec), - ) - ) - return plans - - @dataclass class EmitterOutput: """ @@ -220,7 +142,7 @@ def emit_program( # emit any primitive getters if prim_getters is not None: - plans.extend(_emit_prim_getters(prim_getters)) + plans.extend(emitter._emit_prim_getters(prim_getters)) return EmitterOutput( debug_handle_map=debug_handle_map, diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index ee5fa6af4ee..e581789cb31 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1173,6 +1173,77 @@ def _emit_free(self, spec: TensorSpec) -> _AbstractValue: # The value is not used but the caller expects an AbstractValue returned. return _AbstractValue(None, None) # pyre-ignore + def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]: + """ + Given a mapping of function names to return values, emit simple execution + plans that just return these constant values. + + Precondition: All the values are primitives (bool, float, int, str, enum) + or structures (list, dict) of them. + """ + plans = [] + # flatten any structures + for method, vals in prim_getters.items(): + # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. + flattened_output, spec = ex_pytree.tree_flatten(vals) + spec = spec.to_str() + chain = Chain( + inputs=[], + outputs=[], + instructions=[], + stacktrace=None, + ) + + # switch on type of prim + values = [] + for val in flattened_output: + if isinstance(val, float): + values.append(EValue(Double(val))) + + elif isinstance(val, bool): + values.append(EValue(Bool(val))) + + elif isinstance(val, int): + values.append(EValue(Int(val))) + + elif isinstance(val, str): + values.append(EValue(String(val))) + + elif isinstance(val, torch.dtype): + values.append(EValue(Int(scalar_type_enum(val)))) + + elif isinstance(val, torch.layout): + values.append(EValue(Int(layout_enum(val)))) + + elif isinstance(val, torch.Tensor): + values.append( + self._tensor_spec_to_evalue( + TensorSpec.from_tensor(val, const=True) + ) + ) + + else: + raise ExportError( + ExportErrorType.NOT_SUPPORTED, + f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive", + ) + + # add to plans + plans.append( + ExecutionPlan( + name=method, + values=values, + inputs=[], + outputs=list(range(0, len(values))), + chains=[chain], + operators=[], + delegates=[], + non_const_buffer_sizes=[0], + container_meta_type=ContainerMetadata("", spec), + ) + ) + return plans + def fetch_attr(self, target: _Target) -> _AbstractValue: """Fetch weights and other module parameters. If the attribute is a tensor, emit it.""" attr = super().fetch_attr(target) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index b55fb5e5dae..23481a07aaf 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1065,6 +1065,9 @@ def forward(self, k: torch.Tensor) -> torch.Tensor: self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48) def test_emit_prims(self) -> None: + tensor_output = torch.rand(1, 4) + tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)] + class Simple(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1078,6 +1081,12 @@ def get_ints(self) -> Tuple[int]: def get_str(self) -> str: return "foo" + def get_tensor(self) -> torch.Tensor: + return tensor_output + + def get_tensor_list(self) -> List[torch.Tensor]: + return tensor_list_output + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.sigmoid(self.linear(x)) @@ -1090,9 +1099,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: getters = {} getters["get_ints"] = model.get_ints() getters["get_str"] = model.get_str() - print(getters["get_str"]) + getters["get_tensor"] = model.get_tensor() + getters["get_tensor_list"] = model.get_tensor_list() + merged_program = emit_program(exir_input, False, getters).program - self.assertEqual(len(merged_program.execution_plan), 3) + + self.assertEqual(len(merged_program.execution_plan), 5) self.assertEqual( merged_program.execution_plan[0].name, @@ -1106,6 +1118,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: merged_program.execution_plan[2].name, "get_str", ) + self.assertEqual( + merged_program.execution_plan[3].name, + "get_tensor", + ) + self.assertEqual( + merged_program.execution_plan[4].name, + "get_tensor_list", + ) + # no instructions in a getter self.assertEqual( len(merged_program.execution_plan[1].chains[0].instructions), @@ -1141,6 +1162,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: merged_program.execution_plan[2].values[0].val.string_val, "foo", ) + self.assertEqual(len(merged_program.execution_plan[3].outputs), 1) + self.assertEqual(len(merged_program.execution_plan[4].outputs), 2) + + merged_program = to_edge( + export(model, inputs), constant_methods=getters + ).to_executorch() + executorch_module = _load_for_executorch_from_buffer(merged_program.buffer) + torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output) + model_output = executorch_module.run_method("get_tensor_list", []) + for i in range(len(tensor_list_output)): + torch.allclose(model_output[i], tensor_list_output[i]) def test_emit_debug_handle_map(self) -> None: mul_model = Mul()