Skip to content

Commit d14b3c2

Browse files
jerryzh168amdfaa
authored andcommitted
Fix a bug in LinearActivationQuantizedTensor (#1400)
* Fix a bug in LinearActivationQuantizedTensor Summary: quant_kwargs is not populated in some places Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * ruff
1 parent 5dde6e4 commit d14b3c2

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(
181181
def test_tp(self, dtype):
182182
return self._test_tp(dtype)
183183

184+
common_utils.instantiate_parametrized_tests(
185+
TestFloat8woAffineQuantizedTensorParallel
186+
)
184187
common_utils.instantiate_parametrized_tests(
185188
TestFloat8dqTensorAffineQuantizedTensorParallel
186189
)

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def _(func, types, args, kwargs):
147147
)
148148
input_quant_func = weight_tensor.input_quant_func
149149
original_weight_tensor = weight_tensor.original_weight_tensor
150-
aqt = input_quant_func(input_tensor)
151-
return func(bias, aqt, original_weight_tensor)
150+
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
151+
return func(bias, qtensor, original_weight_tensor)
152152
else:
153153
# aten.mm.default
154154
assert args[0].shape[-1] == args[1].shape[0], (
@@ -161,8 +161,8 @@ def _(func, types, args, kwargs):
161161
)
162162
input_quant_func = weight_tensor.input_quant_func
163163
original_weight_tensor = weight_tensor.original_weight_tensor
164-
aqt = input_quant_func(input_tensor)
165-
return func(aqt, original_weight_tensor)
164+
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
165+
return func(qtensor, original_weight_tensor)
166166

167167

168168
@implements(aten.detach.default)
@@ -203,7 +203,9 @@ def _(func, types, args, kwargs):
203203
args,
204204
kwargs,
205205
LinearActivationQuantizedTensor(
206-
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func
206+
func(args[0].original_weight_tensor, *args[1:]),
207+
args[0].input_quant_func,
208+
args[0].quant_kwargs,
207209
),
208210
)
209211

@@ -216,7 +218,9 @@ def _(func, types, args, kwargs):
216218
args,
217219
kwargs,
218220
LinearActivationQuantizedTensor(
219-
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func
221+
func(args[0].original_weight_tensor, *args[1:]),
222+
args[0].input_quant_func,
223+
args[0].quant_kwargs,
220224
),
221225
)
222226

0 commit comments

Comments
 (0)