diff --git a/scripts/benchmark_import_torchao.py b/scripts/benchmark_import_torchao.py new file mode 100644 index 0000000000..bc112302e2 --- /dev/null +++ b/scripts/benchmark_import_torchao.py @@ -0,0 +1,73 @@ +# This script measures cold import time for torchao in two modes: +# - eager: simulates pre-change behavior by importing heavy submodules explicitly +# - lazy: imports only top-level and accesses nothing +# It runs multiple trials in fresh subprocesses for stable results. + +import json +import os +import subprocess +import sys +from statistics import mean, stdev + +PY = sys.executable +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +EAGER_SNIPPET = r""" +import time, importlib +start=time.time() +import torchao +# simulate old eager behavior: import heavy submodules and attributes +import torchao.quantization # heavy +# access re-exported attributes +from torchao.quantization import autoquant, quantize_ +end=time.time() +print(round(end-start,3)) +""" + +LAZY_SNIPPET = r""" +import time, importlib +start=time.time() +import torchao +end=time.time() +print(round(end-start,3)) +""" + + +def run_trial(code: str) -> float: + proc = subprocess.run( + [PY, "-c", code], capture_output=True, text=True, cwd=REPO_ROOT + ) + if proc.returncode != 0: + raise RuntimeError(proc.stderr) + return float(proc.stdout.strip().splitlines()[-1]) + + +def run_many(code: str, n: int = 7): + times = [run_trial(code) for _ in range(n)] + return { + "times": times, + "mean": round(mean(times), 4), + "stdev": round(stdev(times), 4) if len(times) > 1 else 0.0, + } + + +def main(): + res_lazy = run_many(LAZY_SNIPPET) + res_eager = run_many(EAGER_SNIPPET) + speedup = None + if res_eager["mean"] > 0: + speedup = round(res_eager["mean"] / res_lazy["mean"], 2) + print( + json.dumps( + { + "lazy": res_lazy, + "eager_simulated": res_eager, + "speedup_x": speedup, + }, + indent=2, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/test/core/test_lazy_import.py b/test/core/test_lazy_import.py new file mode 100644 index 0000000000..989dfd832f --- /dev/null +++ b/test/core/test_lazy_import.py @@ -0,0 +1,54 @@ +import importlib +import sys + +import pytest + + +# Skip tests if PyTorch is not installed; torchao requires torch at import time +pytest.importorskip("torch", reason="Requires PyTorch for torchao import") + + +def _cleanup_torchao_modules(): + # Remove torchao and its submodules from sys.modules to force a fresh import + for name in list(sys.modules.keys()): + if name == "torchao" or name.startswith("torchao."): + sys.modules.pop(name, None) + + +def test_top_level_import_is_lightweight(): + _cleanup_torchao_modules() + importlib.invalidate_caches() + + import torchao # noqa: F401 - just to import + + # Heavy submodules should not be imported eagerly + assert "torchao.quantization" not in sys.modules + assert "torchao.experimental" not in sys.modules + assert "torchao.experimental.op_lib" not in sys.modules + + +def test_accessing_lazy_attrs_triggers_import(): + _cleanup_torchao_modules() + importlib.invalidate_caches() + + import torchao + + # Accessing lazy attributes should import their source module + _ = torchao.autoquant # triggers torchao.quantization import + assert "torchao.quantization" in sys.modules + assert hasattr(torchao, "autoquant") + + _ = torchao.quantize_ # already imported module should provide symbol + assert hasattr(torchao, "quantize_") + + +def test_lazy_submodule_resolution(): + _cleanup_torchao_modules() + importlib.invalidate_caches() + + import torchao + + # Accessing a lazy submodule should import it on demand + _ = torchao.dtypes + assert "torchao.dtypes" in sys.modules + diff --git a/torchao/__init__.py b/torchao/__init__.py index 3a25a72114..6a70c8ccec 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -3,6 +3,8 @@ # torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy' import warnings +import importlib +import sys import torch warnings.filterwarnings( @@ -51,25 +53,73 @@ torch.ops.load_library(str(file)) from . import ops + # The following library contains CPU kernels from torchao/experimental + # They are built automatically by ao/setup.py if on an ARM machine. + # They can also be built outside of the torchao install process by + # running the script `torchao/experimental/build_torchao_ops.sh ` + # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md + # Avoid eagerly importing experimental op_lib as it is heavy and not always needed. + # Users can trigger it by importing `torchao.experimental` or setting up kernels explicitly. # The following registers meta kernels for some CPU kernels from torchao.csrc_meta_ops import * # noqa: F403 except Exception as e: logger.debug(f"Skipping import of cpp extensions: {e}") -from torchao.quantization import ( - autoquant, - quantize_, -) +# Lazy submodule and attribute exposure to reduce import-time overhead +_LAZY_SUBMODULES = { + "dtypes": "torchao.dtypes", + "optim": "torchao.optim", + "quantization": "torchao.quantization", + "swizzle": "torchao.swizzle", + "testing": "torchao.testing", + "ops": "torchao.ops", + "kernel": "torchao.kernel", + "float8": "torchao.float8", + "sparsity": "torchao.sparsity", + "prototype": "torchao.prototype", + "experimental": "torchao.experimental", + "_models": "torchao._models", + "core": "torchao.core", +} -from . import dtypes, optim, quantization, swizzle, testing +_LAZY_ATTRS = { + # Top-level convenience re-exports + "autoquant": ("torchao.quantization", "autoquant"), + "quantize_": ("torchao.quantization", "quantize_"), +} __all__ = [ + # Submodules "dtypes", - "autoquant", "optim", - "quantize_", + "quantization", "swizzle", "testing", "ops", - "quantization", + "kernel", + "float8", + "sparsity", + "prototype", + "experimental", + "_models", + "core", + # Attributes + "autoquant", + "quantize_", ] + +def __getattr__(name): + if name in _LAZY_SUBMODULES: + module = importlib.import_module(_LAZY_SUBMODULES[name]) + setattr(sys.modules[__name__], name, module) + return module + if name in _LAZY_ATTRS: + mod_name, attr_name = _LAZY_ATTRS[name] + module = importlib.import_module(mod_name) + value = getattr(module, attr_name) + setattr(sys.modules[__name__], name, value) + return value + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + +def __dir__(): + return sorted(set(globals().keys()) | set(__all__))