|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +"""Sparse operators""" |
| 19 | +from __future__ import absolute_import |
| 20 | +import tvm |
| 21 | + |
| 22 | +from ..util import get_const_tuple |
| 23 | + |
| 24 | + |
| 25 | +@tvm.target.generic_func |
| 26 | +def sparse_dense(data, weight_data, weight_indices, weight_indptr): |
| 27 | + """ |
| 28 | + Computes sparse-dense matrix multiplication of `data` and |
| 29 | + `(weight_data, weight_indices, weight_indptr).T` |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + x : tvm.Tensor |
| 34 | + 2-D with shape [M, K], float32 |
| 35 | +
|
| 36 | + weight_data : tvm.Tensor |
| 37 | + 1-D with shape [nnz] (CSR) or |
| 38 | + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) |
| 39 | +
|
| 40 | + weight_indices : tvm.Tensor |
| 41 | + 1-D with shape [nnz] (CSR) or |
| 42 | + 1-D with shape [num_blocks] (BSR) |
| 43 | +
|
| 44 | + weight_indptr : tvm.Tensor |
| 45 | + 1-D with shape [N + 1] (CSR) or |
| 46 | + 1-D with shape [(N + 1) // bs_r] (BSR) |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + output : tvm.Tensor |
| 51 | + 2-D with shape [M, N] |
| 52 | + """ |
| 53 | + assert len(weight_data.shape) in (1, 3) |
| 54 | + if len(weight_data.shape) == 1: |
| 55 | + func = _sparse_dense_csrmm |
| 56 | + if len(weight_data.shape) == 3: |
| 57 | + func = _sparse_dense_bsrmm |
| 58 | + return func(data, weight_data, weight_indices, weight_indptr) |
| 59 | + |
| 60 | + |
| 61 | +def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr): |
| 62 | + oshape = ( |
| 63 | + get_const_tuple(data.shape)[0], |
| 64 | + get_const_tuple(weight_indptr.shape)[0] - 1) |
| 65 | + |
| 66 | + def f(i, row): |
| 67 | + row_start = weight_indptr[row] |
| 68 | + row_end = weight_indptr[row + 1] |
| 69 | + row_elems = row_end - row_start |
| 70 | + elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx") |
| 71 | + elem = row_start + elem_idx |
| 72 | + a_val = weight_data[elem] |
| 73 | + weight_val = data[i, weight_indices[elem]] |
| 74 | + return tvm.sum(a_val * weight_val, axis=elem_idx) |
| 75 | + return tvm.compute(oshape, f, tag="sparse_dense_csrmm") |
| 76 | + |
| 77 | + |
| 78 | +def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr): |
| 79 | + (m, _) = get_const_tuple(data.shape) |
| 80 | + (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) |
| 81 | + (num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape) |
| 82 | + num_blocks = num_blocks_plus_1 - 1 |
| 83 | + |
| 84 | + def _compute_block(i, nb_j, j): |
| 85 | + row_start = weight_indptr[nb_j] |
| 86 | + row_end = weight_indptr[nb_j + 1] |
| 87 | + row_elems = row_end - row_start |
| 88 | + elem_idx = tvm.reduce_axis( |
| 89 | + (0, row_elems), name="elem_idx") |
| 90 | + block_offset = row_start + elem_idx |
| 91 | + c = tvm.reduce_axis((0, bs_c), name="c") |
| 92 | + block_j = weight_indices[block_offset] |
| 93 | + block_ij_val = weight_data[block_offset][j][c] |
| 94 | + x_val = data[i, bs_c * block_j + c] |
| 95 | + return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c]) |
| 96 | + |
| 97 | + bsrmm_block = tvm.compute( |
| 98 | + (m, num_blocks, bs_r), _compute_block, |
| 99 | + tag="sparse_dense_bsrmm_block") |
| 100 | + return tvm.compute( |
| 101 | + (m, num_blocks * bs_r), |
| 102 | + lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r], |
| 103 | + tag="sparse_dense_bsrmm") |
0 commit comments