Skip to content

Commit ecbd94f

Browse files
ethanwharrisBorda
authored andcommitted
[App] Initial plugin server (#16523)
(cherry picked from commit 1288e4c)
1 parent 4d0c2ce commit ecbd94f

File tree

6 files changed

+444
-10
lines changed

6 files changed

+444
-10
lines changed

src/lightning_app/core/flow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from deepdiff import DeepHash
2323

24+
from lightning_app.core.plugin import Plugin
2425
from lightning_app.core.work import LightningWork
2526
from lightning_app.frontend import Frontend
2627
from lightning_app.storage import Path
@@ -740,6 +741,22 @@ def configure_api(self):
740741
"""
741742
raise NotImplementedError
742743

744+
def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]:
745+
"""Configure the plugins of this LightningFlow.
746+
747+
Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`.
748+
749+
.. code-block:: python
750+
751+
class Flow(LightningFlow):
752+
def __init__(self):
753+
super().__init__()
754+
755+
def configure_plugins(self):
756+
return [{"my_plugin_name": MyPlugin()}]
757+
"""
758+
pass
759+
743760
def state_dict(self):
744761
"""Returns the current flow state but not its children."""
745762
return {

src/lightning_app/core/plugin.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import tempfile
16+
from pathlib import Path
17+
from typing import Any, Dict, Optional
18+
19+
import requests
20+
import uvicorn
21+
from fastapi import FastAPI, HTTPException, status
22+
from fastapi.middleware.cors import CORSMiddleware
23+
from pydantic import BaseModel
24+
25+
from lightning_app.utilities.app_helpers import Logger
26+
from lightning_app.utilities.cloud import _get_project
27+
from lightning_app.utilities.component import _set_flow_context
28+
from lightning_app.utilities.enum import AppStage
29+
from lightning_app.utilities.network import LightningClient
30+
31+
logger = Logger(__name__)
32+
33+
34+
class Plugin:
35+
"""A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions."""
36+
37+
def __init__(self) -> None:
38+
self.app_url = None
39+
40+
def run(self, name: str, entrypoint: str) -> None:
41+
"""Override with the logic to execute on the client side."""
42+
43+
def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]:
44+
"""Run a command on the app associated with this plugin.
45+
46+
Args:
47+
command_name: The name of the command to run.
48+
config: The command config or ``None`` if the command doesn't require configuration.
49+
"""
50+
if self.app_url is None:
51+
raise RuntimeError("The plugin must be set up before `run_app_command` can be called.")
52+
53+
command = command_name.replace(" ", "_")
54+
resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
55+
if resp.status_code != 200:
56+
try:
57+
detail = str(resp.json())
58+
except Exception:
59+
detail = "Internal Server Error"
60+
raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}")
61+
62+
return resp.json()
63+
64+
def _setup(self, app_id: str) -> None:
65+
client = LightningClient()
66+
project_id = _get_project(client).project_id
67+
response = client.lightningapp_instance_service_list_lightningapp_instances(
68+
project_id=project_id, app_id=app_id
69+
)
70+
if len(response.lightningapps) > 1:
71+
raise RuntimeError(f"Found multiple apps with ID: {app_id}")
72+
if len(response.lightningapps) == 0:
73+
raise RuntimeError(f"Found no apps with ID: {app_id}")
74+
self.app_url = response.lightningapps[0].status.url
75+
76+
77+
class _Run(BaseModel):
78+
plugin_name: str
79+
project_id: str
80+
cloudspace_id: str
81+
name: str
82+
entrypoint: str
83+
cluster_id: Optional[str] = None
84+
app_id: Optional[str] = None
85+
86+
87+
def _run_plugin(run: _Run) -> None:
88+
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
89+
if run.app_id is None and run.plugin_name == "app":
90+
from lightning_app.runners.cloud import CloudRuntime
91+
92+
# TODO: App dispatch should be a plugin
93+
# Dispatch the run
94+
_set_flow_context()
95+
96+
entrypoint_file = Path("/content") / run.entrypoint
97+
98+
app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute()))
99+
100+
app.stage = AppStage.BLOCKING
101+
102+
runtime = CloudRuntime(
103+
app=app,
104+
entrypoint=entrypoint_file,
105+
start_server=True,
106+
env_vars={},
107+
secrets={},
108+
run_app_comment_commands=True,
109+
)
110+
# Used to indicate Lightning has been dispatched
111+
os.environ["LIGHTNING_DISPATCHED"] = "1"
112+
113+
try:
114+
runtime.cloudspace_dispatch(
115+
project_id=run.project_id,
116+
cloudspace_id=run.cloudspace_id,
117+
name=run.name,
118+
cluster_id=run.cluster_id,
119+
)
120+
except Exception as e:
121+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
122+
elif run.app_id is not None:
123+
from lightning_app.utilities.cli_helpers import _LightningAppOpenAPIRetriever
124+
from lightning_app.utilities.commands.base import _download_command
125+
126+
retriever = _LightningAppOpenAPIRetriever(run.app_id)
127+
128+
metadata = retriever.api_commands[run.plugin_name] # type: ignore
129+
130+
with tempfile.TemporaryDirectory() as tmpdir:
131+
132+
target_file = os.path.join(tmpdir, f"{run.plugin_name}.py")
133+
plugin = _download_command(
134+
run.plugin_name,
135+
metadata["cls_path"],
136+
metadata["cls_name"],
137+
run.app_id,
138+
target_file=target_file,
139+
)
140+
141+
if isinstance(plugin, Plugin):
142+
plugin._setup(app_id=run.app_id)
143+
plugin.run(run.name, run.entrypoint)
144+
else:
145+
# This should never be possible but we check just in case
146+
raise HTTPException(
147+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
148+
detail=f"The plugin {run.plugin_name} is an incorrect type.",
149+
)
150+
else:
151+
raise HTTPException(
152+
status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`."
153+
)
154+
155+
156+
def _start_plugin_server(host: str, port: int) -> None:
157+
"""Start the plugin server which can be used to dispatch apps or run plugins."""
158+
fastapi_service = FastAPI()
159+
160+
fastapi_service.add_middleware(
161+
CORSMiddleware,
162+
allow_origins=["*"],
163+
allow_credentials=True,
164+
allow_methods=["*"],
165+
allow_headers=["*"],
166+
)
167+
168+
fastapi_service.post("/v1/runs")(_run_plugin)
169+
170+
uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error")

