-
Notifications
You must be signed in to change notification settings - Fork 345
Description
Background:
The spin quant paper introduces a method of improving quantization by adding additional rotation matrices to the model weights that improve quantization performance.
While spin-quant is a fairly sophisticated technique, some independent pieces could be implemented to modularly to get incremental improvements on a smaller scale.
(see image)
https://imgur.com/jU60Iqs
In the above image each rotation in both the a and b parts of the figure could be independently implemented to improve quantization accuracy in the model. These rotations are
Rotations which can be fully absorbed by weight matrices of the linear ops and don’t introduce additional ops:
R2
Rotations which need a constant number of additional ops per model:
R1
Rotations which require additional ops per block:
R3, R4
While the second set and third set of rotations requires adding additional ops to the model, the R2 rotation would require only a small change to the model weights and no additional ops.
Task:
Start by implementing the R2 rotation with a random hadamard matrix (the paper indicates they perform fairly well) and demonstrate the improved quantization accuracy for int8 dynamic/weight-only and int4 weight-only quantization. Ideally we'd like to see improved eval performance in eval compared to the non spin-quant version. Code would ideally go into a new file in torchao/quantization/spin_quant.py.
Adding additional rotations (and the necessary additional ops) or a rotation optimization procedure for R2 as used in Spin Quant can follow after.