Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/neo4j/_codec/packstream/_python/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@
# limitations under the License.


from .... import _typing as t


class Structure:
def __init__(self, tag, *fields):
tag: bytes
fields: list[t.Any]

def __init__(self, tag: bytes, *fields: t.Any):
self.tag = tag
self.fields = list(fields)

def __repr__(self):
fields = ", ".join(map(repr, self.fields))
tag_int = ord(self.tag)
return f"Structure[0x{tag_int:02X}]({fields})"
def __repr__(self) -> str:
args = ", ".join(map(repr, (self.tag, *self.fields)))
return f"Structure({args})"

def __eq__(self, other):
def __eq__(self, other) -> bool:
try:
return self.tag == other.tag and self.fields == other.fields
except AttributeError:
return NotImplementedError
return NotImplemented

def __len__(self):
def __len__(self) -> int:
return len(self.fields)

def __getitem__(self, key):
Expand Down
139 changes: 139 additions & 0 deletions tests/unit/common/codec/packstream/test_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from neo4j._codec.packstream import Structure


@pytest.mark.parametrize(
"args",
(
(b"T", 1, 2, 3, "abc", 1.2, None, False),
(b"F",),
),
)
def test_structure_accessors(args):
tag = args[0]
fields = list(args[1:])
s1 = Structure(*args)
assert s1.tag == tag
assert s1.fields == fields


@pytest.mark.parametrize(
("other", "expected"),
(
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
(
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
False,
),
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
(object(), NotImplemented),
),
)
def test_structure_equality(other, expected):
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
assert s1.__eq__(other) is expected # noqa: PLC2801
if expected is NotImplemented:
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
else:
assert s1.__ne__(other) is not expected # noqa: PLC2801


@pytest.mark.parametrize(
("args", "expected"),
(
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
(
(b"T", 1.3, None, {"a": "b"}),
"Structure(b'T', 1.3, None, {'a': 'b'})",
),
),
)
def test_structure_repr(args, expected):
s1 = Structure(*args)
assert repr(s1) == expected
assert str(s1) == expected

# Ensure that the repr is consistent with the constructor
assert eval(repr(s1)) == s1
assert eval(str(s1)) == s1


@pytest.mark.parametrize(
("fields", "expected"),
(
((), 0),
(([],), 1),
((1, 2), 2),
((1, 2, []), 3),
(([1, 2], {"a": "foo", "b": "bar"}), 2),
),
)
def test_structure_len(fields, expected):
structure = Structure(b"F", *fields)
assert len(structure) == expected


def test_structure_getitem():
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
structure = Structure(b"F", *fields)
for i, field in enumerate(fields):
assert structure[i] == field
assert structure[-len(fields) + i] == field
with pytest.raises(IndexError):
_ = structure[len(fields)]
with pytest.raises(IndexError):
_ = structure[-len(fields) - 1]


def test_structure_setitem():
test_value = object()
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
structure = Structure(b"F", *fields)
for i, original_value in enumerate(fields):
structure[i] = test_value
assert structure[i] == test_value
assert structure[-len(fields) + i] == test_value
assert structure[i] != original_value
assert structure[-len(fields) + i] != original_value

structure[i] = original_value
assert structure[i] == original_value
assert structure[-len(fields) + i] == original_value

structure[-len(fields) + i] = test_value
assert structure[i] == test_value
assert structure[-len(fields) + i] == test_value
assert structure[i] != original_value
assert structure[-len(fields) + i] != original_value

structure[-len(fields) + i] = original_value
assert structure[i] == original_value
assert structure[-len(fields) + i] == original_value
with pytest.raises(IndexError):
structure[len(fields)] = test_value
with pytest.raises(IndexError):
structure[-len(fields) - 1] = test_value