Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 76 additions & 27 deletions tools/codegen/install_xpu_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from pathlib import Path

VERBOSE = False

parser = argparse.ArgumentParser(description="Utils for append ops headers")
parser.add_argument(
Expand All @@ -20,34 +21,53 @@ def append_xpu_function_header(src, dst):
r"""
Cleans trailing empty lines from the destination file, then appends #include
lines from the source file that match `#include <ATen/ops/...` to the destination.
Only modifies file if changes are actually needed.
"""
if args.dry_run:
return

# Read source file and match header lines
with open(src, encoding="utf-8") as fr:
src_text = fr.read()
try:
with open(src, encoding="utf-8") as fr:
src_text = fr.read()
except OSError as e:
if VERBOSE:
print(f"Warning: Could not read source file {src}: {e}")
return

pattern = r"^#include <ATen/ops/.*>\s*\r?\n"
matches = re.findall(pattern, src_text, re.MULTILINE)
if not matches:
return

with open(dst, "r+", encoding="utf-8") as f:
dst_lines = f.readlines()
dst_text = "".join(dst_lines)
missing_headers = [match for match in matches if match not in dst_text]
if not missing_headers:
return
try:
with open(dst, encoding="utf-8") as f:
dst_lines = f.readlines()
dst_text = "".join(dst_lines)
except OSError as e:
if VERBOSE:
print(f"Warning: Could not read destination file {dst}: {e}")
return

missing_headers = [match for match in matches if match not in dst_text]
if not missing_headers:
return

new_dst_lines = dst_lines.copy()

while new_dst_lines and not new_dst_lines[-1].strip():
new_dst_lines.pop()
new_dst_lines.extend(missing_headers)

# Remove trailing empty lines from dst_lines
while dst_lines and not dst_lines[-1].strip():
dst_lines.pop()
new_content = "".join(new_dst_lines)
old_content = "".join(dst_lines)

f.seek(0)
f.truncate()
f.writelines(dst_lines)
# Append missing headers to the end of the file
f.writelines(missing_headers)
if new_content != old_content:
try:
with open(dst, "w", encoding="utf-8") as f:
f.writelines(new_dst_lines)
except OSError as e:
if VERBOSE:
print(f"Error: Could not write to {dst}: {e}")


def parse_ops_headers(src):
Expand Down Expand Up @@ -78,18 +98,36 @@ def classify_ops_headers(src_dir, dst_dir):

def generate_xpu_ops_headers_cmake(src_dir, dst_dir, xpu_ops_headers):
r"""
Generate XPU ops headers xpu_ops_generated_headers.cmake
Generate XPU ops headers xpu_ops_generated_headers.cmake only if content changes
"""
with open(os.path.join(src_dir, "xpu_ops_generated_headers.cmake"), "w", encoding="utf-8") as fw:
fw.write("set(xpu_ops_generated_headers\n")
for header in xpu_ops_headers:
fw.write(f' "{Path(os.path.join(dst_dir, header)).as_posix()}"\n')
fw.write(")\n")
output_file = os.path.join(src_dir, "xpu_ops_generated_headers.cmake")

# Generate new content
new_content = "set(xpu_ops_generated_headers\n"
for header in xpu_ops_headers:
new_content += f' "{Path(os.path.join(dst_dir, header)).as_posix()}"\n'
new_content += ")\n"

# Check if file exists and has same content
should_write = True
if os.path.exists(output_file):
try:
with open(output_file, encoding="utf-8") as f:
existing_content = f.read()
should_write = existing_content != new_content
except OSError:
# If we can't read the file, write it anyway
should_write = True

if should_write:
with open(output_file, "w", encoding="utf-8") as fw:
fw.write(new_content)


def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
r"""
For XPU-specific ops headers, copy them to destination build and append XPU declarations to common headers.
Copies and appends are done only if leading to file changes to prevent unnecessary recompilations.
"""
if args.dry_run:
return
Expand All @@ -99,7 +137,16 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
# assert "xpu" in f, f"Error: The function signature or namespace in '{f}' is incorrect. Expected 'xpu' to be present."
src = os.path.join(src_dir, f)
dst = os.path.join(dst_dir, f)
shutil.copy(src, dst)
# Only copy if src and dst differ or dst does not exist
should_copy = True
if os.path.exists(dst):
try:
with open(src, "rb") as fsrc, open(dst, "rb") as fdst:
should_copy = fsrc.read() != fdst.read()
except OSError:
should_copy = True
if should_copy:
shutil.copy(src, dst)

for f in common_headers:
src = os.path.join(src_dir, f)
Expand All @@ -118,6 +165,7 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
with open(dst, "r+", encoding="utf-8") as f:
dst_lines = f.readlines()
dst_text = "".join(dst_lines)
old_content = "".join(dst_lines)
missing_declarations = []
insertion_index = None
for index, line in enumerate(dst_lines):
Expand All @@ -133,9 +181,10 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
break
assert (insertion_index is not None), f"Error: No TORCH_API declaration found in {dst}."

f.seek(0)
f.writelines(dst_lines)
f.truncate()
if old_content != "".join(dst_lines):
f.seek(0)
f.writelines(dst_lines)
f.truncate()


def main():
Expand Down
Loading