Skip to content
4 changes: 4 additions & 0 deletions packaging/env_var_script_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ TORCH_CUDA_ARCH_LIST="8.0;8.6"
if [[ ${CU_VERSION:-} == "cu124" ]]; then
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
fi

# Ensure pip does not use PEP 517 build isolation so that pre-installed
# tools from pre_build_script.sh (setuptools, wheel) are visible to the build.
export PIP_NO_BUILD_ISOLATION=1
21 changes: 12 additions & 9 deletions packaging/post_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ set -eux
# Prepare manywheel, only for CUDA.
# The wheel is a pure python wheel for other platforms.
if [[ "$CU_VERSION" == cu* ]]; then
WHEEL_NAME=$(ls dist/)

pushd dist
# Determine the original wheel produced by build (there should be exactly one)
ORIG_WHEEL=$(ls -1 *.whl | head -n 1)
manylinux_plat=manylinux_2_28_x86_64
auditwheel repair --plat "$manylinux_plat" -w . \
# Only run auditwheel if the wheel contains at least one shared object (.so)
if unzip -l "$ORIG_WHEEL" | awk '{print $4}' | grep -E '\\.so($|\.)' >/dev/null 2>&1; then
auditwheel repair --plat "$manylinux_plat" -w . \
--exclude libtorch.so \
--exclude libtorch_python.so \
--exclude libtorch_cuda.so \
Expand All @@ -23,15 +25,16 @@ if [[ "$CU_VERSION" == cu* ]]; then
--exclude libc10_cuda.so \
--exclude libcuda.so.* \
--exclude libcudart.so.* \
"${WHEEL_NAME}"
"${ORIG_WHEEL}"
else
echo "No shared libraries detected in wheel ${ORIG_WHEEL}; skipping auditwheel."
fi

ls -lah .
# Clean up the linux_x86_64 wheel
rm "${WHEEL_NAME}"
popd
fi

