diff --git a/exir/TARGETS b/exir/TARGETS index cc20f766bad..170a22f1328 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -144,6 +144,7 @@ python_library( "//executorch/exir/capture:lib", "//executorch/exir/emit:lib", "//executorch/exir/program:lib", + "//executorch/exir/serde:serialize", ], ) diff --git a/exir/__init__.py b/exir/__init__.py index d71a2be064b..c6a1939d357 100644 --- a/exir/__init__.py +++ b/exir/__init__.py @@ -24,6 +24,7 @@ ExirExportedProgram, to_edge, ) +from executorch.exir.serde.serialize import load, save from executorch.exir.tracer import ExirDynamoConfig from torch.export import ExportedProgram, ExportGraphSignature @@ -49,4 +50,6 @@ "ExecutorchBackendConfig", "Value", "ExirDynamoConfig", + "load", + "save", ] diff --git a/exir/serde/TARGETS b/exir/serde/TARGETS index ff3cf7999f6..fceebe7869f 100644 --- a/exir/serde/TARGETS +++ b/exir/serde/TARGETS @@ -13,7 +13,6 @@ python_library( ":schema", "//caffe2:torch", "//executorch/exir:delegate", - "//executorch/exir:lib", "//executorch/exir:lowered_backend_module", "//executorch/exir:memory", "//executorch/exir/backend:compile_spec_schema", diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 494aa29de29..7536a323ee0 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -25,3 +25,7 @@ class LoweredBackendModule: compile_specs: List[CompileSpec] original_module: export_schema.ExportedProgram original_state_dict: str + + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (1, 0) diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index 5eb28b830ce..e1521b98c09 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -9,9 +9,12 @@ import base64 import copy import dataclasses +import io import json import logging import operator +import os +import zipfile from typing import Any, Callable, Dict, List, Optional, Union import executorch.exir as exir @@ -30,9 +33,11 @@ from executorch.exir.lowered_backend_module import ( LoweredBackendModule as ExirLoweredBackendModule, ) +from executorch.exir.serde.export_serialize import SerializedArtifact from executorch.exir.serde.schema import ( CompileSpec, LoweredBackendModule as SerdeLoweredBackendModule, + SCHEMA_VERSION, ) from torch._export.serde.schema import SchemaVersion from torch._export.serde.serialize import SerializeError @@ -628,7 +633,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer): def deserialize( self, serialized_artifact: export_serialize.SerializedArtifact, - ) -> exir.ExportedProgram: + ) -> ep.ExportedProgram: assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) symbol_name_to_range = { @@ -738,7 +743,7 @@ def serialize( def deserialize( artifact: export_serialize.SerializedArtifact, expected_opset_version: Optional[Dict[str, int]] = None, -) -> exir.ExportedProgram: +) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) exported_program_str = artifact.exported_program.decode("utf-8") exported_program_dict = json.loads(exported_program_str) @@ -750,3 +755,96 @@ def deserialize( serialized_exported_program, artifact.state_dict, artifact.constants ) ) + + +def save( + ep_save: ep.ExportedProgram, + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + opset_version: Optional[Dict[str, int]] = None, +) -> None: + if not isinstance(ep_save, ep.ExportedProgram): + raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") + + artifact: SerializedArtifact = serialize(ep_save, opset_version) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with zipfile.ZipFile(f, "w") as zipf: + # Save every field in the SerializedArtifact to a file. + assert isinstance(artifact.exported_program, bytes) + zipf.writestr("serialized_exported_program.json", artifact.exported_program) + zipf.writestr("serialized_state_dict.pt", artifact.state_dict) + zipf.writestr("serialized_constants.pt", artifact.constants) + + zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) + + # Add extra files if provided + if extra_files: + for extra_file_name, content in extra_files.items(): + encoded_content = content.encode("utf-8") + zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) + + +def load( + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ep.ExportedProgram: + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + with zipfile.ZipFile(f, "r") as zipf: + # Check the version + version = zipf.read("version").decode().split(".") + + assert len(version) == len(SCHEMA_VERSION) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: Optional[bytes] = None + serialized_state_dict: Optional[bytes] = None + serialized_constants: Optional[bytes] = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + print("This version of file is deprecated") + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + print("This version of file is deprecated") + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode("utf-8") + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index d4be4686590..1ada5169479 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -6,6 +6,7 @@ # pyre-strict +import io import unittest from typing import Tuple @@ -47,7 +48,7 @@ def check_ep( self.assertTrue(torch.allclose(orig, loaded)) # pyre-ignore - def check_serde(self, m, inputs) -> None: + def check_serde(self, m, inputs, check_executorch=True) -> None: aten = export(m, inputs) aten_new = deserialize(serialize(aten)) self.check_ep(aten, aten_new, inputs) @@ -56,10 +57,23 @@ def check_serde(self, m, inputs) -> None: edge_new = deserialize(serialize(edge.exported_program())) self.check_ep(edge.exported_program(), edge_new, inputs) + buffer = io.BytesIO() + exir.save(edge.exported_program(), buffer) + buffer.seek(0) + loaded_ep = exir.load(buffer) + self.check_ep(edge.exported_program(), loaded_ep, inputs) + executorch = edge.to_executorch().exported_program() executorch_new = deserialize(serialize(executorch)) - with torch.no_grad(): - self.check_ep(executorch, executorch_new, inputs) + if check_executorch: + with torch.no_grad(): + self.check_ep(executorch, executorch_new, inputs) + + buffer = io.BytesIO() + exir.save(executorch, buffer) + buffer.seek(0) + loaded_ep = exir.load(buffer) + self.check_ep(executorch, loaded_ep, inputs) def test_basic(self) -> None: class MyModule(torch.nn.Module): @@ -88,7 +102,12 @@ def get_random_inputs(self): model = MyModel() inputs = model.get_random_inputs() - self.check_serde(model, inputs) + # We set check_executorch to false for this test because this triggers + # an edge case where calling .module() on the executorch exported program + # will cause an unlift pass to be run on the graph and dead code elimination + # will be subsequently run, which essentially causes the split_copy op to be + # removed. + self.check_serde(model, inputs, check_executorch=False) def test_to_out_variant_multiple_out(self) -> None: class MyModel(torch.nn.Module):