Skip to content

Commit aefc05a

Browse files
ajtullochwweic
authored andcommitted
We observe multiple groups across a range of domains (ASR, NMT, LM, etc), (apache#3566)
internally and externally, interested in replacing standard dense layers with block-sparse matrix multiplication layers. The motivations are generally: higher performance (due to reduction in FLOPs, memory bandwidth/cache footprint), enabling larger models (e.g. fitting more layers in a given memory budget). Some public work along these lines: * https://openai.com/blog/block-sparse-gpu-kernels/ * https://openai.com/blog/sparse-transformer/ * https://arxiv.org/abs/1802.08435 * https://arxiv.org/abs/1711.02782 Various groups have been able to successfully train models with reasonable levels of sparsity (90%+) with marginal accuracy changes, which suggests substantial speedups are possible (as this implies a >10x reduction in FLOPs). It is fairly straightforward to realize these theoretical speedups, see e.g. TVM benchmarks for Intel CPUs in https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902, and CUDA results in https://github.com/openai/blocksparse, etc. * https://github.com/openai/blocksparse (CUDA) * https://software.intel.com/en-us/mkl-developer-reference-c-mkl-bsrmm (MKL BSRM) * https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html (SCIPY BSR representation) This is extracted from an internal patch we've been using internally. There are various extensions possible (int8/fp16/bf16, CUDA/other GPU architectures), but this is a reasonable starting point. This needs more thorough unit test coverage however. We follow the conventions established by scipy.sparse.bsr_matrix and other libraries, see the unit tests for details. For folks interested in experimenting with scheduling/AutoTVM etc, https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902 is a useful starting point.
1 parent c8fe9e6 commit aefc05a

File tree

9 files changed

+430
-0
lines changed

9 files changed

+430
-0
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
366366
}
367367
};
368368

369+
/*! \brief Attributes for sparse_dense operator */
370+
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
371+
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
372+
};
369373

370374
/*! \brief Attributes for upsampling operator */
371375
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {

python/tvm/relay/op/nn/_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target):
8585

8686
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
8787

88+
# sparse_dense
89+
@reg.register_compute("nn.sparse_dense")
90+
def compute_sparse_dense(attrs, inputs, out_type, target):
91+
"""Compute definition of sparse_dense"""
92+
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
93+
94+
@reg.register_schedule("nn.sparse_dense")
95+
def schedule_sparse_dense(attrs, outputs, target):
96+
"""Schedule definition of batch_matmul"""
97+
with target:
98+
return topi.generic.schedule_sparse_dense(outputs)
99+
100+
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
88101

89102
# conv2d
90103
def _find_conv2d_op(op):

