Skip to content

Commit fe5bdc4

Browse files
author
Eirene Pandi
committed
[TVMC] Add tvmc flag to print ir before / after pass names
1 parent c8ef902 commit fe5bdc4

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

python/tvm/driver/tvmc/compiler.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import autotvm, auto_scheduler
3232
from tvm import relay
3333
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
34-
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
34+
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintBefore, PassPrintAfter
3535
from tvm.ir.memory_pools import WorkspaceMemoryPools
3636
from tvm.target import Target
3737
from tvm.relay.backend import Executor, Runtime
@@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
162162
action="store_true",
163163
help="print compilation time per pass",
164164
)
165+
parser.add_argument(
166+
"--print-ir-before",
167+
help="print IR before each named pass of a comma-separated list of pass names."
168+
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
169+
default="",
170+
)
171+
parser.add_argument(
172+
"--print-ir-after",
173+
help="print IR after each named pass of a comma-separated list of pass names."
174+
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
175+
default="",
176+
)
165177
for one_entry in json_params:
166178
parser.set_defaults(**one_entry)
167179

@@ -220,6 +232,8 @@ def drive_compile(args):
220232
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
221233
),
222234
print_pass_times=args.print_pass_times,
235+
print_ir_before=args.print_ir_before,
236+
print_ir_after=args.print_ir_after,
223237
**transform_args,
224238
)
225239

@@ -247,6 +261,8 @@ def compile_model(
247261
mod_name: Optional[str] = "default",
248262
workspace_pools: Optional[WorkspaceMemoryPools] = None,
249263
print_pass_times: bool = False,
264+
print_ir_before: List[str] = "",
265+
print_ir_after: List[str] = "",
250266
instruments: Optional[Sequence[PassInstrument]] = None,
251267
desired_layout: Optional[str] = None,
252268
desired_layout_ops: Optional[List[str]] = None,
@@ -295,7 +311,7 @@ def compile_model(
295311
needs to be generated.
296312
disabled_pass: str, optional
297313
Comma-separated list of passes which needs to be disabled
298-
during compilation
314+
during compilation.
299315
pass_context_configs: list[str], optional
300316
List of strings containing a set of configurations to be passed to the
301317
PassContext.
@@ -310,6 +326,10 @@ def compile_model(
310326
compilation.
311327
print_pass_times: bool
312328
To enable printing a breakdown of compilation times by pass. Disabled by default.
329+
print_ir_before: list[str]
330+
To print ir before each named pass of a comma-separated list of passes.
331+
print_ir_after: list[str]
332+
To print ir after each named pass of a comma-separated list of passes.
313333
instruments: Optional[Sequence[PassInstrument]]
314334
The list of pass instrument implementations.
315335
desired_layout: str, optional
@@ -369,6 +389,14 @@ def compile_model(
369389
timing_inst = PassTimingInstrument()
370390
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments
371391

392+
if print_ir_before:
393+
print_ir_before_instr = PassPrintBefore(print_ir_before)
394+
instruments = [print_ir_before_instr] if instruments is None else [print_ir_before_instr] + instruments
395+
396+
if print_ir_after:
397+
print_ir_after_instr = PassPrintAfter(print_ir_after)
398+
instruments = [print_ir_after_instr] if instruments is None else [print_ir_after_instr] + instruments
399+
372400
with tvm.transform.PassContext(
373401
opt_level=opt_level,
374402
config=config,

python/tvm/ir/instrument.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def should_run(self, mod, pass_info)
216216
tvm.relay.build(mod, "llvm")
217217
"""
218218

219+
220+
219221
def create_pass_instrument(pi_cls):
220222
if not inspect.isclass(pi_cls):
221223
raise TypeError("pi_cls must be a class")
@@ -255,3 +257,21 @@ def render():
255257
profiles = timing_inst.render()
256258
"""
257259
return _ffi_instrument_api.RenderTimePassProfiles()
260+
261+
@pass_instrument
262+
class PassPrintBefore:
263+
def __init__(self, print_pass_name):
264+
self.print_pass_name = print_pass_name
265+
def run_before_pass(self, mod, pass_info):
266+
if (pass_info.name in self.print_pass_name):
267+
print("Print ir before:")
268+
print(str(pass_info.name) + '\n' + str(mod) + '\n\n')
269+
270+
@pass_instrument
271+
class PassPrintAfter:
272+
def __init__(self, print_pass_name):
273+
self.print_pass_name = print_pass_name
274+
def run_after_pass(self, mod, pass_info):
275+
if (pass_info.name in self.print_pass_name ):
276+
print("Print ir after:")
277+
print(str(pass_info.name) + '\n' + str(mod) + '\n\n')

tests/python/driver/tvmc/test_command_line.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,35 @@ def test_tvmc_print_pass_times(capsys, keras_simple, tmpdir_factory):
289289
captured_out = capsys.readouterr().out
290290
for exp_str in ("Compilation time breakdown by pass:", "sequential:", "us]"):
291291
assert exp_str in captured_out
292+
293+
def test_tvmc_print_ir_before(capsys, keras_simple, tmpdir_factory):
294+
pytest.importorskip("tensorflow")
295+
tmpdir = tmpdir_factory.mktemp("out")
296+
print_cmd = "--print-ir-before=[tir.SplitHostDevice]"
297+
298+
# Compile model
299+
module_file = os.path.join(tmpdir, "keras-tvm.tar")
300+
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
301+
compile_args = compile_cmd.split(" ")[1:]
302+
_main(compile_args)
303+
304+
# Check for timing results output
305+
captured_out = capsys.readouterr().out
306+
for exp_str in ("Print ir before:\n","tir.SplitHostDevice\n"):
307+
assert exp_str in captured_out
308+
309+
def test_tvmc_print_ir_after(capsys, keras_simple, tmpdir_factory):
310+
pytest.importorskip("tensorflow")
311+
tmpdir = tmpdir_factory.mktemp("out")
312+
print_cmd = "--print-ir-after=[tir.SplitHostDevice]"
313+
314+
# Compile model
315+
module_file = os.path.join(tmpdir, "keras-tvm.tar")
316+
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
317+
compile_args = compile_cmd.split(" ")[1:]
318+
_main(compile_args)
319+
320+
# Check for timing results output
321+
captured_out = capsys.readouterr().out
322+
for exp_str in ("Print ir after:\n","tir.SplitHostDevice\n"):
323+
assert exp_str in captured_out

0 commit comments

Comments
 (0)