3131from tvm import autotvm , auto_scheduler
3232from tvm import relay
3333from tvm .driver .tvmc .registry import generate_registry_args , reconstruct_registry_entity
34- from tvm .ir .instrument import PassInstrument , PassTimingInstrument
34+ from tvm .ir .instrument import PassInstrument , PassTimingInstrument , PassPrintBefore , PassPrintAfter
3535from tvm .ir .memory_pools import WorkspaceMemoryPools
3636from tvm .target import Target
3737from tvm .relay .backend import Executor , Runtime
@@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
162162 action = "store_true" ,
163163 help = "print compilation time per pass" ,
164164 )
165+ parser .add_argument (
166+ "--print-ir-before" ,
167+ help = "print IR before each named pass of a comma-separated list of pass names."
168+ "e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' " ,
169+ default = "" ,
170+ )
171+ parser .add_argument (
172+ "--print-ir-after" ,
173+ help = "print IR after each named pass of a comma-separated list of pass names."
174+ "e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' " ,
175+ default = "" ,
176+ )
165177 for one_entry in json_params :
166178 parser .set_defaults (** one_entry )
167179
@@ -220,6 +232,8 @@ def drive_compile(args):
220232 workspace_pools_recombobulate (args , [workspace_pools_target ], extra_targets )
221233 ),
222234 print_pass_times = args .print_pass_times ,
235+ print_ir_before = args .print_ir_before ,
236+ print_ir_after = args .print_ir_after ,
223237 ** transform_args ,
224238 )
225239
@@ -247,6 +261,8 @@ def compile_model(
247261 mod_name : Optional [str ] = "default" ,
248262 workspace_pools : Optional [WorkspaceMemoryPools ] = None ,
249263 print_pass_times : bool = False ,
264+ print_ir_before : List [str ] = "" ,
265+ print_ir_after : List [str ] = "" ,
250266 instruments : Optional [Sequence [PassInstrument ]] = None ,
251267 desired_layout : Optional [str ] = None ,
252268 desired_layout_ops : Optional [List [str ]] = None ,
@@ -295,7 +311,7 @@ def compile_model(
295311 needs to be generated.
296312 disabled_pass: str, optional
297313 Comma-separated list of passes which needs to be disabled
298- during compilation
314+ during compilation.
299315 pass_context_configs: list[str], optional
300316 List of strings containing a set of configurations to be passed to the
301317 PassContext.
@@ -310,6 +326,10 @@ def compile_model(
310326 compilation.
311327 print_pass_times: bool
312328 To enable printing a breakdown of compilation times by pass. Disabled by default.
329+ print_ir_before: list[str]
330+ To print ir before each named pass of a comma-separated list of passes.
331+ print_ir_after: list[str]
332+ To print ir after each named pass of a comma-separated list of passes.
313333 instruments: Optional[Sequence[PassInstrument]]
314334 The list of pass instrument implementations.
315335 desired_layout: str, optional
@@ -369,6 +389,14 @@ def compile_model(
369389 timing_inst = PassTimingInstrument ()
370390 instruments = [timing_inst ] if instruments is None else [timing_inst ] + instruments
371391
392+ if print_ir_before :
393+ print_ir_before_instr = PassPrintBefore (print_ir_before )
394+ instruments = [print_ir_before_instr ] if instruments is None else [print_ir_before_instr ] + instruments
395+
396+ if print_ir_after :
397+ print_ir_after_instr = PassPrintAfter (print_ir_after )
398+ instruments = [print_ir_after_instr ] if instruments is None else [print_ir_after_instr ] + instruments
399+
372400 with tvm .transform .PassContext (
373401 opt_level = opt_level ,
374402 config = config ,
0 commit comments