Skip to content

Commit 2fe0ca0

Browse files
authored
Split implements and implements_torch_function (#2866)
* Split implements and implements_torch_function 1) Added two registers _ATEN_OP_TABLE and _TORCH_FN_TABLE instead of one. 2) Split the decorator into two. * Updates Updated the places where the decorators are called. Added test to check the condition when both decorators are wrapped. * Updating test to check the user code * test fixes * Fix for tests * Delete int8_dynamic_activation_lut_tensor.py * Updates some more files were added with the decorator. * updated test for the condition * Changes * updates
1 parent d446acd commit 2fe0ca0

29 files changed

+184
-77
lines changed

test/test_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import patch
99

1010
import torch
11+
import torch.nn.functional as F
1112

1213
from torchao.testing.utils import skip_if_no_cuda
1314
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
@@ -344,6 +345,53 @@ def __init__(
344345
)
345346
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
346347

348+
def test_implements_and_torch_function_together(self):
349+
"""Ensure a function decorated with both @_implements and @_implements_torch_function works."""
350+
counter = {"calls": 0}
351+
352+
class MyTensor(TorchAOBaseTensor):
353+
tensor_data_names = ["qdata"]
354+
tensor_attribute_names = ["attr", "device"]
355+
356+
def __new__(cls, qdata: torch.Tensor, attr: str = "attr", device=None):
357+
kwargs = {}
358+
if device is None:
359+
device = qdata.device
360+
kwargs["device"] = device
361+
kwargs["dtype"] = qdata.dtype
362+
r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)
363+
r.qdata = qdata
364+
r.attr = attr
365+
return r
366+
367+
def __init__(self, qdata: torch.Tensor, attr: str = "attr", device=None):
368+
pass
369+
370+
implements = MyTensor.implements
371+
implements_torch_function = MyTensor.implements_torch_function
372+
373+
@implements([torch.ops.aten.t.default])
374+
@implements_torch_function([F.linear])
375+
def fake_linear(func, types, args, kwargs):
376+
counter["calls"] += 1
377+
378+
l = torch.nn.Linear(2, 3)
379+
l.weight = torch.nn.Parameter(MyTensor(l.weight.detach(), "attr", None))
380+
x = torch.randn(4, 2)
381+
382+
# Torch function path
383+
F.linear(x, l.weight, l.bias)
384+
self.assertEqual(
385+
counter["calls"], 1, "Expected fake_linear to be called via F.linear"
386+
)
387+
388+
# ATen path
389+
mt = MyTensor(torch.randn(3, 4))
390+
torch.ops.aten.t.default(mt)
391+
self.assertEqual(
392+
counter["calls"], 2, "Expected fake_linear to be called via aten.t.default"
393+
)
394+
347395

348396
if __name__ == "__main__":
349397
unittest.main()

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,11 @@ def _register_aqt_quantized_linear_dispatches():
262262
_register_aqt_quantized_linear_dispatches()
263263

264264
implements = AffineQuantizedTensor.implements
265+
implements_torch_function = AffineQuantizedTensor.implements_torch_function
265266

266267

267-
@implements([torch.nn.functional.linear, aten.linear.default])
268+
@implements([aten.linear.default])
269+
@implements_torch_function([torch.nn.functional.linear])
268270
def _(func, types, args, kwargs):
269271
input_tensor, weight_tensor, bias = (
270272
args[0],
@@ -296,7 +298,7 @@ def _(func, types, args, kwargs):
296298
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
297299

298300

299-
@implements(torch.nn.functional.embedding)
301+
@implements_torch_function(torch.nn.functional.embedding)
300302
def _(func, types, args, kwargs):
301303
if _embedding_q_dq_check(args, kwargs):
302304
return _embedding_q_dq_impl(args, kwargs)

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,8 @@ def from_float(cls, weight):
847847
return cls(weight)
848848

849849

850-
@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
850+
@Float32Tensor.implements(aten.linear.default)
851+
@Float32Tensor.implements_torch_function(torch.nn.functional.linear)
851852
def _(func, types, args, kwargs):
852853
input_tensor, weight_tensor, bias = (
853854
args[0],

torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,11 @@ def to(self, *args, **kwargs):
164164

165165

166166
implements = CodebookQuantizedTensor.implements
167+
implements_torch_function = CodebookQuantizedTensor.implements_torch_function
167168

168169

169-
@implements([torch.nn.functional.linear, aten.linear.default])
170+
@implements([aten.linear.default])
171+
@implements_torch_function([torch.nn.functional.linear])
170172
def _(func, types, args, kwargs):
171173
input_tensor, weight_tensor, bias = (
172174
args[0],
@@ -177,7 +179,8 @@ def _(func, types, args, kwargs):
177179
return func(input_tensor, weight_tensor, bias)
178180

179181

180-
@implements([torch.nn.functional.embedding, aten.embedding.default])
182+
@implements([aten.embedding.default])
183+
@implements_torch_function([torch.nn.functional.embedding])
181184
def _(func, types, args, kwargs):
182185
assert len(args) == 2
183186
indices, weight_tensor = (

torchao/prototype/quantization/codebook_groupwise/codebook_quantized_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,10 @@ def from_codebook_quantized_tensor(
167167

168168

169169
implements = CodebookQuantizedPackedTensor.implements
170+
implements_torch_function = CodebookQuantizedPackedTensor.implements_torch_function
170171

171172

172-
@implements([F.linear])
173+
@implements_torch_function(F.linear)
173174
def _(func, types, args, kwargs):
174175
"""
175176
Override for `torch.nn.functional.linear` specifically for the

torchao/prototype/quantization/gguf/gguf_quantized_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def from_float(cls, input_float, n_blocks_per_superblock, target_dtype):
218218

219219

220220
implements = GGUFQuantizedTensor.implements
221+
implements_torch_function = GGUFQuantizedTensor.implements_torch_function
221222

222223

223224
@implements([aten.detach.default, aten.alias.default])
@@ -244,7 +245,8 @@ def _(func, types, args, kwargs):
244245
)
245246

246247

247-
@implements([torch.nn.functional.linear, aten.linear.default])
248+
@implements(aten.linear.default)
249+
@implements_torch_function(torch.nn.functional.linear)
248250
def _(func, types, args, kwargs):
249251
input_tensor, weight_tensor, bias = (
250252
args[0],

torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def from_intx_unpacked_to_int8_tensor(
165165

166166

167167
implements = Int8LutTensor.implements
168+
implements_torch_function = Int8LutTensor.implements_torch_function
168169

169170

170171
def _linear_impl_2d(
@@ -202,7 +203,8 @@ def _linear_impl_2d(
202203
return res
203204

204205

205-
@implements([torch.nn.functional.linear, aten.linear.default])
206+
@implements(aten.linear.default)
207+
@implements_torch_function(torch.nn.functional.linear)
206208
def _(func, types, args, kwargs):
207209
input_tensor, weight_tensor, bias = (
208210
args[0],

torchao/prototype/quantized_training/bitnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fsdp_post_all_gather(
140140
return BitNetPacked2bitLinearWeight(data_i2, scale), all_gather_outputs
141141

142142

143-
@BitNetTrainingLinearWeight.implements(F.linear)
143+
@BitNetTrainingLinearWeight.implements_torch_function(F.linear)
144144
def _(func, types, args, kwargs):
145145
if torch.is_autocast_enabled("cuda"):
146146
dtype = torch.get_autocast_gpu_dtype()
@@ -324,7 +324,7 @@ def dequantize(self, out_dtype=None):
324324
return out
325325

326326

327-
@BitNetPacked2bitLinearWeight.implements(F.linear)
327+
@BitNetPacked2bitLinearWeight.implements_torch_function(F.linear)
328328
def _(func, types, args, kwargs):
329329
return _BitNetPacked2bitLinear.apply(*args, **kwargs)
330330

torchao/prototype/quantized_training/int8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,10 @@ def backward(ctx, grad_output):
174174

175175

176176
implements = Int8QuantizedTrainingLinearWeight.implements
177+
implements_torch_function = Int8QuantizedTrainingLinearWeight.implements_torch_function
177178

178179

179-
@implements(torch.nn.functional.linear)
180+
@implements_torch_function(torch.nn.functional.linear)
180181
def _(func, types, args, kwargs):
181182
return _Int8WeightOnlyLinear.apply(*args, **kwargs)
182183

torchao/quantization/autoquant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,8 @@ def from_float(cls, weight):
833833
return cls(weight)
834834

835835

836-
@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
836+
@Float32Tensor.implements_torch_function(torch.nn.functional.linear)
837+
@Float32Tensor.implements(aten.linear.default)
837838
def _(func, types, args, kwargs):
838839
input_tensor, weight_tensor, bias = (
839840
args[0],

0 commit comments

Comments
 (0)