Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ class BucketStore(BaseStore):

def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()

Expand All @@ -19,25 +18,24 @@ def num_elements_in_bucket(self, reduce_rank: int = None):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements

def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)

def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)

def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}

def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0

def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]

def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
159 changes: 73 additions & 86 deletions colossalai/zero/sharded_optim/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):

Expand Down Expand Up @@ -248,9 +248,13 @@ def _partition_param_list(self, param_list):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank

###########################################################
# Backward Reduction Hook
###########################################################
###########################
# Backward Reduction Hook #
###########################

def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad

def _attach_reduction_hook(self):
# we iterate over the fp16 params
Expand All @@ -268,53 +272,61 @@ def _attach_reduction_hook(self):
else:
reduce_rank = None

def _define_and_attach(param, reduce_rank):
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))

reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param,
reduce_rank=reduce_rank)
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()

# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args):
reduction_func()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)

accum_grad_obj.register_hook(reduce_grad_hook)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)

_define_and_attach(param, reduce_rank)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)

def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)

# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()

# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)

# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)

self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)

#######################
# Reduction Functions #
#######################

def _reduce_grads_in_bucket(self, reduce_rank=None):
def _run_reduction(self, reduce_rank=None):
# reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))

# use communication stream if overlapping
# communication with computation
Expand Down Expand Up @@ -351,50 +363,24 @@ def _reduce_grads_in_bucket(self, reduce_rank=None):

self._bucket_store.reset_by_rank(reduce_rank)

def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)

for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)

##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)

for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)

if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()

if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _add_to_reduction_bucket(self, param, reduce_rank=None):
param_size = param.numel()

def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._run_reduction(reduce_rank)

with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)

# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)

################################
# torch.optim.Optimizer methods
Expand Down Expand Up @@ -498,8 +484,9 @@ def step(self, closure=None):
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)

Expand Down Expand Up @@ -585,16 +572,16 @@ def _reduce_grad_stage1(self):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param)
self._add_to_reduction_bucket(param)

# we need to reduce the gradients
# left in the communication bucket
self._reduce_grads_in_bucket()
self._run_reduction()

def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank)
self._run_reduction(reduce_rank)
17 changes: 12 additions & 5 deletions tests/test_tensor/common_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.distributed as dist
from torch.testing import assert_close

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
Expand Down Expand Up @@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
return tensor_chunk.clone()


def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
assert_close(t_a, t_b, rtol=rtol, atol=atol)
return True


def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
def tensor_shard_equal(tensor: torch.Tensor,
shard: torch.Tensor,
rank: int,
world_size: int,
rtol: float = 1e-3,
atol: float = 1e-1):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
return tensor_equal(tensor, shard, rtol, atol)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
Expand All @@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
if rank is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
else:
raise NotImplementedError

Expand Down
Loading