Skip to content

Commit 81fd125

Browse files
committed
[TOP] split, reshape, concatenate (#43)
1 parent d34036f commit 81fd125

File tree

7 files changed

+118
-39
lines changed

7 files changed

+118
-39
lines changed

nnvm/README.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm)
44
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
55

6-
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to:
6+
NNVM is a reusable computational graph compilation stack for deep learning systems. It provides modules to:
77

88
- Represent deep learning workloads from front-end frameworks via a graph IR.
99
- Optimize computation graphs to improve performance.
@@ -20,26 +20,23 @@ from tvm.contrib import graph_runtime, rpc
2020
import nnvm.frontend
2121
import nnvm.compiler
2222

23-
# get model from frameworks
23+
# GET model from frameworks
2424
# change xyz to supported framework name.
2525
graph, params = nnvm.frontend.from_xyz(...)
2626

27-
# optimize and compile the graph to get a deployable module
27+
# OPTIMIZE and COMPILE the graph to get a deployable module
2828
# target can be "opencl", "llvm", "metal" or any target supported by tvm
2929
target = "cuda"
30-
graph, lib, params = nnvm.compiler.build(
31-
graph, target, shape={"data", data_shape}, params=params)
30+
graph, lib, params = nnvm.compiler.build(graph, target, {"data", data_shape}, params=params)
3231

33-
# deploy and run on gpu(0)
32+
# DEPLOY and run on gpu(0)
3433
module = graph_runtime.create(graph, lib, tvm.gpu(0))
3534
module.set_input(**params)
35+
module.run(data=data_array)
3636
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
37-
for data_array in dataset:
38-
module.set_input("data", data_array)
39-
module.run()
40-
module.get_output(0, output)
37+
module.get_output(0, output)
4138

42-
# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime
39+
# DEPLOY to REMOTE mobile/rasp/browser with minimum tvm rpc runtime
4340
# useful for quick experiments on mobile devices
4441
remote = rpc.connect(remote_host, remote_port)
4542
lib.export_library("mylib.so")

nnvm/python/nnvm/compiler/param_dict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=invalid-name
12
"""Helper utility to save parameter dict"""
23
import tvm
34

nnvm/python/nnvm/top/transform.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
"""Tensor transformation ops"""
33
from __future__ import absolute_import
44

5-
import tvm
65
import topi
76
from .tensor import _fschedule_broadcast, _fschedule_injective
87
from . import registry as reg
98
from .registry import OpPattern
109

11-
# Need add reshape
10+
# expand_dims
1211
@reg.register_compute("expand_dims")
1312
def compute_expand_dims(attrs, inputs, out_info):
1413
"""Compute definition of expand_dims"""
@@ -18,34 +17,46 @@ def compute_expand_dims(attrs, inputs, out_info):
1817
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
1918
reg.register_schedule("expand_dims", _fschedule_broadcast)
2019

21-
20+
# transpose
2221
@reg.register_compute("transpose")
2322
def compute_transpose(attrs, inputs, out_info):
24-
"""Compute definition of expand_dims"""
23+
"""Compute definition of transpose"""
2524
axes = attrs.get_int_tuple("axes")
2625
axes = tuple(axes) if axes else None
2726
return topi.transpose(inputs[0], axes)
2827
reg.register_pattern("transpose", OpPattern.INJECTIVE)
2928
reg.register_schedule("transpose", _fschedule_injective)
3029

31-
32-
def _flatten_index(indices, shape):
33-
"""flatten the index to 1D"""
34-
idx = 0
35-
for i, value in enumerate(shape):
36-
if i != 0:
37-
idx *= value
38-
idx = idx + indices[i]
39-
return idx
40-
4130
# reshape
4231
@reg.register_compute("reshape")
4332
def compute_reshape(attrs, inputs, out_info):
44-
"""Compute definition of softmax"""
45-
# TODO(sxj) add support for general reshape
46-
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
33+
"""Compute definition of reshape"""
4734
oshape = out_info[0].shape
48-
x = inputs[0]
49-
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
35+
return topi.reshape(inputs[0], oshape)
5036
reg.register_pattern("reshape", OpPattern.INJECTIVE)
5137
reg.register_schedule("reshape", _fschedule_injective)
38+
39+
# concatenate
40+
@reg.register_compute("concatenate")
41+
def compute_concatenate(attrs, inputs, out_info):
42+
"""Compute definition of concatenate"""
43+
axis = attrs.get_int("axis")
44+
return topi.concatenate([x for x in inputs], axis=axis)
45+
46+
reg.register_pattern("concatenate", OpPattern.INJECTIVE)
47+
reg.register_schedule("concatenate", _fschedule_injective)
48+
49+
# split
50+
@reg.register_compute("split")
51+
def compute_split(attrs, inputs, out_info):
52+
"""Compute definition of split"""
53+
x = attrs["indices_or_sections"]
54+
if x.startswith("(") or x.startswith("["):
55+
indices_or_sections = attrs.get_int_tuple("indices_or_sections")
56+
else:
57+
indices_or_sections = attrs.get_int("indices_or_sections")
58+
return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis"))
59+
60+
61+
reg.register_pattern("split", OpPattern.INJECTIVE)
62+
reg.register_schedule("split", _fschedule_injective)

nnvm/src/compiler/graph_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <nnvm/graph.h>
1010
#include <vector>
11+
#include <string>
1112
#include "../../tvm/src/runtime/graph/graph_runtime.h"
1213

