|
5 | 5 | import torch |
6 | 6 | from torch import nn |
7 | 7 | from torch.testing._internal import common_utils |
8 | | -from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType |
| 8 | +from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout |
9 | 9 | from torchao.quantization.quant_api import ( |
10 | 10 | int4_weight_only, |
11 | 11 | int8_dynamic_activation_int8_weight, |
@@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile): |
74 | 74 |
|
75 | 75 | quantize_( |
76 | 76 | model, |
77 | | - int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), |
| 77 | + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), |
78 | 78 | ) |
79 | 79 | if compile: |
80 | 80 | model = torch.compile(model) |
@@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile): |
108 | 108 | dense_result = model_copy(input.bfloat16()).half() |
109 | 109 |
|
110 | 110 | # Sparse + quantized |
111 | | - quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) |
| 111 | + quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) |
112 | 112 | if compile: |
113 | 113 | model = torch.compile(model) |
114 | 114 | sparse_result = model(input) |
@@ -185,12 +185,12 @@ def test_sparse(self, compile): |
185 | 185 | quantize_(model_copy, int8_dynamic_activation_int8_weight()) |
186 | 186 | reference = model_copy(input) |
187 | 187 |
|
188 | | - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType |
| 188 | + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout |
189 | 189 |
|
190 | 190 | quantize_( |
191 | 191 | model, |
192 | 192 | int8_dynamic_activation_int8_weight( |
193 | | - layout_type=BlockSparseLayoutType(blocksize=64) |
| 193 | + layout=BlockSparseLayout(blocksize=64) |
194 | 194 | ), |
195 | 195 | ) |
196 | 196 | if compile: |
|
0 commit comments