Skip to content

Commit cf0134d

Browse files
author
ZihengJiang
committed
[TensorOp] Add testcase for scheduling tensor_compute_op.
1 parent 753cb03 commit cf0134d

File tree

9 files changed

+344
-316
lines changed

9 files changed

+344
-316
lines changed

include/tvm/operation.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ class TensorComputeOpNode : public OperationNode {
186186
public:
187187
Array<IterVar> axis;
188188

189+
Array<IterVar> out_axis;
190+
189191
Array<IterVar> tensor_axis;
190192

191193
Array<IterVar> reduce_axis;
@@ -229,20 +231,21 @@ class TensorComputeOpNode : public OperationNode {
229231
v->Visit("name", &name);
230232
v->Visit("tag", &tag);
231233
v->Visit("axis", &axis);
234+
v->Visit("out_axis", &out_axis);
232235
v->Visit("tensor_axis", &tensor_axis);
233236
v->Visit("reduce_axis", &reduce_axis);
234237
v->Visit("inputs", &inputs);
235238
}
236239

237240
static Operation make(std::string name,
238241
std::string tag,
239-
Array<IterVar> axis,
242+
Array<IterVar> out_axis,
240243
Array<IterVar> tensor_axis,
241244
TensorIntrinCall intrin_call);
242245

243246
static Operation make(std::string name,
244247
std::string tag,
245-
Array<IterVar> axis,
248+
Array<IterVar> out_axis,
246249
Array<IterVar> tensor_axis,
247250
Array<IterVar> reduce_axis,
248251
Array<Tensor> tensors,

include/tvm/tensor_intrin.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,6 @@ class TensorIntrin : public NodeRef {
2626
*/
2727
inline const TensorIntrinNode* operator->() const;
2828

29-
// template<typename... Args>
30-
// inline Stmt operator()(Args&& ...args) const {
31-
// Array<Expr> inputs{std::forward<Args>(args)...};
32-
// return operator()(inputs);
33-
// }
34-
35-
// TVM_DLL TensorIntrinCall operator()(Array<Expr> inputs) const;
36-
3729
/*! \brief specify container node */
3830
using ContainerType = TensorIntrinNode;
3931
};

python/tvm/api.py

Lines changed: 4 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import absolute_import as _abs
44

55
from numbers import Integral as _Integral
6-
from collections import namedtuple
76

87
from ._ffi.base import string_types
98
from ._ffi.node import register_node, NodeBase
@@ -244,6 +243,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
244243
raise ValueError("nested tag is not allowed for now")
245244
tag = _tag.TagScope.current.tag
246245
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
246+
# for python3
247+
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
247248
ndim = len(shape)
248249
code = fcompute.__code__
249250

@@ -254,7 +255,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
254255
arg_names = code.co_varnames[:code.co_argcount]
255256
out_ndim = code.co_argcount
256257

257-
# TODO check ndim, arg_names
258258
if out_ndim != len(arg_names):
259259
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
260260

@@ -264,8 +264,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
264264
if isinstance(body, _tensor.TensorIntrinCall):
265265
tensor_var = []
266266
for i, s in enumerate(shape[out_ndim:]):
267-
name = "ax" + str(i)
268-
tensor_var.append(_IterVar((0, s), name, 4))
267+
var_name = "ax" + str(i)
268+
tensor_var.append(_IterVar((0, s), var_name, 4))
269269
op_node = _api_internal._TensorComputeOp(name,
270270
tag,
271271
dim_var,
@@ -275,7 +275,6 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
275275
if not isinstance(body, (list, tuple)):
276276
body = [body]
277277
body = convert(body)
278-
# print('body: {0}'.format(body))
279278
op_node = _api_internal._ComputeOp(
280279
name, tag, attrs, dim_var, body)
281280

@@ -353,88 +352,6 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
353352
return res[0] if len(res) == 1 else res
354353

355354

356-
def _get_region(tslice):
357-
region = []
358-
for idx in tslice.indices:
359-
if isinstance(idx, slice):
360-
assert idx.step is None
361-
region.append(Range(idx.start, idx.stop))
362-
else:
363-
if isinstance(idx, _schedule.IterVar):
364-
begin = idx.var
365-
else:
366-
begin = idx
367-
region.append(_make.range_by_min_extent(begin, 1))
368-
return region
369-
370-
371-
# def tensor_op(out_dims,
372-
# in_dims, # pylint: disable=unused-argument
373-
# finputs,
374-
# intrin,
375-
# raxis=None,
376-
# name='tensor_op',
377-
# tag=""):
378-
# """Construct new tensors with intrinsic.
379-
#
380-
# Parameters
381-
# ----------
382-
# out_dims: tuple
383-
# The dimensions out of the tensorized region, which can be
384-
# scheduled through `reorder`, `split`.
385-
#
386-
# in_dims: tuple
387-
# The dimensions inside of the tensorized region, which cannot
388-
# be manipulated.
389-
#
390-
# finputs: lambda function of out_dims -> list of TensorSlice
391-
# Specifies involved regions of input tensors.
392-
#
393-
# tensor_intrin : TensorIntrin
394-
# The tensor intrinsic used for computation.
395-
#
396-
# raxis : IterVar
397-
# An iteration variable representing the value.
398-
#
399-
# name: str, optional
400-
# The name hint of the tensor
401-
#
402-
# tag: str, optional
403-
# Additonal tag information about the compute.
404-
# """
405-
# if _tag.TagScope.current is not None:
406-
# if tag != "":
407-
# raise ValueError("nested tag is not allowed for now")
408-
# tag = _tag.TagScope.current.tag
409-
#
410-
# code = finputs.__code__
411-
# if finputs.__code__.co_argcount == 0:
412-
# arg_names = ["i%d" % i for i in range(ndim)]
413-
# else:
414-
# arg_names = code.co_varnames[:code.co_argcount]
415-
#
416-
# if len(out_dims) != len(arg_names):
417-
# raise ValueError("finputs do not match dimension, ndim=%d" % out_dims)
418-
#
419-
# out_var = [_IterVar((0, extent), arg_name, 0)
420-
# for arg_name, extent in zip(arg_names, out_dims)]
421-
# if isinstance(raxis, _schedule.IterVar):
422-
# raxis = [raxis]
423-
# if raxis is None:
424-
# raxis = []
425-
# tensor_regions = finputs(*[v.var for v in out_var])
426-
#
427-
# op = _api_internal._TensorOp(name,
428-
# tag,
429-
# out_var,
430-
# raxis,
431-
# [x.tensor for x in tensor_regions],
432-
# [_get_region(x) for x in tensor_regions],
433-
# intrin)
434-
# # only support single output
435-
# return op.output(0)
436-
437-
438355
def extern(shape,
439356
inputs,
440357
fcompute,

python/tvm/tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def reduce_axis(self):
161161
return self.__getattr__("reduce_axis")
162162

163163

164+
@register_node
165+
class TensorComputeOp(Operation):
166+
"""Tensor operation."""
167+
pass
168+
169+
164170
@register_node
165171
class ScanOp(Operation):
166172
"""Scan operation."""
@@ -174,9 +180,3 @@ def scan_axis(self):
174180
class ExternOp(Operation):
175181
"""Extern operation."""
176182
pass
177-
178-
179-
@register_node
180-
class TensorComputeOp(Operation):
181-
"""Tensor operation."""
182-
pass

python/tvm/tensor_intrin.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,25 @@
66
from . import stmt as _stmt
77
from . import make as _make
88
from . import tensor as _tensor
9+
from . import schedule as _schedule
910
from .build_module import current_build_config
1011
from ._ffi.node import NodeBase, register_node
1112

13+
14+
def _get_region(tslice):
15+
region = []
16+
for idx in tslice.indices:
17+
if isinstance(idx, slice):
18+
assert idx.step is None
19+
region.append(_api.Range(idx.start, idx.stop))
20+
else:
21+
if isinstance(idx, _schedule.IterVar):
22+
begin = idx.var
23+
else:
24+
begin = idx
25+
region.append(_make.range_by_min_extent(begin, 1))
26+
return region
27+
1228
@register_node
1329
class TensorIntrin(NodeBase):
1430
"""Tensor intrinsic functions for certain computation.
@@ -19,7 +35,7 @@ class TensorIntrin(NodeBase):
1935
"""
2036
def __call__(self, *args, **kwargs):
2137
tensors = [x.tensor for x in args]
22-
regions = [_api._get_region(x) for x in args]
38+
regions = [_get_region(x) for x in args]
2339
reduce_axis = []
2440
if "reduce_axis" in kwargs:
2541
reduce_axis = kwargs["reduce_axis"]

src/lang/tensor.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ Tensor Operation::output(size_t i) const {
3737
return Tensor(node);
3838
}
3939

40-
// TensorIntrinCall TensorIntrin::operator()(Array<Expr> inputs) const {
41-
// using HalideIR::Internal::Call;
42-
// LOG(FATAL) << "CallTensorIntrin";
43-
// CHECK_EQ(tensors.size(), regions.size());
44-
// }
45-
4640
Tensor TensorNode::make(Array<Expr> shape,
4741
Type dtype,
4842
Operation op,

0 commit comments

Comments
 (0)