Skip to content

Commit 2a502f7

Browse files
authored
Add FP6-LLM doc and move FP6-LLM to prototype (#358)
* add doc. move fp6_llm under prototype * doc update * update doc. rename functions
1 parent 6f44d25 commit 2a502f7

File tree

11 files changed

+62
-15
lines changed

11 files changed

+62
-15
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ To learn more try out our APIs, you can check out API examples in
124124
4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees
125125
- [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning
126126
- [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads
127+
- [FP6-LLM](torchao/prototype/fp6_llm) mixed matmul FP16 x FP6 kernel for io bound workloads
127128

128129
## Our Goals
129130

benchmarks/benchmark_fp6_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import nn
3-
from torchao.quantization.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
3+
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
44
from torch.utils.benchmark import Timer
55
import pandas as pd
66
from tqdm import tqdm

test/quantization/test_fp6_llm.py renamed to test/prototype/test_fp6_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
parametrize,
88
run_tests,
99
)
10-
from torchao.quantization.fp6_llm import (
10+
from torchao.prototype.fp6_llm.fp6_llm import (
1111
to_tc_float6_e3m2,
1212
from_tc_float6_e3m2,
1313
_to_tc_float6_e3m2_ref,

test/test_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
33
from torch.testing._internal.optests import opcheck
44
import torchao
5-
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
5+
from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2
66
import unittest
77
from parameterized import parameterized
88
import pytest
@@ -26,27 +26,27 @@ def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
2626
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)
2727

2828
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
29-
def test_fp16act_fp6weight_linear(self):
29+
def test_fp6_llm_linear(self):
3030
BS = 2
3131
OC = 256
3232
IC = 256
3333
splitK = 1
3434
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
3535

3636
# smoke test
37-
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
37+
torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
3838

3939
# comprehensive testing
4040
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
41-
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
41+
opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
4242

4343
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
4444
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
4545
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
46-
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
46+
def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
4747
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
4848

49-
results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
49+
results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
5050

5151
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
5252
results_fp16 = fp16_activation @ fp16_weight.T

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
178178
}
179179

180180
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
181-
m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda);
181+
m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda);
182182
}
183183

184184
} // namespace torchao

torchao/csrc/fp6_llm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44

55
TORCH_LIBRARY_FRAGMENT(torchao, m) {
66
m.impl_abstract_pystub("torchao.ops");
7-
m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
7+
m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor");
88
}

torchao/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def decorator(func):
1212
return decorator
1313

1414

15-
def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
15+
def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor:
1616
"""
1717
FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
1818
@@ -25,10 +25,10 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso
2525
Returns
2626
output of linear layer
2727
"""
28-
return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK)
28+
return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK)
2929

3030

31-
@register_custom_op("torchao::fp16act_fp6weight_linear")
31+
@register_custom_op("torchao::fp6_llm_linear")
3232
def _(_in_feats, _weights, _scales, splitK = 1):
3333
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
3434
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")

torchao/prototype/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507)
1010
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
1111
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
12+
- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
1213

1314
#### Roadmap
1415

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# FP6-LLM
2+
3+
This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6.
4+
5+
## Usage
6+
7+
```python
8+
from torchao.prototype.fp6_llm import convert_fp6_llm
9+
10+
model = ...
11+
convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear
12+
13+
# fully compatible with torch.compile()
14+
model.compile(mode="max-autotune", fullgraph=True)
15+
```
16+
17+
It's also possible to pre-process the weight and call the kernel directly.
18+
19+
```python
20+
import torch
21+
from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2
22+
from torchao.ops import fp6_llm_linear
23+
24+
fp32_weight = torch.randn(1024, 512).cuda()
25+
26+
# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
27+
# layout for tensor cores. refer to paper for more details.
28+
fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight)
29+
30+
fp16_act = torch.randn(1, 512).cuda().half()
31+
outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024)
32+
```
33+
34+
## TODO
35+
36+
- [ ] Compile CUDA kernel for Windows
37+
- [ ] Merge FP5 from upstream
38+
39+
## Credits
40+
41+
Credits to FP6-LLM authors
42+
43+
- Paper: https://arxiv.org/abs/2401.14112
44+
- Code: https://github.com/usyd-fsalab/fp6_llm
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2

0 commit comments

Comments
 (0)