|
8 | 8 | from unittest.mock import patch
|
9 | 9 |
|
10 | 10 | import torch
|
| 11 | +import torch.nn.functional as F |
11 | 12 |
|
12 | 13 | from torchao.testing.utils import skip_if_no_cuda
|
13 | 14 | from torchao.utils import TorchAOBaseTensor, torch_version_at_least
|
@@ -344,6 +345,53 @@ def __init__(
|
344 | 345 | )
|
345 | 346 | self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
|
346 | 347 |
|
| 348 | + def test_implements_and_torch_function_together(self): |
| 349 | + """Ensure a function decorated with both @_implements and @_implements_torch_function works.""" |
| 350 | + counter = {"calls": 0} |
| 351 | + |
| 352 | + class MyTensor(TorchAOBaseTensor): |
| 353 | + tensor_data_names = ["qdata"] |
| 354 | + tensor_attribute_names = ["attr", "device"] |
| 355 | + |
| 356 | + def __new__(cls, qdata: torch.Tensor, attr: str = "attr", device=None): |
| 357 | + kwargs = {} |
| 358 | + if device is None: |
| 359 | + device = qdata.device |
| 360 | + kwargs["device"] = device |
| 361 | + kwargs["dtype"] = qdata.dtype |
| 362 | + r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs) |
| 363 | + r.qdata = qdata |
| 364 | + r.attr = attr |
| 365 | + return r |
| 366 | + |
| 367 | + def __init__(self, qdata: torch.Tensor, attr: str = "attr", device=None): |
| 368 | + pass |
| 369 | + |
| 370 | + implements = MyTensor.implements |
| 371 | + implements_torch_function = MyTensor.implements_torch_function |
| 372 | + |
| 373 | + @implements([torch.ops.aten.t.default]) |
| 374 | + @implements_torch_function([F.linear]) |
| 375 | + def fake_linear(func, types, args, kwargs): |
| 376 | + counter["calls"] += 1 |
| 377 | + |
| 378 | + l = torch.nn.Linear(2, 3) |
| 379 | + l.weight = torch.nn.Parameter(MyTensor(l.weight.detach(), "attr", None)) |
| 380 | + x = torch.randn(4, 2) |
| 381 | + |
| 382 | + # Torch function path |
| 383 | + F.linear(x, l.weight, l.bias) |
| 384 | + self.assertEqual( |
| 385 | + counter["calls"], 1, "Expected fake_linear to be called via F.linear" |
| 386 | + ) |
| 387 | + |
| 388 | + # ATen path |
| 389 | + mt = MyTensor(torch.randn(3, 4)) |
| 390 | + torch.ops.aten.t.default(mt) |
| 391 | + self.assertEqual( |
| 392 | + counter["calls"], 2, "Expected fake_linear to be called via aten.t.default" |
| 393 | + ) |
| 394 | + |
347 | 395 |
|
348 | 396 | if __name__ == "__main__":
|
349 | 397 | unittest.main()
|
0 commit comments