|
9 | 9 | import torch.optim as optim |
10 | 10 | from coati.models.base import get_base_model |
11 | 11 | from coati.replay_buffer import ReplayBuffer |
| 12 | +from coati.models.base import RewardModel |
| 13 | +from coati.models.lora import LoraLinear |
| 14 | +from coati.replay_buffer import ReplayBuffer |
12 | 15 | from torch.optim import Optimizer |
13 | 16 | from torch.utils.data import DataLoader |
14 | 17 | from transformers.modeling_utils import PreTrainedModel |
@@ -71,8 +74,20 @@ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = No |
71 | 74 | state_dict = torch.load(path, map_location=map_location) |
72 | 75 | optimizer.load_state_dict(state_dict) |
73 | 76 |
|
| 77 | + def save_pretrained(self, |
| 78 | + model: nn.Module, |
| 79 | + path: str, |
| 80 | + only_rank0: bool = True, |
| 81 | + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: |
| 82 | + unwrapped_model = self.unwrap_model(model) |
| 83 | + assert isinstance(unwrapped_model, PreTrainedModel) |
| 84 | + unwrapped_model.save_pretrained(path) |
| 85 | + if tokenizer is not None: |
| 86 | + tokenizer.save_pretrained(path) |
| 87 | + |
74 | 88 | def get_model_state_dict_shard(self, model: nn.Module, **config): |
75 | 89 | # TODO: implement sharding on naive strategy |
| 90 | + model = self.unwrap_model(model) |
76 | 91 | if 'requires_grad_only' in config and config['requires_grad_only'] == True: |
77 | 92 | state_dict = get_grad_required_state_dict(model) |
78 | 93 | else: |
@@ -111,14 +126,3 @@ def _try_init_dist(self, force: bool = False) -> None: |
111 | 126 | except Exception as e: |
112 | 127 | if force: |
113 | 128 | raise e |
114 | | - |
115 | | - def save_pretrained(self, |
116 | | - model: nn.Module, |
117 | | - path: str, |
118 | | - only_rank0: bool = True, |
119 | | - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: |
120 | | - unwrapped_model = self.unwrap_model(model) |
121 | | - assert isinstance(unwrapped_model, PreTrainedModel) |
122 | | - unwrapped_model.save_pretrained(path) |
123 | | - if tokenizer is not None: |
124 | | - tokenizer.save_pretrained(path) |
|
0 commit comments