-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Fix the universal checkpoint issue for stage3 when there are multiple subgroups. #7585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
deepspeed/runtime/engine.py
Outdated
filename = "zero_pp_rank_0" | ||
else: | ||
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) | ||
# For stage 3, when loading the checkpoint, the world size may change, non-existent files may be loaded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clarify this idea as it appears to contradict the design principle that we only support world size changes through universal checkpointing. In other words, if self.load_universal_checkpoint() == False
there is no world size change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After #5475, all process will save to zero_pp_rank_0. This can cause multiple processes to write to a file simultaneously and lead to file corruption. So I revert this part of the logic.
We have such a scene:
- (1) we use 2 process to train, we will generate two file:
zero_pp_rank_0_mp_rank_00_model_states.pt
andzero_pp_rank_1_mp_rank_00_model_states.pt
. - (2) Then we use ds_to_universal.py, converting to universal checkpoint.
- (3) Then we use 4 process to train, the rank 2 and rank 3 will read the non-existent file
zero_pp_rank_2_mp_rank_00_model_states.pt
andzero_pp_rank_3_mp_rank_00_model_states.pt
. Then failed.
For stage3, we can read any file, so here just read zero_pp_rank_0_mp_rank_00_model_states
.
And it is ok for stage1 and 2, because we have one model_states file for stag 1 and 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu I can understand (1)(2)(3), but could you clarify why the previous version will fail?
DeepSpeed/deepspeed/runtime/engine.py
Lines 3003 to 3004 in 17d80ce
if self.load_universal_checkpoint(): | |
filename = "zero_pp_rank_0" |
When universal checkpoint is enabled, only the zero_pp_rank_0*
will got loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous version will not fail, but may produce corrupted files. Multiple processes may write to the same file. In my experiments, occasionally, there will be a problem that the model file cannot be loaded, and it prompts that the file is corrupted.
I just said that if I directly revert the relevant logic in #5475, it will cause the loading to fail. (1)(2)(3) just describe why loading fail if I directly revert the relevant logic in #5475.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu I see your point. The loading logic itself is fine, but when we enable the universal checkpoint and then save it, the _get_ckpt_name
function will also be called. At that point, filename gets set to zero_pp_rank_0
, which could lead to corrupted files.
To make the codebase cleaner and avoid introducing extra complexity, we could consider refactoring:
-
Remove the
self.load_universal_checkpoint()
logic from_get_ckpt_name
. -
Instead, move that logic into
_get_all_ckpt_names
(which is for loading ckpt) and, ifself.load_universal_checkpoint()
is true, resetckpt_file_pattern
tozero_pp_rank_0
.
This way, we avoid adding a new argument like read_mode
and keep the code simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu and @xylian86 thanks for resolving the issues.
@zhengchenyu please ping me when you address the last two comments so we can merge this PR. This is a great contribution. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xylian86
The current implementation does not suffer from write corruption issues. _get_all_ckpt_names
is only for _load_checkpoint
. When we save, echo rank write their own zero_pp_rank_{rank}
. It is fine.
And your suggestion to refactor is great, I have modified it and avoided introducing read_mode
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu Thank you!
@xylian86 can you help review? |
@zhengchenyu Thanks for raising this issue and submitting the PR to address it. I just want to clarify that this is not the actual cause of the problem. The current implementation can already handle multiple subgroups, regardless of how many we shard. The real cause is that when we initialize the ZeRO-3 Engine, the optimizer passes multiple param_groups. These param_groups may have different hyperparameters, but the current implementation does not yet account for that. |
@xylian86 DeepSpeed/deepspeed/runtime/zero/stage3.py Line 727 in bc9ed47
Not only will the incoming param_group be split, but if the param_group is too large (i.e., the group exceeds sub_group_size), it will be split twice. The hyperparameters have been retained, see the following code:
Each sub group can only belong to one incoming param group, so the hyperparameters are not lost here. |
There are two levels of parameter grouping in the implementation:
For example, the structure looks like: {
'param_group1': [subgroup1, subgroup2, subgroup3],
'param_group2': [subgroup1, subgroup2]
} Already handled: Multiple subgroups within a single
Not previously covered: Multiple
I disagree with the statement that model size or multiple subgroups cause the As I mentioned earlier, the root cause of the error is specifically the presence of multiple |
DeepSpeed/deepspeed/runtime/zero/stage3.py Line 726 in 17d80ce
Although DeepSpeed/deepspeed/runtime/zero/stage3.py Line 735 in 17d80ce
If we only have one incoming param group, but very large, and split three sub group. the structure of
Then the So my explanation is fine. |
@zhengchenyu Thank you for the clarification. I realize that when saving checkpoints, in |
param_shapes = _parse_model_states_stage3(model_files) | ||
param_shapes = {k: v for d in param_shapes for k, v in d.items()} | ||
dp_degree = len(model_files) | ||
dp_degree = len(optim_files) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about let's still use dp_degree = len(model_files)
? By design, the number of optim_files
should match the number of model_files
, since each DP worker saves both a model file and an optimizer file. However, if a developer chooses to resume training using only the weights and deletes the optimizer state files, then len(optim_files) = 0
, which would cause issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. I made this change because the number of model files was set to 1 before. However, this PR has reversed this logic. So I will switch back to using len(model_files).
deepspeed/runtime/engine.py
Outdated
filename = "zero_pp_rank_0" | ||
else: | ||
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) | ||
# For stage 3, when loading the checkpoint, the world size may change, non-existent files may be loaded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu I can understand (1)(2)(3), but could you clarify why the previous version will fail?
DeepSpeed/deepspeed/runtime/engine.py
Lines 3003 to 3004 in 17d80ce
if self.load_universal_checkpoint(): | |
filename = "zero_pp_rank_0" |
When universal checkpoint is enabled, only the zero_pp_rank_0*
will got loaded.
self, | ||
module, | ||
init_optimizer, | ||
param_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepSpeed/deepspeed/runtime/zero/stage3.py
Lines 2868 to 2872 in 17d80ce
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, | |
param_shapes): | |
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes) | |
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes): |
In earlier versions we passed param_shapes
to supply parameter names for each subgroup. Could we reuse param_shapes
here instead of introducing a new param_names
argument, which seems redundant? Or is there a strong reason to replace param_shapes
with param_names
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a critical question!
But No, we can not use param_shapes
! Because the param_shapes
is from checkpoint, param_shapes
comes from the checkpoint, that is from the last distributed environment, that is the sub group
(or fp16_group
) is used in last distributed environment.
If the world size is changed, the sub group
(or fp16_group
) will changed. Here we set optimizer parameter for current sub group
(or fp16_group
).
In fact, I also used param_shapes
at first. But I found this problem when my new unit test (which sub_group_size is set to 100) failed to run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhengchenyu Make sense to me. Could you remove the param_shapes
to clean up the codebase a bit? Since we’re using param_names
, param_shapes
is no longer needed..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I have remove param_shapes
.
@zhengchenyu Thank you for sharing the loss trace — it looks good to me. I’ve also left two additional comments on the codebase with suggestions for improvement; could you take a look and address those? The other parts LGTM. |
@sfc-gh-truwase zhengchenyu has addressed all my comments. The current version LGTM. |
@sfc-gh-truwase @xylian86 |
Describe the bug
When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below:
To Reproduce
Steps to reproduce the behavior:
The reason
I found that the previous stage3 universal checkpoint implementation did not take subgroups into account. I also found the following problems during debugging.
Relete issue: #7584