Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/builder/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _partition_layers(self, method):
for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str)
self._logger.info(log_str, ranks=[0])

# Save the partition
self._interval = self.parts[pipeline_rank]
Expand Down
4 changes: 2 additions & 2 deletions colossalai/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .collective import all_gather, reduce_scatter, scatter
from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta

__all__ = [
'all_gather', 'reduce_scatter', 'scatter',
'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
Expand Down
104 changes: 66 additions & 38 deletions colossalai/communication/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.

Expand All @@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
shape = list(temp.shape)
shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
out = torch.cat(out, dim=dim)
return out
# shape = list(temp.shape)
# shape[dim] *= depth
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
# out = list(torch.chunk(out, depth, dim=dim))
# out = [val.contiguous() for val in out]
shape = [1] * len(tensor.shape)
shape[dim] = depth
out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if async_op:
return out, op
else:
return out


def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.

Expand All @@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape,
dtype=temp[0].dtype,
device=get_current_device())
dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode))
return out
# temp = list(torch.chunk(tensor, depth, dim=dim))
# temp = [val.contiguous() for val in temp]
# out = torch.zeros(temp[0].shape,
# dtype=temp[0].dtype,
# device=get_current_device())
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = temp[0].clone()
op = dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
else:
return out


def scatter(tensor: Tensor, src: int, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Scatters in a specific dimension from source rank to all ranks in
the parallel group.
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
async_op=False) -> Tensor:
op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return tensor, op
else:
return tensor


# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.

:param tensor: Tensor to be scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out
# :param tensor: Tensor to be scattered
# :param dim: The dimension scattering in
# :param parallel_mode: Parallel group mode used in this communication
# :type tensor: Tensor
# :type dim: int
# :type parallel_mode: ParallelMode
# :return: The tensor generated by scatter
# :rtype: Tensor
# """
# depth = gpc.get_world_size(parallel_mode)
# temp = tensor.clone()
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
# return out
6 changes: 5 additions & 1 deletion colossalai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@

# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'

# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
26 changes: 16 additions & 10 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,18 @@ def init_global_dist(self, addr=None, port=None):
:type port: int, optional
"""
# get config
rank = self._dist_args.local_rank
local_rank = self._dist_args.local_rank
rank = self._dist_args.rank
world_size = self._dist_args.world_size
if local_rank is None:
local_rank = os.getenv('LOCAL_RANK')
if rank is None:
rank = os.getenv('RANK')
if world_size is None:
world_size = os.getenv('WORLD_SIZE')
# default env config, overwrite by exporting
# them in your bash script

addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
port = os.getenv('MASTER_PORT', '8008') if port is None else port
init_method = f'tcp://{addr}:{port}'
Expand All @@ -293,7 +301,8 @@ def init_global_dist(self, addr=None, port=None):
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self._global_ranks[ParallelMode.GLOBAL] = rank
self.add_global_rank(ParallelMode.GLOBAL, rank)
# self._global_ranks[ParallelMode.GLOBAL] = rank

def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
Expand Down Expand Up @@ -426,18 +435,15 @@ def set_seed(self):
if torch.cuda.is_available():
# create random seed for different parallel modes
# data parallel seed are kept the same
parallel_seed = seed
tp_rank = self._local_ranks.get(ParallelMode.TENSOR, 0)
pp_rank = self._local_ranks.get(ParallelMode.PIPELINE, 0)
parallel_seed = seed + tp_rank + pp_rank * 1024
add_seed(ParallelMode.DATA, parallel_seed)

# model parallel seeds are different across ranks
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)

# add seed for data parallel and tensor parallel only
if self.is_initialized(ParallelMode.TENSOR):
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
# 100 is only to increase the diff in seeds between pipeline stages
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
tp_seed = seed + tp_rank_with_offset
dp_rank = self._local_ranks.get(ParallelMode.DATA, 0) + 1
tp_seed = parallel_seed + dp_rank * 128
add_seed(ParallelMode.TENSOR, tp_seed)

set_mode(ParallelMode.DATA)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os

import torch.distributed as dist
from colossalai.constants import DEPTH_3D
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
Expand All @@ -18,7 +18,7 @@ def _check_depth_env_var(depth):

if env_depth:
assert int(env_depth) == depth, \
'SUMMA_DIM has been set in the current environment and ' \
'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[DEPTH_3D] = str(depth)
Expand All @@ -43,6 +43,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D

for h in range(self.num_group):
for i in range(self.depth):
Expand Down Expand Up @@ -82,6 +83,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D

for h in range(self.num_group):
for k in range(self.depth):
Expand Down Expand Up @@ -121,6 +123,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D

for h in range(self.num_group):
for i in range(self.depth):
Expand Down
14 changes: 11 additions & 3 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,21 @@ def parse_args():
type=str,
default=None,
help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
help='rank for the default process group')
help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for torch.distributed')
help='backend for distributed communication')
return parser.parse_args()


def init_dist(config: Union[str, dict] = None,
local_rank: int = None,
rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
Expand Down Expand Up @@ -86,6 +88,8 @@ def init_dist(config: Union[str, dict] = None,
config = args.config
if local_rank is None:
local_rank = args.local_rank
if rank is None:
rank = args.rank
if world_size is None:
world_size = args.world_size
if host is None:
Expand All @@ -99,12 +103,14 @@ def init_dist(config: Union[str, dict] = None,
host=host,
port=port,
world_size=world_size,
rank=rank,
local_rank=local_rank,
backend=backend))

# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
rank=rank,
world_size=args.world_size,
backend=args.backend))

Expand Down Expand Up @@ -178,6 +184,7 @@ def seed_worker(worker_id):

def initialize(config: Union[str, dict] = None,
local_rank: int = None,
rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
Expand Down Expand Up @@ -209,6 +216,7 @@ def initialize(config: Union[str, dict] = None,
# initialize distributed environment
init_dist(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
Expand Down
33 changes: 33 additions & 0 deletions colossalai/nn/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import math

from torch import Tensor
from torch.nn import init as init


def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)

def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)
16 changes: 9 additions & 7 deletions colossalai/nn/layer/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import math

import numpy as np
from colossalai.utils.common import print_rank_0
import torch
from torch import Tensor
from torch import nn
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint

from colossalai.constants import IS_TENSOR_PARALLEL
from torch import Tensor, nn


def divide(numerator, denominator):
Expand All @@ -33,9 +33,11 @@ def swish(x: Tensor) -> Tensor:
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
def set_tensor_parallel_attribute(param, size):
# if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
# if not hasattr(param, NUM_PARTITIONS):
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))


class CheckpointModule(nn.Module):
Expand Down
Loading