|
| 1 | +""" |
| 2 | +HTTP Server to route message requests to message producer function. |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import logging |
| 8 | +import re |
| 9 | +import signal |
| 10 | +import socket |
| 11 | +import subprocess |
| 12 | +import sys |
| 13 | +import time |
| 14 | +from contextlib import closing, contextmanager |
| 15 | +from importlib import import_module |
| 16 | +from pathlib import Path |
| 17 | +from threading import Thread |
| 18 | +from typing import Generator, NoReturn, Tuple, Union |
| 19 | + |
| 20 | +import requests |
| 21 | + |
| 22 | +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) |
| 23 | + |
| 24 | +import flask |
| 25 | +from yarl import URL |
| 26 | + |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +class Provider: |
| 31 | + """ |
| 32 | + Provider class to route message requests to message producer function. |
| 33 | +
|
| 34 | + Sets up three endpoints: |
| 35 | + - /_test/ping: A simple ping endpoint for testing. |
| 36 | + - /produce_message: Route message requests to the handler function. |
| 37 | + - /set_provider_state: Set the provider state. |
| 38 | +
|
| 39 | + The specific `produce_message` and `set_provider_state` URLs can be configured |
| 40 | + with the `produce_message_url` and `set_provider_state_url` arguments. |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__( # noqa: PLR0913 |
| 44 | + self, |
| 45 | + handler_module: str, |
| 46 | + handler_function: str, |
| 47 | + produce_message_url: str, |
| 48 | + state_provider_module: str, |
| 49 | + state_provider_function: str, |
| 50 | + set_provider_state_url: str, |
| 51 | + ) -> None: |
| 52 | + """ |
| 53 | + Initialize the provider. |
| 54 | +
|
| 55 | + Args: |
| 56 | + handler_module: |
| 57 | + The name of the module containing the handler function. |
| 58 | + handler_function: |
| 59 | + The name of the handler function. |
| 60 | + produce_message_url: |
| 61 | + The URL to route message requests to the handler function. |
| 62 | + state_provider_module: |
| 63 | + The name of the module containing the state provider setup function. |
| 64 | + state_provider_function: |
| 65 | + The name of the state provider setup function. |
| 66 | + set_provider_state_url: |
| 67 | + The URL to set the provider state. |
| 68 | + """ |
| 69 | + self.app = flask.Flask("Provider") |
| 70 | + self.handler_function = getattr( |
| 71 | + import_module(handler_module), handler_function |
| 72 | + ) |
| 73 | + self.produce_message_url = produce_message_url |
| 74 | + self.set_provider_state_url = set_provider_state_url |
| 75 | + if (state_provider_module): |
| 76 | + self.state_provider_function = getattr( |
| 77 | + import_module(state_provider_module), |
| 78 | + state_provider_function |
| 79 | + ) |
| 80 | + |
| 81 | + @self.app.get("/_test/ping") |
| 82 | + def ping() -> str: |
| 83 | + """Simple ping endpoint for testing.""" |
| 84 | + return "pong" |
| 85 | + |
| 86 | + @self.app.route(self.produce_message_url, methods=["POST"]) |
| 87 | + def produce_message() -> Union[str, Tuple[str, int]]: |
| 88 | + """ |
| 89 | + Route a message request to the handler function. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + The response from the handler function. |
| 93 | + """ |
| 94 | + try: |
| 95 | + body, content_type = self.handler_function() |
| 96 | + return flask.Response( |
| 97 | + response=body, |
| 98 | + status=200, |
| 99 | + content_type=content_type, |
| 100 | + direct_passthrough=True, |
| 101 | + ) |
| 102 | + except Exception as e: # noqa: BLE001 |
| 103 | + return str(e), 500 |
| 104 | + |
| 105 | + @self.app.route(self.set_provider_state_url, methods=["POST"]) |
| 106 | + def set_provider_state() -> Tuple[str, int]: |
| 107 | + """ |
| 108 | + Calls the state provider function with the state provided in the request. |
| 109 | +
|
| 110 | + Returns: |
| 111 | + A response indicating that the state has been set. |
| 112 | + """ |
| 113 | + if self.state_provider_function: |
| 114 | + self.state_provider_function(flask.request.args["state"]) |
| 115 | + return "Provider state set", 200 |
| 116 | + |
| 117 | + def _find_free_port(self) -> int: |
| 118 | + """ |
| 119 | + Find a free port. |
| 120 | +
|
| 121 | + This is used to find a free port to host the API on when running locally. It |
| 122 | + is allocated, and then released immediately so that it can be used by the |
| 123 | + API. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + The port number. |
| 127 | + """ |
| 128 | + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: |
| 129 | + s.bind(("", 0)) |
| 130 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 131 | + return s.getsockname()[1] |
| 132 | + |
| 133 | + def run(self) -> None: |
| 134 | + """ |
| 135 | + Start the provider. |
| 136 | + """ |
| 137 | + url = URL(f"http://localhost:{self._find_free_port()}") |
| 138 | + sys.stderr.write(f"Starting provider on {url}\n") |
| 139 | + |
| 140 | + self.app.run( |
| 141 | + host=url.host, |
| 142 | + port=url.port, |
| 143 | + debug=True, |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +def start_provider(**kwargs: dict([str, str])) -> Generator[URL, None, None]: # noqa: C901 |
| 148 | + """ |
| 149 | + Start the provider app. |
| 150 | +
|
| 151 | + Expects kwargs to to contain the following: |
| 152 | + handler_module: Required. The name of the module containing |
| 153 | + the handler function. |
| 154 | + handler_function: Required. The name of the handler function. |
| 155 | + produce_message_url: Optional. The URL to route message requests to |
| 156 | + the handler function. |
| 157 | + state_provider_module: Optional. The name of the module containing |
| 158 | + the state provider setup function. |
| 159 | + state_provider_function: Optional. The name of the state provider |
| 160 | + setup function. |
| 161 | + set_provider_state_url: Optional. The URL to set the provider state. |
| 162 | + """ |
| 163 | + process = subprocess.Popen( |
| 164 | + [ # noqa: S603 |
| 165 | + sys.executable, |
| 166 | + Path(__file__), |
| 167 | + kwargs.pop("handler_module"), |
| 168 | + kwargs.pop("handler_function"), |
| 169 | + kwargs.pop("produce_message_url", "/produce_message"), |
| 170 | + kwargs.pop("state_provider_module", ""), |
| 171 | + kwargs.pop("state_provider_function", ""), |
| 172 | + kwargs.pop("set_provider_state_url", "/set_provider_state"), |
| 173 | + ], |
| 174 | + cwd=Path.cwd(), |
| 175 | + stdout=subprocess.PIPE, |
| 176 | + stderr=subprocess.PIPE, |
| 177 | + encoding="utf-8", |
| 178 | + ) |
| 179 | + |
| 180 | + pattern = re.compile(r" \* Running on (?P<url>[^ ]+)") |
| 181 | + while True: |
| 182 | + if process.poll() is not None: |
| 183 | + logger.error("Provider process exited with code %d", process.returncode) |
| 184 | + logger.error( |
| 185 | + "Provider stdout: %s", process.stdout.read() if process.stdout else "" |
| 186 | + ) |
| 187 | + logger.error( |
| 188 | + "Provider stderr: %s", process.stderr.read() if process.stderr else "" |
| 189 | + ) |
| 190 | + msg = f"Provider process exited with code {process.returncode}" |
| 191 | + raise RuntimeError(msg) |
| 192 | + if ( |
| 193 | + process.stderr |
| 194 | + and (line := process.stderr.readline()) |
| 195 | + and (match := pattern.match(line)) |
| 196 | + ): |
| 197 | + break |
| 198 | + time.sleep(0.1) |
| 199 | + |
| 200 | + url = URL(match.group("url")) |
| 201 | + logger.debug("Provider started on %s", url) |
| 202 | + for _ in range(50): |
| 203 | + try: |
| 204 | + response = requests.get(str(url / "_test" / "ping"), timeout=1) |
| 205 | + assert response.text == "pong" |
| 206 | + break |
| 207 | + except (requests.RequestException, AssertionError): |
| 208 | + time.sleep(0.1) |
| 209 | + continue |
| 210 | + else: |
| 211 | + msg = "Failed to ping provider" |
| 212 | + raise RuntimeError(msg) |
| 213 | + |
| 214 | + def redirect() -> NoReturn: |
| 215 | + while True: |
| 216 | + if process.stdout: |
| 217 | + while line := process.stdout.readline(): |
| 218 | + logger.debug("Provider stdout: %s", line.strip()) |
| 219 | + if process.stderr: |
| 220 | + while line := process.stderr.readline(): |
| 221 | + logger.debug("Provider stderr: %s", line.strip()) |
| 222 | + |
| 223 | + thread = Thread(target=redirect, daemon=True) |
| 224 | + thread.start() |
| 225 | + |
| 226 | + try: |
| 227 | + yield url |
| 228 | + finally: |
| 229 | + process.send_signal(signal.SIGINT) |
| 230 | + |
| 231 | + |
| 232 | +start_provider_context = contextmanager(start_provider) |
| 233 | + |
| 234 | +if __name__ == "__main__": |
| 235 | + import sys |
| 236 | + |
| 237 | + if len(sys.argv) < 5: # noqa: PLR2004 |
| 238 | + sys.stderr.write( |
| 239 | + f"Usage: {sys.argv[0]} <state_provider_module> <state_provider_function> " |
| 240 | + f"<handler_module> <handler_function>" |
| 241 | + ) |
| 242 | + sys.exit(1) |
| 243 | + |
| 244 | + handler_module = sys.argv[1] |
| 245 | + handler_function = sys.argv[2] |
| 246 | + produce_message_url = sys.argv[3] |
| 247 | + state_provider_module = sys.argv[4] |
| 248 | + state_provider_function = sys.argv[5] |
| 249 | + set_provider_state_url = sys.argv[6] |
| 250 | + Provider( |
| 251 | + handler_module, |
| 252 | + handler_function, |
| 253 | + produce_message_url, |
| 254 | + state_provider_module, |
| 255 | + state_provider_function, |
| 256 | + set_provider_state_url, |
| 257 | + ).run() |
0 commit comments