Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions scripts/benchmark_import_torchao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# This script measures cold import time for torchao in two modes:
# - eager: simulates pre-change behavior by importing heavy submodules explicitly
# - lazy: imports only top-level and accesses nothing
# It runs multiple trials in fresh subprocesses for stable results.

import json
import os
import subprocess
import sys
from statistics import mean, stdev

PY = sys.executable
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

EAGER_SNIPPET = r"""
import time, importlib
start=time.time()
import torchao
# simulate old eager behavior: import heavy submodules and attributes
import torchao.quantization # heavy
# access re-exported attributes
from torchao.quantization import autoquant, quantize_
end=time.time()
print(round(end-start,3))
"""

LAZY_SNIPPET = r"""
import time, importlib
start=time.time()
import torchao
end=time.time()
print(round(end-start,3))
"""


def run_trial(code: str) -> float:
proc = subprocess.run(
[PY, "-c", code], capture_output=True, text=True, cwd=REPO_ROOT
)
if proc.returncode != 0:
raise RuntimeError(proc.stderr)
return float(proc.stdout.strip().splitlines()[-1])


def run_many(code: str, n: int = 7):
times = [run_trial(code) for _ in range(n)]
return {
"times": times,
"mean": round(mean(times), 4),
"stdev": round(stdev(times), 4) if len(times) > 1 else 0.0,
}


def main():
res_lazy = run_many(LAZY_SNIPPET)
res_eager = run_many(EAGER_SNIPPET)
speedup = None
if res_eager["mean"] > 0:
speedup = round(res_eager["mean"] / res_lazy["mean"], 2)
print(
json.dumps(
{
"lazy": res_lazy,
"eager_simulated": res_eager,
"speedup_x": speedup,
},
indent=2,
)
)


if __name__ == "__main__":
main()
54 changes: 54 additions & 0 deletions test/core/test_lazy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import importlib
import sys

import pytest


# Skip tests if PyTorch is not installed; torchao requires torch at import time
pytest.importorskip("torch", reason="Requires PyTorch for torchao import")


def _cleanup_torchao_modules():
# Remove torchao and its submodules from sys.modules to force a fresh import
for name in list(sys.modules.keys()):
if name == "torchao" or name.startswith("torchao."):
sys.modules.pop(name, None)


def test_top_level_import_is_lightweight():
_cleanup_torchao_modules()
importlib.invalidate_caches()

import torchao # noqa: F401 - just to import

# Heavy submodules should not be imported eagerly
assert "torchao.quantization" not in sys.modules
assert "torchao.experimental" not in sys.modules
assert "torchao.experimental.op_lib" not in sys.modules


def test_accessing_lazy_attrs_triggers_import():
_cleanup_torchao_modules()
importlib.invalidate_caches()

import torchao

# Accessing lazy attributes should import their source module
_ = torchao.autoquant # triggers torchao.quantization import
assert "torchao.quantization" in sys.modules
assert hasattr(torchao, "autoquant")

_ = torchao.quantize_ # already imported module should provide symbol
assert hasattr(torchao, "quantize_")


def test_lazy_submodule_resolution():
_cleanup_torchao_modules()
importlib.invalidate_caches()

import torchao

# Accessing a lazy submodule should import it on demand
_ = torchao.dtypes
assert "torchao.dtypes" in sys.modules

66 changes: 58 additions & 8 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
import warnings

import importlib
import sys
import torch

warnings.filterwarnings(
Expand Down Expand Up @@ -51,25 +53,73 @@
torch.ops.load_library(str(file))
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
# Avoid eagerly importing experimental op_lib as it is heavy and not always needed.
# Users can trigger it by importing `torchao.experimental` or setting up kernels explicitly.
# The following registers meta kernels for some CPU kernels
from torchao.csrc_meta_ops import * # noqa: F403
except Exception as e:
logger.debug(f"Skipping import of cpp extensions: {e}")

from torchao.quantization import (
autoquant,
quantize_,
)
# Lazy submodule and attribute exposure to reduce import-time overhead
_LAZY_SUBMODULES = {
"dtypes": "torchao.dtypes",
"optim": "torchao.optim",
"quantization": "torchao.quantization",
"swizzle": "torchao.swizzle",
"testing": "torchao.testing",
"ops": "torchao.ops",
"kernel": "torchao.kernel",
"float8": "torchao.float8",
"sparsity": "torchao.sparsity",
"prototype": "torchao.prototype",
"experimental": "torchao.experimental",
"_models": "torchao._models",
"core": "torchao.core",
}

from . import dtypes, optim, quantization, swizzle, testing
_LAZY_ATTRS = {
# Top-level convenience re-exports
"autoquant": ("torchao.quantization", "autoquant"),
"quantize_": ("torchao.quantization", "quantize_"),
}

__all__ = [
# Submodules
"dtypes",
"autoquant",
"optim",
"quantize_",
"quantization",
"swizzle",
"testing",
"ops",
"quantization",
"kernel",
"float8",
"sparsity",
"prototype",
"experimental",
"_models",
"core",
# Attributes
"autoquant",
"quantize_",
]

def __getattr__(name):
if name in _LAZY_SUBMODULES:
module = importlib.import_module(_LAZY_SUBMODULES[name])
setattr(sys.modules[__name__], name, module)
return module
if name in _LAZY_ATTRS:
mod_name, attr_name = _LAZY_ATTRS[name]
module = importlib.import_module(mod_name)
value = getattr(module, attr_name)
setattr(sys.modules[__name__], name, value)
return value
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

def __dir__():
return sorted(set(globals().keys()) | set(__all__))
Loading