1616# under the License.
1717# pylint: disable=invalid-name,missing-function-docstring
1818"""Intrinsics for tensorization on NVIDIA GPU."""
19+ from .. import Cast
1920from ..._ffi import register_func
2021from ...runtime import convert
2122from .. import TensorIntrin
@@ -46,6 +47,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
4647lift = convert
4748
4849M_DIM = 16
50+ N_DIM = 16
4951WARP_SIZE = 32
5052HALF_WARP = WARP_SIZE // 2
5153HALF_WARP_expr = lift (HALF_WARP )
@@ -81,7 +83,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
8183 assert dtype == "int8"
8284
8385 if ldmatrix_col_major :
84- print ("foo" )
8586 index_map = shared_32x16_to_ldmatrix_32x16_layout
8687 shared_offset = (
8788 lambda _ , stride : stride
@@ -172,6 +173,148 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
172173 return ldmatrix_desc , ldmatrix_impl
173174
174175
176+ def get_mma_intrin (k_dim , out_dtype , transposed ):
177+ local_size = (M_DIM * k_dim ) // WARP_SIZE
178+ local_size_out = (M_DIM * N_DIM ) // 32
179+
180+ index_map_C = shared_16x16_to_ldmatrix_32x8_layout
181+
182+ if k_dim == 16 :
183+ index_map_A = shared_16x16_to_ldmatrix_32x8_layout
184+ index_map_B = shared_16x16_to_ldmatrix_32x8_layout
185+ mma_prefix = "m16n8k16"
186+ elif k_dim == 32 and transposed :
187+ index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout
188+ mma_prefix = "m16n8k32"
189+ elif k_dim == 32 and not transposed :
190+ index_map_A = shared_16x32_to_ldmatrix_32x16_layout
191+ index_map_B = shared_32x16_to_ldmatrix_32x16_layout
192+ mma_prefix = "m16n8k32"
193+ else :
194+ assert False
195+
196+ out_dtype_abbrv = {"float16" : "fp16" , "float32" : "fp32" , "int32" : "int32" }[out_dtype ]
197+
198+ if out_dtype in ["float16" , "float32" ]:
199+ in_dtype = "float16"
200+ in_dtype_abbrv = "fp16"
201+ else :
202+ in_dtype = "int8"
203+ in_dtype_abbrv = "int8"
204+
205+ def maybe_cast (v ):
206+ if out_dtype in ["float32" , "int32" ]:
207+ return Cast (out_dtype , v )
208+ return v
209+
210+ def maybe_swap (i , j ):
211+ if transposed :
212+ return j , i
213+ return i , j
214+
215+ @T .prim_func
216+ def mma_sync_desc (a : T .handle , b : T .handle , c : T .handle ) -> None :
217+ A = T .match_buffer (
218+ a , (WARP_SIZE , local_size ), in_dtype , align = 128 , offset_factor = 16 , scope = "warp"
219+ )
220+ B = T .match_buffer (
221+ b , (WARP_SIZE , local_size ), in_dtype , align = 128 , offset_factor = 16 , scope = "warp"
222+ )
223+ C = T .match_buffer (
224+ c , (WARP_SIZE , local_size_out ), out_dtype , align = 128 , offset_factor = 16 , scope = "warp"
225+ )
226+
227+ with T .block ("root" ):
228+ T .reads (
229+ C [0 :WARP_SIZE , 0 :local_size_out ],
230+ A [0 :WARP_SIZE , 0 :local_size ],
231+ B [0 :WARP_SIZE , 0 :local_size ],
232+ )
233+ T .writes (C [0 :WARP_SIZE , 0 :local_size_out ])
234+
235+ for i , j , k in T .grid (M_DIM , N_DIM , k_dim ):
236+ with T .block ("C" ):
237+ i , j , k = T .axis .remap ("SSR" , [i , j , k ])
238+ b_row_ind , b_col_ind = maybe_swap (k , j )
239+
240+ thread_id_C , local_id_C = index_map_C (i , j )
241+ thread_id_A , local_id_A = index_map_A (i , k )
242+ thread_id_B , local_id_B = index_map_B (b_row_ind , b_col_ind )
243+
244+ T .reads (
245+ C [thread_id_C , local_id_C ],
246+ A [thread_id_A , local_id_A ],
247+ B [thread_id_B , local_id_B ],
248+ )
249+ T .writes (C [thread_id_C , local_id_C ])
250+
251+ C [thread_id_C , local_id_C ] += maybe_cast (
252+ A [thread_id_A , local_id_A ]
253+ ) * maybe_cast (B [thread_id_B , local_id_B ])
254+
255+ @T .prim_func
256+ def mma_sync_impl (a : T .handle , b : T .handle , c : T .handle ) -> None :
257+ A = T .match_buffer (
258+ a , (WARP_SIZE , local_size ), in_dtype , align = 128 , offset_factor = 16 , scope = "warp"
259+ )
260+ B = T .match_buffer (
261+ b , (WARP_SIZE , local_size ), in_dtype , align = 128 , offset_factor = 16 , scope = "warp"
262+ )
263+ C = T .match_buffer (
264+ c , (WARP_SIZE , local_size_out ), out_dtype , align = 128 , offset_factor = 16 , scope = "warp"
265+ )
266+
267+ with T .block ("root" ):
268+ T .reads (
269+ C [0 :WARP_SIZE , 0 :local_size_out ],
270+ A [0 :WARP_SIZE , 0 :local_size ],
271+ B [0 :WARP_SIZE , 0 :local_size ],
272+ )
273+ T .writes (C [0 :WARP_SIZE , 0 :local_size_out ])
274+ tx = T .env_thread ("threadIdx.x" )
275+ T .launch_thread (tx , WARP_SIZE )
276+
277+ T .evaluate (
278+ T .ptx_mma (
279+ mma_prefix ,
280+ "row" ,
281+ "col" ,
282+ in_dtype_abbrv ,
283+ in_dtype_abbrv ,
284+ out_dtype_abbrv ,
285+ A .data ,
286+ A .elem_offset + tx * lift (local_size ),
287+ B .data ,
288+ B .elem_offset + tx * lift (local_size ),
289+ C .data ,
290+ C .elem_offset + tx * lift (local_size_out ),
291+ False ,
292+ dtype = out_dtype ,
293+ )
294+ )
295+
296+ T .evaluate (
297+ T .ptx_mma (
298+ mma_prefix ,
299+ "row" ,
300+ "col" ,
301+ in_dtype_abbrv ,
302+ in_dtype_abbrv ,
303+ out_dtype_abbrv ,
304+ A .data ,
305+ A .elem_offset + tx * lift (local_size ),
306+ B .data ,
307+ B .elem_offset + tx * lift (local_size ) + lift (local_size ) // 2 ,
308+ C .data ,
309+ C .elem_offset + tx * lift (local_size_out ) + lift (local_size_out ) // 2 ,
310+ False ,
311+ dtype = out_dtype ,
312+ )
313+ )
314+
315+ return mma_sync_desc , mma_sync_impl
316+
317+
175318LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
176319TensorIntrin .register (LDMATRIX_16x16_A_INTRIN , * get_ldmatrix_intrin (16 , "float16" , False , False ))
177320
@@ -191,3 +334,21 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
191334
192335LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
193336TensorIntrin .register (LDMATRIX_16x32_B_TRANS_INTRIN , * get_ldmatrix_intrin (32 , "int8" , True , True ))
337+
338+ MMA_f16f16f32_INTRIN = "mma_f16f16f32"
339+ TensorIntrin .register (MMA_f16f16f32_INTRIN , * get_mma_intrin (16 , "float32" , False ))
340+
341+ MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
342+ TensorIntrin .register (MMA_f16f16f32_TRANS_INTRIN , * get_mma_intrin (16 , "float32" , True ))
343+
344+ MMA_f16f16f16_INTRIN = "mma_f16f16f16"
345+ TensorIntrin .register (MMA_f16f16f16_INTRIN , * get_mma_intrin (16 , "float16" , False ))
346+
347+ MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
348+ TensorIntrin .register (MMA_f16f16f16_TRANS_INTRIN , * get_mma_intrin (16 , "float16" , True ))
349+
350+ MMA_i8i8i32_INTRIN = "mma_i8i8i32"
351+ TensorIntrin .register (MMA_i8i8i32_INTRIN , * get_mma_intrin (32 , "int32" , False ))
352+
353+ MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
354+ TensorIntrin .register (MMA_i8i8i32_TRANS_INTRIN , * get_mma_intrin (32 , "int32" , True ))
0 commit comments