Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
};

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target):

reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type, target):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]

@reg.register_schedule("nn.sparse_dense")
def schedule_sparse_dense(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_sparse_dense(outputs)

reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# conv2d
def _find_conv2d_op(op):
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,39 @@ def batch_matmul(x, y):
"""
return _make.batch_matmul(x, y)

def sparse_dense(data, weight):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.

.. math::

\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]

where `as_dense` returns dense equivalent of the given sparse matrix.

See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
and
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
for more detail on the sparse matrix representation.

Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication

weight : namedtuple.
The sparse weight matrix for the matrix multiplication.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)


def contrib_conv2d_winograd_without_weight_transform(data,
weight,
Expand Down
97 changes: 97 additions & 0 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file sparse.cc
* \brief Property def of nn.sparse_dense operator.
*/

#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/alter_op_layout.h"

namespace tvm {
namespace relay {

// relay.nn.sparse_dense
TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);

bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight_data = types[1].as<TensorTypeNode>();
CHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
const auto* weight_indptr = types[3].as<TensorTypeNode>();
if (data == nullptr) return false;

if (weight_data->shape.size() == 1) {
// CSR case.
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}

if (weight_data->shape.size() == 3) {
// BSR case.
Array<IndexExpr> oshape({
data->shape[0],
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
return false;
}

// Positional relay function to create dense operator used by frontend FFI.
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
auto attrs = make_node<SparseDenseAttrs>();
static const Op& op = Op::Get("nn.sparse_dense");
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
});

RELAY_REGISTER_OP("nn.sparse_dense")
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.

- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
.set_num_inputs(4)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
.set_support_level(1)
.add_type_rel("SparseDense", SparseDenseRel);

} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,23 @@ def schedule_l2_normalize(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)

@tvm.target.generic_func
def schedule_sparse_dense(outs):
"""Schedule for sparse_dense

Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_dense
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False)
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
103 changes: 103 additions & 0 deletions topi/python/topi/nn/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Sparse operators"""
from __future__ import absolute_import
import tvm

from ..util import get_const_tuple


@tvm.target.generic_func
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`

Parameters
----------
x : tvm.Tensor
2-D with shape [M, K], float32

weight_data : tvm.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)

weight_indices : tvm.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)

weight_indptr : tvm.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)

Returns
-------
output : tvm.Tensor
2-D with shape [M, N]
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm
return func(data, weight_data, weight_indices, weight_indptr)


def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
oshape = (
get_const_tuple(data.shape)[0],
get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
row_start = weight_indptr[row]
row_end = weight_indptr[row + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx")
elem = row_start + elem_idx
a_val = weight_data[elem]
weight_val = data[i, weight_indices[elem]]
return tvm.sum(a_val * weight_val, axis=elem_idx)
return tvm.compute(oshape, f, tag="sparse_dense_csrmm")


def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape)
num_blocks = num_blocks_plus_1 - 1

def _compute_block(i, nb_j, j):
row_start = weight_indptr[nb_j]
row_end = weight_indptr[nb_j + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis(
(0, row_elems), name="elem_idx")
block_offset = row_start + elem_idx
c = tvm.reduce_axis((0, bs_c), name="c")
block_j = weight_indices[block_offset]
block_ij_val = weight_data[block_offset][j][c]
x_val = data[i, bs_c * block_j + c]
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])

bsrmm_block = tvm.compute(
(m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_bsrmm_block")
return tvm.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
tag="sparse_dense_bsrmm")
63 changes: 63 additions & 0 deletions topi/python/topi/x86/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""sparse_dense schedule on x86"""
import tvm

from .. import generic
from ..util import traverse_inline, get_const_int
from .util import get_fp32_len


@generic.schedule_sparse_dense.register(["cpu"])
def _schedule_sparse_dense(outs):
s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
(_, v_i) = s[op].op.axis
s[op].vectorize(v_i)
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[op].compute_at(s[outs[0]], y_o)
s[outs[0].op].vectorize(y_i)
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
y_reshape = op
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
bs_r = get_const_int(b_r.dom.extent)
(elem_idx, c) = s[y_bsrmm].op.reduce_axis
s[y_bsrmm].reorder(num_blocks, m, elem_idx, b_r, c)
s[y_bsrmm].vectorize(b_r)
(m_o, n_o) = s[y_reshape].op.axis
(noo, noi) = s[y_reshape].split(n_o, bs_r)
s[y_bsrmm].compute_at(s[y_reshape], noi)
s[y_reshape].vectorize(noi)
if op != s[outs[0]].op:
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[y_reshape].compute_at(s[outs[0]], y_o)
s[outs[0].op].parallel(y_o)
s[outs[0].op].vectorize(y_i)
else:
m_o_noo = s[y_reshape].fuse(m_o, noo)
s[y_reshape].parallel(m_o_noo)

traverse_inline(s, outs[0].op, _callback)
return s
Loading