|
| 1 | +# Copyright The Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | +import tempfile |
| 16 | +from pathlib import Path |
| 17 | +from typing import Any, Dict, Optional |
| 18 | + |
| 19 | +import requests |
| 20 | +import uvicorn |
| 21 | +from fastapi import FastAPI, HTTPException, status |
| 22 | +from fastapi.middleware.cors import CORSMiddleware |
| 23 | +from pydantic import BaseModel |
| 24 | + |
| 25 | +from lightning_app.utilities.app_helpers import Logger |
| 26 | +from lightning_app.utilities.cloud import _get_project |
| 27 | +from lightning_app.utilities.component import _set_flow_context |
| 28 | +from lightning_app.utilities.enum import AppStage |
| 29 | +from lightning_app.utilities.network import LightningClient |
| 30 | + |
| 31 | +logger = Logger(__name__) |
| 32 | + |
| 33 | + |
| 34 | +class Plugin: |
| 35 | + """A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions.""" |
| 36 | + |
| 37 | + def __init__(self) -> None: |
| 38 | + self.app_url = None |
| 39 | + |
| 40 | + def run(self, name: str, entrypoint: str) -> None: |
| 41 | + """Override with the logic to execute on the client side.""" |
| 42 | + |
| 43 | + def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]: |
| 44 | + """Run a command on the app associated with this plugin. |
| 45 | +
|
| 46 | + Args: |
| 47 | + command_name: The name of the command to run. |
| 48 | + config: The command config or ``None`` if the command doesn't require configuration. |
| 49 | + """ |
| 50 | + if self.app_url is None: |
| 51 | + raise RuntimeError("The plugin must be set up before `run_app_command` can be called.") |
| 52 | + |
| 53 | + command = command_name.replace(" ", "_") |
| 54 | + resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None) |
| 55 | + if resp.status_code != 200: |
| 56 | + try: |
| 57 | + detail = str(resp.json()) |
| 58 | + except Exception: |
| 59 | + detail = "Internal Server Error" |
| 60 | + raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}") |
| 61 | + |
| 62 | + return resp.json() |
| 63 | + |
| 64 | + def _setup(self, app_id: str) -> None: |
| 65 | + client = LightningClient() |
| 66 | + project_id = _get_project(client).project_id |
| 67 | + response = client.lightningapp_instance_service_list_lightningapp_instances( |
| 68 | + project_id=project_id, app_id=app_id |
| 69 | + ) |
| 70 | + if len(response.lightningapps) > 1: |
| 71 | + raise RuntimeError(f"Found multiple apps with ID: {app_id}") |
| 72 | + if len(response.lightningapps) == 0: |
| 73 | + raise RuntimeError(f"Found no apps with ID: {app_id}") |
| 74 | + self.app_url = response.lightningapps[0].status.url |
| 75 | + |
| 76 | + |
| 77 | +class _Run(BaseModel): |
| 78 | + plugin_name: str |
| 79 | + project_id: str |
| 80 | + cloudspace_id: str |
| 81 | + name: str |
| 82 | + entrypoint: str |
| 83 | + cluster_id: Optional[str] = None |
| 84 | + app_id: Optional[str] = None |
| 85 | + |
| 86 | + |
| 87 | +def _run_plugin(run: _Run) -> None: |
| 88 | + """Create a run with the given name and entrypoint under the cloudspace with the given ID.""" |
| 89 | + if run.app_id is None and run.plugin_name == "app": |
| 90 | + from lightning_app.runners.cloud import CloudRuntime |
| 91 | + |
| 92 | + # TODO: App dispatch should be a plugin |
| 93 | + # Dispatch the run |
| 94 | + _set_flow_context() |
| 95 | + |
| 96 | + entrypoint_file = Path("/content") / run.entrypoint |
| 97 | + |
| 98 | + app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute())) |
| 99 | + |
| 100 | + app.stage = AppStage.BLOCKING |
| 101 | + |
| 102 | + runtime = CloudRuntime( |
| 103 | + app=app, |
| 104 | + entrypoint=entrypoint_file, |
| 105 | + start_server=True, |
| 106 | + env_vars={}, |
| 107 | + secrets={}, |
| 108 | + run_app_comment_commands=True, |
| 109 | + ) |
| 110 | + # Used to indicate Lightning has been dispatched |
| 111 | + os.environ["LIGHTNING_DISPATCHED"] = "1" |
| 112 | + |
| 113 | + try: |
| 114 | + runtime.cloudspace_dispatch( |
| 115 | + project_id=run.project_id, |
| 116 | + cloudspace_id=run.cloudspace_id, |
| 117 | + name=run.name, |
| 118 | + cluster_id=run.cluster_id, |
| 119 | + ) |
| 120 | + except Exception as e: |
| 121 | + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
| 122 | + elif run.app_id is not None: |
| 123 | + from lightning_app.utilities.cli_helpers import _LightningAppOpenAPIRetriever |
| 124 | + from lightning_app.utilities.commands.base import _download_command |
| 125 | + |
| 126 | + retriever = _LightningAppOpenAPIRetriever(run.app_id) |
| 127 | + |
| 128 | + metadata = retriever.api_commands[run.plugin_name] # type: ignore |
| 129 | + |
| 130 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 131 | + |
| 132 | + target_file = os.path.join(tmpdir, f"{run.plugin_name}.py") |
| 133 | + plugin = _download_command( |
| 134 | + run.plugin_name, |
| 135 | + metadata["cls_path"], |
| 136 | + metadata["cls_name"], |
| 137 | + run.app_id, |
| 138 | + target_file=target_file, |
| 139 | + ) |
| 140 | + |
| 141 | + if isinstance(plugin, Plugin): |
| 142 | + plugin._setup(app_id=run.app_id) |
| 143 | + plugin.run(run.name, run.entrypoint) |
| 144 | + else: |
| 145 | + # This should never be possible but we check just in case |
| 146 | + raise HTTPException( |
| 147 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 148 | + detail=f"The plugin {run.plugin_name} is an incorrect type.", |
| 149 | + ) |
| 150 | + else: |
| 151 | + raise HTTPException( |
| 152 | + status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`." |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +def _start_plugin_server(host: str, port: int) -> None: |
| 157 | + """Start the plugin server which can be used to dispatch apps or run plugins.""" |
| 158 | + fastapi_service = FastAPI() |
| 159 | + |
| 160 | + fastapi_service.add_middleware( |
| 161 | + CORSMiddleware, |
| 162 | + allow_origins=["*"], |
| 163 | + allow_credentials=True, |
| 164 | + allow_methods=["*"], |
| 165 | + allow_headers=["*"], |
| 166 | + ) |
| 167 | + |
| 168 | + fastapi_service.post("/v1/runs")(_run_plugin) |
| 169 | + |
| 170 | + uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error") |
0 commit comments