- 
                Notifications
    You must be signed in to change notification settings 
- Fork 4.5k
Description
🐛 Describe the bug
When using tensor parallelism, model parameters are sharded across GPUs to reduce its memory consumption and parallel execution.
However, the optimizer still holds unsharded model parameters, preventing the old unsharded parameters from being released, taking more memory.
Example code: (adopted from examples/language/gpt2/hybridparallelism/finetune.py)
colossalai.launch_from_torch(config={})
plugin = HybridParallelPlugin(tp_size=4, pp_size=1)
optimizer = Adam(model.parameters())
# initialize dataloader
model, optimizer, *_ = booster.booster(model, optimizer, ...)> model.module.transformer.wte.weight
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', dtype=torch.float16, requires_grad=True)
> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])
> optimizer.param_groups[0]["params"][0]
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', requires_grad=True)
> optimizer.param_groups[0]["params"[0].shape
torch.Size([50257, 768])
This also affects MixedPrecisionOptimizer.master_to_working_map and MixedPrecisionOptimizer.working_to_master_map:
# model.module.transformer.wte.weight is supposed to be in a working parameter
> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])
> id(model.module.transformer.wte.weight)
139684649437120
# First working parameter in map does not refer to this
> list(iter(optimizer.master_to_working_map))[0].shape
torch.Size([50257, 768])
> id(list(iter(optimizer.master_to_working_map))[0])
139693862695728
Because of this it seems only a portion of parameters (ie. unsharded ones) only trained, as MixedPrecisionOptimizer.step() skips sharded parameters as gradients are not stored in mismatched unsharded parameters:
ColossalAI/colossalai/amp/naive_amp/mixed_precision_optimizer.py
Lines 173 to 175 in df5e9c5
| if working_param.grad is not None: | |
| p.grad = working_param.grad.data.float() | |
| working_param.grad = None | 
Environment
PyTorch 2.2.1 / CUDA 12.1