Skip to content

Commit c4bb9e7

Browse files
committed
update readme. small fixes
1 parent cbc3f05 commit c4bb9e7

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler
1212

1313
import argparse
14+
import datetime
1415
import math
1516
from contextlib import nullcontext
1617
from functools import partial
@@ -175,6 +176,7 @@ def evaluate_model(model, args):
175176

176177
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
177178

179+
start_time = datetime.datetime.now()
178180
step = 0
179181
for epoch_idx in range(args.n_epochs):
180182
model.train()
@@ -214,4 +216,5 @@ def evaluate_model(model, args):
214216
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
215217
logger.log(dict(val_acc=val_acc), step=step)
216218

217-
print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
219+
print(f"Time taken: {(datetime.datetime.now() - start_time)}")
220+
print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

torchao/prototype/low_bit_optim/README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@ NOTE:
3131

3232
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
3333

34-
Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER:
34+
Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:
3535

36-
TODO: update this table
36+
Adam impl | max memory (GB) | time taken | accuracy
37+
-----------|-----------------|------------|----------
38+
PyTorch | 12.98 | 10m 08s | 87.70
39+
bnb 8-bit | 8.31 | 8m 38s | 86.22
40+
ao 8-bit | 8.32 | 10m 54s | 86.67
41+
lpmm 4-bit | 7.72 | 7m 48s | 84.70
42+
ao 4-bit | 7.72 | 9m 17s | 85.60
3743

38-
Adam impl | max memory (GB) | training time | accuracy
39-
----------|-----------------|---------------|----------
44+
NOTE: time taken includes validation time, and compile time for torchao optimizers.
4045

4146
## Credits
4247

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def quantize_4bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implem
4949
raise ValueError(f"Unsupported implementation={implementation}")
5050

5151
# packing
52-
codes1, codes2 = codes.chunk(2, 0)
53-
codes = (codes1 << 4) | codes2
54-
52+
codes = (codes[::2] << 4) | codes[1::2]
5553
return codes, scale
5654

5755

@@ -90,9 +88,7 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No
9088

9189
def dequantize(self, output_dtype=None):
9290
# unpack
93-
codes1 = self.codes >> 4
94-
codes2 = self.codes & 0b1111
95-
codes = torch.cat([codes1, codes2], 0)
91+
codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1)
9692

9793
# torch.compile() cannot use uint8 as index
9894
float_data = self.qmap[codes.int()]

0 commit comments

Comments
 (0)