Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions colossalai/zero/sharded_optim/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def split_half_float_double(tensor_list):
return buckets


def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[ProcessGroup] = None):
def reduce_tensor_dp_group(tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None):
"""
Reduce the tensor in the data parallel process group

Expand All @@ -128,36 +132,22 @@ def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[Proce
else:
tensor_to_reduce = tensor

if isinstance(pg, ProcessGroup):
group = pg.dp_process_group()
world_size = pg.dp_world_size()
else:
world_size = gpc.get_world_size(ParallelMode.DATA)
group = gpc.get_group(ParallelMode.DATA)

world_size = dist.get_world_size(group=group)
tensor_to_reduce.div_(world_size)

# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
use_all_reduce = dst_local_rank is None

if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
else:
if pg is not None:
ranks_in_group = pg.dp_rank_list()
else:
ranks_in_group = gpc.get_ranks_in_group(ParallelMode.DATA)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)

# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
if pg is not None:
local_rank = pg.dp_local_rank()
else:
local_rank = gpc.get_local_rank(ParallelMode.DATA)
if use_all_reduce or dst_rank == local_rank:
local_rank = dist.get_rank(group=group)
if use_all_reduce or dst_local_rank == local_rank:
tensor.copy_(tensor_to_reduce)

return tensor
Expand Down
17 changes: 5 additions & 12 deletions colossalai/zero/sharded_optim/bookkeeping/base_store.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
from typing import Optional

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
import torch.distributed as dist
from torch.distributed import ProcessGroup


class BaseStore:

def __init__(self, pg: Optional[ProcessGroup] = None):
if isinstance(pg, ProcessGroup):
self._world_size = pg.dp_world_size()
self._local_rank = pg.dp_local_rank()
else:
self._world_size = gpc.get_world_size(ParallelMode.DATA)
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)

@property
def world_size(self):
Expand Down
8 changes: 3 additions & 5 deletions colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Optional

from colossalai.tensor import ProcessGroup
from torch.distributed import ProcessGroup

from .base_store import BaseStore


class BucketStore(BaseStore):

def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(pg)
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import List, Optional
from typing import List

from torch import Tensor

from colossalai.tensor import ProcessGroup
from torch.distributed import ProcessGroup

from .base_store import BaseStore


class ParameterStore(BaseStore):

def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(pg)
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
Expand Down
135 changes: 69 additions & 66 deletions colossalai/zero/sharded_optim/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device

from ._utils import (
Expand All @@ -34,32 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__(
self,
optimizer: Optimizer,
pg: Optional[ProcessGroup] = None,
# grad scaler config
initial_scale=2**16,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=2000,
hysteresis=2,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,

# grad clipping
clip_grad_norm=0.0,
verbose=False,

# communication
reduce_bucket_size=1024 * 1024,
communication_dtype=None,
overlap_communication=False,

# stage 2
partition_grad=False,
# cpu offload
cpu_offload=False,

# forced dtype
forced_dtype=None):
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):

# TODO: add support for
# 1. fp16 master weights
Expand All @@ -76,31 +65,30 @@ def __init__(

self._cpu_offload = cpu_offload

self._pg = pg
if isinstance(pg, ProcessGroup):
self._local_rank = pg.dp_local_rank()
self._world_size = pg.dp_world_size()
self._dp_group = pg.dp_process_group()
if pg.tp_world_size() > 1:
self._mp_group = pg.tp_process_group()
else:
self._mp_group = None
elif pg is None:
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
self._world_size = colo_pg.dp_world_size()
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
self._dp_torch_group = colo_pg.dp_process_group()
self._mp_torch_group = None
if colo_pg.tp_world_size() > 1:
self._mp_torch_group = colo_pg.tp_process_group()
elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL

self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)

self._dp_group = gpc.get_group(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode)
else:
self._mp_group = None
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
raise TypeError(f"pg should be None or a ProcesGroup")
raise NotImplementedError
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
Expand Down Expand Up @@ -136,14 +124,9 @@ def __init__(

# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
if self._pg is not None:
self._param_store = ParameterStore(self._pg)
self._grad_store = GradientStore(self._pg)
self._bucket_store = BucketStore(self._pg)
else:
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group)
self._bucket_store = BucketStore(self._dp_torch_group)

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
Expand Down Expand Up @@ -224,6 +207,30 @@ def loss_scale(self):
def num_param_groups(self):
return len(self._fp16_param_groups)

def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg

def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
Expand All @@ -241,14 +248,6 @@ def _partition_param_list(self, param_list):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank

def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

###########################################################
# Backward Reduction Hook
###########################################################
Expand Down Expand Up @@ -384,10 +383,14 @@ def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):

with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_rank=reduce_rank,
pg=self._pg)
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)

# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
Expand Down Expand Up @@ -456,8 +459,8 @@ def step(self, closure=None):
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_group,
mp_group=self._mp_group)
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)

# create flat gradient for the flat fp32 params
Expand Down Expand Up @@ -497,7 +500,7 @@ def step(self, closure=None):
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)

for handle in handles:
Expand All @@ -519,11 +522,11 @@ def _check_overflow(self):
break

# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)

# all-reduce over model parallel group
if self._mp_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)

if self._found_overflow.item() > 0:
return True
Expand Down
12 changes: 3 additions & 9 deletions tests/test_zero/low_level_zero/test_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@ def exam_zero_1_2_grad_acc():
# create model
zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
Expand Down Expand Up @@ -86,7 +83,7 @@ def fwd_bwd_func(number, cur_data):
assert torch.equal(z1p.data, z2p.data)


def exam_zero_1_grad_acc(use_pg=True):
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
Expand All @@ -105,9 +102,7 @@ def exam_zero_1_grad_acc(use_pg=True):
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
pg = ProcessGroup() if use_pg else None #ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
Expand Down Expand Up @@ -158,9 +153,8 @@ def fwd_bwd_func(number, cur_data, check_flag):
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

exam_zero_1_grad_acc(True)
exam_zero_1_grad_acc(False)
# exam_zero_1_2_grad_acc()
exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc()


@pytest.mark.dist
Expand Down
Loading