Skip to content

[BUG]: HybridParallelOptimizer holds unsharded model parameters after sharding #5539

@insujang

Description

@insujang

🐛 Describe the bug

When using tensor parallelism, model parameters are sharded across GPUs to reduce its memory consumption and parallel execution.
However, the optimizer still holds unsharded model parameters, preventing the old unsharded parameters from being released, taking more memory.

Example code: (adopted from examples/language/gpt2/hybridparallelism/finetune.py)

colossalai.launch_from_torch(config={})
plugin = HybridParallelPlugin(tp_size=4, pp_size=1)
optimizer = Adam(model.parameters())
# initialize dataloader

model, optimizer, *_ = booster.booster(model, optimizer, ...)
> model.module.transformer.wte.weight
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', dtype=torch.float16, requires_grad=True)

> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])

> optimizer.param_groups[0]["params"][0]
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', requires_grad=True)

> optimizer.param_groups[0]["params"[0].shape
torch.Size([50257, 768])

This also affects MixedPrecisionOptimizer.master_to_working_map and MixedPrecisionOptimizer.working_to_master_map:

# model.module.transformer.wte.weight is supposed to be in a working parameter
> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])
> id(model.module.transformer.wte.weight)
139684649437120

# First working parameter in map does not refer to this
> list(iter(optimizer.master_to_working_map))[0].shape
torch.Size([50257, 768])
> id(list(iter(optimizer.master_to_working_map))[0])
139693862695728

Because of this it seems only a portion of parameters (ie. unsharded ones) only trained, as MixedPrecisionOptimizer.step() skips sharded parameters as gradients are not stored in mismatched unsharded parameters:

if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None

Environment

PyTorch 2.2.1 / CUDA 12.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions