Skip to content

Commit 20c7e72

Browse files
authored
[ingress][mlir-gen] Add Python mlir-gen and supporting venv install and basic ir builder scripts (#2)
A rewrite of the original mlir-gen in C++ in tpp-mlir.
1 parent 821f270 commit 20c7e72

File tree

9 files changed

+764
-0
lines changed

9 files changed

+764
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
3+
source mlir-gen-venv/bin/activate
4+
5+
LAYERS=1024,2048,4096,512
6+
7+
mkdir -p cache
8+
python -m mlir_gen --output named --kernel args --layers $LAYERS --batch 256 --bias --relu > cache/linalg-named-3layer-mlp.mlir
9+
python -m mlir_gen --output einsum --kernel args --layers $LAYERS --batch 256 --bias --relu > cache/linalg-einsum-3layer-mlp.mlir
10+
python -m mlir_gen --output generic --kernel args --layers $LAYERS --batch 256 --bias --relu > cache/linalg-generic-3layer-mlp.mlir
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/usr/bin/env bash
2+
3+
echo "First ensure uv is installed"
4+
5+
python -m pip install uv --upgrade
6+
7+
echo "Preparing the virtual environment"
8+
python -m uv venv mlir-gen-venv --python 3.13
9+
10+
source mlir-gen-venv/bin/activate
11+
12+
echo "Installing mlir-python-bindings and numpy"
13+
uv pip install numpy mlir-python-bindings -f https://makslevental.github.io/wheels

ingress/mlir-gen/mlir_gen/__init__.py

Whitespace-only changes.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import sys
2+
from . import main
3+
4+
# Invoke on command line with `python -m mlir_gen`.
5+
main.main(sys.argv[1:])
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Union
2+
3+
from mlir import ir
4+
from mlir.dialects import arith, linalg, tensor
5+
6+
from . import named, generic
7+
from .utils import get_outputs, get_weights, get_bias, affine_map
8+
9+
10+
def times_weights(
11+
inputs: ir.Value,
12+
weights_or_weights_type: Union[ir.Value, ir.RankedTensorType],
13+
outputs_or_outputs_type: Union[ir.Value, ir.RankedTensorType],
14+
) -> ir.Value:
15+
weights: ir.Value = get_weights(weights_or_weights_type)
16+
outputs: ir.Value = get_outputs(outputs_or_outputs_type)
17+
18+
affine_maps, _ = generic.affine_maps_and_iter_types(weights.type.rank)
19+
20+
if weights.type.rank == 5: # tiled weights with vnni blocking
21+
vnni_block = weights.type.get_dim_size(4)
22+
assert inputs.type.shape[-1] % vnni_block == 0
23+
24+
expanded_shape = (
25+
inputs.type.shape[:-1]
26+
+ [inputs.type.shape[-1] // vnni_block]
27+
+ [vnni_block]
28+
)
29+
expanded_type = ir.RankedTensorType.get(
30+
expanded_shape, inputs.type.element_type
31+
)
32+
inputs = tensor.expand_shape(
33+
expanded_type,
34+
inputs,
35+
reassociation=[[0], [1], [2], [3, 4]],
36+
output_shape=[],
37+
static_output_shape=expanded_shape,
38+
)
39+
40+
return linalg.contract(inputs, weights, outs=[outputs], indexing_maps=affine_maps)
41+
42+
43+
def add_bias(inputs: ir.Value, bias_or_bias_type: Union[ir.Value, ir.Type]) -> ir.Value:
44+
bias: ir.Value = get_bias(bias_or_bias_type)
45+
46+
M, N, mb, nb = [ir.AffineDimExpr.get(i) for i in range(4)]
47+
affine_maps = {
48+
2: [affine_map(2, [N]), affine_map(2, [M, N]), affine_map(2, [M, N])],
49+
4: [
50+
affine_map(4, [N, nb]),
51+
affine_map(4, [M, N, mb, nb]),
52+
affine_map(4, [M, N, mb, nb]),
53+
],
54+
}[inputs.type.rank]
55+
56+
out_uninit = tensor.EmptyOp(inputs.type.shape, inputs.type.element_type)
57+
return linalg.elementwise(
58+
bias,
59+
inputs,
60+
outs=(out_uninit,),
61+
kind=linalg.ElementwiseKind.add,
62+
indexing_maps=affine_maps,
63+
)
64+
65+
66+
def relu(inputs: ir.Value) -> ir.Value:
67+
zero = arith.constant(inputs.type.element_type, 0.0)
68+
out_uninit = tensor.EmptyOp(inputs.type.shape, inputs.type.element_type)
69+
out = linalg.fill(zero, outs=out_uninit)
70+
71+
return linalg.elementwise(
72+
inputs,
73+
out,
74+
outs=(out_uninit,),
75+
kind=linalg.ElementwiseKind.max_signed, # NB: on float args, gives arith.maximumf in body
76+
)
77+
78+
79+
softmax = named.softmax
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from typing import Union
2+
3+
from mlir import ir
4+
from mlir.dialects import linalg, arith, tensor, math
5+
6+
from .utils import (
7+
affine_map,
8+
get_bias,
9+
get_outputs,
10+
get_weights,
11+
parallel,
12+
reduction,
13+
)
14+
15+
16+
def affine_maps_and_iter_types(rank: int):
17+
M, N, K = [ir.AffineDimExpr.get(i) for i in range(3)]
18+
19+
if rank == 2: # plain 2D weights
20+
affine_maps = [
21+
affine_map(3, [M, K]),
22+
affine_map(3, [K, N]),
23+
affine_map(3, [M, N]),
24+
]
25+
iterator_types = [parallel, parallel, reduction]
26+
elif rank == 4: # tiled weights, no vnni blocking
27+
mb, nb, kb = [ir.AffineDimExpr.get(i) for i in range(3, 6)]
28+
affine_maps = [
29+
affine_map(6, [M, K, mb, kb]),
30+
affine_map(6, [N, K, kb, nb]), # transposed K and N on B
31+
affine_map(6, [M, N, mb, nb]),
32+
]
33+
iterator_types = [parallel, parallel, reduction] * 2
34+
elif rank == 5: # tiled weights with vnni blocking
35+
# FIXME: due to replicating C++ code, vnni dim is in middle instead of at end.
36+
k_vnni, mb, nb, kb = [ir.AffineDimExpr.get(i) for i in range(3, 7)]
37+
38+
affine_maps = [
39+
affine_map(7, [M, K, mb, kb, k_vnni]),
40+
# TODO(RM): check if kb and (k_)vnni _not_ being contiguous makes sense.
41+
affine_map(7, [N, K, kb, nb, k_vnni]), # transposed K and N on B
42+
affine_map(7, [M, N, mb, nb]),
43+
]
44+
iterator_types = [
45+
parallel, # M
46+
parallel, # N
47+
reduction, # K
48+
reduction, # vnni
49+
parallel, # mb
50+
parallel, # nb
51+
reduction, # kb
52+
]
53+
else:
54+
assert False
55+
56+
return affine_maps, iterator_types
57+
58+
59+
def times_weights(
60+
inputs: ir.Value,
61+
weights_or_weights_type: Union[ir.Value, ir.RankedTensorType],
62+
outputs_or_outputs_type: Union[ir.Value, ir.RankedTensorType],
63+
) -> ir.Value:
64+
weights: ir.Value = get_weights(weights_or_weights_type)
65+
outputs: ir.Value = get_outputs(outputs_or_outputs_type)
66+
67+
if weights.type.rank == 5: # tiled weights with vnni blocking
68+
vnni_block = weights.type.get_dim_size(4)
69+
assert inputs.type.shape[-1] % vnni_block == 0
70+
71+
expanded_shape = (
72+
inputs.type.shape[:-1]
73+
+ [inputs.type.shape[-1] // vnni_block]
74+
+ [vnni_block]
75+
)
76+
inputs = tensor.expand_shape(
77+
ir.RankedTensorType.get(expanded_shape, inputs.type.element_type),
78+
inputs,
79+
reassociation=[[0], [1], [2], [3, 4]],
80+
output_shape=[],
81+
static_output_shape=expanded_shape,
82+
)
83+
84+
affine_maps, iterator_types = affine_maps_and_iter_types(weights.type.rank)
85+
86+
@linalg.generic([inputs, weights], [outputs], affine_maps, iterator_types)
87+
def inputs_times_weights(a, b, c):
88+
prod = arith.MulFOp(a, b)
89+
return arith.AddFOp(prod.result, c)
90+
91+
return inputs_times_weights
92+
93+
94+
def add_bias(inputs: ir.Value, bias_or_bias_type: Union[ir.Value, ir.Type] = None):
95+
bias: ir.Value = get_bias(bias_or_bias_type)
96+
97+
M, N, mb, nb = [ir.AffineDimExpr.get(i) for i in range(4)]
98+
affine_maps, iterator_types = {
99+
2: ([affine_map(2, [N]), affine_map(2, [M, N])], [parallel] * 2),
100+
4: ([affine_map(4, [N, nb]), affine_map(4, [M, N, mb, nb])], [parallel] * 4),
101+
}[inputs.type.rank]
102+
103+
@linalg.generic([bias], [inputs], affine_maps, iterator_types)
104+
def biased(a, b):
105+
return arith.AddFOp(a, b)
106+
107+
return biased
108+
109+
110+
def relu(inputs: ir.Value):
111+
zero = arith.constant(inputs.type.element_type, 0.0)
112+
113+
M, N, mb, nb = [ir.AffineDimExpr.get(i) for i in range(4)]
114+
affine_maps, iterator_types = {
115+
2: ([affine_map(2, [M, N])], [parallel, parallel]),
116+
4: ([affine_map(4, [M, N, mb, nb])], [parallel, parallel] * 2),
117+
}[inputs.type.rank]
118+
119+
@linalg.generic([], [inputs], affine_maps, iterator_types)
120+
def relu_ed(a):
121+
return arith.MaximumFOp(a, zero)
122+
123+
return relu_ed
124+
125+
126+
def softmax(
127+
inputs: ir.Value, softmax_buf_or_softmax_buf_type: Union[ir.Value, ir.Type]
128+
) -> ir.Value:
129+
softmax_buf = get_outputs(softmax_buf_or_softmax_buf_type)
130+
131+
shape, elem_type = inputs.type.shape, inputs.type.element_type
132+
exp_out_uninit = tensor.EmptyOp(shape, elem_type)
133+
134+
dims = [ir.AffineDimExpr.get(i) for i in range(inputs.type.rank)]
135+
par_affine_map = affine_map(inputs.type.rank, dims)
136+
par_affine_maps = [par_affine_map] * inputs.type.rank
137+
par_iter_types = [parallel] * inputs.type.rank
138+
red_affine_map = affine_map(
139+
inputs.type.rank, [dims[0], ir.AffineConstantExpr.get(0)]
140+
)
141+
red_iter_types = [parallel, reduction] * (inputs.type.rank // 2)
142+
143+
@linalg.generic([inputs], [exp_out_uninit.result], par_affine_maps, par_iter_types)
144+
def exped(input, _output):
145+
return math.exp(input)
146+
147+
zero = arith.constant(elem_type, 0.0)
148+
reduction_out_uninit = tensor.EmptyOp((shape[0], 1), elem_type)
149+
reduction_out = linalg.fill(zero, outs=reduction_out_uninit)
150+
151+
@linalg.generic(
152+
[exped], [reduction_out], [par_affine_map, red_affine_map], red_iter_types
153+
)
154+
def summed_exped(exped_input, redex):
155+
return arith.AddFOp(exped_input, redex)
156+
157+
bcast_out_uninit = tensor.EmptyOp(shape, elem_type)
158+
159+
@linalg.generic(
160+
[summed_exped],
161+
[bcast_out_uninit.result],
162+
[red_affine_map, par_affine_map],
163+
par_iter_types,
164+
)
165+
def bcasted_summed_exped(input, _output):
166+
return input
167+
168+
@linalg.generic(
169+
[exped, bcasted_summed_exped],
170+
[softmax_buf],
171+
[par_affine_map] * 3,
172+
par_iter_types,
173+
)
174+
def dived_bcasted_summed_exped(exped_input, normalizing_term, _output):
175+
return arith.DivFOp(exped_input, normalizing_term)
176+
177+
return dived_bcasted_summed_exped

0 commit comments

Comments
 (0)