Skip to content

Commit ca1ae27

Browse files
committed
introduce gen_tensor_op.py
1 parent 37bb918 commit ca1ae27

File tree

2 files changed

+255
-192
lines changed

2 files changed

+255
-192
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 29 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""GEMM kernel generator and profiler for CUTLASS."""
19-
import logging
2019
import os
2120
import re
2221
import tempfile
2322
import subprocess
2423
import multiprocessing
2524
from .gemm_operation import GemmOperation, EmitGemmInstance
2625
from .gemm_profiler import GemmProfilerEmitter
26+
from gen_tensor_op import (
27+
ProfilerEngine,
28+
generate_sm75_tensor_op_1688,
29+
generate_sm80_tensor_op_16816,
30+
)
2731
from .library import (
2832
EpilogueFunctor,
2933
SwizzlingFunctor,
@@ -37,10 +41,8 @@
3741
TileDescription,
3842
)
3943

40-
logger = logging.getLogger("cutlass")
41-
4244

43-
def create_gemm_operator(
45+
def _create_gemm_operator(
4446
layouts,
4547
tile_descriptions,
4648
data_type,
@@ -132,141 +134,32 @@ def create_gemm_operator(
132134
return ret
133135

134136

135-
def generate_tensor_op_common(
136-
math_instructions, alignment_constraints, get_tile_descriptions, batched=False
137-
):
138-
"""Common kernel generator to be used by archtecture specific generators."""
139-
ops = []
140-
layouts = [
141-
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
142-
]
143-
for math_inst in math_instructions:
144-
tile_descriptions = get_tile_descriptions(math_inst)
145-
data_type = [
146-
math_inst.element_a,
147-
math_inst.element_b,
148-
math_inst.element_accumulator,
149-
math_inst.element_accumulator,
150-
]
151-
152-
out = create_gemm_operator(
153-
layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
137+
def create_gemm_operator(batched):
138+
def op_creator(
139+
layouts,
140+
tile_descriptions,
141+
data_type,
142+
alignment_constraints,
143+
swizzling_functor=SwizzlingFunctor.Identity8,
144+
):
145+
return _create_gemm_operator(
146+
layouts,
147+
tile_descriptions,
148+
data_type,
149+
alignment_constraints,
150+
swizzling_functor,
151+
batched=batched,
154152
)
155153

156-
ops.extend(out)
157-
158-
return ops
159-
160-
161-
def generate_sm75_tensor_op_1688(out_dtype, batched=False):
162-
"""Generate GEMM kernels for Turing."""
163-
assert out_dtype in ["float32", "float16"]
164-
math_instructions = {
165-
"float32": [
166-
MathInstruction(
167-
[16, 8, 8],
168-
DataType.f16,
169-
DataType.f16,
170-
DataType.f32,
171-
OpcodeClass.TensorOp,
172-
MathOperation.multiply_add,
173-
)
174-
],
175-
"float16": [
176-
MathInstruction(
177-
[16, 8, 8],
178-
DataType.f16,
179-
DataType.f16,
180-
DataType.f16,
181-
OpcodeClass.TensorOp,
182-
MathOperation.multiply_add,
183-
)
184-
],
185-
}[out_dtype]
186-
187-
alignment_constraints = [8, 4, 2, 1]
188-
189-
def get_tile_descriptions(math_inst):
190-
min_cc = 75
191-
max_cc = 1024
192-
return [
193-
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
194-
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
195-
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
196-
TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
197-
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
198-
TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
199-
TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
200-
]
201-
202-
return generate_tensor_op_common(
203-
math_instructions, alignment_constraints, get_tile_descriptions, batched
204-
)
205-
206-
207-
def generate_sm80_tensor_op_16816(out_dtype, batched=False):
208-
"""Generate GEMM kernels for Ampere."""
209-
assert out_dtype in ["float32", "float16"]
210-
math_instructions = {
211-
"float32": [
212-
MathInstruction(
213-
[16, 8, 16],
214-
DataType.f16,
215-
DataType.f16,
216-
DataType.f32,
217-
OpcodeClass.TensorOp,
218-
MathOperation.multiply_add,
219-
)
220-
],
221-
"float16": [
222-
MathInstruction(
223-
[16, 8, 16],
224-
DataType.f16,
225-
DataType.f16,
226-
DataType.f16,
227-
OpcodeClass.TensorOp,
228-
MathOperation.multiply_add,
229-
)
230-
],
231-
}[out_dtype]
232-
233-
alignment_constraints = [8, 4, 2]
234-
235-
def get_tile_descriptions(math_inst):
236-
min_cc = 80
237-
max_cc = 1024
238-
max_cc_smem_limited = 80
239-
return [
240-
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
241-
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
242-
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
243-
TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
244-
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
245-
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
246-
TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
247-
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
248-
TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
249-
TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
250-
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
251-
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
252-
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
253-
TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
254-
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
255-
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
256-
TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
257-
TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
258-
]
259-
260-
return generate_tensor_op_common(
261-
math_instructions, alignment_constraints, get_tile_descriptions, batched
262-
)
154+
return op_creator
263155

264156

265157
GENERATOR_FUNC_TABLE = {
266158
75: generate_sm75_tensor_op_1688,
267159
80: generate_sm80_tensor_op_16816,
268160
}
269161

162+
270163
# TODO(masahi): A sensible way to pick reasonable default kernels
271164
DEFAULT_KERNELS = {
272165
75: {
@@ -280,66 +173,6 @@ def get_tile_descriptions(math_inst):
280173
}
281174

282175

283-
class ProfilerEngine:
284-
"""Compile and run a given profiler executable."""
285-
286-
def __init__(self, cuda_arch, cutlass_path, binary_prefix):
287-
self.cuda_arch = cuda_arch
288-
self.binary_prefix = binary_prefix
289-
self.cutlass = cutlass_path
290-
self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
291-
cutlass=cutlass_path
292-
)
293-
self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
294-
self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
295-
arch=cuda_arch
296-
)
297-
self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
298-
self.cmd = "nvcc {cflags} {src} -o {output}"
299-
300-
def _compile(self, op):
301-
os.makedirs(self.binary_prefix, exist_ok=True)
302-
opath = os.path.join(self.binary_prefix, op["name"])
303-
if os.path.exists(opath):
304-
return
305-
fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
306-
fi.write(op["src"])
307-
fi.close()
308-
cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
309-
os.system(cmd)
310-
os.unlink(fi.name)
311-
312-
def compile_all(self, ops, use_multiprocessing=False):
313-
"""Compile all profiler executables."""
314-
if use_multiprocessing:
315-
pool = multiprocessing.Pool(multiprocessing.cpu_count())
316-
pool.map(self._compile, ops)
317-
else:
318-
for op in ops:
319-
self._compile(op)
320-
321-
def evaluate(self, op, args):
322-
"""Run the profiler executable corresponding to op_name with args."""
323-
op_name = op["name"]
324-
opath = os.path.join(self.binary_prefix, op_name)
325-
if not os.path.exists(opath):
326-
self._compile(op)
327-
cmd = [opath]
328-
if args is not None:
329-
cmd.append(str(args[0]))
330-
cmd.append(str(args[1]))
331-
cmd.append(str(args[2]))
332-
if len(args) > 3:
333-
cmd.append(str(args[3]))
334-
try:
335-
sp = subprocess.run(cmd, capture_output=True, check=True)
336-
rt = float(sp.stdout)
337-
logger.info("%s, %f", op_name, rt)
338-
except subprocess.CalledProcessError:
339-
rt = -1
340-
return rt
341-
342-
343176
class CutlassGemmProfiler:
344177
"""Profile all candidate kernels and select the best one."""
345178

@@ -362,7 +195,9 @@ def get_default(self, out_dtype, batched=False):
362195
"""Return the default kernel for the requested architecture.
363196
For now, the default kernel was picked arbitrary.
364197
"""
365-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
198+
ops = GENERATOR_FUNC_TABLE[self.sm](
199+
out_dtype, op_creator=create_gemm_operator(batched)
200+
)
366201
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
367202
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
368203
assert len(filtered) == 1
@@ -378,7 +213,9 @@ def profile(
378213
if (M, N, K) in self.cache:
379214
return self.cache[(M, N, K)]
380215

381-
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
216+
ops = GENERATOR_FUNC_TABLE[self.sm](
217+
out_dtype, op_creator=create_gemm_operator(batched)
218+
)
382219
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
383220

384221
for op in ops:

0 commit comments

Comments
 (0)