Skip to content

Conversation

zhengchenyu
Copy link
Contributor

@zhengchenyu zhengchenyu commented Sep 23, 2025

Describe the bug

When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below:

*** 1. Extracting ZeRO fragments
  0%|                                                     | 0/1 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module>
    main()
  File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main
    ds_to_universal_main(args)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main
    _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3
    _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work
    results.append(do_work(work))
                   ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3
    dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment
    state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072).

To Reproduce
Steps to reproduce the behavior:

  1. Use large model to run, or set sub_group_size to a lower value. Then train and save model
  2. Run ds_to_universal.py

The reason

I found that the previous stage3 universal checkpoint implementation did not take subgroups into account. I also found the following problems during debugging.

  • Unable to handle multiple sub-groups, which will result in data loss
  • When load_checkpoint is True, then all process will save to same zero model checkpoint file. If multiple processes write at the same time, the file will be corrupted. Occasionally, file corruption was discovered during testing.

Relete issue: #7584

filename = "zero_pp_rank_0"
else:
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
# For stage 3, when loading the checkpoint, the world size may change, non-existent files may be loaded,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clarify this idea as it appears to contradict the design principle that we only support world size changes through universal checkpointing. In other words, if self.load_universal_checkpoint() == False there is no world size change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After #5475, all process will save to zero_pp_rank_0. This can cause multiple processes to write to a file simultaneously and lead to file corruption. So I revert this part of the logic.
We have such a scene:

  • (1) we use 2 process to train, we will generate two file: zero_pp_rank_0_mp_rank_00_model_states.pt and zero_pp_rank_1_mp_rank_00_model_states.pt.
  • (2) Then we use ds_to_universal.py, converting to universal checkpoint.
  • (3) Then we use 4 process to train, the rank 2 and rank 3 will read the non-existent file zero_pp_rank_2_mp_rank_00_model_states.pt and zero_pp_rank_3_mp_rank_00_model_states.pt. Then failed.

For stage3, we can read any file, so here just read zero_pp_rank_0_mp_rank_00_model_states.

And it is ok for stage1 and 2, because we have one model_states file for stag 1 and 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu I can understand (1)(2)(3), but could you clarify why the previous version will fail?

if self.load_universal_checkpoint():
filename = "zero_pp_rank_0"

When universal checkpoint is enabled, only the zero_pp_rank_0* will got loaded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version will not fail, but may produce corrupted files. Multiple processes may write to the same file. In my experiments, occasionally, there will be a problem that the model file cannot be loaded, and it prompts that the file is corrupted.
I just said that if I directly revert the relevant logic in #5475, it will cause the loading to fail. (1)(2)(3) just describe why loading fail if I directly revert the relevant logic in #5475.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu I see your point. The loading logic itself is fine, but when we enable the universal checkpoint and then save it, the _get_ckpt_name function will also be called. At that point, filename gets set to zero_pp_rank_0, which could lead to corrupted files.

To make the codebase cleaner and avoid introducing extra complexity, we could consider refactoring:

  1. Remove the self.load_universal_checkpoint() logic from _get_ckpt_name.

  2. Instead, move that logic into _get_all_ckpt_names (which is for loading ckpt) and, if self.load_universal_checkpoint() is true, reset ckpt_file_pattern to zero_pp_rank_0.

This way, we avoid adding a new argument like read_mode and keep the code simpler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu and @xylian86 thanks for resolving the issues.

@zhengchenyu please ping me when you address the last two comments so we can merge this PR. This is a great contribution. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xylian86
The current implementation does not suffer from write corruption issues. _get_all_ckpt_names is only for _load_checkpoint. When we save, echo rank write their own zero_pp_rank_{rank}. It is fine.

And your suggestion to refactor is great, I have modified it and avoided introducing read_mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu Thank you!

@sfc-gh-truwase
Copy link
Collaborator

@xylian86 can you help review?

@xylian86
Copy link
Contributor

When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail

@zhengchenyu Thanks for raising this issue and submitting the PR to address it. I just want to clarify that this is not the actual cause of the problem. The current implementation can already handle multiple subgroups, regardless of how many we shard.

The real cause is that when we initialize the ZeRO-3 Engine, the optimizer passes multiple param_groups. These param_groups may have different hyperparameters, but the current implementation does not yet account for that.

@zhengchenyu
Copy link
Contributor Author

@xylian86
You can see here:

self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups)

Not only will the incoming param_group be split, but if the param_group is too large (i.e., the group exceeds sub_group_size), it will be split twice.

The hyperparameters have been retained, see the following code:

     for param_group in self.optimizer.param_groups:
            # Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter.
            for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items():
                if key == 'params':
                    param_group['params'] = []
                else:
                    param_group[key] = value

Each sub group can only belong to one incoming param group, so the hyperparameters are not lost here.

@xylian86
Copy link
Contributor

@zhengchenyu

There are two levels of parameter grouping in the implementation:

  1. Higher level: param_group - From the user's optimizer config.
  2. Lower level: subgroup - Each param_group is further split based on sub_group_size

For example, the structure looks like:

{
  'param_group1': [subgroup1, subgroup2, subgroup3], 
  'param_group2': [subgroup1, subgroup2]
}

