Skip to content

Commit 4f8394e

Browse files
authored
Rename LightningLite (1/n) (#15932)
1 parent 0d2d8dc commit 4f8394e

File tree

24 files changed

+102
-108
lines changed

24 files changed

+102
-108
lines changed

docs/source-pytorch/starter/lightning_lite.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Lightning Lite
77
on any kind of device while retaining full control over their own loops and optimization logic.
88

99
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/lite/lightning_lite.gif
10-
:alt: Animation showing how to convert your PyTorch code to LightningLite.
10+
:alt: Animation showing how to convert your PyTorch code to Fabric.
1111
:width: 500
1212
:align: center
1313

@@ -72,8 +72,8 @@ The ``train`` function contains a standard training loop used to train ``MyModel
7272
----------
7373

7474

75-
Convert to LightningLite
76-
========================
75+
Convert to Fabric
76+
=================
7777

7878
Here are five easy steps to let :class:`~pytorch_lightning.lite.LightningLite` scale your PyTorch models.
7979

@@ -89,7 +89,7 @@ Here are five easy steps to let :class:`~pytorch_lightning.lite.LightningLite` s
8989
import torch
9090
from torch import nn
9191
from torch.utils.data import DataLoader, Dataset
92-
from lightning.lite import LightningLite
92+
from lightning.lite import Fabric
9393
9494
9595
class MyModel(nn.Module):
@@ -102,7 +102,7 @@ Here are five easy steps to let :class:`~pytorch_lightning.lite.LightningLite` s
102102
103103
def train(args):
104104
105-
lite = LightningLite()
105+
lite = Fabric()
106106
107107
model = MyModel(...)
108108
optimizer = torch.optim.SGD(model.parameters(), ...)
@@ -124,7 +124,7 @@ Here are five easy steps to let :class:`~pytorch_lightning.lite.LightningLite` s
124124
125125
126126
That's all you need to do to your code. You can now train on any kind of device and scale your training.
127-
Check out `this <https://github.com/Lightning-AI/lightning/blob/master/examples/lite/image_classifier_2_lite.py>`_ full MNIST training example with LightningLite.
127+
Check out `this <https://github.com/Lightning-AI/lightning/blob/master/examples/lite/image_classifier_2_lite.py>`_ full MNIST training example with Fabric.
128128

129129
Here is how to train on eight GPUs with `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_ precision:
130130

@@ -149,7 +149,7 @@ You can also easily use distributed collectives if required.
149149

150150
.. code-block:: python
151151
152-
lite = LightningLite()
152+
lite = Fabric()
153153
154154
# Transfer and concatenate tensors across processes
155155
lite.all_gather(...)

examples/app_multi_node/train_lite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import lightning as L
44
from lightning.app.components import LiteMultiNode
5-
from lightning.lite import LightningLite
5+
from lightning.lite import Fabric
66

77

88
class LitePyTorchDistributed(L.LightningWork):
@@ -14,8 +14,8 @@ def run(self):
1414
torch.nn.Linear(1, 1),
1515
)
1616

17-
# 2. Create LightningLite.
18-
lite = LightningLite(strategy="ddp", precision=16)
17+
# 2. Create Fabric.
18+
lite = Fabric(strategy="ddp", precision=16)
1919
model, optimizer = lite.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
2020
criterion = torch.nn.MSELoss()
2121

examples/lite/image_classifier_2_lite.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Here are 4 easy steps to use LightningLite in your PyTorch code.
15+
"""Here are 4 easy steps to use Fabric in your PyTorch code.
1616
1717
1. Create the Lightning Lite object at the beginning of your script.
1818
19-
2. Remove all ``.to`` and ``.cuda`` calls since LightningLite will take care of it.
19+
2. Remove all ``.to`` and ``.cuda`` calls since Fabric will take care of it.
2020
2121
3. Apply ``setup`` over each model and optimizers pair, ``setup_dataloaders`` on all your dataloaders,
2222
and replace ``loss.backward()`` with ``self.backward(loss)``.
@@ -40,7 +40,7 @@
4040
from torchmetrics.classification import Accuracy
4141
from torchvision.datasets import MNIST
4242

43-
from lightning.lite import LightningLite # import LightningLite
43+
from lightning.lite import Fabric # import Fabric
4444
from lightning.lite import seed_everything
4545

4646
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
@@ -49,7 +49,7 @@
4949
def run(hparams):
5050
# Create the Lightning Lite object. The parameters like accelerator, strategy, devices etc. will be proided
5151
# by the command line. See all options: `lightning run model --help`
52-
lite = LightningLite()
52+
lite = Fabric()
5353

5454
lite.hparams = hparams
5555
seed_everything(hparams.seed) # instead of torch.manual_seed(...)
@@ -148,7 +148,7 @@ def run(hparams):
148148
# Arguments can be passed in through the CLI as normal and will be parsed here
149149
# Example:
150150
# lightning run model image_classifier.py accelerator=cuda --epochs=3
151-
parser = argparse.ArgumentParser(description="LightningLite MNIST Example")
151+
parser = argparse.ArgumentParser(description="Fabric MNIST Example")
152152
parser.add_argument(
153153
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
154154
)

src/lightning/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
3636
from lightning.app.perf import pdb # noqa: E402
3737
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
3838
from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
39-
from lightning.lite.lite import LightningLite # noqa: E402
39+
from lightning.lite.lite import Fabric # noqa: E402
4040
from lightning.pytorch.callbacks import Callback # noqa: E402
4141
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
4242
from lightning.pytorch.trainer import Trainer # noqa: E402
@@ -60,7 +60,7 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
6060
"LightningModule",
6161
"Callback",
6262
"seed_everything",
63-
"LightningLite",
63+
"Fabric",
6464
"storage",
6565
"pdb",
6666
]

src/lightning_app/utilities/introspection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class TorchMetricVisitor(LightningVisitor):
261261

262262

263263
class LightningLiteVisitor(LightningVisitor):
264-
class_name = "LightningLite"
264+
class_name = "Fabric"
265265

266266

267267
class LightningBaseProfilerVisitor(LightningVisitor):

src/lightning_lite/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
2929

3030

31-
-
31+
- Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932))
3232

3333

3434
### Deprecated

src/lightning_lite/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
2121

2222

23-
from lightning_lite.lite import LightningLite # noqa: E402
23+
from lightning_lite.lite import Fabric # noqa: E402
2424
from lightning_lite.utilities.seed import seed_everything # noqa: E402
2525

26-
__all__ = ["LightningLite", "seed_everything"]
26+
__all__ = ["Fabric", "seed_everything"]
2727

2828
# for compatibility with namespace packages
2929
__import__("pkg_resources").declare_namespace(__name__)

src/lightning_lite/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@
106106
)
107107
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
108108
def _run_model(**kwargs: Any) -> None:
109-
"""Run a Lightning Lite script.
109+
"""Run a Lightning Fabric script.
110110
111-
SCRIPT is the path to the Python script with the code to run. The script must contain a LightningLite object.
111+
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
112112
113113
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
114114
there.
@@ -120,7 +120,7 @@ def _run_model(**kwargs: Any) -> None:
120120
def _set_env_variables(args: Namespace) -> None:
121121
"""Set the environment variables for the new processes.
122122
123-
The Lite connector will parse the arguments set here.
123+
The Fabric connector will parse the arguments set here.
124124
"""
125125
os.environ["LT_CLI_USED"] = "1"
126126
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
@@ -187,7 +187,7 @@ def main(args: Namespace, script_args: Optional[List[str]] = None) -> None:
187187
if __name__ == "__main__":
188188
if not _CLICK_AVAILABLE: # pragma: no cover
189189
_log.error(
190-
"To use the Lightning Lite CLI, you must have `click` installed."
190+
"To use the Lightning Fabric CLI, you must have `click` installed."
191191
" Install it by running `pip install -U click`."
192192
)
193193
raise SystemExit(1)

src/lightning_lite/connector.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666

6767
class _Connector:
68-
"""The Connector parses several Lite arguments and instantiates the Strategy including its owned components.
68+
"""The Connector parses several Fabric arguments and instantiates the Strategy including its owned components.
6969
7070
A. accelerator flag could be:
7171
1. accelerator class
@@ -297,7 +297,7 @@ def _check_device_config_and_set_final_flags(
297297
else self._accelerator_flag
298298
)
299299
raise ValueError(
300-
f"`Lite(devices={self._devices_flag!r})` value is not a valid input"
300+
f"`Fabric(devices={self._devices_flag!r})` value is not a valid input"
301301
f" using {accelerator_name} accelerator."
302302
)
303303

@@ -345,7 +345,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
345345
f"`{accelerator_cls.__qualname__}` can not run on your system"
346346
" since the accelerator is not available. The following accelerator(s)"
347347
" is available and can be passed into `accelerator` argument of"
348-
f" `Lite`: {available_accelerator}."
348+
f" `Fabric`: {available_accelerator}."
349349
)
350350

351351
self._set_devices_flag_if_auto_passed()
@@ -416,14 +416,14 @@ def _check_strategy_and_fallback(self) -> None:
416416
strategy_flag = "ddp"
417417
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
418418
raise ValueError(
419-
f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this"
420-
f" platform. We recommed `Lite(strategy='ddp_spawn')` instead."
419+
f"You selected `Fabric(strategy='{strategy_flag}')` but process forking is not supported on this"
420+
f" platform. We recommed `Fabric(strategy='ddp_spawn')` instead."
421421
)
422422
if (
423423
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
424424
) and self._accelerator_flag not in ("cuda", "gpu"):
425425
raise ValueError(
426-
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
426+
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Fabric(accelerator='gpu', ...)`"
427427
" to continue or select a different strategy."
428428
)
429429
if strategy_flag:
@@ -449,7 +449,7 @@ def _check_and_init_precision(self) -> Precision:
449449
elif self._precision_input in (16, "bf16"):
450450
if self._precision_input == 16:
451451
rank_zero_warn(
452-
"You passed `Lite(accelerator='tpu', precision=16)` but AMP"
452+
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
453453
" is not supported with TPUs. Using `precision='bf16'` instead."
454454
)
455455
return TPUBf16Precision()
@@ -463,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
463463

