Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 21, 2025

Summary:
Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:

  • Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
  • New: Direct int8 tensor with scaling factor and zero point

Related Issue/PR: #3012 (comment) #2752

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:
- Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
- New: Direct int8 tensor with qdata, scale, and zero_point attributes

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan:
Implement block-wise quantization using `block_size` parameter
Copy link

pytorch-bot bot commented Sep 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2025
@jerryzh168
Copy link
Contributor

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
? similar to
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

args[2] if len(args) > 2 else None,
)

if isinstance(input_tensor, Int8PlainInt8Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to quantize input_tensor in this function now, please check

if act_quant_kwargs is not None:
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
)

@namgyu-youn namgyu-youn changed the title Add Int8PlainInt8Tensor for clearer interface Add Int8Tensor for clearer interface Sep 23, 2025
Comment on lines 176 to 182
x_int32 = input_tensor.qdata.to(torch.int32)
w_int32 = weight_tensor.qdata.to(torch.int32).t()

result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32)
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0)
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the same as

def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int8(weight_tensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)
# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to check the kernel that's used similar to

def test_expected_gpu_kernel_fbgemm(self):
as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to check the kernel that's used similar to

def test_expected_gpu_kernel_fbgemm(self):

as well?

Yes linked workflow should be better to prevent overhead, I will fix it.

result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8 (static)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is the code for weight only quant I think:

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done at 9383550 , thanks for pointing it out.

block_size (Optional[list[int]]): block size for quantization granularity
"""

kernel_preference: KernelPreference = KernelPreference.AUTO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like no multiple kernel preferences right now right? if so, we can remove this for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this flag, but how about adding TODO for real kernel preference? Keeping current structure might be helpful for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have different kernel options for this one I think



@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestInt8Tensor(TorchAOIntegrationTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?

def test_slice(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)
quantize_(dummy, config)
weight1 = dummy.weight.clone().narrow(0, 0, 64)
weight2 = dummy.weight.clone().narrow(1, 0, 128)
self.assertEqual(
weight1.qdata,
dummy.weight.qdata.narrow(0, 0, 64),
)
self.assertEqual(
weight2.qdata,
dummy.weight.qdata.narrow(1, 0, 128),
)
if isinstance(granularity, PerRow):
self.assertEqual(
weight1.scale,
dummy.weight.scale.narrow(0, 0, 64),
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
else:
self.assertEqual(
weight1.scale,
dummy.weight.scale,
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
and
def test_slice_preserves_aliasing(self, granularity):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes linked unit test is helpful for slicing (PerTensor, PerRow) test, but I didn't implemented granularity in this PR yet for smaller PR size. Can I address it after this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the slicing tests are specific to a granularity, you should be able to adapt it for the currently supported granularity I think

raise ValueError("Expected 2D tensor and block_size length 2")

# Rounding function from high precision dtype
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like block_size is not used? why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can checkout

def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias):
for expected granularity

also this should be using these quant primitive ops:

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)
wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
, arguments can be found by tracing through the code path for int8 in
new_weight = to_affine_quantized_intx(
and
scale, zero_point = choose_qparams_affine(

this might require a bit too much context, let me know if you would like us to take over

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)

btw, version 2 is updated at c53dad0 (version 1 is default)

@namgyu-youn namgyu-youn marked this pull request as draft September 28, 2025 13:23
@namgyu-youn namgyu-youn marked this pull request as ready for review September 30, 2025 06:09
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase, and let me know when this is ready for review again @namgyu-youn

@namgyu-youn namgyu-youn requested a review from jerryzh168 October 4, 2025 11:08
self.input_fp, weight_q8_dynamic, self.bias
)

self.assertEqual(result_dynamic.shape, reference.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probably add a test for compute_error comparing floating point weight and int8+int8 weight as well

)

def test_linear_operations(self):
"""Test fp+int8 and int8+int8 linear ops"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not int8+int8 I think? this is weight only quant

Comment on lines +116 to +130
def test_linear_operations(self):
"""Test fp+int8 and int8+int8 linear ops"""
weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size)

reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias)
result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias)
result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias)

self.assertEqual(result_fp.shape, reference.shape)
self.assertEqual(result_q8.shape, reference.shape)
self.assertTrue(compute_error(result_fp, reference) > 10)
self.assertTrue(compute_error(result_q8, reference) > 10)

def test_dynamic_quantization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove these 2 tests actually, since they are already tested in test_int8_linear_variants

Comment on lines +180 to +181
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add assert for scale as well?

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))

def test_transpose(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used anywhere? for most of the tensors we actually don't support transpose so far, we tend to add this only when need

weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
selected = weight_q8.select(0, 0)

self.assertEqual(selected.shape, (3,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test the data as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can follow this:

)
else:
assert config.version == 2, f"Unexpected version: {config.version}"
block_size = [weight.shape[0], weight.shape[1]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as L1393 I think, you can extract L1390-L1393 out of the first if branch and use that I think

else:
quantized_weight = Int8Tensor.from_hp(
weight,
block_size=get_weight_block_size(weight),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can calculate block_size outside of the if/else

elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
return Int8Tensor.from_hp(
tensor,
quant_kwargs.block_size or [1, tensor.shape[-1]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not make block_size mandatory?

block_size (Optional[list[int]]): block size for quantization granularity
"""

block_size: Optional[list[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this optional?

"dtype",
]

def __new__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please annotate the args with types to be clearer

}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

self.qdata = qdata
self.scale = scale
self.block_size = block_size
self._shape = shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need to set shape here, since it will be set in torch.Tensor._make_wrapper_subclass

Comment on lines +144 to +146
# Reshape 1D scale to [N, 1] for broadcasting with [N, K] qdata
if scale.ndim == 1:
scale = scale.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

)


@implements(aten.transpose.int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this yet I think, we can remove for now and add later when needed

if dim == 0 and tensor.scale.ndim >= 1:
sliced_scale = aten.slice.Tensor(tensor.scale, 0, start, end, step)

sliced_shape = list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not get the shape from sliced tensor directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check

? I'm not sure if the current implementation is enough to cover all cases actually

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants