diff --git a/src/lightning/app/core/flow.py b/src/lightning/app/core/flow.py index 1aedf30f4a5da..f3e3f697b4bdf 100644 --- a/src/lightning/app/core/flow.py +++ b/src/lightning/app/core/flow.py @@ -21,6 +21,7 @@ from deepdiff import DeepHash +from lightning.app.core.plugin import Plugin from lightning.app.core.work import LightningWork from lightning.app.frontend import Frontend from lightning.app.storage import Path @@ -740,6 +741,22 @@ def configure_api(self): """ raise NotImplementedError + def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]: + """Configure the plugins of this LightningFlow. + + Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`. + + .. code-block:: python + + class Flow(LightningFlow): + def __init__(self): + super().__init__() + + def configure_plugins(self): + return [{"my_plugin_name": MyPlugin()}] + """ + pass + def state_dict(self): """Returns the current flow state but not its children.""" return { diff --git a/src/lightning/app/core/plugin.py b/src/lightning/app/core/plugin.py new file mode 100644 index 0000000000000..a75ff33c42263 --- /dev/null +++ b/src/lightning/app/core/plugin.py @@ -0,0 +1,170 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional + +import requests +import uvicorn +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from lightning.app.utilities.app_helpers import Logger +from lightning.app.utilities.cloud import _get_project +from lightning.app.utilities.component import _set_flow_context +from lightning.app.utilities.enum import AppStage +from lightning.app.utilities.network import LightningClient + +logger = Logger(__name__) + + +class Plugin: + """A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions.""" + + def __init__(self) -> None: + self.app_url = None + + def run(self, name: str, entrypoint: str) -> None: + """Override with the logic to execute on the client side.""" + + def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]: + """Run a command on the app associated with this plugin. + + Args: + command_name: The name of the command to run. + config: The command config or ``None`` if the command doesn't require configuration. + """ + if self.app_url is None: + raise RuntimeError("The plugin must be set up before `run_app_command` can be called.") + + command = command_name.replace(" ", "_") + resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None) + if resp.status_code != 200: + try: + detail = str(resp.json()) + except Exception: + detail = "Internal Server Error" + raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}") + + return resp.json() + + def _setup(self, app_id: str) -> None: + client = LightningClient() + project_id = _get_project(client).project_id + response = client.lightningapp_instance_service_list_lightningapp_instances( + project_id=project_id, app_id=app_id + ) + if len(response.lightningapps) > 1: + raise RuntimeError(f"Found multiple apps with ID: {app_id}") + if len(response.lightningapps) == 0: + raise RuntimeError(f"Found no apps with ID: {app_id}") + self.app_url = response.lightningapps[0].status.url + + +class _Run(BaseModel): + plugin_name: str + project_id: str + cloudspace_id: str + name: str + entrypoint: str + cluster_id: Optional[str] = None + app_id: Optional[str] = None + + +def _run_plugin(run: _Run) -> None: + """Create a run with the given name and entrypoint under the cloudspace with the given ID.""" + if run.app_id is None and run.plugin_name == "app": + from lightning.app.runners.cloud import CloudRuntime + + # TODO: App dispatch should be a plugin + # Dispatch the run + _set_flow_context() + + entrypoint_file = Path("/content") / run.entrypoint + + app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute())) + + app.stage = AppStage.BLOCKING + + runtime = CloudRuntime( + app=app, + entrypoint=entrypoint_file, + start_server=True, + env_vars={}, + secrets={}, + run_app_comment_commands=True, + ) + # Used to indicate Lightning has been dispatched + os.environ["LIGHTNING_DISPATCHED"] = "1" + + try: + runtime.cloudspace_dispatch( + project_id=run.project_id, + cloudspace_id=run.cloudspace_id, + name=run.name, + cluster_id=run.cluster_id, + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + elif run.app_id is not None: + from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever + from lightning.app.utilities.commands.base import _download_command + + retriever = _LightningAppOpenAPIRetriever(run.app_id) + + metadata = retriever.api_commands[run.plugin_name] # type: ignore + + with tempfile.TemporaryDirectory() as tmpdir: + + target_file = os.path.join(tmpdir, f"{run.plugin_name}.py") + plugin = _download_command( + run.plugin_name, + metadata["cls_path"], + metadata["cls_name"], + run.app_id, + target_file=target_file, + ) + + if isinstance(plugin, Plugin): + plugin._setup(app_id=run.app_id) + plugin.run(run.name, run.entrypoint) + else: + # This should never be possible but we check just in case + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"The plugin {run.plugin_name} is an incorrect type.", + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`." + ) + + +def _start_plugin_server(host: str, port: int) -> None: + """Start the plugin server which can be used to dispatch apps or run plugins.""" + fastapi_service = FastAPI() + + fastapi_service.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + fastapi_service.post("/v1/runs")(_run_plugin) + + uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error") diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py index dcf3a27038f69..c03c5e73db7c8 100644 --- a/src/lightning/app/runners/cloud.py +++ b/src/lightning/app/runners/cloud.py @@ -208,6 +208,67 @@ def open(self, name: str, cluster_id: Optional[str] = None): logger.error(e.body) sys.exit(1) + def cloudspace_dispatch( + self, + project_id: str, + cloudspace_id: str, + name: str, + cluster_id: str = None, + ): + """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties + such as the project and cluster IDs that are instead passed directly. + + Args: + project_id: The ID of the project. + cloudspace_id: The ID of the cloudspace. + name: The name for the run. + cluster_id: The ID of the cluster to run on. + + Raises: + ApiException: If there was an issue in the backend. + RuntimeError: If there are validation errors. + ValueError: If there are validation errors. + """ + # Dispatch in four phases: resolution, validation, spec creation, API transactions + # Resolution + root = self._resolve_root() + ignore_functions = self._resolve_open_ignore_functions() + repo = self._resolve_repo(root, ignore_functions) + cloudspace = self._resolve_cloudspace(project_id, cloudspace_id) + cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace]) + queue_server_type = self._resolve_queue_server_type() + + self.app._update_index_file() + + # Validation + # TODO: Validate repo and surface to the user + # self._validate_repo(root, repo) + self._validate_work_build_specs_and_compute() + self._validate_drives() + self._validate_mounts() + + # Spec creation + flow_servers = self._get_flow_servers() + network_configs = self._get_network_configs(flow_servers) + works = self._get_works() + run_body = self._get_run_body(cluster_id, flow_servers, network_configs, works, False, root, True) + env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands) + + # API transactions + run = self._api_create_run(project_id, cloudspace_id, run_body) + self._api_package_and_upload_repo(repo, run) + + self._api_create_run_instance( + cluster_id, + project_id, + name, + cloudspace_id, + run.id, + V1LightningappInstanceState.RUNNING, + queue_server_type, + env_vars, + ) + def dispatch( self, name: str = "", @@ -410,6 +471,13 @@ def _resolve_project(self) -> V1Membership: """Determine the project to run on, choosing a default if multiple projects are found.""" return _get_project(self.backend.client) + def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace: + """Get a cloudspace by project / cloudspace ID.""" + return self.backend.client.cloud_space_service_get_cloud_space( + project_id=project_id, + id=cloudspace_id, + ) + def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]: """Lists all the cloudspaces with a name matching the provided cloudspace name.""" # TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces. @@ -871,7 +939,7 @@ def _api_create_run_instance( self, cluster_id: str, project_id: str, - cloudspace_name: str, + run_name: str, cloudspace_id: str, run_id: str, desired_state: V1LightningappInstanceState, @@ -886,7 +954,7 @@ def _api_create_run_instance( id=run_id, body=IdGetBody1( cluster_id=cluster_id, - name=cloudspace_name, + name=run_name, desired_state=desired_state, queue_server_type=queue_server_type, env=env_vars, diff --git a/src/lightning/app/utilities/commands/base.py b/src/lightning/app/utilities/commands/base.py index 2d0ec762e1a83..4ce208184cfce 100644 --- a/src/lightning/app/utilities/commands/base.py +++ b/src/lightning/app/utilities/commands/base.py @@ -31,6 +31,7 @@ from lightning.app.api.http_methods import Post from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse +from lightning.app.core.plugin import Plugin from lightning.app.utilities import frontend from lightning.app.utilities.app_helpers import is_overridden, Logger from lightning.app.utilities.cloud import _get_project @@ -108,7 +109,7 @@ def _download_command( app_id: Optional[str] = None, debug_mode: bool = False, target_file: Optional[str] = None, -) -> ClientCommand: +) -> Union[ClientCommand, Plugin]: # TODO: This is a skateboard implementation and the final version will rely on versioned # immutable commands for security concerns command_name = command_name.replace(" ", "_") @@ -139,7 +140,13 @@ def _download_command( mod = module_from_spec(spec) sys.modules[cls_name] = mod spec.loader.exec_module(mod) - command = getattr(mod, cls_name)(method=None) + command_type = getattr(mod, cls_name) + if issubclass(command_type, ClientCommand): + command = command_type(method=None) + elif issubclass(command_type, Plugin): + command = command_type() + else: + raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand` or `Plugin`.") if tmpdir and os.path.exists(tmpdir): shutil.rmtree(tmpdir) return command @@ -182,12 +189,11 @@ def _validate_client_command(command: ClientCommand): ) -def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]: +def _upload(name: str, prefix: str, obj: Any) -> Optional[str]: from lightning.app.storage.path import _filesystem, _is_s3fs_available, _shared_storage_path - command_name = command_name.replace(" ", "_") - filepath = f"commands/{command_name}.py" - remote_url = str(_shared_storage_path() / "artifacts" / filepath) + name = name.replace(" ", "_") + filepath = f"{prefix}/{name}.py" fs = _filesystem() if _is_s3fs_available(): @@ -196,7 +202,7 @@ def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]: if not isinstance(fs, S3FileSystem): return - source_file = str(inspect.getfile(command.__class__)) + source_file = str(inspect.getfile(obj.__class__)) remote_url = str(_shared_storage_path() / "artifacts" / filepath) fs.put(source_file, remote_url) return filepath @@ -211,13 +217,25 @@ def _prepare_commands(app) -> List: for command_mapping in commands: for command_name, command in command_mapping.items(): if isinstance(command, ClientCommand): - _upload_command(command_name, command) + _upload(command_name, "commands", command) # 2: Cache the commands on the app. app.commands = commands return commands +def _prepare_plugins(app) -> List: + if not is_overridden("configure_plugins", app.root): + return [] + + # 1: Upload the plugins to s3. + plugins = app.root.configure_plugins() + for plugin_mapping in plugins: + for plugin_name, plugin in plugin_mapping.items(): + if isinstance(plugin, Plugin): + _upload(plugin_name, "plugins", plugin) + + def _process_api_request(app, request: _APIRequest): flow = app.get_component_by_name(request.name) method = getattr(flow, request.method_name) diff --git a/tests/tests_app/core/test_plugin.py b/tests/tests_app/core/test_plugin.py new file mode 100644 index 0000000000000..2756955cefc0f --- /dev/null +++ b/tests/tests_app/core/test_plugin.py @@ -0,0 +1,117 @@ +from pathlib import Path +from unittest import mock + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from lightning.app.core.plugin import _Run, _start_plugin_server, Plugin + + +@pytest.fixture() +@mock.patch("lightning.app.core.plugin.uvicorn") +def mock_plugin_server(mock_uvicorn) -> TestClient: + """This fixture returns a `TestClient` for the plugin server.""" + + test_client = {} + + def create_test_client(app, **_): + test_client["client"] = TestClient(app) + + mock_uvicorn.run.side_effect = create_test_client + + _start_plugin_server("0.0.0.0", 8888) + + return test_client["client"] + + +def test_run_bad_request(mock_plugin_server): + body = _Run( + plugin_name="test", + project_id="any", + cloudspace_id="any", + name="any", + entrypoint="any", + ) + + response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "App ID must be specified" in response.text + + +@mock.patch("lightning.app.runners.cloud.CloudRuntime") +def test_run_app(mock_cloud_runtime, mock_plugin_server): + """Tests that app dispatch call the correct `CloudRuntime` methods with the correct arguments.""" + body = _Run( + plugin_name="app", + project_id="test_project_id", + cloudspace_id="test_cloudspace_id", + name="test_name", + entrypoint="test_entrypoint", + ) + + mock_app = mock.MagicMock() + mock_cloud_runtime.load_app_from_file.return_value = mock_app + + response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) + + assert response.status_code == status.HTTP_200_OK + + mock_cloud_runtime.load_app_from_file.assert_called_once_with( + str((Path("/content") / "test_entrypoint").absolute()) + ) + + mock_cloud_runtime.assert_called_once_with( + app=mock_app, + entrypoint=Path("/content/test_entrypoint"), + start_server=True, + env_vars={}, + secrets={}, + run_app_comment_commands=True, + ) + + mock_cloud_runtime().cloudspace_dispatch.assert_called_once_with( + project_id=body.project_id, + cloudspace_id=body.cloudspace_id, + name=body.name, + cluster_id=body.cluster_id, + ) + + +@mock.patch("lightning.app.utilities.commands.base._download_command") +@mock.patch("lightning.app.utilities.cli_helpers._LightningAppOpenAPIRetriever") +def test_run_plugin(mock_retriever, mock_download_command, mock_plugin_server): + """Tests that running a plugin calls the correct `CloudRuntime` methods with the correct arguments.""" + body = _Run( + plugin_name="test_plugin", + project_id="test_project_id", + cloudspace_id="test_cloudspace_id", + name="test_name", + entrypoint="test_entrypoint", + app_id="test_app_id", + ) + + mock_plugin = mock.MagicMock(spec=Plugin) + mock_download_command.return_value = mock_plugin + + mock_retriever.return_value.api_commands = { + body.plugin_name: {"cls_path": "test_cls_path", "cls_name": "test_cls_name"} + } + + response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) + + assert response.status_code == status.HTTP_200_OK + + mock_retriever.assert_called_once_with(body.app_id) + + mock_download_command.assert_called_once_with( + body.plugin_name, + "test_cls_path", + "test_cls_name", + body.app_id, + target_file=mock.ANY, + ) + + mock_plugin._setup.assert_called_once_with(app_id=body.app_id) + mock_plugin.run.assert_called_once_with(body.name, body.entrypoint) diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index e54fc3551e9fb..0af7b3acc776f 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -1601,6 +1601,50 @@ def test_not_enabled(self, monkeypatch, capsys): assert "`lightning open` command has not been enabled" in out +class TestCloudspaceDispatch: + def test_cloudspace_dispatch(self, monkeypatch): + """Tests that the cloudspace_dispatch method calls the expected API endpoints.""" + mock_client = mock.MagicMock() + mock_client.auth_service_get_user.return_value = V1GetUserResponse( + username="tester", + ) + mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( + memberships=[V1Membership(name="project", project_id="project_id")] + ) + mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(id="run_id") + mock_client.cloud_space_service_create_lightning_run_instance.return_value = Externalv1LightningappInstance( + id="instance_id" + ) + + cluster = Externalv1Cluster(id="test", spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)) + mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse( + clusters=[V1ProjectClusterBinding(cluster_id="test")], + ) + mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([cluster]) + mock_client.cluster_service_get_cluster.return_value = cluster + + cloud_backend = mock.MagicMock() + cloud_backend.client = mock_client + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) + mock_local_source = mock.MagicMock() + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock_local_source) + + cloud_runtime = cloud.CloudRuntime(app=mock.MagicMock(), entrypoint=Path(".")) + + cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name") + + mock_client.cloud_space_service_create_lightning_run.assert_called_once_with( + project_id="project_id", + cloudspace_id="cloudspace_id", + body=mock.ANY, + ) + mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with( + project_id="project_id", cloudspace_id="cloudspace_id", id="run_id", body=mock.ANY + ) + + assert mock_client.cloud_space_service_create_lightning_run_instance.call_args.kwargs["body"].name == "run_name" + + @mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock()) @mock.patch("lightning.app.runners.backends.cloud.LightningClient", MagicMock()) def test_get_project(monkeypatch):