Skip to content

Commit 6cbc66a

Browse files
author
Eirene Pandi
committed
[TVMC] Add tvmc flag to print ir before / after pass names
1 parent c8ef902 commit 6cbc66a

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

python/tvm/driver/tvmc/compiler.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import autotvm, auto_scheduler
3232
from tvm import relay
3333
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
34-
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
34+
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintBefore, PassPrintAfter
3535
from tvm.ir.memory_pools import WorkspaceMemoryPools
3636
from tvm.target import Target
3737
from tvm.relay.backend import Executor, Runtime
@@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
162162
action="store_true",
163163
help="print compilation time per pass",
164164
)
165+
parser.add_argument(
166+
"--print-ir-before",
167+
help="print IR before each named pass of a comma-separated list of pass names."
168+
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
169+
default="",
170+
)
171+
parser.add_argument(
172+
"--print-ir-after",
173+
help="print IR after each named pass of a comma-separated list of pass names."
174+
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
175+
default="",
176+
)
165177
for one_entry in json_params:
166178
parser.set_defaults(**one_entry)
167179

@@ -220,6 +232,8 @@ def drive_compile(args):
220232
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
221233
),
222234
print_pass_times=args.print_pass_times,
235+
print_ir_before=args.print_ir_before,
236+
print_ir_after=args.print_ir_after,
223237
**transform_args,
224238
)
225239

@@ -247,6 +261,8 @@ def compile_model(
247261
mod_name: Optional[str] = "default",
248262
workspace_pools: Optional[WorkspaceMemoryPools] = None,
249263
print_pass_times: bool = False,
264+
print_ir_before: List[str] = "",
265+
print_ir_after: List[str] = "",
250266
instruments: Optional[Sequence[PassInstrument]] = None,
251267
desired_layout: Optional[str] = None,
252268
desired_layout_ops: Optional[List[str]] = None,
@@ -295,7 +311,7 @@ def compile_model(
295311
needs to be generated.
296312
disabled_pass: str, optional
297313
Comma-separated list of passes which needs to be disabled
298-
during compilation
314+
during compilation.
299315
pass_context_configs: list[str], optional
300316
List of strings containing a set of configurations to be passed to the
301317
PassContext.
@@ -310,6 +326,10 @@ def compile_model(
310326
compilation.
311327
print_pass_times: bool
312328
To enable printing a breakdown of compilation times by pass. Disabled by default.
329+
print_ir_before: list[str]
330+
To print ir before each named pass of a comma-separated list of passes.
331+
print_ir_after: list[str]
332+
To print ir after each named pass of a comma-separated list of passes.
313333
instruments: Optional[Sequence[PassInstrument]]
314334
The list of pass instrument implementations.
315335
desired_layout: str, optional
@@ -369,6 +389,20 @@ def compile_model(
369389
timing_inst = PassTimingInstrument()
370390
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments
371391

392+
if print_ir_before:
393+
print_ir_before_instr = PassPrintBefore(print_ir_before)
394+
instruments = (
395+
[print_ir_before_instr]
396+
if instruments is None
397+
else [print_ir_before_instr] + instruments
398+
)
399+
400+
if print_ir_after:
401+
print_ir_after_instr = PassPrintAfter(print_ir_after)
402+
instruments = (
403+
[print_ir_after_instr] if instruments is None else [print_ir_after_instr] + instruments
404+
)
405+
372406
with tvm.transform.PassContext(
373407
opt_level=opt_level,
374408
config=config,
@@ -581,7 +615,6 @@ def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule,
581615
save_to_file = all([dump_path != "-", dump_path != ""])
582616

583617
if print_to_console or save_to_file:
584-
585618
operations_distribution = analyze_operations_distribution(mod)
586619

587620
def annotate_f(x):

python/tvm/ir/instrument.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,25 @@ def render():
255255
profiles = timing_inst.render()
256256
"""
257257
return _ffi_instrument_api.RenderTimePassProfiles()
258+
259+
260+
@pass_instrument
261+
class PassPrintBefore:
262+
def __init__(self, print_pass_name):
263+
self.print_pass_name = print_pass_name
264+
265+
def run_before_pass(self, mod, pass_info):
266+
if pass_info.name in self.print_pass_name:
267+
print("Print ir before:")
268+
print(str(pass_info.name) + "\n" + str(mod) + "\n\n")
269+
270+
271+
@pass_instrument
272+
class PassPrintAfter:
273+
def __init__(self, print_pass_name):
274+
self.print_pass_name = print_pass_name
275+
276+
def run_after_pass(self, mod, pass_info):
277+
if pass_info.name in self.print_pass_name:
278+
print("Print ir after:")
279+
print(str(pass_info.name) + "\n" + str(mod) + "\n\n")

tests/python/driver/tvmc/test_command_line.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def paddle_model(paddle_resnet50):
201201
@mock.patch.object(compiler, "compile_model")
202202
# @mock.patch.object(compiler, "compile_model")
203203
def test_tvmc_compile_input_model(mock_compile_model, tmpdir_factory, model):
204-
205204
output_dir = tmpdir_factory.mktemp("output")
206205
output_file = output_dir / "model.tar"
207206

@@ -289,3 +288,37 @@ def test_tvmc_print_pass_times(capsys, keras_simple, tmpdir_factory):
289288
captured_out = capsys.readouterr().out
290289
for exp_str in ("Compilation time breakdown by pass:", "sequential:", "us]"):
291290
assert exp_str in captured_out
291+
292+
293+
def test_tvmc_print_ir_before(capsys, keras_simple, tmpdir_factory):
294+
pytest.importorskip("tensorflow")
295+
tmpdir = tmpdir_factory.mktemp("out")
296+
print_cmd = "--print-ir-before=[tir.SplitHostDevice]"
297+
298+
# Compile model
299+
module_file = os.path.join(tmpdir, "keras-tvm.tar")
300+
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
301+
compile_args = compile_cmd.split(" ")[1:]
302+
_main(compile_args)
303+
304+
# Check for timing results output
305+
captured_out = capsys.readouterr().out
306+
for exp_str in ("Print ir before:\n", "tir.SplitHostDevice\n"):
307+
assert exp_str in captured_out
308+
309+
310+
def test_tvmc_print_ir_after(capsys, keras_simple, tmpdir_factory):
311+
pytest.importorskip("tensorflow")
312+
tmpdir = tmpdir_factory.mktemp("out")
313+
print_cmd = "--print-ir-after=[tir.SplitHostDevice]"
314+
315+
# Compile model
316+
module_file = os.path.join(tmpdir, "keras-tvm.tar")
317+
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
318+
compile_args = compile_cmd.split(" ")[1:]
319+
_main(compile_args)
320+
321+
# Check for timing results output
322+
captured_out = capsys.readouterr().out
323+
for exp_str in ("Print ir after:\n", "tir.SplitHostDevice\n"):
324+
assert exp_str in captured_out

0 commit comments

Comments
 (0)