Skip to content

Commit 9f85488

Browse files
authored
Allow Int4WeightOnlyQuantizer to set different dtype for scales_and_zeros (#479)
* Allow Int4WeightOnlyQuantizer to set different dtype for scales_and_zeros As titled. Currently `Int4WeightOnlyQuantizer` is hardcoded to return `scales_and_zeros` with dtype `torch.bfloat16`. Adding `dtype` argument into the flow so that it can be different dtype. * Add comment
1 parent ec73788 commit 9f85488

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

torchao/quantization/GPTQ.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,14 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
525525
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
526526
return k_divisible_by_groupsize
527527

528-
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
528+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize, dtype=torch.bfloat16):
529529
origin_x_size = x.size()
530530
x = x.reshape(-1, origin_x_size[-1])
531531
c = torch.ops.aten._weight_int4pack_mm(
532-
x.to(torch.bfloat16),
532+
x.to(dtype),
533533
weight_int4pack,
534534
groupsize,
535-
scales_and_zeros.to(torch.bfloat16)
535+
scales_and_zeros.to(dtype)
536536
).to(dtype=x.dtype)
537537
new_shape = origin_x_size[:-1] + (out_features,)
538538
c = c.reshape(new_shape)
@@ -546,12 +546,12 @@ class WeightOnlyInt4Linear(torch.nn.Module):
546546

547547
def __init__(
548548
self, in_features: int, out_features: int,
549-
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
549+
bias=False, device=None, dtype=torch.bfloat16, groupsize: int = 128, inner_k_tiles: int = 8,
550550
) -> None:
551551
super().__init__()
552552
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
553553
if self.padding:
554-
from model import find_multiple
554+
from .utils import find_multiple
555555
self.origin_in_features = in_features
556556
in_features = find_multiple(in_features, 1024)
557557

@@ -567,9 +567,10 @@ def __init__(
567567
"weight",
568568
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
569569
)
570+
self.dtype = dtype
570571
self.register_buffer(
571572
"scales_and_zeros",
572-
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
573+
torch.empty((in_features // groupsize, out_features, 2), dtype=self.dtype)
573574
)
574575

575576
def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -578,20 +579,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
578579
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
579580
return linear_forward_int4(
580581
input,
581-
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
582+
self.weight, self.scales_and_zeros, self.out_features, self.groupsize, self.dtype
582583
)
583584

584-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None):
585+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None, dtype=torch.bfloat16):
585586

586587
for name, child in module.named_children():
587588
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
588589
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
589590
setattr(module, name, WeightOnlyInt4Linear(
590591
child.in_features, child.out_features, bias=False,
591592
groupsize=groupsize, inner_k_tiles=inner_k_tiles,
593+
dtype=dtype,
592594
))
593595
else:
594-
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func)
596+
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func, dtype)
595597

596598
class Int4WeightOnlyQuantizer(Quantizer):
597599
def __init__(
@@ -600,6 +602,7 @@ def __init__(
600602
padding_allowed: bool = True,
601603
inner_k_tiles: Optional[int] = 8,
602604
device: torch.device = torch.device("cuda"),
605+
precision: torch.dtype = torch.bfloat16,
603606
) -> None:
604607
super().__init__()
605608
assert inner_k_tiles in [2, 4, 8]
@@ -609,6 +612,8 @@ def __init__(
609612
self.groupsize: int = groupsize
610613
self.padding_allowed: bool = padding_allowed
611614
self.device: torch.device = device
615+
# precision and dtype are being used interchangeably here
616+
self.precision: torch.dtype = precision
612617

613618
@torch.no_grad()
614619
def _create_quantized_state_dict(
@@ -648,6 +653,7 @@ def _create_quantized_state_dict(
648653
weight,
649654
4, # n_bit
650655
self.groupsize,
656+
self.precision, # dtype for scales_and_zeros
651657
)
652658
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles)
653659
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
@@ -660,6 +666,8 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
660666
self.groupsize,
661667
self.inner_k_tiles,
662668
self.padding_allowed,
669+
skip_layer_func=None,
670+
dtype=self.precision,
663671
)
664672
return model
665673

torchao/quantization/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16
307307
).reshape(w.shape[0], -1)
308308

309309

310-
def pack_tinygemm_scales_and_zeros(scales, zeros):
311-
guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size())
312-
guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16)
310+
def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16):
311+
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
312+
guard_dtype_size(zeros, "zeros", dtype=dtype)
313313
return (
314314
torch.cat(
315315
[
@@ -376,7 +376,7 @@ def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bflo
376376
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
377377
w, scales, zeros, n_bit, groupsize
378378
)
379-
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
379+
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros, dtype)
380380
return w_int4x8, scales_and_zeros
381381

382382

0 commit comments

Comments
 (0)