|
19 | 19 | import abc |
20 | 20 | import asyncio |
21 | 21 | import logging |
| 22 | +import math |
22 | 23 | import typing as t |
23 | 24 | from collections import ( |
24 | 25 | defaultdict, |
@@ -669,6 +670,7 @@ async def acquire( |
669 | 670 | ): |
670 | 671 | # The access_mode and database is not needed for a direct connection, |
671 | 672 | # it's just there for consistency. |
| 673 | + _check_acquisition_timeout(timeout) |
672 | 674 | log.debug( |
673 | 675 | "[#0000] _: <POOL> acquire direct connection, " |
674 | 676 | "access_mode=%r, database=%r", |
@@ -969,6 +971,7 @@ async def update_routing_table( |
969 | 971 |
|
970 | 972 | :raise neo4j.exceptions.ServiceUnavailable: |
971 | 973 | """ |
| 974 | + _check_acquisition_timeout(acquisition_timeout) |
972 | 975 | async with self.refresh_lock: |
973 | 976 | routing_table = await self.get_routing_table(database) |
974 | 977 | if routing_table is not None: |
@@ -1152,11 +1155,7 @@ async def acquire( |
1152 | 1155 | if access_mode not in {WRITE_ACCESS, READ_ACCESS}: |
1153 | 1156 | # TODO: 6.0 - change this to be a ValueError |
1154 | 1157 | raise ClientError(f"Non valid 'access_mode'; {access_mode}") |
1155 | | - if not timeout: |
1156 | | - # TODO: 6.0 - change this to be a ValueError |
1157 | | - raise ClientError( |
1158 | | - f"'timeout' must be a float larger than 0; {timeout}" |
1159 | | - ) |
| 1158 | + _check_acquisition_timeout(timeout) |
1160 | 1159 |
|
1161 | 1160 | from ...api import check_access_mode |
1162 | 1161 |
|
@@ -1253,3 +1252,23 @@ async def on_write_failure(self, address, database): |
1253 | 1252 | if table is not None: |
1254 | 1253 | table.writers.discard(address) |
1255 | 1254 | log.debug("[#0000] _: <POOL> table=%r", self.routing_tables) |
| 1255 | + |
| 1256 | + |
| 1257 | +def _check_acquisition_timeout(timeout: object) -> None: |
| 1258 | + if isinstance(timeout, int): |
| 1259 | + if timeout <= 0: |
| 1260 | + raise ValueError( |
| 1261 | + f"Connection acquisition timeout must be > 0, got {timeout}" |
| 1262 | + ) |
| 1263 | + elif isinstance(timeout, float): |
| 1264 | + if math.isnan(timeout): |
| 1265 | + raise ValueError("Connection acquisition timeout must not be NaN") |
| 1266 | + if timeout <= 0: |
| 1267 | + raise ValueError( |
| 1268 | + f"Connection acquisition timeout must be > 0, got {timeout}" |
| 1269 | + ) |
| 1270 | + else: |
| 1271 | + raise TypeError( |
| 1272 | + "Connection acquisition timeout must be a number, " |
| 1273 | + f"got {type(timeout)}" |
| 1274 | + ) |
0 commit comments