Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions colossalai/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats:
self._warmup = True
self._comp_cuda_demand_time = 0

def reset_attributes(self):
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0

def is_warmup(self):
return self._warmup

def memstats(self):
"""memstats

Expand All @@ -73,12 +84,7 @@ def post_iter(self):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self._compute_idx = -1
self._h2d_volume = 0
self._d2h_volume = 0
self._layout_time = 0
self._evict_time = 0
self._comp_cuda_demand_time = 0
self.reset_attributes()

def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided
Expand Down
23 changes: 23 additions & 0 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,35 @@ def __init__(self,

self._logger = get_dist_logger()

def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()

def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"

args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()

if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
Expand Down
122 changes: 122 additions & 0 deletions tests/test_gemini/update/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed


def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()

for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', ['gpt2'])
def exam_inference(placement_policy, model_name: str):
set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

torch_model = model_builder().cuda()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()

for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)

world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)

model.eval()
torch_model.eval()

set_seed(dist.get_rank() * 3 + 128)
train_dataloader = iter(train_dataloader)

def train_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)

def inference_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
with torch.no_grad():
torch_output = torch_model(input_ids)
torch_loss = criterion(torch_output.float(), label)
zero_output = model(input_ids)
zero_loss = criterion(zero_output.float(), label)
assert_close(torch_loss, zero_loss)

train_iter()
inference_iter()
train_iter()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_inference()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_inference(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_inference(1)