Skip to content

Commit ddf1143

Browse files
committed
_layout -> layout for public API
1 parent 1b027ef commit ddf1143

File tree

11 files changed

+44
-44
lines changed

11 files changed

+44
-44
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3131
base_functions.append(int4_weight_only(group_size=32))
3232

3333
if do_sparse:
34-
base_functions.append(int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()))
34+
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
3535

3636
if is_cuda_8_9:
3737
base_functions.append(float8_weight_only())

test/sparsity/test_marlin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
5050
dense_result = model_copy(self.input.bfloat16()).half()
5151

5252
# Sparse + quantized
53-
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
53+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5454
sparse_result = self.model(self.input)
5555

5656
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
@@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
6767
dense_result = model_copy(self.input.bfloat16()).half()
6868

6969
# Sparse + quantized
70-
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
70+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
7171
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7272
sparse_result = self.model(self.input)
7373

test/sparsity/test_sparse_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):
7474

7575
quantize_(
7676
model,
77-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
77+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
7878
)
7979
if compile:
8080
model = torch.compile(model)
@@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
108108
dense_result = model_copy(input.bfloat16()).half()
109109

110110
# Sparse + quantized
111-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
111+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
112112
if compile:
113113
model = torch.compile(model)
114114
sparse_result = model(input)
@@ -190,7 +190,7 @@ def test_sparse(self, compile):
190190
quantize_(
191191
model,
192192
int8_dynamic_activation_int8_weight(
193-
_layout=BlockSparseLayout(blocksize=64)
193+
layout=BlockSparseLayout(blocksize=64)
194194
),
195195
)
196196
if compile:

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_evaluation(
9898
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
9999
if "marlin" in quantization:
100100
from torchao.dtypes import MarlinSparseLayout
101-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
101+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
102102
if "int4wo" in quantization and "gptq" in quantization:
103103
# avoid circular imports
104104
from torchao._models._eval import InputRecorder

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def main(
231231
quantize_(model, int4_weight_only(group_size=groupsize))
232232
if "marlin" in quantization:
233233
from torchao.dtypes import MarlinSparseLayout
234-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
234+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
235235
if "fp6" in quantization:
236236
quantize_(model, fpx_weight_only(3, 2))
237237
if quantization.startswith("awq"):

torchao/_models/sam/eval_combo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def mlp_only(mod, name):
315315
int8_dynamic_activation_int8_weight(),
316316
attn_only)
317317
quantize_(predictor.model.image_encoder,
318-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
318+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
319319
mlp_lin1_only)
320320
sparsify_(predictor.model.image_encoder,
321321
semi_sparse_weight(),
@@ -330,7 +330,7 @@ def mlp_only(mod, name):
330330
quantize_(predictor.model.image_encoder,
331331
int8_dynamic_activation_int8_weight(),
332332
attn_only)
333-
quantize_(predictor.model.image_encoder, int4_weight_only(_layout=MarlinSparseLayout()), mlp_lin1_only)
333+
quantize_(predictor.model.image_encoder, int4_weight_only(layout=MarlinSparseLayout()), mlp_lin1_only)
334334
sparsify_(predictor.model.image_encoder,
335335
semi_sparse_weight(),
336336
mlp_lin2_only)

torchao/quantization/quant_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
511511
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)
512512

513513

514-
def int4_weight_only(group_size=128, _layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False):
514+
def int4_weight_only(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False):
515515
"""
516516
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
517517
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -527,7 +527,7 @@ def int4_weight_only(group_size=128, _layout=TensorCoreTiledLayout(inner_k_tiles
527527
Args:
528528
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
529529
size is more fine grained, choices are [256, 128, 64, 32]
530-
`_layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
530+
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
531531
`use_hqq`: whether to use hqq or default quantization mode, default is False
532532
"""
533533
def apply_int4_weight_only_quant(weight):
@@ -550,12 +550,12 @@ def apply_int4_weight_only_quant(weight):
550550
# Sparse Marlin only supports symmetric quantization.
551551
# NOTE: If we start having lots of layouts that require different configurations,
552552
# we should consider moving this logic somewhere else.
553-
if isinstance(_layout, MarlinSparseLayout):
553+
if isinstance(layout, MarlinSparseLayout):
554554
mapping_type = MappingType.SYMMETRIC
555555
preserve_zero = True
556556
zero_point_domain = ZeroPointDomain.INT
557557

558-
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=_layout, use_hqq=use_hqq)
558+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, _layout=layout, use_hqq=use_hqq)
559559

560560
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
561561

@@ -583,7 +583,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
583583
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
584584

585585

586-
def int8_dynamic_activation_int8_weight(_layout=PlainLayout()):
586+
def int8_dynamic_activation_int8_weight(layout=PlainLayout()):
587587
"""
588588
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
589589
quantization to linear layers
@@ -609,7 +609,7 @@ def get_weight_block_size(x):
609609
input_quant_func = _int8_symm_per_token_reduced_range_quant
610610

611611
block_size = get_weight_block_size(weight)
612-
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=_layout)
612+
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout)
613613
weight = to_linear_activation_quantized(weight, input_quant_func)
614614
return weight
615615

@@ -621,12 +621,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
621621
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
622622
quantization + 2:4 sparsity to linear layers.
623623
"""
624-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the _layout kwarg in int8_dynamic_activation_int8_weight instead.
624+
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
625625
626626
from torchao.dtypes import SemiSparseLayout
627-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()""")
627+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
628628

629-
return int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout())
629+
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
630630

631631

632632
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):

torchao/sparsity/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from torchao.dtypes import MarlinSparseLayout
5757

5858
# Your FP16 model
5959
model = model.cuda().half()
60-
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
60+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
6161
```
6262

6363
Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop
@@ -72,7 +72,7 @@ from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_in
7272
from torchao.dtypes import SemiSparseLayout
7373

7474
model = model.cuda()
75-
quantize_(model, int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()))
75+
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
7676
```
7777

7878
### 2:4 sparsity

torchao/sparsity/prototype/superblock/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def accelerate_with_sparsity(model, args):
164164

165165
quantize_(
166166
model,
167-
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
167+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
168168
mlp_0_only,
169169
)
170170
sparsify_(model, semi_sparse_weight(), mlp_3_only)

torchao/sparsity/sparse_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def sparsify_(
4848
4949
Currently, we support three options for sparsity:
5050
- semi-structured (2:4) sparsity with `semi_sparse_weight`
51-
- int8 dynamic quantization + 2:4 sparsity with `_layout=SemiSparseLayout`
52-
- int4 weight-only quantization + 2:4 sparsity with `_layout=SparseMarlinLayout`
51+
- int8 dynamic quantization + 2:4 sparsity with `layout=SemiSparseLayout`
52+
- int4 weight-only quantization + 2:4 sparsity with `layout=SparseMarlinLayout`
5353
5454
Args:
5555
model (torch.nn.Module): input model
@@ -73,7 +73,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
7373
7474
# for int8 dynamic quantization + 2:4 sparsity
7575
from torchao.dtypes import SemiSparseLayout
76-
m = quantize_(m, int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout), filter_fn)
76+
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
7777
"""
7878
_replace_with_custom_fn_if_matches_filter(
7979
model,

0 commit comments

Comments
 (0)