Skip to content

Commit a4ceeb1

Browse files
fix: moved uint2 to prototype folder
1 parent 783c364 commit a4ceeb1

File tree

3 files changed

+7
-23
lines changed

3 files changed

+7
-23
lines changed

test/dtypes/test_uint2.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
11
from unittest import main
2-
32
import torch
43
import torch.nn as nn
5-
6-
from torch.testing._internal.common_quantization import (
7-
QuantizationTestCase,
8-
)
9-
10-
from torchao.dtypes.uint2 import (
11-
BitnetTensor
12-
)
13-
from torchao.quantization.quant_api import (
14-
_replace_with_custom_fn_if_matches_filter,
15-
)
4+
from torch.testing._internal.common_quantization import QuantizationTestCase
5+
from torchao.prototype.dtypes.uint2 import BitnetTensor
6+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
167

178
def _apply_weight_only_uint2_quant(model):
189
def fn(mod):
@@ -25,7 +16,6 @@ def fn(mod):
2516
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
2617
)
2718

28-
2919
class TestUInt2(QuantizationTestCase):
3020
def test_gpu_quant(self):
3121
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -35,11 +25,6 @@ def test_gpu_quant(self):
3525
y_ref = m(x)
3626
_apply_weight_only_uint2_quant(m)
3727
y_wo = m(x)
38-
# sqnr = compute_error(y_ref, y_wo)
39-
# opt = torch.compile(m, fullgraph=True, mode="max-autotune")
40-
# make sure it runs
41-
# opt(x)
42-
4328

4429
if __name__ == "__main__":
4530
main()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from .nf4tensor import NF4Tensor, to_nf4
2-
from .uint2 import UInt2Tensor, BitnetTensor
2+
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
44
from .aqt import AffineQuantizedTensor, to_aq
55
from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2
66

77
__all__ = [
88
"NF4Tensor",
99
"to_nf4",
10-
"UInt2Tensor",
11-
"BitnetTensor",
1210
"UInt4Tensor"
1311
"AffineQuantizedTensor",
1412
"to_aq",

torchao/dtypes/uint2.py renamed to torchao/prototype/dtypes/uint2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch._prims_common as utils
33
import torch.utils._pytree as pytree
44
from torch.library import impl, Library
5-
from .uint4 import qtensor_lib
5+
from ...dtypes.uint4 import qtensor_lib
66

77

88
def down_size(size):
@@ -94,6 +94,7 @@ def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor:
9494
return output
9595

9696
else:
97+
# TODO: torch compile issue https://github.com/pytorch/pytorch/issues/127374 is fixed
9798
#@torch.compile
9899
def unpack_uint8_to_trinary2(uint8_data: torch.Tensor) -> torch.Tensor:
99100
# since we are using uint8 we will decode 4 entries per byte
@@ -150,7 +151,7 @@ def fill_defaults(args, n, defaults_tail):
150151
return r
151152

152153

153-
#qtensor_lib = Library("qtensors", "DEF")
154+
# qtensor_lib = Library("qtensors", "DEF")
154155
qtensor_lib.define(
155156
"quantize_per_tensor_uint2(Tensor input, float scale, int zero_point) -> Tensor"
156157
)

0 commit comments

Comments
 (0)