Skip to content

Commit d565a24

Browse files
authored
[zero] add unit testings for hybrid parallelism (#2486)
1 parent fcc6d61 commit d565a24

File tree

4 files changed

+188
-98
lines changed

4 files changed

+188
-98
lines changed

colossalai/zero/sharded_optim/bookkeeping/bucket_store.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
77

88
def __init__(self, torch_pg: ProcessGroup):
99
super().__init__(torch_pg)
10-
self._grads = dict()
1110
self._params = dict()
1211
self._num_elements_in_bucket = dict()
1312

@@ -19,25 +18,24 @@ def num_elements_in_bucket(self, reduce_rank: int = None):
1918
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
2019
self._num_elements_in_bucket[reduce_rank] += num_elements
2120

22-
def add_grad(self, tensor, reduce_rank: int = None):
23-
self._grads[reduce_rank].append(tensor)
24-
2521
def add_param(self, tensor, reduce_rank: int = None):
2622
self._params[reduce_rank].append(tensor)
2723

2824
def reset(self):
2925
keys = [None] + list(range(self._world_size))
30-
self._grads = {rank: [] for rank in keys}
3126
self._params = {rank: [] for rank in keys}
3227
self._num_elements_in_bucket = {rank: 0 for rank in keys}
3328

3429
def reset_by_rank(self, reduce_rank=None):
35-
self._grads[reduce_rank] = []
3630
self._params[reduce_rank] = []
3731
self._num_elements_in_bucket[reduce_rank] = 0
3832

3933
def get_grad(self, reduce_rank: int = None):
40-
return self._grads[reduce_rank]
34+
param_list = self.get_param(reduce_rank)
35+
for param in param_list:
36+
# the param must have grad for reduction
37+
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
38+
return [param.grad for param in param_list]
4139

4240
def get_param(self, reduce_rank: int = None):
4341
return self._params[reduce_rank]

colossalai/zero/sharded_optim/low_level_optim.py

Lines changed: 73 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
reduce_bucket_size: int = 1024 * 1024, # communication
4747
communication_dtype: Optional[torch.dtype] = None,
4848
overlap_communication: bool = False,
49-
partition_grad: bool = False, # stage 2
49+
partition_grad: bool = False, # stage 2 flag
5050
cpu_offload: bool = False, # cpu offload
5151
forced_dtype: Optional[torch.dtype] = None):
5252

@@ -248,9 +248,13 @@ def _partition_param_list(self, param_list):
248248
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
249249
return params_per_rank
250250

251-
###########################################################
252-
# Backward Reduction Hook
253-
###########################################################
251+
###########################
252+
# Backward Reduction Hook #
253+
###########################
254+
255+
def _grad_handler(self, param, grad, reduce_rank):
256+
self._add_to_reduction_bucket(param, reduce_rank)
257+
return grad
254258

255259
def _attach_reduction_hook(self):
256260
# we iterate over the fp16 params
@@ -268,53 +272,61 @@ def _attach_reduction_hook(self):
268272
else:
269273
reduce_rank = None
270274

271-
def _define_and_attach(param, reduce_rank):
272-
# get the AccumulateGrad object of the param itself
273-
accum_grad_obj = get_grad_accumulate_object(param)
274-
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
275+
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
275276

276-
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
277-
param=param,
278-
reduce_rank=reduce_rank)
277+
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
278+
if self._overlap_communication:
279+
torch.cuda.synchronize()
280+
self._param_store.clear_grads_of_previous_reduced_params()
281+
stream = self._comm_stream
282+
else:
283+
stream = torch.cuda.current_stream()
279284

280-
# define hook
281-
# NOT IMPORTANT BUT GOOD TO KNOW:
282-
# args here is not grad, but allow_unreacable and accumulate_grad
283-
def reduce_grad_hook(*args):
284-
reduction_func()
285+
with torch.cuda.stream(stream):
286+
flat = bucket.flatten()
287+
reduce_global_rank = None
288+
if reduce_rank is not None:
289+
reduce_global_rank = self._dp_global_ranks[reduce_rank]
290+
reduced_flat = reduce_tensor_dp_group(tensor=flat,
291+
dtype=self._communication_dtype,
292+
dst_local_rank=reduce_rank,
293+
dst_global_rank=reduce_global_rank,
294+
group=self._dp_torch_group)
285295

286-
accum_grad_obj.register_hook(reduce_grad_hook)
296+
# update the reduced tensor
297+
if reduce_rank is None or reduce_rank == self._local_rank:
298+
bucket.unflatten_and_copy(reduced_flat)
287299

288-
_define_and_attach(param, reduce_rank)
300+
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
301+
param_bucket = TensorBucket(size=bucket_size)
289302

290-
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
291-
param_size = param.numel()
303+
for tensor in tensor_list:
304+
param_bucket.add_to_bucket(tensor, allow_oversize=True)
292305

293-
# check if the bucket is full
294-
# if full, will reduce the grads already in the bucket
295-
# after reduction, the bucket will be empty
296-
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
297-
self._reduce_grads_in_bucket(reduce_rank)
306+
if param_bucket.is_full_or_oversized():
307+
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
308+
param_bucket.empty()
298309

299-
# the param must not be reduced to ensure correctness
300-
is_param_reduced = self._param_store.is_param_reduced(param)
301-
if is_param_reduced:
302-
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
303-
+ 'duplicate reduction will lead to arithmetic incorrectness'
304-
raise RuntimeError(msg)
310+
if not param_bucket.is_empty():
311+
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
305312

306-
# the param must have grad for reduction
307-
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
313+
def _reduce_grads(self, reduce_rank, grads, bucket_size):
314+
grad_buckets_by_dtype = split_half_float_double(grads)
308315

