3535 wait_exponential ,
3636)
3737
38- from graphdatascience .query_runner .arrow_authentication import ArrowAuthentication
38+ from graphdatascience .query_runner .arrow_authentication import ArrowAuthentication , UsernamePasswordAuthentication
3939from graphdatascience .retry_utils .retry_config import RetryConfig
4040from graphdatascience .retry_utils .retry_utils import before_log
4141
@@ -49,7 +49,7 @@ class GdsArrowClient:
4949 @staticmethod
5050 def create (
5151 arrow_info : ArrowInfo ,
52- arrow_authentication : Optional [ArrowAuthentication ] = None ,
52+ auth : Optional [Union [ ArrowAuthentication , tuple [ str , str ]] ] = None ,
5353 encrypted : bool = False ,
5454 disable_server_verification : bool = False ,
5555 tls_root_certs : Optional [bytes ] = None ,
@@ -81,7 +81,7 @@ def create(
8181 host ,
8282 retry_config ,
8383 int (port ),
84- arrow_authentication ,
84+ auth ,
8585 encrypted ,
8686 disable_server_verification ,
8787 tls_root_certs ,
@@ -93,7 +93,7 @@ def __init__(
9393 host : str ,
9494 retry_config : RetryConfig ,
9595 port : int = 8491 ,
96- auth : Optional [ArrowAuthentication ] = None ,
96+ auth : Optional [Union [ ArrowAuthentication , tuple [ str , str ]] ] = None ,
9797 encrypted : bool = False ,
9898 disable_server_verification : bool = False ,
9999 tls_root_certs : Optional [bytes ] = None ,
@@ -108,8 +108,8 @@ def __init__(
108108 The host address of the GDS Arrow server
109109 port: int
110110 The host port of the GDS Arrow server (default is 8491)
111- auth: Optional[ArrowAuthentication]
112- An implementation of ArrowAuthentication providing a pair to be used for basic authentication
111+ auth: Optional[Union[ ArrowAuthentication, tuple[str, str]] ]
112+ Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
113113 encrypted: bool
114114 A flag that indicates whether the connection should be encrypted (default is False)
115115 disable_server_verification: bool
@@ -126,7 +126,7 @@ def __init__(
126126 self ._arrow_endpoint_version = arrow_endpoint_version
127127 self ._host = host
128128 self ._port = port
129- self ._auth = auth
129+ self ._auth = None
130130 self ._encrypted = encrypted
131131 self ._disable_server_verification = disable_server_verification
132132 self ._tls_root_certs = tls_root_certs
@@ -135,6 +135,10 @@ def __init__(
135135 self ._logger = logging .getLogger ("gds_arrow_client" )
136136
137137 if auth :
138+ if not isinstance (auth , ArrowAuthentication ):
139+ username , password = auth
140+ auth = UsernamePasswordAuthentication (username , password )
141+ self ._auth = auth
138142 self ._auth_middleware = AuthMiddleware (auth )
139143
140144 self ._flight_client = self ._instantiate_flight_client ()
0 commit comments