Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/lightning/app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from deepdiff import DeepHash

from lightning.app.core.plugin import Plugin
from lightning.app.core.work import LightningWork
from lightning.app.frontend import Frontend
from lightning.app.storage import Path
Expand Down Expand Up @@ -740,6 +741,22 @@ def configure_api(self):
"""
raise NotImplementedError

def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]:
"""Configure the plugins of this LightningFlow.

Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`.

.. code-block:: python

class Flow(LightningFlow):
def __init__(self):
super().__init__()

def configure_plugins(self):
return [{"my_plugin_name": MyPlugin()}]
"""
pass

def state_dict(self):
"""Returns the current flow state but not its children."""
return {
Expand Down
170 changes: 170 additions & 0 deletions src/lightning/app/core/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional

import requests
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from lightning.app.utilities.app_helpers import Logger
from lightning.app.utilities.cloud import _get_project
from lightning.app.utilities.component import _set_flow_context
from lightning.app.utilities.enum import AppStage
from lightning.app.utilities.network import LightningClient

logger = Logger(__name__)


class Plugin:
"""A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions."""

def __init__(self) -> None:
self.app_url = None

def run(self, name: str, entrypoint: str) -> None:
"""Override with the logic to execute on the client side."""

def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]:
"""Run a command on the app associated with this plugin.

Args:
command_name: The name of the command to run.
config: The command config or ``None`` if the command doesn't require configuration.
"""
if self.app_url is None:
raise RuntimeError("The plugin must be set up before `run_app_command` can be called.")

command = command_name.replace(" ", "_")
resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
if resp.status_code != 200:
try:
detail = str(resp.json())
except Exception:
detail = "Internal Server Error"
raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}")

return resp.json()

def _setup(self, app_id: str) -> None:
client = LightningClient()
project_id = _get_project(client).project_id
response = client.lightningapp_instance_service_list_lightningapp_instances(
project_id=project_id, app_id=app_id
)
if len(response.lightningapps) > 1:
raise RuntimeError(f"Found multiple apps with ID: {app_id}")
if len(response.lightningapps) == 0:
raise RuntimeError(f"Found no apps with ID: {app_id}")
self.app_url = response.lightningapps[0].status.url


class _Run(BaseModel):
plugin_name: str
project_id: str
cloudspace_id: str
name: str
entrypoint: str
cluster_id: Optional[str] = None
app_id: Optional[str] = None


def _run_plugin(run: _Run) -> None:
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
if run.app_id is None and run.plugin_name == "app":
from lightning.app.runners.cloud import CloudRuntime

# TODO: App dispatch should be a plugin
# Dispatch the run
_set_flow_context()

entrypoint_file = Path("/content") / run.entrypoint

app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute()))

app.stage = AppStage.BLOCKING

runtime = CloudRuntime(
app=app,
entrypoint=entrypoint_file,
start_server=True,
env_vars={},
secrets={},
run_app_comment_commands=True,
)
# Used to indicate Lightning has been dispatched
os.environ["LIGHTNING_DISPATCHED"] = "1"

try:
runtime.cloudspace_dispatch(
project_id=run.project_id,
cloudspace_id=run.cloudspace_id,
name=run.name,
cluster_id=run.cluster_id,
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
elif run.app_id is not None:
from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever
from lightning.app.utilities.commands.base import _download_command

retriever = _LightningAppOpenAPIRetriever(run.app_id)

metadata = retriever.api_commands[run.plugin_name] # type: ignore

with tempfile.TemporaryDirectory() as tmpdir:

target_file = os.path.join(tmpdir, f"{run.plugin_name}.py")
plugin = _download_command(
run.plugin_name,
metadata["cls_path"],
metadata["cls_name"],
run.app_id,
target_file=target_file,
)

if isinstance(plugin, Plugin):
plugin._setup(app_id=run.app_id)
plugin.run(run.name, run.entrypoint)
else:
# This should never be possible but we check just in case
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The plugin {run.plugin_name} is an incorrect type.",
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`."
)


def _start_plugin_server(host: str, port: int) -> None:
"""Start the plugin server which can be used to dispatch apps or run plugins."""
fastapi_service = FastAPI()

fastapi_service.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

fastapi_service.post("/v1/runs")(_run_plugin)

uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error")
72 changes: 70 additions & 2 deletions src/lightning/app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,67 @@ def open(self, name: str, cluster_id: Optional[str] = None):
logger.error(e.body)
sys.exit(1)

def cloudspace_dispatch(
self,
project_id: str,
cloudspace_id: str,
name: str,
cluster_id: str = None,
):
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
such as the project and cluster IDs that are instead passed directly.

Args:
project_id: The ID of the project.
cloudspace_id: The ID of the cloudspace.
name: The name for the run.
cluster_id: The ID of the cluster to run on.