python/tvm/relay/op/nn/nn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,39 @@ def batch_matmul(x, y):
839839
"""
840840
return _make.batch_matmul(x, y)
841841

842+
def sparse_dense(data, weight):
843+
r"""
844+
Computes the matrix multiplication of `data` and `weight`, where `data` is
845+
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
846+
fields `data`, `indices`, and `indptr`.
847+
848+
.. math::
849+
850+
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
851+
852+
where `as_dense` returns dense equivalent of the given sparse matrix.
853+
854+
See
855+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
856+
and
857+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
858+
for more detail on the sparse matrix representation.
859+
860+
Parameters
861+
----------
862+
data : tvm.relay.Expr
863+
The input data for the matrix multiplication
864+
865+
weight : namedtuple.
866+
The sparse weight matrix for the matrix multiplication.
867+
868+
Returns
869+
-------
870+
result: tvm.relay.Expr
871+
The computed result.
872+
"""
873+
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
874+
842875

843876
def contrib_conv2d_winograd_without_weight_transform(data,
844877
weight,

src/relay/op/nn/sparse.cc

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2018 by Contributors
22+
* \file sparse.cc
23+
* \brief Property def of nn.sparse_dense operator.
24+
*/
25+
26+
#include <tvm/data_layout.h>
27+
#include <tvm/relay/op.h>
28+
#include <tvm/relay/attrs/nn.h>
29+
#include <vector>
30+
31+
#include "../../pass/alter_op_layout.h"
32+
33+
namespace tvm {
34+
namespace relay {
35+
36+
// relay.nn.sparse_dense
37+
TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);
38+
39+
bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
40+
const TypeReporter& reporter) {
41+
CHECK_EQ(types.size(), 5);
42+
const auto* data = types[0].as<TensorTypeNode>();
43+
const auto* weight_data = types[1].as<TensorTypeNode>();
44+
CHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
45+
const auto* weight_indptr = types[3].as<TensorTypeNode>();
46+
if (data == nullptr) return false;
47+
48+
if (weight_data->shape.size() == 1) {
49+
// CSR case.
50+
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
51+
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
52+
return true;
53+
}
54+
55+
if (weight_data->shape.size() == 3) {
56+
// BSR case.
57+
Array<IndexExpr> oshape({
58+
data->shape[0],
59+
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
60+
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
61+
return true;
62+
}
63+
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
64+
return false;
65+
}
66+
67+
// Positional relay function to create dense operator used by frontend FFI.
68+
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
69+
auto attrs = make_node<SparseDenseAttrs>();
70+
static const Op& op = Op::Get("nn.sparse_dense");
71+
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
72+
}
73+
74+
TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
75+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
76+
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
77+
});
78+
79+
RELAY_REGISTER_OP("nn.sparse_dense")
80+
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
81+
82+
- **data**: `(x1, x2, ..., xn, input_dim)`
83+
- **weight**: `(units, input_dim)`
84+
- **out**: `(x1, x2, ..., xn, units)`.
85+
86+
)code" TVM_ADD_FILELINE)
87+
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
88+
.set_num_inputs(4)
89+
.add_argument("data", "nD Tensor", "Input data.")
90+
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
91+
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
92+
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
93+
.set_support_level(1)
94+
.add_type_rel("SparseDense", SparseDenseRel);
95+
96+
} // namespace relay
97+
} // namespace tvm

topi/python/topi/generic/nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,23 @@ def schedule_l2_normalize(outs):
513513
cpp_target = cpp.TEST_create_target(target.target_name)
514514
return cpp.generic.default_schedule(cpp_target, outs, False)
515515

516+
@tvm.target.generic_func
517+
def schedule_sparse_dense(outs):
518+
"""Schedule for sparse_dense
519+
520+
Parameters
521+
----------
522+
outs: Array of Tensor
523+
The computation graph description of sparse_dense
524+
in the format of an array of tensors.
525+
526+
Returns
527+
-------
528+
sch: Schedule
529+
The computation schedule for the op.
530+
"""
531+
return _default_schedule(outs, False)
532+
516533
@tvm.target.generic_func
517534
def schedule_batch_matmul(outs):
518535
target = tvm.target.current_target(allow_none=False)

topi/python/topi/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .bitserial_dense import *
2121
from .l2_normalize import *
2222
from .batch_matmul import *
23+
from .sparse import *

topi/python/topi/nn/sparse.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Sparse operators"""
19+
from __future__ import absolute_import
20+
import tvm
21+
22+
from ..util import get_const_tuple
23+
24+
25+
@tvm.target.generic_func
26+
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
27+
"""
28+
Computes sparse-dense matrix multiplication of `data` and
29+
`(weight_data, weight_indices, weight_indptr).T`
30+
31+
Parameters
32+
----------
33+
x : tvm.Tensor
34+
2-D with shape [M, K], float32
35+
36+
weight_data : tvm.Tensor
37+
1-D with shape [nnz] (CSR) or
38+
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
39+
40+
weight_indices : tvm.Tensor
41+
1-D with shape [nnz] (CSR) or
42+
1-D with shape [num_blocks] (BSR)
43+
44+
weight_indptr : tvm.Tensor
45+
1-D with shape [N + 1] (CSR) or
46+
1-D with shape [(N + 1) // bs_r] (BSR)
47+
48+
Returns
49+
-------
50+
output : tvm.Tensor
51+
2-D with shape [M, N]
52+
"""
53+
assert len(weight_data.shape) in (1, 3)
54+
if len(weight_data.shape) == 1:
55+
func = _sparse_dense_csrmm
56+
if len(weight_data.shape) == 3:
57+
func = _sparse_dense_bsrmm
58+
return func(data, weight_data, weight_indices, weight_indptr)
59+
60+
61+
def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
62+
oshape = (
63+
get_const_tuple(data.shape)[0],
64+
get_const_tuple(weight_indptr.shape)[0] - 1)
65+
66+
def f(i, row):
67+
row_start = weight_indptr[row]
68+
row_end = weight_indptr[row + 1]
69+
row_elems = row_end - row_start
70+
elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx")
71+
elem = row_start + elem_idx
72+
a_val = weight_data[elem]
73+
weight_val = data[i, weight_indices[elem]]
74+
return tvm.sum(a_val * weight_val, axis=elem_idx)
75+
return tvm.compute(oshape, f, tag="sparse_dense_csrmm")
76+
77+
78+
def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
79+
(m, _) = get_const_tuple(data.shape)
80+
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
81+
(num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape)
82+
num_blocks = num_blocks_plus_1 - 1
83+
84+
def _compute_block(i, nb_j, j):
85+
row_start = weight_indptr[nb_j]
86+
row_end = weight_indptr[nb_j + 1]
87+
row_elems = row_end - row_start
88+
elem_idx = tvm.reduce_axis(
89+
(0, row_elems), name="elem_idx")
90+
block_offset = row_start + elem_idx
91+
c = tvm.reduce_axis((0, bs_c), name="c")
92+
block_j = weight_indices[block_offset]
93+
block_ij_val = weight_data[block_offset][j][c]
94+
x_val = data[i, bs_c * block_j + c]
95+
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
96+
97+
bsrmm_block = tvm.compute(
98+
(m, num_blocks, bs_r), _compute_block,
99+
tag="sparse_dense_bsrmm_block")
100+
return tvm.compute(
101+
(m, num_blocks * bs_r),
102+
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
103+
tag="sparse_dense_bsrmm")

