Skip to content

Spin Quant in TorchAO #579

@HDCharles

Description

@HDCharles

Background:
The spin quant paper introduces a method of improving quantization by adding additional rotation matrices to the model weights that improve quantization performance.

While spin-quant is a fairly sophisticated technique, some independent pieces could be implemented to modularly to get incremental improvements on a smaller scale.

(see image)
https://imgur.com/jU60Iqs

In the above image each rotation in both the a and b parts of the figure could be independently implemented to improve quantization accuracy in the model. These rotations are

Rotations which can be fully absorbed by weight matrices of the linear ops and don’t introduce additional ops:
R2
Rotations which need a constant number of additional ops per model:
R1
Rotations which require additional ops per block:
R3, R4

While the second set and third set of rotations requires adding additional ops to the model, the R2 rotation would require only a small change to the model weights and no additional ops.

Task:

Start by implementing the R2 rotation with a random hadamard matrix (the paper indicates they perform fairly well) and demonstrate the improved quantization accuracy for int8 dynamic/weight-only and int4 weight-only quantization. Ideally we'd like to see improved eval performance in eval compared to the non spin-quant version. Code would ideally go into a new file in torchao/quantization/spin_quant.py.

Adding additional rotations (and the necessary additional ops) or a rotation optimization procedure for R2 as used in Spin Quant can follow after.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions