Skip to content

Commit e3dc315

Browse files
CsRiccsricver217
authored
state_dict sending adapts to new unwrap function (hpcaitech#20)
* prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <[email protected]> Co-authored-by: ver217 <[email protected]>
1 parent 4843df6 commit e3dc315

File tree

9 files changed

+115
-65
lines changed

9 files changed

+115
-65
lines changed

applications/Chat/coati/models/generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def sample(model: nn.Module,
7777
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
7878
if update_model_kwargs_fn is not None:
7979
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
80+
8081
# if eos_token was found in one sentence, set sentence to finished
8182
if eos_token_id is not None:
8283
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

applications/Chat/coati/ray/detached_trainer_ppo.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,14 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
120120
ray.get(tasks)
121121
# sending loop
122122
tasks = []
123-
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config):
123+
124+
for state_dict_shard in self._get_model_state_dict_shard(self.actor, **config):
124125
for target_holder in self.target_holder_list:
125126
tasks.append(
126127
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
127128
fully_update=fully_update))
128129
# sending loop
129-
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
130+
for state_dict_shard in self._get_model_state_dict_shard(self.critic, **config):
130131
for target_holder in self.target_holder_list:
131132
tasks.append(
132133
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
@@ -176,28 +177,6 @@ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None
176177
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
177178
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
178179

179-
def _get_unwrapped_actor(self):
180-
if False:
181-
pass
182-
elif isinstance(self.strategy, ColossalAIStrategy):
183-
ret = Actor(self.strategy._unwrap_model(self.actor))
184-
return ret
185-
elif isinstance(self.strategy, DDPStrategy):
186-
return Actor(self.strategy._unwrap_actor(self.actor))
187-
elif isinstance(self.strategy, NaiveStrategy):
188-
return self.actor
189-
190-
def _get_unwrapped_critic(self):
191-
if False:
192-
pass
193-
elif isinstance(self.strategy, ColossalAIStrategy):
194-
ret = self.strategy._unwrap_model(self.critic)
195-
return ret
196-
elif isinstance(self.strategy, DDPStrategy):
197-
return self.critic.module
198-
elif isinstance(self.strategy, NaiveStrategy):
199-
return self.critic
200-
201180
def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
202181
# try:
203182
# self.strategy.merge_lora_weight(model)

applications/Chat/coati/ray/utils.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,6 @@ def set_dist_env(env_info: Dict[str, str]):
120120
os.environ['MASTER_ADDR'] = env_info['master_addr']
121121

122122

123-
def state_dict_to(state_dict: Dict[str, Any],
124-
dtype: torch.dtype = torch.float16,
125-
device: torch.device = torch.device('cpu')):
126-
'''
127-
keep state_dict intact
128-
'''
129-
new_state_dict = {}
130-
for k, v in state_dict.items():
131-
new_state_dict[k] = v.to(dtype=dtype, device=device)
132-
return new_state_dict
133-
134-
135123
def get_model_numel(model: nn.Module) -> int:
136124
numel = sum(p.numel() for p in model.parameters())
137125
return numel
@@ -150,3 +138,41 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
150138
# a receiver may have more than one sender
151139
target_receivers.append(sender_idx % num_receivers)
152140
return target_receivers
141+
142+
143+
def state_dict_to(state_dict: Dict[str, Any],
144+
dtype: torch.dtype = torch.float16,
145+
device: torch.device = torch.device('cpu')):
146+
'''
147+
keep state_dict intact
148+
'''
149+
new_state_dict = OrderedDict()
150+
for k, v in state_dict.items():
151+
new_state_dict[k] = v.to(dtype=dtype, device=device)
152+
return new_state_dict
153+
154+
155+
def state_dict_filter_lora(state_dict: Dict[str, Any], keep_non_lora = False):
156+
'''
157+
if keep_non_lora, also return non_lora state_dict
158+
'''
159+
state_dict_lora = OrderedDict()
160+
state_dict_non_lora = OrderedDict()
161+
for k, v in state_dict:
162+
if 'lora_A' in k or 'lora_B' in k:
163+
state_dict_lora[k] = v
164+
elif keep_non_lora:
165+
state_dict_non_lora[k] = v
166+
if keep_non_lora:
167+
return state_dict_lora, state_dict_non_lora
168+
else:
169+
return state_dict_lora
170+
171+
172+
def state_dict_lora_reconstruct(state_dict_lora: Dict[str, Any]):
173+
'''
174+
xxx.lora_A, xxx.lora_B -->> xxx.weight
175+
'''
176+
state_dict_reconstruct = OrderedDict()
177+
178+

applications/Chat/coati/trainer/strategies/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ def unwrap_model(model: nn.Module) -> nn.Module:
104104
"""
105105
return get_base_model(model)
106106

107-
@staticmethod
108-
def _unwrap_critic(critic: Critic) -> nn.Module:
109-
return Strategy._unwrap_model(critic)
110-
111107
@abstractmethod
112108
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
113109
pass
@@ -134,3 +130,7 @@ def save_pretrained(self,
134130
only_rank0: bool = True,
135131
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
136132
pass
133+
134+
@abstractmethod
135+
def get_model_state_dict_shard(self, model: nn.Module, **config):
136+
pass

applications/Chat/coati/trainer/strategies/colossalai.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,6 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal
171171
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
172172
torch.save(optimizer.state_dict(), path)
173173

174-
def get_model_state_dict_shard(self, model: nn.Module, **config):
175-
if self.stage != 3:
176-
yield from super().get_model_state_dict_shard(model, **config)
177-
else:
178-
# unwrapped_model = self._unwrap_model(model)
179-
# for module in unwrapped_model.modules():
180-
# if isinstance(module, LoraLinear):
181-
# module.merge_weights = True
182-
# module.eval()
183-
model: ZeroDDP = model
184-
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
185-
186174
def unwrap_model(self, model: nn.Module) -> nn.Module:
187175
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
188176
if self.stage == 3:
@@ -198,3 +186,15 @@ def save_pretrained(self,
198186
if self.stage == 3:
199187
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
200188
super().save_pretrained(model, path, only_rank0, tokenizer)
189+
190+
def get_model_state_dict_shard(self, model: nn.Module, **config):
191+
if self.stage != 3:
192+
yield from super().get_model_state_dict_shard(model, **config)
193+
else:
194+
# unwrapped_model = self._unwrap_model(model)
195+
# for module in unwrapped_model.modules():
196+
# if isinstance(module, LoraLinear):
197+
# module.merge_weights = True
198+
# module.eval()
199+
base_model: ZeroDDP = get_base_model(model)
200+
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

applications/Chat/coati/trainer/strategies/naive.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import torch.optim as optim
1010
from coati.models.base import get_base_model
1111
from coati.replay_buffer import ReplayBuffer
12+
from coati.models.base import RewardModel
13+
from coati.models.lora import LoraLinear
14+
from coati.replay_buffer import ReplayBuffer
1215
from torch.optim import Optimizer
1316
from torch.utils.data import DataLoader
1417
from transformers.modeling_utils import PreTrainedModel
@@ -71,8 +74,20 @@ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = No
7174
state_dict = torch.load(path, map_location=map_location)
7275
optimizer.load_state_dict(state_dict)
7376

77+
def save_pretrained(self,
78+
model: nn.Module,
79+
path: str,
80+
only_rank0: bool = True,
81+
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
82+
unwrapped_model = self.unwrap_model(model)
83+
assert isinstance(unwrapped_model, PreTrainedModel)
84+
unwrapped_model.save_pretrained(path)
85+
if tokenizer is not None:
86+
tokenizer.save_pretrained(path)
87+
7488
def get_model_state_dict_shard(self, model: nn.Module, **config):
7589
# TODO: implement sharding on naive strategy
90+
model = self.unwrap_model(model)
7691
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
7792
state_dict = get_grad_required_state_dict(model)
7893
else:
@@ -111,14 +126,3 @@ def _try_init_dist(self, force: bool = False) -> None:
111126
except Exception as e:
112127
if force:
113128
raise e
114-
115-
def save_pretrained(self,
116-
model: nn.Module,
117-
path: str,
118-
only_rank0: bool = True,
119-
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
120-
unwrapped_model = self.unwrap_model(model)
121-
assert isinstance(unwrapped_model, PreTrainedModel)
122-
unwrapped_model.save_pretrained(path)
123-
if tokenizer is not None:
124-
tokenizer.save_pretrained(path)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
logs/*
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
PROMPT_PATH=/home/lccsr/data3/awesome-chatgpt-prompts/prompts.csv
3+
4+
num_trainers=4
5+
num_makers=4
6+
7+
# "facebook/opt-2.7b"
8+
for pretrain in "facebook/opt-1.3b" "facebook/opt-6.7b" "facebook/opt-13b"
9+
do
10+
11+
for experience_batch_size in 16 32 64
12+
do
13+
for train_batch_size in 16 32 64
14+
do
15+
for update_steps in 8 32 128
16+
do
17+
# set a big enough experience_steps for twice maker-update
18+
experience_steps=$((2*num_trainers*train_batch_size*update_steps/num_makers/experience_batch_size))
19+
20+
config_string=${num_trainers}_${num_makers}_pretrain_${pretrain##*/}_experience_batch_size_${experience_batch_size}_train_batch_size_${train_batch_size}_update_steps_${update_steps}_experience_steps_${experience_steps}
21+
echo running: ${config_string}
22+
23+
nohup python mmmt_prompt.py \
24+
--prompt_path $PROMPT_PATH \
25+
--trainer_strategy colossalai_gemini --maker_strategy naive \
26+
--model 'opt' \
27+
--pretrain $pretrain \
28+
--critic_pretrain "facebook/opt-350m" \
29+
--num_trainers $num_trainers \
30+
--num_makers $num_makers \
31+
--experience_steps $experience_steps \
32+
--experience_batch_size $experience_batch_size \
33+
--update_steps $update_steps \
34+
--train_batch_size $train_batch_size \
35+
--debug > logs/output_${config_string}.txt 2>&1
36+
done
37+
done
38+
done
39+
done

applications/Chat/examples/ray/mmmt_prompt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def model_fn():
101101
]
102102

103103
def trainer_model_fn():
104-
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
105-
critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
104+
actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
105+
critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
106106
return actor, critic
107107

108108
# configure Trainer

0 commit comments

Comments
 (0)