309-
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
310-
self._bucket_store.add_grad(param.grad, reduce_rank)
311-
self._bucket_store.add_param(param, reduce_rank)
316+
for tensor_list in grad_buckets_by_dtype:
317+
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
318+
bucket_size=bucket_size,
319+
reduce_rank=reduce_rank)
320+
321+
#######################
322+
# Reduction Functions #
323+
#######################
312324

313-
def _reduce_grads_in_bucket(self, reduce_rank=None):
325+
def _run_reduction(self, reduce_rank=None):
314326
# reduce grads
315-
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
316-
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
317-
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
327+
self._reduce_grads(reduce_rank=reduce_rank,
328+
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
329+
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
318330

319331
# use communication stream if overlapping
320332
# communication with computation
@@ -351,50 +363,24 @@ def _reduce_grads_in_bucket(self, reduce_rank=None):
351363

352364
self._bucket_store.reset_by_rank(reduce_rank)
353365

354-
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
355-
grad_buckets_by_dtype = split_half_float_double(grads)
356-
357-
for tensor_list in grad_buckets_by_dtype:
358-
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
359-
360-
##############################
361-
# Reduction Utility Function #
362-
##############################
363-
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
364-
param_bucket = TensorBucket(size=bucket_size)
365-
366-
for tensor in tensor_list:
367-
param_bucket.add_to_bucket(tensor, allow_oversize=True)
368-
369-
if param_bucket.is_full_or_oversized():
370-
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
371-
param_bucket.empty()
372-
373-
if not param_bucket.is_empty():
374-
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
366+
def _add_to_reduction_bucket(self, param, reduce_rank=None):
367+
param_size = param.numel()
375368

376-
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
377-
if self._overlap_communication:
378-
torch.cuda.synchronize()
379-
self._param_store.clear_grads_of_previous_reduced_params()
380-
stream = self._comm_stream
381-
else:
382-
stream = torch.cuda.current_stream()
369+
# check if the bucket is full
370+
# if full, will reduce the grads already in the bucket
371+
# after reduction, the bucket will be empty
372+
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
373+
self._run_reduction(reduce_rank)
383374

384-
with torch.cuda.stream(stream):
385-
flat = bucket.flatten()
386-
reduce_global_rank = None
387-
if reduce_rank is not None:
388-
reduce_global_rank = self._dp_global_ranks[reduce_rank]
389-
reduced_flat = reduce_tensor_dp_group(tensor=flat,
390-
dtype=self._communication_dtype,
391-
dst_local_rank=reduce_rank,
392-
dst_global_rank=reduce_global_rank,
393-
group=self._dp_torch_group)
375+
# the param must not be reduced to ensure correctness
376+
is_param_reduced = self._param_store.is_param_reduced(param)
377+
if is_param_reduced:
378+
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
379+
+ 'duplicate reduction will lead to arithmetic incorrectness'
380+
raise RuntimeError(msg)
394381

395-
# update the reduced tensor
396-
if reduce_rank is None or reduce_rank == self._local_rank:
397-
bucket.unflatten_and_copy(reduced_flat)
382+
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
383+
self._bucket_store.add_param(param, reduce_rank)
398384

399385
################################
400386
# torch.optim.Optimizer methods
@@ -498,8 +484,9 @@ def step(self, closure=None):
498484
# broadcast the updated model weights
499485
handles = []
500486
for group_id in range(self.num_param_groups):
501-
for rank in range(self._world_size):
502-
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
487+
for index in range(self._world_size):
488+
rank = self._dp_global_ranks[index]
489+
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
503490
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
504491
handles.append(handle)
505492

@@ -585,16 +572,16 @@ def _reduce_grad_stage1(self):
585572
param_group = self._fp16_param_groups[group_id]
586573
for param in param_group:
587574
if param.grad is not None:
588-
self._reduce_and_remove_grads_by_bucket(param)
575+
self._add_to_reduction_bucket(param)
589576

590577
# we need to reduce the gradients
591578
# left in the communication bucket
592-
self._reduce_grads_in_bucket()
579+
self._run_reduction()
593580

594581
def _reduce_grad_stage2(self):
595582
# when partition_grads is True, reduction hooks
596583
# are attached in the __init__ function, so we
597584
# only need to reduce the gradients
598585
# left in the communication bucket
599586
for reduce_rank in range(self._world_size):
600-
self._reduce_grads_in_bucket(reduce_rank)
587+
self._run_reduction(reduce_rank)

tests/test_tensor/common_utils/_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
import torch.distributed as dist
7+
from torch.testing import assert_close
78

89
from colossalai.context import ParallelMode
910
from colossalai.core import global_context as gpc
@@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
4142
return tensor_chunk.clone()
4243

4344

44-
def tensor_equal(A, B):
45-
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
45+
def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
46+
assert_close(t_a, t_b, rtol=rtol, atol=atol)
47+
return True
4648

4749

48-
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
50+
def tensor_shard_equal(tensor: torch.Tensor,
51+
shard: torch.Tensor,
52+
rank: int,
53+
world_size: int,
54+
rtol: float = 1e-3,
55+
atol: float = 1e-1):
4956
assert tensor.ndim == shard.ndim
5057
if tensor.shape == shard.shape:
51-
return tensor_equal(tensor, shard)
58+
return tensor_equal(tensor, shard, rtol, atol)
5259
else:
5360
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
5461
if dims_not_eq.numel() == 1:
@@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
5865
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
5966
if rank is None:
6067
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
61-
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
68+
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
6269
else:
6370
raise NotImplementedError
6471

0 commit comments

Comments
 (0)