Skip to content

Commit 2ba46c1

Browse files
kai01aiLysandreJik
authored andcommitted
fix _resize_token_embeddings will set lm head size to 0 when enabled deepspeed zero3 (#26024)
1 parent 8160e42 commit 2ba46c1

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,10 +1451,20 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
14511451
add_hook_to_module(new_embeddings, hook)
14521452
self.set_input_embeddings(new_embeddings)
14531453

1454+
# Update new_num_tokens with the actual size of new_embeddings
1455+
if pad_to_multiple_of is not None:
1456+
if is_deepspeed_zero3_enabled():
1457+
import deepspeed
1458+
1459+
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
1460+
new_num_tokens = new_embeddings.weight.shape[0]
1461+
else:
1462+
new_num_tokens = new_embeddings.weight.shape[0]
1463+
14541464
# if word embeddings are not tied, make sure that lm head is resized as well
14551465
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
14561466
old_lm_head = self.get_output_embeddings()
1457-
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0])
1467+
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
14581468
if hasattr(old_lm_head, "_hf_hook"):
14591469
hook = old_lm_head._hf_hook
14601470
add_hook_to_module(new_lm_head, hook)

0 commit comments

Comments
 (0)