44Based on https://github.com/facebookresearch/SpinQuant
55"""
66
7+ from pathlib import Path
78import typing
89
910import torch
@@ -31,12 +32,12 @@ def forward(self, x):
3132 return x
3233
3334
34- def apply_spinquant (model : Transformer ):
35+ def apply_spinquant (model : Transformer , use_r1 = True , use_r2 = True , use_r4 = True , pretrained_rotation_path = None ):
3536 """
3637 Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406
3738
38- Currently, this only applies the R1 + R2 + R4 rotation matrices to the model
39- (not R3, and no Cayley optimization).
39+ Currently, this has the option of applying R1, R2, and R4 rotation matrices to
40+ the model (not R3, and no Cayley optimization).
4041 """
4142 assert isinstance (model , Transformer ), "Only Transformer models are supported"
4243
@@ -45,25 +46,59 @@ def apply_spinquant(model: Transformer):
4546 model .to (device = device )
4647 torch .manual_seed (0 ) # for reproducability of random Hadamard matrices
4748
48- fuse_layernorm_weights_into_linear (model )
49- apply_spinquant_r1 (model , device )
50- apply_spinquant_r2 (model , device )
51- apply_spinquant_r4 (model , device )
49+ # For testing purposes (remove later)
50+ # Weights link: https://drive.google.com/drive/folders/1nV9juzE6_OHr10y6Ke5KCyOiGqDr0srX
51+ pretrained_rotation_path = "7B_W4A16KV16_lr_1.5_seed_0/R.bin"
52+
53+ if pretrained_rotation_path is not None :
54+ assert Path (pretrained_rotation_path ).exists , "Pretrained rotation path does not exist"
55+ assert Path (pretrained_rotation_path ).suffix == ".bin" , "Expected a .bin file."
56+
57+ if use_r1 :
58+ fuse_layernorm_weights_into_linear (model )
59+ apply_spinquant_r1 (model , device , pretrained_rotation_path )
60+ if use_r2 :
61+ apply_spinquant_r2 (model , device , pretrained_rotation_path )
62+ if use_r4 :
63+ apply_spinquant_r4 (model , device )
5264
5365 model .to (device = original_device )
5466
5567
56- def apply_spinquant_r1 (model , device ):
57- R1 = random_hadamard_matrix (model .config .dim , device )
68+ def apply_spinquant_r1 (model , device , pretrained_rotation_path = None ):
69+ """Apply the SpinQuant R1 rotation matrix to the model."""
70+
71+ # Load R1 matrix
72+ if pretrained_rotation_path is not None :
73+ R1 = torch .load (pretrained_rotation_path )["R1" ].to (device ).to (torch .float64 )
74+ assert R1 .shape == (model .config .dim , model .config .dim ), f"{ R1 .shape } vs { model .config .dim } "
75+ else :
76+ R1 = random_hadamard_matrix (model .config .dim , device )
77+
5878 _rotate_model_r1 (model , R1 )
5979
6080
61- def apply_spinquant_r2 (model , device ):
62- R2 = random_hadamard_matrix (model .config .head_dim , device )
63- _rotate_model_r2 (model , R2 )
81+ def apply_spinquant_r2 (model , device , pretrained_rotation_path = None ):
82+ """Apply the SpinQuant R2 rotation matrix to the model."""
83+
84+ # Load R2 matrices
85+ R2s = []
86+ head_dim = model .config .head_dim
87+ for i , _ in enumerate (model .layers ):
88+ if pretrained_rotation_path is not None :
89+ key = f"model.layers.{ i } .self_attn.R2"
90+ R2s_ = torch .load (pretrained_rotation_path )
91+ R2 = R2s_ [key ].to (device ).to (torch .float64 )
92+ assert R2 .shape == (head_dim , head_dim ), f"{ R2 .shape } != ({ head_dim } , { head_dim } )"
93+ else :
94+ R2 = random_hadamard_matrix (head_dim , device )
95+ R2s .append (R2 )
96+
97+ _rotate_model_r2 (model , R2s )
6498
6599
66100def apply_spinquant_r4 (model , device ):
101+ """Apply the SpinQuant R4 rotation matrix to the model."""
67102 _rotate_model_r4 (model )
68103 _add_activation_wrappers_r4 (model )
69104
@@ -102,14 +137,16 @@ def _rotate_model_r1(model, R1):
102137
103138
104139@torch .inference_mode ()
105- def _rotate_model_r2 (model , R2 ):
140+ def _rotate_model_r2 (model , R2s ):
106141 """Rotate the W_v and W_o weights of the multi-head self-attention modules."""
107-
142+
108143 # Apply R2 rotation to all multi-head self-attention modules
109- for layer in model .layers :
144+ for idx , layer in enumerate ( model .layers ) :
110145 attn = layer .attention
111146 head_dim = model .config .head_dim
112147
148+ R2 = R2s [idx ]
149+
113150 # Rotate W_o
114151 apply_exact_had_to_linear (attn .wo , had_dim = head_dim , output = False , R2 = R2 )
115152
@@ -169,10 +206,9 @@ def fuse_layernorm_weights_into_linear(model):
169206 something.)
170207 """
171208 # Embedding fusion (from utils/fuse_norm_utils.py:43)
172- # I currently don't understand why this is necessary, so I'm omitting it (I
173- # contacted the authors about it:
174- # https://github.com/facebookresearch/SpinQuant/issues/14). It doesn't seem
175- # to affect performance (tested on int4wo)
209+ # I currently don't understand why this is necessary, so I contacted the
210+ # authors about it:
211+ # https://github.com/facebookresearch/SpinQuant/issues/14).
176212 for W in [model .tok_embeddings ]:
177213 W_ = W .weight .data .double ()
178214 W .weight .data = (W_ - W_ .mean (dim = - 1 , keepdim = True )).to (W .weight .data .dtype )
@@ -196,7 +232,7 @@ def fuse_layernorm_weights_into_linear(model):
196232def _rotate_mlp_output (layer , R1 ):
197233 W = layer .feed_forward .w2
198234 dtype = W .weight .dtype
199- W_ = W .weight .data .to ( dtype = torch .float64 )
235+ W_ = W .weight .data .to (dtype = torch .float64 )
200236 W .weight .data = torch .matmul (R1 .T , W_ ).to (dtype = dtype )
201237 if W .bias is not None :
202238 b = W .bias .data .to (dtype = torch .float64 )
0 commit comments