diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml index a74a31cf9297..16100cafe87d 100644 --- a/.github/workflows/nv-nightly.yml +++ b/.github/workflows/nv-nightly.yml @@ -38,6 +38,10 @@ jobs: git rev-parse --short HEAD pip install . + - name: Install datasets + run: | + pip install datasets + - name: Install deepspeed run: | pip install .[dev,1bit,autotuning,inf] diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index ab554297159b..88dc41867d1f 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -492,11 +492,10 @@ def _run_after_backward_function(sub_module): # post backward hook self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) + @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) - prev_grad_state = torch.is_grad_enabled( - ) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled - torch.set_grad_enabled(False) + global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) @@ -504,8 +503,8 @@ def pre_sub_module_forward_function(self, sub_module): param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state) - torch.set_grad_enabled(prev_grad_state) + param_coordinator.fetch_sub_module(sub_module, forward=True) + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) @torch.no_grad() @@ -514,7 +513,7 @@ def post_sub_module_forward_function(self, sub_module): force=False) param_coordinator = self.get_param_coordinator(training=sub_module.training) - param_coordinator.release_sub_module(sub_module, backward=False) + param_coordinator.release_sub_module(sub_module) see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) @@ -535,7 +534,7 @@ def post_sub_module_backward_function(self, sub_module): f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True) + self.get_param_coordinator(training=True).release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 030a050b88e2..940519d7db85 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -652,7 +652,6 @@ def __init__( partitions: List[Tensor], world_size: int, use_secondary_tensor=False, - forward=False, quantization=None, ) -> None: self.allgather_handle = allgather_handle @@ -660,7 +659,6 @@ def __init__( self.partitions = partitions self.world_size = world_size self.use_secondary_tensor = use_secondary_tensor - self.forward = forward self.complete = False self.quantization = quantization @@ -691,7 +689,7 @@ def wait(self) -> None: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" partitions: List[Tensor] = [] ds_tensor_numel = param.ds_tensor.ds_numel - if self.use_secondary_tensor and not self.forward: + if self.use_secondary_tensor: ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups for rank in range(self.world_size): param_start = rank * ds_tensor_numel @@ -946,7 +944,7 @@ def __init__( self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size() self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group) self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup() - print_rank_0(f"hpZeRO group size? {self.num_ranks_in_param_group}", force=True) + print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True) logger.debug( "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} " @@ -1115,10 +1113,10 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): param_list = [cls] return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) - def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group): + def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group): partition_sz = sum(p.ds_tensor.ds_numel for p in params) - use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward + use_secondary_tensor = params[0].ds_secondary_tensor is not None if use_secondary_tensor: partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) @@ -1148,12 +1146,10 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc partitions=partitions, world_size=world_size, use_secondary_tensor=use_secondary_tensor, - forward=forward, ) @instrument_w_nvtx def all_gather_coalesced(params: Iterable[Parameter], - forward: bool = True, safe_mode: bool = False, quantize: bool = False) -> AllGatherCoalescedHandle: @@ -1172,8 +1168,8 @@ def all_gather_coalesced(params: Iterable[Parameter], ds_process_group = self.ds_process_group rank_in_group = self.rank world_size = self.dp_world_size - use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward - if self.zero_param_process_group and not forward: + use_secondary_tensor = params[0].ds_secondary_tensor is not None + if self.zero_param_process_group and use_secondary_tensor: ds_process_group = self.zero_param_process_group #intragroup rank_in_group = self.rank_in_group world_size = self.num_ranks_in_param_group @@ -1253,8 +1249,7 @@ def all_gather_coalesced(params: Iterable[Parameter], dtype_params[p.ds_tensor.dtype].append(p) handles = [] for dtype, params in dtype_params.items(): - handles.append( - _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group)) + handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) return MultipleAllGatherHandles(handles) @@ -1315,11 +1310,10 @@ def all_gather_coalesced(params: Iterable[Parameter], partitions=None, world_size=world_size, use_secondary_tensor=use_secondary_tensor, - forward=forward, quantization=quant_info, ) - def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) @@ -1473,7 +1467,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: - self._partition_param_sec(param, has_been_updated=has_been_updated) + self._partition_param_sec(param) self._partition_param(param, has_been_updated=has_been_updated) param.ds_status = ZeroParamStatus.NOT_AVAILABLE @@ -1608,30 +1602,28 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False): ##support for NVME secondary param offload #print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True) if param.ds_status is ZeroParamStatus.AVAILABLE: - if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned - - return #check padding tensor_size = self._aligned_size(param) partition_size = tensor_size // self.dp_world_size secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group) - final_location = None - secondary_partitioned_tensor = torch.empty(secondary_partition_size, - dtype=param.dtype, - device=self.remote_device) - - if self.pin_memory: - secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory() - # quantize the tensor if it's not trainable - if not param.requires_grad and self.quantized_nontrainable_weights: - secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( - secondary_partitioned_tensor) - secondary_partitioned_tensor.requires_grad = False - param.ds_secondary_tensor = secondary_partitioned_tensor - param.ds_secondary_tensor.ds_numel = secondary_partition_size - param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE - param.ds_secondary_tensor.final_location = final_location + if param.ds_secondary_tensor is None: + final_location = None + secondary_partitioned_tensor = torch.empty(secondary_partition_size, + dtype=param.dtype, + device=self.remote_device) + + if self.pin_memory: + secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory() + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + secondary_partitioned_tensor) + secondary_partitioned_tensor.requires_grad = False + param.ds_secondary_tensor = secondary_partitioned_tensor + param.ds_secondary_tensor.ds_numel = secondary_partition_size + param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_secondary_tensor.final_location = final_location #use rank in group for secondary tensor secondary_start = secondary_partition_size * self.rank_in_group diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 299138e84712..bcbf544cd07e 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -387,7 +387,7 @@ def _is_currently_on_nvme(param): @instrument_w_nvtx @torch.no_grad() - def release_sub_module(self, submodule: Module, backward: bool) -> None: + def release_sub_module(self, submodule: Module) -> None: """release the parameters of a sub module, assuming they meet conditions to be released.""" params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( @@ -395,7 +395,7 @@ def release_sub_module(self, submodule: Module, backward: bool) -> None: for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param, backward) + self.__release_param(param) @instrument_w_nvtx @torch.no_grad() @@ -408,7 +408,7 @@ def release_and_reset_all(self, module: Module) -> None: # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue param.ds_active_sub_modules.clear() - self.__release_param(param, backward=False) + self.__release_param(param) for param in iter_params(module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: @@ -439,19 +439,27 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: all_gather_numel += param.ds_numel if partitioned_params: - partitioned_params self.__n_available_params += all_gather_numel - with get_accelerator().stream(self.__allgather_stream): - event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER - self.__profiler.start_event(event_name) - handle = partitioned_params[0].all_gather_coalesced(partitioned_params, - forward=forward, - quantize=quantize) - self.__profiler.stop_event(event_name, all_gather_numel) - - for param in partitioned_params: - assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() - self.__inflight_param_registry[param] = handle + # here we need to handle a special case where some of the parameters have a valid hpz secondary tensor (e.g. they are not trainable so their secondary tensor never expire) but others do not. + partitioned_params_with_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is not None + ] + partitioned_params_without_secondary_tensors = [ + p for p in partitioned_params if p.ds_secondary_tensor is None + ] + for param_group in [ + partitioned_params_with_secondary_tensors, partitioned_params_without_secondary_tensors + ]: + if not param_group: + continue + with get_accelerator().stream(self.__allgather_stream): + event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER + self.__profiler.start_event(event_name) + handle = param_group[0].all_gather_coalesced(param_group, quantize=quantize) + self.__profiler.stop_event(event_name, all_gather_numel) + for param in param_group: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle # Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU swap_persisted_params = [ @@ -461,11 +469,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params) @instrument_w_nvtx - def __release_param(self, param: Parameter, backward: bool) -> None: + def __release_param(self, param: Parameter) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition(backward=backward) + param.partition() self.__n_available_params -= param.ds_numel @instrument_w_nvtx diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 56a031fb00d4..77fcbdadb1d1 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1927,7 +1927,7 @@ def _post_step(self, timer_names): if self.swap_optimizer: self.optimizer_swapper.log_timers() - # self.invalidate_secondary_tensor() # given that we want hpz in forward pass when no_grad is set, we need to keep the secondary tensor + self.invalidate_secondary_tensor() self.timers.log(timer_names) diff --git a/tests/unit/runtime/zero/test_zeropp.py b/tests/unit/runtime/zero/test_zeropp.py index 27ec7269afc6..545ed98ad2ef 100644 --- a/tests/unit/runtime/zero/test_zeropp.py +++ b/tests/unit/runtime/zero/test_zeropp.py @@ -14,6 +14,12 @@ from deepspeed.runtime.zero.config import DeepSpeedZeroConfig import torch.nn as nn +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.utils.data import DataLoader + +import numpy as np class NNModel(nn.Module): @@ -40,9 +46,16 @@ def _assert_no_secondary_tensor_group(model: Module) -> None: assert param.ds_zero_param_process_group is None -def _assert_secondary_tensor_size(model: Module) -> None: +def _check_secondary_tensor_existence(model: Module) -> None: for _, param in model.named_parameters(): - assert param.ds_secondary_tensor is not None + if param.ds_secondary_tensor is not None: + return True + return False + + +def _assert_secondary_tensor_size(model: Module) -> None: + for name, param in model.named_parameters(): + assert param.ds_secondary_tensor is not None, f"param {param.ds_id}:{name} does not have secondary tensor" assert param.ds_secondary_tensor.size()[0] % param.ds_tensor.size()[0] == 0 @@ -50,7 +63,7 @@ def _assert_secondary_tensor_size(model: Module) -> None: #Assert when zpg=1 that secondary group and tensors are invalid @pytest.mark.sequential @pytest.mark.parametrize("h_dim", [1024]) -@pytest.mark.parametrize("n_layers", [4, 9]) +@pytest.mark.parametrize("n_layers", [9]) @pytest.mark.parametrize("zpg", [1, 2, 4]) class TestZeroPPConfigSweep(DistributedTest): world_size = 4 @@ -92,3 +105,172 @@ def test(self, h_dim: int, n_layers: int, zpg: int) -> None: loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + def test_eval(self, h_dim: int, n_layers: int, zpg: int) -> None: + # in this test case, we are testing that hpz should be enabled when eval mode is on + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "zero_hpz_partition_size": zpg, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device) + dist.barrier() + if zpg == 1: + _assert_no_secondary_tensor_group(model) + + for n, batch in enumerate(data_loader): + if zpg != 1: + # here we check that the hpz is enabled when the previous iteration does not update the model + _assert_secondary_tensor_size(model) + with torch.no_grad(): + loss = model(batch[0], batch[1]) + + def test_gradient_accumulation(self, h_dim: int, n_layers: int, zpg: int) -> None: + # in this test case, we are testing that hpz should be enabled for the intermediate gradient accumulation steps + # In this test, we should disable loss_scale + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 3, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "zero_hpz_partition_size": zpg, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0., + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device) + dist.barrier() + if zpg == 1: + _assert_no_secondary_tensor_group(model) + + for n, batch in enumerate(data_loader): + if n == 0 and zpg != 1: + _assert_secondary_tensor_size(model) + # here we cannot assert that secondary tensor does not exist because the gradient is likely overflowed as we use random data + if n > 0 and n % 3 != 0 and zpg != 1: + # if the previous iteration does not update the model, then the hpz should be enabled + assert _check_secondary_tensor_existence(model), f"n={n}" + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + +@pytest.mark.nightly +@pytest.mark.parametrize("model_name", ["gpt2"]) +class TestZeroPPConvergence(DistributedTest): + world_size = 4 + + def load_and_prepare_data(self, model_name): + """Load model, tokenizer and dataset, and prepare data loader.""" + from datasets import load_dataset + + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Load and tokenize dataset + dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]') + + def tokenize_function(examples): + # Tokenize and ensure 'labels' are the same as 'input_ids' + tokenized_output = tokenizer(examples["text"], padding="max_length", truncation=True, return_tensors='pt') + tokenized_output["labels"] = tokenized_output["input_ids"].clone() + return tokenized_output + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels']) + + # Create data loader + data_loader = DataLoader(tokenized_dataset, batch_size=1, shuffle=False) + return model, data_loader + + def get_loss(self, model, data_loader, config_dict, step=500): + """Train the model and calculate average loss.""" + # Initialize DeepSpeed + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + dist.barrier() + model.train() + + # Training loop + losses = [] + for n, batch in enumerate(data_loader): + if n >= step: + break + batch = {k: v.to(model.device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + model.backward(loss) + model.step() + losses.append(loss.item()) + + return np.nanmean(losses[-100:]) + + def get_config_dict(self, use_quantized_weights=False, use_hpz=False): + """Generate the configuration dictionary for DeepSpeed.""" + config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5 + } + }, + "fp16": { + "enabled": True + } + } + if use_quantized_weights: + config["zero_optimization"]["zero_quantized_weights"] = True + if use_hpz: + config["zero_optimization"]["zero_hpz_partition_size"] = self.world_size // 2 + return config + + def test(self, model_name): + torch.manual_seed(0) + model, data_loader = self.load_and_prepare_data(model_name) + zeropp_loss = self.get_loss(model, data_loader, self.get_config_dict(use_quantized_weights=True, use_hpz=True)) + model, data_loader = self.load_and_prepare_data(model_name) + baseline_loss = self.get_loss(model, data_loader, self.get_config_dict()) + + # Output and assert + print(f"zeropp_loss={zeropp_loss}, baseline_loss={baseline_loss}") + assert zeropp_loss < baseline_loss * 1.1, f"zeropp_loss={zeropp_loss}, baseline_loss={baseline_loss}"