diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py index caf5c8472cc6..4de5c96cbbea 100644 --- a/colossalai/builder/pipeline.py +++ b/colossalai/builder/pipeline.py @@ -198,7 +198,7 @@ def _partition_layers(self, method): for st, ed in self.parts[stage]: for idx, layer in enumerate(self.layers[st: ed]): log_str += f'\t{idx + st:2d}: {layer}\n' - self._logger.info(log_str) + self._logger.info(log_str, ranks=[0]) # Save the partition self._interval = self.parts[pipeline_rank] diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py index 4241bff4becd..5da045326ce0 100644 --- a/colossalai/communication/__init__.py +++ b/colossalai/communication/__init__.py @@ -1,4 +1,4 @@ -from .collective import all_gather, reduce_scatter, scatter +from .collective import all_gather, reduce_scatter, all_reduce from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, recv_forward, recv_backward) @@ -6,7 +6,7 @@ from .utils import send_tensor_meta, recv_tensor_meta __all__ = [ - 'all_gather', 'reduce_scatter', 'scatter', + 'all_gather', 'reduce_scatter', 'all_reduce', 'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward', 'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward', 'send_forward_recv_backward', 'recv_backward', 'recv_forward', diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 6db799c99c73..5778028ea5a6 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -11,7 +11,7 @@ def all_gather(tensor: Tensor, dim: int, - parallel_mode: ParallelMode) -> Tensor: + parallel_mode: ParallelMode, async_op=False) -> Tensor: """Gathers all tensors from the parallel group and concatenates them in a specific dimension. @@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int, """ depth = gpc.get_world_size(parallel_mode) temp = tensor.clone() - shape = list(temp.shape) - shape[dim] *= depth - out = torch.empty(shape, dtype=temp.dtype, device=get_current_device()) - out = list(torch.chunk(out, depth, dim=dim)) - out = [val.contiguous() for val in out] - dist.all_gather(out, temp, group=gpc.get_group(parallel_mode)) - out = torch.cat(out, dim=dim) - return out + # shape = list(temp.shape) + # shape[dim] *= depth + # out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device()) + # out = list(torch.chunk(out, depth, dim=dim)) + # out = [val.contiguous() for val in out] + shape = [1] * len(tensor.shape) + shape[dim] = depth + out = tensor.repeat(shape) + out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim))) + op = dist.all_gather(tensor_list=out, + tensor=temp, + group=gpc.get_group(parallel_mode), + async_op=async_op) + # out = torch.cat(out, dim=dim) + if async_op: + return out, op + else: + return out def reduce_scatter(tensor: Tensor, dim: int, - parallel_mode: ParallelMode) -> Tensor: + parallel_mode: ParallelMode, async_op=False) -> Tensor: """Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. @@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int, :rtype: Tensor """ depth = gpc.get_world_size(parallel_mode) - temp = list(torch.chunk(tensor, depth, dim=dim)) - temp = [val.contiguous() for val in temp] - out = torch.empty(temp[0].shape, - dtype=temp[0].dtype, - device=get_current_device()) - dist.reduce_scatter(output=out, - input_list=temp, - group=gpc.get_group(parallel_mode)) - return out + # temp = list(torch.chunk(tensor, depth, dim=dim)) + # temp = [val.contiguous() for val in temp] + # out = torch.zeros(temp[0].shape, + # dtype=temp[0].dtype, + # device=get_current_device()) + temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) + out = temp[0].clone() + op = dist.reduce_scatter(output=out, + input_list=temp, + group=gpc.get_group(parallel_mode), + async_op=async_op) + if async_op: + return out, op + else: + return out -def scatter(tensor: Tensor, src: int, dim: int, - parallel_mode: ParallelMode) -> Tensor: - """Scatters in a specific dimension from source rank to all ranks in - the parallel group. +def all_reduce(tensor: Tensor, + parallel_mode: ParallelMode, + async_op=False) -> Tensor: + op = dist.all_reduce(tensor, + group=gpc.get_group(parallel_mode), + async_op=async_op) + if async_op: + return tensor, op + else: + return tensor + + +# def scatter(tensor: Tensor, src: int, dim: int, +# parallel_mode: ParallelMode) -> Tensor: +# """Scatters in a specific dimension from source rank to all ranks in +# the parallel group. - :param tensor: Tensor to be scattered - :param dim: The dimension scattering in - :param parallel_mode: Parallel group mode used in this communication - :type tensor: Tensor - :type dim: int - :type parallel_mode: ParallelMode - :return: The tensor generated by scatter - :rtype: Tensor - """ - depth = gpc.get_world_size(parallel_mode) - temp = tensor.clone() - dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) - rank = gpc.get_local_rank(parallel_mode) - out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() - return out +# :param tensor: Tensor to be scattered +# :param dim: The dimension scattering in +# :param parallel_mode: Parallel group mode used in this communication +# :type tensor: Tensor +# :type dim: int +# :type parallel_mode: ParallelMode +# :return: The tensor generated by scatter +# :rtype: Tensor +# """ +# depth = gpc.get_world_size(parallel_mode) +# temp = tensor.clone() +# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) +# rank = gpc.get_local_rank(parallel_mode) +# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() +# return out diff --git a/colossalai/constants.py b/colossalai/constants.py index 073dd2d2a3f1..874c53d7291f 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -25,7 +25,11 @@ # 3D parallel DEPTH_3D = 'DEPTH_3D' +INPUT_GROUP_3D = 'PARALLEL_3D_INPUT' +WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT' +OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT' # Tensor parallel attributes IS_TENSOR_PARALLEL = 'is_tensor_parallel' -TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL] +NUM_PARTITIONS = 'num_partitions' +TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 1a4d8bd432bc..5ced84021447 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -277,10 +277,18 @@ def init_global_dist(self, addr=None, port=None): :type port: int, optional """ # get config - rank = self._dist_args.local_rank + local_rank = self._dist_args.local_rank + rank = self._dist_args.rank world_size = self._dist_args.world_size + if local_rank is None: + local_rank = os.getenv('LOCAL_RANK') + if rank is None: + rank = os.getenv('RANK') + if world_size is None: + world_size = os.getenv('WORLD_SIZE') # default env config, overwrite by exporting # them in your bash script + addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr port = os.getenv('MASTER_PORT', '8008') if port is None else port init_method = f'tcp://{addr}:{port}' @@ -293,7 +301,8 @@ def init_global_dist(self, addr=None, port=None): # None will give the default global process group for pytorch dist operations self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL) - self._global_ranks[ParallelMode.GLOBAL] = rank + self.add_global_rank(ParallelMode.GLOBAL, rank) + # self._global_ranks[ParallelMode.GLOBAL] = rank def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode): @@ -426,18 +435,15 @@ def set_seed(self): if torch.cuda.is_available(): # create random seed for different parallel modes # data parallel seed are kept the same - parallel_seed = seed + tp_rank = self._local_ranks.get(ParallelMode.TENSOR, 0) + pp_rank = self._local_ranks.get(ParallelMode.PIPELINE, 0) + parallel_seed = seed + tp_rank + pp_rank * 1024 add_seed(ParallelMode.DATA, parallel_seed) - # model parallel seeds are different across ranks - pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) - # add seed for data parallel and tensor parallel only if self.is_initialized(ParallelMode.TENSOR): - tp_rank = self.get_local_rank(ParallelMode.TENSOR) - # 100 is only to increase the diff in seeds between pipeline stages - tp_rank_with_offset = tp_rank + pipeline_offset * 1024 - tp_seed = seed + tp_rank_with_offset + dp_rank = self._local_ranks.get(ParallelMode.DATA, 0) + 1 + tp_seed = parallel_seed + dp_rank * 128 add_seed(ParallelMode.TENSOR, tp_seed) set_mode(ParallelMode.DATA) diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index 3912307679de..464049193f0f 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -5,7 +5,7 @@ import os import torch.distributed as dist -from colossalai.constants import DEPTH_3D +from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from colossalai.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode @@ -18,7 +18,7 @@ def _check_depth_env_var(depth): if env_depth: assert int(env_depth) == depth, \ - 'SUMMA_DIM has been set in the current environment and ' \ + 'DEPTH_3D has been set in the current environment and ' \ 'does not match with the value passed to this initialized' else: os.environ[DEPTH_3D] = str(depth) @@ -43,6 +43,7 @@ def init_dist_group(self): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_INPUT + os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D for h in range(self.num_group): for i in range(self.depth): @@ -82,6 +83,7 @@ def init_dist_group(self): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_WEIGHT + os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D for h in range(self.num_group): for k in range(self.depth): @@ -121,6 +123,7 @@ def init_dist_group(self): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_OUTPUT + os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D for h in range(self.num_group): for i in range(self.depth): diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 3c94c1cbfed3..351c67947097 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -42,19 +42,21 @@ def parse_args(): type=str, default=None, help='the master port for distributed training') - parser.add_argument('--world_size', type=int, help='world size for ') + parser.add_argument('--world_size', type=int, help='world size for distributed training') + parser.add_argument('--rank', type=int, help='rank for the default process group') parser.add_argument('--local_rank', type=int, - help='rank for the default process group') + help='local rank on the node') parser.add_argument('--backend', type=str, default='nccl', - help='backend for torch.distributed') + help='backend for distributed communication') return parser.parse_args() def init_dist(config: Union[str, dict] = None, local_rank: int = None, + rank: int = None, world_size: int = None, host: str = None, port: str = None, @@ -86,6 +88,8 @@ def init_dist(config: Union[str, dict] = None, config = args.config if local_rank is None: local_rank = args.local_rank + if rank is None: + rank = args.rank if world_size is None: world_size = args.world_size if host is None: @@ -99,12 +103,14 @@ def init_dist(config: Union[str, dict] = None, host=host, port=port, world_size=world_size, + rank=rank, local_rank=local_rank, backend=backend)) # set distributed settings dist_args = Config( dict(local_rank=args.local_rank, + rank=rank, world_size=args.world_size, backend=args.backend)) @@ -178,6 +184,7 @@ def seed_worker(worker_id): def initialize(config: Union[str, dict] = None, local_rank: int = None, + rank: int = None, world_size: int = None, host: str = None, port: str = None, @@ -209,6 +216,7 @@ def initialize(config: Union[str, dict] = None, # initialize distributed environment init_dist(config=config, local_rank=local_rank, + rank=rank, world_size=world_size, host=host, port=port, diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py new file mode 100644 index 000000000000..057cc008d32b --- /dev/null +++ b/colossalai/nn/init.py @@ -0,0 +1,33 @@ +import math + +from torch import Tensor +from torch.nn import init as init + + +def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'): + if init_method == 'torch': + a = math.sqrt(5) + nonlinearity = 'leaky_relu' + std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) + bound = math.sqrt(3.0) * std + init.uniform_(tensor, -bound, bound) + elif init_method == 'jax': + std = math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std + init.uniform_(tensor, -a, a) + elif init_method == 'jax_embed': + std = math.sqrt(1.0 / fan_in) + init.trunc_normal_(tensor, std=std / .87962566103423978) + elif init_method == 'zero': + init.zeros_(tensor) + +def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'): + if init_method == 'torch': + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(tensor, -bound, bound) + elif init_method == 'jax': + init.normal_(tensor, std=1e-6) + elif init_method == 'jax_embed': + init.trunc_normal_(tensor, std=.02) + elif init_method == 'zero': + init.zeros_(tensor) diff --git a/colossalai/nn/layer/_common_utils.py b/colossalai/nn/layer/_common_utils.py index 3bb45c365824..db0f362b270b 100644 --- a/colossalai/nn/layer/_common_utils.py +++ b/colossalai/nn/layer/_common_utils.py @@ -3,12 +3,12 @@ import math +import numpy as np +from colossalai.utils.common import print_rank_0 import torch -from torch import Tensor -from torch import nn +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS from colossalai.utils import checkpoint - -from colossalai.constants import IS_TENSOR_PARALLEL +from torch import Tensor, nn def divide(numerator, denominator): @@ -33,9 +33,11 @@ def swish(x: Tensor) -> Tensor: ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} -def set_tensor_parallel_attribute(param): - if not hasattr(param, IS_TENSOR_PARALLEL): - setattr(param, IS_TENSOR_PARALLEL, True) +def set_tensor_parallel_attribute(param, size): + # if not hasattr(param, IS_TENSOR_PARALLEL): + setattr(param, IS_TENSOR_PARALLEL, True) + # if not hasattr(param, NUM_PARTITIONS): + setattr(param, NUM_PARTITIONS, size // np.prod(param.shape)) class CheckpointModule(nn.Module): diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index cb790fb510fe..f8287f932ae9 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -1,21 +1,223 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch import torch.distributed as dist -from colossalai.communication import all_gather, reduce_scatter, scatter +from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import empty_cache, get_current_device from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +class linear_3d(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + input_dim: int = 0, + weight_dim: int = -1, + output_dim: int = 0) -> Tensor: + assert input_.shape[-1] == weight.shape[0], \ + 'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape) + + ctx.use_bias = bias is not None + + input_ = all_gather(input_, input_dim, input_parallel_mode) + input_ = torch.cat(input_, dim=input_dim) + # weight = all_gather(weight, weight_dim, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, output_dim, output_parallel_mode) + + if bias is not None: + # ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode) + # src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)] + # dist.broadcast(bias, + # src=src_rank, + # group=gpc.get_group(output_parallel_mode)) + # bias = all_gather(bias, -1, weight_parallel_mode) + output += bias + # ctx.src_rank = src_rank + + # ctx.save_for_backward(input_, weight) + # output = torch.matmul(input_, weight) + # dist.all_reduce(output, group=gpc.get_group(output_parallel_mode)) + # output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + ctx.input_dim = input_dim + ctx.weight_dim = weight_dim + ctx.output_dim = output_dim + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + with torch.no_grad(): + # input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + # dist.all_reduce(input_grad, + # group=gpc.get_group(ctx.input_parallel_mode)) + # weight_grad = torch.matmul( + # input_.reshape(-1, input_.shape[-1]).transpose(0, 1), + # output_grad.reshape(-1, output_grad.shape[-1])) + # dist.all_reduce(weight_grad, + # group=gpc.get_group(ctx.weight_parallel_mode)) + + # bias_grad = torch.sum(output_grad, + # dim=tuple( + # range(len(output_grad.shape))[:-1])) + # bias_grad = reduce_scatter(bias_grad, -1, + # ctx.weight_parallel_mode) + # dist.reduce(bias_grad, + # dst=ctx.src_rank, + # group=gpc.get_group(ctx.output_parallel_mode)) + # if gpc.get_local_rank( + # ctx.output_parallel_mode) != gpc.get_local_rank( + # ctx.input_parallel_mode): + # bias_grad = None + + # input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode) + # weight = all_gather(weight, ctx.weight_dim, + # ctx.weight_parallel_mode) + + output_grad = all_gather(output_grad, ctx.output_dim, + ctx.output_parallel_mode) + output_grad = torch.cat(output_grad, dim=ctx.output_dim) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + + input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim, + ctx.input_parallel_mode, + async_op=True) + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), + output_grad.reshape(-1, output_grad.shape[-1])) + + # weight_grad = torch.matmul( + # input_.reshape(-1, input_.shape[-1]).transpose(0, 1), + # output_grad.reshape(-1, output_grad.shape[-1])) + # weight_grad = reduce_scatter(weight_grad, ctx.weight_dim, + # ctx.weight_parallel_mode) + if ctx.use_bias: + bias_grad = torch.sum(output_grad, + dim=tuple( + range(len(output_grad.shape))[:-1])) + # bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode) + # dist.all_reduce(bias_grad, + # group=gpc.get_group(ctx.weight_parallel_mode)) + weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)]) + + weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + + input_op.wait() + weight_op.wait() + if ctx.use_bias: + bias_grad = weight_grad[-1] + weight_grad = weight_grad[:-1] + + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None + + +class layer_norm_3d(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor, + normalized_shape: int, eps: float, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode) -> Tensor: + # mean = torch.sum(input_, dim=-1) + # dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode)) + # mean /= normalized_shape + # mu = input_ - mean + # var = torch.sum(torch.pow(mu, 2), dim=-1) + # dist.all_reduce(var, group=gpc.get_group(output_parallel_mode)) + # var /= normalized_shape + # std_dev = torch.sqrt(var + eps) + # ctx.save_for_backward(input_, mu, std_dev, weight) + + # output = weight * mu / std_dev + bias + + mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), + output_parallel_mode) / normalized_shape + mu = input_ - mean + var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), + output_parallel_mode) / normalized_shape + sigma = torch.sqrt(var + eps) + + # ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) + # src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] + # transforms = torch.stack([weight, bias]).contiguous() + # dist.broadcast(transforms, + # src=src_rank, + # group=gpc.get_group(input_parallel_mode)) + # transforms = all_gather(transforms, -1, weight_parallel_mode) + # weight, bias = transforms[0], transforms[1] + + ctx.save_for_backward(mu, sigma, weight) + + z = mu / sigma + output = weight * z + bias + + # ctx.src_rank = src_rank + ctx.normalized_shape = normalized_shape + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + mu, sigma, weight = ctx.saved_tensors + with torch.no_grad(): + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + grads = torch.stack([bias_grad, weight_grad]).contiguous() + grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1])) + grads = all_reduce(grads, ctx.weight_parallel_mode) + grads = all_reduce(grads, ctx.input_parallel_mode) + bias_grad, weight_grad = grads[0], grads[1] + + # grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode) + # dist.reduce(grads, + # dst=ctx.src_rank, + # group=gpc.get_group(ctx.input_parallel_mode)) + # if gpc.get_local_rank( + # ctx.input_parallel_mode) == gpc.get_local_rank( + # ctx.output_parallel_mode): + # bias_grad, weight_grad = grads[0], grads[1] + # else: + # bias_grad, weight_grad = None, None + + dz = output_grad * weight + dvar = dz * mu * (-0.5) * sigma**(-3) + dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode) + dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape + dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode) + + input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape + + return input_grad, weight_grad, bias_grad, None, None, None, None, None class Matmul_AB_3D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -29,7 +231,6 @@ def forward(ctx: Any, # A: [m/q^2, n, k/q] # B: [k/q, h/q^2] # C: [m/q^2, n, h/q] - empty_cache() ctx.save_for_backward(A, B) assert A.shape[-1] == B.shape[0], \ @@ -52,6 +253,7 @@ def forward(ctx: Any, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): @@ -72,6 +274,7 @@ class Matmul_ABT_3D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB^T` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -85,7 +288,6 @@ def forward(ctx: Any, # A: [m/q^2, n, h/q] # B: [k/q, h/q^2] # C: [m/q^2, n, k/q] - empty_cache() ctx.save_for_backward(A, B) A_temp = all_gather(A, input_dim, input_parallel_mode) @@ -105,6 +307,7 @@ def forward(ctx: Any, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): @@ -125,6 +328,7 @@ class Matmul_ATB_3D(torch.autograd.Function): """Matrix multiplication for :math:`C = A^TB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -138,7 +342,6 @@ def forward(ctx: Any, # A: [m/q^2, n, k/q] # B: [m/q^2, n, h/q] # C: [k/q, h/q^2] - empty_cache() ctx.save_for_backward(A, B) A_temp = all_gather(A, input_dim, input_parallel_mode) @@ -160,6 +363,7 @@ def forward(ctx: Any, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): @@ -180,6 +384,7 @@ class Add_3D(torch.autograd.Function): """Matrix add bias: :math:`C = A + b` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, @@ -206,6 +411,7 @@ def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: # output_grad: [m/q^2, n, h/q] with torch.no_grad(): @@ -217,8 +423,8 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: dst=ctx.src_rank, group=gpc.get_group(ctx.A_group_parallel_mode)) if gpc.get_local_rank( - ctx.A_group_parallel_mode) != gpc.get_local_rank( - ctx.C_group_parallel_mode): + ctx.A_group_parallel_mode) != gpc.get_local_rank( + ctx.C_group_parallel_mode): bias_grad = None return output_grad, bias_grad, None, None, None, None @@ -227,6 +433,7 @@ class Mul_3D(torch.autograd.Function): """Matrix multiplication for :math:`C = A * b` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, @@ -243,7 +450,7 @@ def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, # [h/q] bias_temp = all_gather(bias_temp, -1, weight_parallel_mode) - empty_cache() + # empty_cache() ctx.save_for_backward(input_, bias_temp) out = torch.mul(input_, bias_temp) @@ -257,6 +464,7 @@ def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: # output_grad: [m/q^2, n, h/q] with torch.no_grad(): @@ -272,8 +480,8 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: dst=ctx.src_rank, group=gpc.get_group(ctx.A_group_parallel_mode)) if gpc.get_local_rank( - ctx.A_group_parallel_mode) != gpc.get_local_rank( - ctx.C_group_parallel_mode): + ctx.A_group_parallel_mode) != gpc.get_local_rank( + ctx.C_group_parallel_mode): bias_grad = None return input_grad, bias_grad, None, None, None, None @@ -282,6 +490,7 @@ class Sum_3D(torch.autograd.Function): """Compute the sum of input tensors """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input_: Tensor, dim: int, @@ -299,6 +508,7 @@ def forward(ctx: Any, return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: with torch.no_grad(): output_grad = output_grad.contiguous() @@ -315,35 +525,39 @@ class Reduce_3D(torch.autograd.Function): """Reduce input tensors """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input_: Tensor, depth: int, parallel_mode: ParallelMode) -> Tensor: dist.all_reduce(input_, group=gpc.get_group(parallel_mode)) return input_.clone() @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return output_grad, None, None -class Slice_3D(torch.autograd.Function): - """Slice input tensor - """ - @staticmethod - def forward(ctx: Any, input_: Tensor, dim: int, depth: int, - parallel_mode: ParallelMode) -> Tensor: - rank = gpc.get_local_rank(parallel_mode) - out = torch.chunk(input_, depth, dim=dim)[rank].contiguous() - - ctx.depth = depth - ctx.parallel_mode = parallel_mode - ctx.dim = dim - ctx.input_shape = input_.shape - - return out - - @staticmethod - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - with torch.no_grad(): - input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) - input_grad.reshape(ctx.input_shape) - return input_grad, None, None, None +# class Slice_3D(torch.autograd.Function): +# """Slice input tensor +# """ +# @staticmethod +# @custom_fwd(cast_inputs=torch.float16) +# def forward(ctx: Any, input_: Tensor, dim: int, depth: int, +# parallel_mode: ParallelMode) -> Tensor: +# rank = gpc.get_local_rank(parallel_mode) +# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous() + +# ctx.depth = depth +# ctx.parallel_mode = parallel_mode +# ctx.dim = dim +# ctx.input_shape = input_.shape + +# return out + +# @staticmethod +# @custom_bwd +# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: +# with torch.no_grad(): +# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) +# input_grad.reshape(ctx.input_shape) +# return input_grad, None, None, None diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py index 3c92360174cb..ca3b405eaf64 100644 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -3,7 +3,8 @@ import os -from colossalai.constants import DEPTH_3D +from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D, + WEIGHT_GROUP_3D) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from torch import Tensor @@ -23,6 +24,10 @@ def get_depth_from_env() -> int: ) +def get_parallel_mode_from_env(group): + return getattr(ParallelMode, os.environ[group]) + + def get_last_group(a, b): mapping = { ParallelMode.PARALLEL_3D_INPUT: 'A', @@ -41,6 +46,11 @@ def get_last_group(a, b): return ParallelMode.PARALLEL_3D_OUTPUT +def swap_in_out_group(): + os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \ + os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D] + + def dbg_check_shape(tensor: Tensor, shape: tuple): rank = gpc.get_global_rank() if rank == 0: diff --git a/colossalai/nn/layer/parallel_3d/_vit.py b/colossalai/nn/layer/parallel_3d/_vit.py index ffe7a146af71..09d9370433f1 100644 --- a/colossalai/nn/layer/parallel_3d/_vit.py +++ b/colossalai/nn/layer/parallel_3d/_vit.py @@ -1,17 +1,21 @@ import math -from typing import Tuple +import os +from typing import Tuple, Optional import torch import torch.distributed as dist +from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, + WEIGHT_GROUP_3D) from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.registry import LAYERS +from colossalai.nn.init import init_bias_, init_weight_ from colossalai.utils import checkpoint, get_current_device from torch import Tensor, dtype, nn from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute from ..vanilla_vision_transformer.layers import to_2tuple -from ._utils import get_depth_from_env +from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group from .layers import Linear3D @@ -38,28 +42,35 @@ def __init__(self, in_chans: int, embed_size: int, drop_prob: float, - flatten: bool = True): + flatten: bool = True, + init_method: str ='torch'): super().__init__() self.depth = get_depth_from_env() - self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT - self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT - self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_last_group(self.input_parallel_mode, + self.weight_parallel_mode) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.in_chans = in_chans self.embed_size = embed_size self.embed_size_per_partition = divide(self.embed_size, self.depth) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten + self.init_weight = 'torch' + self.init_bias = 'torch' + if init_method == 'jax': + self.init_weight = 'jax_embed' + self.init_bias = 'zero' - with seed(ParallelMode.TENSOR): - self.proj = nn.Conv2d(in_chans, - self.embed_size_per_partition, - kernel_size=patch_size, - stride=patch_size) + self.proj = nn.Conv2d(self.in_chans, + self.embed_size_per_partition, + kernel_size=patch_size, + stride=patch_size) self.cls_token = nn.Parameter( torch.zeros(1, 1, self.embed_size_per_partition)) @@ -68,23 +79,26 @@ def __init__(self, self.embed_size_per_partition)) self.pos_drop = nn.Dropout(drop_prob) - self._sync_parameters() - self.proj.weight.register_hook(self._sync_grad_hook) - self.proj.bias.register_hook(self._sync_grad_hook) - self.cls_token.register_hook(self._sync_grad_hook) - self.pos_embed.register_hook(self._sync_grad_hook) - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute(self.proj.weight) - set_tensor_parallel_attribute(self.proj.bias) - set_tensor_parallel_attribute(self.cls_token) - set_tensor_parallel_attribute(self.pos_embed) - - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.input_parallel_mode, self.weight_parallel_mode - - def _sync_parameters(self): + self.reset_parameters(self.init_weight, self.init_bias) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute(self.proj.weight, self.in_chans * self.embed_size * self.num_patches) + set_tensor_parallel_attribute(self.proj.bias, self.embed_size) + set_tensor_parallel_attribute(self.cls_token, 1 * 1 * self.embed_size) + set_tensor_parallel_attribute(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size) + + def reset_parameters(self, init_weight, init_bias): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight) + # std = math.sqrt(1.0 / fan_in) + # nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978) + # nn.init.zeros_(self.proj.bias) + if init_weight != 'torch': + init_weight_(self.proj.weight, fan_in, init_method=init_weight) + init_bias_(self.pos_embed, fan_in, init_method=init_weight) + if init_bias != 'torch': + init_bias_(self.proj.bias, fan_in, init_method=init_bias) + self.to(get_current_device()) weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] dist.broadcast(self.proj.weight, @@ -100,10 +114,11 @@ def _sync_parameters(self): dist.broadcast(self.proj.bias, src=input_src_rank, group=gpc.get_group(self.input_parallel_mode)) - set_tensor_parallel_attribute(self.proj.weight) - set_tensor_parallel_attribute(self.proj.bias) - set_tensor_parallel_attribute(self.cls_token) - set_tensor_parallel_attribute(self.pos_embed) + + self.proj.weight.register_hook(self._sync_grad_hook) + self.proj.bias.register_hook(self._sync_grad_hook) + self.cls_token.register_hook(self._sync_grad_hook) + self.pos_embed.register_hook(self._sync_grad_hook) def _sync_grad_hook(self, grad) -> None: dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode)) @@ -111,6 +126,12 @@ def _sync_grad_hook(self, grad) -> None: return grad def forward(self, x: Tensor) -> Tensor: + # split a partition from inputs + x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( + self.weight_parallel_mode)].contiguous() + x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( + self.input_parallel_mode)].contiguous() + B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." @@ -118,12 +139,6 @@ def forward(self, x: Tensor) -> Tensor: if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - # split a partition from embedded states - x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( - self.weight_parallel_mode)].contiguous() - x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank( - self.input_parallel_mode)].contiguous() - # add cls token & pos embedding # [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q] cls_token = self.cls_token.expand(x.shape[0], -1, -1) @@ -165,36 +180,47 @@ def __init__(self, hidden_dropout_prob: float, dtype: dtype = None, bias: bool = True, - checkpoint: bool = False): + checkpoint: bool = False, + init_method: str ='torch'): super().__init__() self.depth = get_depth_from_env() - self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT - self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT - self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT + # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + # self.output_parallel_mode = get_last_group(self.input_parallel_mode, + # self.weight_parallel_mode) self.hidden_size = hidden_size self.num_attention_heads = divide(num_attention_heads, self.depth) self.attention_head_size = divide(hidden_size, num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.checkpoint = checkpoint + self.init_weight = 'torch' + self.init_bias = 'torch' + if init_method == 'jax': + self.init_weight = 'jax' + self.init_bias = 'zero' self.query_key_value = Linear3D(self.hidden_size, 3 * self.hidden_size, - self.input_parallel_mode, - self.weight_parallel_mode, + # self.input_parallel_mode, + # self.weight_parallel_mode, dtype=dtype, - bias=bias) + bias=bias, + init_weight=self.init_weight, + init_bias=self.init_bias) self.attention_dropout = nn.Dropout(attention_probs_dropout_prob) self.dense = Linear3D(self.hidden_size, self.hidden_size, - self.output_parallel_mode, - self.weight_parallel_mode, + # self.output_parallel_mode, + # self.weight_parallel_mode, dtype=dtype, - bias=bias) + bias=bias, + init_weight=self.init_weight, + init_bias=self.init_bias) self.dropout = nn.Dropout(hidden_dropout_prob) self.softmax = nn.Softmax(dim=-1) - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.input_parallel_mode, self.weight_parallel_mode + # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: + # return self.input_parallel_mode, self.weight_parallel_mode def _forward(self, hidden_states: Tensor) -> Tensor: query_key_value = self.query_key_value(hidden_states) @@ -266,33 +292,41 @@ def __init__(self, hidden_act: str = 'gelu', dtype: dtype = None, bias: bool = True, - checkpoint: bool = False): + checkpoint: bool = False, + init_method: str = 'torch'): super().__init__() - self.depth = get_depth_from_env() - self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT - self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT - self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT + # self.depth = get_depth_from_env() + # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + # self.output_parallel_mode = get_last_group(self.input_parallel_mode, + # self.weight_parallel_mode) self.hidden_size = hidden_size self.mlp_ratio = mlp_ratio self.checkpoint = checkpoint + self.init_weight = init_method + self.init_bias = init_method self.dense_1 = Linear3D(self.hidden_size, self.mlp_ratio * self.hidden_size, - self.input_parallel_mode, - self.weight_parallel_mode, + # self.input_parallel_mode, + # self.weight_parallel_mode, dtype=dtype, - bias=bias) + bias=bias, + init_weight=self.init_weight, + init_bias=self.init_bias) self.activation_func = ACT2FN[hidden_act] self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size, self.hidden_size, - self.output_parallel_mode, - self.weight_parallel_mode, + # self.output_parallel_mode, + # self.weight_parallel_mode, dtype=dtype, - bias=bias) + bias=bias, + init_weight=self.init_weight, + init_bias=self.init_bias) self.dropout = nn.Dropout(hidden_dropout_prob) - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.input_parallel_mode, self.weight_parallel_mode + # def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: + # return self.input_parallel_mode, self.weight_parallel_mode def _forward(self, hidden_states: Tensor) -> Tensor: intermediate_output = self.dense_1(hidden_states) @@ -335,33 +369,41 @@ def __init__(self, in_features: int, num_classes: int, dtype: dtype = None, - bias: bool = True): + bias: bool = True, + init_method: str = 'torch'): super().__init__() - self.depth = get_depth_from_env() - self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT - self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT - self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT + # self.depth = get_depth_from_env() + # self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + # self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + # self.output_parallel_mode = get_last_group(self.input_parallel_mode, + # self.weight_parallel_mode) self.in_features = in_features self.num_classes = num_classes - out_features = math.ceil(self.num_classes / - (self.depth**2)) * (self.depth**2) - self.num_classes_per_partition = divide(self.num_classes, self.depth) + # out_features = math.ceil(self.num_classes / + # (self.depth**2)) * (self.depth**2) + # self.num_classes_per_partition = divide(self.num_classes, self.depth) + self.init_weight = 'torch' + self.init_bias = 'torch' + if init_method == 'jax': + self.init_weight = 'zero' + self.init_bias = 'zero' + self.linear = Linear3D(self.in_features, - out_features, - self.input_parallel_mode, - self.weight_parallel_mode, + self.num_classes, + # self.input_parallel_mode, + # self.weight_parallel_mode, dtype=dtype, - bias=bias) - - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.linear.groups_for_next_layer() + bias=bias, + init_weight=self.init_weight, + init_bias=self.init_bias) def forward(self, x: Tensor) -> Tensor: # [b/q^2, s, h/q] --> [b/q^2, h/q] x = x[:, 0] # [b/q^2, h/q] --> [b/q^2, c/q] x = self.linear(x) - return x[:, :self.num_classes_per_partition] + # return x[:, :self.num_classes_per_partition] + return x def extra_repr(self): return 'in_features={}, num_classes={}'.format(self.in_features, diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index c6d63100872c..775fc207a2a1 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -2,19 +2,28 @@ # -*- encoding: utf-8 -*- import math +import os from typing import Tuple import torch +import torch.distributed as dist import torch.nn as nn +from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, + WEIGHT_GROUP_3D) from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.nn.init import init_bias_, init_weight_ from colossalai.registry import LAYERS from colossalai.utils import get_current_device from torch import Tensor, dtype from torch.nn import Parameter +from torch.nn import init as init from .._common_utils import divide, set_tensor_parallel_attribute -from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D -from ._utils import get_depth_from_env, get_last_group +from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d, + linear_3d) +from ._utils import (get_depth_from_env, get_last_group, + get_parallel_mode_from_env, swap_in_out_group) @LAYERS.register_module @@ -22,20 +31,19 @@ class LayerNorm3D(nn.Module): def __init__( self, normalized_shape: int, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, + # input_parallel_mode: ParallelMode, + # weight_parallel_mode: ParallelMode, eps: float = 1e-12, dtype: dtype = None, ): super().__init__() - self.input_parallel_mode = input_parallel_mode - self.weight_parallel_mode = weight_parallel_mode + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.normalized_shape = normalized_shape - self.normalized_shape_per_partition = divide(normalized_shape, - self.depth**2) + self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( torch.ones(self.normalized_shape_per_partition, @@ -49,37 +57,40 @@ def __init__( self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute(self.weight) - set_tensor_parallel_attribute(self.bias) - - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.input_parallel_mode, self.weight_parallel_mode + set_tensor_parallel_attribute(self.weight, self.normalized_shape) + set_tensor_parallel_attribute(self.bias, self.normalized_shape) def reset_parameters(self): - nn.init.zeros_(self.bias) - nn.init.ones_(self.weight) + init.zeros_(self.bias) + init.ones_(self.weight) def forward(self, input_: Tensor) -> Tensor: - '''x = weight * (x - mean) / sqrt(var + eps) + bias''' - # input: [m/q^2, n, h/q] - # [m/q^2, n, 1] - mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode, - True) / self.normalized_shape - # [m/q^2, n, 1] - var = (input_ - mean).pow(2) - var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode, - True) / self.normalized_shape - - output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon) - output = Mul_3D.apply(output, self.weight, self.depth, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode) - output = Add_3D.apply(output, self.bias, self.depth, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode) - return output + # '''x = weight * (x - mean) / sqrt(var + eps) + bias''' + # # input: [m/q^2, n, h/q] + # # [m/q^2, n, 1] + # mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode, + # True) / self.normalized_shape + # # [m/q^2, n, 1] + # var = (input_ - mean).pow(2) + # var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode, + # True) / self.normalized_shape + + # output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon) + # output = Mul_3D.apply(output, self.weight, self.depth, + # self.input_parallel_mode, + # self.weight_parallel_mode, + # self.output_parallel_mode) + # output = Add_3D.apply(output, self.bias, self.depth, + # self.input_parallel_mode, + # self.weight_parallel_mode, + # self.output_parallel_mode) + # return output + return layer_norm_3d.apply(input_, self.weight, self.bias, + self.normalized_shape, + self.variance_epsilon, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode) def extra_repr(self): return '{}, eps={}'.format(self.normalized_shape, @@ -88,33 +99,36 @@ def extra_repr(self): @LAYERS.register_module class Linear3D(nn.Module): - def __init__(self, - in_features: int, - out_features: int, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - bias: bool = True, - dtype: dtype = None): + def __init__( + self, + in_features: int, + out_features: int, + # input_parallel_mode: ParallelMode, + # weight_parallel_mode: ParallelMode, + bias: bool = True, + dtype: dtype = None, + init_weight: str ='torch', + init_bias: str ='torch'): super().__init__() self.in_features = in_features self.out_features = out_features - self.input_parallel_mode = input_parallel_mode - self.weight_parallel_mode = weight_parallel_mode + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) - self.with_bias = bias + # self.with_bias = bias self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(out_features, self.depth**2) + self.out_features_per_partition = divide(out_features, self.depth) - # [k/q, h/q^2] + # [k/q, h/q] self.weight = Parameter( torch.empty(self.in_features_per_partition, self.out_features_per_partition, device=get_current_device(), dtype=dtype)) - # [h/q^2] + # [h/q] if bias: self.bias = Parameter( torch.zeros(self.out_features_per_partition, @@ -123,49 +137,54 @@ def __init__(self, else: self.register_parameter('bias', None) - self.reset_parameters() + self.reset_parameters(init_weight, init_bias) self._set_tensor_parallel_attributes() + swap_in_out_group() def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute(self.weight) + set_tensor_parallel_attribute(self.weight, self.in_features * self.out_features) if self.bias is not None: - set_tensor_parallel_attribute(self.bias) + set_tensor_parallel_attribute(self.bias, self.out_features) - def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]: - return self.output_parallel_mode, self.weight_parallel_mode - - def reset_parameters(self): + def reset_parameters(self, init_weight, init_bias) -> None: # setting - fan_in = self.in_features - a = math.sqrt(5) - nonlinearity = 'leaky_relu' - + fan_in, fan_out = self.in_features, self.out_features + weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] + output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] + # init weight - std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) - bound = math.sqrt(3.0) * std - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.weight, -bound, bound) - + init_weight_(self.weight, fan_in, fan_out, init_method=init_weight) + dist.broadcast(self.weight, + src=weight_src_rank, + group=gpc.get_group(self.weight_parallel_mode)) # init bias - if self.with_bias: - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.bias, -bound, bound) + if self.bias is not None: + init_bias_(self.bias, fan_in, init_method=init_bias) + dist.broadcast(self.bias, + src=weight_src_rank, + group=gpc.get_group(self.weight_parallel_mode)) + dist.broadcast(self.bias, + src=output_src_rank, + group=gpc.get_group(self.output_parallel_mode)) def forward(self, input_: Tensor) -> Tensor: - # input: [m/q^2, n, k/q] - # output: [m/q^2, n, h/q] - output = Matmul_AB_3D.apply(input_, self.weight, self.depth, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode) - - if self.with_bias: - output = Add_3D.apply(output, self.bias, self.depth, - self.output_parallel_mode, - self.weight_parallel_mode, - self.input_parallel_mode) - return output + # # input: [m/q^2, n, k/q] + # # output: [m/q^2, n, h/q] + # output = Matmul_AB_3D.apply(input_, self.weight, self.depth, + # self.input_parallel_mode, + # self.weight_parallel_mode, + # self.output_parallel_mode) + + # if self.bias is not None: + # output = Add_3D.apply(output, self.bias, self.depth, + # self.output_parallel_mode, + # self.weight_parallel_mode, + # self.input_parallel_mode) + # return output + return linear_3d.apply(input_, self.weight, self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( diff --git a/colossalai/nn/loss/cross_entropy_3d.py b/colossalai/nn/loss/cross_entropy_3d.py index b1ef7731bc39..97409322d1f5 100644 --- a/colossalai/nn/loss/cross_entropy_3d.py +++ b/colossalai/nn/loss/cross_entropy_3d.py @@ -1,32 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os + import torch import torch.distributed as dist -from torch.nn.modules.loss import _Loss - -from colossalai.communication import all_gather +from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, + WEIGHT_GROUP_3D) from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_3d._operation import Reduce_3D -from colossalai.nn.layer.parallel_3d._utils import get_last_group, get_depth_from_env +from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env, + get_last_group, + get_parallel_mode_from_env) from colossalai.registry import LOSSES from colossalai.utils import get_current_device - - -def accuracy_3d(output, target, input_parallel_mode, weight_parallel_mode): - depth = get_depth_from_env() - output_parallel_mode = get_last_group(input_parallel_mode, - weight_parallel_mode) - j = gpc.get_local_rank(input_parallel_mode) - i = gpc.get_local_rank(weight_parallel_mode) - target = torch.chunk(target, depth, dim=0)[i] - target = torch.chunk(target, depth, dim=0)[j] - output = all_gather(output, -1, output_parallel_mode) - prediction = torch.argmax(output, dim=-1) - correct = torch.sum(prediction == target) - dist.all_reduce(correct, group=gpc.get_group(input_parallel_mode)) - dist.all_reduce(correct, group=gpc.get_group(weight_parallel_mode)) - return correct.item() +from torch.nn.modules.loss import _Loss class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function): @@ -112,16 +100,18 @@ class CrossEntropyLoss3D(_Loss): :param reduction: whether to average the loss, defaults to True :type reduction: bool, optional """ - def __init__(self, - input_parallel_mode, - weight_parallel_mode, - reduction=True): + def __init__( + self, + # input_parallel_mode, + # weight_parallel_mode, + reduction=True, + label_smoothing=0.0): super().__init__() self.depth = get_depth_from_env() - self.input_parallel_mode = input_parallel_mode - self.weight_parallel_mode = weight_parallel_mode - self.output_parallel_mode = get_last_group(input_parallel_mode, - weight_parallel_mode) + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_last_group(self.input_parallel_mode, + self.weight_parallel_mode) self.input_rank = gpc.get_local_rank(self.input_parallel_mode) self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode) self.reduction_mean = reduction @@ -141,53 +131,53 @@ def forward(self, logits, targets): return loss -@LOSSES.register_module -class LabelSmoothingCrossEntropy3D(_Loss): - """ - NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy - - :param input_parallel_mode: parallel mode for input tensor - :type input_parallel_mode: ParallelMode - :param weight_parallel_mode: parallel mode for weight - :type weight_parallel_mode: ParallelMode - :param smoothing: label smoothing value, defaults to 0.1 - :type smoothing: float - :param reduction: whether to average the loss, defaults to True - :type reduction: bool, optional - """ - def __init__(self, - input_parallel_mode, - weight_parallel_mode, - smoothing=0.1, - reduction=True): - super().__init__() - assert smoothing < 1.0 - self.smoothing = smoothing - self.confidence = 1. - smoothing - self.depth = get_depth_from_env() - self.input_parallel_mode = input_parallel_mode - self.weight_parallel_mode = weight_parallel_mode - self.output_parallel_mode = get_last_group(input_parallel_mode, - weight_parallel_mode) - self.reduction_mean = reduction - - def forward(self, logits, targets): - # split label partition from the entire batch - j = gpc.get_local_rank(self.input_parallel_mode) - i = gpc.get_local_rank(self.weight_parallel_mode) - targets = torch.chunk(targets, self.depth, dim=0)[i] - targets = torch.chunk(targets, self.depth, dim=0)[j] - exp_logits = torch.exp(logits) - sum_exp_logits = Sum3D.apply(exp_logits, -1, depth, - self.output_parallel_mode, False) - log_probs = torch.log(sum_exp_logits) - logits - nll_loss = _ParallelCrossEntropyLossFunction_3D.apply( - logits, targets, self.depth, self.output_parallel_mode) - smooth_loss = -log_probs.mean(dim=-1) - loss = self.confidence * nll_loss + self.smoothing * smooth_loss - if self.reduction_mean: - loss = loss.sum() - loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode) - loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode) - loss /= batch_size - return loss +# @LOSSES.register_module +# class LabelSmoothingCrossEntropy3D(_Loss): +# """ +# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy + +# :param input_parallel_mode: parallel mode for input tensor +# :type input_parallel_mode: ParallelMode +# :param weight_parallel_mode: parallel mode for weight +# :type weight_parallel_mode: ParallelMode +# :param smoothing: label smoothing value, defaults to 0.1 +# :type smoothing: float +# :param reduction: whether to average the loss, defaults to True +# :type reduction: bool, optional +# """ +# def __init__(self, +# input_parallel_mode, +# weight_parallel_mode, +# smoothing=0.1, +# reduction=True): +# super().__init__() +# assert smoothing < 1.0 +# self.smoothing = smoothing +# self.confidence = 1. - smoothing +# self.depth = get_depth_from_env() +# self.input_parallel_mode = input_parallel_mode +# self.weight_parallel_mode = weight_parallel_mode +# self.output_parallel_mode = get_last_group(input_parallel_mode, +# weight_parallel_mode) +# self.reduction_mean = reduction + +# def forward(self, logits, targets): +# # split label partition from the entire batch +# j = gpc.get_local_rank(self.input_parallel_mode) +# i = gpc.get_local_rank(self.weight_parallel_mode) +# targets = torch.chunk(targets, self.depth, dim=0)[i] +# targets = torch.chunk(targets, self.depth, dim=0)[j] +# exp_logits = torch.exp(logits) +# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth, +# self.output_parallel_mode, False) +# log_probs = torch.log(sum_exp_logits) - logits +# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply( +# logits, targets, self.depth, self.output_parallel_mode) +# smooth_loss = -log_probs.mean(dim=-1) +# loss = self.confidence * nll_loss + self.smoothing * smooth_loss +# if self.reduction_mean: +# loss = loss.sum() +# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode) +# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode) +# loss /= batch_size +# return loss diff --git a/colossalai/nn/optimizer/_utils.py b/colossalai/nn/optimizer/_utils.py index 255b48ea9720..31fc62213437 100644 --- a/colossalai/nn/optimizer/_utils.py +++ b/colossalai/nn/optimizer/_utils.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from colossalai.utils.common import print_rank_0 import torch from torch._six import inf @@ -10,8 +11,9 @@ print('Colossalai should be built with cuda extension to use the FP16 optimizer') from ..multi_tensor_apply import multi_tensor_applier + +from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES, NUM_PARTITIONS import torch.distributed as dist -from colossalai.constants import IS_TENSOR_PARALLEL, TENSOR_PARALLEL_ATTRIBUTES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -105,7 +107,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): no_tensor_parallel_grads = [] for p in params: if is_model_parallel_parameter(p): - tensor_parallel_grads.append(p.grad.data) + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) + tensor_parallel_grads.append(p.grad.data / reductor) else: no_tensor_parallel_grads.append(p.grad.data) if norm_type == 2.0: diff --git a/colossalai/nn/optimizer/fp16_optimizer.py b/colossalai/nn/optimizer/fp16_optimizer.py index c64a732c9dee..4ae970910c5c 100644 --- a/colossalai/nn/optimizer/fp16_optimizer.py +++ b/colossalai/nn/optimizer/fp16_optimizer.py @@ -113,7 +113,7 @@ def update(self, found_inf): if self._hysteresis_tracker <= 0: self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) - self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}') + self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) else: # If there is no nan/inf, increment the growth tracker. self._growth_tracker += 1 @@ -125,10 +125,10 @@ def update(self, found_inf): # and scale up the loss scale. if self._max_scale is not None and self._scale >= self._max_scale: self._logger.info( - f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed') + f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) else: self._scale = self._scale * self.growth_factor - self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}') + self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) def state_dict(self): state_dict = {} diff --git a/colossalai/registry/__init__.py b/colossalai/registry/__init__.py index 1de1c98aea1e..9f270f049611 100644 --- a/colossalai/registry/__init__.py +++ b/colossalai/registry/__init__.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.optim as optim import torchvision.models as tv_models -from torchvision.transforms import transforms +from torchvision import transforms from .registry import Registry diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 96a82d995817..92ef64393305 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -9,6 +9,7 @@ from tqdm import tqdm from colossalai.builder import build_hooks +from colossalai.core import global_context as gpc from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger from colossalai.nn.data import DataParallelSampler @@ -159,13 +160,22 @@ def _train_epoch(self, else: progress = tqdm(progress, desc=f'[Epoch {epoch} train]') - # train 1 epoch + # train 1 epoch + # ### metric measured by zbian + # train_loss = 0 + # batch_cnt = 0 + # num_samples = 0 + # ###### self._call_hooks('before_train_epoch') self._call_timer(action='start', item='train-epoch') for i in progress: self._call_hooks('before_train_iter') self._call_timer(action='start', item='train-step') + # ### metric measured by zbian + # cur_lr = self._engine.optimizer.param_groups[0]['lr'] + # ###### + if i == self._steps_per_epoch - 1: is_last_iteration = True else: @@ -178,6 +188,24 @@ def _train_epoch(self, self._cur_step += 1 + # ### metric measured by zbian + # if display_progress: + # if isinstance(label, (tuple, list)): + # batch_size = label[0].size(0) + # else: + # batch_size = label.size(0) + # batch_size *= self._engine._grad_accum_size * gpc.data_parallel_size + # train_loss += loss.item() + # num_samples += batch_size + # batch_cnt += 1 + # batch_time = self._timer.get_timer('train-step').get_elapsed_time() + # print_features = dict(lr='%g' % cur_lr, + # loss='%.3f' % (train_loss / (i + 1)), + # throughput='%.3f (samples/sec)' % + # (batch_size / (batch_time + 1e-12))) + # progress.set_postfix(**print_features) + # ###### + # stop when max iter is reached if self._exceed_max_step(): break @@ -185,6 +213,16 @@ def _train_epoch(self, self._call_timer(action='stop', item='train-epoch', keep_in_history=True) self._call_hooks('after_train_epoch') self._call_timer(action='reset', item='train-step') + # ### metric measured by zbian + # if display_progress: + # epoch_time = self._timer.get_timer('train-epoch').get_elapsed_time() + # epoch_loss = train_loss / batch_cnt + # epoch_throughput = num_samples / (epoch_time + 1e-12) + # if display_progress: + # self._logger.info( + # '[Epoch %d] Loss: %.3f | Throughput: %.3f (samples/sec)' % + # (epoch, epoch_loss, epoch_throughput)) + # ###### def _eval(self, test_dataloader: DataLoader, @@ -352,4 +390,4 @@ def predict(self, data: Union[Tensor, List[Tensor]]): simple_dataloader = [(data, None)] data_iter = iter(simple_dataloader) output, _, _ = self._engine.step(data_iter, return_loss=False) - return output + return output \ No newline at end of file diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py index e1d9d4714277..0f53f79c9e5e 100644 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -68,7 +68,7 @@ def after_train_epoch(self): self.trainer.engine.optimizer, self._lr_scheduler) self.logger.info( - f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}') + f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) @HOOKS.register_module @@ -135,6 +135,6 @@ def before_train(self): self.trainer.cur_epoch = last_epoch self.logger.info( - f'loaded checkpoint from {path}') + f'loaded checkpoint from {path}', ranks=[0]) else: raise FileNotFoundError(f'checkpoint is not found at {path}') diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py index ca483aebe14b..ae9fcd2561cc 100644 --- a/colossalai/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -31,14 +31,27 @@ def __init__(self, super().__init__(trainer=trainer, priority=priority) self.by_epoch = by_epoch + assert not ('warmup_epochs' in lr_scheduler_cfg and 'warmup_steps' in lr_scheduler_cfg), \ + 'Do not set both warmup_epochs and warmup_steps for lr_scheduler.' + warmup_steps = 0 if by_epoch: total_steps = trainer.max_epochs + if 'warmup_epochs' in lr_scheduler_cfg: + warmup_steps = lr_scheduler_cfg['warmup_epochs'] + elif 'warmup_steps' in lr_scheduler_cfg: + warmup_steps = lr_scheduler_cfg['warmup_steps'] else: total_steps = trainer.max_epochs * trainer.steps_per_epoch if trainer.max_steps is not None: total_steps = min(total_steps, trainer.max_steps) + if 'warmup_epochs' in lr_scheduler_cfg: + warmup_steps = lr_scheduler_cfg['warmup_epochs'] * trainer.steps_per_epoch + elif 'warmup_steps' in lr_scheduler_cfg: + warmup_steps = lr_scheduler_cfg['warmup_steps'] lr_scheduler_cfg['total_steps'] = total_steps + lr_scheduler_cfg['warmup_steps'] = warmup_steps + lr_scheduler_cfg.pop('warmup_epochs', None) self.lr_scheduler = build_lr_scheduler( lr_scheduler_cfg, trainer.engine.optimizer) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 31004d90ec22..50834d8ab705 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -168,15 +168,15 @@ class Accuracy3DHook(MetricHook): def __init__(self, trainer: Trainer, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, + # input_parallel_mode: ParallelMode, + # weight_parallel_mode: ParallelMode, priority: int = 10): super().__init__(trainer, priority) if self._is_stage_to_compute: - self.metric = Accuracy3D(epoch_only=True, - input_parallel_mode=input_parallel_mode, - weight_parallel_mode=weight_parallel_mode) + self.metric = Accuracy3D(epoch_only=True) + # input_parallel_mode=input_parallel_mode, + # weight_parallel_mode=weight_parallel_mode) # register the metric self.trainer.states['metrics']['test'][ diff --git a/colossalai/trainer/metric.py b/colossalai/trainer/metric.py index d0255b4ea3b3..5038826c96ac 100644 --- a/colossalai/trainer/metric.py +++ b/colossalai/trainer/metric.py @@ -3,12 +3,14 @@ import torch import torch.distributed as dist - from colossalai.communication import all_gather +from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, + WEIGHT_GROUP_3D) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer._parallel_utilities import _gather -from colossalai.nn.layer.parallel_3d._utils import get_last_group +from colossalai.nn.layer.parallel_3d._utils import (get_last_group, + get_parallel_mode_from_env) from colossalai.utils import get_current_device @@ -22,7 +24,6 @@ class Metric(ABC): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - def __init__(self, epoch_only: bool): # is the metric only read for the full epoch self._epoch_only = epoch_only @@ -80,7 +81,6 @@ class Loss(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) self.last_step_loss = torch.zeros(1, device=get_current_device()) @@ -110,7 +110,8 @@ def get_accumulated_value(self): """Returns accumulated loss. """ if gpc.is_initialized(ParallelMode.DATA): - dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, + dist.all_reduce(self.accum_loss, + op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) @@ -132,7 +133,6 @@ class LearningRate(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - def __init__(self, epoch_only: bool, initial_lr: float = 0.): super().__init__(epoch_only=epoch_only) self.lr = 0. @@ -160,7 +160,6 @@ class Accuracy(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - def __init__(self, epoch_only: bool): super().__init__(epoch_only=epoch_only) self.last_step_sum = torch.zeros(1, device=get_current_device()) @@ -219,7 +218,6 @@ class Accuracy2D(Accuracy): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - def __init__(self, epoch_only: bool): super().__init__(epoch_only=epoch_only) @@ -229,11 +227,7 @@ def update(self, logits, label) -> None: if isinstance(label, (list, tuple)): label = label[0] - logits = _gather( - logits, - ParallelMode.PARALLEL_2D_ROW, - 1 - ) + logits = _gather(logits, ParallelMode.PARALLEL_2D_ROW, 1) logits = _gather( logits, ParallelMode.PARALLEL_2D_COL, @@ -290,11 +284,7 @@ def update(self, logits, label) -> None: if isinstance(label, (list, tuple)): label = label[0] - logits = _gather( - logits, - ParallelMode.PARALLEL_2P5D_ROW, - 1 - ) + logits = _gather(logits, ParallelMode.PARALLEL_2P5D_ROW, 1) logits = _gather( logits, ParallelMode.PARALLEL_2P5D_COL, @@ -329,14 +319,14 @@ class Accuracy3D(Accuracy): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ - - def __init__(self, epoch_only, input_parallel_mode, weight_parallel_mode): + def __init__(self, epoch_only): + # input_parallel_mode, weight_parallel_mode): super().__init__(epoch_only=epoch_only) self.depth = int(os.environ['DEPTH_3D']) - self.input_parallel_mode = input_parallel_mode - self.weight_parallel_mode = weight_parallel_mode - self.output_parallel_mode = get_last_group(input_parallel_mode, - weight_parallel_mode) + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_last_group(self.input_parallel_mode, + self.weight_parallel_mode) def update(self, logits, target): if isinstance(logits, (list, tuple)): @@ -352,6 +342,7 @@ def update(self, logits, target): target = torch.chunk(target, self.depth, dim=0)[j] logits = all_gather(logits, -1, self.output_parallel_mode) + logits = torch.cat(logits, dim=-1) prediction = torch.argmax(logits, dim=-1) correct = torch.sum(prediction == target) diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py index b47b4099dafe..f60029b045c0 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/utils/memory.py @@ -15,10 +15,18 @@ def bytes_to_GB(val, decimal=2): '''A byte-to-Gigabyte converter, defaultly using binary notation. :param val: X bytes to convert - :return: X' Gb + :return: X' GB ''' return round(val / (1024 * 1024 * 1024), decimal) +def bytes_to_MB(val, decimal=2): + '''A byte-to-Megabyte converter, defaultly using binary notation. + + :param val: X bytes to convert + :return: X' MB + ''' + return round(val / (1024 * 1024), decimal) + def report_memory_usage(message): '''Calculate and print RAM usage (in GB) @@ -35,14 +43,14 @@ def report_memory_usage(message): vm_stats = psutil.virtual_memory() vm_used = bytes_to_GB(vm_stats.total - vm_stats.available) - gpu_allocated = bytes_to_GB(torch.cuda.memory_allocated()) - gpu_max_allocated = bytes_to_GB(torch.cuda.max_memory_allocated()) - gpu_cached = bytes_to_GB(torch.cuda.memory_cached()) - gpu_max_cached = bytes_to_GB(torch.cuda.max_memory_cached()) + gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated()) + gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated()) + gpu_cached = bytes_to_MB(torch.cuda.memory_reserved()) + gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved()) get_global_dist_logger().info( - f"{message} - GPU: allocated {gpu_allocated}GB, max allocated {gpu_max_allocated}GB, cached: {gpu_cached} GB, " - f"max cached: {gpu_max_cached}GB, CPU Virtual Memory: used = {vm_used}GB, percent = {vm_stats.percent}%") + f"{message} - GPU: allocated {gpu_allocated}MB, max allocated {gpu_max_allocated}MB, cached: {gpu_cached} MB, " + f"max cached: {gpu_max_cached}MB, CPU Virtual Memory: used = {vm_used}GB, percent = {vm_stats.percent}%") # get the peak memory to report correct data, so reset the counter for the next call if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ diff --git a/configs/vit/vit_2d.py b/configs/vit/vit_2d.py index 23ddc8d6cad8..b771b583e9d9 100644 --- a/configs/vit/vit_2d.py +++ b/configs/vit/vit_2d.py @@ -144,7 +144,7 @@ parallel = dict( pipeline=dict(size=1), - tensor=dict(size=1, mode='2d'), + tensor=dict(size=4, mode='2d'), ) # for fp16 training diff --git a/setup.py b/setup.py index 8541b0a6ce3a..f7684d4daac9 100644 --- a/setup.py +++ b/setup.py @@ -132,4 +132,4 @@ def fetch_requirements(path): ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, install_requires=install_requires, -) +) \ No newline at end of file diff --git a/tests/test_layers/test_3d/common.py b/tests/test_layers/test_3d/common.py index c85046855a15..88c0f41c6038 100644 --- a/tests/test_layers/test_3d/common.py +++ b/tests/test_layers/test_3d/common.py @@ -7,9 +7,9 @@ BATCH_SIZE = 512 SEQ_LENGTH = 128 HIDDEN_SIZE = 512 -NUM_CLASSES = 10 +NUM_CLASSES = 1000 NUM_BLOCKS = 6 -IMG_SIZE = 32 +IMG_SIZE = 224 def check_equal(A, B): - return torch.allclose(A, B, rtol=1e-5, atol=1e-2) + return torch.allclose(A, B, rtol=1e-4, atol=1e-2) diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 21c560820e3e..7c1212c20f4a 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -10,14 +10,14 @@ seed=0) -def check_operations(): - check_AB() - check_ABT() - check_ATB() - check_add() - check_mul() - check_sum() - # check_pooler() +# def check_operations(): +# check_AB() +# check_ABT() +# check_ATB() +# check_add() +# check_mul() +# check_sum() +# check_pooler() def check_layer(): @@ -48,7 +48,7 @@ def _test_main(): torch.backends.cudnn.benchmark = True # check operation - check_operations() + # check_operations() # check layers check_layer() diff --git a/tests/test_layers/test_3d/test_conn.py b/tests/test_layers/test_3d/test_conn.py index 83cb32dd5203..c88368b93edf 100644 --- a/tests/test_layers/test_3d/test_conn.py +++ b/tests/test_layers/test_3d/test_conn.py @@ -1,19 +1,34 @@ +import time + import torch import torch.distributed as dist +from colossalai.communication import all_gather, reduce_scatter, all_reduce +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import init_dist, parse_args +from colossalai.utils import get_current_device, print_rank_0 + +# ARGS = parse_args() +# size = ARGS.world_size +# rank = ARGS.rank -from colossalai.initialize import parse_args -from colossalai.utils import get_current_device +# init_method = f'tcp://{ARGS.host}:{ARGS.port}' +# dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method) +CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) +init_dist(CONFIG) -ARGS = parse_args() -size = ARGS.world_size -rank = ARGS.local_rank +assert dist.get_rank() == gpc.get_global_rank() -init_method = f'tcp://{ARGS.host}:{ARGS.port}' -dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method) print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) SIZE = 8 tensor = torch.randn(SIZE) tensor = tensor.to(get_current_device()) -dist.all_reduce(tensor) -print('Rank {0}: {1}'.format(rank, tensor.detach().cpu().numpy().tolist())) +print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) +time.sleep(1) +# tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) +# tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) +tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) +print_rank_0('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) +op.wait() +print_rank_0('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) diff --git a/tests/test_layers/test_3d/test_layer.py b/tests/test_layers/test_3d/test_layer.py index db5de22a4cbf..4c661ed658bc 100644 --- a/tests/test_layers/test_3d/test_layer.py +++ b/tests/test_layers/test_3d/test_layer.py @@ -10,6 +10,8 @@ from colossalai.logging import get_global_dist_logger from colossalai.registry import LAYERS, LOSSES from colossalai.utils import get_current_device, print_rank_0 +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from common import * @@ -22,24 +24,38 @@ def check_linear(): INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) layer = LAYERS.get_module('Linear3D')(INPUT_SIZE, OUTPUT_SIZE, - ParallelMode.PARALLEL_3D_INPUT, - ParallelMode.PARALLEL_3D_WEIGHT, + # ParallelMode.PARALLEL_3D_INPUT, + # ParallelMode.PARALLEL_3D_WEIGHT, dtype=dtype, bias=True) - torch.nn.init.zeros_(layer.bias) - torch.nn.init.ones_(layer.weight) + # torch.nn.init.zeros_(layer.bias) + # torch.nn.init.ones_(layer.weight) layer = layer.to(device) layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE) - torch.nn.init.zeros_(layer_master.bias) - torch.nn.init.ones_(layer_master.weight) + # torch.nn.init.zeros_(layer_master.bias) + # torch.nn.init.ones_(layer_master.weight) layer_master = layer_master.to(device) + weight_master = layer_master.weight.data.transpose(0, 1) + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[k] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + layer.weight = torch.nn.Parameter(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias = torch.nn.Parameter(bias) + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) torch.distributed.broadcast(A_master, src=0) @@ -89,21 +105,15 @@ def check_linear(): B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] logger.info('Rank {} linear backward (weight_grad): {}'.format( rank, check_equal(B_grad, layer.weight.grad))) - if j == k: - bias_grad = layer_master.bias.grad - bias_grad = torch.chunk(bias_grad, DEPTH)[j] - bias_grad = torch.chunk(bias_grad, DEPTH)[i] - logger.info('Rank {} linear backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) - else: - logger.info('Rank {} linear backward (bias_grad): {}'.format( - rank, - # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0)) - layer.bias.grad is None)) + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} linear backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + # logger.info(f'\nRank {rank} Master:\n{layer_master.bias.grad}\nRank {rank} True:\n{bias_grad}\nRank {rank} Out:\n{layer.bias.grad}') return fwd_end - fwd_start, bwd_end - bwd_start @@ -115,18 +125,31 @@ def check_layernorm(): dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) norm = LAYERS.get_module('LayerNorm3D')(INPUT_SIZE, - ParallelMode.PARALLEL_3D_INPUT, - ParallelMode.PARALLEL_3D_WEIGHT, + # ParallelMode.PARALLEL_3D_INPUT, + # ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6, dtype=dtype) norm = norm.to(device) norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6) norm_master = norm_master.to(device) + + weight_master = norm_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH)[k] + norm.weight = torch.nn.Parameter(weight) + bias_master = norm_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[k] + norm.bias = torch.nn.Parameter(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) @@ -181,29 +204,15 @@ def check_layernorm(): logger.info('Rank {} layernorm backward (input_grad): {}'.format( rank, check_equal(A_grad, A.grad))) - if j == k: - bias_grad = norm_master.weight.grad - bias_grad = torch.chunk(bias_grad, DEPTH)[j] - bias_grad = torch.chunk(bias_grad, DEPTH)[i] - logger.info('Rank {} linear backward (weight_grad): {}'.format( - rank, check_equal(bias_grad, norm.weight.grad))) - else: - logger.info('Rank {} linear backward (weight_grad): {}'.format( - rank, - # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0)) - norm.weight.grad is None)) - - if j == k: - bias_grad = norm_master.bias.grad - bias_grad = torch.chunk(bias_grad, DEPTH)[j] - bias_grad = torch.chunk(bias_grad, DEPTH)[i] - logger.info('Rank {} linear backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, norm.bias.grad))) - else: - logger.info('Rank {} linear backward (bias_grad): {}'.format( - rank, - # np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0)) - norm.bias.grad is None)) + bias_grad = norm_master.weight.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (weight_grad): {}'.format( + rank, check_equal(bias_grad, norm.weight.grad))) + + bias_grad = norm_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} layernorm backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, norm.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -216,9 +225,13 @@ def check_attention(): INPUT_SIZE = HIDDEN_SIZE NUM_ATTENTION_HEADS = 2 - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) layer = LAYERS.get_module('ViTSelfAttention3D')(HIDDEN_SIZE, NUM_ATTENTION_HEADS, @@ -268,9 +281,13 @@ def check_mlp(): dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) layer = LAYERS.get_module('ViTMLP3D')(HIDDEN_SIZE, 1, @@ -325,23 +342,37 @@ def check_head(): dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) head = LAYERS.get_module('ViTHead3D')(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True) - torch.nn.init.zeros_(head.linear.bias) - torch.nn.init.ones_(head.linear.weight) + # torch.nn.init.zeros_(head.linear.bias) + # torch.nn.init.ones_(head.linear.weight) head = head.to(device) layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True) - torch.nn.init.zeros_(layer.linear.bias) - torch.nn.init.ones_(layer.linear.weight) + # torch.nn.init.zeros_(layer.linear.bias) + # torch.nn.init.ones_(layer.linear.weight) layer = layer.to(device) + weight_master = layer.linear.weight.data.transpose(0, 1) + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[k] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + head.linear.weight = torch.nn.Parameter(weight) + bias_master = layer.linear.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + head.linear.bias = torch.nn.Parameter(bias) + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) torch.distributed.broadcast(A_master, src=0) @@ -397,31 +428,43 @@ def check_head(): B_grad = layer.linear.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH - - B_grad.shape[-1]) - B_grad = torch.cat( - [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1) - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] logger.info('Rank {} head backward (weight_grad): {}'.format( rank, check_equal(B_grad, head.linear.weight.grad))) - if j == k: - bias_grad = layer.linear.bias.grad - bias_grad = torch.chunk(bias_grad, DEPTH)[j] - pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH - - bias_grad.shape[0], ) - bias_grad = torch.cat( - [bias_grad, - torch.zeros(pad_shape, dtype=dtype, device=device)]) - bias_grad = torch.chunk(bias_grad, DEPTH)[i] - logger.info('Rank {} head backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, head.linear.bias.grad))) - else: - logger.info('Rank {} head backward (bias_grad): {}'.format( - rank, - # np.count_nonzero( - # head.linear.bias.grad.detach().cpu().numpy()) == 0)) - head.linear.bias.grad is None)) + bias_grad = layer.linear.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} head backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, head.linear.bias.grad))) + + # B_grad = layer.linear.weight.grad.transpose(0, 1) + # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + # pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH - + # B_grad.shape[-1]) + # B_grad = torch.cat( + # [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1) + # B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + # logger.info('Rank {} head backward (weight_grad): {}'.format( + # rank, check_equal(B_grad, head.linear.weight.grad))) + + # if j == k: + # bias_grad = layer.linear.bias.grad + # bias_grad = torch.chunk(bias_grad, DEPTH)[j] + # pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH - + # bias_grad.shape[0], ) + # bias_grad = torch.cat( + # [bias_grad, + # torch.zeros(pad_shape, dtype=dtype, device=device)]) + # bias_grad = torch.chunk(bias_grad, DEPTH)[i] + # logger.info('Rank {} head backward (bias_grad): {}'.format( + # rank, check_equal(bias_grad, head.linear.bias.grad))) + # else: + # logger.info('Rank {} head backward (bias_grad): {}'.format( + # rank, + # # np.count_nonzero( + # # head.linear.bias.grad.detach().cpu().numpy()) == 0)) + # head.linear.bias.grad is None)) return fwd_end - fwd_start, bwd_end - bwd_start @@ -455,9 +498,13 @@ def check_embed(): logger = get_global_dist_logger() dtype = torch.float32 - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) layer = LAYERS.get_module('ViTPatchEmbedding3D')(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.) @@ -589,12 +636,16 @@ def check_loss(): device = get_current_device() dtype = torch.float32 - j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = A_rank = global_context.get_local_rank(input_parallel_mode) + i = B_rank = global_context.get_local_rank(weight_parallel_mode) + k = C_rank = global_context.get_local_rank(output_parallel_mode) - criterion = LOSSES.get_module('CrossEntropyLoss3D')( - ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT) + criterion = LOSSES.get_module('CrossEntropyLoss3D')() + # ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT) criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) diff --git a/tests/test_models/test_vision_transformer/configs/vit_2d_imagenet.py b/tests/test_models/test_vision_transformer/configs/vit_2d_imagenet.py new file mode 100644 index 000000000000..8cac68b06a43 --- /dev/null +++ b/tests/test_models/test_vision_transformer/configs/vit_2d_imagenet.py @@ -0,0 +1,105 @@ +from colossalai.engine import AMP_TYPE + +BATCH_SIZE = 128 +LEARNING_RATE = 0.001 +IMG_SIZE = 224 +PATCH_SIZE = 16 +DIM = 2048 +NUM_ATTENTION_HEADS = 16 +NUM_CLASSES = 1000 +DEPTH = 48 +NUM_EPOCHS = 300 + +parallel = dict( + data=4, + pipeline=1, + tensor=dict(size=1, mode='2d'), +) + +model = dict( + type='VisionTransformerFromConfig', + tensor_splitting_cfg=dict(type='ViTInputSplitter2D', ), + embedding_cfg=dict( + type='ViTPatchEmbedding2D', + img_size=IMG_SIZE, + patch_size=PATCH_SIZE, + embed_dim=DIM, + ), + token_fusion_cfg=dict(type='ViTTokenFuser2D', + img_size=IMG_SIZE, + patch_size=PATCH_SIZE, + embed_dim=DIM, + drop_rate=0.1), + norm_cfg=dict( + type='LayerNorm2D', + normalized_shape=DIM, + eps=1e-6, + ), + block_cfg=dict( + type='ViTBlock', + attention_cfg=dict(type='ViTSelfAttention2D', + hidden_size=DIM, + num_attention_heads=NUM_ATTENTION_HEADS, + attention_dropout_prob=0., + hidden_dropout_prob=0.1, + checkpoint=True), + droppath_cfg=dict(type='VanillaViTDropPath', ), + mlp_cfg=dict(type='ViTMLP2D', + in_features=DIM, + dropout_prob=0.1, + mlp_ratio=4, + checkpoint=True), + norm_cfg=dict( + type='LayerNorm2D', + normalized_shape=DIM, + eps=1e-6, + ), + ), + head_cfg=dict( + type='ViTHead2D', + hidden_size=DIM, + num_classes=NUM_CLASSES, + ), + embed_dim=DIM, + depth=DEPTH, + drop_path_rate=0., +) + +optimizer = dict( + type='AdamW', + lr=3e-3, + weight_decay=0.3, +) + +loss = dict(type='CrossEntropyLoss2D', reduction=True) + +clip_grad = 1.0 + +num_epochs = NUM_EPOCHS + +fp16 = dict(mode=AMP_TYPE.PARALLEL, initial_scale=2**8) + +# this engine config can be ignored if you want to use default values +engine = dict( + # schedule=None, + schedule=dict(num_microbatches=4), + gradient_handlers=None, + gradient_accumulation=1, + gradient_clipping=1.0, +) + +hooks = [ + dict(type='LogMetricByEpochHook'), + dict(type='LogMemoryByEpochHook'), + dict(type='LogTimingByEpochHook'), + dict(type='Accuracy2DHook'), + dict(type='LossHook'), + dict(type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict(type='CosineAnnealingWarmupLR', + warmup_steps=32)) +] + +logging = dict( + root_path= + f"./vit_2d_imagenet1k_bs{BATCH_SIZE}_{fp16['mode']}_clip_grad{clip_grad}") diff --git a/tests/test_models/test_vision_transformer/configs/vit_3d.py b/tests/test_models/test_vision_transformer/configs/vit_3d.py index ad041efd0a22..5dffcb753b5a 100644 --- a/tests/test_models/test_vision_transformer/configs/vit_3d.py +++ b/tests/test_models/test_vision_transformer/configs/vit_3d.py @@ -4,18 +4,23 @@ import os from pathlib import Path -from colossalai.context import ParallelMode +# from colossalai.context import ParallelMode +from colossalai.engine import AMP_TYPE +from torchvision.transforms import AutoAugmentPolicy IMG_SIZE = 32 PATCH_SIZE = 4 -EMBED_SIZE = 512 -HIDDEN_SIZE = 512 -NUM_HEADS = 8 +EMBED_SIZE = 256 +HIDDEN_SIZE = 256 +NUM_HEADS = 4 NUM_CLASSES = 10 -NUM_BLOCKS = 6 +NUM_BLOCKS = 7 DROP_RATE = 0.1 + BATCH_SIZE = 512 LEARNING_RATE = 0.001 +WEIGHT_DECAY = 3e-2 + DATASET_PATH = Path(os.environ['DATA']) model = dict( @@ -34,8 +39,8 @@ type='LayerNorm3D', normalized_shape=HIDDEN_SIZE, eps=1e-6, - input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, + # input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, + # weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, ), attention_cfg=dict( type='ViTSelfAttention3D', @@ -48,7 +53,7 @@ mlp_cfg=dict( type='ViTMLP3D', hidden_size=HIDDEN_SIZE, - mlp_ratio=1, + mlp_ratio=2, hidden_dropout_prob=DROP_RATE, hidden_act='gelu', ), @@ -56,8 +61,9 @@ norm_cfg=dict(type='LayerNorm3D', normalized_shape=HIDDEN_SIZE, eps=1e-6, - input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT), + # input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, + # weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, + ), head_cfg=dict( type='ViTHead3D', in_features=HIDDEN_SIZE, @@ -69,28 +75,31 @@ ) loss = dict(type='CrossEntropyLoss3D', - input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT, - weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, - reduction=True) - -optimizer = dict(type='Adam', lr=LEARNING_RATE, weight_decay=0) - -train_data = dict(dataset=dict(type='CIFAR10Dataset', - root=DATASET_PATH, - transform_pipeline=[ - dict(type='RandomCrop', - size=IMG_SIZE, - padding=4), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', - mean=[0.4914, 0.4822, 0.4465], - std=[0.2023, 0.1994, 0.2010]), - ]), - dataloader=dict(batch_size=BATCH_SIZE, - pin_memory=True, - shuffle=True, - num_workers=8)) + # input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT, + # weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, + # reduction=True, + ) +# loss = dict(type='CrossEntropyLoss', label_smoothing=0.1) + +optimizer = dict(type='AdamW', lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) + +train_data = dict( + dataset=dict( + type='CIFAR10Dataset', + root=DATASET_PATH, + transform_pipeline=[ + dict(type='RandomCrop', size=IMG_SIZE, padding=4), + # dict(type='RandomHorizontalFlip'), + dict(type='AutoAugment', policy=AutoAugmentPolicy.CIFAR10), + dict(type='ToTensor'), + dict(type='Normalize', + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ]), + dataloader=dict(batch_size=BATCH_SIZE, + pin_memory=True, + shuffle=True, + num_workers=1)) test_data = dict(dataset=dict(type='CIFAR10Dataset', root=DATASET_PATH, @@ -102,34 +111,45 @@ mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), ]), - dataloader=dict(batch_size=400, - pin_memory=True, - num_workers=8)) + dataloader=dict(batch_size=1000, + pin_memory=True)) + +parallel = dict( + data=1, + pipeline=1, + tensor=dict(mode='3d', size=8), +) + +clip_grad = 1.0 + +engine = dict( + schedule=None, + gradient_handlers=None, + gradient_accumulation=1, + gradient_clipping=clip_grad, +) + +num_epochs = 200 hooks = [ dict(type='LogMetricByEpochHook'), - dict(type='LogTimingByEpochHook'), dict(type='LogMemoryByEpochHook'), dict( type='Accuracy3DHook', - input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT, - weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, + # input_parallel_mode=ParallelMode.PARALLEL_3D_OUTPUT, + # weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT, ), dict(type='LossHook'), - dict( - type='LRSchedulerHook', - by_epoch=True, - lr_scheduler_cfg=dict( - type='LinearWarmupLR', - warmup_steps=5 - ) - ), + dict(type='LRSchedulerHook', + by_epoch=False, + lr_scheduler_cfg=dict(type='CosineAnnealingWarmupLR', + warmup_epochs=10, + eta_min=1e-5)), ] -parallel = dict( - data=1, - pipeline=1, - tensor=dict(mode='3d', size=8), -) +# fp16 = dict(mode=AMP_TYPE.TORCH, init_scale=2**6) -num_epochs = 60 +logging = dict( + root_path= + f"./vit_3d_cifar10_bs{BATCH_SIZE}_lr{LEARNING_RATE}_clip_grad{clip_grad}" +) diff --git a/tests/test_models/test_vision_transformer/configs/vit_3d_imagenet.py b/tests/test_models/test_vision_transformer/configs/vit_3d_imagenet.py new file mode 100644 index 000000000000..175c87f0acbc --- /dev/null +++ b/tests/test_models/test_vision_transformer/configs/vit_3d_imagenet.py @@ -0,0 +1,119 @@ +from colossalai.engine import AMP_TYPE +from colossalai.context import ParallelMode + +### VIT-S/16 +IMG_SIZE = 224 +PATCH_SIZE = 16 +EMBED_SIZE = 384 +HIDDEN_SIZE = 384 +MLP_RATIO = 4 +NUM_HEADS = 6 +NUM_CLASSES = 1000 +DROP_RATE = 0.1 +DEPTH = 12 +### + +# ### ViT-L/16 +# IMG_SIZE = 224 +# PATCH_SIZE = 16 +# EMBED_SIZE = 10240 +# HIDDEN_SIZE = 10240 +# MLP_RATIO = 4 +# NUM_HEADS = 64 +# NUM_CLASSES = 1000 +# DROP_RATE = 0.1 +# DEPTH = 64 +# ### + +BATCH_SIZE = 4096 + +parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=8, mode='3d'), +) + +optimizer = dict( + type='AdamW', + lr=3e-3, + weight_decay=0.3, +) + +loss = dict(type='CrossEntropyLoss3D', reduction=True) + +model = dict( + type='VisionTransformerFromConfig', + embedding_cfg=dict( + type='ViTPatchEmbedding3D', + img_size=IMG_SIZE, + patch_size=PATCH_SIZE, + in_chans=3, + embed_size=EMBED_SIZE, + drop_prob=DROP_RATE, + init_method='jax', + ), + block_cfg=dict( + type='ViTBlock', + norm_cfg=dict( + type='LayerNorm3D', + normalized_shape=HIDDEN_SIZE, + eps=1e-6, + ), + attention_cfg=dict(type='ViTSelfAttention3D', + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + attention_probs_dropout_prob=0., + hidden_dropout_prob=DROP_RATE, + checkpoint=True, + init_method='jax'), + droppath_cfg=dict(type='VanillaViTDropPath', ), + mlp_cfg=dict(type='ViTMLP3D', + hidden_size=HIDDEN_SIZE, + mlp_ratio=4, + hidden_dropout_prob=DROP_RATE, + hidden_act='gelu', + checkpoint=True, + init_method='jax'), + ), + norm_cfg=dict(type='LayerNorm3D', normalized_shape=HIDDEN_SIZE, eps=1e-6), + head_cfg=dict( + type='ViTHead3D', + in_features=HIDDEN_SIZE, + num_classes=NUM_CLASSES, + init_method='jax', + ), + embed_dim=HIDDEN_SIZE, + depth=DEPTH, + drop_path_rate=0., +) + +clip_grad = 1.0 + +engine = dict( + schedule=None, + gradient_handlers=None, + gradient_accumulation=1, + gradient_clipping=clip_grad, +) + +num_epochs = 300 + +hooks = [ + dict(type='LogMetricByEpochHook'), + dict(type='LogMemoryByEpochHook'), + dict(type='LogTimingByEpochHook'), + dict(type='Accuracy3DHook', ), + dict(type='LossHook'), + dict(type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict(type='CosineAnnealingWarmupLR', + warmup_steps=32, + eta_min=1e-5)), +] + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +logging = dict( + root_path= + f"./vit_3d_imagenet1k_bs{BATCH_SIZE}_{fp16['mode']}_clip_grad{clip_grad}") + +seed = 42 \ No newline at end of file diff --git a/tests/test_models/test_vision_transformer/test_vit_3d/profiling_3d.py b/tests/test_models/test_vision_transformer/test_vit_3d/profiling_3d.py new file mode 100644 index 000000000000..1044710986a3 --- /dev/null +++ b/tests/test_models/test_vision_transformer/test_vit_3d/profiling_3d.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import time +import colossalai + +import torch +from tqdm import tqdm + +from colossalai import initialize +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_global_dist_logger +from colossalai.utils import print_rank_0, report_memory_usage +from colossalai.utils import empty_cache + +WAIT_STEPS = 3 +WARMUP_STEPS = 50 +ACTIVE_STEPS = 100 +PROFILE_CYCLE = WAIT_STEPS + WARMUP_STEPS + ACTIVE_STEPS + + +def _train_epoch(epoch, engine, dataloader, profiler=None): + logger = get_global_dist_logger() + print_rank_0('[Epoch %d] training start' % (epoch), logger) + engine.train() + data_iter = iter(dataloader) + + train_loss = 0 + batch_cnt = 0 + num_samples = 0 + now = time.time() + epoch_start = now + progress = range(PROFILE_CYCLE) + if gpc.get_global_rank() == 0: + progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1) + for step in progress: + cur_lr = engine.optimizer.param_groups[0]['lr'] + + _, targets, loss = engine.step(data_iter) + if profiler is not None: + profiler.step() + + batch_size = targets[0].size( + 0) * engine._grad_accum_size * gpc.data_parallel_size + train_loss += loss.item() + num_samples += batch_size + batch_cnt += 1 + + batch_time = time.time() - now + now = time.time() + if gpc.get_global_rank() == 0: + print_features = dict(lr='%g' % cur_lr, + loss='%.3f' % (train_loss / (step + 1)), + throughput='%.3f (images/sec)' % + (batch_size / (batch_time + 1e-12))) + progress.set_postfix(**print_features) + + epoch_end = time.time() + epoch_loss = train_loss / batch_cnt + epoch_throughput = num_samples / (epoch_end - epoch_start + 1e-12) + print_rank_0( + '[Epoch %d] Loss: %.3f | Throughput: %.3f (samples/sec)' % + (epoch, epoch_loss, epoch_throughput), logger) + if gpc.get_global_rank() == 0: + report_memory_usage('Memory usage') + + +def test_cifar(): + engine, train_dataloader, test_dataloader = initialize() + + logger = get_global_dist_logger() + logger.info("Train start", ranks=[0]) + data_iter = iter(train_dataloader) + output, targets, loss = engine.step(data_iter) + if gpc.get_global_rank() == 0: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=WAIT_STEPS, + warmup=WARMUP_STEPS, + active=ACTIVE_STEPS), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f'./log_cifar_{gpc.config.parallel.tensor.mode}_{gpc.get_world_size(ParallelMode.GLOBAL)}' + ), + record_shapes=True, + # profile_memory=True, + with_flops=True, + with_modules=True, + ) as prof: + _train_epoch(0, engine, train_dataloader, prof) + + torch.cuda.synchronize() + + print('Test complete. Generating profiling report ...') + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total")) + + torch.distributed.barrier() + else: + _train_epoch(0, engine, train_dataloader) + torch.cuda.synchronize() + torch.distributed.barrier() + + +def test_imagenet(): + from test_vit_3d import build_dali_train, build_dali_test + engine, train_dataloader, test_dataloader = initialize( + train_dataloader=build_dali_train, test_dataloader=build_dali_test) + + logger = get_global_dist_logger() + logger.info("Train start", ranks=[0]) + if gpc.get_global_rank() == 0: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=WAIT_STEPS, + warmup=WARMUP_STEPS, + active=ACTIVE_STEPS), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f'./log_imagenet_{gpc.config.parallel.tensor.mode}_{gpc.get_world_size(ParallelMode.GLOBAL)}' + ), + record_shapes=True, + # profile_memory=True, + with_flops=True, + with_modules=True, + ) as prof: + _train_epoch(0, engine, train_dataloader, prof) + + torch.cuda.synchronize() + + print('Test complete. Generating profiling report ...') + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total")) + + torch.distributed.barrier() + else: + _train_epoch(0, engine, train_dataloader) + torch.cuda.synchronize() + torch.distributed.barrier() + + +def test_allgather_n_broadcast(): + from colossalai.communication import all_gather + from colossalai.initialize import init_dist + from colossalai.utils import get_current_device + from tqdm import trange + + init_dist() + + logger = get_global_dist_logger() + + BATCH_SIZE = 4024 + HIDDEN_SIZE = 512 + DEPTH = torch.distributed.get_world_size() + SEQ_LENGTH = 128 + + logger.info("Test start", ranks=[0]) + if gpc.get_global_rank() == 0: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=1, + warmup=5, + active=10, + repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f'./log_allgather_n_broadcast_{gpc.get_world_size(ParallelMode.GLOBAL)}' + ), + record_shapes=True, + # profile_memory=True, + with_flops=True, + with_modules=True, + ) as prof: + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE // DEPTH) + for _ in trange(16): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = all_gather(x, -1, ParallelMode.GLOBAL) + prof.step() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + for _ in trange(16): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = x.clone() + torch.distributed.broadcast(x, src=0) + prof.step() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + print('Test complete. Generating profiling report ...') + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total")) + torch.distributed.barrier() + else: + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE // DEPTH) + for _ in range(16): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = all_gather(x, -1, ParallelMode.GLOBAL) + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + for _ in range(16): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = x.clone() + torch.distributed.broadcast(x, src=0) + + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.distributed.barrier() + + +def test_layer(): + from colossalai.initialize import init_dist + from colossalai.utils import get_current_device + from tqdm import trange + from colossalai.nn.layer.parallel_3d import Linear3D, LayerNorm3D + + CONFIG = dict(parallel=dict(pipeline=1, tensor=dict(mode='3d', size=8)), + seed=0) + + init_dist(config=CONFIG) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + gpc.set_seed() + + logger = get_global_dist_logger() + + BATCH_SIZE = 512 + HIDDEN_SIZE = 4096 + DEPTH = colossalai.nn.layer.parallel_3d._utils.get_depth_from_env() + SEQ_LENGTH = 128 + linear1 = Linear3D(HIDDEN_SIZE, HIDDEN_SIZE * 4) + linear2 = Linear3D(HIDDEN_SIZE * 4, HIDDEN_SIZE) + dropout = torch.nn.Dropout(0.0) + norm = LayerNorm3D(HIDDEN_SIZE, eps=1e-5) + layer = torch.nn.Sequential(linear1, linear2, dropout, norm) + + logger.info("Test start", ranks=[0]) + tensor_shape = (BATCH_SIZE // DEPTH ** 2, SEQ_LENGTH, HIDDEN_SIZE // DEPTH) + + if gpc.get_global_rank() == 0: + for _ in trange(WARMUP_STEPS): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = layer(x) + grad = torch.randn(x.shape, + dtype=torch.float, + device=get_current_device()) + x.backward(grad) + empty_cache() + start = time.time() + for _ in trange(ACTIVE_STEPS): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = layer(x) + grad = torch.randn(x.shape, + dtype=torch.float, + device=get_current_device()) + x.backward(grad) + empty_cache() + torch.cuda.synchronize() + end = time.time() + avg_step_time = (end - start) / ACTIVE_STEPS + throughput = ACTIVE_STEPS * BATCH_SIZE / (end - start) + logger.info('Avg step time = {:.3f} s | Throughput = {:.3f} /s'.format(avg_step_time, throughput)) + else: + for _ in range(WARMUP_STEPS + ACTIVE_STEPS): + x = torch.randn(tensor_shape, + dtype=torch.float, + device=get_current_device()) + x = layer(x) + grad = torch.randn(x.shape, + dtype=torch.float, + device=get_current_device()) + x.backward(grad) + empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + # if gpc.get_global_rank() == 0: + # with torch.profiler.profile( + # activities=[ + # torch.profiler.ProfilerActivity.CPU, + # torch.profiler.ProfilerActivity.CUDA, + # ], + # schedule=torch.profiler.schedule(wait=WAIT_STEPS, + # warmup=WARMUP_STEPS, + # active=ACTIVE_STEPS), + # on_trace_ready=torch.profiler.tensorboard_trace_handler( + # f'./log_layer_3d_{gpc.get_world_size(ParallelMode.GLOBAL)}' + # ), + # record_shapes=True, + # # profile_memory=True, + # with_flops=True, + # with_modules=True, + # ) as prof: + # for _ in trange(PROFILE_CYCLE): + # x = torch.randn(tensor_shape, + # dtype=torch.float, + # device=get_current_device()) + # x = layer(x) + # grad = torch.randn(x.shape, + # dtype=torch.float, + # device=get_current_device()) + # x.backward(grad) + # prof.step() + + # torch.cuda.synchronize() + + # report_memory_usage('Memory usage') + # print('Test complete. Generating profiling report ...') + # print( + # prof.key_averages(group_by_input_shape=True).table( + # sort_by="cuda_time_total")) + # torch.distributed.barrier() + # else: + # for _ in range(PROFILE_CYCLE): + # x = torch.randn(tensor_shape, + # dtype=torch.float, + # device=get_current_device()) + # x = layer(x) + # grad = torch.randn(x.shape, + # dtype=torch.float, + # device=get_current_device()) + # x.backward(grad) + + # torch.cuda.synchronize() + # torch.distributed.barrier() + + +if __name__ == '__main__': + # test_cifar() + # test_imagenet() + # test_allgather_n_broadcast() + test_layer() diff --git a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py b/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py index 7bee2c78b4b2..8a450581ef33 100644 --- a/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py +++ b/tests/test_models/test_vision_transformer/test_vit_3d/test_vit_3d.py @@ -1,105 +1,205 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import time -from pathlib import Path -import torch -from tqdm import tqdm +import glob +import os import colossalai +import nvidia.dali.fn as fn +import nvidia.dali.tfrecord as tfrec +import torch from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer -from colossalai.trainer.metric import Accuracy3D -from colossalai.utils import print_rank_0 +from colossalai.utils import (get_global_multitimer, + set_global_multitimer_status) +from nvidia.dali import types +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator + +DATASET_PATH = str(os.environ['DATA']) +# imagenet 100 +# TRAIN_RECS = '/project/scratch/p200012/imagenet-100/train/*' +# VAL_RECS = '/project/scratch/p200012/imagenet-100/validation/*' +# TRAIN_IDX = '/project/scratch/p200012/imagenet-100/idx_files/train/*' +# VAL_IDX = '/project/scratch/p200012/imagenet-100/idx_files/validation/*' + +# imagenet 1000 +TRAIN_RECS = DATASET_PATH + '/train/*' +VAL_RECS = DATASET_PATH + '/validation/*' +TRAIN_IDX = DATASET_PATH + '/idx_files/train/*' +VAL_IDX = DATASET_PATH + '/idx_files/validation/*' + + +class DaliDataloader(DALIClassificationIterator): + def __init__(self, + tfrec_filenames, + tfrec_idx_filenames, + shard_id=0, + num_shards=1, + batch_size=128, + num_threads=4, + resize=256, + crop=224, + prefetch=2, + training=True, + gpu_aug=False, + cuda=True): + pipe = Pipeline( + batch_size=batch_size, + num_threads=num_threads, + device_id=torch.cuda.current_device() if cuda else None, + seed=1024) + with pipe: + inputs = fn.readers.tfrecord(path=tfrec_filenames, + index_path=tfrec_idx_filenames, + random_shuffle=training, + shard_id=shard_id, + num_shards=num_shards, + initial_fill=10000, + read_ahead=True, + prefetch_queue_depth=prefetch, + name='Reader', + features={ + 'image/encoded': + tfrec.FixedLenFeature( + (), tfrec.string, ""), + 'image/class/label': + tfrec.FixedLenFeature([1], + tfrec.int64, + -1), + }) + images = inputs["image/encoded"] + + if training: + images = fn.decoders.image( + images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.random_resized_crop( + images, size=crop, device='gpu' if gpu_aug else 'cpu') + flip_lr = fn.random.coin_flip(probability=0.5) + else: + # decode jpeg and resize + images = fn.decoders.image( + images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.resize(images, + device='gpu' if gpu_aug else 'cpu', + resize_x=resize, + resize_y=resize, + dtype=types.FLOAT, + interp_type=types.INTERP_TRIANGULAR) + flip_lr = False + + # center crop and normalise + images = fn.crop_mirror_normalize(images, + dtype=types.FLOAT, + crop=(crop, crop), + mean=[127.5], + std=[127.5], + mirror=flip_lr) + label = inputs["image/class/label"] - 1 # 0-999 + # LSG: element_extract will raise exception, let's flatten outside + # label = fn.element_extract(label, element_map=0) # Flatten + if cuda: # transfer data to gpu + pipe.set_outputs(images.gpu(), label.gpu()) + else: + pipe.set_outputs(images, label) + + pipe.build() + last_batch_policy = 'DROP' if training else 'PARTIAL' + super().__init__(pipe, + reader_name="Reader", + auto_reset=True, + last_batch_policy=last_batch_policy) + + def __iter__(self): + # if not reset (after an epoch), reset; if just initialize, ignore + if self._counter >= self._size or self._size < 0: + self.reset() + return self + + def __next__(self): + data = super().__next__() + img, label = data[0]['data'], data[0]['label'] + label = label.squeeze() + return (img, ), (label, ) + + +def build_dali_train(): + return DaliDataloader(sorted(glob.glob(TRAIN_RECS)), + sorted(glob.glob(TRAIN_IDX)), + batch_size=gpc.config.BATCH_SIZE // + gpc.data_parallel_size, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=True, + gpu_aug=True, + cuda=True) + + +def build_dali_test(): + return DaliDataloader(sorted(glob.glob(VAL_RECS)), + sorted(glob.glob(VAL_IDX)), + batch_size=gpc.config.BATCH_SIZE // + gpc.data_parallel_size, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=False, + gpu_aug=True, + cuda=True) + + +def train_cifar(): + # init dist + engine, train_dataloader, test_dataloader = colossalai.initialize() + logger = get_global_dist_logger() + set_global_multitimer_status(True) -CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_3d.py') + logger.info("Engine is built", ranks=[0]) + trainer = Trainer(engine=engine, + timer=get_global_multitimer(), + verbose=True) + logger.info("Trainer is built", ranks=[0]) -def _train_epoch(epoch, engine): - logger = get_global_dist_logger() - print_rank_0('[Epoch %d] training start' % (epoch), logger) - engine.train() - - train_loss = 0 - batch_cnt = 0 - num_samples = 0 - now = time.time() - epoch_start = now - progress = range(engine._schedule.num_steps) - if gpc.get_global_rank() == 0: - progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1) - for step in progress: - cur_lr = engine.get_lr() - - _, targets, loss = engine.step() - - batch_size = targets[0].size(0) - train_loss += loss.item() - num_samples += batch_size - batch_cnt += 1 - - batch_time = time.time() - now - now = time.time() - if gpc.get_global_rank() == 0: - print_features = dict(lr='%g' % cur_lr, - loss='%.3f' % (train_loss / (step + 1)), - throughput='%.3f (images/sec)' % - (batch_size / (batch_time + 1e-12))) - progress.set_postfix(**print_features) - - epoch_end = time.time() - epoch_loss = train_loss / batch_cnt - epoch_throughput = num_samples / (epoch_end - epoch_start + 1e-12) - print_rank_0( - '[Epoch %d] Loss: %.3f | Throughput: %.3f (samples/sec)' % - (epoch, epoch_loss, epoch_throughput), logger) - - -def _eval(epoch, engine): - logger = get_global_dist_logger() - engine.eval() - - eval_loss = 0 - acc = Accuracy3D(True, ParallelMode.PARALLEL_3D_OUTPUT, - ParallelMode.PARALLEL_3D_WEIGHT) - total = 0 - with torch.no_grad(): - for _ in range(engine._schedule.num_steps): - outputs, targets, loss = engine.step() - if isinstance(outputs, (list, tuple)): - outputs = outputs[0] - if isinstance(targets, (list, tuple)): - targets = targets[0] - eval_loss += loss.item() - acc.update(outputs, targets) - total += targets.size(0) - - print_rank_0( - '[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' % - (epoch, eval_loss / engine._schedule.num_steps, - acc.get_accumulated_value() * 100), logger) - - -def train(): + logger.info("Train start", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + # epochs=gpc.config.num_epochs, + epochs=5, + hooks_cfg=gpc.config.hooks, + display_progress=True, + test_interval=1) + + +def train_imagenet(): # init dist - engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH) + engine, train_dataloader, test_dataloader = colossalai.initialize( + train_dataloader=build_dali_train, test_dataloader=build_dali_test) logger = get_global_dist_logger() + set_global_multitimer_status(True) logger.info("Engine is built", ranks=[0]) - trainer = Trainer(engine=engine, verbose=True) + trainer = Trainer(engine=engine, + timer=get_global_multitimer(), + verbose=True) logger.info("Trainer is built", ranks=[0]) logger.info("Train start", ranks=[0]) trainer.fit(train_dataloader=train_dataloader, - test_dataloader=test_dataloader, + # test_dataloader=test_dataloader, epochs=gpc.config.num_epochs, + max_steps=100, hooks_cfg=gpc.config.hooks, display_progress=True, test_interval=1) if __name__ == '__main__': - train() + # train_cifar() + train_imagenet()