Raises:
ApiException: If there was an issue in the backend.
RuntimeError: If there are validation errors.
ValueError: If there are validation errors.
"""
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
root = self._resolve_root()
ignore_functions = self._resolve_open_ignore_functions()
repo = self._resolve_repo(root, ignore_functions)
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace])
queue_server_type = self._resolve_queue_server_type()

self.app._update_index_file()

# Validation
# TODO: Validate repo and surface to the user
# self._validate_repo(root, repo)
self._validate_work_build_specs_and_compute()
self._validate_drives()
self._validate_mounts()

# Spec creation
flow_servers = self._get_flow_servers()
network_configs = self._get_network_configs(flow_servers)
works = self._get_works()
run_body = self._get_run_body(cluster_id, flow_servers, network_configs, works, False, root, True)
env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands)

# API transactions
run = self._api_create_run(project_id, cloudspace_id, run_body)
self._api_package_and_upload_repo(repo, run)

self._api_create_run_instance(
cluster_id,
project_id,
name,
cloudspace_id,
run.id,
V1LightningappInstanceState.RUNNING,
queue_server_type,
env_vars,
)

def dispatch(
self,
name: str = "",
Expand Down Expand Up @@ -410,6 +471,13 @@ def _resolve_project(self) -> V1Membership:
"""Determine the project to run on, choosing a default if multiple projects are found."""
return _get_project(self.backend.client)

def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace:
"""Get a cloudspace by project / cloudspace ID."""
return self.backend.client.cloud_space_service_get_cloud_space(
project_id=project_id,
id=cloudspace_id,
)

def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]:
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""
# TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces.
Expand Down Expand Up @@ -871,7 +939,7 @@ def _api_create_run_instance(
self,
cluster_id: str,
project_id: str,
cloudspace_name: str,
run_name: str,
cloudspace_id: str,
run_id: str,
desired_state: V1LightningappInstanceState,
Expand All @@ -886,7 +954,7 @@ def _api_create_run_instance(
id=run_id,
body=IdGetBody1(
cluster_id=cluster_id,
name=cloudspace_name,
name=run_name,
desired_state=desired_state,
queue_server_type=queue_server_type,
env=env_vars,
Expand Down
34 changes: 26 additions & 8 deletions src/lightning/app/utilities/commands/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from lightning.app.api.http_methods import Post
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
from lightning.app.core.plugin import Plugin
from lightning.app.utilities import frontend
from lightning.app.utilities.app_helpers import is_overridden, Logger
from lightning.app.utilities.cloud import _get_project
Expand Down Expand Up @@ -108,7 +109,7 @@ def _download_command(
app_id: Optional[str] = None,
debug_mode: bool = False,
target_file: Optional[str] = None,
) -> ClientCommand:
) -> Union[ClientCommand, Plugin]:
# TODO: This is a skateboard implementation and the final version will rely on versioned
# immutable commands for security concerns
command_name = command_name.replace(" ", "_")
Expand Down Expand Up @@ -139,7 +140,13 @@ def _download_command(
mod = module_from_spec(spec)
sys.modules[cls_name] = mod
spec.loader.exec_module(mod)
command = getattr(mod, cls_name)(method=None)
command_type = getattr(mod, cls_name)
if issubclass(command_type, ClientCommand):
command = command_type(method=None)
elif issubclass(command_type, Plugin):
command = command_type()
else:
raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand` or `Plugin`.")
if tmpdir and os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
return command
Expand Down Expand Up @@ -182,12 +189,11 @@ def _validate_client_command(command: ClientCommand):
)


def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
def _upload(name: str, prefix: str, obj: Any) -> Optional[str]:
from lightning.app.storage.path import _filesystem, _is_s3fs_available, _shared_storage_path

command_name = command_name.replace(" ", "_")
filepath = f"commands/{command_name}.py"
remote_url = str(_shared_storage_path() / "artifacts" / filepath)
name = name.replace(" ", "_")
filepath = f"{prefix}/{name}.py"
fs = _filesystem()

if _is_s3fs_available():
Expand All @@ -196,7 +202,7 @@ def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
if not isinstance(fs, S3FileSystem):
return

source_file = str(inspect.getfile(command.__class__))
source_file = str(inspect.getfile(obj.__class__))
remote_url = str(_shared_storage_path() / "artifacts" / filepath)
fs.put(source_file, remote_url)
return filepath
Expand All @@ -211,13 +217,25 @@ def _prepare_commands(app) -> List:
for command_mapping in commands:
for command_name, command in command_mapping.items():
if isinstance(command, ClientCommand):
_upload_command(command_name, command)
_upload(command_name, "commands", command)

# 2: Cache the commands on the app.
app.commands = commands
return commands


def _prepare_plugins(app) -> List:
if not is_overridden("configure_plugins", app.root):
return []

# 1: Upload the plugins to s3.
plugins = app.root.configure_plugins()
for plugin_mapping in plugins:
for plugin_name, plugin in plugin_mapping.items():
if isinstance(plugin, Plugin):
_upload(plugin_name, "plugins", plugin)


def _process_api_request(app, request: _APIRequest):
flow = app.get_component_by_name(request.name)
method = getattr(flow, request.method_name)
Expand Down
Loading