Skip to content

Commit f400fef

Browse files
committed
Merge remote-tracking branch 'origin/main' into bench_structure
2 parents 0a2499c + 81f0bf2 commit f400fef

File tree

122 files changed

+5238
-4650
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+5238
-4650
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ jobs:
3333
- name: Install requirements
3434
run: |
3535
conda activate venv
36-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
36+
# Install executorch first because it installs its own version
37+
# of torch and torchao, which we do not want to use
38+
pip install executorch
39+
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
3740
pip install numpy
3841
pip install pytest
3942
pip install parameterized
@@ -57,6 +60,12 @@ jobs:
5760
sh build_and_run_tests.sh
5861
rm -rf /tmp/cmake-out
5962
popd
63+
- name: ET ops build
64+
run: |
65+
conda activate venv
66+
pushd torchao/experimental
67+
sh build_torchao_ops.sh executorch
68+
popd
6069
6170
test-mps-ops:
6271
strategy:

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
115115
ADAM takes 2x as much memory as the model params so we can quantize the optimizer state to either 8 or 4 bit effectively reducing the optimizer VRAM requirements by 2x or 4x respectively over an fp16 baseline
116116

117117
```python
118-
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
118+
from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
119119
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
120120
```
121121

122-
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)
122+
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/optim)
123123

124-
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
124+
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
125125

