|
19 | 19 |
|
20 | 20 | from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ |
21 | 21 | rpc as _rpc, target as _target |
22 | | -from ...contrib import nvcc, ndk |
| 22 | +from ...contrib import nvcc, ndk, tar |
23 | 23 |
|
24 | 24 | from ..util import get_const_tuple |
25 | 25 | from ..env import AutotvmGlobalScope |
@@ -58,20 +58,20 @@ class LocalBuilder(Builder): |
58 | 58 | build_func: callable or str |
59 | 59 | If is 'default', use default build function |
60 | 60 | If is 'ndk', use function for android ndk |
61 | | - If is callable, use it as custom build function |
| 61 | + If is callable, use it as custom build function, expect lib_format field. |
62 | 62 | """ |
63 | 63 | def __init__(self, timeout=10, n_parallel=None, build_func='default'): |
64 | 64 | super(LocalBuilder, self).__init__(timeout, n_parallel) |
65 | 65 |
|
66 | 66 | if isinstance(build_func, str): |
67 | 67 | if build_func == 'default': |
68 | | - build_func = default_build_func |
| 68 | + build_func = tar.tar |
69 | 69 | elif build_func == 'ndk': |
70 | | - build_func = android_ndk_build_func |
| 70 | + build_func = ndk.create_shared |
71 | 71 | else: |
72 | 72 | raise ValueError("Invalid build_func" + build_func) |
73 | 73 |
|
74 | | - self.build_func = build_func |
| 74 | + self.build_func = _wrap_build_func(build_func) |
75 | 75 | self.executor = LocalExecutor(timeout=timeout) |
76 | 76 | self.tmp_dir = tempfile.mkdtemp() |
77 | 77 |
|
@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti |
349 | 349 | return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) |
350 | 350 |
|
351 | 351 |
|
352 | | -def default_build_func(measure_input, tmp_dir, **kwargs): |
| 352 | +def _wrap_build_func(build_func): |
353 | 353 | """ |
354 | | - Default build func. This can work for cuda, opencl, llvm backend |
| 354 | + Wrap build_func to a function that can be used in measure. |
355 | 355 |
|
356 | 356 | Parameters |
357 | 357 | ---------- |
358 | | - measure_input: MeasureInput |
359 | | - The input of measurement |
360 | | - tmp_dir: str |
361 | | - The path of temporary directory to export generated library |
362 | | - """ |
363 | | - tic = time.time() |
364 | | - try: |
365 | | - filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) |
366 | | - func, arg_info = _build_func_common(measure_input, **kwargs) |
367 | | - func.export_library(filename) |
368 | | - except Exception as e: # pylint: disable=broad-except |
369 | | - return BuildResult(None, None, e, time.time() - tic) |
370 | | - return BuildResult(filename, arg_info, None, time.time() - tic) |
371 | | - |
372 | | - |
373 | | -def android_ndk_build_func(measure_input, tmp_dir, **kwargs): |
374 | | - """ |
375 | | - Build function for android device using ndk. |
| 358 | + build_func : The compilation function |
| 359 | + We expect fcompile to contain an attr "output_format" |
376 | 360 |
|
377 | | - Parameters |
378 | | - ---------- |
379 | | - measure_input: MeasureInput |
380 | | - The input of measurement |
381 | | - tmp_dir: str |
382 | | - The path of temporary directory to export generated library |
| 361 | + Returns |
| 362 | + ------- |
| 363 | + wrapped_build_func : function |
| 364 | + The wrapped build function |
383 | 365 | """ |
384 | | - tic = time.time() |
385 | | - try: |
386 | | - filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64)) |
387 | | - func, arg_info = _build_func_common(measure_input, **kwargs) |
388 | | - func.export_library(filename, ndk.create_shared) |
389 | | - except Exception as e: # pylint: disable=broad-except |
390 | | - return BuildResult(None, None, e, time.time() - tic) |
391 | | - return BuildResult(filename, arg_info, None, time.time() - tic) |
| 366 | + if not hasattr(build_func, "output_format"): |
| 367 | + raise AttributeError("Expect build_func to have the attribute output_format.") |
| 368 | + output_format = build_func.output_format |
| 369 | + |
| 370 | + def _wrapped(measure_input, tmp_dir, **kwargs): |
| 371 | + """ |
| 372 | + Wrapped build func. |
| 373 | +
|
| 374 | + Parameters |
| 375 | + ---------- |
| 376 | + measure_input: MeasureInput |
| 377 | + The input of measurement |
| 378 | +
|
| 379 | + tmp_dir: str |
| 380 | + The path of temporary directory to export generated library |
| 381 | + """ |
| 382 | + tic = time.time() |
| 383 | + try: |
| 384 | + filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( |
| 385 | + getrandbits(64), output_format)) |
| 386 | + # TODO(tvm-team) consider linline _build_func_common |
| 387 | + func, arg_info = _build_func_common(measure_input, **kwargs) |
| 388 | + func.export_library(filename, build_func) |
| 389 | + except Exception as e: # pylint: disable=broad-except |
| 390 | + return BuildResult(None, None, e, time.time() - tic) |
| 391 | + return BuildResult(filename, arg_info, None, time.time() - tic) |
| 392 | + return _wrapped |
392 | 393 |
|
393 | 394 |
|
394 | 395 | def run_through_rpc(measure_input, build_result, |
|
0 commit comments