src/lightning_app/runners/cloud.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,67 @@ def open(self, name: str, cluster_id: Optional[str] = None):
208208
logger.error(e.body)
209209
sys.exit(1)
210210

211+
def cloudspace_dispatch(
212+
self,
213+
project_id: str,
214+
cloudspace_id: str,
215+
name: str,
216+
cluster_id: str = None,
217+
):
218+
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
219+
such as the project and cluster IDs that are instead passed directly.
220+
221+
Args:
222+
project_id: The ID of the project.
223+
cloudspace_id: The ID of the cloudspace.
224+
name: The name for the run.
225+
cluster_id: The ID of the cluster to run on.
226+
227+
Raises:
228+
ApiException: If there was an issue in the backend.
229+
RuntimeError: If there are validation errors.
230+
ValueError: If there are validation errors.
231+
"""
232+
# Dispatch in four phases: resolution, validation, spec creation, API transactions
233+
# Resolution
234+
root = self._resolve_root()
235+
ignore_functions = self._resolve_open_ignore_functions()
236+
repo = self._resolve_repo(root, ignore_functions)
237+
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
238+
cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace])
239+
queue_server_type = self._resolve_queue_server_type()
240+
241+
self.app._update_index_file()
242+
243+
# Validation
244+
# TODO: Validate repo and surface to the user
245+
# self._validate_repo(root, repo)
246+
self._validate_work_build_specs_and_compute()
247+
self._validate_drives()
248+
self._validate_mounts()
249+
250+
# Spec creation
251+
flow_servers = self._get_flow_servers()
252+
network_configs = self._get_network_configs(flow_servers)
253+
works = self._get_works()
254+
run_body = self._get_run_body(cluster_id, flow_servers, network_configs, works, False, root, True)
255+
env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands)
256+
257+
# API transactions
258+
run = self._api_create_run(project_id, cloudspace_id, run_body)
259+
self._api_package_and_upload_repo(repo, run)
260+
261+
self._api_create_run_instance(
262+
cluster_id,
263+
project_id,
264+
name,
265+
cloudspace_id,
266+
run.id,
267+
V1LightningappInstanceState.RUNNING,
268+
queue_server_type,
269+
env_vars,
270+
)
271+
211272
def dispatch(
212273
self,
213274
name: str = "",
@@ -410,6 +471,13 @@ def _resolve_project(self) -> V1Membership:
410471
"""Determine the project to run on, choosing a default if multiple projects are found."""
411472
return _get_project(self.backend.client)
412473

474+
def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace:
475+
"""Get a cloudspace by project / cloudspace ID."""
476+
return self.backend.client.cloud_space_service_get_cloud_space(
477+
project_id=project_id,
478+
id=cloudspace_id,
479+
)
480+
413481
def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]:
414482
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""
415483
# TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces.
@@ -871,7 +939,7 @@ def _api_create_run_instance(
871939
self,
872940
cluster_id: str,
873941
project_id: str,
874-
cloudspace_name: str,
942+
run_name: str,
875943
cloudspace_id: str,
876944
run_id: str,
877945
desired_state: V1LightningappInstanceState,
@@ -886,7 +954,7 @@ def _api_create_run_instance(
886954
id=run_id,
887955
body=IdGetBody1(
888956
cluster_id=cluster_id,
889-
name=cloudspace_name,
957+
name=run_name,
890958
desired_state=desired_state,
891959
queue_server_type=queue_server_type,
892960
env=env_vars,

src/lightning_app/utilities/commands/base.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from lightning_app.api.http_methods import Post
3333
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
34+
from lightning_app.core.plugin import Plugin
3435
from lightning_app.utilities import frontend
3536
from lightning_app.utilities.app_helpers import is_overridden, Logger
3637
from lightning_app.utilities.cloud import _get_project
@@ -108,7 +109,7 @@ def _download_command(
108109
app_id: Optional[str] = None,
109110
debug_mode: bool = False,
110111
target_file: Optional[str] = None,
111-
) -> ClientCommand:
112+
) -> Union[ClientCommand, Plugin]:
112113
# TODO: This is a skateboard implementation and the final version will rely on versioned
113114
# immutable commands for security concerns
114115
command_name = command_name.replace(" ", "_")
@@ -139,7 +140,13 @@ def _download_command(
139140
mod = module_from_spec(spec)
140141
sys.modules[cls_name] = mod
141142
spec.loader.exec_module(mod)
142-
command = getattr(mod, cls_name)(method=None)
143+
command_type = getattr(mod, cls_name)
144+
if issubclass(command_type, ClientCommand):
145+
command = command_type(method=None)
146+
elif issubclass(command_type, Plugin):
147+
command = command_type()
148+
else:
149+
raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand` or `Plugin`.")
143150
if tmpdir and os.path.exists(tmpdir):
144151
shutil.rmtree(tmpdir)
145152
return command
@@ -182,12 +189,11 @@ def _validate_client_command(command: ClientCommand):
182189
)
183190

184191

185-
def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
192+
def _upload(name: str, prefix: str, obj: Any) -> Optional[str]:
186193
from lightning_app.storage.path import _filesystem, _is_s3fs_available, _shared_storage_path
187194

188-
command_name = command_name.replace(" ", "_")
189-
filepath = f"commands/{command_name}.py"
190-
remote_url = str(_shared_storage_path() / "artifacts" / filepath)
195+
name = name.replace(" ", "_")
196+
filepath = f"{prefix}/{name}.py"
191197
fs = _filesystem()
192198

193199
if _is_s3fs_available():
@@ -196,7 +202,7 @@ def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
196202
if not isinstance(fs, S3FileSystem):
197203
return
198204

199-
source_file = str(inspect.getfile(command.__class__))
205+
source_file = str(inspect.getfile(obj.__class__))
200206
remote_url = str(_shared_storage_path() / "artifacts" / filepath)
201207
fs.put(source_file, remote_url)
202208
return filepath
@@ -211,13 +217,25 @@ def _prepare_commands(app) -> List:
211217
for command_mapping in commands:
212218
for command_name, command in command_mapping.items():
213219
if isinstance(command, ClientCommand):
214-
_upload_command(command_name, command)
220+
_upload(command_name, "commands", command)
215221

216222
# 2: Cache the commands on the app.
217223
app.commands = commands
218224
return commands
219225

220226

227+
def _prepare_plugins(app) -> List:
228+
if not is_overridden("configure_plugins", app.root):
229+
return []
230+
231+
# 1: Upload the plugins to s3.
232+
plugins = app.root.configure_plugins()
233+
for plugin_mapping in plugins:
234+
for plugin_name, plugin in plugin_mapping.items():
235+
if isinstance(plugin, Plugin):
236+
_upload(plugin_name, "plugins", plugin)
237+
238+
221239
def _process_api_request(app, request: _APIRequest):
222240
flow = app.get_component_by_name(request.name)
223241
method = getattr(flow, request.method_name)

0 commit comments

Comments
 (0)