Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 56 additions & 52 deletions fms_mo/aiu_addons/fp8/fp8_spyre_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional

# Third Party
from packaging.version import Version
from torch import Tensor
import torch
import torch.nn.functional as F
Expand All @@ -29,60 +30,63 @@
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482


def _scaled_mm_cpu_out(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out_dtype is None:
out_dtype = torch.float32
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)

if bias is not None:
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
else:
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)

if out is not None:
out.copy_(ret)
return out
return ret


torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)


@torch.library.register_kernel("aten::_scaled_mm", "cpu")
def _scaled_mm_cpu(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> Tensor:
return _scaled_mm_cpu_out(
mat1,
mat2,
scale1,
scale2,
bias,
scale_result,
out_dtype,
use_fast_accum,
out=None,
if Version(torch.__version__) <= Version("2.7"):
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
# while for earlier versions we need a custom definition
def _scaled_mm_cpu_out(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out_dtype is None:
out_dtype = torch.float32
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)

if bias is not None:
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
else:
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)

if out is not None:
out.copy_(ret)
return out
return ret

torch.library.register_kernel(
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
)

@torch.library.register_kernel("aten::_scaled_mm", "cpu")
def _scaled_mm_cpu(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> Tensor:
return _scaled_mm_cpu_out(
mat1,
mat2,
scale1,
scale2,
bias,
scale_result,
out_dtype,
use_fast_accum,
out=None,
)


@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
def spyre_scaled_bmm(
Expand Down
10 changes: 3 additions & 7 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ def implement_op_decorator(op_namespace_id):
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
if Version(torch.__version__) < Version("2.4"):
return torch.library.impl(op_namespace_id, "default")(func)
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)

Expand All @@ -51,10 +49,8 @@ def register_op_decorator(op_namespace_id):
Always compare against pytorch version in current environment.
"""

torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])

def decorator(func):
if torch_version < Version("2.4"):
if Version(torch.__version__) < Version("2.4"):
return torch.library.impl_abstract(op_namespace_id)(func)
return torch.library.register_fake(op_namespace_id)(func)

Expand All @@ -73,7 +69,7 @@ def register_aiu_i8i8_op():
logger.warning("AIU op has already been registered")
return
op_namespace_id = "fms_mo::i8i8_aiu"
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
if Version(torch.__version__) < Version("2.4"):
torch.library.define(
op_namespace_id,
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
Expand Down
6 changes: 3 additions & 3 deletions tests/aiu_addons/test_fp8_addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_fp8_op() -> None:
# Local
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op

query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
query = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
key = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
value = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")

out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
assert out.size() == query.size()
Loading