diff --git a/driver/clirr-ignored-differences.xml b/driver/clirr-ignored-differences.xml index dbc9a40f0c..95185eeb5a 100644 --- a/driver/clirr-ignored-differences.xml +++ b/driver/clirr-ignored-differences.xml @@ -818,4 +818,10 @@ org.neo4j.driver.Logging systemLogging() + + org/neo4j/driver/types/TypeSystem + 7012 + org.neo4j.driver.types.Type VECTOR() + + diff --git a/driver/src/main/java/org/neo4j/driver/Values.java b/driver/src/main/java/org/neo4j/driver/Values.java index e68454d9a0..eb7fe636f1 100644 --- a/driver/src/main/java/org/neo4j/driver/Values.java +++ b/driver/src/main/java/org/neo4j/driver/Values.java @@ -42,9 +42,15 @@ import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.internal.AsValue; import org.neo4j.driver.internal.GqlStatusError; +import org.neo4j.driver.internal.InternalByteVector; +import org.neo4j.driver.internal.InternalDoubleVector; +import org.neo4j.driver.internal.InternalFloatVector; +import org.neo4j.driver.internal.InternalIntVector; import org.neo4j.driver.internal.InternalIsoDuration; +import org.neo4j.driver.internal.InternalLongVector; import org.neo4j.driver.internal.InternalPoint2D; import org.neo4j.driver.internal.InternalPoint3D; +import org.neo4j.driver.internal.InternalShortVector; import org.neo4j.driver.internal.value.BooleanValue; import org.neo4j.driver.internal.value.BytesValue; import org.neo4j.driver.internal.value.DateTimeValue; @@ -60,6 +66,7 @@ import org.neo4j.driver.internal.value.PointValue; import org.neo4j.driver.internal.value.StringValue; import org.neo4j.driver.internal.value.TimeValue; +import org.neo4j.driver.internal.value.VectorValue; import org.neo4j.driver.mapping.Property; import org.neo4j.driver.types.Entity; import org.neo4j.driver.types.IsoDuration; @@ -69,6 +76,7 @@ import org.neo4j.driver.types.Point; import org.neo4j.driver.types.Relationship; import org.neo4j.driver.types.TypeSystem; +import org.neo4j.driver.types.Vector; import org.neo4j.driver.util.Preview; /** @@ -170,6 +178,9 @@ public static Value value(Object value) { if (value instanceof Point) { return value((Point) value); } + if (value instanceof Vector vector) { + return value(vector); + } if (value instanceof List) { return value((List) value); @@ -469,10 +480,11 @@ public static Value value(java.lang.Record record) { for (var recordComponent : recordComponents) { var propertyAnnotation = recordComponent.getAnnotation(Property.class); var property = propertyAnnotation != null ? propertyAnnotation.value() : recordComponent.getName(); + var isVector = recordComponent.getAnnotation(org.neo4j.driver.mapping.Vector.class) != null; Value value; try { var objectValue = recordComponent.getAccessor().invoke(record); - value = (objectValue != null) ? value(objectValue) : null; + value = (objectValue != null) ? isVector ? vector(objectValue) : value(objectValue) : null; } catch (Throwable throwable) { var message = "Failed to map '%s' property to value during mapping '%s' to map value" .formatted(property, record.getClass().getCanonicalName()); @@ -996,4 +1008,99 @@ public static Function> ofList() { public static Function> ofList(final Function innerMap) { return value -> value.asList(innerMap); } + + /** + * Returns Neo4j Vector that holds a sequence of {@code byte} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(byte[] elements) { + return value(new InternalByteVector(elements)); + } + + /** + * Returns Neo4j Vector that holds a sequence of {@code short} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(short[] elements) { + return value(new InternalShortVector(elements)); + } + + /** + * Returns Neo4j Vector that holds a sequence of {@code int} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(int[] elements) { + return value(new InternalIntVector(elements)); + } + + /** + * Returns Neo4j Vector that holds a sequence of {@code long} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(long[] elements) { + return value(new InternalLongVector(elements)); + } + + /** + * Returns Neo4j Vector that holds a sequence of {@code float} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(float[] elements) { + return value(new InternalFloatVector(elements)); + } + + /** + * Returns Neo4j Vector that holds a sequence of {@code double} values. + * + * @param elements the vector elements + * @return the vector value + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + public static Value vector(double[] elements) { + return value(new InternalDoubleVector(elements)); + } + + private static Value value(Vector vector) { + return new VectorValue(vector); + } + + private static Value vector(Object array) { + if (array instanceof byte[] elements) { + return value(Values.vector(elements)); + } else if (array instanceof short[] elements) { + return value(Values.vector(elements)); + } else if (array instanceof int[] elements) { + return value(Values.vector(elements)); + } else if (array instanceof long[] elements) { + return value(Values.vector(elements)); + } else if (array instanceof float[] elements) { + return value(Values.vector(elements)); + } else if (array instanceof double[] elements) { + return value(Values.vector(elements)); + } else { + throw new IllegalArgumentException( + "Unsupported vector element type: " + array.getClass().getName()); + } + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/AbstractArrayVector.java b/driver/src/main/java/org/neo4j/driver/internal/AbstractArrayVector.java new file mode 100644 index 0000000000..eb14b160ee --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/AbstractArrayVector.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.lang.reflect.Array; +import java.util.Objects; +import org.neo4j.bolt.connection.values.Vector; + +public abstract class AbstractArrayVector implements Vector { + private final Class elementType; + private final int length; + protected final T elements; + + AbstractArrayVector(T elements) { + this.elementType = elements.getClass().getComponentType(); + this.length = Array.getLength(elements); + this.elements = arraycopy(elements); + } + + public Class elementType() { + return elementType; + } + + public int length() { + return length; + } + + public T toArray() { + return arraycopy(elements); + } + + @Override + public Object elements() { + return elements; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + var that = (AbstractArrayVector) o; + return length == that.length + && Objects.equals(elementType, that.elementType) + && Objects.equals(elements, that.elements); + } + + @Override + public int hashCode() { + return Objects.hash(elementType, length, elements); + } + + @Override + public String toString() { + return "AbstractArrayVector{" + "elementType=" + + elementType + ", length=" + + length + ", elements=" + + elements + '}'; + } + + @SuppressWarnings({"unchecked", "SuspiciousSystemArraycopy"}) + private T arraycopy(T elements) { + var result = (T) Array.newInstance(elementType, length); + System.arraycopy(elements, 0, result, 0, length); + return result; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalByteVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalByteVector.java new file mode 100644 index 0000000000..bb47f09e0d --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalByteVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.ByteVector; + +public final class InternalByteVector extends AbstractArrayVector implements ByteVector { + public InternalByteVector(byte[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalDoubleVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalDoubleVector.java new file mode 100644 index 0000000000..8b201127c3 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalDoubleVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.DoubleVector; + +public final class InternalDoubleVector extends AbstractArrayVector implements DoubleVector { + public InternalDoubleVector(double[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalFloatVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalFloatVector.java new file mode 100644 index 0000000000..4dc2f986f4 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalFloatVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.FloatVector; + +public final class InternalFloatVector extends AbstractArrayVector implements FloatVector { + public InternalFloatVector(float[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalIntVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalIntVector.java new file mode 100644 index 0000000000..d030fb1c7f --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalIntVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.IntVector; + +public final class InternalIntVector extends AbstractArrayVector implements IntVector { + public InternalIntVector(int[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalLongVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalLongVector.java new file mode 100644 index 0000000000..7feaf1af9b --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalLongVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.LongVector; + +public final class InternalLongVector extends AbstractArrayVector implements LongVector { + public InternalLongVector(long[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalShortVector.java b/driver/src/main/java/org/neo4j/driver/internal/InternalShortVector.java new file mode 100644 index 0000000000..846121a451 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalShortVector.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.types.ShortVector; + +public final class InternalShortVector extends AbstractArrayVector implements ShortVector { + public InternalShortVector(short[] elements) { + super(elements); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/types/InternalTypeSystem.java b/driver/src/main/java/org/neo4j/driver/internal/types/InternalTypeSystem.java index 404b5bcfc4..2b0eb267ef 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/types/InternalTypeSystem.java +++ b/driver/src/main/java/org/neo4j/driver/internal/types/InternalTypeSystem.java @@ -36,6 +36,7 @@ import static org.neo4j.driver.internal.types.TypeConstructor.RELATIONSHIP; import static org.neo4j.driver.internal.types.TypeConstructor.STRING; import static org.neo4j.driver.internal.types.TypeConstructor.TIME; +import static org.neo4j.driver.internal.types.TypeConstructor.VECTOR; import org.neo4j.driver.Value; import org.neo4j.driver.types.Type; @@ -70,6 +71,7 @@ public class InternalTypeSystem implements TypeSystem { private final TypeRepresentation dateTimeType = constructType(DATE_TIME); private final TypeRepresentation durationType = constructType(DURATION); private final TypeRepresentation nullType = constructType(NULL); + private final TypeRepresentation vectorType = constructType(VECTOR); private InternalTypeSystem() {} @@ -173,6 +175,11 @@ public Type NULL() { return nullType; } + @Override + public Type VECTOR() { + return vectorType; + } + private TypeRepresentation constructType(TypeConstructor tyCon) { return new TypeRepresentation(tyCon); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/types/TypeConstructor.java b/driver/src/main/java/org/neo4j/driver/internal/types/TypeConstructor.java index 9e3824dd93..68a6095981 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/types/TypeConstructor.java +++ b/driver/src/main/java/org/neo4j/driver/internal/types/TypeConstructor.java @@ -56,7 +56,8 @@ public boolean covers(Value value) { LOCAL_DATE_TIME, DATE_TIME, DURATION, - NULL; + NULL, + VECTOR; private static TypeConstructor typeConstructorOf(Value value) { return ((InternalValue) value).typeConstructor(); diff --git a/driver/src/main/java/org/neo4j/driver/internal/value/BoltValueFactory.java b/driver/src/main/java/org/neo4j/driver/internal/value/BoltValueFactory.java index e9618f7274..cd4a7534f4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/value/BoltValueFactory.java +++ b/driver/src/main/java/org/neo4j/driver/internal/value/BoltValueFactory.java @@ -244,6 +244,27 @@ public Value point(int srid, double x, double y, double z) { return ((InternalValue) Values.point(srid, x, y, z)); } + @Override + public Value vector(Class elementType, Object elements) { + Value value; + if (elements.getClass().equals(byte[].class)) { + value = ((InternalValue) Values.vector((byte[]) elements)); + } else if (elements.getClass().equals(short[].class)) { + value = ((InternalValue) Values.vector((short[]) elements)); + } else if (elements.getClass().equals(int[].class)) { + value = ((InternalValue) Values.vector((int[]) elements)); + } else if (elements.getClass().equals(long[].class)) { + value = ((InternalValue) Values.vector((long[]) elements)); + } else if (elements.getClass().equals(float[].class)) { + value = ((InternalValue) Values.vector((float[]) elements)); + } else if (elements.getClass().equals(double[].class)) { + value = ((InternalValue) Values.vector((double[]) elements)); + } else { + throw new AssertionError("Unsupported type: " + elements.getClass()); + } + return value; + } + @Override public Value unsupportedDateTimeValue(DateTimeException e) { return new UnsupportedDateTimeValue(e); diff --git a/driver/src/main/java/org/neo4j/driver/internal/value/ValueAdapter.java b/driver/src/main/java/org/neo4j/driver/internal/value/ValueAdapter.java index 24d875518d..e3ae9327a9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/value/ValueAdapter.java +++ b/driver/src/main/java/org/neo4j/driver/internal/value/ValueAdapter.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import org.neo4j.bolt.connection.values.Vector; import org.neo4j.driver.Value; import org.neo4j.driver.exceptions.value.NotMultiValued; import org.neo4j.driver.exceptions.value.Uncoercible; @@ -339,6 +340,11 @@ public Iterable values(Function mapFunction) { throw new NotMultiValued(type().name() + " is not iterable"); } + @Override + public Vector asBoltVector() { + throw new Uncoercible(type().name(), "Vector"); + } + @Override public final TypeConstructor typeConstructor() { return ((TypeRepresentation) type()).constructor(); diff --git a/driver/src/main/java/org/neo4j/driver/internal/value/VectorValue.java b/driver/src/main/java/org/neo4j/driver/internal/value/VectorValue.java new file mode 100644 index 0000000000..090488a3e2 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/value/VectorValue.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.value; + +import org.neo4j.driver.internal.AbstractArrayVector; +import org.neo4j.driver.internal.types.InternalTypeSystem; +import org.neo4j.driver.types.ByteVector; +import org.neo4j.driver.types.DoubleVector; +import org.neo4j.driver.types.FloatVector; +import org.neo4j.driver.types.IntVector; +import org.neo4j.driver.types.LongVector; +import org.neo4j.driver.types.ShortVector; +import org.neo4j.driver.types.Type; +import org.neo4j.driver.types.Vector; + +public class VectorValue extends ObjectValueAdapter { + public VectorValue(Vector vector) { + super(vector); + } + + @Override + public Type type() { + return InternalTypeSystem.TYPE_SYSTEM.VECTOR(); + } + + @Override + public org.neo4j.bolt.connection.values.Type boltValueType() { + return org.neo4j.bolt.connection.values.Type.VECTOR; + } + + @Override + public T as(Class targetClass) { + if (targetClass.isAssignableFrom(ByteVector.class) + || targetClass.isAssignableFrom(ShortVector.class) + || targetClass.isAssignableFrom(IntVector.class) + || targetClass.isAssignableFrom(LongVector.class) + || targetClass.isAssignableFrom(FloatVector.class) + || targetClass.isAssignableFrom(DoubleVector.class)) { + return targetClass.cast(asObject()); + } else if (targetClass.isArray()) { + var arrayVector = (AbstractArrayVector) asObject(); + if (targetClass.getComponentType().equals(arrayVector.elementType())) { + return targetClass.cast(arrayVector.toArray()); + } else { + throw new AssertionError("Unsupported type: " + targetClass); + } + } + return asMapped(targetClass); + } + + @Override + public org.neo4j.bolt.connection.values.Vector asBoltVector() { + var vector = (AbstractArrayVector) asObject(); + return new org.neo4j.bolt.connection.values.Vector() { + @Override + public Class elementType() { + return vector.elementType(); + } + + @Override + public Object elements() { + return vector.elements(); + } + }; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/mapping/Vector.java b/driver/src/main/java/org/neo4j/driver/mapping/Vector.java new file mode 100644 index 0000000000..38b4a0cf78 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/mapping/Vector.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.mapping; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.neo4j.driver.util.Preview; + +/** + * Marks the annotated array as Neo4j Vector. + *

+ * Example: + *

+ * {@code
+ * // assuming the following Java record
+ * public record Entity(String name, @Vector double[] vector) {}
+ * // the vector will be stored as Neo4j Vector in the database
+ * }
+ * 
+ * + * @since 6.0.0 + */ +@Target({ElementType.PARAMETER, ElementType.RECORD_COMPONENT}) +@Retention(RetentionPolicy.RUNTIME) +@Preview(name = "Neo4j Vector") +public @interface Vector {} diff --git a/driver/src/main/java/org/neo4j/driver/types/ByteVector.java b/driver/src/main/java/org/neo4j/driver/types/ByteVector.java new file mode 100644 index 0000000000..dbc0da6551 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/ByteVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalByteVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code byte} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface ByteVector extends Vector permits InternalByteVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + byte[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/DoubleVector.java b/driver/src/main/java/org/neo4j/driver/types/DoubleVector.java new file mode 100644 index 0000000000..a1fb29105c --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/DoubleVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalDoubleVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code double} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface DoubleVector extends Vector permits InternalDoubleVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + double[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/FloatVector.java b/driver/src/main/java/org/neo4j/driver/types/FloatVector.java new file mode 100644 index 0000000000..94cd92b968 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/FloatVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalFloatVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code float} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface FloatVector extends Vector permits InternalFloatVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + float[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/IntVector.java b/driver/src/main/java/org/neo4j/driver/types/IntVector.java new file mode 100644 index 0000000000..050cddc4a3 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/IntVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalIntVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code int} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface IntVector extends Vector permits InternalIntVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + int[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/LongVector.java b/driver/src/main/java/org/neo4j/driver/types/LongVector.java new file mode 100644 index 0000000000..20993f5926 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/LongVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalLongVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code long} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface LongVector extends Vector permits InternalLongVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + long[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/ShortVector.java b/driver/src/main/java/org/neo4j/driver/types/ShortVector.java new file mode 100644 index 0000000000..b9b21ac73c --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/ShortVector.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.InternalShortVector; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type that holds a sequence of {@code short} values. + * + * @since 6.0.0 + * @see Vector + * @see Values + */ +@Preview(name = "Neo4j Vector") +public sealed interface ShortVector extends Vector permits InternalShortVector { + /** + * Returns array with vector elements. + * + * @return the array with vector elements + */ + short[] toArray(); +} diff --git a/driver/src/main/java/org/neo4j/driver/types/TypeSystem.java b/driver/src/main/java/org/neo4j/driver/types/TypeSystem.java index af2380877d..efced007a1 100644 --- a/driver/src/main/java/org/neo4j/driver/types/TypeSystem.java +++ b/driver/src/main/java/org/neo4j/driver/types/TypeSystem.java @@ -19,6 +19,7 @@ import static org.neo4j.driver.internal.types.InternalTypeSystem.TYPE_SYSTEM; import org.neo4j.driver.util.Immutable; +import org.neo4j.driver.util.Preview; /** * A listing of all database types this driver can handle. @@ -155,4 +156,12 @@ static TypeSystem getDefault() { * @return the type instance */ Type NULL(); + + /** + * Returns a {@link Type} instance representing Neo4j Vector. + * @return the type instance + * @since 6.0.0 + */ + @Preview(name = "Neo4j Vector") + Type VECTOR(); } diff --git a/driver/src/main/java/org/neo4j/driver/types/Vector.java b/driver/src/main/java/org/neo4j/driver/types/Vector.java new file mode 100644 index 0000000000..2ced7590ab --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/types/Vector.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.types; + +import org.neo4j.driver.Values; +import org.neo4j.driver.util.Preview; + +/** + * Represents Neo4j Vector type. + *

+ * Values that represent vectors be created using the following methods: + *

    + *
  • {@link Values#vector(byte[])} - returns {@link ByteVector}
  • + *
  • {@link Values#vector(short[])} - returns {@link ShortVector}
  • + *
  • {@link Values#vector(int[])} - returns {@link IntVector}
  • + *
  • {@link Values#vector(long[])} - returns {@link LongVector}
  • + *
  • {@link Values#vector(float[])} - returns {@link FloatVector}
  • + *
  • {@link Values#vector(double[])} - returns {@link DoubleVector}
  • + *
+ * + * @see Values + * @since 6.0.0 + */ +@Preview(name = "Neo4j Vector") +public sealed interface Vector permits ByteVector, DoubleVector, FloatVector, IntVector, LongVector, ShortVector { + /** + * Returns the element type. + * + * @return the element type + */ + Class elementType(); + + /** + * Returns the length. + * + * @return the length + */ + int length(); +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/ValuesTest.java b/driver/src/test/java/org/neo4j/driver/internal/ValuesTest.java index f606539a6a..ca489ba4cf 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/ValuesTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/ValuesTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -38,7 +39,6 @@ import static org.neo4j.driver.Values.ofToString; import static org.neo4j.driver.Values.point; import static org.neo4j.driver.Values.value; -import static org.neo4j.driver.Values.values; import static org.neo4j.driver.internal.util.ValueFactory.emptyNodeValue; import static org.neo4j.driver.internal.util.ValueFactory.emptyRelationshipValue; import static org.neo4j.driver.internal.util.ValueFactory.filledPathValue; @@ -53,11 +53,13 @@ import java.time.ZonedDateTime; import java.time.temporal.ChronoUnit; import java.util.ArrayDeque; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.function.Function; import java.util.stream.Stream; @@ -74,8 +76,15 @@ import org.neo4j.driver.internal.value.LocalTimeValue; import org.neo4j.driver.internal.value.MapValue; import org.neo4j.driver.internal.value.TimeValue; +import org.neo4j.driver.mapping.Vector; +import org.neo4j.driver.types.ByteVector; +import org.neo4j.driver.types.DoubleVector; +import org.neo4j.driver.types.FloatVector; +import org.neo4j.driver.types.IntVector; import org.neo4j.driver.types.IsoDuration; +import org.neo4j.driver.types.LongVector; import org.neo4j.driver.types.Point; +import org.neo4j.driver.types.ShortVector; class ValuesTest { @Test @@ -579,6 +588,18 @@ void shouldMapJavaRecordToMap() { var javaDuration = Duration.of(1000, ChronoUnit.MINUTES); var point2d = new InternalPoint2D(0, 0, 0); var point3d = new InternalPoint3D(0, 0, 0, 0); + var byteVector = new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var byteVectorObject = Values.vector(byteVector).as(ByteVector.class); + var shortVector = new short[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var shortVectorObject = Values.vector(shortVector).as(ShortVector.class); + var intVector = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var intVectorObject = Values.vector(intVector).as(IntVector.class); + var longVector = new long[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var longVectorObject = Values.vector(longVector).as(LongVector.class); + var floatVector = new float[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var floatVectorObject = Values.vector(floatVector).as(FloatVector.class); + var doubleVector = new double[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + var doubleVectorObject = Values.vector(doubleVector).as(DoubleVector.class); var valueHolder = new ValueHolder( string, null, @@ -595,17 +616,29 @@ void shouldMapJavaRecordToMap() { period, javaDuration, point2d, - point3d); + point3d, + byteVector, + byteVectorObject, + shortVector, + shortVectorObject, + intVector, + intVectorObject, + longVector, + longVectorObject, + floatVector, + floatVectorObject, + doubleVector, + doubleVectorObject); // when var mapValue = Values.value(valueHolder); // then - assertEquals(15, mapValue.size()); + assertEquals(27, mapValue.size()); assertEquals(string, mapValue.get("string").as(String.class)); assertFalse(mapValue.containsKey("nullValue")); assertEquals(listWithString, mapValue.get("listWithString").as(List.class)); - assertEquals(bytes, mapValue.get("bytes").as(byte[].class)); + assertArrayEquals(bytes, mapValue.get("bytes").as(byte[].class)); assertEquals(bool, mapValue.get("bool").as(boolean.class)); assertEquals(boltInteger, mapValue.get("boltInteger").as(long.class)); assertEquals(boltFloat, mapValue.get("boltFloat").as(double.class)); @@ -618,6 +651,36 @@ void shouldMapJavaRecordToMap() { assertEquals(javaDuration, mapValue.get("javaDuration").as(Duration.class)); assertEquals(point2d, mapValue.get("point2d").as(Point.class)); assertEquals(point3d, mapValue.get("point3d").as(Point.class)); + assertArrayEquals( + byteVector, mapValue.get("byteVector").as(ByteVector.class).toArray()); + assertArrayEquals( + byteVector, + mapValue.get("byteVectorObject").as(ByteVector.class).toArray()); + assertArrayEquals( + shortVector, mapValue.get("shortVector").as(ShortVector.class).toArray()); + assertArrayEquals( + shortVector, + mapValue.get("shortVectorObject").as(ShortVector.class).toArray()); + assertArrayEquals( + intVector, mapValue.get("intVector").as(IntVector.class).toArray()); + assertArrayEquals( + intVector, mapValue.get("intVectorObject").as(IntVector.class).toArray()); + assertArrayEquals( + longVector, mapValue.get("longVector").as(LongVector.class).toArray()); + assertArrayEquals( + longVector, + mapValue.get("longVectorObject").as(LongVector.class).toArray()); + assertArrayEquals( + floatVector, mapValue.get("floatVector").as(FloatVector.class).toArray()); + assertArrayEquals( + floatVector, + mapValue.get("floatVectorObject").as(FloatVector.class).toArray()); + assertArrayEquals( + doubleVector, + mapValue.get("doubleVector").as(DoubleVector.class).toArray()); + assertArrayEquals( + doubleVector, + mapValue.get("doubleVectorObject").as(DoubleVector.class).toArray()); assertEquals(valueHolder, mapValue.as(ValueHolder.class)); } @@ -637,5 +700,84 @@ public record ValueHolder( Period period, Duration javaDuration, Point point2d, - Point point3d) {} + Point point3d, + @Vector byte[] byteVector, + ByteVector byteVectorObject, + @Vector short[] shortVector, + ShortVector shortVectorObject, + @Vector int[] intVector, + IntVector intVectorObject, + @Vector long[] longVector, + LongVector longVectorObject, + @Vector float[] floatVector, + FloatVector floatVectorObject, + @Vector double[] doubleVector, + DoubleVector doubleVectorObject) { + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ValueHolder that = (ValueHolder) o; + return bool == that.bool + && boltInteger == that.boltInteger + && Double.compare(boltFloat, that.boltFloat) == 0 + && Objects.deepEquals(bytes, that.bytes) + && Objects.equals(string, that.string) + && Objects.equals(period, that.period) + && Objects.equals(point2d, that.point2d) + && Objects.equals(point3d, that.point3d) + && Objects.equals(date, that.date) + && Objects.equals(time, that.time) + && Objects.deepEquals(intVector, that.intVector) + && Objects.equals(nullValue, that.nullValue) + && Objects.deepEquals(byteVector, that.byteVector) + && Objects.deepEquals(longVector, that.longVector) + && Objects.deepEquals(shortVector, that.shortVector) + && Objects.deepEquals(floatVector, that.floatVector) + && Objects.equals(duration, that.duration) + && Objects.equals(javaDuration, that.javaDuration) + && Objects.deepEquals(doubleVector, that.doubleVector) + && Objects.equals(dateTime, that.dateTime) + && Objects.equals(intVectorObject, that.intVectorObject) + && Objects.equals(listWithString, that.listWithString) + && Objects.equals(localDateTime, that.localDateTime) + && Objects.equals(byteVectorObject, that.byteVectorObject) + && Objects.equals(longVectorObject, that.longVectorObject) + && Objects.equals(shortVectorObject, that.shortVectorObject) + && Objects.equals(floatVectorObject, that.floatVectorObject) + && Objects.equals(doubleVectorObject, that.doubleVectorObject); + } + + @Override + public int hashCode() { + return Objects.hash( + string, + nullValue, + listWithString, + Arrays.hashCode(bytes), + bool, + boltInteger, + boltFloat, + date, + time, + dateTime, + localDateTime, + duration, + period, + javaDuration, + point2d, + point3d, + Arrays.hashCode(byteVector), + byteVectorObject, + Arrays.hashCode(shortVector), + shortVectorObject, + Arrays.hashCode(intVector), + intVectorObject, + Arrays.hashCode(longVector), + longVectorObject, + Arrays.hashCode(floatVector), + floatVectorObject, + Arrays.hashCode(doubleVector), + doubleVectorObject); + } + } } diff --git a/pom.xml b/pom.xml index 69b8235550..b047cc81e7 100644 --- a/pom.xml +++ b/pom.xml @@ -33,7 +33,7 @@ true - 5.0.0 + 6.0.0 1.0.4 diff --git a/testkit-backend/pom.xml b/testkit-backend/pom.xml index f206a49090..625e86afc9 100644 --- a/testkit-backend/pom.xml +++ b/testkit-backend/pom.xml @@ -19,6 +19,7 @@ ${project.basedir}/.. ,-processing + 21 diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/TestkitModule.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/TestkitModule.java index cab0b15e2e..6267357c67 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/TestkitModule.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/TestkitModule.java @@ -24,6 +24,7 @@ import neo4j.org.testkit.backend.messages.requests.deserializer.TestkitCypherDateTimeDeserializer; import neo4j.org.testkit.backend.messages.requests.deserializer.TestkitCypherDurationDeserializer; import neo4j.org.testkit.backend.messages.requests.deserializer.TestkitCypherTimeDeserializer; +import neo4j.org.testkit.backend.messages.requests.deserializer.TestkitCypherVectorDeserializer; import neo4j.org.testkit.backend.messages.requests.deserializer.TestkitListDeserializer; import neo4j.org.testkit.backend.messages.requests.deserializer.types.CypherDateTime; import neo4j.org.testkit.backend.messages.requests.deserializer.types.CypherTime; @@ -40,6 +41,7 @@ import neo4j.org.testkit.backend.messages.responses.serializer.TestkitRelationshipValueSerializer; import neo4j.org.testkit.backend.messages.responses.serializer.TestkitTimeValueSerializer; import neo4j.org.testkit.backend.messages.responses.serializer.TestkitValueSerializer; +import neo4j.org.testkit.backend.messages.responses.serializer.TestkitVectorSerializer; import org.neo4j.driver.Record; import org.neo4j.driver.Value; import org.neo4j.driver.internal.value.DateTimeValue; @@ -53,7 +55,9 @@ import org.neo4j.driver.internal.value.PathValue; import org.neo4j.driver.internal.value.RelationshipValue; import org.neo4j.driver.internal.value.TimeValue; +import org.neo4j.driver.internal.value.VectorValue; import org.neo4j.driver.types.IsoDuration; +import org.neo4j.driver.types.Vector; public class TestkitModule extends SimpleModule { @Serial @@ -66,6 +70,7 @@ public TestkitModule() { this.addDeserializer(CypherTime.class, new TestkitCypherTimeDeserializer()); this.addDeserializer(IsoDuration.class, new TestkitCypherDurationDeserializer()); this.addDeserializer(LocalDate.class, new TestkitCypherDateDeserializer()); + this.addDeserializer(Vector.class, new TestkitCypherVectorDeserializer()); this.addSerializer(Value.class, new TestkitValueSerializer()); this.addSerializer(NodeValue.class, new TestkitNodeValueSerializer()); @@ -80,5 +85,6 @@ public TestkitModule() { this.addSerializer(MapValue.class, new TestkitMapValueSerializer()); this.addSerializer(PathValue.class, new TestkitPathValueSerializer()); this.addSerializer(RelationshipValue.class, new TestkitRelationshipValueSerializer()); + this.addSerializer(VectorValue.class, new TestkitVectorSerializer()); } } diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/VectorType.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/VectorType.java new file mode 100644 index 0000000000..b99c99778e --- /dev/null +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/VectorType.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package neo4j.org.testkit.backend.messages; + +import java.util.Objects; + +public enum VectorType { + BYTE("i8"), + SHORT("i16"), + INT("i32"), + LONG("i64"), + FLOAT("f32"), + DOUBLE("f64"); + + private final String name; + + VectorType(String name) { + this.name = Objects.requireNonNull(name); + } + + public String getName() { + return name; + } + + public static VectorType of(String name) { + for (var type : VectorType.values()) { + if (type.name.equals(name)) { + return type; + } + } + throw new IllegalArgumentException("Unknown type " + name); + } +} diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java index a6cc752800..f73ea248f9 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java @@ -44,6 +44,7 @@ public class GetFeatures implements TestkitRequest { "Feature:Bolt:5.6", "Feature:Bolt:5.7", "Feature:Bolt:5.8", + "Feature:Bolt:6.0", "AuthorizationExpiredTreatment", "ConfHint:connection.recv_timeout_seconds", "Feature:Auth:Bearer", diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java index bcfe007691..4e469169df 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java @@ -75,6 +75,11 @@ public class StartTest implements TestkitRequest { skipMessage); COMMON_SKIP_PATTERN_TO_REASON.put( "^.*\\.TestHomeDbMixedCluster\\.test_connection_acquisition_timeout_during_fallback$", skipMessage); + COMMON_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestConnectionAcquisitionTimeoutMs\\.test_does_encompass_router_route_response$", skipMessage); + COMMON_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestConnectionAcquisitionTimeoutMs\\.test_router_handshake_shares_acquisition_timeout$", + skipMessage); skipMessage = "This test needs updating to implement expected behaviour"; COMMON_SKIP_PATTERN_TO_REASON.put( "^.*\\.TestAuthenticationSchemes[^.]+\\.test_custom_scheme_empty$", skipMessage); diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/deserializer/TestkitCypherVectorDeserializer.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/deserializer/TestkitCypherVectorDeserializer.java new file mode 100644 index 0000000000..edcebc4d78 --- /dev/null +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/deserializer/TestkitCypherVectorDeserializer.java @@ -0,0 +1,147 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package neo4j.org.testkit.backend.messages.requests.deserializer; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import java.io.IOException; +import java.io.Serial; +import java.util.Arrays; +import neo4j.org.testkit.backend.messages.VectorType; +import org.neo4j.driver.internal.InternalByteVector; +import org.neo4j.driver.internal.InternalDoubleVector; +import org.neo4j.driver.internal.InternalFloatVector; +import org.neo4j.driver.internal.InternalIntVector; +import org.neo4j.driver.internal.InternalLongVector; +import org.neo4j.driver.internal.InternalShortVector; +import org.neo4j.driver.types.Vector; + +public class TestkitCypherVectorDeserializer extends StdDeserializer { + @Serial + private static final long serialVersionUID = 3489940766207129614L; + + @SuppressWarnings("serial") + private final TestkitCypherTypeMapper mapper; + + public TestkitCypherVectorDeserializer() { + super(Vector.class); + mapper = new TestkitCypherTypeMapper(); + } + + @Override + public Vector deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + var data = mapper.mapData(p, ctxt, new CypherVectorData()); + return switch (VectorType.of(data.dtype)) { + case BYTE -> new InternalByteVector(deserializeToBytes(data.data)); + case SHORT -> new InternalShortVector(deserializeToShorts(data.data)); + case INT -> new InternalIntVector(deserializeToIntegers(data.data)); + case LONG -> new InternalLongVector(deserializeToLongs(data.data)); + case FLOAT -> new InternalFloatVector(deserializeToFloats(data.data)); + case DOUBLE -> new InternalDoubleVector(deserializeToDoubles(data.data)); + }; + } + + private static final class CypherVectorData { + String dtype; + String data; + } + + public static byte[] deserializeToBytes(String hex) { + if (hex.isEmpty()) { + return new byte[0]; + } + var parts = hex.trim().split("\\s+"); + var result = new byte[parts.length]; + for (var i = 0; i < parts.length; i++) { + result[i] = (byte) Integer.parseInt(parts[i], 16); + } + return result; + } + + public static short[] deserializeToShorts(String hex) { + if (hex.isEmpty()) { + return new short[0]; + } + var parts = hex.trim().split("\\s+"); + if (parts.length % 2 != 0) throw new IllegalArgumentException("Invalid string: " + hex); + + var result = new short[parts.length / 2]; + for (var i = 0; i < result.length; i++) { + var hi = Integer.parseInt(parts[i * 2], 16); + var lo = Integer.parseInt(parts[i * 2 + 1], 16); + result[i] = (short) ((hi << 8) | lo); + } + return result; + } + + public static int[] deserializeToIntegers(String hex) { + if (hex.isEmpty()) { + return new int[0]; + } + var parts = hex.trim().split("\\s+"); + if (parts.length % 4 != 0) throw new IllegalArgumentException("Invalid string: " + hex); + + var result = new int[parts.length / 4]; + for (var i = 0; i < result.length; i++) { + var b0 = Integer.parseInt(parts[i * 4], 16); + var b1 = Integer.parseInt(parts[i * 4 + 1], 16); + var b2 = Integer.parseInt(parts[i * 4 + 2], 16); + var b3 = Integer.parseInt(parts[i * 4 + 3], 16); + result[i] = (b0 << 24) | (b1 << 16) | (b2 << 8) | b3; + } + return result; + } + + public static long[] deserializeToLongs(String hex) { + if (hex.isEmpty()) { + return new long[0]; + } + var parts = hex.trim().split("\\s+"); + if (parts.length % 8 != 0) throw new IllegalArgumentException("Invalid string: " + hex); + + var result = new long[parts.length / 8]; + for (var i = 0; i < result.length; i++) { + long val = 0; + for (var j = 0; j < 8; j++) { + val = (val << 8) | Integer.parseInt(parts[i * 8 + j], 16); + } + result[i] = val; + } + return result; + } + + public static float[] deserializeToFloats(String hex) { + if (hex.isEmpty()) { + return new float[0]; + } + var bits = deserializeToIntegers(hex); + var result = new float[bits.length]; + for (var i = 0; i < bits.length; i++) { + result[i] = Float.intBitsToFloat(bits[i]); + } + return result; + } + + public static double[] deserializeToDoubles(String hex) { + if (hex.isEmpty()) { + return new double[0]; + } + var bits = deserializeToLongs(hex); + return Arrays.stream(bits).mapToDouble(Double::longBitsToDouble).toArray(); + } +} diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/GenUtils.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/GenUtils.java index 1b7bc74410..5da08ae76c 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/GenUtils.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/GenUtils.java @@ -26,6 +26,7 @@ import neo4j.org.testkit.backend.messages.requests.deserializer.types.CypherDateTime; import neo4j.org.testkit.backend.messages.requests.deserializer.types.CypherTime; import org.neo4j.driver.types.IsoDuration; +import org.neo4j.driver.types.Vector; @AllArgsConstructor(access = AccessLevel.PRIVATE) public final class GenUtils { @@ -87,6 +88,7 @@ public static Class cypherTypeToJavaType(String typeString) { case "CypherTime" -> CypherTime.class; case "CypherDate" -> LocalDate.class; case "CypherDuration" -> IsoDuration.class; + case "CypherVector" -> Vector.class; default -> null; }; } diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/TestkitVectorSerializer.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/TestkitVectorSerializer.java new file mode 100644 index 0000000000..1c1ee91206 --- /dev/null +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/responses/serializer/TestkitVectorSerializer.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package neo4j.org.testkit.backend.messages.responses.serializer; + +import static neo4j.org.testkit.backend.messages.responses.serializer.GenUtils.cypherObject; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import java.io.IOException; +import java.io.Serial; +import java.util.Arrays; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import neo4j.org.testkit.backend.messages.VectorType; +import org.neo4j.driver.internal.value.VectorValue; +import org.neo4j.driver.types.ByteVector; +import org.neo4j.driver.types.DoubleVector; +import org.neo4j.driver.types.FloatVector; +import org.neo4j.driver.types.IntVector; +import org.neo4j.driver.types.LongVector; +import org.neo4j.driver.types.ShortVector; +import org.neo4j.driver.types.Vector; + +public class TestkitVectorSerializer extends StdSerializer { + @Serial + private static final long serialVersionUID = 5456264010641357998L; + + public TestkitVectorSerializer() { + super(VectorValue.class); + } + + @Override + public void serialize(VectorValue vectorValue, JsonGenerator gen, SerializerProvider serializerProvider) + throws IOException { + String dtype; + String data; + var vector = vectorValue.as(Vector.class); + switch (vector) { + case ByteVector byteVector -> { + dtype = VectorType.BYTE.getName(); + data = toHexString(byteVector.toArray()); + } + case ShortVector shortVector -> { + dtype = VectorType.SHORT.getName(); + data = toHexString(shortVector.toArray()); + } + case IntVector intVector -> { + dtype = VectorType.INT.getName(); + data = toHexString(intVector.toArray()); + } + case LongVector longVector -> { + dtype = VectorType.LONG.getName(); + data = toHexString(longVector.toArray()); + } + case FloatVector floatVector -> { + dtype = VectorType.FLOAT.getName(); + data = toHexString(floatVector.toArray()); + } + case DoubleVector doubleVector -> { + dtype = VectorType.DOUBLE.getName(); + data = toHexString(doubleVector.toArray()); + } + default -> throw new IllegalArgumentException( + "Unsupported vector type: " + vector.getClass().getName()); + } + + cypherObject(gen, "CypherVector", () -> { + gen.writeFieldName("dtype"); + gen.writeString(dtype); + gen.writeFieldName("data"); + gen.writeString(data); + }); + } + + public static String toHexString(byte[] array) { + return IntStream.range(0, array.length) + .mapToObj(i -> String.format("%02x", array[i] & 0xFF)) + .collect(Collectors.joining(" ")); + } + + public static String toHexString(short[] array) { + return IntStream.range(0, array.length) + .mapToObj(i -> String.format("%02x %02x", (array[i] >> 8) & 0xFF, array[i] & 0xFF)) + .collect(Collectors.joining(" ")); + } + + public static String toHexString(int[] array) { + return Arrays.stream(array) + .mapToObj(val -> String.format( + "%02x %02x %02x %02x", (val >> 24) & 0xFF, (val >> 16) & 0xFF, (val >> 8) & 0xFF, val & 0xFF)) + .collect(Collectors.joining(" ")); + } + + public static String toHexString(long[] array) { + return Arrays.stream(array) + .mapToObj(val -> String.format( + "%02x %02x %02x %02x %02x %02x %02x %02x", + (val >> 56) & 0xFF, + (val >> 48) & 0xFF, + (val >> 40) & 0xFF, + (val >> 32) & 0xFF, + (val >> 24) & 0xFF, + (val >> 16) & 0xFF, + (val >> 8) & 0xFF, + val & 0xFF)) + .collect(Collectors.joining(" ")); + } + + public static String toHexString(float[] array) { + return toHexString(IntStream.range(0, array.length) + .map(i -> Float.floatToRawIntBits(array[i])) + .toArray()); + } + + public static String toHexString(double[] array) { + return toHexString( + Arrays.stream(array).mapToLong(Double::doubleToRawLongBits).toArray()); + } +} diff --git a/testkit/Dockerfile b/testkit/Dockerfile index 5515042538..04139fdc3b 100644 --- a/testkit/Dockerfile +++ b/testkit/Dockerfile @@ -1,6 +1,6 @@ FROM debian:bullseye-slim -ENV JAVA_HOME=/usr/lib/jvm/openjdk-17 \ +ENV JAVA_HOME=/usr/lib/jvm/openjdk-21 \ PYTHON=python3 RUN apt-get update && apt-get install -y \ @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y \ && rm -rf /var/lib/apt/lists/* # https://hub.docker.com/_/eclipse-temurin -COPY --from=eclipse-temurin:17-jdk /opt/java/openjdk $JAVA_HOME +COPY --from=eclipse-temurin:21-jdk /opt/java/openjdk $JAVA_HOME COPY --from=maven:3.9.2-eclipse-temurin-17 /usr/share/maven /opt/apache-maven