1314
namespace nnvm {

nnvm/src/top/tensor/transform.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,17 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
225225
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
226226
CHECK_LT(param.axis, dshape.ndim());
227227
TShape oshape = dshape;
228-
dim_t total = 0;
229-
for (size_t i = 1; i < num_outputs; ++i) {
230-
oshape[param.axis] = param.indices_or_sections[i - 1];
231-
total += oshape[param.axis];
232-
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i - 1, oshape);
228+
dim_t begin = 0;
229+
for (size_t i = 0; i < num_outputs - 1; ++i) {
230+
CHECK_GT(param.indices_or_sections[i], begin)
231+
<< "indices_or_sections need to be a sorted ascending list";
232+
oshape[param.axis] = param.indices_or_sections[i] - begin;
233+
begin = param.indices_or_sections[i];
234+
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
233235
}
234-
CHECK_LT(total, dshape[param.axis])
236+
CHECK_LT(begin, dshape[param.axis])
235237
<< "The sum of sections must match the input.shape[axis]";
236-
oshape[param.axis] = dshape[param.axis] - total;
238+
oshape[param.axis] = dshape[param.axis] - begin;
237239
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, num_outputs - 1, oshape);
238240
}
239241
return true;
@@ -256,11 +258,11 @@ NNVM_REGISTER_OP(split)
256258
along which to split the array.
257259
258260
)code" NNVM_ADD_FILELINE)
259-
.add_argument("data", "Tensor", "List of arrays to concatenate")
261+
.add_argument("data", "Tensor", "Array to be splitted")
260262
.add_arguments(SplitParam::__FIELDS__())
261263
.set_attr_parser(SplitParamParser)
262264
.set_attr<FInferShape>("FInferShape", SplitInferShape)
263-
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
265+
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
264266
.set_num_inputs(1)
265267
.set_num_outputs(SplitNumOutputs)
266268
.set_support_level(1);

nnvm/tests/python/compiler/test_top_level1.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,52 @@ def test_batchnorm():
177177
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
178178

179179

180+
def verify_concatenate(ishape, axis):
181+
x = [sym.Variable("x%d" % i) for i in range(len(ishape))]
182+
y = sym.concatenate(*x, axis=axis) + 1
183+
dtype = "float32"
184+
for target, ctx in ctx_list():
185+
# set input
186+
data = []
187+
for i, shape in enumerate(ishape):
188+
data.append(np.random.uniform(size=shape).astype(dtype))
189+
pdict = {"x%d" % i : v for i, v in enumerate(data)}
190+
shape = {"x%d" % i : v.shape for i, v in enumerate(data)}
191+
graph, lib, _ = nnvm.compiler.build(y, target, shape)
192+
m = graph_runtime.create(graph, lib, ctx)
193+
m.run(**pdict)
194+
out_np = np.concatenate(data, axis=axis) + 1
195+
out = m.get_output(0, tvm.nd.empty(out_np.shape))
196+
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
197+
198+
def test_concatenate():
199+
verify_concatenate([(2, 3, 4), (1, 3, 4)], axis=0)
200+
verify_concatenate([(2, 4), (2, 7)], axis=1)
201+
202+
203+
def verify_split(ishape, indices_or_sections, axis):
204+
x = sym.Variable("x")
205+
y = sym.split(x, indices_or_sections=indices_or_sections, axis=axis)
206+
dtype = "float32"
207+
x_np = np.random.uniform(size=ishape).astype(dtype)
208+
res = np.split(x_np, indices_or_sections, axis=axis)
209+
for target, ctx in ctx_list():
210+
# set input
211+
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
212+
m = graph_runtime.create(graph, lib, ctx)
213+
m.run(x=x_np)
214+
for i, arr in enumerate(res):
215+
out = m.get_output(i, tvm.nd.empty(arr.shape))
216+
np.testing.assert_allclose(out.asnumpy(), arr, atol=1e-5, rtol=1e-5)
217+
218+
def test_split():
219+
verify_split((2, 3), 2, axis=0)
220+
verify_split((5, 3), [3], axis=0)
221+
verify_split((5, 9, 3), [3, 4], axis=1)
222+
180223
if __name__ == "__main__":
224+
test_split()
225+
test_concatenate()
181226
test_log_softmax()
182227
test_batchnorm()
183228
test_dense()

nnvm/tests/python/compiler/test_top_level4.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,29 @@ def test_reduce():
5151
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
5252

5353

54+
def verify_reshape(dshape, oshape):
55+
x = sym.Variable("x")
56+
y = sym.reshape(x, shape=oshape)
57+
y = y + 1
58+
dtype = "float32"
59+
for target, ctx in ctx_list():
60+
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
61+
m = graph_runtime.create(graph, lib, ctx)
62+
# set input
63+
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
64+
m.run(x=data)
65+
out_np = data.asnumpy().reshape(oshape) + 1
66+
out = m.get_output(0, tvm.nd.empty(out_np.shape))
67+
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
68+
69+
def test_reshape():
70+
verify_reshape((2, 3, 4), (-1, 2, 1))
71+
verify_reshape((2, 3, 4), (8, 3))
72+
verify_reshape((4, 7), (2, 7, 2))
73+
74+
5475
if __name__ == "__main__":
76+
test_reshape()
5577
test_reduce()
5678
test_tranpose()
5779
print(nnvm.compiler.engine.dump())

0 commit comments

Comments
 (0)