topi/python/topi/x86/sparse.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""sparse_dense schedule on x86"""
19+
import tvm
20+
21+
from .. import generic
22+
from ..util import traverse_inline, get_const_int
23+
from .util import get_fp32_len
24+
25+
26+
@generic.schedule_sparse_dense.register(["cpu"])
27+
def _schedule_sparse_dense(outs):
28+
s = tvm.create_schedule([x.op for x in outs])
29+
30+
def _callback(op):
31+
simd_width = get_fp32_len()
32+
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
33+
(_, v_i) = s[op].op.axis
34+
s[op].vectorize(v_i)
35+
(y_o, y_i) = s[outs[0].op].split(
36+
s[outs[0].op].op.axis[1], 2 * simd_width)
37+
s[op].compute_at(s[outs[0]], y_o)
38+
s[outs[0].op].vectorize(y_i)
39+
if op.tag == "sparse_dense_bsrmm":
40+
y_bsrmm = op.input_tensors[0]
41+
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
42+
y_reshape = op
43+
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
44+
bs_r = get_const_int(b_r.dom.extent)
45+
(elem_idx, c) = s[y_bsrmm].op.reduce_axis
46+
s[y_bsrmm].reorder(num_blocks, m, elem_idx, b_r, c)
47+
s[y_bsrmm].vectorize(b_r)
48+
(m_o, n_o) = s[y_reshape].op.axis
49+
(noo, noi) = s[y_reshape].split(n_o, bs_r)
50+
s[y_bsrmm].compute_at(s[y_reshape], noi)
51+
s[y_reshape].vectorize(noi)
52+
if op != s[outs[0]].op:
53+
(y_o, y_i) = s[outs[0].op].split(
54+
s[outs[0].op].op.axis[1], 2 * simd_width)
55+
s[y_reshape].compute_at(s[outs[0]], y_o)
56+
s[outs[0].op].parallel(y_o)
57+
s[outs[0].op].vectorize(y_i)
58+
else:
59+
m_o_noo = s[y_reshape].fuse(m_o, noo)
60+
s[y_reshape].parallel(m_o_noo)
61+
62+
traverse_inline(s, outs[0].op, _callback)
63+
return s

0 commit comments

Comments
 (0)