464464
if self._precision_input == 16 and self._accelerator_flag == "cpu":
465465
rank_zero_warn(
466-
"You passed `Lite(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
466+
"You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
467467
" Using `precision='bf16'` instead."
468468
)
469469
self._precision_input = "bf16"
@@ -487,7 +487,7 @@ def _validate_precision_choice(self) -> None:
487487
if isinstance(self.accelerator, TPUAccelerator):
488488
if self._precision_input == 64:
489489
raise NotImplementedError(
490-
"`Lite(accelerator='tpu', precision=64)` is not implemented."
490+
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
491491
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
492492
" requesting this feature."
493493
)
@@ -519,10 +519,10 @@ def _lazy_init_strategy(self) -> None:
519519

520520
if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
521521
raise RuntimeError(
522-
f"`Lite(strategy={self._strategy_flag!r})` is not compatible with an interactive"
522+
f"`Fabric(strategy={self._strategy_flag!r})` is not compatible with an interactive"
523523
" environment. Run your code as a script, or choose one of the compatible strategies:"
524-
f" Lite(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})."
525-
" In case you are spawning processes yourself, make sure to include the Lite"
524+
f" Fabric(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})."
525+
" In case you are spawning processes yourself, make sure to include the Fabric"
526526
" creation inside the worker function."
527527
)
528528

