Skip to content

Commit 741e6b3

Browse files
authored
Improve packstream Structure class (#1232)
These changes have no affect on the driver's public API. They're targeted at improving the development and debugging experience. * Adjust `repr` to follow Python's recommendations * Fix `__eq__` returning `NotImplementedError` instead of `NotImplemented` * Add type hints * Add tests
1 parent 23207fd commit 741e6b3

File tree

2 files changed

+152
-8
lines changed

2 files changed

+152
-8
lines changed

src/neo4j/_codec/packstream/_python/_common.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@
1414
# limitations under the License.
1515

1616

17+
from .... import _typing as t
18+
19+
1720
class Structure:
18-
def __init__(self, tag, *fields):
21+
tag: bytes
22+
fields: list[t.Any]
23+
24+
def __init__(self, tag: bytes, *fields: t.Any):
1925
self.tag = tag
2026
self.fields = list(fields)
2127

22-
def __repr__(self):
23-
fields = ", ".join(map(repr, self.fields))
24-
tag_int = ord(self.tag)
25-
return f"Structure[0x{tag_int:02X}]({fields})"
28+
def __repr__(self) -> str:
29+
args = ", ".join(map(repr, (self.tag, *self.fields)))
30+
return f"Structure({args})"
2631

27-
def __eq__(self, other):
32+
def __eq__(self, other) -> bool:
2833
try:
2934
return self.tag == other.tag and self.fields == other.fields
3035
except AttributeError:
31-
return NotImplementedError
36+
return NotImplemented
3237

33-
def __len__(self):
38+
def __len__(self) -> int:
3439
return len(self.fields)
3540

3641
def __getitem__(self, key):
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from neo4j._codec.packstream import Structure
20+
21+
22+
@pytest.mark.parametrize(
23+
"args",
24+
(
25+
(b"T", 1, 2, 3, "abc", 1.2, None, False),
26+
(b"F",),
27+
),
28+
)
29+
def test_structure_accessors(args):
30+
tag = args[0]
31+
fields = list(args[1:])
32+
s1 = Structure(*args)
33+
assert s1.tag == tag
34+
assert s1.fields == fields
35+
36+
37+
@pytest.mark.parametrize(
38+
("other", "expected"),
39+
(
40+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
41+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
42+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
43+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
44+
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
45+
(
46+
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
47+
False,
48+
),
49+
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
50+
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
51+
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
52+
(object(), NotImplemented),
53+
),
54+
)
55+
def test_structure_equality(other, expected):
56+
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
57+
assert s1.__eq__(other) is expected # noqa: PLC2801
58+
if expected is NotImplemented:
59+
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
60+
else:
61+
assert s1.__ne__(other) is not expected # noqa: PLC2801
62+
63+
64+
@pytest.mark.parametrize(
65+
("args", "expected"),
66+
(
67+
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
68+
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
69+
(
70+
(b"T", 1.3, None, {"a": "b"}),
71+
"Structure(b'T', 1.3, None, {'a': 'b'})",
72+
),
73+
),
74+
)
75+
def test_structure_repr(args, expected):
76+
s1 = Structure(*args)
77+
assert repr(s1) == expected
78+
assert str(s1) == expected
79+
80+
# Ensure that the repr is consistent with the constructor
81+
assert eval(repr(s1)) == s1
82+
assert eval(str(s1)) == s1
83+
84+
85+
@pytest.mark.parametrize(
86+
("fields", "expected"),
87+
(
88+
((), 0),
89+
(([],), 1),
90+
((1, 2), 2),
91+
((1, 2, []), 3),
92+
(([1, 2], {"a": "foo", "b": "bar"}), 2),
93+
),
94+
)
95+
def test_structure_len(fields, expected):
96+
structure = Structure(b"F", *fields)
97+
assert len(structure) == expected
98+
99+
100+
def test_structure_getitem():
101+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
102+
structure = Structure(b"F", *fields)
103+
for i, field in enumerate(fields):
104+
assert structure[i] == field
105+
assert structure[-len(fields) + i] == field
106+
with pytest.raises(IndexError):
107+
_ = structure[len(fields)]
108+
with pytest.raises(IndexError):
109+
_ = structure[-len(fields) - 1]
110+
111+
112+
def test_structure_setitem():
113+
test_value = object()
114+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
115+
structure = Structure(b"F", *fields)
116+
for i, original_value in enumerate(fields):
117+
structure[i] = test_value
118+
assert structure[i] == test_value
119+
assert structure[-len(fields) + i] == test_value
120+
assert structure[i] != original_value
121+
assert structure[-len(fields) + i] != original_value
122+
123+
structure[i] = original_value
124+
assert structure[i] == original_value
125+
assert structure[-len(fields) + i] == original_value
126+
127+
structure[-len(fields) + i] = test_value
128+
assert structure[i] == test_value
129+
assert structure[-len(fields) + i] == test_value
130+
assert structure[i] != original_value
131+
assert structure[-len(fields) + i] != original_value
132+
133+
structure[-len(fields) + i] = original_value
134+
assert structure[i] == original_value
135+
assert structure[-len(fields) + i] == original_value
136+
with pytest.raises(IndexError):
137+
structure[len(fields)] = test_value
138+
with pytest.raises(IndexError):
139+
structure[-len(fields) - 1] = test_value

0 commit comments

Comments
 (0)