diff --git a/sqlalchemy_example.py b/sqlalchemy_example.py index d76716f..ee2a2ce 100644 --- a/sqlalchemy_example.py +++ b/sqlalchemy_example.py @@ -53,7 +53,7 @@ # See src/databricks/sql/thrift_backend.py for complete list extra_connect_args = { "_tls_verify_hostname": True, - "_user_agent_entry": "PySQL Example Script", + "user_agent_entry": "PySQL Example Script", } diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7..fba05fc 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -409,21 +409,12 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): if not dialect.name == "databricks": return - ua = cparams.get("_user_agent_entry", "") + ua = cparams.get("user_agent_entry", "sqlalchemy") - def add_sqla_tag_if_not_present(val: str): - if not val: - output = "sqlalchemy" + if "sqlalchemy" not in ua: + ua = f"sqlalchemy + {ua}" - if val and "sqlalchemy" in val: - output = val - - else: - output = f"sqlalchemy + {val}" - - return output - - cparams["_user_agent_entry"] = add_sqla_tag_if_not_present(ua) + cparams["user_agent_entry"] = ua if sqlalchemy.__version__.startswith("1.3"): # SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index ce0b5d8..a83919f 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -55,7 +55,7 @@ def version_agnostic_connect_arguments(connection_details) -> Tuple[str, dict]: CATALOG = connection_details["catalog"] SCHEMA = connection_details["schema"] - ua_connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} + ua_connect_args = {"user_agent_entry": USER_AGENT_TOKEN} if sqlalchemy_1_3(): conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}" @@ -510,7 +510,7 @@ def engine(self, connection_details: dict): SCHEMA = connection_details["schema"] connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" - connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} + connect_args = {"user_agent_entry": USER_AGENT_TOKEN} engine = create_engine(connection_string, connect_args=connect_args) return engine diff --git a/tests/test_local/e2e/test_setup.py b/tests/test_local/e2e/test_setup.py index 94a37aa..935ac8d 100644 --- a/tests/test_local/e2e/test_setup.py +++ b/tests/test_local/e2e/test_setup.py @@ -16,7 +16,7 @@ def db_engine(self) -> Engine: CATALOG = self.arguments["catalog"] SCHEMA = self.arguments["schema"] - connect_args = {"_user_agent_entry": "SQLAlchemy e2e Tests"} + connect_args = {"user_agent_entry": "SQLAlchemy e2e Tests"} conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" return create_engine(conn_string, connect_args=connect_args)