Skip to content

Commit 6f6445d

Browse files
Add option to load pretrained R1/R2 matrices
1 parent a378156 commit 6f6445d

File tree

1 file changed

+56
-20
lines changed

1 file changed

+56
-20
lines changed

torchao/quantization/spin_quant.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Based on https://github.com/facebookresearch/SpinQuant
55
"""
66

7+
from pathlib import Path
78
import typing
89

910
import torch
@@ -31,12 +32,12 @@ def forward(self, x):
3132
return x
3233

3334

34-
def apply_spinquant(model: Transformer):
35+
def apply_spinquant(model: Transformer, use_r1=True, use_r2=True, use_r4=True, pretrained_rotation_path=None):
3536
"""
3637
Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406
3738
38-
Currently, this only applies the R1 + R2 + R4 rotation matrices to the model
39-
(not R3, and no Cayley optimization).
39+
Currently, this has the option of applying R1, R2, and R4 rotation matrices to
40+
the model (not R3, and no Cayley optimization).
4041
"""
4142
assert isinstance(model, Transformer), "Only Transformer models are supported"
4243

@@ -45,25 +46,59 @@ def apply_spinquant(model: Transformer):
4546
model.to(device=device)
4647
torch.manual_seed(0) # for reproducability of random Hadamard matrices
4748

48-
fuse_layernorm_weights_into_linear(model)
49-
apply_spinquant_r1(model, device)
50-
apply_spinquant_r2(model, device)
51-
apply_spinquant_r4(model, device)
49+
# For testing purposes (remove later)
50+
# Weights link: https://drive.google.com/drive/folders/1nV9juzE6_OHr10y6Ke5KCyOiGqDr0srX
51+
pretrained_rotation_path = "7B_W4A16KV16_lr_1.5_seed_0/R.bin"
52+
53+
if pretrained_rotation_path is not None:
54+
assert Path(pretrained_rotation_path).exists, "Pretrained rotation path does not exist"
55+
assert Path(pretrained_rotation_path).suffix == ".bin", "Expected a .bin file."
56+
57+
if use_r1:
58+
fuse_layernorm_weights_into_linear(model)
59+
apply_spinquant_r1(model, device, pretrained_rotation_path)
60+
if use_r2:
61+
apply_spinquant_r2(model, device, pretrained_rotation_path)
62+
if use_r4:
63+
apply_spinquant_r4(model, device)
5264

5365
model.to(device=original_device)
5466

5567

56-
def apply_spinquant_r1(model, device):
57-
R1 = random_hadamard_matrix(model.config.dim, device)
68+
def apply_spinquant_r1(model, device, pretrained_rotation_path=None):
69+
"""Apply the SpinQuant R1 rotation matrix to the model."""
70+
71+
# Load R1 matrix
72+
if pretrained_rotation_path is not None:
73+
R1 = torch.load(pretrained_rotation_path)["R1"].to(device).to(torch.float64)
74+
assert R1.shape == (model.config.dim, model.config.dim), f"{R1.shape} vs {model.config.dim}"
75+
else:
76+
R1 = random_hadamard_matrix(model.config.dim, device)
77+
5878
_rotate_model_r1(model, R1)
5979

6080

61-
def apply_spinquant_r2(model, device):
62-
R2 = random_hadamard_matrix(model.config.head_dim, device)
63-
_rotate_model_r2(model, R2)
81+
def apply_spinquant_r2(model, device, pretrained_rotation_path=None):
82+
"""Apply the SpinQuant R2 rotation matrix to the model."""
83+
84+
# Load R2 matrices
85+
R2s = []
86+
head_dim = model.config.head_dim
87+
for i, _ in enumerate(model.layers):
88+
if pretrained_rotation_path is not None:
89+
key = f"model.layers.{i}.self_attn.R2"
90+
R2s_ = torch.load(pretrained_rotation_path)
91+
R2 = R2s_[key].to(device).to(torch.float64)
92+
assert R2.shape == (head_dim, head_dim), f"{R2.shape} != ({head_dim}, {head_dim})"
93+
else:
94+
R2 = random_hadamard_matrix(head_dim, device)
95+
R2s.append(R2)
96+
97+
_rotate_model_r2(model, R2s)
6498

6599

66100
def apply_spinquant_r4(model, device):
101+
"""Apply the SpinQuant R4 rotation matrix to the model."""
67102
_rotate_model_r4(model)
68103
_add_activation_wrappers_r4(model)
69104

@@ -102,14 +137,16 @@ def _rotate_model_r1(model, R1):
102137

103138

104139
@torch.inference_mode()
105-
def _rotate_model_r2(model, R2):
140+
def _rotate_model_r2(model, R2s):
106141
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""
107-
142+
108143
# Apply R2 rotation to all multi-head self-attention modules
109-
for layer in model.layers:
144+
for idx, layer in enumerate(model.layers):
110145
attn = layer.attention
111146
head_dim = model.config.head_dim
112147

148+
R2 = R2s[idx]
149+
113150
# Rotate W_o
114151
apply_exact_had_to_linear(attn.wo, had_dim=head_dim, output=False, R2=R2)
115152

@@ -169,10 +206,9 @@ def fuse_layernorm_weights_into_linear(model):
169206
something.)
170207
"""
171208
# Embedding fusion (from utils/fuse_norm_utils.py:43)
172-
# I currently don't understand why this is necessary, so I'm omitting it (I
173-
# contacted the authors about it:
174-
# https://github.com/facebookresearch/SpinQuant/issues/14). It doesn't seem
175-
# to affect performance (tested on int4wo)
209+
# I currently don't understand why this is necessary, so I contacted the
210+
# authors about it:
211+
# https://github.com/facebookresearch/SpinQuant/issues/14).
176212
for W in [model.tok_embeddings]:
177213
W_ = W.weight.data.double()
178214
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
@@ -196,7 +232,7 @@ def fuse_layernorm_weights_into_linear(model):
196232
def _rotate_mlp_output(layer, R1):
197233
W = layer.feed_forward.w2
198234
dtype = W.weight.dtype
199-
W_ = W.weight.data.to( dtype=torch.float64)
235+
W_ = W.weight.data.to(dtype=torch.float64)
200236
W.weight.data = torch.matmul(R1.T, W_).to(dtype=dtype)
201237
if W.bias is not None:
202238
b = W.bias.data.to(dtype=torch.float64)

0 commit comments

Comments
 (0)