22import inspect
33import time
44from copy import deepcopy
5+ from dataclasses import dataclass
56from functools import wraps
67from multiprocessing import Queue
78from typing import Any , Callable , Dict , List , Optional
89from uuid import uuid4
910
10- from fastapi import FastAPI , HTTPException
11+ from fastapi import FastAPI , HTTPException , Request
12+ from lightning_utilities .core .apply_func import apply_to_collection
1113
1214from lightning_app .api .request_types import _APIRequest , _CommandRequest , _RequestResponse
1315from lightning_app .utilities .app_helpers import Logger
@@ -19,6 +21,77 @@ def _signature_proxy_function():
1921 pass
2022
2123
24+ @dataclass
25+ class _FastApiMockRequest :
26+ """This class is meant to mock FastAPI Request class that isn't pickle-able.
27+
28+ If a user relies on FastAPI Request annotation, the Lightning framework
29+ patches the annotation before pickling and replace them right after.
30+
31+ Finally, the FastAPI request is converted back to the _FastApiMockRequest
32+ before being delivered to the users.
33+
34+ Example:
35+
36+ import lightning as L
37+ from fastapi import Request
38+ from lightning.app.api import Post
39+
40+ class Flow(L.LightningFlow):
41+
42+ def request(self, request: Request) -> OutputRequestModel:
43+ ...
44+
45+ def configure_api(self):
46+ return [Post("/api/v1/request", self.request)]
47+ """
48+
49+ _body : Optional [str ] = None
50+ _json : Optional [str ] = None
51+ _method : Optional [str ] = None
52+ _headers : Optional [Dict ] = None
53+
54+ @property
55+ def receive (self ):
56+ raise NotImplementedError
57+
58+ @property
59+ def method (self ):
60+ raise self ._method
61+
62+ @property
63+ def headers (self ):
64+ return self ._headers
65+
66+ def body (self ):
67+ return self ._body
68+
69+ def json (self ):
70+ return self ._json
71+
72+ def stream (self ):
73+ raise NotImplementedError
74+
75+ def form (self ):
76+ raise NotImplementedError
77+
78+ def close (self ):
79+ raise NotImplementedError
80+
81+ def is_disconnected (self ):
82+ raise NotImplementedError
83+
84+
85+ async def _mock_fastapi_request (request : Request ):
86+ # TODO: Add more requests parameters.
87+ return _FastApiMockRequest (
88+ _body = await request .body (),
89+ _json = await request .json (),
90+ _headers = request .headers ,
91+ _method = request .method ,
92+ )
93+
94+
2295class _HttpMethod :
2396 def __init__ (self , route : str , method : Callable , method_name : Optional [str ] = None , timeout : int = 30 , ** kwargs ):
2497 """This class is used to inject user defined methods within the App Rest API.
@@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
34107 self .method_annotations = method .__annotations__
35108 # TODO: Validate the signature contains only pydantic models.
36109 self .method_signature = inspect .signature (method )
110+
37111 if not self .attached_to_flow :
38112 self .component_name = method .__name__
39113 self .method = method
@@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
43117 self .timeout = timeout
44118 self .kwargs = kwargs
45119
120+ # Enable the users to rely on FastAPI annotation typing with Request.
121+ # Note: Only a part of the Request functionatilities are supported.
122+ self ._patch_fast_api_request ()
123+
46124 def add_route (self , app : FastAPI , request_queue : Queue , responses_store : Dict [str , Any ]) -> None :
47125 # 1: Get the route associated with the http method.
48126 route = getattr (app , self .__class__ .__name__ .lower ())
49127
128+ self ._unpatch_fast_api_request ()
129+
50130 # 2: Create a proxy function with the signature of the wrapped method.
51131 fn = deepcopy (_signature_proxy_function )
52132 fn .__annotations__ = self .method_annotations
@@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
69149 @wraps (_signature_proxy_function )
70150 async def _handle_request (* args , ** kwargs ):
71151 async def fn (* args , ** kwargs ):
152+ args , kwargs = apply_to_collection ((args , kwargs ), Request , _mock_fastapi_request )
153+ for k , v in kwargs .items ():
154+ if hasattr (v , "__await__" ):
155+ kwargs [k ] = await v
156+
72157 request_id = str (uuid4 ()).split ("-" )[0 ]
73158 logger .debug (f"Processing request { request_id } for route: { self .route } " )
74159 request_queue .put (
@@ -101,6 +186,26 @@ async def fn(*args, **kwargs):
101186 # 4: Register the user provided route to the Rest API.
102187 route (self .route , ** self .kwargs )(_handle_request )
103188
189+ def _patch_fast_api_request (self ):
190+ """This function replaces signature annotation for Request with its mock."""
191+ for k , v in self .method_annotations .items ():
192+ if v == Request :
193+ self .method_annotations [k ] = _FastApiMockRequest
194+
195+ for v in self .method_signature .parameters .values ():
196+ if v ._annotation == Request :
197+ v ._annotation = _FastApiMockRequest
198+
199+ def _unpatch_fast_api_request (self ):
200+ """This function replaces back signature annotation to fastapi Request."""
201+ for k , v in self .method_annotations .items ():
202+ if v == _FastApiMockRequest :
203+ self .method_annotations [k ] = Request
204+
205+ for v in self .method_signature .parameters .values ():
206+ if v ._annotation == _FastApiMockRequest :
207+ v ._annotation = Request
208+
104209
105210class Post (_HttpMethod ):
106211 pass
0 commit comments