Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
48ddf31
Fix potential secondary tensor out of sync issue
HeyangQin Dec 19, 2023
1bc053f
remove all forward/backward flag and use existence of secondary tenso…
HeyangQin Jan 6, 2024
03e85c5
fix conflict with master
HeyangQin Jan 6, 2024
b99d25a
fix format
HeyangQin Jan 6, 2024
165a854
enable hpz for lora frozen weights
HeyangQin Jan 8, 2024
aad9403
Merge branch 'master' into HeyangQin/mixz_hpz_fix
HeyangQin Jan 8, 2024
a516257
change param partitioning logic
HeyangQin Jan 8, 2024
f86e880
Merge branch 'HeyangQin/mixz_hpz_fix' of https://github.com/microsoft…
HeyangQin Jan 8, 2024
71bbf1c
bring back profiler
HeyangQin Jan 8, 2024
9f25158
Merge branch 'master' into HeyangQin/mixz_hpz_fix
HeyangQin Jan 12, 2024
7d20df7
bring back forward flag for profiling
HeyangQin Jan 12, 2024
6bbb2ae
update unit test
HeyangQin Jan 14, 2024
ae50855
add convergence test
HeyangQin Jan 17, 2024
ecfd60f
tmp change to workflow for test
HeyangQin Jan 17, 2024
63512ef
relax time out for convergence test
HeyangQin Jan 24, 2024
786eed0
revert tmp changes
HeyangQin Jan 24, 2024
67d5e2e
add nightly flag
HeyangQin Jan 24, 2024
86a72a5
Merge branch 'master' into HeyangQin/mixz_hpz_fix
HeyangQin Jan 24, 2024
b88f001
Fix whitespace in nv-torch-latest-v100.yml
HeyangQin Jan 24, 2024
1485bfa
fix format
HeyangQin Jan 24, 2024
77d227e
fix incorrect format by clang
HeyangQin Jan 24, 2024
b418841
skip test if datasets is not installed
HeyangQin Jan 24, 2024
5e684e6
fix format
HeyangQin Jan 25, 2024
8857c07
Update nv-nightly.yml
mrwyattii Jan 25, 2024
d58631f
remove test skip
mrwyattii Jan 25, 2024
345f653
re-add import
mrwyattii Jan 25, 2024
d6c4b1a
Merge branch 'master' into HeyangQin/mixz_hpz_fix
mrwyattii Jan 25, 2024
e17de06
Update tests/unit/runtime/zero/test_zeropp.py
mrwyattii Jan 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/nv-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ jobs:
git rev-parse --short HEAD
pip install .

- name: Install datasets
run: |
pip install datasets

- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,inf]
Expand Down
13 changes: 6 additions & 7 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,20 +492,19 @@ def _run_after_backward_function(sub_module):
# post backward hook
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
prev_grad_state = torch.is_grad_enabled(
) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
torch.set_grad_enabled(False)

global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
torch.set_grad_enabled(prev_grad_state)
param_coordinator.fetch_sub_module(sub_module, forward=True)

see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

@torch.no_grad()
Expand All @@ -514,7 +513,7 @@ def post_sub_module_forward_function(self, sub_module):
force=False)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.release_sub_module(sub_module, backward=False)
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
Expand All @@ -535,7 +534,7 @@ def post_sub_module_backward_function(self, sub_module):
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)
self.get_param_coordinator(training=True).release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
Expand Down
60 changes: 26 additions & 34 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,15 +652,13 @@ def __init__(
partitions: List[Tensor],
world_size: int,
use_secondary_tensor=False,
forward=False,
quantization=None,
) -> None:
self.allgather_handle = allgather_handle
self.params = params
self.partitions = partitions
self.world_size = world_size
self.use_secondary_tensor = use_secondary_tensor
self.forward = forward
self.complete = False
self.quantization = quantization

Expand Down Expand Up @@ -691,7 +689,7 @@ def wait(self) -> None:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
partitions: List[Tensor] = []
ds_tensor_numel = param.ds_tensor.ds_numel
if self.use_secondary_tensor and not self.forward:
if self.use_secondary_tensor:
ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups
for rank in range(self.world_size):
param_start = rank * ds_tensor_numel
Expand Down Expand Up @@ -946,7 +944,7 @@ def __init__(
self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size()
self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group)
self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup()
print_rank_0(f"hpZeRO group size? {self.num_ranks_in_param_group}", force=True)
print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True)

logger.debug(
"hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} "
Expand Down Expand Up @@ -1115,10 +1113,10 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
param_list = [cls]
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group):
def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group):
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward
use_secondary_tensor = params[0].ds_secondary_tensor is not None

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
Expand Down Expand Up @@ -1148,12 +1146,10 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc
partitions=partitions,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
forward=forward,
)

@instrument_w_nvtx
def all_gather_coalesced(params: Iterable[Parameter],
forward: bool = True,
safe_mode: bool = False,
quantize: bool = False) -> AllGatherCoalescedHandle:

Expand All @@ -1172,8 +1168,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
ds_process_group = self.ds_process_group
rank_in_group = self.rank
world_size = self.dp_world_size
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward
if self.zero_param_process_group and not forward:
use_secondary_tensor = params[0].ds_secondary_tensor is not None
if self.zero_param_process_group and use_secondary_tensor:
ds_process_group = self.zero_param_process_group #intragroup
rank_in_group = self.rank_in_group
world_size = self.num_ranks_in_param_group
Expand Down Expand Up @@ -1253,8 +1249,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(
_all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group))
handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))

