44import types
55import warnings
66from abc import abstractmethod
7+ from functools import partial
78from os .path import isdir , isfile , join
89from typing import Dict , List , Optional , Union
910
1011import accelerate
12+ import numpy as np
1113import torch
1214import torch .nn as nn
1315import transformers
1416from safetensors .torch import save_file as safe_save
1517from torch import device
18+ from tqdm import tqdm
1619from transformers import AutoConfig , AutoModelForCausalLM , PreTrainedModel
1720from transformers .modeling_utils import no_init_weights
1821from transformers .utils .generic import ContextManagers
1922from transformers .utils .hub import PushToHubMixin , cached_file
2023
21- from ... .tensor_parallel .batch_infer_state import BatchInferState
22- from ... .tensor_parallel .kvcache_manager import MemoryManager
24+ from colossalai . inference .tensor_parallel .batch_infer_state import BatchInferState
25+ from colossalai . inference .tensor_parallel .kvcache_manager import MemoryManager
2326
2427CPU = device ("cpu" )
25- CUDA_0 = device ("cuda:0" )
2628
2729SUPPORTED_MODELS = ["llama" ]
2830
2931
30- def get_module_by_name_suffix (model , module_name : str ):
31- for name , module in model .named_modules ():
32- if name .endswith (module_name ):
33- return module
34-
35-
36- def simple_dispatch_model (model , device_map ):
37- from accelerate .hooks import AlignDevicesHook , add_hook_to_module
38-
39- if "" in device_map :
40- d = device_map ["" ]
41- model = model .to (torch .device (d ))
42- model .hf_device_map = device_map
43- return model
44-
45- tied_params = accelerate .utils .modeling .find_tied_parameters (model )
46- if set (device_map .values ()) == {"cpu" } or set (device_map .values ()) == {"cpu" , "disk" }:
47- main_device = "cpu"
48- else :
49- main_device = [d for d in device_map .values () if d not in ["cpu" , "disk" ]][0 ]
50-
51- cpu_offload_group = [(n , d ) for n , d in device_map .items () if d == "cpu" ]
52- prev_hook = None
53- for idx , (n , d ) in enumerate (cpu_offload_group ):
54- m = get_module_by_name_suffix (model , n )
55- _ , prev_hook = accelerate .cpu_offload_with_hook (m , execution_device = main_device , prev_module_hook = prev_hook )
56- # set first cpu offload module's prev_module_hook to the last cpu offload module's hook
57- if len (cpu_offload_group ) > 1 :
58- get_module_by_name_suffix (model , cpu_offload_group [0 ][0 ])._hf_hook .prev_module_hook = prev_hook
59-
60- for n , d in device_map .items ():
61- m = get_module_by_name_suffix (model , n )
62- if d != "cpu" :
63- d = torch .device (d )
64- hook = AlignDevicesHook (d , io_same_device = True , place_submodules = True )
65- add_hook_to_module (m , hook )
66- accelerate .utils .modeling .retie_parameters (model , tied_params )
67- model .hf_device_map = device_map
68-
69- return model
70-
71-
7232class BaseSmoothForCausalLM (nn .Module , PushToHubMixin ):
7333 layer_type : str = None
7434
@@ -132,6 +92,7 @@ def init_batch_state(self, max_output_len=256, **kwargs):
13292 batch_infer_state .past_key_values_len = 0
13393 batch_infer_state .is_context_stage = True
13494 batch_infer_state .set_cache_manager (self .cache_manager )
95+ batch_infer_state .cache_manager .free_all ()
13596 return batch_infer_state
13697
13798 @abstractmethod
@@ -157,15 +118,79 @@ def generate(self, **kwargs):
157118 if self .config .model_type == "llama" :
158119 setattr (self .model .model , "infer_state" , batch_infer_state )
159120
160- batch_infer_state .is_context_stage = True
161-
162121 with torch .inference_mode ():
163122 return self .model .generate (** kwargs )
164123
165124 def prepare_inputs_for_generation (self , * args , ** kwargs ):
166125 """shortcut for model.prepare_inputs_for_generation"""
167126 return self .model .prepare_inputs_for_generation (* args , ** kwargs )
168127
128+ def collect_act_scales (self , model , tokenizer , dataset , device , num_samples = 512 , seq_len = 512 ):
129+ for text in tqdm (dataset ):
130+ input_ids = tokenizer (text , return_tensors = "pt" , max_length = seq_len , truncation = True ).input_ids .to (device )
131+ model (input_ids )
132+
133+ def collect_act_dict (self , model , tokenizer , dataset , act_dict , device , num_samples = 512 , seq_len = 512 ):
134+ pbar = tqdm (dataset )
135+ for text in pbar :
136+ input_ids = tokenizer (text , return_tensors = "pt" , max_length = seq_len , truncation = True ).input_ids .to (device )
137+ model (input_ids )
138+ mean_scale = np .mean ([v ["input" ] for v in act_dict .values ()])
139+ pbar .set_description (f"Mean input scale: { mean_scale :.2f} " )
140+
141+ def get_act_scales (self , model , tokenizer , dataset , num_samples = 512 , seq_len = 512 ):
142+ model .eval ()
143+ device = next (model .parameters ()).device
144+ act_scales = {}
145+
146+ def stat_tensor (name , tensor ):
147+ hidden_dim = tensor .shape [- 1 ]
148+ tensor = tensor .view (- 1 , hidden_dim ).abs ().detach ()
149+ comming_max = torch .max (tensor , dim = 0 )[0 ].float ().cpu ()
150+ if name in act_scales :
151+ act_scales [name ] = torch .max (act_scales [name ], comming_max )
152+ else :
153+ act_scales [name ] = comming_max
154+
155+ def stat_input_hook (m , x , y , name ):
156+ if isinstance (x , tuple ):
157+ x = x [0 ]
158+ stat_tensor (name , x )
159+
160+ hooks = []
161+ for name , m in model .named_modules ():
162+ if isinstance (m , nn .Linear ):
163+ hooks .append (m .register_forward_hook (partial (stat_input_hook , name = name )))
164+
165+ self .collect_act_scales (model , tokenizer , dataset , device , num_samples , seq_len )
166+
167+ for h in hooks :
168+ h .remove ()
169+
170+ return act_scales
171+
172+ @torch .no_grad ()
173+ def smooth_ln_fcs (self , ln , fcs , act_scales , alpha = 0.5 ):
174+ if not isinstance (fcs , list ):
175+ fcs = [fcs ]
176+ for fc in fcs :
177+ assert isinstance (fc , nn .Linear )
178+ assert ln .weight .numel () == fc .in_features == act_scales .numel ()
179+
180+ device , dtype = fcs [0 ].weight .device , fcs [0 ].weight .dtype
181+ act_scales = act_scales .to (device = device , dtype = dtype )
182+ weight_scales = torch .cat ([fc .weight .abs ().max (dim = 0 , keepdim = True )[0 ] for fc in fcs ], dim = 0 )
183+ weight_scales = weight_scales .max (dim = 0 )[0 ].clamp (min = 1e-5 )
184+
185+ scales = (act_scales .pow (alpha ) / weight_scales .pow (1 - alpha )).clamp (min = 1e-5 ).to (device ).to (dtype )
186+
187+ ln .weight .div_ (scales )
188+ if hasattr (ln , "bias" ):
189+ ln .bias .div_ (scales )
190+
191+ for fc in fcs :
192+ fc .weight .mul_ (scales .view (1 , - 1 ))
193+
169194 def save_quantized (
170195 self ,
171196 save_dir : str ,
0 commit comments