44import pathlib
55from dataclasses import dataclass
66from enum import Enum
7+ from time import sleep
78from typing import Optional
89from urllib .parse import urlencode
910
@@ -44,9 +45,11 @@ def __post_init__(self):
4445 setattr (self , key .suffix , os .environ .get (key .value , None ))
4546
4647 self ._with_env_var = bool (self .user_id and self .api_key ) # used by authenticate method
47- if self .api_key and not self .user_id :
48+ if self ._with_env_var :
49+ self .save ("" , self .user_id , self .api_key , self .user_id )
50+ logger .info ("Credentials loaded from environment variables" )
51+ elif self .api_key or self .user_id :
4852 raise ValueError (
49- f"{ Keys .USER_ID .value } is missing from env variables. "
5053 "To use env vars for authentication both "
5154 f"{ Keys .USER_ID .value } and { Keys .API_KEY .value } should be set."
5255 )
@@ -135,7 +138,8 @@ def authenticate(self) -> Optional[str]:
135138
136139
137140class AuthServer :
138- def get_auth_url (self , port : int ) -> str :
141+ @staticmethod
142+ def get_auth_url (port : int ) -> str :
139143 redirect_uri = f"http://localhost:{ port } /login-complete"
140144 params = urlencode (dict (redirectTo = redirect_uri ))
141145 return f"{ get_lightning_cloud_url ()} /sign-in?{ params } "
@@ -144,6 +148,7 @@ def login_with_browser(self, auth: Auth) -> None:
144148 app = FastAPI ()
145149 port = find_free_network_port ()
146150 url = self .get_auth_url (port )
151+
147152 try :
148153 # check if server is reachable or catch any network errors
149154 requests .head (url )
@@ -156,32 +161,42 @@ def login_with_browser(self, auth: Auth) -> None:
156161 f"An error occurred with the request. Please report this issue to Lightning Team \n { e } " # E501
157162 )
158163
159- logger .info (f"login started for lightning.ai, opening { url } " )
164+ logger .info (
165+ "\n Attempting to automatically open the login page in your default browser.\n "
166+ 'If the browser does not open, navigate to the "Keys" tab on your Lightning AI profile page:\n \n '
167+ f"{ get_lightning_cloud_url ()} /me/keys\n \n "
168+ 'Copy the "Headless CLI Login" command, and execute it in your terminal.\n '
169+ )
160170 click .launch (url )
161171
162172 @app .get ("/login-complete" )
163173 async def save_token (request : Request , token = "" , key = "" , user_id : str = Query ("" , alias = "userID" )):
164- if token :
165- auth .save (token = token , username = user_id , user_id = user_id , api_key = key )
166- logger .info ("Authentication Successful" )
167- else :
174+ async def stop_server_once_request_is_done ():
175+ while not await request .is_disconnected ():
176+ sleep (0.25 )
177+ server .should_exit = True
178+
179+ if not token :
168180 logger .warn (
169- "Authentication Failed. This is most likely because you're using an older version of the CLI. \n " # noqa E501
181+ "Login Failed. This is most likely because you're using an older version of the CLI. \n " # noqa E501
170182 "Please try to update the CLI or open an issue with this information \n " # E501
171183 f"expected token in { request .query_params .items ()} "
172184 )
185+ return RedirectResponse (
186+ url = f"{ get_lightning_cloud_url ()} /cli-login-failed" ,
187+ background = BackgroundTask (stop_server_once_request_is_done ),
188+ )
189+
190+ auth .save (token = token , username = user_id , user_id = user_id , api_key = key )
191+ logger .info ("Login Successful" )
173192
174193 # Include the credentials in the redirect so that UI will also be logged in
175194 params = urlencode (dict (token = token , key = key , userID = user_id ))
176195
177196 return RedirectResponse (
178- url = f"{ get_lightning_cloud_url ()} /me/apps?{ params } " ,
179- # The response background task is being executed right after the server finished writing the response
180- background = BackgroundTask (stop_server ),
197+ url = f"{ get_lightning_cloud_url ()} /cli-login-successful?{ params } " ,
198+ background = BackgroundTask (stop_server_once_request_is_done ),
181199 )
182200
183- def stop_server ():
184- server .should_exit = True
185-
186201 server = uvicorn .Server (config = uvicorn .Config (app , port = port , log_level = "error" ))
187202 server .run ()
0 commit comments