diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 0c701985e..e4d1338f9 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -13,10 +13,13 @@ import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.BatchReference; +import io.weaviate.client6.v1.api.collections.data.DeleteManyResponse; import io.weaviate.client6.v1.api.collections.data.Reference; import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryReference; +import io.weaviate.client6.v1.api.collections.query.Where; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; @@ -271,4 +274,117 @@ public void testUpdate() throws IOException { .returns(vector, Vectors::getDefaultSingle); }); } + + @Test + public void testDeleteMany() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings, + collection -> collection + .properties(Property.integer("last_used"))); + + var things = client.collections.use(nsThings); + things.data.insert(Map.of("last_used", 1)); + var delete_1 = things.data.insert(Map.of("last_used", 5)).metadata().uuid(); + var delete_2 = things.data.insert(Map.of("last_used", 9)).metadata().uuid(); + + // Act (dry run) + things.data.deleteMany( + Where.property("last_used").gte(4), + opt -> opt.dryRun(true)); + + // Assert + Assertions.assertThat(things.data.exists(delete_1)).as("#1 exists").isTrue(); + Assertions.assertThat(things.data.exists(delete_2)).as("#2 exists").isTrue(); + + // Act (live run) + var deleted = things.data.deleteMany( + Where.property("last_used").gte(4), + opt -> opt.verbose(true)); + + // Assert + Assertions.assertThat(deleted) + .returns(2L, DeleteManyResponse::matches) + .returns(2L, DeleteManyResponse::successful) + .returns(0L, DeleteManyResponse::failed) + .extracting(DeleteManyResponse::objects, InstanceOfAssertFactories.list(DeleteManyResponse.DeletedObject.class)) + .extracting(DeleteManyResponse.DeletedObject::uuid) + .containsOnly(delete_1, delete_2); + + var count = things.aggregate.overAll( + cnt -> cnt + .objectLimit(100) + .includeTotalCount(true)) + .totalCount(); + + Assertions.assertThat(count) + .as("one object remaining") + .isEqualTo(1); + + } + + @Test + public void testInsertMany() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings); + + var things = client.collections.use(nsThings); + + // Act + things.data.insertMany(Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); + + // Assert + var count = things.aggregate.overAll( + cnt -> cnt + .objectLimit(100) + .includeTotalCount(true)) + .totalCount(); + + Assertions.assertThat(count) + .as("collection has 5 objects") + .isEqualTo(5); + } + + @Test + public void testReferenceAddMany() throws IOException { + // Arrange + var nsCities = ns("Cities"); + var nsAirports = ns("Airports"); + + client.collections.create(nsAirports); + client.collections.create(nsCities, c -> c + .references(Property.reference("hasAirports", nsAirports))); + + var airports = client.collections.use(nsAirports); + var cities = client.collections.use(nsCities); + + var alpha = airports.data.insert(Map.of()).uuid(); + var goodburg = cities.data.insert(Map.of(), city -> city + .reference("hasAirports", Reference.uuids(alpha))); + + // Act + var newAirports = airports.data.insertMany(Map.of(), Map.of()); + var bravo = newAirports.responses().get(0).uuid(); + var charlie = newAirports.responses().get(1).uuid(); + + var response = cities.data.referenceAddMany(BatchReference.uuids(goodburg, "hasAirports", bravo, charlie)); + + // Assert + Assertions.assertThat(response.errors()).isEmpty(); + + var goodburgAirports = cities.query.byId(goodburg.metadata().uuid(), + city -> city.returnReferences( + QueryReference.single("hasAirports", + airport -> airport.returnMetadata(Metadata.ID)))); + + Assertions.assertThat(goodburgAirports).get() + .as("Goodburg has 3 airports") + .extracting(WeaviateObject::references) + .extracting(references -> references.get("hasAirports"), InstanceOfAssertFactories.list(WeaviateObject.class)) + .extracting(WeaviateObject::uuid) + .contains(alpha, bravo, charlie); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 2dd4529bd..aeea29032 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -20,8 +20,8 @@ public CollectionHandle( CollectionDescriptor collectionDescriptor) { this.config = new WeaviateConfigClient(collectionDescriptor, restTransport, grpcTransport); + this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, grpcTransport); this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); - this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, this.query); this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index ec41a5a21..27b262c6d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -20,8 +20,8 @@ public CollectionHandleAsync( CollectionDescriptor collectionDescriptor) { this.config = new WeaviateConfigClientAsync(collectionDescriptor, restTransport, grpcTransport); + this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, grpcTransport); this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport); - this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, this.query); this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java index 6fdecbb84..6732b5256 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java @@ -1,5 +1,6 @@ package io.weaviate.client6.v1.api.collections; +import java.util.UUID; import java.util.function.Function; import com.google.gson.annotations.SerializedName; @@ -24,6 +25,10 @@ public static class Builder implements ObjectBuilder { private String uuid; private Vectors vectors; + public Builder uuid(UUID uuid) { + return uuid(uuid.toString()); + } + public Builder uuid(String uuid) { this.uuid = uuid; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java index cfd647ce3..9638bed49 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java @@ -104,6 +104,10 @@ public Float[][] getDefaultMulti() { return getMulti(VectorIndex.DEFAULT_VECTOR_NAME); } + public Map asMap() { + return Map.copyOf(namedVectors); + } + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java index 2f7345f2a..7b1b40306 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java @@ -6,6 +6,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import com.google.gson.Gson; import com.google.gson.JsonArray; @@ -26,6 +27,21 @@ public record WeaviateObject( Map> references, M metadata) { + /** Shorthand for accesing objects's UUID from metadata. */ + public String uuid() { + return metadata.uuid(); + } + + /** Shorthand for accesing objects's vectors from metadata. */ + public Vectors vectors() { + return metadata.vectors(); + } + + public static WeaviateObject of( + Function, ObjectBuilder>> fn) { + return fn.apply(new Builder<>()).build(); + } + public WeaviateObject(Builder builder) { this(builder.collection, builder.properties, builder.references, builder.metadata); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java new file mode 100644 index 000000000..8e14a04ec --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java @@ -0,0 +1,108 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.io.IOException; +import java.util.Arrays; + +import com.google.gson.TypeAdapter; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; + +public record BatchReference(String fromCollection, String fromProperty, String fromUuid, Reference reference) { + + public static BatchReference[] objects(WeaviateObject fromObject, String fromProperty, + WeaviateObject... toObjects) { + return Arrays.stream(toObjects) + .map(to -> new BatchReference( + fromObject.collection(), fromProperty, fromObject.metadata().uuid(), + Reference.object(to))) + .toArray(BatchReference[]::new); + } + + public static BatchReference[] uuids(WeaviateObject fromObject, String fromProperty, + String... toUuids) { + return Arrays.stream(toUuids) + .map(to -> new BatchReference( + fromObject.collection(), fromProperty, fromObject.metadata().uuid(), + Reference.uuids(to))) + .toArray(BatchReference[]::new); + } + + public static final TypeAdapter TYPE_ADAPTER = new TypeAdapter() { + + @Override + public void write(JsonWriter out, BatchReference value) throws IOException { + out.beginObject(); + + out.name("from"); + out.value(Reference.toBeacon(value.fromCollection, value.fromProperty, value.fromUuid)); + + out.name("to"); + out.value(Reference.toBeacon(value.reference.collection(), value.reference.uuids().get(0))); + + // TODO: add tenant + + out.endObject(); + } + + @Override + public BatchReference read(JsonReader in) throws IOException { + String fromCollection = null; + String fromProperty = null; + String fromUuid = null; + Reference toReference = null; + + in.beginObject(); + while (in.hasNext()) { + switch (in.nextName()) { + + case "from": { + var beacon = in.nextString(); + beacon = beacon.replaceFirst("weaviate://localhost/", ""); + + var parts = beacon.split("/"); + fromCollection = parts[0]; + fromUuid = parts[1]; + fromProperty = parts[2]; + break; + } + + case "to": { + String collection = null; + String id = null; + + var beacon = in.nextString(); + beacon = beacon.replaceFirst("weaviate://localhost/", ""); + if (beacon.contains("/")) { + var parts = beacon.split("/"); + collection = parts[0]; + id = parts[1]; + } else { + id = beacon; + } + toReference = new Reference(collection, id); + break; + } + + // case "tenant": + // switch (in.peek()) { + // case STRING: + // in.nextString(); + // case NULL: + // in.nextNull(); + // default: + // // We don't expect anything else + // } + // System.out.println("processed tenant"); + // break; + // default: + // in.skipValue(); + } + } + in.endObject(); + + return new BatchReference(fromCollection, fromProperty, fromUuid, toReference); + } + }.nullSafe(); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java new file mode 100644 index 000000000..fe4788481 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java @@ -0,0 +1,99 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.query.Where; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; +import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatchDelete; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record DeleteManyRequest(Where where, Boolean verbose, Boolean dryRun) { + + public static Rpc rpc( + CollectionDescriptor collectionDescriptor) { + return Rpc + .of( + request -> { + var message = WeaviateProtoBatchDelete.BatchDeleteRequest.newBuilder(); + message.setCollection(collectionDescriptor.name()); + + if (request.verbose != null) { + message.setVerbose(request.verbose); + } + if (request.dryRun != null) { + message.setDryRun(request.dryRun); + } + + var filters = WeaviateProtoBase.Filters.newBuilder(); + request.where.appendTo(filters); + message.setFilters(filters); + + return message.build(); + }, + reply -> { + var objects = reply.getObjectsList() + .stream() + .map(obj -> new DeleteManyResponse.DeletedObject( + ByteStringUtil.decodeUuid(obj.getUuid()).toString(), + obj.getSuccessful(), + obj.getError())) + .toList(); + + return new DeleteManyResponse( + reply.getTook(), + reply.getFailed(), + reply.getMatches(), + reply.getSuccessful(), + objects); + }, + () -> WeaviateBlockingStub::batchDelete, + () -> WeaviateFutureStub::batchDelete); + } + + public static DeleteManyRequest of(Where where) { + return of(where, ObjectBuilder.identity()); + } + + public DeleteManyRequest(Builder builder) { + this( + builder.where, + builder.verbose, + builder.dryRun); + } + + public static DeleteManyRequest of(Where where, Function> fn) { + return fn.apply(new Builder(where)).build(); + } + + public static class Builder implements ObjectBuilder { + // Required request parameters; + private final Where where; + + private Boolean verbose; + private Boolean dryRun; + + public Builder(Where where) { + this.where = where; + } + + public Builder verbose(boolean verbose) { + this.verbose = verbose; + return this; + } + + public Builder dryRun(boolean dryRun) { + this.dryRun = dryRun; + return this; + } + + @Override + public DeleteManyRequest build() { + return new DeleteManyRequest(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyResponse.java new file mode 100644 index 000000000..5df7fd93a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyResponse.java @@ -0,0 +1,9 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.List; + +public record DeleteManyResponse(float took, long failed, long matches, long successful, List objects) { + + public static record DeletedObject(String uuid, boolean successful, String error) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java new file mode 100644 index 000000000..6c6f42748 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java @@ -0,0 +1,171 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; +import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.Vectors.VectorType; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record InsertManyRequest(List> objects) { + + @SafeVarargs + public InsertManyRequest(WeaviateObject... objects) { + this(Arrays.asList(objects)); + } + + @SafeVarargs + public static final InsertManyRequest of(T... properties) { + var objects = Arrays.stream(properties) + .map(p -> WeaviateObject.of( + obj -> obj + .properties(p) + .metadata(ObjectMetadata.of(m -> m.uuid(UUID.randomUUID()))))) + .toList(); + return new InsertManyRequest(objects); + } + + public static Rpc, WeaviateProtoBatch.BatchObjectsRequest, InsertManyResponse, WeaviateProtoBatch.BatchObjectsReply> rpc( + List> insertObjects, + CollectionDescriptor collectionsDescriptor) { + return Rpc.of( + request -> { + var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); + + var batch = request.objects.stream().map(obj -> { + var batchObject = WeaviateProtoBatch.BatchObject.newBuilder(); + buildObject(batchObject, obj, collectionsDescriptor); + return batchObject.build(); + }).toList(); + + message.addAllObjects(batch); + return message.build(); + }, + response -> { + var insertErrors = response.getErrorsList(); + + var responses = new ArrayList(insertObjects.size()); + var errors = new ArrayList(insertErrors.size()); + var uuids = new ArrayList(); + + var failed = insertErrors.stream() + .collect(Collectors.toMap(err -> err.getIndex(), err -> err.getError())); + + var iter = insertObjects.listIterator(); + while (iter.hasNext()) { + var idx = iter.nextIndex(); + var next = iter.next(); + var uuid = next.metadata() != null ? next.metadata().uuid() : null; + + if (failed.containsKey(idx)) { + var err = failed.get(idx); + errors.add(err); + responses.add(new InsertManyResponse.InsertObject(uuid, false, err)); + } else { + uuids.add(uuid); + responses.add(new InsertManyResponse.InsertObject(uuid, true, null)); + } + } + + return new InsertManyResponse(response.getTook(), responses, uuids, errors); + }, + () -> WeaviateBlockingStub::batchObjects, + () -> WeaviateFutureStub::batchObjects); + } + + public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object, + WeaviateObject insert, + CollectionDescriptor collectionDescriptor) { + object.setCollection(collectionDescriptor.name()); + + var metadata = insert.metadata(); + if (metadata != null) { + if (metadata.uuid() != null) { + object.setUuid(metadata.uuid()); + } + + if (metadata.vectors() != null) { + var vectors = metadata.vectors().asMap() + .entrySet().stream().map(entry -> { + var value = entry.getValue(); + var vector = WeaviateProtoBase.Vectors.newBuilder() + .setName(entry.getKey()); + + if (value instanceof Float[] single) { + vector.setType(VectorType.VECTOR_TYPE_SINGLE_FP32); + vector.setVectorBytes(ByteStringUtil.encodeVectorSingle(single)); + } else if (value instanceof Float[][] multi) { + vector.setVectorBytes(ByteStringUtil.encodeVectorMulti(multi)); + vector.setType(VectorType.VECTOR_TYPE_MULTI_FP32); + } + + return vector.build(); + }).toList(); + object.addAllVectors(vectors); + } + } + + var properties = WeaviateProtoBatch.BatchObject.Properties.newBuilder(); + var nonRef = com.google.protobuf.Struct.newBuilder(); + var singleRef = new ArrayList(); + var multiRef = new ArrayList(); + + collectionDescriptor + .propertiesReader(insert.properties()).readProperties() + .entrySet().stream().forEach(entry -> { + var value = entry.getValue(); + var protoValue = com.google.protobuf.Value.newBuilder(); + + if (value instanceof String v) { + protoValue.setStringValue(v); + } else if (value instanceof Number v) { + protoValue.setNumberValue(v.doubleValue()); + } else { + assert false : "(insertMany) branch not covered"; + } + + nonRef.putFields(entry.getKey(), protoValue.build()); + }); + + insert.references() + .entrySet().stream().forEach(entry -> { + var references = entry.getValue(); + + // dyma: How are we supposed to know if the reference + // is single- or multi-target? + for (var ref : references) { + if (ref.collection() == null) { + singleRef.add( + WeaviateProtoBatch.BatchObject.SingleTargetRefProps.newBuilder() + .addAllUuids(ref.uuids()) + .setPropName(entry.getKey()) + .build()); + } else { + multiRef.add( + WeaviateProtoBatch.BatchObject.MultiTargetRefProps.newBuilder() + .setTargetCollection(ref.collection()) + .addAllUuids(ref.uuids()) + .setPropName(entry.getKey()) + .build()); + } + } + }); + + properties + .setNonRefProperties(nonRef) + .addAllSingleTargetRefProps(singleRef) + .addAllMultiTargetRefProps(multiRef); + + object.setProperties(properties); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyResponse.java new file mode 100644 index 000000000..ee41cc458 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyResponse.java @@ -0,0 +1,9 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.List; + +public record InsertManyResponse(float took, List responses, List uuids, List errors) { + + public static record InsertObject(String uuid, boolean successful, String error) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java index b6cb75b33..3a4206feb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java @@ -44,6 +44,22 @@ public static Reference collection(String collection, String... uuids) { return new Reference(collection, Arrays.asList(uuids)); } + public static String toBeacon(String collection, String uuid) { + return toBeacon(collection, null, uuid); + } + + public static String toBeacon(String collection, String property, String uuid) { + var beacon = "weaviate://localhost"; + if (collection != null) { + beacon += "/" + collection; + } + beacon += "/" + uuid; + if (property != null) { + beacon += "/" + property; + } + return beacon; + } + public static final TypeAdapter TYPE_ADAPTER = new TypeAdapter() { @Override @@ -51,14 +67,7 @@ public void write(JsonWriter out, Reference value) throws IOException { for (var uuid : value.uuids()) { out.beginObject(); out.name("beacon"); - - var beacon = "weaviate://localhost"; - if (value.collection() != null) { - beacon += "/" + value.collection(); - } - beacon += "/" + uuid; - - out.value(beacon); + out.value(toBeacon(value.collection(), uuid)); out.endObject(); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java new file mode 100644 index 000000000..808a158c5 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java @@ -0,0 +1,36 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record ReferenceAddManyRequest(List references) { + + public static final Endpoint endpoint( + List references) { + return Endpoint.of( + request -> "POST", + request -> "/batch/references", + (gson, request) -> JSON.serialize(request.references), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> { + var result = JSON.deserialize(response, ReferenceAddManyResponse.class); + var errors = new ArrayList(); + + for (var err : result.errors()) { + errors.add(new ReferenceAddManyResponse.BatchError( + err.message(), + references.get(err.referenceIndex()), + err.referenceIndex())); + } + return new ReferenceAddManyResponse(errors); + }); + } + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyResponse.java new file mode 100644 index 000000000..d0fc89ace --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyResponse.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; + +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +public record ReferenceAddManyResponse(List errors) { + public record BatchError(String message, BatchReference reference, int referenceIndex) { + } + + public static enum CustomJsonDeserializer implements JsonDeserializer { + INSTANCE; + + @Override + public ReferenceAddManyResponse deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) + throws JsonParseException { + + var errors = new ArrayList(); + int i = 0; + for (var el : json.getAsJsonArray()) { + var result = el.getAsJsonObject().get("result").getAsJsonObject(); + if (result.get("status").getAsString().equals("FAILED")) { + var errorMsg = result + .get("errors").getAsJsonObject() + .get("error").getAsJsonArray() + .get(0).getAsString(); + + var batchErr = new BatchError(errorMsg, null, i); + errors.add(batchErr); + } + i++; + } + return new ReferenceAddManyResponse(errors); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java index 9228d7ec8..cb0771e0f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java @@ -1,26 +1,34 @@ package io.weaviate.client6.v1.api.collections.data; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; +import io.weaviate.client6.v1.api.collections.query.Where; +import io.weaviate.client6.v1.api.collections.query.WhereOperand; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; public class WeaviateDataClient { private final RestTransport restTransport; + private final GrpcTransport grpcTransport; private final CollectionDescriptor collectionDescriptor; private final WeaviateQueryClient query; public WeaviateDataClient(CollectionDescriptor collectionDescriptor, RestTransport restTransport, - WeaviateQueryClient query) { + GrpcTransport grpcTransport) { this.restTransport = restTransport; + this.grpcTransport = grpcTransport; this.collectionDescriptor = collectionDescriptor; - this.query = query; + this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); + } public WeaviateObject insert(T properties) throws IOException { @@ -33,6 +41,19 @@ public WeaviateObject insert(T properties, return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties, fn)); } + @SafeVarargs + public final InsertManyResponse insertMany(T... objects) { + return insertMany(InsertManyRequest.of(objects)); + } + + public InsertManyResponse insertMany(List> objects) { + return insertMany(new InsertManyRequest<>(objects)); + } + + public InsertManyResponse insertMany(InsertManyRequest request) { + return this.grpcTransport.performRequest(request, InsertManyRequest.rpc(request.objects(), collectionDescriptor)); + } + public WeaviateObject insert(InsertObjectRequest request) throws IOException { return this.restTransport.performRequest(request, InsertObjectRequest.endpoint(collectionDescriptor)); } @@ -58,6 +79,27 @@ public void delete(String uuid) throws IOException { DeleteObjectRequest._ENDPOINT); } + public DeleteManyResponse deleteMany(String... uuids) throws IOException { + var either = Arrays.stream(uuids) + .map(uuid -> (WhereOperand) Where.uuid().eq(uuid)) + .toList(); + return deleteMany(DeleteManyRequest.of(Where.or(either))); + } + + public DeleteManyResponse deleteMany(Where where) throws IOException { + return deleteMany(DeleteManyRequest.of(where)); + } + + public DeleteManyResponse deleteMany(Where where, + Function> fn) + throws IOException { + return deleteMany(DeleteManyRequest.of(where, fn)); + } + + public DeleteManyResponse deleteMany(DeleteManyRequest request) throws IOException { + return this.grpcTransport.performRequest(request, DeleteManyRequest.rpc(collectionDescriptor)); + } + public void referenceAdd(String fromUuid, String fromProperty, Reference reference) throws IOException { for (var uuid : reference.uuids()) { var singleRef = new Reference(reference.collection(), uuid); @@ -66,6 +108,15 @@ public void referenceAdd(String fromUuid, String fromProperty, Reference referen } } + public ReferenceAddManyResponse referenceAddMany(BatchReference... references) throws IOException { + return referenceAddMany(Arrays.asList(references)); + } + + public ReferenceAddManyResponse referenceAddMany(List references) throws IOException { + return this.restTransport.performRequest(new ReferenceAddManyRequest(references), + ReferenceAddManyRequest.endpoint(references)); + } + public void referenceDelete(String fromUuid, String fromProperty, Reference reference) throws IOException { for (var uuid : reference.uuids()) { var singleRef = new Reference(reference.collection(), uuid); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java index 8f43ca293..506020b4b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java @@ -1,7 +1,9 @@ package io.weaviate.client6.v1.api.collections.data; import java.io.IOException; +import java.util.Arrays; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; @@ -9,21 +11,26 @@ import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; +import io.weaviate.client6.v1.api.collections.query.Where; +import io.weaviate.client6.v1.api.collections.query.WhereOperand; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; public class WeaviateDataClientAsync { private final RestTransport restTransport; + private final GrpcTransport grpcTransport; private final CollectionDescriptor collectionDescriptor; private final WeaviateQueryClientAsync query; public WeaviateDataClientAsync(CollectionDescriptor collectionDescriptor, RestTransport restTransport, - WeaviateQueryClientAsync query) { + GrpcTransport grpcTransport) { this.restTransport = restTransport; + this.grpcTransport = grpcTransport; this.collectionDescriptor = collectionDescriptor; - this.query = query; + this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport); } public CompletableFuture> insert(T properties) throws IOException { @@ -41,6 +48,20 @@ public CompletableFuture> insert(Inser return this.restTransport.performRequestAsync(request, InsertObjectRequest.endpoint(collectionDescriptor)); } + @SafeVarargs + public final CompletableFuture insertMany(T... objects) { + return insertMany(InsertManyRequest.of(objects)); + } + + public CompletableFuture insertMany(List> objects) { + return insertMany(new InsertManyRequest<>(objects)); + } + + public CompletableFuture insertMany(InsertManyRequest request) { + return this.grpcTransport.performRequestAsync(request, + InsertManyRequest.rpc(request.objects(), collectionDescriptor)); + } + public CompletableFuture exists(String uuid) { return this.query.byId(uuid).thenApply(Optional::isPresent); } @@ -64,6 +85,27 @@ public CompletableFuture delete(String uuid) { DeleteObjectRequest._ENDPOINT); } + public CompletableFuture deleteMany(String... uuids) throws IOException { + var either = Arrays.stream(uuids) + .map(uuid -> (WhereOperand) Where.uuid().eq(uuid)) + .toList(); + return deleteMany(DeleteManyRequest.of(Where.or(either))); + } + + public CompletableFuture deleteMany(Where where) throws IOException { + return deleteMany(DeleteManyRequest.of(where)); + } + + public CompletableFuture deleteMany(Where where, + Function> fn) + throws IOException { + return deleteMany(DeleteManyRequest.of(where, fn)); + } + + public CompletableFuture deleteMany(DeleteManyRequest request) throws IOException { + return this.grpcTransport.performRequestAsync(request, DeleteManyRequest.rpc(collectionDescriptor)); + } + public CompletableFuture referenceAdd(String fromUuid, String fromProperty, Reference reference) { return forEachAsync(reference.uuids(), uuid -> { var singleRef = new Reference(reference.collection(), (String) uuid); @@ -72,6 +114,16 @@ public CompletableFuture referenceAdd(String fromUuid, String fromProperty }); } + public CompletableFuture referenceAddMany(BatchReference... references) throws IOException { + return referenceAddMany(Arrays.asList(references)); + } + + public CompletableFuture referenceAddMany(List references) + throws IOException { + return this.restTransport.performRequestAsync(new ReferenceAddManyRequest(references), + ReferenceAddManyRequest.endpoint(references)); + } + public CompletableFuture referenceDelete(String fromUuid, String fromProperty, Reference reference) { return forEachAsync(reference.uuids(), uuid -> { var singleRef = new Reference(reference.collection(), (String) uuid); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index 5d84e9f43..cc0017527 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -10,11 +10,11 @@ abstract class AbstractQueryClient { protected final CollectionDescriptor collection; - protected final GrpcTransport transport; + protected final GrpcTransport grpcTransport; - AbstractQueryClient(CollectionDescriptor collection, GrpcTransport transport) { + AbstractQueryClient(CollectionDescriptor collection, GrpcTransport grpcTransport) { this.collection = collection; - this.transport = transport; + this.grpcTransport = grpcTransport; } protected abstract SingleT byId(ById byId); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java index d5f2525c0..266e9ddbf 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java @@ -4,7 +4,7 @@ import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter; import io.weaviate.client6.v1.internal.ObjectBuilder; -import io.weaviate.client6.v1.internal.grpc.GRPC; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; @@ -58,7 +58,7 @@ WeaviateProtoBaseSearch.NearVector.Builder protoBuilder() { var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); nearVector.addVectors(WeaviateProtoBase.Vectors.newBuilder() .setType(WeaviateProtoBase.Vectors.VectorType.VECTOR_TYPE_SINGLE_FP32) - .setVectorBytes(GRPC.toByteString(vector))); + .setVectorBytes(ByteStringUtil.encodeVectorSingle(vector))); if (certainty != null) { nearVector.setCertainty(certainty); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 70348a797..12121f678 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -12,7 +12,7 @@ import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.internal.grpc.GRPC; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; @@ -153,10 +153,10 @@ private static WeaviateObject unmarshalReferences var vectorName = vector.getName(); switch (vector.getType()) { case VECTOR_TYPE_SINGLE_FP32: - vectors.vector(vectorName, GRPC.fromByteString(vector.getVectorBytes())); + vectors.vector(vectorName, ByteStringUtil.decodeVectorSingle(vector.getVectorBytes())); break; case VECTOR_TYPE_MULTI_FP32: - vectors.vector(vectorName, GRPC.fromByteStringMulti(vector.getVectorBytes())); + vectors.vector(vectorName, ByteStringUtil.decodeVectorMulti(vector.getVectorBytes())); break; default: continue; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java index c48e0171c..54801ca12 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java @@ -10,14 +10,14 @@ public class WeaviateQueryClient extends AbstractQueryClient>, QueryResponse, QueryResponseGrouped> { - public WeaviateQueryClient(CollectionDescriptor collection, GrpcTransport transport) { - super(collection, transport); + public WeaviateQueryClient(CollectionDescriptor collection, GrpcTransport grpcTransport) { + super(collection, grpcTransport); } @Override protected Optional> byId(ById byId) { var request = new QueryRequest(byId, null); - var result = this.transport.performRequest(request, QueryRequest.rpc(collection)); + var result = this.grpcTransport.performRequest(request, QueryRequest.rpc(collection)); return optionalFirst(result.objects()); } @@ -25,13 +25,13 @@ protected Optional> byId(ById byId) { @Override protected final QueryResponse performRequest(QueryOperator operator) { var request = new QueryRequest(operator, null); - return this.transport.performRequest(request, QueryRequest.rpc(collection)); + return this.grpcTransport.performRequest(request, QueryRequest.rpc(collection)); } @Override protected final QueryResponseGrouped performRequest(QueryOperator operator, GroupBy groupBy) { var request = new QueryRequest(operator, groupBy); - return this.transport.performRequest(request, QueryRequest.grouped(collection)); + return this.grpcTransport.performRequest(request, QueryRequest.grouped(collection)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java index 35e309e99..e8415314f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java @@ -11,28 +11,28 @@ public class WeaviateQueryClientAsync extends AbstractQueryClient>>, CompletableFuture>, CompletableFuture>> { - public WeaviateQueryClientAsync(CollectionDescriptor collection, GrpcTransport transport) { - super(collection, transport); + public WeaviateQueryClientAsync(CollectionDescriptor collection, GrpcTransport grpcTransport) { + super(collection, grpcTransport); } @Override protected CompletableFuture>> byId( ById byId) { var request = new QueryRequest(byId, null); - var result = this.transport.performRequestAsync(request, QueryRequest.rpc(collection)); + var result = this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection)); return result.thenApply(r -> optionalFirst(r.objects())); } @Override protected final CompletableFuture> performRequest(QueryOperator operator) { var request = new QueryRequest(operator, null); - return this.transport.performRequestAsync(request, QueryRequest.rpc(collection)); + return this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection)); } @Override protected final CompletableFuture> performRequest(QueryOperator operator, GroupBy groupBy) { var request = new QueryRequest(operator, groupBy); - return this.transport.performRequestAsync(request, QueryRequest.grouped(collection)); + return this.grpcTransport.performRequestAsync(request, QueryRequest.grouped(collection)); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java similarity index 81% rename from src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java rename to src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java index e5a6c0b5a..c4dbd7785 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java @@ -4,16 +4,25 @@ import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.Arrays; +import java.util.UUID; import org.apache.commons.lang3.ArrayUtils; import com.google.protobuf.ByteString; -public class GRPC { +public class ByteStringUtil { private static final ByteOrder BYTE_ORDER = ByteOrder.LITTLE_ENDIAN; + /** Decode ByteString to UUID. */ + public static UUID decodeUuid(ByteString bs) { + var buf = ByteBuffer.wrap(bs.toByteArray()); + var most = buf.getLong(); + var least = buf.getLong(); + return new UUID(most, least); + } + /** Encode Float[] to ByteString. */ - public static ByteString toByteString(Float[] vector) { + public static ByteString encodeVectorSingle(Float[] vector) { if (vector == null || vector.length == 0) { return ByteString.EMPTY; } @@ -23,7 +32,7 @@ public static ByteString toByteString(Float[] vector) { } /** Encode float[] to ByteString. */ - public static ByteString toByteString(float[] vector) { + public static ByteString encodeVectorSingle(float[] vector) { ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES).order(BYTE_ORDER); for (float f : vector) { buffer.putFloat(f); @@ -37,7 +46,7 @@ public static ByteString toByteString(float[] vector) { * The first 2 bytes of the resulting ByteString encode the number of dimensions * (uint16 / short) followed by concatenated vectors (4 bytes per element). */ - public static ByteString toByteString(Float[][] vectors) { + public static ByteString encodeVectorMulti(Float[][] vectors) { if (vectors == null || vectors.length == 0 || vectors[0].length == 0) { return ByteString.EMPTY; } @@ -56,7 +65,7 @@ public static ByteString toByteString(Float[][] vectors) { * Decode ByteString into a Float[]. ByteString size must be a multiple of * {@link Float#BYTES}, throws {@link IllegalArgumentException} otherwise. */ - public static Float[] fromByteString(ByteString bs) { + public static Float[] decodeVectorSingle(ByteString bs) { if (bs.size() % Float.BYTES != 0) { throw new IllegalArgumentException( "byte string size not a multiple of " + String.valueOf(Float.BYTES) + " (Float.BYTES)"); @@ -66,8 +75,8 @@ public static Float[] fromByteString(ByteString bs) { return ArrayUtils.toObject(vector); } - /** Decode ByteString into a Float[][]. */ - public static Float[][] fromByteStringMulti(ByteString bs) { + /** Decode ByteString to Float[][]. */ + public static Float[][] decodeVectorMulti(ByteString bs) { if (bs == null || bs.size() == 0) { return new Float[0][0]; } diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index 8285e574b..c58fa7072 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -9,6 +9,8 @@ public final class JSON { static { var gsonBuilder = new GsonBuilder(); + + // TypeAdapterFactories --------------------------------------------------- gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.WeaviateObject.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( @@ -24,12 +26,21 @@ public final class JSON { gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.Generative.CustomTypeAdapterFactory.INSTANCE); + // TypeAdapters ----------------------------------------------------------- gsonBuilder.registerTypeAdapter( io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.class, io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.TYPE_ADAPTER); gsonBuilder.registerTypeAdapter( io.weaviate.client6.v1.api.collections.data.Reference.class, io.weaviate.client6.v1.api.collections.data.Reference.TYPE_ADAPTER); + gsonBuilder.registerTypeAdapter( + io.weaviate.client6.v1.api.collections.data.BatchReference.class, + io.weaviate.client6.v1.api.collections.data.BatchReference.TYPE_ADAPTER); + + // Deserilizers ----------------------------------------------------------- + gsonBuilder.registerTypeAdapter( + io.weaviate.client6.v1.api.collections.data.ReferenceAddManyResponse.class, + io.weaviate.client6.v1.api.collections.data.ReferenceAddManyResponse.CustomJsonDeserializer.INSTANCE); gson = gsonBuilder.create(); } diff --git a/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java b/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java index d18f4a00e..1bc5d76a4 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.internal.grpc; import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.Test; @@ -15,34 +16,42 @@ */ public class GRPCTest { @Test - public void test_toBytesString_1d() { + public void test_encodeVector_1d() { Float[] vector = { 1f, 2f, 3f }; byte[] want = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; - byte[] got = GRPC.toByteString(vector).toByteArray(); + byte[] got = ByteStringUtil.encodeVectorSingle(vector).toByteArray(); assertArrayEquals(want, got); } @Test - public void test_fromBytesString_1d() { + public void test_decodeVector_1d() { byte[] bytes = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; Float[] want = { 1f, 2f, 3f }; - Float[] got = GRPC.fromByteString(ByteString.copyFrom(bytes)); + Float[] got = ByteStringUtil.decodeVectorSingle(ByteString.copyFrom(bytes)); assertArrayEquals(want, got); } @Test - public void test_toBytesString_2d() { + public void test_encodeVector_2d() { Float[][] vector = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; byte[] want = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; - byte[] got = GRPC.toByteString(vector).toByteArray(); + byte[] got = ByteStringUtil.encodeVectorMulti(vector).toByteArray(); assertArrayEquals(want, got); } @Test - public void test_fromBytesString_2d() { + public void test_decodeVector_2d() { byte[] bytes = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; Float[][] want = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; - Float[][] got = GRPC.fromByteStringMulti(ByteString.copyFrom(bytes)); + Float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); assertArrayEquals(want, got); } + + @Test + public void test_decodeUuid() { + byte[] bytes = { 38, 19, -74, 24, -114, -19, 73, 43, -112, -60, 47, 96, 83, -89, -35, -23 }; + String want = "2613b618-8eed-492b-90c4-2f6053a7dde9"; + String got = ByteStringUtil.decodeUuid(ByteString.copyFrom(bytes)).toString(); + assertEquals(want, got); + } } diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 8f4f211a1..e76516bcd 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -23,7 +23,9 @@ import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.BatchReference; import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.data.ReferenceAddManyResponse; import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; import io.weaviate.client6.v1.api.collections.vectorindex.Distance; import io.weaviate.client6.v1.api.collections.vectorindex.Flat; @@ -289,6 +291,17 @@ public static Object[][] testCases() { } """, }, + { + BatchReference.class, + new BatchReference("FromCollection", "fromProperty", "from-uuid", + Reference.collection("ToCollection", "to-uuid")), + """ + { + "from": "weaviate://localhost/FromCollection/from-uuid/fromProperty", + "to": "weaviate://localhost/ToCollection/to-uuid" + } + """, + }, }; } @@ -346,4 +359,26 @@ private static void compareVectors(Object got, Object want) { .withEqualsForType(Arrays::deepEquals, Float[][].class) .isEqualTo(want); } + + @Test + public void test_ReferenceAddManyResponse_CustomDeserializer() { + var json = """ + [ + { + "result": { "status": "SUCCESS", "errors": {} } + }, + { + "result": { "status": "FAILED", "errors": { + "error": [ "oops" ] + }} + } + ] + """; + + var got = JSON.deserialize(json, ReferenceAddManyResponse.class); + + Assertions.assertThat(got.errors()) + .as("response contains 1 error") + .hasSize(1); + } }