Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))


## [1.9.2] - 2023-02-15
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_supported_strategies() -> List[str]:
@click.option(
"--accelerator",
type=click.Choice(_SUPPORTED_ACCELERATORS),
default="cpu",
default=None,
help="The hardware accelerator to run on.",
)
@click.option(
Expand Down Expand Up @@ -108,7 +108,7 @@ def _get_supported_strategies() -> List[str]:
@click.option(
"--precision",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
default="32-true",
default=None,
help=(
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
Expand All @@ -133,12 +133,14 @@ def _set_env_variables(args: Namespace) -> None:
The Fabric connector will parse the arguments set here.
"""
os.environ["LT_CLI_USED"] = "1"
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.accelerator is not None:
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy is not None:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
os.environ["LT_PRECISION"] = str(args.precision)
if args.precision is not None:
os.environ["LT_PRECISION"] = str(args.precision)


def _get_num_processes(accelerator: str, devices: str) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
_run_model.main([fake_script])
assert e.value.code == 0
assert os.environ["LT_CLI_USED"] == "1"
assert os.environ["LT_ACCELERATOR"] == "cpu"
assert "LT_ACCELERATOR" not in os.environ
assert "LT_STRATEGY" not in os.environ
assert os.environ["LT_DEVICES"] == "1"
assert os.environ["LT_NUM_NODES"] == "1"
assert os.environ["LT_PRECISION"] == "32-true"
assert "LT_PRECISION" not in os.environ


@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
Expand Down
12 changes: 9 additions & 3 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,18 +813,22 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
assert isinstance(connector.strategy, strategy_cls)


@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"])
@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
def test_precision_from_environment(_, precision):
"""Test that the precision input can be set through the environment variable."""
with mock.patch.dict(os.environ, {"LT_PRECISION": precision}):
env_vars = {}
if precision is not None:
env_vars["LT_PRECISION"] = precision
with mock.patch.dict(os.environ, env_vars):
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
assert isinstance(connector.precision, Precision)


@pytest.mark.parametrize(
"accelerator, strategy, expected_accelerator, expected_strategy",
[
(None, None, CPUAccelerator, SingleDeviceStrategy),
("cpu", None, CPUAccelerator, SingleDeviceStrategy),
("cpu", "ddp", CPUAccelerator, DDPStrategy),
pytest.param("mps", None, MPSAccelerator, SingleDeviceStrategy, marks=RunIf(mps=True)),
Expand All @@ -836,7 +840,9 @@ def test_precision_from_environment(_, precision):
)
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
"""Test that the accelerator and strategy input can be set through the environment variables."""
env_vars = {"LT_ACCELERATOR": accelerator}
env_vars = {}
if accelerator is not None:
env_vars["LT_ACCELERATOR"] = accelerator
if strategy is not None:
env_vars["LT_STRATEGY"] = strategy

Expand Down