Already handled: Multiple subgroups within a single param_group

  • The existing code properly manages scenarios where one param_group contains multiple subgroups

Not previously covered: Multiple param_groups from user's optimizer configuration

  • This is precisely what your PR addresses, and I appreciate this contribution

When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below:

I disagree with the statement that model size or multiple subgroups cause the ds_to_universal.py failure. The current implementation can handle large models and any number of subgroups, as long as there is only a single param_group.

As I mentioned earlier, the root cause of the error is specifically the presence of multiple param_groups in the user's optimizer configuration, not the model size or subgroup count within individual parameter groups.

@zhengchenyu
Copy link
Contributor Author

zhengchenyu commented Sep 24, 2025

@xylian86

param_groups: List[List[Parameter]] = tuple(

Although param_groups is 2-level structure. Look the below code, param_groups will flatten to 1-level, just bit16_groups.

self.fp16_groups.append(sub_group)

If we only have one incoming param group, but very large, and split three sub group. the structure of param_groups looks like

{
  'param_group1': [subgroup1, subgroup2, subgroup3], 
}

Then the bit16_groups will be [subgroup1, subgroup2, subgroup3]. The problem is whether fp16_groups has multiple subgroups. But not incoming has multiple subgroups.

So my explanation is fine.

@xylian86
Copy link
Contributor

@zhengchenyu Thank you for the clarification. I realize that when saving checkpoints, in optimizer state, each subgroup has separate tensors rather than being concatenated. I believe I’m aligned with you now.

param_shapes = _parse_model_states_stage3(model_files)
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
dp_degree = len(model_files)
dp_degree = len(optim_files)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about let's still use dp_degree = len(model_files)? By design, the number of optim_files should match the number of model_files, since each DP worker saves both a model file and an optimizer file. However, if a developer chooses to resume training using only the weights and deletes the optimizer state files, then len(optim_files) = 0, which would cause issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. I made this change because the number of model files was set to 1 before. However, this PR has reversed this logic. So I will switch back to using len(model_files).

filename = "zero_pp_rank_0"
else:
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
# For stage 3, when loading the checkpoint, the world size may change, non-existent files may be loaded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu I can understand (1)(2)(3), but could you clarify why the previous version will fail?

if self.load_universal_checkpoint():
filename = "zero_pp_rank_0"

When universal checkpoint is enabled, only the zero_pp_rank_0* will got loaded.

self,
module,
init_optimizer,
param_names,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):

In earlier versions we passed param_shapes to supply parameter names for each subgroup. Could we reuse param_shapes here instead of introducing a new param_names argument, which seems redundant? Or is there a strong reason to replace param_shapes with param_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a critical question!

But No, we can not use param_shapes! Because the param_shapes is from checkpoint, param_shapes comes from the checkpoint, that is from the last distributed environment, that is the sub group (or fp16_group) is used in last distributed environment.

If the world size is changed, the sub group (or fp16_group) will changed. Here we set optimizer parameter for current sub group (or fp16_group).

In fact, I also used param_shapes at first. But I found this problem when my new unit test (which sub_group_size is set to 100) failed to run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu Make sense to me. Could you remove the param_shapes to clean up the codebase a bit? Since we’re using param_names, param_shapes is no longer needed..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I have remove param_shapes.

@zhengchenyu
Copy link
Contributor Author

Add some experimental results, we used Qwen3, 0.6b, and 400 steps.
Experimental Group 1:
Base was trained without interruption, world_size=8. Universal was interrupted at steps 50, 100, 200, and 300, and load via the universal checkpoint, world_size=8。

截屏2025-09-25 19 06 34

Experimental Group 2:
Base was trained without interruption, world_size=8. Universal_world_change was interrupted at steps 50, 100, 200, and 300, with the world size changed, and then loaded via the universal checkpoint. At steps 50 and 200, the world_size was reduced from 8 to 4. At steps 100 and 300, the world_size was increased from 4 to 8.

Total_batch_size was kept constant by modifying gradient_accumulation_steps.

截屏2025-09-25 19 09 49

@xylian86
Copy link
Contributor

@zhengchenyu Thank you for sharing the loss trace — it looks good to me. I’ve also left two additional comments on the codebase with suggestions for improvement; could you take a look and address those? The other parts LGTM.

@xylian86
Copy link
Contributor

@sfc-gh-truwase zhengchenyu has addressed all my comments. The current version LGTM.

@sfc-gh-truwase sfc-gh-truwase enabled auto-merge (squash) September 27, 2025 17:11
@sfc-gh-truwase sfc-gh-truwase merged commit 91d1452 into deepspeedai:master Sep 27, 2025
12 checks passed
@zhengchenyu zhengchenyu deleted the fix.universal branch September 28, 2025 03:33
@zhengchenyu
Copy link
Contributor Author

zhengchenyu commented Sep 28, 2025

@sfc-gh-truwase @xylian86
Sorry, I found that there was a problem with the last commit, which caused the universal checkpoint to be unable to be loaded when the world size scale up.
I'm curious why the unit test didn't fail. Originally TestZeROUniversalCheckpointDP::test_dp_world_size_2to4 could reproduce this problem
I have submit new PR #7599 to solve it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants