Skip to content

Commit ddf9d82

Browse files
add triton kernels for float8 quantization with jagged rowwise scales
1 parent f788897 commit ddf9d82

File tree

6 files changed

+471
-31
lines changed

6 files changed

+471
-31
lines changed

torchao/prototype/scaled_grouped_mm/kernels/__init__.py

Whitespace-only changes.
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Triton kernels for scaling high precision tensors to float8.
9+
"""
10+
import itertools
11+
from typing import Tuple
12+
13+
import torch
14+
import triton
15+
import triton.language as tl
16+
17+
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
18+
19+
EPS = 1e-12
20+
21+
FP8_DTYPE_MAP = {
22+
torch.int8: tl.int8,
23+
torch.int16: tl.int16,
24+
torch.int32: tl.int32,
25+
torch.int64: tl.int64,
26+
torch.float8_e4m3fn: tl.float8e4nv,
27+
torch.float8_e5m2: tl.float8e5,
28+
torch.float16: tl.float16,
29+
torch.bfloat16: tl.bfloat16,
30+
torch.float32: tl.float32,
31+
torch.float64: tl.float64,
32+
}
33+
34+
block_sizes = [128, 256]
35+
kernel_configs_2D = [
36+
triton.Config({"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols})
37+
for block_size_rows in block_sizes
38+
for block_size_cols in block_sizes
39+
]
40+
41+
42+
def triton_fp8_row_major_jagged_rowwise_scales(
43+
hp_tensor: torch.Tensor,
44+
offsets: torch.Tensor,
45+
output_dtype: torch.dtype = torch.float8_e4m3fn,
46+
round_scales_to_power_of_2: bool = False,
47+
) -> Tuple[torch.Tensor, torch.Tensor]:
48+
"""
49+
Converts a high precision tensor to a float8 tensor is row-major memory layout,
50+
using 'jagged' rowwise scales (i.e., separate scales for each group/subtensor as
51+
determined by the offsets).
52+
53+
Args:
54+
- hp_tensor: 2D high precision tensor to be converted
55+
- fp8_dtype: desired fp8 dtype
56+
- offsets: end index for each group/subtensor along dim 1
57+
Returns:
58+
- float8 tensor
59+
- jagged rowwise scales (i.e., rowwise scales for each group)
60+
"""
61+
assert hp_tensor.ndim == 2, "input tensor must be 2D"
62+
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"
63+
64+
num_elements = hp_tensor.numel()
65+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
66+
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
67+
68+
fp8_dtype_min = torch.finfo(output_dtype).min
69+
fp8_dtype_max = torch.finfo(output_dtype).max
70+
71+
m, k = hp_tensor.shape
72+
n_groups = offsets.numel()
73+
74+
# perform fp8 conversion
75+
output_buffer = torch.empty_like(
76+
hp_tensor, dtype=output_dtype, device=hp_tensor.device
77+
)
78+
scales_buffer = torch.empty(
79+
(m * n_groups), dtype=torch.float32, device=hp_tensor.device
80+
)
81+
82+
# parallelize across rows and groups (offsets)
83+
grid = lambda meta: (
84+
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
85+
offsets.numel(),
86+
)
87+
_triton_fp8_row_major_jagged_rowwise_scales[grid](
88+
hp_tensor,
89+
offsets,
90+
output_buffer,
91+
scales_buffer,
92+
m,
93+
k,
94+
hp_tensor.stride(0),
95+
hp_tensor.stride(1),
96+
output_buffer.stride(0),
97+
output_buffer.stride(1),
98+
num_elements,
99+
fp8_dtype_min,
100+
fp8_dtype_max,
101+
tl_input_dtype,
102+
tl_output_dtype,
103+
round_scales_to_power_of_2,
104+
EPS=EPS,
105+
)
106+
return output_buffer, scales_buffer
107+
108+
109+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
110+
@triton.jit
111+
def _triton_fp8_row_major_jagged_rowwise_scales(
112+
input_ptr,
113+
offsets_ptr,
114+
out_ptr,
115+
scales_ptr,
116+
M: int,
117+
K: int,
118+
stride_input_row: int,
119+
stride_input_col: int,
120+
stride_output_row: int,
121+
stride_output_col: int,
122+
num_elements: int,
123+
fp8_dtype_min: tl.constexpr,
124+
fp8_dtype_max: tl.constexpr,
125+
input_dtype: tl.constexpr,
126+
output_dtype: tl.constexpr,
127+
round_scales_to_power_of_2: tl.constexpr,
128+
BLOCK_SIZE_ROWS: tl.constexpr,
129+
BLOCK_SIZE_COLS: tl.constexpr,
130+
EPS: tl.constexpr,
131+
):
132+
# parallel across rows and groups (offsets)
133+
block_row_id = tl.program_id(axis=0)
134+
offset_idx = tl.program_id(axis=1)
135+
136+
# determine start and end column idx for this group
137+
block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
138+
group_col_start_idx = tl.load(
139+
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
140+
)
141+
group_col_end_idx = tl.load(offsets_ptr + offset_idx)
142+
143+
# compute rowwise amaxes for this group
144+
amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=tl.float64)
145+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
146+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
147+
block_offs = (
148+
block_row_offs[:, None] * stride_input_row
149+
+ block_col_offs[None, :] * stride_input_col
150+
)
151+
block_mask = (block_row_offs[:, None] < M) & (
152+
block_col_offs[None, :] < group_col_end_idx
153+
)
154+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
155+
input_dtype
156+
)
157+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1))
158+
159+
# compute rowwise scales for this group. round scales to nearest power of 2.
160+
scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to(
161+
tl.float32
162+
)
163+
if round_scales_to_power_of_2:
164+
scales = tl.exp2(tl.floor(tl.log2(scales)))
165+
166+
# store rowwise scales for each group in contiguous memory:
167+
# [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
168+
scales_offs = block_row_offs + (M * offset_idx)
169+
scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M
170+
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
171+
172+
# perform float8 conversion for this group
173+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
174+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
175+
block_offs = (
176+
block_row_offs[:, None] * stride_input_row
177+
+ block_col_offs[None, :] * stride_input_col
178+
)
179+
block_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K)
180+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
181+
input_dtype
182+
)
183+
scaled_data = data * scales[:, None]
184+
fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
185+
output_dtype
186+
)
187+
out_offs = (
188+
block_row_offs[:, None] * stride_output_row
189+
+ block_col_offs[None, :] * stride_output_col
190+
)
191+
out_mask = (block_row_offs[:, None] < M) & (block_col_offs[None, :] < K)
192+
tl.store(out_ptr + out_offs, fp8_data, mask=out_mask)
193+
194+
195+
def triton_fp8_col_major_jagged_colwise_scales(
196+
hp_tensor: torch.Tensor,
197+
offsets: torch.Tensor,
198+
output_dtype: torch.dtype = torch.float8_e4m3fn,
199+
round_scales_to_power_of_2: bool = False,
200+
) -> Tuple[torch.Tensor, torch.Tensor]:
201+
"""
202+
Converts a high precision tensor to a float8 tensor is row-major memory layout,
203+
using 'jagged' column-wise scales (i.e., separate scales for each group/subtensor as
204+
determined by the offsets).
205+
206+
Args:
207+
- hp_tensor: 2D high precision tensor to be converted
208+
- fp8_dtype: desired fp8 dtype
209+
- offsets: end index for each group/subtensor along dim 0
210+
Returns:
211+
- float8 tensor
212+
- jagged column-wise scales (i.e., column-wise scales for each group)
213+
"""
214+
assert hp_tensor.ndim == 2, "input tensor must be 2D"
215+
assert _is_column_major(hp_tensor), "input tensor must be column-major"
216+
217+
num_elements = hp_tensor.numel()
218+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
219+
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
220+
221+
fp8_dtype_min = torch.finfo(output_dtype).min
222+
fp8_dtype_max = torch.finfo(output_dtype).max
223+
224+
k, n = hp_tensor.shape
225+
n_groups = offsets.numel()
226+
227+
# perform fp8 conversion
228+
output_buffer = torch.empty_like(
229+
hp_tensor, dtype=output_dtype, device=hp_tensor.device
230+
)
231+
scales_buffer = torch.empty(
232+
(n * n_groups), dtype=torch.float32, device=hp_tensor.device
233+
)
234+
235+
# parallelize across columns and groups (offsets)
236+
grid = lambda meta: (
237+
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
238+
offsets.numel(),
239+
)
240+
_triton_fp8_col_major_jagged_colwise_scales[grid](
241+
hp_tensor,
242+
offsets,
243+
output_buffer,
244+
scales_buffer,
245+
k,
246+
n,
247+
hp_tensor.stride(0),
248+
hp_tensor.stride(1),
249+
output_buffer.stride(0),
250+
output_buffer.stride(1),
251+
num_elements,
252+
fp8_dtype_min,
253+
fp8_dtype_max,
254+
tl_input_dtype,
255+
tl_output_dtype,
256+
round_scales_to_power_of_2,
257+
EPS=EPS,
258+
)
259+
return output_buffer, scales_buffer
260+
261+
262+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
263+
@triton.jit
264+
def _triton_fp8_col_major_jagged_colwise_scales(
265+
input_ptr,
266+
offsets_ptr,
267+
out_ptr,
268+
scales_ptr,
269+
K: int,
270+
N: int,
271+
stride_input_row: int,
272+
stride_input_col: int,
273+
stride_output_row: int,
274+
stride_output_col: int,
275+
num_elements: int,
276+
fp8_dtype_min: tl.constexpr,
277+
fp8_dtype_max: tl.constexpr,
278+
input_dtype: tl.constexpr,
279+
output_dtype: tl.constexpr,
280+
round_scales_to_power_of_2: tl.constexpr,
281+
BLOCK_SIZE_ROWS: tl.constexpr,
282+
BLOCK_SIZE_COLS: tl.constexpr,
283+
EPS: tl.constexpr,
284+
):
285+
# parallel across columns and groups (offsets)
286+
block_col_id = tl.program_id(axis=0)
287+
offset_idx = tl.program_id(axis=1)
288+
289+
# determine start and end row idx for this group
290+
block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
291+
group_row_start_idx = tl.load(
292+
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
293+
)
294+
group_row_end_idx = tl.load(offsets_ptr + offset_idx)
295+
296+
# compute colwise amaxes for this group
297+
amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=tl.float64)
298+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
299+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
300+
block_offs = (
301+
block_row_offs[:, None] * stride_input_row
302+
+ block_col_offs[None, :] * stride_input_col
303+
)
304+
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
305+
block_col_offs[None, :] < N
306+
)
307+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
308+
input_dtype
309+
)
310+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0))
311+
312+
# compute rowwise scales for this group.
313+
scales = (fp8_dtype_max / tl.clamp(amax_buffer, min=EPS, max=float("inf"))).to(
314+
tl.float32
315+
)
316+
if round_scales_to_power_of_2:
317+
scales = tl.exp2(tl.floor(tl.log2(scales)))
318+
319+
# store colwise scales for each group in contiguous memory:
320+
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
321+
# note: input tensor is in col-major memory layout.
322+
scales_offs = block_col_offs + (N * offset_idx)
323+
scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N
324+
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
325+
326+
# perform float8 conversion for this group
327+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
328+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
329+
block_offs = (
330+
block_row_offs[:, None] * stride_input_row
331+
+ block_col_offs[None, :] * stride_input_col
332+
)
333+
block_mask = (block_row_offs[:, None] < group_row_end_idx) & (
334+
block_col_offs[None, :] < N
335+
)
336+
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
337+
input_dtype
338+
)
339+
scaled_data = data * scales[None, :]
340+
fp8_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
341+
output_dtype
342+
)
343+
out_offs = (
344+
block_row_offs[:, None] * stride_output_row
345+
+ block_col_offs[None, :] * stride_output_col
346+
)
347+
out_mask = (block_row_offs[:, None] < K) & (block_col_offs[None, :] < N)
348+
tl.store(out_ptr + out_offs, fp8_data, mask=out_mask)

0 commit comments

Comments
 (0)