2323from ..utils import get_const_tuple
2424
2525
26- def sparse_dense (data , weight_data , weight_indices , weight_indptr ):
26+ def sparse_dense_v2 (data , weight_data , weight_indices , weight_indptr ):
2727 """
2828 Computes sparse-dense matrix multiplication of `data` and
2929 `(weight_data, weight_indices, weight_indptr).T`
@@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
5252 """
5353 assert len (weight_data .shape ) in (1 , 3 )
5454 if len (weight_data .shape ) == 1 :
55- func = _sparse_dense_csrmm
55+ func = _sparse_dense_csrmm_v2
5656 if len (weight_data .shape ) == 3 :
57- func = _sparse_dense_bsrmm
57+ func = _sparse_dense_bsrmm_v2
5858 return func (data , weight_data , weight_indices , weight_indptr )
5959
6060
61- def _sparse_dense_csrmm (data , weight_data , weight_indices , weight_indptr ):
61+ def sparse_dense_v1 (data_data , data_indices , data_indptr , weight ):
62+ """
63+ Computes sparse-dense matrix multiplication of
64+ `(data_data, data_indices, data_indptr)` and `weight.T`
65+
66+ Parameters
67+ ----------
68+ data_data:
69+ 1-D with shape [nnz] (CSR) or
70+ 3-D with shape [num_blocks, bs_r, bs_c] (BSR)
71+
72+ data_indices:
73+ 1-D with shape [nnz] (CSR) or
74+ 1-D with shape [num_blocks] (BSR)
75+
76+ data_indptr:
77+ 1-D with shape [M + 1] (CSR) or
78+ 1-D with shape [(M + 1) // bs_r] (BSR)
79+
80+ weight:
81+ 2-D with shape [N, K], float32
82+
83+ Returns
84+ -------
85+ output : tvm.te.Tensor
86+ 2-D with shape [M, N]
87+ """
88+ assert len (data_data .shape ) in (1 , 3 )
89+ if len (data_data .shape ) == 1 :
90+ func = _sparse_dense_csrmm_v1
91+ if len (data_data .shape ) == 3 :
92+ func = _sparse_dense_bsrmm_v1
93+ return func (data_data , data_indices , data_indptr , weight )
94+
95+
96+ # pylint: disable=no-else-return,inconsistent-return-statements
97+ def sparse_dense (dense_data , sparse_data , sparse_indices , sparse_indptr , sparse_lhs = False ):
98+ """
99+ Computes sparse-dense matrix multiplication of `data` and
100+ `(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False
101+ or
102+ Computes sparse-dense matrix multiplication of
103+ `(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True
104+
105+ Parameters
106+ ----------
107+ dense_data : tvm.te.Tensor
108+ 2-D with shape [M, K], float32
109+
110+ sparse_data : tvm.te.Tensor
111+ 1-D with shape [nnz] (CSR) or
112+ 3-D with shape [num_blocks, bs_r, bs_c] (BSR)
113+
114+ sparse_indices : tvm.te.Tensor
115+ 1-D with shape [nnz] (CSR) or
116+ 1-D with shape [num_blocks] (BSR)
117+
118+ sparse_indptr : tvm.te.Tensor
119+ 1-D with shape [N + 1] (CSR) or
120+ 1-D with shape [(N + 1) // bs_r] (BSR)
121+
122+ sparse_lhs : bool, optional
123+ Indicates whether lhs or rhs matrix is sparse. Default value is False.
124+
125+ Returns
126+ -------
127+ output : tvm.te.Tensor
128+ 2-D with shape [M, N]
129+ """
130+ if sparse_lhs :
131+ return sparse_dense_v1 (sparse_data , sparse_indices , sparse_indptr , dense_data )
132+ else :
133+ return sparse_dense_v2 (dense_data , sparse_data , sparse_indices , sparse_indptr )
134+
135+
136+ def _sparse_dense_csrmm_v1 (data_data , data_indices , data_indptr , weight ):
137+ oshape = (get_const_tuple (data_indptr .shape )[0 ] - 1 , get_const_tuple (weight .shape )[0 ])
138+
139+ def f (row , i ):
140+ row_start = data_indptr [row ]
141+ row_end = data_indptr [row + 1 ]
142+ row_elems = row_end - row_start
143+ elem_idx = te .reduce_axis ((0 , row_elems ), name = "elem_idx" )
144+ elem = row_start + elem_idx
145+ a_val = data_data [elem ]
146+ weight_val = weight [i , data_indices [elem ]]
147+ return te .sum (a_val * weight_val , axis = elem_idx )
148+
149+ return te .compute (oshape , f , tag = "sparse_dense_csrmm_v1" )
150+
151+
152+ def _sparse_dense_csrmm_v2 (data , weight_data , weight_indices , weight_indptr ):
62153 oshape = (get_const_tuple (data .shape )[0 ], get_const_tuple (weight_indptr .shape )[0 ] - 1 )
63154
64155 def f (i , row ):
@@ -71,10 +162,41 @@ def f(i, row):
71162 weight_val = data [i , weight_indices [elem ]]
72163 return te .sum (a_val * weight_val , axis = elem_idx )
73164
74- return te .compute (oshape , f , tag = "sparse_dense_csrmm " )
165+ return te .compute (oshape , f , tag = "sparse_dense_csrmm_v2 " )
75166
76167
77- def _sparse_dense_bsrmm (data , weight_data , weight_indices , weight_indptr ):
168+ def _sparse_dense_bsrmm_v1 (data_data , data_indices , data_indptr , weight ):
169+ (m , _ ) = get_const_tuple (weight .shape )
170+ (_ , bs_r , bs_c ) = get_const_tuple (data_data .shape )
171+ (num_blocks_plus_1 ,) = get_const_tuple (data_indptr .shape )
172+ num_blocks = num_blocks_plus_1 - 1
173+
174+ def _compute_block (nb_j , j , i ):
175+ row_start = data_indptr [nb_j ]
176+ row_end = data_indptr [nb_j + 1 ]
177+ row_elems = row_end - row_start
178+ elem_idx = te .reduce_axis ((0 , row_elems ), name = "elem_idx" )
179+ block_offset = row_start + elem_idx
180+ c = te .reduce_axis ((0 , bs_c ), name = "c" )
181+ block_j = data_indices [block_offset ]
182+ block_ij_val = data_data [block_offset ][j ][c ]
183+ x_val = weight [i , bs_c * block_j + c ]
184+ return te .sum (block_ij_val * x_val , axis = [elem_idx , c ])
185+
186+ idxd = tvm .tir .indexdiv
187+ idxm = tvm .tir .indexmod
188+
189+ bsrmm_block = te .compute (
190+ (num_blocks , bs_r , m ), _compute_block , tag = "sparse_dense_bsrmm_block_v1"
191+ )
192+ return te .compute (
193+ (num_blocks * bs_r , m ),
194+ lambda m , n : bsrmm_block [idxd (m , bs_r ), idxm (m , bs_r ), n ],
195+ tag = "sparse_dense_bsrmm_v1" ,
196+ )
197+
198+
199+ def _sparse_dense_bsrmm_v2 (data , weight_data , weight_indices , weight_indptr ):
78200 (m , _ ) = get_const_tuple (data .shape )
79201 (_ , bs_r , bs_c ) = get_const_tuple (weight_data .shape )
80202 (num_blocks_plus_1 ,) = get_const_tuple (weight_indptr .shape )
@@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j):
95217 idxd = tvm .tir .indexdiv
96218 idxm = tvm .tir .indexmod
97219
98- bsrmm_block = te .compute ((m , num_blocks , bs_r ), _compute_block , tag = "sparse_dense_bsrmm_block" )
220+ bsrmm_block = te .compute (
221+ (m , num_blocks , bs_r ), _compute_block , tag = "sparse_dense_bsrmm_block_v2"
222+ )
99223 return te .compute (
100224 (m , num_blocks * bs_r ),
101225 lambda m , n : bsrmm_block [m , idxd (n , bs_r ), idxm (n , bs_r )],
102- tag = "sparse_dense_bsrmm " ,
226+ tag = "sparse_dense_bsrmm_v2 " ,
103227 )
104228
105229
0 commit comments