126126
```python
127127
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)

benchmarks/benchmark_low_bit_adam.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torchvision.transforms import v2
3535
from tqdm import tqdm
3636

37-
from torchao.prototype import low_bit_optim
37+
from torchao import optim
3838
from torchao.utils import get_available_devices
3939

4040
_DEVICE = get_available_devices()[-1]
@@ -43,9 +43,9 @@
4343
OPTIM_MAP = dict(
4444
AdamW=partial(torch.optim.AdamW, fused=True),
4545
AdamW8bitBnb=bnb.optim.AdamW8bit,
46-
AdamW8bitAo=low_bit_optim.AdamW8bit,
47-
AdamWFp8Ao=low_bit_optim.AdamWFp8,
48-
AdamW4bitAo=low_bit_optim.AdamW4bit,
46+
AdamW8bitAo=optim.AdamW8bit,
47+
AdamWFp8Ao=optim.AdamWFp8,
48+
AdamW4bitAo=optim.AdamW4bit,
4949
)
5050

5151
try:
@@ -249,12 +249,10 @@ def evaluate_model(model, args):
249249
optim_cls = OPTIM_MAP[args.optim]
250250

251251
if args.optim_cpu_offload == "ao":
252-
optim_cls = partial(
253-
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254-
)
252+
optim_cls = partial(optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
255253
elif args.optim_cpu_offload == "ao_offload_grads":
256254
optim_cls = partial(
257-
low_bit_optim.CPUOffloadOptimizer,
255+
optim.CPUOffloadOptimizer,
258256
optimizer_class=optim_cls,
259257
offload_gradients=True,
260258
)

benchmarks/benchmark_rowwise_scaled_linear_cutlass.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,55 @@
77
rowwise_scaled_linear_cutlass_s4s4,
88
rowwise_scaled_linear_cutlass_s8s4,
99
)
10+
from torchao.quantization.quant_api import (
11+
_int4_symm_cutlass_quant,
12+
_int8_symm_cutlass_quant,
13+
)
14+
15+
dtype = torch.bfloat16
16+
dtypeq = torch.int8
17+
dtype_scale = torch.float32
18+
device = torch.device("cuda")
1019

1120

1221
def benchmark_microseconds(f, *args):
1322
return do_bench(lambda: f(*args), return_mode="median") * 1e3
1423

1524

16-
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
17-
assert A_nbits in (4, 8) and B_nbits in (4, 8)
25+
def get_problem(m: int, n: int, k: int, Xq_nbits: int):
26+
assert k % 2 == 0
27+
assert Xq_nbits in [4, 8]
28+
29+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
30+
W_ref = torch.rand((n, k), dtype=dtype, device=device)
1831

19-
dev = torch.device("cuda")
20-
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
21-
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
22-
B = torch.randint(
23-
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
32+
X_quant_func = (
33+
_int4_symm_cutlass_quant if Xq_nbits == 4 else _int8_symm_cutlass_quant
2434
)
25-
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
26-
C = None
35+
W_quant_func = _int4_symm_cutlass_quant
36+
X_aqt = X_quant_func(X_ref)
37+
W_aqt = W_quant_func(W_ref)
2738

28-
return A, A_scale, B, B_scale, C
39+
Xq = X_aqt.tensor_impl.int_data
40+
X_scale = X_aqt.tensor_impl.scale
41+
Wq = W_aqt.tensor_impl.int_data
42+
W_scale = W_aqt.tensor_impl.scale
43+
bias = None
44+
out_dtype = dtype
2945

46+
return (X_ref, W_ref), (Xq, X_scale, Wq, W_scale, bias, out_dtype)
3047

31-
def benchmark(m: int, k: int, n: int):
32-
dev = torch.device("cuda")
33-
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
34-
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
35-
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
3648

37-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
38-
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
39-
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
49+
def benchmark(m: int, k: int, n: int):
50+
ref_args, args = get_problem(m, n, k, 4)
51+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
52+
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
53+
rowwise_scaled_linear_cutlass_s4s4, *args
4054
)
4155

42-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
43-
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
44-
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
56+
_, args = get_problem(m, n, k, 8)
57+
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
58+
rowwise_scaled_linear_cutlass_s8s4, *args
4559
)
4660

4761
return {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
7+
from torchao.quantization.quant_api import (
8+
_float8_cutlass_quant,
9+
_float8_cutlass_quant_sparse,
10+
)
11+
from torchao.sparsity.utils import create_semi_structured_tensor
12+
13+
dtype = torch.bfloat16
14+
dtypeq_X = torch.float8_e5m2
15+
dtypeq_W = torch.float8_e4m3fn
16+
device = torch.device("cuda")
17+
18+
19+
def benchmark_microseconds(f, *args):
20+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
21+
22+
23+
def get_problem(m: int, n: int, k: int):
24+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
25+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
26+
27+
X_quant_func = _float8_cutlass_quant
28+
W_quant_func = _float8_cutlass_quant_sparse
29+
X_aqt = X_quant_func(X_ref, dtypeq_X)
30+
W_aqt = W_quant_func(W_ref, dtypeq_W)
31+
32+
Xq = X_aqt.tensor_impl.float8_data
33+
X_scale = X_aqt.tensor_impl.scale
34+
Wq_sparse = W_aqt.tensor_impl.sparse
35+
W_meta = W_aqt.tensor_impl.meta
36+
W_scale = W_aqt.tensor_impl.scale
37+
bias = None
38+
out_dtype = dtype
39+
40+
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)
41+
42+
43+
def benchmark(m: int, k: int, n: int):
44+
ref_args, args = get_problem(m, n, k)
45+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
46+
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
47+
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
48+
)
49+
50+
return {
51+
"m": m,
52+
"k": k,
53+
"n": n,
54+
"fp16_latency (ms)": fp16_time,
55+
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
56+
"f8f8 speedup (d/s)": fp16_time
57+
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
58+
}
59+
60+
61+
if __name__ == "__main__":
62+
k_vals = (8192, 8192, 8192, 28672)
63+
n_vals = (8192, 10240, 57344, 8192)
64+
65+
results = []
66+
for m in tqdm([1 << i for i in range(10)]):
67+
for n, k in zip(n_vals, k_vals):
68+
results.append(benchmark(m, k, n))
69+
70+
df = pd.DataFrame(results)
71+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
72+
print(df.to_markdown(index=False))

benchmarks/float8/float8_roofline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ def get_gemm_times(
184184
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
185185
scale_a = torch.ones(M, 1, device=device)
186186
scale_b = torch.ones(1, N, device=device)
187+
elif mx_recipe_name == "mxfp8_cublas":
188+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
189+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
187190
else:
188-
assert False, "TODO add mx gemm here"
191+
assert False, "TODO add cutlass mx gemm here"
189192

190193
def do_matmul(A, B):
191194
return torch._scaled_mm(

benchmarks/float8/training/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ The `float8_training_benchmark.sh` script in this directory can be used to launc
44

55
## Usage
66

7-
Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE=rowwise ./float8_training_benchmark.sh`
7+
Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE_WITH_BEST_SETTINGS=rowwise ./float8_training_benchmark.sh`
88

99
Training parameters can be configured via environment variables.
1010

1111
- Required:
12-
- `TORCHTITAN_ROOT`
12+
- `TORCHTITAN_ROOT`: Root directory of torchtitan in your local filesystem
1313
- Optional:
14-
- `RECIPE`: rowwise|tensorwise. defaults to tensorwise.
15-
- `BATCH_SIZE`: defaults to 1.
16-
- `STEPS`: defaults to 100.
14+
- `FLOAT8_RECIPE_WITH_BEST_SETTINGS`: "rowwise" or "tensorwise". Applies float8 training with the specified scaling recipe, as well as additional training configs which are optimal for that scaling recipe. See `float8_training_benchmark.sh` for more details.
15+
- `BATCH_SIZE`: Defaults to 1.
16+
- `STEPS`: Defaults to 100.
1717

1818
**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222
from torch.utils.checkpoint import checkpoint
2323
from tqdm import tqdm
2424

25-
from torchao import quantize_
25+
from torchao import optim, quantize_
2626
from torchao._models.llama.model import (
2727
ModelArgs,
2828
RMSNorm,
2929
Transformer,
3030
transformer_configs,
3131
)
32-
from torchao.prototype import low_bit_optim
3332
from torchao.prototype.quantized_training import (
3433
bitnet_training,
3534
int8_mixed_precision_training,
@@ -190,10 +189,10 @@ def insert_rmsnorm(module: torch.nn.Module):
190189
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
191190
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights
192191

193-
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
192+
# only use optimizers from torchao.optim to support quantized training
194193
if args.optim == "AdamW":
195194
args.optim = "_AdamW"
196-
optim = getattr(low_bit_optim, args.optim)(
195+
optimizer = getattr(optim, args.optim)(
197196
model.parameters(),
198197
lr=args.lr,
199198
weight_decay=args.weight_decay,
@@ -228,15 +227,15 @@ def insert_rmsnorm(module: torch.nn.Module):
228227
if step % args.log_interval == 0:
229228
log_dict = dict(
230229
loss=loss.item(),
231-
lr=optim.param_groups[0]["lr"],
230+
lr=optimizer.param_groups[0]["lr"],
232231
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
233232
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
234233
)
235234
run.log(log_dict, step=step)
236235
pbar.set_postfix(loss=log_dict["loss"])
237236

238-
optim.step()
239-
optim.zero_grad()
237+
optimizer.step()
238+
optimizer.zero_grad()
240239

241240
step += 1
242241
pbar.update()

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
2828
MarlinQQQLayout
2929
Int4CPULayout
3030
CutlassInt4PackedLayout
31+
CutlassSemiSparseLayout
3132

3233
Quantization techniques
3334
-----------------------

0 commit comments

Comments
 (0)