Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/neo4j/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def __str__(self) -> str:
return str(self._inner)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.raw()!r}, {self.dtype!r})"
return (
f"{self.__class__.__name__}({self.raw()!r}, {self.dtype.value!r})"
)

@classmethod
def from_bytes(
Expand Down Expand Up @@ -627,6 +629,7 @@ class _InnerVector(_abc.ABC):

dtype: _t.ClassVar[VectorDType]
size: _t.ClassVar[int]
cypher_inner_type_repr: _t.ClassVar[str]
_data: bytes
_data_le: bytes | None

Expand Down Expand Up @@ -676,6 +679,8 @@ def data_le(self, data: bytes, /) -> None:

def __init_subclass__(cls) -> None:
super().__init_subclass__()
if _abc.ABC in cls.__bases__:
return
dtype = getattr(cls, "dtype", None)
if not isinstance(dtype, VectorDType):
raise TypeError(
Expand Down Expand Up @@ -703,7 +708,12 @@ def __len__(self) -> int:

def __str__(self) -> str:
size = len(self)
return f"Vec[{self.dtype}; {size}]"
type_repr = self.cypher_inner_type_repr
values_repr = self._cypher_values_repr()
return f"vector({values_repr}, {size}, {type_repr})"

@_abc.abstractmethod
def _cypher_values_repr(self) -> str: ...

def __repr__(self) -> str:
cls_name = self.__class__.__name__
Expand Down Expand Up @@ -743,11 +753,20 @@ def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self:
def to_pyarrow(self) -> pyarrow.Array: ...


class _VecF64(_InnerVector):
class _InnerVectorFloat(_InnerVector, _abc.ABC):
__slots__ = ()

def _cypher_values_repr(self) -> str:
res = str(self.to_native())
return res.replace("nan", "NaN").replace("inf", "Infinity")


class _VecF64(_InnerVectorFloat):
__slots__ = ()

dtype = VectorDType.F64
size = 8
cypher_inner_type_repr = "FLOAT NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down Expand Up @@ -817,11 +836,12 @@ def to_pyarrow(self) -> pyarrow.Array:
)


class _VecF32(_InnerVector):
class _VecF32(_InnerVectorFloat):
__slots__ = ()

dtype = VectorDType.F32
size = 4
cypher_inner_type_repr = "FLOAT32 NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down Expand Up @@ -891,15 +911,23 @@ def to_pyarrow(self) -> pyarrow.Array:
)


class _InnerVectorInt(_InnerVector, _abc.ABC):
__slots__ = ()

def _cypher_values_repr(self) -> str:
return str(self.to_native())


_I64_MIN = -9_223_372_036_854_775_808
_I64_MAX = 9_223_372_036_854_775_807


class _VecI64(_InnerVector):
class _VecI64(_InnerVectorInt):
__slots__ = ()

dtype = VectorDType.I64
size = 8
cypher_inner_type_repr = "INTEGER NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down Expand Up @@ -987,11 +1015,12 @@ def to_pyarrow(self) -> pyarrow.Array:
_I32_MAX = 2_147_483_647


class _VecI32(_InnerVector):
class _VecI32(_InnerVectorInt):
__slots__ = ()

dtype = VectorDType.I32
size = 4
cypher_inner_type_repr = "INTEGER32 NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down Expand Up @@ -1079,11 +1108,12 @@ def to_pyarrow(self) -> pyarrow.Array:
_I16_MAX = 32_767


class _VecI16(_InnerVector):
class _VecI16(_InnerVectorInt):
__slots__ = ()

dtype = VectorDType.I16
size = 2
cypher_inner_type_repr = "INTEGER16 NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down Expand Up @@ -1171,11 +1201,12 @@ def to_pyarrow(self) -> pyarrow.Array:
_I8_MAX = 127


class _VecI8(_InnerVector):
class _VecI8(_InnerVectorInt):
__slots__ = ()

dtype = VectorDType.I8
size = 1
cypher_inner_type_repr = "INTEGER8 NOT NULL"

@classmethod
def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self:
Expand Down
Loading