Skip to content

Commit 9c16e05

Browse files
committed
- Masa's comments
1 parent d3df1d3 commit 9c16e05

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

python/tvm/contrib/cutlass/build.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def compile_for_cutlass(mod, cutlass_target):
572572
return final_mod
573573

574574

575-
def finalize_modules(lib, lib_path, tmp_dir):
575+
def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"):
576576
"""Returns lib with any C source, LLVM and static library modules complied and linked in ready
577577
for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does
578578
assume nvcc will be used for final compilation and linking. It is provided here for
@@ -584,22 +584,23 @@ def finalize_modules(lib, lib_path, tmp_dir):
584584
The output from relay.build.
585585
586586
lib_path : string
587-
Name for temporary library .so file.
587+
The path to a shared library which will be generated as the result of the build process.
588588
589-
tmp_dir : Working temporary directory.
589+
tmp_dir : string
590+
A temporary directory where intermediate compiled artifacts will be stored.
590591
591592
Returns
592593
-------
593-
updated_lib : runtime::Module
594-
The given lib with any final compilation and linking steps completed.
594+
updated_lib : runtime.Module
595+
The updated library with all compilation and linking completed.
595596
596597
"""
597598
lib_path = os.path.join(tmp_dir, lib_path)
598599
lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
599600
return runtime.load_module(lib_path)
600601

601602

602-
def finalize_modules_vm(vm_exec, lib_path, tmp_dir):
603+
def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp"):
603604
"""Returns vm_exec with any C source, LLVM and static library modules compiled and linked in
604605
ready for use by the VM executor. This method is not specific to CUTLASS, however it does
605606
assume nvcc will be used for final compilation and linking. It is provided here for
@@ -608,20 +609,28 @@ def finalize_modules_vm(vm_exec, lib_path, tmp_dir):
608609
Parameters
609610
----------
610611
vm_exec : vm.Executable
611-
The output from relay.vm.compile.
612+
The output from relay.vm.compile containing compiled host code and kernels.
612613
613614
lib_path : string
614-
Name for temporary library .so file.
615+
The path to a shared library which will be generated as the result of the build process.
616+
617+
vmcode_path : string
618+
The path where the VM bytecode will be serialized to.
615619
616-
tmp_dir : Working temporary directory.
620+
tmp_dir : string
621+
A temporary directory where intermediate compiled artifacts will be stored.
617622
618623
Returns
619624
-------
620625
updated_vm_exec : vm.Executable
621-
The given lib with any final compilation and linking steps completed.
626+
The updated VM executable with all compilation and linking completed.
622627
"""
623628
code, lib = vm_exec.save()
624629
lib_path = os.path.join(tmp_dir, lib_path)
630+
vmcode_path = os.path.join(tmp_dir, vmcode_path)
625631
lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
632+
with open(vmcode_path, "wb") as fo:
633+
fo.write(code)
626634
lib = tvm.runtime.load_module(lib_path)
635+
code = bytearray(open(vmcode_path, "rb").read())
627636
return tvm.runtime.vm.Executable.load_exec(code, lib)

0 commit comments

Comments
 (0)