Skip to content

Commit ea8bcdc

Browse files
committed
fix
1 parent 5464925 commit ea8bcdc

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

colossalai/shardformer/layer/embedding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ def __init__(
304304

305305
# deal with tensor parallelism
306306
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
307-
self.num_embeddings = self.num_embeddings_per_partition
308307
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
309308
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
310309

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,15 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
518518
p_mapping = param_to_save_data
519519
for name, param in self.name2param.items():
520520
if param is not None:
521-
origin_shape = self.params_info["name2shape"][name]
522521
if is_ddp_ignored(param):
523522
# deal with ddp ignored parameters
524523
destination[prefix + name] = param if keep_vars else param.detach()
525524
else:
526-
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
525+
if self.params_info is not None:
526+
origin_shape = self.params_info["name2shape"][name]
527+
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
528+
else:
529+
destination[prefix + name] = p_mapping[param]
527530
del p_mapping
528531
del param_to_save_data
529532

@@ -891,8 +894,10 @@ def state_dict_shard(
891894
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
892895
gathered_param = gathered_param_buffer.pop(param_to_save)
893896

894-
origin_shape = self.params_info["name2shape"][name]
895-
gathered_param = gathered_param[: origin_shape[0], ...]
897+
if self.params_info is not None:
898+
origin_shape = self.params_info["name2shape"][name]
899+
gathered_param = gathered_param[: origin_shape[0], ...]
900+
896901
block, block_size = sharder.append_param(prefix + name, gathered_param)
897902
if block is not None:
898903
yield block, block_size

0 commit comments

Comments
 (0)