Skip to content

Commit d4af7ad

Browse files
authored
[TEST/PYTHON] Add unittest folder, add a build pipeline. Rename Buffer.ptr to Buffer.data to be consistent with Array. (#29)
1 parent 891630e commit d4af7ad

26 files changed

+155
-28
lines changed

include/tvm/buffer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class BufferNode : public Node {
6161
/*! \brief optional name of the buffer */
6262
std::string name;
6363
/*! \brief The pointer to the head of the data */
64-
Var ptr;
64+
Var data;
6565
/*! \brief The shape of the buffer */
6666
Array<Expr> shape;
6767
/*!
@@ -77,7 +77,7 @@ class BufferNode : public Node {
7777

7878
void VisitAttrs(AttrVisitor* v) final {
7979
v->Visit("name", &name);
80-
v->Visit("ptr", &ptr);
80+
v->Visit("data", &data);
8181
v->Visit("shape", &shape);
8282
v->Visit("strides", &strides);
8383
v->Visit("dtype", &dtype);

python/tvm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717

1818
from ._base import TVMError
1919
from .api import *
20+
from .build import build

python/tvm/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def Buffer(shape, dtype=None,
145145
name="buffer",
146146
ptr=None,
147147
strides=None):
148-
"""Create a new buffer
148+
"""Create a new symbolic buffer
149149
150150
Parameters
151151
----------

python/tvm/build.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""The build pipeline in python.
2+
3+
Eventually some of these pipelines will be moved to C++.
4+
But the first pipeline will be kept in python for ease of change and evolving.
5+
"""
6+
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
7+
8+
from . import api
9+
from . import tensor
10+
from . import schedule
11+
from . import expr
12+
from . import ir_pass
13+
from . import codegen
14+
15+
def build(sch,
16+
args,
17+
target,
18+
name="default_function",
19+
binds=None,
20+
record_codes=None):
21+
"""Build a function with arguments as signiture.
22+
23+
Parameters
24+
----------
25+
sch : tvm.Schedule
26+
The schedule to be builded
27+
28+
args : list of Buffer or Tensor or Var
29+
The argument lists to the function.
30+
31+
target : str
32+
The target of the compilation.
33+
34+
name : str
35+
The name of result function.
36+
37+
binds : dict, optional
38+
Dictionary that maps the binding of symbolic buffer to Tensor.
39+
By default, a new buffer is created for each tensor in the argument.
40+
41+
Returns
42+
-------
43+
f : Function, or pair of functions
44+
The result function.
45+
If the function requires host space allocation,
46+
a pair of functions will be returned.
47+
"""
48+
binds = {} if binds is None else binds.copy()
49+
arg_list = []
50+
for x in args:
51+
if isinstance(x, tensor.Tensor):
52+
buf = api.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
53+
assert x not in binds
54+
binds[x] = buf
55+
arg_list.append(buf)
56+
elif isinstance(x, schedule.Buffer):
57+
arg_list.append(x)
58+
elif isinstance(x, expr.Var):
59+
arg_list.append(x)
60+
else:
61+
raise ValueError("args must be Tensor, Buffer or Var")
62+
63+
# lowering
64+
bounds = schedule.InferBound(sch)
65+
stmt = ir_pass.ScheduleOps(sch, bounds)
66+
stmt = ir_pass.StorageFlatten(stmt, binds)
67+
stmt = ir_pass.Simplify(stmt)
68+
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
69+
fsplits = codegen.SplitHostDevice(fapi)
70+
71+
if record_codes is not None:
72+
output_ssa = False
73+
for i, f in enumerate(fsplits):
74+
t = target if i >= 1 else "c"
75+
record_codes.append(codegen.CompileToC(f, output_ssa, t))
76+
77+
if target == "cuda":
78+
ret = codegen.BuildNVRTC(fsplits, "stackvm")
79+
elif target == "opencl":
80+
ret = codegen.BuildOpenCL(fsplits, "stackvm")
81+
else:
82+
raise ValueError("Unknown target %s" % target)
83+
return ret

python/tvm/collections.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ class IterVar(NodeBase, _expr.ExprOp):
5858
pass
5959

6060

61-
@register_node
62-
class Buffer(NodeBase):
63-
"""Represent a Buffer in TVM."""
64-
pass
65-
66-
6761
@register_node
6862
class LoweredFunc(NodeBase):
6963
"""Represent a LoweredFunc in TVM."""

python/tvm/schedule.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from . import _api_internal
66
from . import tensor as _tensor
77

8+
@register_node
9+
class Buffer(NodeBase):
10+
"""Represent a Buffer in TVM."""
11+
pass
12+
813
@register_node
914
class Split(NodeBase):
1015
"""Split operation on axis."""

src/codegen/make_api.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ LoweredFunc MakeAPI(Stmt body,
138138
UIntImm::make(UInt(16), dtype.lanes()));
139139
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
140140
// Data Field
141-
if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
141+
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
142142
v_arg->name_hint + ".data")) {
143-
Var vptr(buf->ptr);
143+
Var vptr(buf->data);
144144
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
145145
}
146146
// shape field

src/lang/buffer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
4545

4646
Expr Buffer::MakeLoad(Array<Expr> index) const {
4747
const BufferNode* n = operator->();
48-
return ir::Load::make(n->dtype, n->ptr, BufferOffset(n, index));
48+
return ir::Load::make(n->dtype, n->data, BufferOffset(n, index));
4949
}
5050

5151
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
5252
const BufferNode* n = operator->();
5353
CHECK_EQ(value.type(), n->dtype);
54-
return ir::Store::make(n->ptr, value, BufferOffset(n, index));
54+
return ir::Store::make(n->data, value, BufferOffset(n, index));
5555
}
5656

5757
Buffer BufferNode::make(std::string name,
58-
Var ptr,
58+
Var data,
5959
Array<Expr> shape,
6060
Array<Expr> strides,
6161
Type dtype) {
6262
auto n = std::make_shared<BufferNode>();
6363
n->name = name;
64-
n->ptr = ptr;
64+
n->data = data;
6565
n->shape = shape;
6666
n->strides = strides;
6767
n->dtype = dtype;

src/pass/storage_flatten.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class StorageFlattener : public IRMutator {
138138
buf_map_[key].released = true;
139139

140140
return Allocate::make(
141-
e.buffer->ptr, e.buffer->dtype, e.buffer->shape,
141+
e.buffer->data, e.buffer->dtype, e.buffer->shape,
142142
make_const(Bool(e.buffer->dtype.lanes()), true), body);
143143
}
144144
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import tvm
2+
import numpy as np
3+
4+
def test_add():
5+
# graph
6+
n = tvm.Var('n')
7+
A = tvm.placeholder((n,), name='A')
8+
B = tvm.placeholder((n,), name='B')
9+
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
10+
# schedule
11+
s = tvm.Schedule(C.op)
12+
# create iter var and assign them tags.
13+
num_thread = 256
14+
block_x = tvm.IterVar(thread_tag="blockIdx.x")
15+
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
16+
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=block_x)
17+
_, x = s[C].split(x, outer=thread_x)
18+
19+
# one line to build the function.
20+
codes = []
21+
fadd = tvm.build(s, args=[A, B, C],
22+
target="cuda", name="myadd",
23+
record_codes=codes)
24+
for c in codes:
25+
print(c)
26+
27+
# call the function
28+
num_device = 1
29+
for i in range(num_device):
30+
ctx = tvm.gpu(i)
31+
if not ctx.enabled:
32+
continue
33+
# launch the kernel.
34+
n = 1027
35+
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
36+
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
37+
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
38+
fadd(a, b, c)
39+
np.testing.assert_allclose(
40+
c.asnumpy(), a.asnumpy() + b.asnumpy())
41+
42+
43+
if __name__ == "__main__":
44+
test_add()

0 commit comments

Comments
 (0)