diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d8ec92e1..22a42e54d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -156,6 +156,15 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. should be treated as immutable. - Graph type sets (`neo4j.graph.EntitySetView`) can no longer by indexed by legacy `id` (`int`, e.g., `graph.nodes[0]`). Use the `element_id` instead (`str`, e.g., `graph.nodes["..."]`). +- Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some + instances) if the other operand is not of a supported type. + This means that when comparing a driver type with another type is doesn't support, the other type get the chance to + handle the comparison. + Affected types: + - `neo4j.Record` + - `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path` + - `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime` + - `neo4j.spatial.Point` (and subclasses) ## Version 5.28 diff --git a/src/neo4j/_codec/packstream/_python/_common.py b/src/neo4j/_codec/packstream/_python/_common.py index 3cb230838..f2fa6bb01 100644 --- a/src/neo4j/_codec/packstream/_python/_common.py +++ b/src/neo4j/_codec/packstream/_python/_common.py @@ -28,10 +28,7 @@ def __eq__(self, other): try: return self.tag == other.tag and self.fields == other.fields except AttributeError: - return False - - def __ne__(self, other): - return not self.__eq__(other) + return NotImplementedError def __len__(self): return len(self.fields) diff --git a/src/neo4j/_data.py b/src/neo4j/_data.py index 4c5ee7ed3..27b6d459a 100644 --- a/src/neo4j/_data.py +++ b/src/neo4j/_data.py @@ -110,10 +110,7 @@ def __eq__(self, other: object) -> bool: other = t.cast(t.Mapping, other) return dict(self) == dict(other) else: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return reduce(xor_operator, map(hash, self.items())) diff --git a/src/neo4j/_io/__init__.py b/src/neo4j/_io/__init__.py index 648efffb8..b0c227890 100644 --- a/src/neo4j/_io/__init__.py +++ b/src/neo4j/_io/__init__.py @@ -61,28 +61,28 @@ def __ne__(self, other: object) -> bool: return self.version != other return NotImplemented - def __lt__(self, other: object) -> bool: + def __lt__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version < other.version if isinstance(other, tuple): return self.version < other return NotImplemented - def __le__(self, other: object) -> bool: + def __le__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version <= other.version if isinstance(other, tuple): return self.version <= other return NotImplemented - def __gt__(self, other: object) -> bool: + def __gt__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version > other.version if isinstance(other, tuple): return self.version > other return NotImplemented - def __ge__(self, other: object) -> bool: + def __ge__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version >= other.version if isinstance(other, tuple): diff --git a/src/neo4j/api.py b/src/neo4j/api.py index fb39cbf3b..4a6a9eff5 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -113,7 +113,7 @@ def __init__( if parameters: self.parameters = parameters - def __eq__(self, other: _t.Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, Auth): return NotImplemented return vars(self) == vars(other) diff --git a/src/neo4j/graph/__init__.py b/src/neo4j/graph/__init__.py index 1f1de91eb..e52d9e5c5 100644 --- a/src/neo4j/graph/__init__.py +++ b/src/neo4j/graph/__init__.py @@ -117,7 +117,6 @@ def __init__( } def __eq__(self, other: _t.Any) -> bool: - # TODO: 6.0 - return NotImplemented on type mismatch instead of False try: return ( type(self) is type(other) @@ -125,10 +124,7 @@ def __eq__(self, other: _t.Any) -> bool: and self.element_id == other.element_id ) except AttributeError: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return hash(self._element_id) @@ -325,17 +321,13 @@ def __repr__(self) -> str: ) def __eq__(self, other: _t.Any) -> bool: - # TODO: 6.0 - return NotImplemented on type mismatch instead of False try: return ( self.start_node == other.start_node and self.relationships == other.relationships ) except AttributeError: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): value = hash(self._nodes[0]) diff --git a/src/neo4j/spatial/__init__.py b/src/neo4j/spatial/__init__.py index d8d09212b..a2e07726e 100644 --- a/src/neo4j/spatial/__init__.py +++ b/src/neo4j/spatial/__init__.py @@ -77,10 +77,7 @@ def __eq__(self, other: object) -> bool: _t.cast(Point, other) ) except (AttributeError, TypeError): - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return hash(type(self)) ^ hash(tuple(self)) diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index b830dd533..2c7afa93d 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -1217,60 +1217,31 @@ def __hash__(self): def __eq__(self, other: object) -> bool: """``==`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, _date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - return False + return NotImplemented return self.toordinal() == other.toordinal() - def __ne__(self, other: object) -> bool: - """``!=`` comparison with :class:`.Date` or :class:`datetime.date`.""" - # TODO: 6.0 - return NotImplemented for non-Date objects - # if not isinstance(other, (Date, date)): - # return NotImplemented - return not self.__eq__(other) - def __lt__(self, other: Date | _date) -> bool: """``<`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, _date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'<' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() < other.toordinal() def __le__(self, other: Date | _date) -> bool: """``<=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, _date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'<=' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() <= other.toordinal() def __ge__(self, other: Date | _date) -> bool: """``>=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, _date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'>=' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() >= other.toordinal() def __gt__(self, other: Date | _date) -> bool: """``>`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, _date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'>' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() > other.toordinal() def __add__(self, other: Duration) -> Date: # type: ignore[override] @@ -1917,29 +1888,37 @@ def tzinfo(self) -> _tzinfo | None: # OPERATIONS # - def _get_both_normalized_ticks(self, other: object, strict=True): - if isinstance(other, (_time, Time)) and ( - (self.utc_offset() is None) ^ (other.utcoffset() is None) - ): + @_t.overload + def _get_both_normalized_ticks( + self, other: Time | _time, strict: _t.Literal[True] = True + ) -> tuple[int, int]: ... + + @_t.overload + def _get_both_normalized_ticks( + self, other: Time | _time, strict: _t.Literal[False] + ) -> tuple[int, int] | None: ... + + def _get_both_normalized_ticks( + self, other: Time | _time, strict: bool = True + ) -> tuple[int, int] | None: + if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( "can't compare offset-naive and offset-aware times" ) else: - return None, None + return None other_ticks: int if isinstance(other, Time): other_ticks = other.__ticks - elif isinstance(other, _time): + else: + assert isinstance(other, _time) other_ticks = int( 3600000000000 * other.hour + 60000000000 * other.minute + _NANO_SECONDS * other.second + 1000 * other.microsecond ) - else: - return None, None - assert isinstance(other, (Time, _time)) utc_offset: _timedelta | None = other.utcoffset() if utc_offset is not None: other_ticks -= int(utc_offset.total_seconds() * _NANO_SECONDS) @@ -1959,43 +1938,40 @@ def __hash__(self): def __eq__(self, other: object) -> bool: """`==` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks( - other, strict=False - ) - if self_ticks is None: + if not isinstance(other, (Time, _time)): + return NotImplemented + ticks = self._get_both_normalized_ticks(other, strict=False) + if ticks is None: return False + self_ticks, other_ticks = ticks return self_ticks == other_ticks - def __ne__(self, other: object) -> bool: - """`!=` comparison with :class:`.Time` or :class:`datetime.time`.""" - return not self.__eq__(other) - def __lt__(self, other: Time | _time) -> bool: """`<` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, _time)): return NotImplemented + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks < other_ticks def __le__(self, other: Time | _time) -> bool: """`<=` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, _time)): return NotImplemented + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks <= other_ticks def __ge__(self, other: Time | _time) -> bool: """`>=` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, _time)): return NotImplemented + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks >= other_ticks def __gt__(self, other: Time | _time) -> bool: """`>` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, _time)): return NotImplemented + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks > other_ticks # INSTANCE METHODS # @@ -2603,29 +2579,36 @@ def hour_minute_second_nanosecond(self) -> tuple[int, int, int, int]: # OPERATIONS # - def _get_both_normalized(self, other, strict=True): - if isinstance(other, (_datetime, DateTime)) and ( - (self.utc_offset() is None) ^ (other.utcoffset() is None) - ): + @_t.overload + def _get_both_normalized( + self, other: _datetime | DateTime, strict: _t.Literal[True] = True + ) -> tuple[DateTime, DateTime | _datetime]: ... + + @_t.overload + def _get_both_normalized( + self, other: _datetime | DateTime, strict: _t.Literal[False] + ) -> tuple[DateTime, DateTime | _datetime] | None: ... + + def _get_both_normalized( + self, other: _datetime | DateTime, strict: bool = True + ) -> tuple[DateTime, DateTime | _datetime] | None: + if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( "can't compare offset-naive and offset-aware datetimes" ) else: - return None, None + return None self_norm = self utc_offset = self.utc_offset() if utc_offset is not None: self_norm -= utc_offset self_norm = self_norm.replace(tzinfo=None) other_norm = other - if isinstance(other, (_datetime, DateTime)): - utc_offset = other.utcoffset() - if utc_offset is not None: - other_norm -= utc_offset - other_norm = other_norm.replace(tzinfo=None) - else: - return None, None + utc_offset = other.utcoffset() + if utc_offset is not None: + other_norm -= utc_offset + other_norm = other_norm.replace(tzinfo=None) return self_norm, other_norm def __hash__(self): @@ -2647,21 +2630,12 @@ def __eq__(self, other: object) -> bool: return NotImplemented if self.utc_offset() == other.utcoffset(): return self.date() == other.date() and self.time() == other.time() - self_norm, other_norm = self._get_both_normalized(other, strict=False) - if self_norm is None: + normalized = self._get_both_normalized(other, strict=False) + if normalized is None: return False + self_norm, other_norm = normalized return self_norm == other_norm - def __ne__(self, other: object) -> bool: - """ - ``!=`` comparison with another datetime. - - Accepts :class:`.DateTime` and :class:`datetime.datetime`. - """ - if not isinstance(other, (DateTime, _datetime)): - return NotImplemented - return not self.__eq__(other) - def __lt__( # type: ignore[override] self, other: _datetime | DateTime ) -> bool: diff --git a/tests/unit/common/time/test_datetime.py b/tests/unit/common/time/test_datetime.py index 3a7740ef2..5610f393b 100644 --- a/tests/unit/common/time/test_datetime.py +++ b/tests/unit/common/time/test_datetime.py @@ -1190,6 +1190,62 @@ def test_comparison(dt1, dt2) -> None: assert not dt1 >= dt2 +@pytest.mark.parametrize( + ("dt1_args", "dt2_args"), + ( + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 33, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 35, 56, 789124), + ), + ( + (2022, 11, 25, 12, 32, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 36, 56, 789124), + ), + ), +) +@pytest.mark.parametrize("dt1_cls", (DateTime, datetime)) +@pytest.mark.parametrize("dt2_cls", (DateTime, datetime)) +@pytest.mark.parametrize( + "tz", + (FixedOffset(0), FixedOffset(1), FixedOffset(-1), utc, timezone_berlin), +) +def test_comparison_only_one_with_tzinfo( + dt1_args, dt1_cls, dt2_args, dt2_cls, tz +) -> None: + dt1 = dt1_cls(*dt1_args) + dt2 = dt2_cls(*dt2_args, tzinfo=None) + err_msg = "can't compare offset-naive and offset-aware" + dt2 = dt2.replace(tzinfo=tz) + with pytest.raises(TypeError, match=err_msg): + assert not dt1 < dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 < dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 <= dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 <= dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 > dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 > dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 <= dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 <= dt1 + + def test_str() -> None: dt = DateTime(2018, 4, 26, 23, 0, 17, 914390409) assert str(dt) == "2018-04-26T23:00:17.914390409" diff --git a/tests/unit/common/time/test_time.py b/tests/unit/common/time/test_time.py index 4fa414375..2dc7a342c 100644 --- a/tests/unit/common/time/test_time.py +++ b/tests/unit/common/time/test_time.py @@ -577,6 +577,62 @@ def test_pickle(self, expected): assert expected.foo is not actual.foo +@pytest.mark.parametrize( + ("t1_args", "t2_args"), + ( + ( + (12, 34, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 33, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 34, 56, 789124), + (12, 35, 56, 789124), + ), + ( + (12, 32, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 34, 56, 789124), + (12, 36, 56, 789124), + ), + ), +) +@pytest.mark.parametrize("t1_cls", (Time, time)) +@pytest.mark.parametrize("t2_cls", (Time, time)) +@pytest.mark.parametrize( + "tz", + (FixedOffset(0), FixedOffset(1), FixedOffset(-1), utc), +) +def test_comparison_only_one_with_tzinfo( + t1_args, t1_cls, t2_args, t2_cls, tz +) -> None: + t1 = t1_cls(*t1_args) + t2 = t2_cls(*t2_args, tzinfo=None) + err_msg = "can't compare offset-naive and offset-aware" + t2 = t2.replace(tzinfo=tz) + with pytest.raises(TypeError, match=err_msg): + assert not t1 < t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 < t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 <= t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 <= t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 > t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 > t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 <= t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 <= t1 + + def test_str() -> None: t = Time(12, 34, 56, 789123001) assert str(t) == "12:34:56.789123001"