|
| 1 | +"""The build pipeline in python. |
| 2 | +
|
| 3 | +Eventually some of these pipelines will be moved to C++. |
| 4 | +But the first pipeline will be kept in python for ease of change and evolving. |
| 5 | +""" |
| 6 | +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments |
| 7 | + |
| 8 | +from . import api |
| 9 | +from . import tensor |
| 10 | +from . import schedule |
| 11 | +from . import expr |
| 12 | +from . import ir_pass |
| 13 | +from . import codegen |
| 14 | + |
| 15 | +def build(sch, |
| 16 | + args, |
| 17 | + target, |
| 18 | + name="default_function", |
| 19 | + binds=None, |
| 20 | + record_codes=None): |
| 21 | + """Build a function with arguments as signiture. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + sch : tvm.Schedule |
| 26 | + The schedule to be builded |
| 27 | +
|
| 28 | + args : list of Buffer or Tensor or Var |
| 29 | + The argument lists to the function. |
| 30 | +
|
| 31 | + target : str |
| 32 | + The target of the compilation. |
| 33 | +
|
| 34 | + name : str |
| 35 | + The name of result function. |
| 36 | +
|
| 37 | + binds : dict, optional |
| 38 | + Dictionary that maps the binding of symbolic buffer to Tensor. |
| 39 | + By default, a new buffer is created for each tensor in the argument. |
| 40 | +
|
| 41 | + Returns |
| 42 | + ------- |
| 43 | + f : Function, or pair of functions |
| 44 | + The result function. |
| 45 | + If the function requires host space allocation, |
| 46 | + a pair of functions will be returned. |
| 47 | + """ |
| 48 | + binds = {} if binds is None else binds.copy() |
| 49 | + arg_list = [] |
| 50 | + for x in args: |
| 51 | + if isinstance(x, tensor.Tensor): |
| 52 | + buf = api.Buffer(x.shape, dtype=x.dtype, name=x.op.name) |
| 53 | + assert x not in binds |
| 54 | + binds[x] = buf |
| 55 | + arg_list.append(buf) |
| 56 | + elif isinstance(x, schedule.Buffer): |
| 57 | + arg_list.append(x) |
| 58 | + elif isinstance(x, expr.Var): |
| 59 | + arg_list.append(x) |
| 60 | + else: |
| 61 | + raise ValueError("args must be Tensor, Buffer or Var") |
| 62 | + |
| 63 | + # lowering |
| 64 | + bounds = schedule.InferBound(sch) |
| 65 | + stmt = ir_pass.ScheduleOps(sch, bounds) |
| 66 | + stmt = ir_pass.StorageFlatten(stmt, binds) |
| 67 | + stmt = ir_pass.Simplify(stmt) |
| 68 | + fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list)) |
| 69 | + fsplits = codegen.SplitHostDevice(fapi) |
| 70 | + |
| 71 | + if record_codes is not None: |
| 72 | + output_ssa = False |
| 73 | + for i, f in enumerate(fsplits): |
| 74 | + t = target if i >= 1 else "c" |
| 75 | + record_codes.append(codegen.CompileToC(f, output_ssa, t)) |
| 76 | + |
| 77 | + if target == "cuda": |
| 78 | + ret = codegen.BuildNVRTC(fsplits, "stackvm") |
| 79 | + elif target == "opencl": |
| 80 | + ret = codegen.BuildOpenCL(fsplits, "stackvm") |
| 81 | + else: |
| 82 | + raise ValueError("Unknown target %s" % target) |
| 83 | + return ret |
0 commit comments