Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self):
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}

Expand All @@ -101,6 +102,7 @@ def __init__(self):
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}

Expand Down Expand Up @@ -200,18 +202,25 @@ def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
A, W = args[:2]
return s, [A, W, C]

@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
A, W = args[:2]
return s, [A, W, C]

@register("topi_nn_bitserial_dense")
def _topi_nn_bitserial_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.bitserial_dense(*args, **kwargs)
s = topi.generic.schedule_bitserial_dense([C])
return s, [A, W, C]

@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/arm_cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from . import depthwise_conv2d
from . import conv2d_transpose
from . import bitserial_conv2d
from . import bitserial_dense
from . import injective
7 changes: 4 additions & 3 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm import autotvm
from .. import tag
from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc
from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc
from ..nn.bitserial_util import bitpack, binary_op_multiplier
from ..nn.util import get_pad_tuple
from ..util import get_const_int, get_const_tuple
from .. import generic
Expand Down Expand Up @@ -93,7 +94,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
policy='candidate', candidate=[
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
[n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ====================

VC = cfg["tile_co"].size[-1]
Expand Down Expand Up @@ -310,7 +312,6 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,

s[conv_out].compute_at(s[last], co)
s[last].parallel(oh)
s = s.normalize()
return s

@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
Expand Down
184 changes: 184 additions & 0 deletions topi/python/topi/arm_cpu/bitserial_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .. import tag
from .. import generic
from .bitserial_conv2d import _intrin_popcount
from ..nn.pad import pad
from ..nn.bitserial_dense import bitserial_dense
from ..nn.bitserial_util import bitpack, binary_op_multiplier

@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
unipolar):
"""The default implementation of bitserial dense in topi.

Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]

weight : tvm.Tensor
2-D with shape [out_dim, in_dim]

Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight

batch, DB, in_dim = get_const_tuple(data_packed.shape)
out_dim, WB, in_dim = get_const_tuple(weight_packed.shape)

# Pad Inputs so that microkernel can be used
# out_dim and in_dim need to be multiples of 8
if out_dim % 8 != 0:
out_dim_pad = out_dim % 8
data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name='PaddedInput')
out_dim += out_dim_pad

######## Search space

x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)

ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8)

cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
policy='candidate', candidate=[
[yo, xo, ko, xi, wb, db, yi, ki],
[yo, xo, xi, ko, wb, db, yi, ki],
[yo, xo, ko, xi, wb, db, yi, ki]])

###### Compute rule
VY = cfg['tile_y'].size[-1]
VK = cfg['tile_k'].size[-1]

wvshape = (out_dim//VY, in_dim//VK, WB, VY, VK)
oshape = (batch, out_dim)

k = tvm.reduce_axis((0, in_dim), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')

# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda yo, ko, wb, vy, vk:
weight_packed[yo*VY+vy][wb][ko*VK+vk], name='weight_vec')
matmul_unipolar = tvm.compute(oshape, lambda x, y: tvm.sum(
(tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)) -
tvm.popcount(~weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')

matmul = tvm.compute(oshape, lambda x, y: tvm.sum(
tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')

cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype))

if unipolar:
return matmul_unipolar
return matmul


@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for binary_dense.

Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])

def _schedule(cfg, s, data_vec, weight_vec, output, unipolar):

z, k, _, y, x = s[weight_vec].op.axis
s[weight_vec].parallel(z)
s[weight_vec].vectorize(x)

x, y = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
_, DB, _ = get_const_tuple(data_vec.shape)
_, _, WB, _, _ = get_const_tuple(weight_vec.shape)

yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)

cfg["reorder_0"].apply(s, output, [yo, xo, ko, xi, wb, db, yi, ki])

fused = s[output].fuse(xo, yo)
s[output].parallel(fused)

nfactor = cfg['tile_y'].size[-1]
kfactor = cfg['tile_k'].size[-1]
if nfactor % 8 == 0:
pc = _intrin_popcount(nfactor, kfactor, WB, DB, unipolar)
s[output].tensorize(wb, pc)

return s

def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)

elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
output = op.output(0)
weight_vec = op.input_tensors[0]

data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
unipolar = (output.op.tag == 'bitserial_dense_unipolar')
_schedule(cfg, s, data_vec, weight_vec, output, unipolar)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)

traverse(outs[0].op)
return s
16 changes: 16 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,22 @@ def schedule_bitserial_conv2d_nhwc(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_bitserial_dense(outs):
"""Schedule for bitserial_dense
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
"""Schedule for reduction
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
Loading