Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
238ba1f
fix: forbid repeated deepspeed.initialize on training objects
traincheck-team Dec 16, 2024
d1e7777
fix: remove mark-time checking for non-existence of the flag as DeepS…
traincheck-team Dec 16, 2024
62067cc
handle callable types in init mark
traincheck-team Dec 19, 2024
2c5806b
change: do init checking and marking in one func
traincheck-team Dec 30, 2024
6a0b600
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Jan 2, 2025
7452786
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Jan 4, 2025
71d3e31
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Jan 13, 2025
80e9e16
Merge branch 'master' into fix-6848-forbid-repeated-init
tjruwase Jan 21, 2025
a9837f9
remove unnecessary prints
traincheck-team Jan 21, 2025
b1d4330
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Jan 21, 2025
1b15bea
add: split TestNoRepeatedInitializationAllowed test into two separate…
traincheck-team Jan 27, 2025
f84cca6
Merge branch 'master' into fix-6848-forbid-repeated-init
tjruwase Jan 28, 2025
13dbe56
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Jan 31, 2025
d2f315f
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Feb 7, 2025
ee20181
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams Feb 14, 2025
15831ce
Merge branch 'master' into fix-6848-forbid-repeated-init
tjruwase Feb 23, 2025
5098754
Merge branch 'master' into fix-6848-forbid-repeated-init
tjruwase Mar 21, 2025
20e9203
Merge branch 'master' into fix-6848-forbid-repeated-init
loadams May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import types
import json
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -27,6 +27,8 @@

from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.base_optimizer import DeepSpeedOptimizer
from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -65,6 +67,44 @@ def _parse_version(version_str):
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader)


def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects
trainobj.ds_is_inited = True


def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
if isinstance(trainobj, DS_PRIM_TYPES):
return True
else:
return getattr(trainobj, 'ds_is_inited', False)


def _ensure_and_mark_trainobjs_inited(
model: torch.nn.Module,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]],
ensures_not_inited: bool = False,
):
trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}

for name, trainobj in trainobjs.items():
if trainobj is None:
continue
if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)):
# skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable
continue
if ensures_not_inited:
if _is_ds_initialized(trainobj):
raise ValueError(
f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once."
)
_mark_ds_initialized(trainobj)


def initialize(args=None,
model: torch.nn.Module = None,
Expand Down Expand Up @@ -137,6 +177,8 @@ def initialize(args=None,
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
_ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True)

global dist
from deepspeed import comm as dist
Expand Down Expand Up @@ -221,6 +263,9 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

# mark engine, optimizer, and lr_scheduler as initialized
_ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False)

return_items = [
engine,
engine.optimizer,
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder
from deepspeed import _is_ds_initialized


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -434,3 +435,73 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
else:
# callable
assert isinstance(ds_lr_scheduler, OneCycleLR)


# https://github.com/microsoft/DeepSpeed/issues/6770
class TestNoRepeatedInitializationAllowed(DistributedTest):
world_size = 1

@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test_objs_marked_ds_inited(self, optimizer_type):
hidden_dim = 10
model = SimpleModel(hidden_dim)

def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

# Initialize DeepSpeed engine
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)

# arguments should be marked as initialized now
assert _is_ds_initialized(model), "Client model should be marked as initialized"
if optimizer_type is Optimizer:
assert _is_ds_initialized(client_optimizer), "Client optimizer should be marked as initialized"

# return values should also be marked as initialized
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"

@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test_repeated_initialization_raises_error(self, optimizer_type):
hidden_dim = 10
model = SimpleModel(hidden_dim)

def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

# Initialize DeepSpeed engine
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)
err_msg_pattern = "has already been initialized"
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)