Skip to content

Commit 4b14ad4

Browse files
wereWei Chen
authored andcommitted
[Hybrid Script] Unify the symbol tables to one; support tvm.container.Array (apache#2366)
1 parent 6256aa3 commit 4b14ad4

File tree

6 files changed

+258
-157
lines changed

6 files changed

+258
-157
lines changed

docs/langref/hybrid_script.rst

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,23 @@ The current parse interface looks like:
5252
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
5353
5454
55-
If we pass these tvm tensors to this function, it returns a op node:
55+
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
56+
or ``tvm.container.Array``, to this function, it returns a op node:
5657

5758
.. code-block:: python
5859
5960
a = tvm.placeholder((100, ), name='a')
6061
b = tvm.placeholder((99, ), name='b')
6162
c = outer_product(a, b, c) # return the output tensor(s) of the operator
6263
63-
**Under construction, we are still deciding what kind of node should be returned.**
64+
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
65+
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
66+
to LLVM module.
6467

6568
Tuning
6669
~~~~~~
6770

68-
**Under construction, not truly supported yet.**
71+
**Under construction, not supported yet.**
6972

7073
Follow up the example above, you can use some tvm like interfaces to tune the code:
7174

@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
8689
these **4** keywords to annotate the corresponding types of for loops.
8790
The the usage is roughly the same as Python standard ``range``.
8891

92+
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
93+
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
94+
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
95+
Users can access containers by either constants or constants loops annotated.
96+
97+
.. code-block:: python
98+
99+
@tvm.hybrid.script
100+
def foo(a, b): # b is a tvm.container.Array
101+
c = output_tensor(a.shape, a.dtype)
102+
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
103+
c[i] = a[i] + b[i]
104+
return c
105+
106+
89107
Variables
90108
~~~~~~~~~
91109

@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
111129
s += a[i, j] # do something with sum
112130
b[i] = sum # you can still use sum in this level
113131
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
114-
b = (1, 2) # this has NOT been supported yet!
115132
116133
117134
Attributes
118135
~~~~~~~~~~
119136

120-
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
121-
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
137+
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
138+
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
139+
Currently, only constant-indexed access is supported.
122140

123141
.. code-block:: python
124142
@@ -133,8 +151,11 @@ Conditional Statement and Expression
133151

134152
.. code-block:: python
135153
136-
if condition:
137-
# do something
154+
if condition1 and condition2 and condition3:
155+
# do something
156+
else:
157+
# do something else
158+
# Select
138159
a = b if condition else c
139160
140161
However, NO ``True`` and ``False`` keyword supported yet.
@@ -153,7 +174,9 @@ Array Allocation
153174
**Under construction, this function will be supported later!**
154175

155176
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
156-
The basic usage is roughly the same as a normal array.
177+
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
178+
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
179+
even for ``tvm.container.Array`` for compilation.
157180

158181

159182
Thread Bind
@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
170193
171194
Keywords
172195
~~~~~~~~
173-
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
196+
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
174197
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``

python/tvm/hybrid/calls.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212
#pylint: disable=redefined-builtin
1313

1414
LOOP_INTRIN = {
15-
'range' : For.Serial,
16-
'unroll' : For.Unrolled,
17-
'parallel' : For.Parallel,
18-
'vectorize': For.Vectorized,
15+
'range' : For.Serial,
16+
'unroll' : For.Unrolled,
17+
'parallel' : For.Parallel,
18+
'vectorize' : For.Vectorized,
19+
'const_range' : (For.Unrolled, ),
1920
}
2021

22+
2123
def _range(annotation, args):
2224
"""Handling TVM loop types"""
23-
n = len(args)
25+
n = args.__len__()
2426
if n == 1:
2527
low, ext = _api.const(0, dtype='int32'), args[0]
2628
else:
@@ -33,13 +35,13 @@ def _range(annotation, args):
3335
return iter_var, low, ext, for_type
3436

3537

36-
range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name
38+
range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name
3739

3840

3941
def bind(func_id, args):
4042
"""Handling TVM thread binding"""
4143
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
42-
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
44+
_internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
4345
_internal_assert(isinstance(args[0], str), \
4446
"A loop bind's first argument should be a string!")
4547
iter_var = _api.thread_axis(args[0])
@@ -56,7 +58,7 @@ def _math_intrin(func_id, args):
5658

5759

5860
def _min_max(func_id, args):
59-
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
61+
_internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
6062
return getattr(_make, func_id.title())(args[0], args[1])
6163

6264

@@ -66,7 +68,7 @@ def _min_max(func_id, args):
6668
def _allocate_tensor(func_id, args):
6769
"""Handling TVM tensor allocation.
6870
You may refer hybrid.intrin.allocate for more details."""
69-
n = len(args)
71+
n = args.__len__()
7072
_internal_assert(isinstance(_api.convert(args[0]), Array), \
7173
"allocate's first argument should be a tuple of shape!")
7274
shape = args[0]
@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
8991
scope = 'global' if func_id != 'output_tensor' else 'output'
9092
return (shape, dtype, scope)
9193

94+
9295
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
96+
97+
98+
def len(func_id, args):
99+
"""Iterpret the len function"""
100+
_internal_assert(args.__len__() == 1, "Only 1 argument is expected!")
101+
_internal_assert(func_id == "len", "This function cannot be directly invoked!")
102+
try:
103+
return _api.convert(args[0].__len__())
104+
except: #pylint: disable=bare-except
105+
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
106+
return _api.convert(args[0].shape[0])

python/tvm/hybrid/intrin.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,19 @@
22

33
import numpy
44

5-
class _range(object):
6-
"""Base class of the loop ranges in hybrid script"""
7-
def __init__(self, a, b=None):
8-
if b is None:
9-
self.low = 0
10-
self.ext = a
11-
else:
12-
self.low = a
13-
self.ext = b
5+
6+
class bind(object): #pylint: disable=invalid-name
7+
"""GPU bind software emulataion runtime."""
8+
def __init__(self, _, ext):
9+
self.ext = ext
1410

1511
def __iter__(self):
1612
i = 0
1713
while i < self.ext:
18-
yield i + self.low
14+
yield i
1915
i += 1
2016

2117

22-
class bind(_range): #pylint: disable=invalid-name
23-
def __init__(self, tag, ext):
24-
super(bind, self).__init__(ext)
25-
self.tag = tag
26-
27-
28-
unroll = vectorize = parallel = _range #pylint: disable=invalid-name
29-
30-
3118
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
3219
"""Allocate a buffer with given shape
3320
@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
4734
"""
4835
return numpy.zeros(shape).astype(dtype)
4936

50-
output_tensor = allocate #pylint: disable=invalid-name
5137

5238
def popcount(x):
5339
"""
@@ -87,17 +73,19 @@ def sigmoid(x):
8773

8874

8975
HYBRID_GLOBALS = {
90-
'unroll' : unroll,
91-
'vectorize' : vectorize,
92-
'parallel' : parallel,
93-
'allocate' : allocate,
94-
'output_tensor': output_tensor,
76+
'len' : len,
77+
'unroll' : range,
78+
'vectorize' : range,
79+
'parallel' : range,
80+
'const_range' : range,
9581
'bind' : bind,
82+
'allocate' : allocate,
83+
'output_tensor': allocate,
9684
'sqrt' : numpy.sqrt,
9785
'log' : numpy.log,
9886
'tanh' : numpy.tanh,
9987
'power' : numpy.power,
10088
'exp' : numpy.exp,
10189
'sigmoid' : sigmoid,
102-
'popcount' : popcount
90+
'popcount' : popcount,
10391
}

0 commit comments

Comments
 (0)