1616# under the License.
1717# pylint: disable=invalid-name
1818"""GEMM kernel generator and profiler for CUTLASS."""
19- import logging
2019import os
2120import re
2221import tempfile
2322import subprocess
2423import multiprocessing
2524from .gemm_operation import GemmOperation , EmitGemmInstance
2625from .gemm_profiler import GemmProfilerEmitter
26+ from gen_tensor_op import (
27+ ProfilerEngine ,
28+ generate_sm75_tensor_op_1688 ,
29+ generate_sm80_tensor_op_16816 ,
30+ )
2731from .library import (
2832 EpilogueFunctor ,
2933 SwizzlingFunctor ,
3741 TileDescription ,
3842)
3943
40- logger = logging .getLogger ("cutlass" )
41-
4244
43- def create_gemm_operator (
45+ def _create_gemm_operator (
4446 layouts ,
4547 tile_descriptions ,
4648 data_type ,
@@ -132,141 +134,32 @@ def create_gemm_operator(
132134 return ret
133135
134136
135- def generate_tensor_op_common (
136- math_instructions , alignment_constraints , get_tile_descriptions , batched = False
137- ):
138- """Common kernel generator to be used by archtecture specific generators."""
139- ops = []
140- layouts = [
141- (LayoutType .RowMajor , LayoutType .ColumnMajor , LayoutType .RowMajor ),
142- ]
143- for math_inst in math_instructions :
144- tile_descriptions = get_tile_descriptions (math_inst )
145- data_type = [
146- math_inst .element_a ,
147- math_inst .element_b ,
148- math_inst .element_accumulator ,
149- math_inst .element_accumulator ,
150- ]
151-
152- out = create_gemm_operator (
153- layouts , tile_descriptions , data_type , alignment_constraints , batched = batched
137+ def create_gemm_operator (batched ):
138+ def op_creator (
139+ layouts ,
140+ tile_descriptions ,
141+ data_type ,
142+ alignment_constraints ,
143+ swizzling_functor = SwizzlingFunctor .Identity8 ,
144+ ):
145+ return _create_gemm_operator (
146+ layouts ,
147+ tile_descriptions ,
148+ data_type ,
149+ alignment_constraints ,
150+ swizzling_functor ,
151+ batched = batched ,
154152 )
155153
156- ops .extend (out )
157-
158- return ops
159-
160-
161- def generate_sm75_tensor_op_1688 (out_dtype , batched = False ):
162- """Generate GEMM kernels for Turing."""
163- assert out_dtype in ["float32" , "float16" ]
164- math_instructions = {
165- "float32" : [
166- MathInstruction (
167- [16 , 8 , 8 ],
168- DataType .f16 ,
169- DataType .f16 ,
170- DataType .f32 ,
171- OpcodeClass .TensorOp ,
172- MathOperation .multiply_add ,
173- )
174- ],
175- "float16" : [
176- MathInstruction (
177- [16 , 8 , 8 ],
178- DataType .f16 ,
179- DataType .f16 ,
180- DataType .f16 ,
181- OpcodeClass .TensorOp ,
182- MathOperation .multiply_add ,
183- )
184- ],
185- }[out_dtype ]
186-
187- alignment_constraints = [8 , 4 , 2 , 1 ]
188-
189- def get_tile_descriptions (math_inst ):
190- min_cc = 75
191- max_cc = 1024
192- return [
193- TileDescription ([256 , 128 , 32 ], 2 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
194- TileDescription ([128 , 256 , 32 ], 2 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
195- TileDescription ([128 , 128 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
196- TileDescription ([64 , 128 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
197- TileDescription ([128 , 64 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
198- TileDescription ([64 , 64 , 32 ], 2 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
199- TileDescription ([64 , 128 , 64 ], 2 , [1 , 2 , 2 ], math_inst , min_cc , max_cc ),
200- ]
201-
202- return generate_tensor_op_common (
203- math_instructions , alignment_constraints , get_tile_descriptions , batched
204- )
205-
206-
207- def generate_sm80_tensor_op_16816 (out_dtype , batched = False ):
208- """Generate GEMM kernels for Ampere."""
209- assert out_dtype in ["float32" , "float16" ]
210- math_instructions = {
211- "float32" : [
212- MathInstruction (
213- [16 , 8 , 16 ],
214- DataType .f16 ,
215- DataType .f16 ,
216- DataType .f32 ,
217- OpcodeClass .TensorOp ,
218- MathOperation .multiply_add ,
219- )
220- ],
221- "float16" : [
222- MathInstruction (
223- [16 , 8 , 16 ],
224- DataType .f16 ,
225- DataType .f16 ,
226- DataType .f16 ,
227- OpcodeClass .TensorOp ,
228- MathOperation .multiply_add ,
229- )
230- ],
231- }[out_dtype ]
232-
233- alignment_constraints = [8 , 4 , 2 ]
234-
235- def get_tile_descriptions (math_inst ):
236- min_cc = 80
237- max_cc = 1024
238- max_cc_smem_limited = 80
239- return [
240- TileDescription ([256 , 128 , 32 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc ),
241- TileDescription ([128 , 256 , 32 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc ),
242- TileDescription ([256 , 64 , 32 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc ),
243- TileDescription ([64 , 256 , 32 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc ),
244- TileDescription ([128 , 128 , 32 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
245- TileDescription ([128 , 128 , 32 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
246- TileDescription ([128 , 128 , 32 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
247- TileDescription ([128 , 64 , 32 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
248- TileDescription ([64 , 128 , 32 ], 6 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
249- TileDescription ([64 , 64 , 32 ], 10 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
250- TileDescription ([256 , 128 , 64 ], 3 , [4 , 2 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
251- TileDescription ([128 , 256 , 64 ], 3 , [2 , 4 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
252- TileDescription ([256 , 64 , 64 ], 4 , [4 , 1 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
253- TileDescription ([64 , 256 , 64 ], 4 , [1 , 4 , 1 ], math_inst , min_cc , max_cc_smem_limited ),
254- TileDescription ([128 , 128 , 64 ], 4 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
255- TileDescription ([128 , 64 , 64 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
256- TileDescription ([64 , 128 , 64 ], 3 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
257- TileDescription ([64 , 64 , 64 ], 5 , [2 , 2 , 1 ], math_inst , min_cc , max_cc ),
258- ]
259-
260- return generate_tensor_op_common (
261- math_instructions , alignment_constraints , get_tile_descriptions , batched
262- )
154+ return op_creator
263155
264156
265157GENERATOR_FUNC_TABLE = {
266158 75 : generate_sm75_tensor_op_1688 ,
267159 80 : generate_sm80_tensor_op_16816 ,
268160}
269161
162+
270163# TODO(masahi): A sensible way to pick reasonable default kernels
271164DEFAULT_KERNELS = {
272165 75 : {
@@ -280,66 +173,6 @@ def get_tile_descriptions(math_inst):
280173}
281174
282175
283- class ProfilerEngine :
284- """Compile and run a given profiler executable."""
285-
286- def __init__ (self , cuda_arch , cutlass_path , binary_prefix ):
287- self .cuda_arch = cuda_arch
288- self .binary_prefix = binary_prefix
289- self .cutlass = cutlass_path
290- self .cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11" .format (
291- cutlass = cutlass_path
292- )
293- self .cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
294- self .cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]" .format (
295- arch = cuda_arch
296- )
297- self .cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
298- self .cmd = "nvcc {cflags} {src} -o {output}"
299-
300- def _compile (self , op ):
301- os .makedirs (self .binary_prefix , exist_ok = True )
302- opath = os .path .join (self .binary_prefix , op ["name" ])
303- if os .path .exists (opath ):
304- return
305- fi = tempfile .NamedTemporaryFile ("w" , delete = False , suffix = ".cu" )
306- fi .write (op ["src" ])
307- fi .close ()
308- cmd = self .cmd .format (cflags = self .cflags , src = fi .name , output = opath )
309- os .system (cmd )
310- os .unlink (fi .name )
311-
312- def compile_all (self , ops , use_multiprocessing = False ):
313- """Compile all profiler executables."""
314- if use_multiprocessing :
315- pool = multiprocessing .Pool (multiprocessing .cpu_count ())
316- pool .map (self ._compile , ops )
317- else :
318- for op in ops :
319- self ._compile (op )
320-
321- def evaluate (self , op , args ):
322- """Run the profiler executable corresponding to op_name with args."""
323- op_name = op ["name" ]
324- opath = os .path .join (self .binary_prefix , op_name )
325- if not os .path .exists (opath ):
326- self ._compile (op )
327- cmd = [opath ]
328- if args is not None :
329- cmd .append (str (args [0 ]))
330- cmd .append (str (args [1 ]))
331- cmd .append (str (args [2 ]))
332- if len (args ) > 3 :
333- cmd .append (str (args [3 ]))
334- try :
335- sp = subprocess .run (cmd , capture_output = True , check = True )
336- rt = float (sp .stdout )
337- logger .info ("%s, %f" , op_name , rt )
338- except subprocess .CalledProcessError :
339- rt = - 1
340- return rt
341-
342-
343176class CutlassGemmProfiler :
344177 """Profile all candidate kernels and select the best one."""
345178
@@ -362,7 +195,9 @@ def get_default(self, out_dtype, batched=False):
362195 """Return the default kernel for the requested architecture.
363196 For now, the default kernel was picked arbitrary.
364197 """
365- ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , batched )
198+ ops = GENERATOR_FUNC_TABLE [self .sm ](
199+ out_dtype , op_creator = create_gemm_operator (batched )
200+ )
366201 default_kernel_name = DEFAULT_KERNELS [self .sm ][out_dtype ]
367202 filtered = list (filter (lambda op : op ["name" ] == default_kernel_name , ops ))
368203 assert len (filtered ) == 1
@@ -378,7 +213,9 @@ def profile(
378213 if (M , N , K ) in self .cache :
379214 return self .cache [(M , N , K )]
380215
381- ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , batched )
216+ ops = GENERATOR_FUNC_TABLE [self .sm ](
217+ out_dtype , op_creator = create_gemm_operator (batched )
218+ )
382219 ops = list (filter (lambda op : self .check_align (op ["name" ], M ), ops ))
383220
384221 for op in ops :
0 commit comments