Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(driss_torch LANGUAGES CXX CUDA)

# Set the C++ standard for all targets
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 20) # This might be unsafe since pytorch use std17
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Enable better clangd support
Expand Down
28 changes: 13 additions & 15 deletions benchmarks/benchmark_saturated_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from driss_torch import saturated_cast
from float8_experimental.float8_utils import amax_to_scale
from jsonargparse import CLI

from tabulate import tabulate
Expand All @@ -22,16 +23,17 @@

def eager_scaled_quant(
a: torch.Tensor,
scale: torch.Tensor,
amax: torch.Tensor,
fp8_dtype: torch.dtype,
):
"""Quantize tensor to fp8 using a delayed scaled and calculate abs_max

Args:
a: Input tensor to quantize
scale: Scale to apply to input tensor, calculated from previous abs_max
amax of the input tensor
fp8_dtype: FP8 datatype to quantize to
"""
scale = amax_to_scale(amax, fp8_dtype, a.dtype)
out = a * scale
out = torch.where(out > torch.finfo(fp8_dtype).max, torch.finfo(fp8_dtype).max, out)
out = torch.where(out < -1 * torch.finfo(fp8_dtype).max, -1 * torch.finfo(fp8_dtype).max, out)
Expand Down Expand Up @@ -97,42 +99,38 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
config.num_rows, config.num_cols, dtype=config.high_precision_dtype, device=device
)
cuda_hp_tensor = high_precision_tensor.clone()
cuda_scale = torch.ones(1, dtype=torch.bfloat16, device=device)
cuda_amax = torch.abs(high_precision_tensor).max().to(torch.float32)

eager_abs_max = torch.abs(high_precision_tensor).max().to(torch.float32)

scale = torch.finfo(config.low_precision_dtype).max / eager_abs_max
scale = scale.to(torch.float32)
scale = torch.ones(1, dtype=torch.float32, device=device)
eager_amax = torch.abs(high_precision_tensor).max().to(torch.float32)

# Correctness check:
cuda_out = saturated_cast(cuda_hp_tensor, config.low_precision_dtype, cuda_scale)
cuda_out, cuda_scale = saturated_cast(cuda_hp_tensor, eager_amax, config.low_precision_dtype)
cuda_out_hp = cuda_out.to(config.high_precision_dtype)

eager_out = eager_scaled_quant(high_precision_tensor, scale, config.low_precision_dtype).to(
config.high_precision_dtype
)
eager_out = eager_scaled_quant(
high_precision_tensor, eager_amax, config.low_precision_dtype
).to(config.high_precision_dtype)
eager_out_hp = eager_out.to(config.high_precision_dtype)

torch.testing.assert_close(cuda_out_hp, eager_out_hp, rtol=1e-3, atol=1e-3)

cuda_time = benchmark_torch_function_in_microseconds(
saturated_cast,
cuda_hp_tensor,
eager_amax,
config.low_precision_dtype,
cuda_scale,
)
pytorch_time = benchmark_torch_function_in_microseconds(
eager_scaled_quant,
high_precision_tensor,
scale,
eager_amax,
config.low_precision_dtype,
)
compiled_pytorch_fn = torch.compile(eager_scaled_quant, fullgraph=True)
compiled_pytorch_time = benchmark_torch_function_in_microseconds(
compiled_pytorch_fn,
high_precision_tensor,
scale,
eager_amax,
config.low_precision_dtype,
)
return ExperimentResult(
Expand Down
27 changes: 9 additions & 18 deletions driss_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Tuple

import torch

Expand All @@ -17,31 +17,22 @@ def list_ops():
return ops.__dir__()


def add_one(x: torch.Tensor) -> torch.Tensor:
"""Add one to a tensor.
This is a dummy test op to demonstrate how to add custom ops to PyTorch.
Args:
x: The input tensor.
Returns:
The output tensor.
"""
return ops.add_one(x)


def saturated_cast(
x: torch.Tensor,
scale: torch.Tensor,
amax: torch.Tensor,
out_dtype: torch.dtype,
transpose: bool = False,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This op takes in a tensor and returns the fp8 saturated casted version of it.
Args;
x: The input tensor.
out_dtype: The output data type, must be a float8 dtype.
scale: An optional on device tensor, this is expected to be a singleton tensor whose value will be multiplied
x before casting
scale: An on device tensor, this is expected to be a singleton tensor whose value is
the max(abs(x) before casting, we will use this to calculate the scale
using the formula `scale = amax / max(max_abs(x), 1e-12)`
transpose: If true will transpose the input tensor during casting
Returns:
The output tensor.
The output tensor. And the on device scale tensor.
"""
return ops.saturated_cast(x, scale, out_dtype, transpose)
assert not transpose, "Transpose is not supported yet"
return ops.saturated_cast(x, amax, out_dtype, transpose)
8 changes: 4 additions & 4 deletions driss_torch/abstract_impls.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
from torch.library import impl_abstract

print(__name__)


@impl_abstract("DrissTorch::saturated_cast")
def saturated_cast_meta(
x: torch.Tensor,
scale: torch.Tensor,
amax: torch.Tensor,
out_dtype: torch.dtype,
transpose: bool = False,
):
return torch.empty_like(x, dtype=out_dtype)
return torch.empty_like(x, dtype=out_dtype), torch.empty(
(), device=x.device, dtype=torch.float32
)
40 changes: 0 additions & 40 deletions src/add.cu

This file was deleted.

7 changes: 0 additions & 7 deletions src/include/add.h

This file was deleted.

5 changes: 1 addition & 4 deletions src/include/saturated_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

namespace driss_torch {

at::Tensor saturated_cast(const at::Tensor &input, const at::Tensor &attn_mask,
std::tuple<at::Tensor, at::Tensor> saturated_cast(const at::Tensor &input, const at::Tensor &amax,
at::ScalarType dtype, bool transpose);
at::Tensor saturated_cast_meta(const at::Tensor &input,
const at::Tensor &attn_mask,
at::ScalarType dtype, bool transpose);
} // namespace driss_torch
9 changes: 8 additions & 1 deletion src/include/utils.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
#pragma once

#include <cstdio>
#include <functional>
#include <stdio.h>


namespace driss_torch {

template <typename... Args>
__device__ void thread_zero_print(const char *fmt, Args &&...args) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
printf(fmt, std::forward<Args>(args)...);
}
}

// error checking macro
#define cudaCheckErrors(msg) \
do { \
Expand Down
5 changes: 1 addition & 4 deletions src/register_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
#include <torch/library.h>

// Custom up headers
#include "add.h"
#include "saturated_cast.h"

TORCH_LIBRARY(DrissTorch, m) {
m.impl_abstract_pystub("driss_torch.abstract_impls");
m.def("add_one(Tensor input) -> Tensor");
m.impl("add_one", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::add_one));
// Saturated cast func from bf16 to fp8 types
m.def("saturated_cast(Tensor input, Tensor scale, ScalarType dtype, bool transpose) -> Tensor");
m.def("saturated_cast(Tensor input, Tensor amax, ScalarType dtype, bool transpose) -> (Tensor, Tensor)");
m.impl("saturated_cast", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::saturated_cast));
}
Loading