Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepspeed/datastates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# DataStates-LLM checkpointing engine.

This feature is not enabled by default. To enable, set the following options in ds_config.json and download [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
6 changes: 6 additions & 0 deletions deepspeed/datastates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team
21 changes: 21 additions & 0 deletions deepspeed/datastates/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team

from deepspeed.runtime.config_utils import DeepSpeedConfigObject


class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):

def __init__(self, param_dict):
super(DeepSpeedDataStatesConfig, self).__init__()

self.enabled = None
self.config = {}

if "datastates_ckpt" in param_dict.keys():
self.enabled = True
self.config = param_dict["datastates_ckpt"]
14 changes: 14 additions & 0 deletions deepspeed/runtime/checkpoint_engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,17 @@ class CheckpointEngine(object):
pass

```


### Asynchronous Lazy Checkpointing using DataStates-LLM

DataStates-LLM is an asynchronous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. A detailed tutorial is available [here](../../../docs/_tutorials/datastates-async-checkpointing.md). To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing, and `parser_threads` to parse multiple checkpoint file requests in parallel using the following lines in config.json supplied during the launch:
```
{
... other deepspeed config options,
"datastates_ckpt": {
"host_cache_size": 16,
"parser_threads": 8
}
}
```
4 changes: 4 additions & 0 deletions deepspeed/runtime/checkpoint_engine/checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ def load(self, path: str, map_location=None):
def commit(self, tag):
# to tell checkpoint services if all files are ready.
pass

def wait(self):
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.

# DeepSpeed Team

from deepspeed.utils import log_dist
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine
from datastates.llm import Checkpointing


class DataStatesCheckpointEngine(CheckpointEngine):

def __init__(self, deepspeed_config, rank):
super().__init__(deepspeed_config)
self.ckpt_engine = Checkpointing(deepspeed_config, rank)

def create(self, tag):
log_dist(f"[DataStates] Checkpoint {tag} is about to be saved!", ranks=[0])
return None

def save(self, state_dict, path: str):
return self.ckpt_engine.save(state_dict, path)

def load(self, path: str, map_location=None):
return self.ckpt_engine.load(path, map_location)

def commit(self, tag):
return self.ckpt_engine.commit(tag)

def wait(self):
return self.ckpt_engine.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CheckpointEngine
from deepspeed.utils import logger, log_dist
from deepspeed.nebula.constants import *
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save


def _get_tag_from_path(path):
Expand Down Expand Up @@ -48,7 +49,8 @@ def save(self, state_dict, path: str):
tag = _get_tag_from_path(path)
partition_name = os.path.basename(path)
logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
self.checkpoint.save(partition_name, state_dict)
debloated_state_dict = clone_tensors_for_torch_save(state_dict)
self.checkpoint.save(partition_name, debloated_state_dict)
logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepspeed.utils import logger, log_dist
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save


class TorchCheckpointEngine(CheckpointEngine):
Expand All @@ -19,7 +20,8 @@ def create(self, tag):

def save(self, state_dict, path: str):
logger.info(f"[Torch] Saving {path}...")
torch.save(state_dict, path)
debloated_state_dict = clone_tensors_for_torch_save(state_dict)
torch.save(debloated_state_dict, path)
logger.info(f"[Torch] Saved {path}.")
return None

Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from ..profiling.config import DeepSpeedFlopsProfilerConfig
from ..autotuning.config import DeepSpeedAutotuningConfig
from ..nebula.config import DeepSpeedNebulaConfig
from ..datastates.config import DeepSpeedDataStatesConfig

from ..compression.config import get_compression_config, get_quantize_enabled
from ..compression.constants import *
Expand Down Expand Up @@ -908,6 +909,7 @@ def _initialize_params(self, param_dict):
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)

self.nebula_config = DeepSpeedNebulaConfig(param_dict)
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)

self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None
Expand Down
16 changes: 16 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,16 @@ def _configure_checkpointing(self, dist_init_required):
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()

if self._config is not None and self._config.datastates_config.enabled:
try:
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
self.checkpoint_engine = DataStatesCheckpointEngine(deepspeed_config=self._config,
rank=dist.get_rank())
except ImportError as err:
raise Exception(
f"The datastates-llm checkpoint engine was not found! Will fall back to torch.save. Details: {err}"
)

dp_rank = groups._get_sequence_data_parallel_rank()

rank = self.local_rank if self.use_node_local_storage() else dp_rank
Expand Down Expand Up @@ -2254,6 +2264,12 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)

try:
self.checkpoint_engine.wait()
except Exception as exc:
logger.error(f"Error during optimizer wait step: {exc}")

self.optimizer.step()

if hasattr(self.optimizer, '_global_grad_norm'):
Expand Down
5 changes: 2 additions & 3 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .topology import PipeDataParallelTopology, PipelineParallelGrid
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save


class PipelineError(Exception):
Expand Down Expand Up @@ -628,8 +627,8 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
if exclude_frozen_params:
for n in self._get_frozen_parameter_names(layer):
del orig_state_dict[n]
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
checkpoint_engine.save(final_state_dict, model_ckpt_path)

checkpoint_engine.save(orig_state_dict, model_ckpt_path)

def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
for idx, layer in enumerate(self.forward_funcs):
Expand Down
67 changes: 67 additions & 0 deletions docs/_tutorials/datastates-async-checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
---
title: "DataStates-LLM Checkpointing Engine"
tags: asynchronous checkpointing for minimizing I/O overheads.
---
This tutorial will show how to use [DataStates-LLM](https://github.com/DataStates/datastates-llm) for asynchronous checkpointing. DataStates-LLM introduces a lazy asynchronous checkpointing mechanism tailored for LLMs, aiming to minimize I/O overhead and enhance training efficiency. This tutorial provides a guide on integrating DataStates-LLM with the DeepSpeed framework.

## Overview of DataStates-LLM

DataStates-LLM is designed to address the challenges of frequent checkpointing in LLM training by introducing a lazy asynchronous multi-level approach. It leverages the immutability of model parameters and optimizer states during forward and backward passes to perform non-blocking data transfers, thereby reducing interference with the training process. This method has demonstrated up to 48x faster checkpointing and 2.2x faster end-to-end training times compared to traditional approaches as outlined in [DataStates-LLM: Lazy Asynchronous Checkpointing for Large Language Models](https://arxiv.org/abs/2406.10707).

## Prerequisites

Before integrating DataStates-LLM with DeepSpeed, ensure the following:

- **DeepSpeed Installation**: DeepSpeed should be installed in your environment. If not, refer to the [DeepSpeed Getting Started Guide](https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/getting-started.md) for installation instructions.

- **DataStates-LLM Repository**: Access the DataStates-LLM source code from its [GitHub repository](https://github.com/DataStates/datastates-llm) and follow the installation instructions provided therein.

## Configuring DeepSpeed for DataStates-LLM

To enable DataStates-LLM's asynchronous checkpointing within DeepSpeed, please modify the `deepspeed_config.json` file to include specific configurations under the `datastates_ckpt` section. Below is an example configuration:

```json
{
// ... other DeepSpeed configuration options
"datastates_ckpt": {
"host_cache_size": 16,
"parser_threads": 8
}
}
```

### Configuration Parameters

- **`host_cache_size`**: Specifies the amount of pinned host memory (in gigabytes) reserved for asynchronous checkpoint flushing. Adjust this value based on your system's memory capacity and the size of your model checkpoints.

- **`parser_threads`**: Determines the number of threads dedicated to parsing checkpoint file requests in parallel. Increasing this value can enhance parsing throughput but may also increase CPU utilization.

## Implementing DataStates-LLM in Your Training Script

After enabling datastates checkpointing the `deepspeed_config.json`, the frequency of checkpointing can be configured by specifying the number of iterations after which the checkpoints should be captured using command-line parameter ` --save-interval`.

## Performance Results

The checkpoint acceleration achieved by DataStates-LLM for various models are shown in

![Higher checkpointing throughput](/assets/images/datastates-async-checkpointing/diff-models-ckpt-throughput.png){: .align-center}

![Faster training iterations](/assets/images/datastates-async-checkpointing/diff-models-iter-times.png){: .align-center}


## Limitations and Ongoing Work

1. DataStates-LLM currently only supports the CUDA runtime on Nvidia-based GPUs.


2. DataStates-LLM has only been tested with ZeRO stage-1 without offloading to any other tiers.


3. While the checkpoint layout of datastates matches Huggingface's [safetensor](https://huggingface.co/docs/safetensors/) format, due to pickled objects required by DeepSpeed during restart, it is not fully compatible with safetensor library yet.

4. DataStates-LLM does not yet support universal or elastic checkpointing.


## Questions and Support

Please use the [DataStates-LLM Github repository](https://github.com/DataStates/datastates-llm) for any questions, issues, or feature requests.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading