Skip to content

Commit ffd36e3

Browse files
committed
_layout -> layout for public API
1 parent 1b027ef commit ffd36e3

File tree

12 files changed

+46
-46
lines changed

12 files changed

+46
-46
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3131
base_functions.append(int4_weight_only(group_size=32))
3232

3333
if do_sparse:
34-
base_functions.append(int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()))
34+
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
3535

3636
if is_cuda_8_9:
3737
base_functions.append(float8_weight_only())

test/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
876876
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
877877
for groupsize in [64, 32]:
878878
for inner_k_tiles in [4, 2]:
879-
kwargs = {"groupsize": groupsize, "_layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
879+
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
880880

881881
def api(mod):
882882
kwargs_copy = kwargs.copy()
@@ -888,7 +888,7 @@ def api(mod):
888888
unwrap_tensor_subclass(mod)
889889
else:
890890
kwargs_copy["inner_k_tiles"] = inner_k_tiles
891-
del kwargs_copy["_layout"]
891+
del kwargs_copy["layout"]
892892
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
893893

894894
self._test_lin_weight_subclass_api_impl(

test/sparsity/test_marlin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
5050
dense_result = model_copy(self.input.bfloat16()).half()
5151

5252
# Sparse + quantized
53-
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
53+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5454
sparse_result = self.model(self.input)
5555

5656
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
@@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
6767
dense_result = model_copy(self.input.bfloat16()).half()
6868

6969
# Sparse + quantized
70-
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
70+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
7171
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7272
sparse_result = self.model(self.input)
7373

test/sparsity/test_sparse_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):
7474

7575
quantize_(
7676
model,
77-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
77+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
7878
)
7979
if compile:
8080
model = torch.compile(model)
@@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
108108
dense_result = model_copy(input.bfloat16()).half()
109109

110110
# Sparse + quantized
111-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
111+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
112112
if compile:
113113
model = torch.compile(model)
114114
sparse_result = model(input)
@@ -190,7 +190,7 @@ def test_sparse(self, compile):
190190
quantize_(
191191
model,
192192
int8_dynamic_activation_int8_weight(
193-
_layout=BlockSparseLayout(blocksize=64)
193+
layout=BlockSparseLayout(blocksize=64)
194194
),
195195
)
196196
if compile:

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_evaluation(
9898
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
9999
if "marlin" in quantization:
100100
from torchao.dtypes import MarlinSparseLayout
101-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
101+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
102102
if "int4wo" in quantization and "gptq" in quantization:
103103
# avoid circular imports
104104
from torchao._models._eval import InputRecorder

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def main(
231231
quantize_(model, int4_weight_only(group_size=groupsize))
232232
if "marlin" in quantization:
233233
from torchao.dtypes import MarlinSparseLayout
234-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
234+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
235235
if "fp6" in quantization:
236236
quantize_(model, fpx_weight_only(3, 2))
237237
if quantization.startswith("awq"):

torchao/_models/sam/eval_combo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def mlp_only(mod, name):
315315
int8_dynamic_activation_int8_weight(),
316316
attn_only)
317317
quantize_(predictor.model.image_encoder,
318-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
318+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
319319
mlp_lin1_only)
320320
sparsify_(predictor.model.image_encoder,
321321
semi_sparse_weight(),
@@ -330,7 +330,7 @@ def mlp_only(mod, name):
330330
quantize_(predictor.model.image_encoder,
331331
int8_dynamic_activation_int8_weight(),
332332
attn_only)
333-
quantize_(predictor.model.image_encoder, int4_weight_only(_layout=MarlinSparseLayout()), mlp_lin1_only)
333+
quantize_(predictor.model.image_encoder, int4_weight_only(layout=MarlinSparseLayout()), mlp_lin1_only)
334334
sparsify_(predictor.model.image_encoder,
335335
semi_sparse_weight(),
336336
mlp_lin2_only)

torchao/quantization/quant_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
511511
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)
512512

513513

514-
def int4_weight_only(group_size=128, _layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False):
514+
def int4_weight_only(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False):
515515
"""
516516
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
517517
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -527,7 +527,7 @@ def int4_weight_only(group_size=128, _layout=TensorCoreTiledLayout(inner_k_tiles
527527
Args:
528528
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
529529
size is more fine grained, choices are [256, 128, 64, 32]
530-
`_layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
530+
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
531531
`use_hqq`: whether to use hqq or default quantization mode, default is False
532532
"""
533533
def apply_int4_weight_only_quant(weight):
@@ -550,12 +550,12 @@ def apply_int4_weight_only_quant(weight):
550550
# Sparse Marlin only supports symmetric quantization.
551551
# NOTE: If we start having lots of layouts that require different configurations,
552552
# we should consider moving this logic somewhere else.
553-
if isinstance(_layout, MarlinSparseLayout):
553+
if isinstance(layout, MarlinSparseLayout):
554554
mapping_type = MappingType.SYMMETRIC
555555
preserve_zero = True
556556
zero_point_domain = ZeroPointDomain.INT
557557

558-
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=_layout, use_hqq=use_hqq)
558+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=layout, use_hqq=use_hqq)
559559

560560
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
561561

@@ -583,7 +583,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
583583
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
584584

585585

586-
def int8_dynamic_activation_int8_weight(_layout=PlainLayout()):
586+
def int8_dynamic_activation_int8_weight(layout=PlainLayout()):
587587
"""
588588
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
589589
quantization to linear layers
@@ -609,7 +609,7 @@ def get_weight_block_size(x):
609609
input_quant_func = _int8_symm_per_token_reduced_range_quant
610610

611611
block_size = get_weight_block_size(weight)
612-
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=_layout)
612+
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout)
613613
weight = to_linear_activation_quantized(weight, input_quant_func)
614614
return weight
615615

@@ -621,12 +621,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
621621
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
622622
quantization + 2:4 sparsity to linear layers.
623623
"""
624-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the _layout kwarg in int8_dynamic_activation_int8_weight instead.
624+
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
625625
626626
from torchao.dtypes import SemiSparseLayout
627-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()""")
627+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
628628

629-
return int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout())
629+
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
630630

631631

632632
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):

torchao/sparsity/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from torchao.dtypes import MarlinSparseLayout
5757

5858
# Your FP16 model
5959
model = model.cuda().half()
60-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
60+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
6161
```
6262

6363
Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop
@@ -72,7 +72,7 @@ from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_in
7272
from torchao.dtypes import SemiSparseLayout
7373

7474
model = model.cuda()
75-
quantize_(model, int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()))
75+
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
7676
```
7777

7878
### 2:4 sparsity

torchao/sparsity/prototype/superblock/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def accelerate_with_sparsity(model, args):
164164

165165
quantize_(
166166
model,
167-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
167+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
168168
mlp_0_only,
169169
)
170170
sparsify_(model, semi_sparse_weight(), mlp_3_only)

0 commit comments

Comments
 (0)