@@ -549,9 +549,9 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any:
549549

550550
if env_value is not None and env_value != current and current != default:
551551
raise ValueError(
552-
f"Your code has `LightningLite({name}={current!r}, ...)` but it conflicts with the value "
552+
f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value "
553553
f"`--{name}={current}` set through the CLI. "
554-
" Remove it either from the CLI or from the Lightning Lite object."
554+
" Remove it either from the CLI or from the Lightning Fabric object."
555555
)
556556
if env_value is None:
557557
return current

src/lightning_lite/lite.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
5555

5656

57-
class LightningLite:
58-
"""Lite accelerates your PyTorch training or inference code with minimal changes required.
57+
class Fabric:
58+
"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
5959
6060
- Automatic placement of models and data onto the device.
6161
- Automatic support for mixed and double precision (smaller memory footprint).
@@ -139,7 +139,7 @@ def is_global_zero(self) -> bool:
139139
return self._strategy.is_global_zero
140140

141141
def run(self, *args: Any, **kwargs: Any) -> Any:
142-
"""All the code inside this run method gets accelerated by Lite.
142+
"""All the code inside this run method gets accelerated by Fabric.
143143
144144
You can pass arbitrary arguments to this function when overriding it.
145145
"""
@@ -502,16 +502,16 @@ def load(self, filepath: Union[str, Path]) -> Any:
502502
"""
503503
return self._strategy.load_checkpoint(filepath)
504504

505-
def launch(self, function: Optional[Callable[["LightningLite"], Any]] = None, *args: Any, **kwargs: Any) -> Any:
505+
def launch(self, function: Optional[Callable[["Fabric"], Any]] = None, *args: Any, **kwargs: Any) -> Any:
506506
if _is_using_cli():
507507
raise RuntimeError(
508508
"This script was launched through the CLI, and processes have already been created. Calling "
509509
" `.launch()` again is not allowed."
510510
)
511511
if function is not None and not inspect.signature(function).parameters:
512512
raise TypeError(
513-
"The function passed to `Lite.launch()` needs to take at least one argument. The launcher will pass"
514-
" in the `LightningLite` object so you can use it inside the function."
513+
"The function passed to `Fabric.launch()` needs to take at least one argument. The launcher will pass"
514+
" in the `Fabric` object so you can use it inside the function."
515515
)
516516
function = partial(self._run_with_setup, function or _do_nothing)
517517
args = [self, *args]
@@ -550,9 +550,9 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -
550550
initial_device = next(model.parameters()).device
551551
if any(param.device != initial_device for param in model.parameters()):
552552
rank_zero_warn(
553-
"The model passed to `Lite.setup()` has parameters on different devices. Since `move_to_device=True`,"
553+
"The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"
554554
" all parameters will be moved to the new device. If this is not desired, set "
555-
" `Lite.setup(..., move_to_device=False)`.",
555+
" `Fabric.setup(..., move_to_device=False)`.",
556556
category=PossibleUserWarning,
557557
)
558558

@@ -586,10 +586,10 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut
586586
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
587587

588588
def _prepare_run_method(self) -> None:
589-
if is_overridden("run", self, LightningLite) and _is_using_cli():
589+
if is_overridden("run", self, Fabric) and _is_using_cli():
590590
raise TypeError(
591-
"Overriding `LightningLite.run()` and launching from the CLI is not allowed. Run the script normally,"
592-
" or change your code to directly call `lite = LightningLite(...); lite.setup(...)` etc."
591+
"Overriding `Fabric.run()` and launching from the CLI is not allowed. Run the script normally,"
592+
" or change your code to directly call `lite = Fabric(...); lite.setup(...)` etc."
593593
)
594594
# wrap the run method, so we can inject setup logic or spawn processes for the user
595595
setattr(self, "run", partial(self._run_impl, self.run))

0 commit comments

Comments
 (0)