Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions colossalai/gemini/chunk/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@

from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter


def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
from colossalai.utils import is_ddp_ignored


def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
params_size_arr = np.array(params_size)

std = np.std(params_size_arr)
Expand Down Expand Up @@ -56,7 +53,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
if is_ddp_ignored(param):
continue

param_key = param.process_group.dp_world_size()
Expand Down
6 changes: 3 additions & 3 deletions colossalai/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn as nn

from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored


def init_chunk_manager(model: nn.Module,
Expand All @@ -34,7 +34,7 @@ def init_chunk_manager(model: nn.Module,
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params

params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
total_size = sum(params_sizes) / 1024**2

dist.barrier()
Expand Down
4 changes: 2 additions & 2 deletions colossalai/nn/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
from colossalai.utils import disposable, get_current_device, is_ddp_ignored

_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}

Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self,
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"

params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
Expand Down
12 changes: 6 additions & 6 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook

from .reducer import Reducer
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self,
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
Expand Down Expand Up @@ -116,7 +116,7 @@ def backward(self, loss: torch.Tensor):
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self,
for p in param_order.generate():
assert isinstance(p, ColoParameter)

if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
p.data = p.data.half()
continue

Expand All @@ -256,7 +256,7 @@ def __init__(self,
self.chunk_manager.close_all_groups()
self._cast_buffers()

params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
Expand Down Expand Up @@ -303,7 +303,7 @@ def forward(self, *args, **kwargs):

def _setup_grads_ptr(self):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
p.grad = None

Expand Down
42 changes: 33 additions & 9 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param, disposable)
from .common import (
clip_grad_norm_fp32,
conditional_context,
copy_tensor_parallel_attributes,
count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
is_no_pp_or_last_stage,
is_tp_rank_0,
is_using_ddp,
is_using_pp,
is_using_sequence,
multi_tensor_applier,
param_is_not_tensor_parallel_duplicate,
print_rank_0,
switch_virtual_pipeline_parallel_rank,
sync_model_param,
)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .memory import (
colo_device_memory_capacity,
colo_device_memory_used,
colo_get_cpu_memory_capacity,
colo_set_cpu_memory_capacity,
colo_set_process_memory_fraction,
report_memory_usage,
)
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer

__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_ddp_ignored',
'is_dp_rank_0',
'is_tp_rank_0',
'is_no_pp_or_last_stage',
Expand Down
8 changes: 6 additions & 2 deletions colossalai/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,18 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)


def is_ddp_ignored(p):
return getattr(p, '_ddp_to_ignore', False)


def _calc_l2_norm(grads):
# we should not
# we should not
global fused_optim

if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()

norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
Expand Down
5 changes: 3 additions & 2 deletions colossalai/zero/utils/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored


class TrainingPhase(Enum):
Expand All @@ -24,7 +25,7 @@ def __init__(self, gemini_manager: GeminiManager) -> None:
self._training_phase = TrainingPhase.FORWARD

def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
Expand All @@ -37,7 +38,7 @@ def pre_op(self, params):
self._gemini_manager.record_model_data_volume()

def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
Expand Down