Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 111 additions & 173 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4359,3 +4359,74 @@ def grid(meta):
xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6]
)
return x_dequant


# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
def to_mxfp8(
data_hp: torch.Tensor,
block_size: int = 32,
):
assert data_hp.dtype in (
torch.bfloat16,
torch.float,
), f"{data_hp.dtype} is not supported yet"
assert (
data_hp.shape[-1] % block_size == 0
), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
assert data_hp.is_contiguous(), "unsupported"

orig_shape = data_hp.shape
data_hp = data_hp.reshape(
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
)

max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)

data_hp = data_hp.to(torch.float32)
max_abs = max_abs.to(torch.float32)

F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
max_pos = F8E4M3_MAX

# RCEIL
def _to_mx_rceil(
data_hp: torch.Tensor,
max_abs: torch.Tensor,
max_pos: float,
) -> tuple[torch.Tensor, torch.Tensor]:
E8M0_EXPONENT_BIAS = 127
descale = max_abs / max_pos
exponent = torch.where(
torch.isnan(descale),
0xFF, # Handle biased exponent for nan
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
(
torch.clamp(
torch.ceil(torch.log2(descale)),
min=-E8M0_EXPONENT_BIAS,
max=E8M0_EXPONENT_BIAS,
)
+ E8M0_EXPONENT_BIAS
).to(torch.uint8),
)

descale_fp = torch.where(
exponent == 0,
1.0,
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
)

# scale and saturated cast the data elements to max of target dtype
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
return exponent, data_lp

scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)

# cast to target dtype
data_lp = data_lp.to(torch.float8_e4m3fn)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)

scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
return scale_e8m0_biased, data_lp
134 changes: 126 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import triton # @manual=//triton:triton

from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
_to_blocked,
calculate_group_max,
mega_fp4_pack,
mega_fp4_quantize_kernel,
Expand All @@ -33,6 +34,7 @@
quantize_fp8_group,
quantize_fp8_row,
scale_fp8_row,
to_mxfp8,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
Expand Down Expand Up @@ -2497,20 +2499,53 @@ def preprocess(self, x, w):
return x, wq, w_scale, m_sizes

def quantize(self, x, wq, w_scale, m_sizes):
xq, x_scale = zip(*[triton_quantize_mx4_unpack(i) for i in x])
xq = torch.stack(xq, dim=0).contiguous()
x_scale = torch.stack(x_scale, dim=0).contiguous()
starting_row_after_padding_list = [0]
xq_list = []
x_scale_list = []
for i in range(m_sizes.shape[0]):
scale_slice = x[i]
if m_sizes[i].item() != 0:
xq, x_scale = triton_quantize_mx4_unpack(scale_slice)
xq_list.append(xq)
x_scale_list.append(x_scale)
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
+ x_scale.numel() // (x[0].shape[1] // 32)
)
else:
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
xq = xq.view(-1, xq.shape[-1])
return xq, wq, x_scale, w_scale, m_sizes
return (
xq,
wq,
x_scale,
w_scale,
m_sizes,
torch.tensor(starting_row_after_padding_list, device=xq.device),
)

def compute(self, xq, wq, x_scale, w_scale, m_sizes):
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
return torch.ops.fbgemm.f4f4bf16_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
xq,
wq,
x_scale,
w_scale,
m_sizes,
starting_row_after_padding=starting_row_after_padding,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
x, w
)
return self.compute(
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -2835,3 +2870,86 @@ def hip(self) -> bool:
@property
def cuda(self) -> bool:
return True


@register_quantize_op
class MXFP8StackedGroupedGemm(QuantizeOpBase):
"""
MXFP8 grouped matmul with blockwise scaling and stacked inputs.
"""

def preprocess(self, x, w):
m_values = [i.shape[0] for i in x]
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
wq_list = []
w_scale_list = []
for i in range(m_sizes.shape[0]):
w_scale, wq = to_mxfp8(w[i])
w_scale = _to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes

def quantize(self, x, wq, w_scale, m_sizes):
starting_row_after_padding_list = [0]
xq_list = []
x_scale_list = []
for i in range(m_sizes.shape[0]):
scale_slice = x[i]
if m_sizes[i].item() != 0:
x_scale, xq = to_mxfp8(scale_slice)
x_scale = _to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
+ x_scale.numel() // (x[0].shape[1] // 32)
)
else:
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
xq = xq.view(-1, xq.shape[-1])
return (
xq,
wq,
x_scale,
w_scale,
m_sizes,
torch.tensor(starting_row_after_padding_list, device=xq.device),
)

def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
xq,
wq,
x_scale,
w_scale,
m_sizes,
starting_row_after_padding=starting_row_after_padding,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
x, w
)
return self.compute(
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
)

@property
def name(self) -> str:
return "cutlass_mx8mx8bf16_grouped_stacked"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,21 @@ __global__ void set_stacked_kernel_args_kernel(
int64_t offset_M = 0;
int64_t accumulated_x_scale = 0;
int64_t accumulated_w_scale = 0;
int ele_per_quantize_group = 16;
if (global_scale == nullptr) {
ele_per_quantize_group = 32;
}
for (int i = 0; i < group_index; i++) {
offset_M += M_sizes[i];
/* It's calculated this way since the scales are at least padded to
multiples of (128, 4), and there is a group of 16 elements per scale.
*/
accumulated_w_scale +=
(((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 / 16);
(((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 /
ele_per_quantize_group);
}
accumulated_x_scale = starting_row_after_padding[group_index] * K / 16;
accumulated_x_scale =
starting_row_after_padding[group_index] * K / ele_per_quantize_group;
// Set the problem shape for this group.
problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K);
// Set input pointers.
Expand Down Expand Up @@ -646,7 +652,7 @@ at::Tensor f4f4bf16_grouped_impl(
layout_SFB,
nullptr,
nullptr,
nullptr);
starting_row_after_padding_ptr);
}
// Set the number of groups to the kernel to be at most the number of
// non-zero rows.
Expand Down
Loading
Loading