@@ -572,7 +572,7 @@ def compile_for_cutlass(mod, cutlass_target):
572572 return final_mod
573573
574574
575- def finalize_modules (lib , lib_path , tmp_dir ):
575+ def finalize_modules (lib , lib_path = "compile.so" , tmp_dir = "./tmp" ):
576576 """Returns lib with any C source, LLVM and static library modules complied and linked in ready
577577 for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does
578578 assume nvcc will be used for final compilation and linking. It is provided here for
@@ -584,22 +584,23 @@ def finalize_modules(lib, lib_path, tmp_dir):
584584 The output from relay.build.
585585
586586 lib_path : string
587- Name for temporary library .so file .
587+ The path to a shared library which will be generated as the result of the build process .
588588
589- tmp_dir : Working temporary directory.
589+ tmp_dir : string
590+ A temporary directory where intermediate compiled artifacts will be stored.
590591
591592 Returns
592593 -------
593- updated_lib : runtime:: Module
594- The given lib with any final compilation and linking steps completed.
594+ updated_lib : runtime. Module
595+ The updated library with all compilation and linking completed.
595596
596597 """
597598 lib_path = os .path .join (tmp_dir , lib_path )
598599 lib .export_library (lib_path , workspace_dir = tmp_dir , cc = "nvcc" )
599600 return runtime .load_module (lib_path )
600601
601602
602- def finalize_modules_vm (vm_exec , lib_path , tmp_dir ):
603+ def finalize_modules_vm (vm_exec , lib_path = "compile.so" , vmcode_path = "vmcode.ro" , tmp_dir = "./tmp" ):
603604 """Returns vm_exec with any C source, LLVM and static library modules compiled and linked in
604605 ready for use by the VM executor. This method is not specific to CUTLASS, however it does
605606 assume nvcc will be used for final compilation and linking. It is provided here for
@@ -608,20 +609,28 @@ def finalize_modules_vm(vm_exec, lib_path, tmp_dir):
608609 Parameters
609610 ----------
610611 vm_exec : vm.Executable
611- The output from relay.vm.compile.
612+ The output from relay.vm.compile containing compiled host code and kernels .
612613
613614 lib_path : string
614- Name for temporary library .so file.
615+ The path to a shared library which will be generated as the result of the build process.
616+
617+ vmcode_path : string
618+ The path where the VM bytecode will be serialized to.
615619
616- tmp_dir : Working temporary directory.
620+ tmp_dir : string
621+ A temporary directory where intermediate compiled artifacts will be stored.
617622
618623 Returns
619624 -------
620625 updated_vm_exec : vm.Executable
621- The given lib with any final compilation and linking steps completed.
626+ The updated VM executable with all compilation and linking completed.
622627 """
623628 code , lib = vm_exec .save ()
624629 lib_path = os .path .join (tmp_dir , lib_path )
630+ vmcode_path = os .path .join (tmp_dir , vmcode_path )
625631 lib .export_library (lib_path , workspace_dir = tmp_dir , cc = "nvcc" )
632+ with open (vmcode_path , "wb" ) as fo :
633+ fo .write (code )
626634 lib = tvm .runtime .load_module (lib_path )
635+ code = bytearray (open (vmcode_path , "rb" ).read ())
627636 return tvm .runtime .vm .Executable .load_exec (code , lib )
0 commit comments