|  | 
|  | 1 | +""" | 
|  | 2 | +HTTP Server to route message requests to message producer function. | 
|  | 3 | +""" | 
|  | 4 | + | 
|  | 5 | +from __future__ import annotations | 
|  | 6 | + | 
|  | 7 | +import logging | 
|  | 8 | +import re | 
|  | 9 | +import signal | 
|  | 10 | +import socket | 
|  | 11 | +import subprocess | 
|  | 12 | +import sys | 
|  | 13 | +import time | 
|  | 14 | +from contextlib import closing, contextmanager | 
|  | 15 | +from importlib import import_module | 
|  | 16 | +from pathlib import Path | 
|  | 17 | +from threading import Thread | 
|  | 18 | +from typing import Generator, NoReturn, Tuple, Union | 
|  | 19 | + | 
|  | 20 | +import requests | 
|  | 21 | + | 
|  | 22 | +sys.path.append(str(Path(__file__).parent.parent.parent)) | 
|  | 23 | + | 
|  | 24 | +import flask | 
|  | 25 | +from yarl import URL | 
|  | 26 | + | 
|  | 27 | +logger = logging.getLogger(__name__) | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +class Provider: | 
|  | 31 | +    """ | 
|  | 32 | +    Provider class to route message requests to message producer function. | 
|  | 33 | +
 | 
|  | 34 | +    Sets up three endpoints: | 
|  | 35 | +        - /_test/ping: A simple ping endpoint for testing. | 
|  | 36 | +        - /produce_message: Route message requests to the handler function. | 
|  | 37 | +        - /set_provider_state: Set the provider state. | 
|  | 38 | +
 | 
|  | 39 | +    The specific `produce_message` and `set_provider_state` URLs can be configured | 
|  | 40 | +    with the `produce_message_url` and `set_provider_state_url` arguments. | 
|  | 41 | +    """ | 
|  | 42 | + | 
|  | 43 | +    def __init__( # noqa: PLR0913 | 
|  | 44 | +        self, | 
|  | 45 | +        handler_module: str, | 
|  | 46 | +        handler_function: str, | 
|  | 47 | +        produce_message_url: str, | 
|  | 48 | +        state_provider_module: str, | 
|  | 49 | +        state_provider_function: str, | 
|  | 50 | +        set_provider_state_url: str, | 
|  | 51 | +    ) -> None: | 
|  | 52 | +        """ | 
|  | 53 | +        Initialize the provider. | 
|  | 54 | +
 | 
|  | 55 | +        Args: | 
|  | 56 | +            handler_module: | 
|  | 57 | +                The name of the module containing the handler function. | 
|  | 58 | +            handler_function: | 
|  | 59 | +                The name of the handler function. | 
|  | 60 | +            produce_message_url: | 
|  | 61 | +                The URL to route message requests to the handler function. | 
|  | 62 | +            state_provider_module: | 
|  | 63 | +                The name of the module containing the state provider setup function. | 
|  | 64 | +            state_provider_function: | 
|  | 65 | +                The name of the state provider setup function. | 
|  | 66 | +            set_provider_state_url: | 
|  | 67 | +                The URL to set the provider state. | 
|  | 68 | +        """ | 
|  | 69 | +        self.app = flask.Flask("Provider") | 
|  | 70 | +        self.handler_function = getattr( | 
|  | 71 | +            import_module(handler_module), handler_function | 
|  | 72 | +        ) | 
|  | 73 | +        self.produce_message_url = produce_message_url | 
|  | 74 | +        self.set_provider_state_url = set_provider_state_url | 
|  | 75 | +        if (state_provider_module): | 
|  | 76 | +            self.state_provider_function = getattr( | 
|  | 77 | +                import_module(state_provider_module), | 
|  | 78 | +                state_provider_function | 
|  | 79 | +            ) | 
|  | 80 | + | 
|  | 81 | +        @self.app.get("/_test/ping") | 
|  | 82 | +        def ping() -> str: | 
|  | 83 | +            """Simple ping endpoint for testing.""" | 
|  | 84 | +            return "pong" | 
|  | 85 | + | 
|  | 86 | +        @self.app.route(self.produce_message_url, methods=["POST"]) | 
|  | 87 | +        def produce_message() -> Union[str, Tuple[str, int]]: | 
|  | 88 | +            """ | 
|  | 89 | +            Route a message request to the handler function. | 
|  | 90 | +
 | 
|  | 91 | +            Returns: | 
|  | 92 | +                The response from the handler function. | 
|  | 93 | +            """ | 
|  | 94 | +            try: | 
|  | 95 | +                body, content_type = self.handler_function() | 
|  | 96 | +                return flask.Response( | 
|  | 97 | +                    response=body, | 
|  | 98 | +                    status=200, | 
|  | 99 | +                    content_type=content_type, | 
|  | 100 | +                    direct_passthrough=True, | 
|  | 101 | +                ) | 
|  | 102 | +            except Exception as e: # noqa: BLE001 | 
|  | 103 | +                return str(e), 500 | 
|  | 104 | + | 
|  | 105 | +        @self.app.route(self.set_provider_state_url, methods=["POST"]) | 
|  | 106 | +        def set_provider_state() -> Tuple[str, int]: | 
|  | 107 | +            """ | 
|  | 108 | +            Calls the state provider function with the state provided in the request. | 
|  | 109 | +
 | 
|  | 110 | +            Returns: | 
|  | 111 | +                A response indicating that the state has been set. | 
|  | 112 | +            """ | 
|  | 113 | +            if self.state_provider_function: | 
|  | 114 | +                self.state_provider_function(flask.request.args["state"]) | 
|  | 115 | +            return "Provider state set", 200 | 
|  | 116 | + | 
|  | 117 | +    def _find_free_port(self) -> int: | 
|  | 118 | +        """ | 
|  | 119 | +        Find a free port. | 
|  | 120 | +
 | 
|  | 121 | +        This is used to find a free port to host the API on when running locally. It | 
|  | 122 | +        is allocated, and then released immediately so that it can be used by the | 
|  | 123 | +        API. | 
|  | 124 | +
 | 
|  | 125 | +        Returns: | 
|  | 126 | +            The port number. | 
|  | 127 | +        """ | 
|  | 128 | +        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | 
|  | 129 | +            s.bind(("", 0)) | 
|  | 130 | +            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | 
|  | 131 | +            return s.getsockname()[1] | 
|  | 132 | + | 
|  | 133 | +    def run(self) -> None: | 
|  | 134 | +        """ | 
|  | 135 | +        Start the provider. | 
|  | 136 | +        """ | 
|  | 137 | +        url = URL(f"http://localhost:{self._find_free_port()}") | 
|  | 138 | +        sys.stderr.write(f"Starting provider on {url}\n") | 
|  | 139 | + | 
|  | 140 | +        self.app.run( | 
|  | 141 | +            host=url.host, | 
|  | 142 | +            port=url.port, | 
|  | 143 | +            debug=True, | 
|  | 144 | +        ) | 
|  | 145 | + | 
|  | 146 | + | 
|  | 147 | +def start_provider(**kwargs: dict([str, str])) -> Generator[URL, None, None]: # noqa: C901 | 
|  | 148 | +    """ | 
|  | 149 | +    Start the provider app. | 
|  | 150 | +
 | 
|  | 151 | +    Expects kwargs to to contain the following: | 
|  | 152 | +        handler_module: Required. The name of the module containing | 
|  | 153 | +                        the handler function. | 
|  | 154 | +        handler_function: Required. The name of the handler function. | 
|  | 155 | +        produce_message_url: Optional. The URL to route message requests to | 
|  | 156 | +                             the handler function. | 
|  | 157 | +        state_provider_module: Optional. The name of the module containing | 
|  | 158 | +                               the state provider setup function. | 
|  | 159 | +        state_provider_function: Optional. The name of the state provider | 
|  | 160 | +                                 setup function. | 
|  | 161 | +        set_provider_state_url: Optional. The URL to set the provider state. | 
|  | 162 | +    """ | 
|  | 163 | +    process = subprocess.Popen( | 
|  | 164 | +        [  # noqa: S603 | 
|  | 165 | +            sys.executable, | 
|  | 166 | +            Path(__file__), | 
|  | 167 | +            kwargs.pop("handler_module"), | 
|  | 168 | +            kwargs.pop("handler_function"), | 
|  | 169 | +            kwargs.pop("produce_message_url", "/produce_message"), | 
|  | 170 | +            kwargs.pop("state_provider_module", ""), | 
|  | 171 | +            kwargs.pop("state_provider_function", ""), | 
|  | 172 | +            kwargs.pop("set_provider_state_url", "/set_provider_state"), | 
|  | 173 | +        ], | 
|  | 174 | +        cwd=Path.cwd(), | 
|  | 175 | +        stdout=subprocess.PIPE, | 
|  | 176 | +        stderr=subprocess.PIPE, | 
|  | 177 | +        encoding="utf-8", | 
|  | 178 | +    ) | 
|  | 179 | + | 
|  | 180 | +    pattern = re.compile(r" \* Running on (?P<url>[^ ]+)") | 
|  | 181 | +    while True: | 
|  | 182 | +        if process.poll() is not None: | 
|  | 183 | +            logger.error("Provider process exited with code %d", process.returncode) | 
|  | 184 | +            logger.error( | 
|  | 185 | +                "Provider stdout: %s", process.stdout.read() if process.stdout else "" | 
|  | 186 | +            ) | 
|  | 187 | +            logger.error( | 
|  | 188 | +                "Provider stderr: %s", process.stderr.read() if process.stderr else "" | 
|  | 189 | +            ) | 
|  | 190 | +            msg = f"Provider process exited with code {process.returncode}" | 
|  | 191 | +            raise RuntimeError(msg) | 
|  | 192 | +        if ( | 
|  | 193 | +            process.stderr | 
|  | 194 | +            and (line := process.stderr.readline()) | 
|  | 195 | +            and (match := pattern.match(line)) | 
|  | 196 | +        ): | 
|  | 197 | +            break | 
|  | 198 | +        time.sleep(0.1) | 
|  | 199 | + | 
|  | 200 | +    url = URL(match.group("url")) | 
|  | 201 | +    logger.debug("Provider started on %s", url) | 
|  | 202 | +    for _ in range(50): | 
|  | 203 | +        try: | 
|  | 204 | +            response = requests.get(str(url / "_test" / "ping"), timeout=1) | 
|  | 205 | +            assert response.text == "pong" | 
|  | 206 | +            break | 
|  | 207 | +        except (requests.RequestException, AssertionError): | 
|  | 208 | +            time.sleep(0.1) | 
|  | 209 | +            continue | 
|  | 210 | +    else: | 
|  | 211 | +        msg = "Failed to ping provider" | 
|  | 212 | +        raise RuntimeError(msg) | 
|  | 213 | + | 
|  | 214 | +    def redirect() -> NoReturn: | 
|  | 215 | +        while True: | 
|  | 216 | +            if process.stdout: | 
|  | 217 | +                while line := process.stdout.readline(): | 
|  | 218 | +                    logger.debug("Provider stdout: %s", line.strip()) | 
|  | 219 | +            if process.stderr: | 
|  | 220 | +                while line := process.stderr.readline(): | 
|  | 221 | +                    logger.debug("Provider stderr: %s", line.strip()) | 
|  | 222 | + | 
|  | 223 | +    thread = Thread(target=redirect, daemon=True) | 
|  | 224 | +    thread.start() | 
|  | 225 | + | 
|  | 226 | +    try: | 
|  | 227 | +        yield url | 
|  | 228 | +    finally: | 
|  | 229 | +        process.send_signal(signal.SIGINT) | 
|  | 230 | + | 
|  | 231 | + | 
|  | 232 | +start_provider_context = contextmanager(start_provider) | 
|  | 233 | + | 
|  | 234 | +if __name__ == "__main__": | 
|  | 235 | +    import sys | 
|  | 236 | + | 
|  | 237 | +    if len(sys.argv) < 5: # noqa: PLR2004 | 
|  | 238 | +        sys.stderr.write( | 
|  | 239 | +            f"Usage: {sys.argv[0]} <state_provider_module> <state_provider_function> " | 
|  | 240 | +            f"<handler_module> <handler_function>" | 
|  | 241 | +        ) | 
|  | 242 | +        sys.exit(1) | 
|  | 243 | + | 
|  | 244 | +    handler_module = sys.argv[1] | 
|  | 245 | +    handler_function = sys.argv[2] | 
|  | 246 | +    produce_message_url = sys.argv[3] | 
|  | 247 | +    state_provider_module = sys.argv[4] | 
|  | 248 | +    state_provider_function = sys.argv[5] | 
|  | 249 | +    set_provider_state_url = sys.argv[6] | 
|  | 250 | +    Provider( | 
|  | 251 | +        handler_module, | 
|  | 252 | +        handler_function, | 
|  | 253 | +        produce_message_url, | 
|  | 254 | +        state_provider_module, | 
|  | 255 | +        state_provider_function, | 
|  | 256 | +        set_provider_state_url, | 
|  | 257 | +    ).run() | 
0 commit comments