Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion python/tvm/relax/base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""BasePyModule: Base class for IRModules with Python function support."""

import inspect
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -369,7 +371,6 @@ def add_python_function(self, name: str, func: callable):
# Create a wrapper that handles both instance methods and static functions
# pylint: disable=import-outside-toplevel
import functools
import inspect

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -383,3 +384,129 @@ def wrapper(*args, **kwargs):

# Set the wrapper as an instance attribute
setattr(self, name, wrapper)

def script(
self,
*,
name: Optional[str] = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
module_alias: str = "cls",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
) -> str:
"""Print TVM IR into TVMScript text format with Python function support.

This method extends the standard IRModule script() method to handle
Python functions stored in the IRModule's pyfuncs attribute.
"""
# First get the standard IRModule script
base_script = self.ir_mod.script(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
module_alias=module_alias,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
indent_spaces=indent_spaces,
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
)

# If there are no Python functions, return the base script
if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
return base_script

# Insert Python functions into the script
return self._insert_python_functions(base_script, indent_spaces)

def _insert_python_functions(self, base_script: str, indent_spaces: int) -> str:
"""Insert Python functions into the TVMScript output."""
lines = base_script.split("\n")
result_lines = []

# Find the class definition line and insert Python functions after it
class_found = False
class_indent = 0

for line in lines:
result_lines.append(line)

# Look for class definition
if not class_found and line.strip().startswith("class "):
class_found = True
class_indent = len(line) - len(line.lstrip())

# Insert Python functions after the class definition
if hasattr(self.ir_mod, "pyfuncs") and self.ir_mod.pyfuncs:
for func_name, func in self.ir_mod.pyfuncs.items():
# Get the function source code
func_source = self._get_function_source(func)
if func_source:
# Format the function with proper indentation
formatted_func = self._format_python_function(
func_name, func_source, class_indent + indent_spaces
)
result_lines.append(formatted_func)
result_lines.append("") # Add empty line for separation

return "\n".join(result_lines)

def _get_function_source(self, func: callable) -> Optional[str]:
"""Get the source code of a Python function."""
try:
source = inspect.getsource(func)
return source
except (OSError, TypeError):
# If we can't get the source, return None
return None

def _format_python_function(self, _func_name: str, func_source: str, indent: int) -> str:
"""Format a Python function with proper indentation for TVMScript."""
lines = func_source.split("\n")
formatted_lines = []

for line in lines:
# Skip the function definition line if it's already properly indented
if line.strip().startswith("def ") or line.strip().startswith("@"):
# Keep decorators and function definition as is
formatted_lines.append(" " * indent + line.strip())
else:
# Add proper indentation for the function body
formatted_lines.append(" " * indent + line.strip())

return "\n".join(formatted_lines)

def show(
self, style: Optional[str] = None, black_format: Optional[bool] = None, **kwargs
) -> None:
"""A sugar for print highlighted TVM script with Python function support.

This method extends the standard IRModule show() method to handle
Python functions stored in the IRModule's pyfuncs attribute.
"""
from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel

if black_format is None:
env = os.environ.get("TVM_BLACK_FORMAT")
black_format = env and int(env)

script_content = self.script(**kwargs)
cprint(script_content, style=style, black_format=black_format)
Loading
Loading