Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 79 additions & 57 deletions fms_mo/aiu_addons/fp8/fp8_spyre_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Optional

# Third Party
from packaging.version import Version
from torch import Tensor
import torch
import torch.nn.functional as F
Expand All @@ -30,62 +29,71 @@
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482


if Version(torch.__version__) <= Version("2.7"):
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
# while for earlier versions we need a custom definition
def _scaled_mm_cpu_out(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out_dtype is None:
out_dtype = torch.float32
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)

if bias is not None:
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
else:
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)

if out is not None:
out.copy_(ret)
return out
return ret
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set.
# This CPU implementation is not enough for our use case, so we still have to
# keep our own custom version.
def _scaled_mm_cpu_out(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out_dtype is None:
out_dtype = torch.float32
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)

if bias is not None:
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
else:
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)

if out is not None:
out.copy_(ret)
return out
return ret


def _scaled_mm_cpu(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> Tensor:
return _scaled_mm_cpu_out(
mat1,
mat2,
scale1,
scale2,
bias,
scale_result,
out_dtype,
use_fast_accum,
out=None,
)


if torch.__version__ >= "2.8":
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out
torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu
else:
torch.library.register_kernel(
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
)

@torch.library.register_kernel("aten::_scaled_mm", "cpu")
def _scaled_mm_cpu(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> Tensor:
return _scaled_mm_cpu_out(
mat1,
mat2,
scale1,
scale2,
bias,
scale_result,
out_dtype,
use_fast_accum,
out=None,
)
torch.library.register_kernel(
torch.ops.aten._scaled_mm.default, "cpu", _scaled_mm_cpu
)


@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
Expand Down Expand Up @@ -115,7 +123,7 @@ def spyre_scaled_bmm(
device=mat1.device,
)
for b_idx in range(mat1.shape[0]):
out[b_idx] = torch._scaled_mm(
out[b_idx] = _scaled_mm_cpu_out(
mat1[b_idx],
mat2[b_idx],
scale1,
Expand Down Expand Up @@ -218,6 +226,7 @@ def scaled_paged_attn_compute(
num_kv_heads = value_cache.shape[2]
head_size = value_cache.shape[3]
block_size = value_cache.shape[1]
seq_len_q = query.shape[1]
num_seqs = query.shape[0]

block_tables_lst = block_table.cpu().tolist()
Expand All @@ -228,6 +237,7 @@ def scaled_paged_attn_compute(
block_table = block_tables_lst[i]
start_pos = int(left_padded_prompt_mask[i].item())
seq_len = int(seq_lens_lst[i])
seq_len_q_i = seq_len_q

keys_lst: list[torch.Tensor] = []
values_lst: list[torch.Tensor] = []
Expand All @@ -243,13 +253,25 @@ def scaled_paged_attn_compute(
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
seq_len_kv = keys.shape[0]

# cut the pads for first prefill
if q.shape[0] > seq_len_kv:
seq_len_q_i = seq_len_kv
q = q[-seq_len_kv:]

if num_kv_heads > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1)
values = torch.repeat_interleave(
values, num_query_heads // num_kv_heads, dim=1
)

# Generate mask for prefix attention
mask = torch.ones((1, 1, seq_len_q_i, seq_len_kv), dtype=torch.bool)
mask[:, :, :, -seq_len_q_i:] = torch.tril(mask[:, :, :, -seq_len_q_i:])
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)

out = F.scaled_dot_product_attention( # noqa: E1102
q.transpose(0, 1).unsqueeze(0), # format for sdpa
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to(
Expand All @@ -258,12 +280,12 @@ def scaled_paged_attn_compute(
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to(
dtype=q.dtype
), # format for sdpa
is_causal=False, # decode assumes no causal mask
attn_mask=mask, # decode assumes no causal mask
scale=scale,
)

out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True)
out = out.transpose(1, 2).view(seq_len_q_i, num_query_heads, head_size)
output[i][-seq_len_q_i:] = out
return output


Expand Down
Loading