diff --git a/nutkit/protocol/cypher.py b/nutkit/protocol/cypher.py index af76a610d..23c8f6e37 100644 --- a/nutkit/protocol/cypher.py +++ b/nutkit/protocol/cypher.py @@ -14,6 +14,8 @@ } """ + +import datetime import math @@ -424,6 +426,22 @@ def __eq__(self, other): "second", "nanosecond", "utc_offset_s", "timezone_id")) + def as_utc(self): + if self.utc_offset_s is None: + return self + us, ns = divmod(self.nanosecond, 1000) + dt = datetime.datetime( + year=self.year, month=self.month, day=self.day, hour=self.hour, + minute=self.minute, second=self.second, microsecond=us + ) + utc_dt = dt - datetime.timedelta(seconds=self.utc_offset_s) + + return CypherDateTime( + utc_dt.year, utc_dt.month, utc_dt.day, utc_dt.hour, utc_dt.minute, + utc_dt.second, utc_dt.microsecond * 1000 + ns, + utc_offset_s=0, timezone_id="UTC" + ) + class CypherDuration: def __init__(self, months, days, seconds, nanoseconds): diff --git a/tests/neo4j/datatypes/test_temporal_types.py b/tests/neo4j/datatypes/test_temporal_types.py index 3cec846c6..ed5d0a1e1 100644 --- a/tests/neo4j/datatypes/test_temporal_types.py +++ b/tests/neo4j/datatypes/test_temporal_types.py @@ -180,6 +180,26 @@ def test_should_echo_all_timezone_ids(self): self._verify_can_echo(dt) def test_date_time_cypher_created_tz_id(self): + def assert_utc_equal(dt_, cypher_dt_): + self.assertEqual(dt_.timezone_id, tz) + # We are comparing in UTC because the server's and the + # driver's timezone db may diverge. + self.assertEqual(dt_.as_utc(), cypher_dt_.as_utc()) + + def assert_wall_time_equal(dt_, cypher_dt_): + self.assertEqual(dt_.year, cypher_dt_.year) + self.assertEqual(dt_.month, cypher_dt_.month) + self.assertEqual(dt_.day, cypher_dt_.day) + self.assertEqual(dt_.hour, cypher_dt_.hour) + self.assertEqual(dt_.minute, cypher_dt_.minute) + self.assertEqual(dt_.second, cypher_dt_.second) + self.assertEqual(dt_.nanosecond, cypher_dt_.nanosecond) + # We are not testing the offset value because the + # server's and the driver's timezone db may diverge. + # self.assertEqual(dt.utc_offset_s, cypher_dt.utc_offset_s) + self.assertEqual(dt_.timezone_id, cypher_dt_.timezone_id) + pass + def work(tx): res = tx.run( f"WITH datetime('1970-01-01T10:08:09.000000001[{tz_id}]') " @@ -204,6 +224,7 @@ def work(tx): return map(lambda x: getattr(x, "value", x), rec.values) self._create_driver_and_session() + server_supports_utc = get_server_info().has_utc_patch for tz_id in TZ_IDS: if not self._timezone_server_support(tz_id): continue @@ -211,17 +232,29 @@ def work(tx): with self.expect_broken_utc_server(): dt, y, mo, d, h, m, s, ns, offset, tz = \ self._session.read_transaction(work) - self.assertEqual(dt.year, y) - self.assertEqual(dt.month, mo) - self.assertEqual(dt.day, d) - self.assertEqual(dt.hour, h) - self.assertEqual(dt.minute, m) - self.assertEqual(dt.second, s) - self.assertEqual(dt.nanosecond, ns) - # We are not testing the offset value because the server's - # and the driver's timezone db may diverge. - # self.assertEqual(dt.utc_offset_s, offset) - self.assertEqual(dt.timezone_id, tz) + cypher_dt = types.CypherDateTime( + y, mo, d, h, m, s, ns, offset, tz_id + ) + if server_supports_utc == 1: + # 5.0+ protocol sends date times in UTC + # => UTC times must be equal + assert_utc_equal(dt, cypher_dt) + elif server_supports_utc == 0: + # 4.2- protocol sends date times in wall clock time + # => Wall clock times must be equal + assert_wall_time_equal(dt, cypher_dt) + else: + # 4.4 and 4.3 protocol sends date times in + # wall clock time or UTC depending on server version, + # driver version and their handshake. + try: + assert_utc_equal(dt, cypher_dt) + except AssertionError: + # guess it was the other + pass + else: + continue + assert_wall_time_equal(dt, cypher_dt) def test_date_components(self): self._create_driver_and_session() diff --git a/tests/neo4j/shared.py b/tests/neo4j/shared.py index dc4412254..f466398df 100644 --- a/tests/neo4j/shared.py +++ b/tests/neo4j/shared.py @@ -109,6 +109,14 @@ def max_protocol_version(self): "5.0": "5.0", }[".".join(self.version.split(".")[:2])] + @property + def has_utc_patch(self): + if self.version >= "5": + return 1 + if self.version >= "4.3": + return 0.5 # maybe + return 0 + def get_server_info(): return ServerInfo(