Skip to content

Commit ff39930

Browse files
tqchenwweic
authored andcommitted
[AUTOTVM] Refactor measure build func (apache#2927)
1 parent e75edce commit ff39930

File tree

4 files changed

+79
-40
lines changed

4 files changed

+79
-40
lines changed

python/tvm/autotvm/measure/measure_methods.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
2121
rpc as _rpc, target as _target
22-
from ...contrib import nvcc, ndk
22+
from ...contrib import nvcc, ndk, tar
2323

2424
from ..util import get_const_tuple
2525
from ..env import AutotvmGlobalScope
@@ -58,20 +58,20 @@ class LocalBuilder(Builder):
5858
build_func: callable or str
5959
If is 'default', use default build function
6060
If is 'ndk', use function for android ndk
61-
If is callable, use it as custom build function
61+
If is callable, use it as custom build function, expect lib_format field.
6262
"""
6363
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
6464
super(LocalBuilder, self).__init__(timeout, n_parallel)
6565

6666
if isinstance(build_func, str):
6767
if build_func == 'default':
68-
build_func = default_build_func
68+
build_func = tar.tar
6969
elif build_func == 'ndk':
70-
build_func = android_ndk_build_func
70+
build_func = ndk.create_shared
7171
else:
7272
raise ValueError("Invalid build_func" + build_func)
7373

74-
self.build_func = build_func
74+
self.build_func = _wrap_build_func(build_func)
7575
self.executor = LocalExecutor(timeout=timeout)
7676
self.tmp_dir = tempfile.mkdtemp()
7777

@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
349349
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
350350

351351

352-
def default_build_func(measure_input, tmp_dir, **kwargs):
352+
def _wrap_build_func(build_func):
353353
"""
354-
Default build func. This can work for cuda, opencl, llvm backend
354+
Wrap build_func to a function that can be used in measure.
355355
356356
Parameters
357357
----------
358-
measure_input: MeasureInput
359-
The input of measurement
360-
tmp_dir: str
361-
The path of temporary directory to export generated library
362-
"""
363-
tic = time.time()
364-
try:
365-
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
366-
func, arg_info = _build_func_common(measure_input, **kwargs)
367-
func.export_library(filename)
368-
except Exception as e: # pylint: disable=broad-except
369-
return BuildResult(None, None, e, time.time() - tic)
370-
return BuildResult(filename, arg_info, None, time.time() - tic)
371-
372-
373-
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
374-
"""
375-
Build function for android device using ndk.
358+
build_func : The compilation function
359+
We expect fcompile to contain an attr "output_format"
376360
377-
Parameters
378-
----------
379-
measure_input: MeasureInput
380-
The input of measurement
381-
tmp_dir: str
382-
The path of temporary directory to export generated library
361+
Returns
362+
-------
363+
wrapped_build_func : function
364+
The wrapped build function
383365
"""
384-
tic = time.time()
385-
try:
386-
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
387-
func, arg_info = _build_func_common(measure_input, **kwargs)
388-
func.export_library(filename, ndk.create_shared)
389-
except Exception as e: # pylint: disable=broad-except
390-
return BuildResult(None, None, e, time.time() - tic)
391-
return BuildResult(filename, arg_info, None, time.time() - tic)
366+
if not hasattr(build_func, "output_format"):
367+
raise AttributeError("Expect build_func to have the attribute output_format.")
368+
output_format = build_func.output_format
369+
370+
def _wrapped(measure_input, tmp_dir, **kwargs):
371+
"""
372+
Wrapped build func.
373+
374+
Parameters
375+
----------
376+
measure_input: MeasureInput
377+
The input of measurement
378+
379+
tmp_dir: str
380+
The path of temporary directory to export generated library
381+
"""
382+
tic = time.time()
383+
try:
384+
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
385+
getrandbits(64), output_format))
386+
# TODO(tvm-team) consider linline _build_func_common
387+
func, arg_info = _build_func_common(measure_input, **kwargs)
388+
func.export_library(filename, build_func)
389+
except Exception as e: # pylint: disable=broad-except
390+
return BuildResult(None, None, e, time.time() - tic)
391+
return BuildResult(filename, arg_info, None, time.time() - tic)
392+
return _wrapped
392393

393394

394395
def run_through_rpc(measure_input, build_result,

python/tvm/contrib/cc.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,46 @@ def create_shared(output,
2929
cc : str, optional
3030
The compile string.
3131
"""
32-
if sys.platform == "darwin" or sys.platform.startswith('linux'):
32+
if sys.platform == "darwin" or sys.platform.startswith("linux"):
3333
_linux_shared(output, objects, options, cc)
3434
elif sys.platform == "win32":
3535
_windows_shared(output, objects, options)
3636
else:
3737
raise ValueError("Unsupported platform")
3838

3939

40+
# assign so as default output format
41+
create_shared.output_format = "so" if sys.platform != "win32" else "dll"
42+
43+
44+
def cross_compiler(cc, options=None, output_format="so"):
45+
"""Create a cross compiler function.
46+
47+
Parameters
48+
----------
49+
cc : str
50+
The cross compiler name.
51+
52+
options : list, optional
53+
List of additional optional string.
54+
55+
output_format : str, optional
56+
Library output format.
57+
58+
Returns
59+
-------
60+
fcompile : function
61+
A compilation function that can be passed to export_library.
62+
"""
63+
def _fcompile(outputs, objects, opts=None):
64+
opts = opts if opts else []
65+
if options:
66+
opts += options
67+
_linux_shared(outputs, objects, opts, cc=cc)
68+
_fcompile.output_format = output_format
69+
return _fcompile
70+
71+
4072
def _linux_shared(output, objects, options, cc="g++"):
4173
cmd = [cc]
4274
cmd += ["-shared", "-fPIC"]

python/tvm/contrib/tar.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def tar(output, files):
4242
msg += py_str(out)
4343
raise RuntimeError(msg)
4444

45+
# assign output format
46+
tar.output_format = "tar"
47+
4548

4649
def untar(tar_file, directory):
4750
"""Unpack all tar files into the directory

python/tvm/contrib/xcode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"):
9898
raise RuntimeError(msg)
9999

100100

101+
# assign so as default output format
102+
create_dylib.output_format = "dylib"
103+
101104
def compile_metal(code, path_target=None, sdk="macosx"):
102105
"""Compile metal with CLI tool from env.
103106

0 commit comments

Comments
 (0)