return MultipleAllGatherHandles(handles)

Expand Down Expand Up @@ -1315,11 +1310,10 @@ def all_gather_coalesced(params: Iterable[Parameter],
partitions=None,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
forward=forward,
quantization=quant_info,
)

def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False):
def partition(param_list=None, hierarchy=0, has_been_updated=False):
cls = param
print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
force=False)
Expand Down Expand Up @@ -1473,7 +1467,7 @@ def _partition(self, param_list, force=False, has_been_updated=False):
for param in param_list:
print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False)
if self.zero_param_process_group is not None:
self._partition_param_sec(param, has_been_updated=has_been_updated)
self._partition_param_sec(param)
self._partition_param(param, has_been_updated=has_been_updated)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
Expand Down Expand Up @@ -1608,30 +1602,28 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
##support for NVME secondary param offload
#print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True)
if param.ds_status is ZeroParamStatus.AVAILABLE:
if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned

return
#check padding
tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.dp_world_size

secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group)
final_location = None
secondary_partitioned_tensor = torch.empty(secondary_partition_size,
dtype=param.dtype,
device=self.remote_device)

if self.pin_memory:
secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory()
# quantize the tensor if it's not trainable
if not param.requires_grad and self.quantized_nontrainable_weights:
secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize(
secondary_partitioned_tensor)
secondary_partitioned_tensor.requires_grad = False
param.ds_secondary_tensor = secondary_partitioned_tensor
param.ds_secondary_tensor.ds_numel = secondary_partition_size
param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE
param.ds_secondary_tensor.final_location = final_location
if param.ds_secondary_tensor is None:
final_location = None
secondary_partitioned_tensor = torch.empty(secondary_partition_size,
dtype=param.dtype,
device=self.remote_device)

if self.pin_memory:
secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory()
# quantize the tensor if it's not trainable
if not param.requires_grad and self.quantized_nontrainable_weights:
secondary_partitioned_tensor, secondary_partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize(
secondary_partitioned_tensor)
secondary_partitioned_tensor.requires_grad = False
param.ds_secondary_tensor = secondary_partitioned_tensor
param.ds_secondary_tensor.ds_numel = secondary_partition_size
param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE
param.ds_secondary_tensor.final_location = final_location

#use rank in group for secondary tensor
secondary_start = secondary_partition_size * self.rank_in_group
Expand Down
42 changes: 25 additions & 17 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ def _is_currently_on_nvme(param):

@instrument_w_nvtx
@torch.no_grad()
def release_sub_module(self, submodule: Module, backward: bool) -> None:
def release_sub_module(self, submodule: Module) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, backward)
self.__release_param(param)

@instrument_w_nvtx
@torch.no_grad()
Expand All @@ -408,7 +408,7 @@ def release_and_reset_all(self, module: Module) -> None:
# TODO. make this throw if if there are still active submodules. currently
# there's a hook execution issue
param.ds_active_sub_modules.clear()
self.__release_param(param, backward=False)
self.__release_param(param)

for param in iter_params(module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
Expand Down Expand Up @@ -439,19 +439,27 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:
all_gather_numel += param.ds_numel

if partitioned_params:
partitioned_params
self.__n_available_params += all_gather_numel
with get_accelerator().stream(self.__allgather_stream):
event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER
self.__profiler.start_event(event_name)
handle = partitioned_params[0].all_gather_coalesced(partitioned_params,
forward=forward,
quantize=quantize)
self.__profiler.stop_event(event_name, all_gather_numel)

for param in partitioned_params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
self.__inflight_param_registry[param] = handle
# here we need to handle a special case where some of the parameters have a valid hpz secondary tensor (e.g. they are not trainable so their secondary tensor never expire) but others do not.
partitioned_params_with_secondary_tensors = [
p for p in partitioned_params if p.ds_secondary_tensor is not None
]
partitioned_params_without_secondary_tensors = [
p for p in partitioned_params if p.ds_secondary_tensor is None
]
for param_group in [
partitioned_params_with_secondary_tensors, partitioned_params_without_secondary_tensors
]:
if not param_group:
continue
with get_accelerator().stream(self.__allgather_stream):
event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER
self.__profiler.start_event(event_name)
handle = param_group[0].all_gather_coalesced(param_group, quantize=quantize)
self.__profiler.stop_event(event_name, all_gather_numel)
for param in param_group:
assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
self.__inflight_param_registry[param] = handle

# Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU
swap_persisted_params = [
Expand All @@ -461,11 +469,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize:
swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params)

@instrument_w_nvtx
def __release_param(self, param: Parameter, backward: bool) -> None:
def __release_param(self, param: Parameter) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition(backward=backward)
param.partition()
self.__n_available_params -= param.ds_numel

@instrument_w_nvtx
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,7 +1927,7 @@ def _post_step(self, timer_names):
if self.swap_optimizer:
self.optimizer_swapper.log_timers()

# self.invalidate_secondary_tensor() # given that we want hpz in forward pass when no_grad is set, we need to keep the secondary tensor
self.invalidate_secondary_tensor()

self.timers.log(timer_names)

Expand Down
Loading