Skip to content

Commit d023076

Browse files
committed
[inference] add smoothquant llama (hpcaitech#4861)
* add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code
1 parent c93bc2d commit d023076

File tree

12 files changed

+2166
-145
lines changed

12 files changed

+2166
-145
lines changed

colossalai/inference/quant/smoothquant/__init__.py

Whitespace-only changes.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
2+
3+
import functools
4+
5+
import torch
6+
import torch.nn as nn
7+
from datasets import load_dataset
8+
from tqdm import tqdm
9+
10+
11+
def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512):
12+
model.eval()
13+
device = next(model.parameters()).device
14+
act_scales = {}
15+
16+
def stat_tensor(name, tensor):
17+
hidden_dim = tensor.shape[-1]
18+
tensor = tensor.view(-1, hidden_dim).abs().detach()
19+
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
20+
if name in act_scales:
21+
act_scales[name] = torch.max(act_scales[name], comming_max)
22+
else:
23+
act_scales[name] = comming_max
24+
25+
def stat_input_hook(m, x, y, name):
26+
if isinstance(x, tuple):
27+
x = x[0]
28+
stat_tensor(name, x)
29+
30+
hooks = []
31+
for name, m in model.named_modules():
32+
if isinstance(m, nn.Linear):
33+
hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name)))
34+
35+
dataset = load_dataset("json", data_files=dataset_path)
36+
37+
print("text", dataset["train"]["rows"][0][1]["row"]["text"])
38+
39+
dataset = dataset.shuffle(seed=42)
40+
41+
for i in tqdm(range(num_samples)):
42+
input_ids = tokenizer(
43+
dataset["train"]["rows"][0][i]["row"]["text"],
44+
return_tensors="pt",
45+
max_length=seq_len,
46+
truncation=True,
47+
).input_ids.to(device)
48+
model(input_ids)
49+
50+
for h in hooks:
51+
h.remove()
52+
53+
return act_scales

0 commit comments

Comments
 (0)