Skip to content

Commit c82b177

Browse files
authored
[inference] add smooth function and delete useless code for smoothquant (#4895)
* add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file
1 parent 82afdc7 commit c82b177

File tree

3 files changed

+102
-121
lines changed

3 files changed

+102
-121
lines changed

colossalai/inference/quant/smoothquant/calibration.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

colossalai/inference/quant/smoothquant/models/base_model.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,31 @@
44
import types
55
import warnings
66
from abc import abstractmethod
7+
from functools import partial
78
from os.path import isdir, isfile, join
89
from typing import Dict, List, Optional, Union
910

1011
import accelerate
12+
import numpy as np
1113
import torch
1214
import torch.nn as nn
1315
import transformers
1416
from safetensors.torch import save_file as safe_save
1517
from torch import device
18+
from tqdm import tqdm
1619
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
1720
from transformers.modeling_utils import no_init_weights
1821
from transformers.utils.generic import ContextManagers
1922
from transformers.utils.hub import PushToHubMixin, cached_file
2023

21-
from ....tensor_parallel.batch_infer_state import BatchInferState
22-
from ....tensor_parallel.kvcache_manager import MemoryManager
24+
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
25+
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
2326

2427
CPU = device("cpu")
25-
CUDA_0 = device("cuda:0")
2628

2729
SUPPORTED_MODELS = ["llama"]
2830

2931

30-
def get_module_by_name_suffix(model, module_name: str):
31-
for name, module in model.named_modules():
32-
if name.endswith(module_name):
33-
return module
34-
35-
36-
def simple_dispatch_model(model, device_map):
37-
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
38-
39-
if "" in device_map:
40-
d = device_map[""]
41-
model = model.to(torch.device(d))
42-
model.hf_device_map = device_map
43-
return model
44-
45-
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
46-
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
47-
main_device = "cpu"
48-
else:
49-
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
50-
51-
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
52-
prev_hook = None
53-
for idx, (n, d) in enumerate(cpu_offload_group):
54-
m = get_module_by_name_suffix(model, n)
55-
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
56-
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
57-
if len(cpu_offload_group) > 1:
58-
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
59-
60-
for n, d in device_map.items():
61-
m = get_module_by_name_suffix(model, n)
62-
if d != "cpu":
63-
d = torch.device(d)
64-
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
65-
add_hook_to_module(m, hook)
66-
accelerate.utils.modeling.retie_parameters(model, tied_params)
67-
model.hf_device_map = device_map
68-
69-
return model
70-
71-
7232
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
7333
layer_type: str = None
7434

@@ -132,6 +92,7 @@ def init_batch_state(self, max_output_len=256, **kwargs):
13292
batch_infer_state.past_key_values_len = 0
13393
batch_infer_state.is_context_stage = True
13494
batch_infer_state.set_cache_manager(self.cache_manager)
95+
batch_infer_state.cache_manager.free_all()
13596
return batch_infer_state
13697

13798
@abstractmethod
@@ -157,15 +118,79 @@ def generate(self, **kwargs):
157118
if self.config.model_type == "llama":
158119
setattr(self.model.model, "infer_state", batch_infer_state)
159120

160-
batch_infer_state.is_context_stage = True
161-
162121
with torch.inference_mode():
163122
return self.model.generate(**kwargs)
164123

165124
def prepare_inputs_for_generation(self, *args, **kwargs):
166125
"""shortcut for model.prepare_inputs_for_generation"""
167126
return self.model.prepare_inputs_for_generation(*args, **kwargs)
168127

128+
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
129+
for text in tqdm(dataset):
130+
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
131+
model(input_ids)
132+
133+
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
134+
pbar = tqdm(dataset)
135+
for text in pbar:
136+
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
137+
model(input_ids)
138+
mean_scale = np.mean([v["input"] for v in act_dict.values()])
139+
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
140+
141+
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
142+
model.eval()
143+
device = next(model.parameters()).device
144+
act_scales = {}
145+
146+
def stat_tensor(name, tensor):
147+
hidden_dim = tensor.shape[-1]
148+
tensor = tensor.view(-1, hidden_dim).abs().detach()
149+
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
150+
if name in act_scales:
151+
act_scales[name] = torch.max(act_scales[name], comming_max)
152+
else:
153+
act_scales[name] = comming_max
154+
155+
def stat_input_hook(m, x, y, name):
156+
if isinstance(x, tuple):
157+
x = x[0]
158+
stat_tensor(name, x)
159+
160+
hooks = []
161+
for name, m in model.named_modules():
162+
if isinstance(m, nn.Linear):
163+
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
164+
165+
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
166+
167+
for h in hooks:
168+
h.remove()
169+
170+
return act_scales
171+
172+
@torch.no_grad()
173+
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
174+
if not isinstance(fcs, list):
175+
fcs = [fcs]
176+
for fc in fcs:
177+
assert isinstance(fc, nn.Linear)
178+
assert ln.weight.numel() == fc.in_features == act_scales.numel()
179+
180+
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
181+
act_scales = act_scales.to(device=device, dtype=dtype)
182+
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
183+
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
184+
185+
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
186+
187+
ln.weight.div_(scales)
188+
if hasattr(ln, "bias"):
189+
ln.bias.div_(scales)
190+
191+
for fc in fcs:
192+
fc.weight.mul_(scales.view(1, -1))
193+
169194
def save_quantized(
170195
self,
171196
save_dir: str,

colossalai/inference/quant/smoothquant/models/llama.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,10 @@
77
from functools import partial
88
from typing import List, Optional, Tuple, Union
99

10-
import numpy as np
1110
import torch
1211
import torch.nn as nn
1312
import torch.nn.functional as F
14-
from datasets import load_dataset
15-
from torch import nn
1613
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
17-
from tqdm import tqdm
1814
from transformers import PreTrainedModel
1915
from transformers.modeling_outputs import BaseModelOutputWithPast
2016
from transformers.models.llama.configuration_llama import LlamaConfig
@@ -756,15 +752,14 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
756752
def __init__(self, model: PreTrainedModel, quantized: bool = False):
757753
super().__init__(model, quantized)
758754

759-
def quantized(
755+
def get_act_dict(
760756
self,
761757
tokenizer,
762-
dataset_path,
758+
dataset,
763759
num_samples=512,
764760
seq_len=512,
765761
):
766762
llama_model = self.model
767-
llama_config = llama_model.config
768763

769764
llama_model.eval()
770765
device = next(llama_model.parameters()).device
@@ -798,23 +793,37 @@ def stat_io_hook(m, x, y, name):
798793
if isinstance(m, torch.nn.Linear):
799794
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
800795

801-
print("Collecting activation scales...")
802-
pbar = tqdm(range(num_samples))
803-
dataset = load_dataset("json", data_files=dataset_path, split="train")
804-
dataset = dataset.shuffle(seed=42)
805-
for i in pbar:
806-
input_ids = tokenizer(
807-
dataset["rows"][0][i]["row"]["text"],
808-
return_tensors="pt",
809-
max_length=seq_len,
810-
truncation=True,
811-
).input_ids.to(device)
812-
llama_model(input_ids)
813-
mean_scale = np.mean([v["input"] for v in act_dict.values()])
814-
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
796+
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
797+
815798
for hook in hooks:
816799
hook.remove()
800+
return act_dict
801+
802+
def smooth_fn(self, scales, alpha=0.5):
803+
model = self.model
804+
for name, module in model.named_modules():
805+
if isinstance(module, LlamaDecoderLayer):
806+
attn_ln = module.input_layernorm
807+
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
808+
qkv_input_scales = scales[name + ".self_attn.q_proj"]
809+
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
810+
811+
def quantized(
812+
self,
813+
tokenizer,
814+
dataset,
815+
num_samples=512,
816+
seq_len=512,
817+
alpha=0.5,
818+
):
819+
llama_model = self.model
820+
llama_config = llama_model.config
821+
822+
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
823+
824+
self.smooth_fn(act_scales, alpha)
817825

826+
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
818827
decoder_layer_scales = []
819828

820829
for idx in range(llama_config.num_hidden_layers):

0 commit comments

Comments
 (0)