Skip to content

Commit 3406392

Browse files
committed
asdf
1 parent 5c864ac commit 3406392

File tree

9 files changed

+402
-0
lines changed

9 files changed

+402
-0
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
366366
}
367367
};
368368

369+
/*! \brief Attributes for sparse_dense operator */
370+
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
371+
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
372+
};
369373

370374
/*! \brief Attributes for upsampling operator */
371375
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {

python/tvm/relay/op/nn/_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target):
8585

8686
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
8787

88+
# sparse_dense
89+
@reg.register_compute("nn.sparse_dense")
90+
def compute_sparse_dense(attrs, inputs, out_type, target):
91+
"""Compute definition of sparse_dense"""
92+
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
93+
94+
@reg.register_schedule("nn.sparse_dense")
95+
def schedule_sparse_dense(attrs, outputs, target):
96+
"""Schedule definition of batch_matmul"""
97+
with target:
98+
return topi.generic.schedule_sparse_dense(outputs)
99+
100+
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
88101

89102
# conv2d
90103
def _find_conv2d_op(op):

python/tvm/relay/op/nn/nn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,39 @@ def batch_matmul(x, y):
817817
"""
818818
return _make.batch_matmul(x, y)
819819

820+
def sparse_dense(data, weight):
821+
r"""
822+
Computes the matrix multiplication of `data` and `weight`, where `data` is
823+
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
824+
fields `data`, `indices`, and `indptr`.
825+
826+
.. math::
827+
828+
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
829+
830+
where `as_dense` returns dense equivalent of the given sparse matrix.
831+
832+
See
833+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
834+
and
835+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
836+
for more detail on the sparse matrix representation.
837+
838+
Parameters
839+
----------
840+
data : tvm.relay.Expr
841+
The input data for the matrix multiplication
842+
843+
weight : namedtuple.
844+
The sparse weight matrix for the matrix multiplication.
845+
846+
Returns
847+
-------
848+
result: tvm.relay.Expr
849+
The computed result.
850+
"""
851+
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
852+
820853

821854
def contrib_conv2d_winograd_without_weight_transform(data,
822855
weight,

src/relay/op/nn/sparse.cc

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2018 by Contributors
22+
* \file sparse.cc
23+
* \brief Property def of nn.sparse_dense operator.
24+
*/
25+
26+
#include <tvm/data_layout.h>
27+
#include <tvm/relay/op.h>
28+
#include <tvm/relay/attrs/nn.h>
29+
#include <vector>
30+
31+
#include "../../pass/alter_op_layout.h"
32+
33+
namespace tvm {
34+
namespace relay {
35+
36+
// relay.nn.dense
37+
TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);
38+
39+
bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
40+
const TypeReporter& reporter) {
41+
CHECK_EQ(types.size(), 5);
42+
const auto* data = types[0].as<TensorTypeNode>();
43+
const auto* weight_data = types[1].as<TensorTypeNode>();
44+
CHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
45+
46+
// const auto* weight_indices = types[2].as<TensorTypeNode>();
47+
const auto* weight_indptr = types[3].as<TensorTypeNode>();
48+
if (data == nullptr) return false;
49+
50+
if (weight_data->shape.size() == 1) {
51+
// CSR case.
52+
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
53+
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
54+
return true;
55+
}
56+
57+
if (weight_data->shape.size() == 3) {
58+
// BSR case.
59+
Array<IndexExpr> oshape({
60+
data->shape[0],
61+
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
62+
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
63+
return true;
64+
}
65+
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
66+
return false;
67+
}
68+
69+
// Positional relay function to create dense operator used by frontend FFI.
70+
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
71+
auto attrs = make_node<SparseDenseAttrs>();
72+
static const Op& op = Op::Get("nn.sparse_dense");
73+
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
74+
}
75+
76+
TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
77+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
78+
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
79+
});
80+
81+
RELAY_REGISTER_OP("nn.sparse_dense")
82+
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
83+
84+
- **data**: `(x1, x2, ..., xn, input_dim)`
85+
- **weight**: `(units, input_dim)`
86+
- **out**: `(x1, x2, ..., xn, units)`.
87+
88+
)code" TVM_ADD_FILELINE)
89+
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
90+
.set_num_inputs(4)
91+
.add_argument("data", "nD Tensor", "Input data.")
92+
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
93+
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
94+
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
95+
.set_support_level(1)
96+
.add_type_rel("SparseDense", SparseDenseRel);
97+
98+
} // namespace relay
99+
} // namespace tvm

topi/python/topi/generic/nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,23 @@ def schedule_l2_normalize(outs):
513513
cpp_target = cpp.TEST_create_target(target.target_name)
514514
return cpp.generic.default_schedule(cpp_target, outs, False)
515515

516+
@tvm.target.generic_func
517+
def schedule_sparse_dense(outs):
518+
"""Schedule for sparse_dense
519+
520+
Parameters
521+
----------
522+
outs: Array of Tensor
523+
The computation graph description of sparse_dense
524+
in the format of an array of tensors.
525+
526+
Returns
527+
-------
528+
sch: Schedule
529+
The computation schedule for the op.
530+
"""
531+
return _default_schedule(outs, False)
532+
516533
@tvm.target.generic_func
517534
def schedule_batch_matmul(outs):
518535
target = tvm.target.current_target(allow_none=False)

topi/python/topi/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .bitserial_dense import *
2121
from .l2_normalize import *
2222
from .batch_matmul import *
23+
from .sparse import *

topi/python/topi/nn/sparse.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Sparse operators"""
19+
from __future__ import absolute_import
20+
import tvm
21+
22+
from ..util import get_const_tuple
23+
24+
25+
@tvm.target.generic_func
26+
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
27+
"""
28+
Computes sparse-dense matrix multiplication of `data` and
29+
`(weight_data, weight_indices, weight_indptr).T`
30+
31+
Parameters
32+
----------
33+
x : tvm.Tensor
34+
2-D with shape [M, K], float32
35+
36+
weight_data : tvm.Tensor
37+
1-D with shape [nnz] (CSR) or
38+
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
39+
40+
weight_indices : tvm.Tensor
41+
1-D with shape [nnz] (CSR) or
42+
1-D with shape [num_blocks] (BSR)
43+
44+
weight_indptr : tvm.Tensor
45+
1-D with shape [N + 1] (CSR) or
46+
1-D with shape [(N + 1) // bs_r] (BSR)
47+
48+
Returns
49+
-------
50+
output : tvm.Tensor
51+
2-D with shape [M, N]
52+
"""
53+
assert len(weight_data.shape) in (1, 3)
54+
if len(weight_data.shape) == 1:
55+
return _sparse_dense_csrmv(
56+
data, weight_data, weight_indices, weight_indptr)
57+
if len(weight_data.shape) == 3:
58+
return _sparse_dense_bsrmv(
59+
data, weight_data, weight_indices, weight_indptr)
60+
61+
62+
def _sparse_dense_csrmv(data, weight_data, weight_indices, weight_indptr):
63+
oshape = (
64+
get_const_tuple(data.shape)[0],
65+
get_const_tuple(weight_indptr.shape)[0] - 1)
66+
assert weight_indices.dtype == "int32", weight_indices.dtype
67+
assert weight_indptr.dtype == "int32", weight_indptr.dtype
68+
69+
def f(i, row):
70+
assert row.dtype == "int32"
71+
row_start = weight_indptr[row]
72+
row_end = weight_indptr[row + 1]
73+
row_elems = row_end - row_start
74+
elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx")
75+
elem = row_start + elem_idx
76+
a_val = weight_data[elem].astype("float32")
77+
weight_val = data[i, weight_indices[elem]]
78+
return tvm.sum(a_val * weight_val, axis=elem_idx)
79+
return tvm.compute(oshape, f, tag="sparse_dense_csrmv")
80+
81+
82+
def _sparse_dense_bsrmv(data, weight_data, weight_indices, weight_indptr):
83+
(M, K) = get_const_tuple(data.shape)
84+
(_, BS_R, BS_C) = get_const_tuple(weight_data.shape)
85+
(NB_plus_1, ) = get_const_tuple(weight_indptr.shape)
86+
NB = NB_plus_1 - 1
87+
oshape = (M, NB, BS_R)
88+
89+
def f(i, nb, r):
90+
row_start = weight_indptr[nb]
91+
row_end = weight_indptr[nb + 1]
92+
row_elems = row_end - row_start
93+
elem_idx = tvm.reduce_axis(
94+
(0, row_elems), name="elem_idx")
95+
jj = row_start + elem_idx
96+
c = tvm.reduce_axis((0, BS_C), name="c")
97+
j = weight_indices[jj]
98+
block_ij_val = weight_data[jj][r][c]
99+
assert weight_data.dtype == "float32"
100+
x_val = data[i, BS_C * j + c]
101+
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
102+
103+
Y = tvm.compute(
104+
oshape, f, tag="sparse_dense_bsrmv_block")
105+
return tvm.compute(
106+
(M, NB * BS_R), lambda m, n: Y[m, n // BS_R, n % BS_R],
107+
tag="sparse_dense_bsrmv")

topi/python/topi/x86/sparse.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from .. import generic
19+
from ..util import traverse_inline, get_const_int
20+
from .util import get_fp32_len
21+
22+
import tvm
23+
24+
25+
@generic.schedule_sparse_dense.register(["cpu"])
26+
def schedule_sparse_dense(outs):
27+
s = tvm.create_schedule([x.op for x in outs])
28+
29+
def callback(op):
30+
simd_width = get_fp32_len()
31+
if op.tag == "sparse_dense_csrmv" and op != outs[0].op:
32+
(_, vi) = s[op].op.axis
33+
s[op].vectorize(vi)
34+
(yo, yi) = s[outs[0].op].split(
35+
s[outs[0].op].op.axis[1], 2 * simd_width)
36+
s[op].compute_at(s[outs[0]], yo)
37+
s[outs[0].op].vectorize(yi)
38+
if op.tag == "sparse_dense_bsrmv":
39+
Y_bsrmv = op.input_tensors[0]
40+
assert Y_bsrmv.op.tag == "sparse_dense_bsrmv_block"
41+
Y_reshape = op
42+
(m, nb, br) = s[Y_bsrmv].op.axis
43+
BS_R = get_const_int(br.dom.extent)
44+
(elem_idx, c) = s[Y_bsrmv].op.reduce_axis
45+
s[Y_bsrmv].reorder(nb, m, elem_idx, br, c)
46+
s[Y_bsrmv].vectorize(br)
47+
(mo, no) = s[Y_reshape].op.axis
48+
(noo, noi) = s[Y_reshape].split(no, BS_R)
49+
s[Y_bsrmv].compute_at(s[Y_reshape], noi)
50+
s[Y_reshape].vectorize(noi)
51+
if op != s[outs[0]].op:
52+
(yo, yi) = s[outs[0].op].split(
53+
s[outs[0].op].op.axis[1], 2 * simd_width)
54+
s[Y_reshape].compute_at(s[outs[0]], yo)
55+
s[outs[0].op].parallel(yo)
56+
s[outs[0].op].vectorize(yi)
57+
else:
58+
mo_noo = s[Y_reshape].fuse(mo, noo)
59+
s[Y_reshape].parallel(mo_noo)
60+
61+
traverse_inline(s, outs[0].op, callback)
62+
return s

0 commit comments

Comments
 (0)