|
7 | 7 | import warnings |
8 | 8 | from dataclasses import dataclass |
9 | 9 | from types import TracebackType |
10 | | -from typing import Any, Callable, Iterable, Optional, Type, Union |
| 10 | +from typing import Any, Callable, Dict, Iterable, Optional, Type, Union |
11 | 11 |
|
12 | 12 | import pyarrow |
13 | 13 | from neo4j.exceptions import ClientError |
@@ -89,19 +89,35 @@ def __init__( |
89 | 89 | self._host = host |
90 | 90 | self._port = port |
91 | 91 | self._auth = auth |
| 92 | + self._encrypted = encrypted |
| 93 | + self._disable_server_verification = disable_server_verification |
| 94 | + self._tls_root_certs = tls_root_certs |
| 95 | + self._user_agent = user_agent |
92 | 96 |
|
93 | | - location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) |
94 | | - |
95 | | - client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification} |
96 | 97 | if auth: |
97 | 98 | self._auth_middleware = AuthMiddleware(auth) |
98 | | - if not user_agent: |
99 | | - user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" |
100 | | - client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] |
101 | | - if tls_root_certs: |
102 | | - client_options["tls_root_certs"] = tls_root_certs |
103 | 99 |
|
104 | | - self._flight_client = flight.FlightClient(location, **client_options) |
| 100 | + self._flight_client = self._instantiate_flight_client() |
| 101 | + |
| 102 | + def _instantiate_flight_client(self) -> flight.FlightClient: |
| 103 | + location = ( |
| 104 | + flight.Location.for_grpc_tls(self._host, self._port) |
| 105 | + if self._encrypted |
| 106 | + else flight.Location.for_grpc_tcp(self._host, self._port) |
| 107 | + ) |
| 108 | + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} |
| 109 | + if self._auth: |
| 110 | + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" |
| 111 | + if self._user_agent: |
| 112 | + user_agent = self._user_agent |
| 113 | + |
| 114 | + client_options["middleware"] = [ |
| 115 | + AuthFactory(self._auth_middleware), |
| 116 | + UserAgentFactory(useragent=user_agent), |
| 117 | + ] |
| 118 | + if self._tls_root_certs: |
| 119 | + client_options["tls_root_certs"] = self._tls_root_certs |
| 120 | + return flight.FlightClient(location, **client_options) |
105 | 121 |
|
106 | 122 | def connection_info(self) -> tuple[str, int]: |
107 | 123 | """ |
@@ -537,11 +553,28 @@ def upload_triplets( |
537 | 553 | """ |
538 | 554 | self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) |
539 | 555 |
|
| 556 | + def __getstate__(self) -> Dict[str, Any]: |
| 557 | + state = self.__dict__.copy() |
| 558 | + # Remove the FlightClient as it isn't serializable |
| 559 | + if "_flight_client" in state: |
| 560 | + del state["_flight_client"] |
| 561 | + return state |
| 562 | + |
| 563 | + def _client(self) -> flight.FlightClient: |
| 564 | + """ |
| 565 | + Lazy client construction to help pickle this class because a PyArrow |
| 566 | + FlightClient is not serializable. |
| 567 | + """ |
| 568 | + if not hasattr(self, "_flight_client") or not self._flight_client: |
| 569 | + self._flight_client = self._instantiate_flight_client() |
| 570 | + return self._flight_client |
| 571 | + |
540 | 572 | def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]: |
541 | 573 | action_type = self._versioned_action_type(action_type) |
542 | 574 |
|
543 | 575 | try: |
544 | | - result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) |
| 576 | + client = self._client() |
| 577 | + result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) |
545 | 578 |
|
546 | 579 | # Consume result fully to sanity check and avoid cancelled streams |
547 | 580 | collected_result = list(result) |
@@ -569,7 +602,9 @@ def _upload_data( |
569 | 602 |
|
570 | 603 | flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type}) |
571 | 604 | upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) |
572 | | - put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema) |
| 605 | + |
| 606 | + client = self._client() |
| 607 | + put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) |
573 | 608 |
|
574 | 609 | @retry( |
575 | 610 | stop=(stop_after_delay(10) | stop_after_attempt(5)), |
|
0 commit comments