MANYWHEEL_NAME=$(ls dist/)
# Try to install the new wheel
pip install "dist/${MANYWHEEL_NAME}"
INSTALL_WHEEL=$(ls -1t dist/*.whl | head -n 1)
# Try to install the new wheel (pick most recent wheel file)
pip install "${INSTALL_WHEEL}"
python -c "import torchao"
220 changes: 134 additions & 86 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import List, Optional

from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext as _setuptools_build_ext

current_date = datetime.now().strftime("%Y%m%d")

Expand Down Expand Up @@ -98,16 +99,7 @@ def use_debug_mode():
return os.getenv("DEBUG", "0") == "1"


import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
IS_WINDOWS,
ROCM_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
_get_cuda_arch_flags,
)
# Heavy imports (torch, torch.utils.cpp_extension) are deferred to build time


class BuildOptions:
Expand Down Expand Up @@ -139,6 +131,8 @@ def __init__(self):
"TORCHAO_BUILD_EXPERIMENTAL_MPS", default=False
)
if self.build_experimental_mps:
import torch # Lazy import

assert is_macos, "TORCHAO_BUILD_EXPERIMENTAL_MPS requires macOS"
assert is_arm64, "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
assert torch.mps.is_available(), (
Expand Down Expand Up @@ -260,10 +254,23 @@ def get_cuda_version_from_nvcc():
return None


def is_nvcc_available():
"""Check if nvcc is available on the system."""
try:
subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
return True
except:
return False


def get_cutlass_build_flags():
"""Determine which CUTLASS kernels to build based on CUDA version.
SM90a: CUDA 12.6+, SM100a: CUDA 12.8+
"""
# Lazy import torch and helper; only needed when building CUDA extensions
import torch
from torch.utils.cpp_extension import _get_cuda_arch_flags

# Try nvcc then torch version
cuda_version = get_cuda_version_from_nvcc() or torch.version.cuda

Expand All @@ -290,64 +297,77 @@ def get_cutlass_build_flags():
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def build_extensions(self):
cmake_extensions = [
ext for ext in self.extensions if isinstance(ext, CMakeExtension)
]
other_extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
for ext in cmake_extensions:
self.build_cmake(ext)

# Use BuildExtension to build other extensions
self.extensions = other_extensions
super().build_extensions()

self.extensions = other_extensions + cmake_extensions

def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

# Get the expected extension file name that Python will look for
# We force CMake to use this library name
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
ext_basename = os.path.splitext(ext_filename)[0]

print(
"CMAKE COMMANG",
[
"cmake",
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
)
class LazyTorchAOBuildExt(_setuptools_build_ext):
def run(self):
# Import heavy torch build only when actually running build_ext
from torch.utils.cpp_extension import BuildExtension as _BuildExtension

class _TorchAOBuildExt(_BuildExtension):
def run(self_inner):
if os.getenv("USE_CPP", "1") != "0":
check_submodules()
if not self_inner.distribution.ext_modules:
self_inner.distribution.ext_modules = get_extensions()
super(_TorchAOBuildExt, self_inner).run()

def build_extensions(self_inner):
cmake_extensions = [
ext
for ext in self_inner.extensions
if isinstance(ext, CMakeExtension)
]
other_extensions = [
ext
for ext in self_inner.extensions
if not isinstance(ext, CMakeExtension)
]
for ext in cmake_extensions:
self_inner.build_cmake(ext)

self_inner.extensions = other_extensions
super(_TorchAOBuildExt, self_inner).build_extensions()
self_inner.extensions = other_extensions + cmake_extensions

def build_cmake(self_inner, ext):
extdir = os.path.abspath(
os.path.dirname(self_inner.get_ext_fullpath(ext.name))
)
if not os.path.exists(self_inner.build_temp):
os.makedirs(self_inner.build_temp)
ext_filename = os.path.basename(self_inner.get_ext_filename(ext.name))
ext_basename = os.path.splitext(ext_filename)[0]
if os.getenv("VERBOSE_BUILD", "0") == "1" or use_debug_mode():
print(
"CMAKE COMMAND",
[
"cmake",
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
)
subprocess.check_call(
[
"cmake",
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
cwd=self_inner.build_temp,
)
subprocess.check_call(
["cmake", "--build", "."], cwd=self_inner.build_temp
)

subprocess.check_call(
[
"cmake",
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
],
cwd=self.build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
# Morph this instance into the real BuildExtension subclass and run
self.__class__ = _TorchAOBuildExt
return _TorchAOBuildExt.run(self)


class CMakeExtension(Extension):
Expand All @@ -371,16 +391,33 @@ def get_extensions():
if debug_mode:
print("Compiling in debug mode")

if CUDA_HOME is None and torch.version.cuda:
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
# Heavy imports moved here to minimize setup.py import overhead
import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
IS_WINDOWS,
ROCM_HOME,
CppExtension,
CUDAExtension,
)

# Only skip CUDA extensions if neither CUDA_HOME nor nvcc is available.
# In many CI environments CUDA_HOME may be unset even though nvcc is on PATH.
if torch.version.cuda and CUDA_HOME is None and not is_nvcc_available():
print(
"CUDA toolkit is not available (CUDA_HOME unset and nvcc not found). Skipping compilation of CUDA extensions"
)
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)
if ROCM_HOME is None and torch.version.hip:
print("ROCm is not available. Skipping compilation of ROCm extensions")
print("If you'd like to compile ROCm extensions locally please install ROCm")

use_cuda = torch.version.cuda and CUDA_HOME is not None
# Build CUDA extensions if CUDA is available and either CUDA_HOME is set or nvcc is present
use_cuda = bool(torch.version.cuda) and (
CUDA_HOME is not None or is_nvcc_available()
)
use_rocm = torch.version.hip and ROCM_HOME is not None
extension = CUDAExtension if (use_cuda or use_rocm) else CppExtension

Expand Down Expand Up @@ -452,11 +489,13 @@ def get_extensions():
found_col16 = False
found_vec_ext = False
found_outer_vec = False
print("ROCM_HOME", ROCM_HOME)
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("ROCM_HOME", ROCM_HOME)
hipblaslt_headers = list(
glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))
)
print("hipblaslt_headers", hipblaslt_headers)
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt_headers", hipblaslt_headers)
for header in hipblaslt_headers:
with open(header) as f:
text = f.read()
Expand All @@ -468,17 +507,22 @@ def get_extensions():
found_outer_vec = True
if found_col16:
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
print("hipblaslt found extended col order enums")
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt found extended col order enums")
else:
print("hipblaslt does not have extended col order enums")
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt does not have extended col order enums")
if found_outer_vec:
extra_compile_args["cxx"].append("-DHIPBLASLT_OUTER_VEC")
print("hipblaslt found outer vec")
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt found outer vec")
elif found_vec_ext:
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
print("hipblaslt found vec ext")
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt found vec ext")
else:
print("hipblaslt does not have vec ext")
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("hipblaslt does not have vec ext")

# Get base directory and source paths
curdir = os.path.dirname(os.path.curdir)
Expand Down Expand Up @@ -641,7 +685,8 @@ def get_extensions():

ext_modules = []
if len(sources) > 0:
print("SOURCES", sources)
if os.getenv("VERBOSE_BUILD", "0") == "1" or debug_mode:
print("SOURCES", sources)
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
sources = [
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
Expand Down Expand Up @@ -735,9 +780,13 @@ def get_extensions():
def bool_to_on_off(value):
return "ON" if value else "OFF"

from distutils.sysconfig import get_python_lib
import importlib.util

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
spec = importlib.util.find_spec("torch")
if spec is None or spec.origin is None:
raise RuntimeError("Unable to locate 'torch' package for CMake config")
torch_pkg_dir = os.path.dirname(spec.origin)
torch_dir = os.path.join(torch_pkg_dir, "share", "cmake", "Torch")

ext_modules.append(
CMakeExtension(
Expand All @@ -762,24 +811,23 @@ def bool_to_on_off(value):
return ext_modules


# Only check submodules if we're going to build C++ extensions
if use_cpp != "0":
check_submodules()
# Defer submodule checks to build time via build_ext

setup(
name="torchao",
version=version + version_suffix,
packages=find_packages(exclude=["benchmarks", "benchmarks.*"]),
packages=find_packages(include=["torchao*"]),
include_package_data=True,
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions(),
# Defer extension discovery to build time for performance
ext_modules=[],
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/ao",
cmdclass={"build_ext": TorchAOBuildExt},
cmdclass={"build_ext": LazyTorchAOBuildExt},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
Loading