diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index fb262cbc2..c4c49737c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -7,34 +7,50 @@ import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; import io.weaviate.client6.v1.api.collections.pagination.Paginator; +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; -public class CollectionHandle { +public class CollectionHandle { public final WeaviateConfigClient config; - public final WeaviateDataClient data; - public final WeaviateQueryClient query; + public final WeaviateDataClient data; + public final WeaviateQueryClient query; public final WeaviateAggregateClient aggregate; + private final CollectionHandleDefaults defaults; + public CollectionHandle( RestTransport restTransport, GrpcTransport grpcTransport, - CollectionDescriptor collectionDescriptor) { - + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults defaults) { this.config = new WeaviateConfigClient(collectionDescriptor, restTransport, grpcTransport); - this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, grpcTransport); - this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); - this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport); + this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport, defaults); + this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport, defaults); + this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, grpcTransport, defaults); + + this.defaults = defaults; + } + + /** Copy constructor that sets new defaults. */ + private CollectionHandle(CollectionHandle c, CollectionHandleDefaults defaults) { + this.config = c.config; + this.aggregate = c.aggregate; + this.query = new WeaviateQueryClient<>(c.query, defaults); + this.data = new WeaviateDataClient<>(c.data, defaults); + + this.defaults = defaults; } - public Paginator paginate() { + public Paginator paginate() { return Paginator.of(this.query); } - public Paginator paginate(Function, ObjectBuilder>> fn) { + public Paginator paginate( + Function, ObjectBuilder>> fn) { return Paginator.of(this.query, fn); } @@ -57,4 +73,12 @@ public Paginator paginate(Function, ObjectBuilder all.includeTotalCount(true)).totalCount(); } + + public ConsistencyLevel consistencyLevel() { + return defaults.consistencyLevel(); + } + + public CollectionHandle withConsistencyLevel(ConsistencyLevel consistencyLevel) { + return new CollectionHandle<>(this, CollectionHandleDefaults.of(def -> def.consistencyLevel(consistencyLevel))); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 9a646d518..dccc85dcd 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -9,6 +9,7 @@ import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClientAsync; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClientAsync; import io.weaviate.client6.v1.api.collections.pagination.AsyncPaginator; +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; @@ -21,15 +22,30 @@ public class CollectionHandleAsync { public final WeaviateQueryClientAsync query; public final WeaviateAggregateClientAsync aggregate; + private final CollectionHandleDefaults defaults; + public CollectionHandleAsync( RestTransport restTransport, GrpcTransport grpcTransport, - CollectionDescriptor collectionDescriptor) { + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults defaults) { this.config = new WeaviateConfigClientAsync(collectionDescriptor, restTransport, grpcTransport); - this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, grpcTransport); - this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport); - this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport); + this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport, defaults); + this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport, defaults); + this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, grpcTransport, defaults); + + this.defaults = defaults; + } + + /** Copy constructor that sets new defaults. */ + private CollectionHandleAsync(CollectionHandleAsync c, CollectionHandleDefaults defaults) { + this.config = c.config; + this.aggregate = c.aggregate; + this.query = new WeaviateQueryClientAsync<>(c.query, defaults); + this.data = new WeaviateDataClientAsync<>(c.data, defaults); + + this.defaults = defaults; } public AsyncPaginator paginate() { @@ -64,4 +80,13 @@ public CompletableFuture size() { return this.aggregate.overAll(all -> all.includeTotalCount(true)) .thenApply(AggregateResponse::totalCount); } + + public ConsistencyLevel consistencyLevel() { + return defaults.consistencyLevel(); + } + + public CollectionHandleAsync withConsistencyLevel(ConsistencyLevel consistencyLevel) { + return new CollectionHandleAsync<>(this, CollectionHandleDefaults.of( + def -> def.consistencyLevel(consistencyLevel))); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java new file mode 100644 index 000000000..5512b491d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java @@ -0,0 +1,221 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; + +import com.google.common.util.concurrent.ListenableFuture; + +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatchDelete; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.rest.Endpoint; +import io.weaviate.client6.v1.internal.rest.EndpointBase; +import io.weaviate.client6.v1.internal.rest.JsonEndpoint; + +public record CollectionHandleDefaults(ConsistencyLevel consistencyLevel) { + private static final String CONSISTENCY_LEVEL = "consistency_level"; + + /** + * Set default values for query / aggregation requests. + * + * @return CollectionHandleDefaults derived from applying {@code fn} to + * {@link Builder}. + */ + public static CollectionHandleDefaults of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + /** + * Empty collection defaults. + * + * @return An tucked builder that does not leaves all defaults unset. + */ + public static Function> none() { + return ObjectBuilder.identity(); + } + + public CollectionHandleDefaults(Builder builder) { + this(builder.consistencyLevel); + } + + public static final class Builder implements ObjectBuilder { + private ConsistencyLevel consistencyLevel; + + /** Set default consistency level for this collection handle. */ + public Builder consistencyLevel(ConsistencyLevel consistencyLevel) { + this.consistencyLevel = consistencyLevel; + return this; + } + + @Override + public CollectionHandleDefaults build() { + return new CollectionHandleDefaults(this); + } + } + + public Endpoint endpoint(Endpoint ep, + Function, ObjectBuilder>> fn) { + return fn.apply(new EndpointBuilder<>(ep)).build(); + } + + public Rpc rpc( + Rpc rpc) { + return new ContextRpc<>(rpc); + } + + /** Which part of the request a parameter should be added to. */ + public static enum Location { + /** Query string. */ + QUERY, + /** + * Request body. {@code RequestT} must implement {@link WithDefaults} for the + * changes to be applied. + */ + BODY; + } + + public static interface WithDefaults> { + ConsistencyLevel consistencyLevel(); + + SelfT withConsistencyLevel(ConsistencyLevel consistencyLevel); + } + + private class ContextEndpoint extends EndpointBase + implements JsonEndpoint { + + private final Location consistencyLevelLoc; + private final Endpoint endpoint; + + ContextEndpoint(EndpointBuilder builder) { + super(builder.endpoint::method, + builder.endpoint::requestUrl, + builder.endpoint::queryParameters, + builder.endpoint::body); + this.consistencyLevelLoc = builder.consistencyLevelLoc; + this.endpoint = builder.endpoint; + } + + /** Return consistencyLevel of the enclosing CollectionHandleDefaults object. */ + private ConsistencyLevel consistencyLevel() { + return CollectionHandleDefaults.this.consistencyLevel; + } + + @Override + public Map queryParameters(RequestT request) { + // Copy the map, as it's most likely unmodifiable. + var query = new HashMap<>(super.queryParameters(request)); + if (consistencyLevel() != null && consistencyLevelLoc == Location.QUERY) { + query.putIfAbsent(CONSISTENCY_LEVEL, consistencyLevel()); + } + return query; + } + + @SuppressWarnings("unchecked") + @Override + public String body(RequestT request) { + if (request instanceof WithDefaults wd) { + if (wd.consistencyLevel() == null) { + wd = wd.withConsistencyLevel(consistencyLevel()); + } + // This cast is safe as long as `wd` returns its own type, + // which it does as per the interface contract. + request = (RequestT) wd; + } + return super.body(request); + } + + @Override + public ResponseT deserializeResponse(int statusCode, String responseBody) { + return EndpointBase.deserializeResponse(endpoint, statusCode, responseBody); + } + } + + /** + * EndpointBuilder configures how CollectionHandleDefautls + * are added to a REST request. + */ + public class EndpointBuilder implements ObjectBuilder> { + private final Endpoint endpoint; + + private Location consistencyLevelLoc; + + EndpointBuilder(Endpoint ep) { + this.endpoint = ep; + } + + /** Control which part of the request to add default consistency level to. */ + public EndpointBuilder consistencyLevel(Location loc) { + this.consistencyLevelLoc = loc; + return this; + } + + @Override + public Endpoint build() { + return new ContextEndpoint<>(this); + } + } + + private class ContextRpc + implements Rpc { + + private final Rpc rpc; + + ContextRpc(Rpc rpc) { + this.rpc = rpc; + } + + /** Return consistencyLevel of the enclosing CollectionHandleDefaults object. */ + private ConsistencyLevel consistencyLevel() { + return CollectionHandleDefaults.this.consistencyLevel; + } + + @SuppressWarnings("unchecked") + @Override + public RequestM marshal(RequestT request) { + var message = rpc.marshal(request); + if (message instanceof WeaviateProtoBatch.BatchObjectsRequest msg) { + var b = msg.toBuilder(); + if (!msg.hasConsistencyLevel() && consistencyLevel() != null) { + consistencyLevel().appendTo(b); + return (RequestM) b.build(); + } + } else if (message instanceof WeaviateProtoBatchDelete.BatchDeleteRequest msg) { + var b = msg.toBuilder(); + if (!msg.hasConsistencyLevel() && consistencyLevel() != null) { + consistencyLevel().appendTo(b); + return (RequestM) b.build(); + } + } else if (message instanceof WeaviateProtoSearchGet.SearchRequest msg) { + var b = msg.toBuilder(); + if (!msg.hasConsistencyLevel() && consistencyLevel() != null) { + consistencyLevel().appendTo(b); + return (RequestM) b.build(); + } + } + + return message; + } + + @Override + public ResponseT unmarshal(ReplyM reply) { + return rpc.unmarshal(reply); + } + + @Override + public BiFunction method() { + return rpc.method(); + } + + @Override + public BiFunction> methodAsync() { + return rpc.methodAsync(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java index b5b6c68bb..b0cc27d2e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java @@ -3,7 +3,6 @@ import java.io.IOException; import java.util.EnumMap; import java.util.Map; -import java.util.function.Function; import com.google.gson.Gson; import com.google.gson.JsonObject; @@ -20,7 +19,6 @@ import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer; -import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JsonEnum; public interface Vectorizer { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java index d8c175692..f71d7e9e9 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java @@ -29,7 +29,24 @@ public WeaviateCollectionsClient(RestTransport restTransport, GrpcTransport grpc * properties. */ public CollectionHandle> use(String collectionName) { - return new CollectionHandle<>(restTransport, grpcTransport, CollectionDescriptor.ofMap(collectionName)); + return use(collectionName, CollectionHandleDefaults.none()); + } + + /** + * Obtain a handle to send requests to a particular collection. + * The returned object is thread-safe. + * + * @return a handle for a collection with {@code Map} + * properties. + */ + public CollectionHandle> use( + String collectionName, + Function> fn) { + return new CollectionHandle<>( + restTransport, + grpcTransport, + CollectionDescriptor.ofMap(collectionName), + CollectionHandleDefaults.of(fn)); } /** diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java index d357d56cc..4fc449cf1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java @@ -22,8 +22,17 @@ public WeaviateCollectionsClientAsync(RestTransport restTransport, GrpcTransport } public CollectionHandleAsync> use(String collectionName) { - return new CollectionHandleAsync<>(restTransport, grpcTransport, - CollectionDescriptor.ofMap(collectionName)); + return use(collectionName, CollectionHandleDefaults.none()); + } + + public CollectionHandleAsync> use( + String collectionName, + Function> fn) { + return new CollectionHandleAsync<>( + restTransport, + grpcTransport, + CollectionDescriptor.ofMap(collectionName), + CollectionHandleDefaults.of(fn)); } public CompletableFuture create(String name) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java index 4258947bd..3b557a244 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java @@ -3,6 +3,7 @@ import java.util.List; import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.query.Hybrid; import io.weaviate.client6.v1.api.collections.query.NearAudio; import io.weaviate.client6.v1.api.collections.query.NearDepth; @@ -20,10 +21,15 @@ abstract class AbstractAggregateClient { protected final CollectionDescriptor collection; protected final GrpcTransport transport; + protected final CollectionHandleDefaults defaults; - AbstractAggregateClient(CollectionDescriptor collection, GrpcTransport transport) { + AbstractAggregateClient( + CollectionDescriptor collection, + GrpcTransport transport, + CollectionHandleDefaults defaults) { this.transport = transport; this.collection = collection; + this.defaults = defaults; } protected abstract ResponseT performRequest(Aggregation aggregation); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClient.java index 8f61720f4..aeb769e76 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClient.java @@ -1,12 +1,16 @@ package io.weaviate.client6.v1.api.collections.aggregate; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; public class WeaviateAggregateClient extends AbstractAggregateClient { - public WeaviateAggregateClient(CollectionDescriptor collection, GrpcTransport transport) { - super(collection, transport); + public WeaviateAggregateClient( + CollectionDescriptor collection, + GrpcTransport transport, + CollectionHandleDefaults defaults) { + super(collection, transport, defaults); } protected final AggregateResponse performRequest(Aggregation aggregation) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClientAsync.java index cdb138867..5a87cfdbc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/WeaviateAggregateClientAsync.java @@ -2,14 +2,18 @@ import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; public class WeaviateAggregateClientAsync extends AbstractAggregateClient, CompletableFuture> { - public WeaviateAggregateClientAsync(CollectionDescriptor collection, GrpcTransport transport) { - super(collection, transport); + public WeaviateAggregateClientAsync( + CollectionDescriptor collection, + GrpcTransport transport, + CollectionHandleDefaults defaults) { + super(collection, transport, defaults); } protected final CompletableFuture performRequest(Aggregation aggregation) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java index fe4788481..5cbb64465 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java @@ -2,6 +2,7 @@ import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.query.Where; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; @@ -15,44 +16,46 @@ public record DeleteManyRequest(Where where, Boolean verbose, Boolean dryRun) { public static Rpc rpc( - CollectionDescriptor collectionDescriptor) { - return Rpc - .of( - request -> { - var message = WeaviateProtoBatchDelete.BatchDeleteRequest.newBuilder(); - message.setCollection(collectionDescriptor.name()); - - if (request.verbose != null) { - message.setVerbose(request.verbose); - } - if (request.dryRun != null) { - message.setDryRun(request.dryRun); - } - - var filters = WeaviateProtoBase.Filters.newBuilder(); - request.where.appendTo(filters); - message.setFilters(filters); - - return message.build(); - }, - reply -> { - var objects = reply.getObjectsList() - .stream() - .map(obj -> new DeleteManyResponse.DeletedObject( - ByteStringUtil.decodeUuid(obj.getUuid()).toString(), - obj.getSuccessful(), - obj.getError())) - .toList(); - - return new DeleteManyResponse( - reply.getTook(), - reply.getFailed(), - reply.getMatches(), - reply.getSuccessful(), - objects); - }, - () -> WeaviateBlockingStub::batchDelete, - () -> WeaviateFutureStub::batchDelete); + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.rpc( + Rpc + .of( + request -> { + var message = WeaviateProtoBatchDelete.BatchDeleteRequest.newBuilder(); + message.setCollection(collection.name()); + + if (request.verbose != null) { + message.setVerbose(request.verbose); + } + if (request.dryRun != null) { + message.setDryRun(request.dryRun); + } + + var filters = WeaviateProtoBase.Filters.newBuilder(); + request.where.appendTo(filters); + message.setFilters(filters); + + return message.build(); + }, + reply -> { + var objects = reply.getObjectsList() + .stream() + .map(obj -> new DeleteManyResponse.DeletedObject( + ByteStringUtil.decodeUuid(obj.getUuid()).toString(), + obj.getSuccessful(), + obj.getError())) + .toList(); + + return new DeleteManyResponse( + reply.getTook(), + reply.getFailed(), + reply.getMatches(), + reply.getSuccessful(), + objects); + }, + () -> WeaviateBlockingStub::batchDelete, + () -> WeaviateFutureStub::batchDelete)); } public static DeleteManyRequest of(Where where) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java index 217a27682..ecb2adb2d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java @@ -2,13 +2,23 @@ import java.util.Collections; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; -public record DeleteObjectRequest(String collectionName, String uuid) { +public record DeleteObjectRequest(String uuid) { - public static final Endpoint _ENDPOINT = SimpleEndpoint.sideEffect( - request -> "DELETE", - request -> "/objects/" + request.collectionName + "/" + request.uuid, - request -> Collections.emptyMap()); + public static final Endpoint endpoint( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + SimpleEndpoint.sideEffect( + request -> "DELETE", + request -> "/objects/" + collection.name() + "/" + request.uuid, + request -> Collections.emptyMap()), + add -> add + .consistencyLevel(Location.QUERY)); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java index 48c41ebec..d270fa336 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java @@ -6,6 +6,7 @@ import java.util.UUID; import java.util.stream.Collectors; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; @@ -37,50 +38,52 @@ public static final InsertManyRequest of(T... properties) { public static Rpc, WeaviateProtoBatch.BatchObjectsRequest, InsertManyResponse, WeaviateProtoBatch.BatchObjectsReply> rpc( List> insertObjects, - CollectionDescriptor collectionsDescriptor) { - return Rpc.of( - request -> { - var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); - - var batch = request.objects.stream().map(obj -> { - var batchObject = WeaviateProtoBatch.BatchObject.newBuilder(); - buildObject(batchObject, obj, collectionsDescriptor); - return batchObject.build(); - }).toList(); - - message.addAllObjects(batch); - return message.build(); - }, - response -> { - var insertErrors = response.getErrorsList(); - - var responses = new ArrayList(insertObjects.size()); - var errors = new ArrayList(insertErrors.size()); - var uuids = new ArrayList(); - - var failed = insertErrors.stream() - .collect(Collectors.toMap(err -> err.getIndex(), err -> err.getError())); - - var iter = insertObjects.listIterator(); - while (iter.hasNext()) { - var idx = iter.nextIndex(); - var next = iter.next(); - var uuid = next.metadata() != null ? next.metadata().uuid() : null; - - if (failed.containsKey(idx)) { - var err = failed.get(idx); - errors.add(err); - responses.add(new InsertManyResponse.InsertObject(uuid, false, err)); - } else { - uuids.add(uuid); - responses.add(new InsertManyResponse.InsertObject(uuid, true, null)); - } - } + CollectionDescriptor collectionsDescriptor, + CollectionHandleDefaults defaults) { + return defaults.rpc( + Rpc.of( + request -> { + var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); + + var batch = request.objects.stream().map(obj -> { + var batchObject = WeaviateProtoBatch.BatchObject.newBuilder(); + buildObject(batchObject, obj, collectionsDescriptor); + return batchObject.build(); + }).toList(); + + message.addAllObjects(batch); + return message.build(); + }, + response -> { + var insertErrors = response.getErrorsList(); + + var responses = new ArrayList(insertObjects.size()); + var errors = new ArrayList(insertErrors.size()); + var uuids = new ArrayList(); + + var failed = insertErrors.stream() + .collect(Collectors.toMap(err -> err.getIndex(), err -> err.getError())); + + var iter = insertObjects.listIterator(); + while (iter.hasNext()) { + var idx = iter.nextIndex(); + var next = iter.next(); + var uuid = next.metadata() != null ? next.metadata().uuid() : null; + + if (failed.containsKey(idx)) { + var err = failed.get(idx); + errors.add(err); + responses.add(new InsertManyResponse.InsertObject(uuid, false, err)); + } else { + uuids.add(uuid); + responses.add(new InsertManyResponse.InsertObject(uuid, true, null)); + } + } - return new InsertManyResponse(response.getTook(), responses, uuids, errors); - }, - () -> WeaviateBlockingStub::batchObjects, - () -> WeaviateFutureStub::batchObjects); + return new InsertManyResponse(response.getTook(), responses, uuids, errors); + }, + () -> WeaviateBlockingStub::batchObjects, + () -> WeaviateFutureStub::batchObjects)); } public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java index b1b460b11..97fecddc8 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -5,6 +5,8 @@ import com.google.gson.reflect.TypeToken; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; @@ -18,16 +20,20 @@ public record InsertObjectRequest(WeaviateObject Endpoint, WeaviateObject> endpoint( - CollectionDescriptor descriptor) { - return new SimpleEndpoint<>( - request -> "POST", - request -> "/objects/", - request -> Collections.emptyMap(), - request -> JSON.serialize(request.object, TypeToken.getParameterized( - WeaviateObject.class, descriptor.typeToken().getType(), Reference.class, ObjectMetadata.class)), - (statusCode, response) -> JSON.deserialize(response, - (TypeToken>) TypeToken.getParameterized( - WeaviateObject.class, descriptor.typeToken().getType(), Object.class, ObjectMetadata.class))); + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + new SimpleEndpoint<>( + request -> "POST", + request -> "/objects/", + request -> Collections.emptyMap(), + request -> JSON.serialize(request.object, TypeToken.getParameterized( + WeaviateObject.class, collection.typeToken().getType(), Reference.class, ObjectMetadata.class)), + (statusCode, response) -> JSON.deserialize(response, + (TypeToken>) TypeToken.getParameterized( + WeaviateObject.class, collection.typeToken().getType(), Object.class, ObjectMetadata.class))), + add -> add + .consistencyLevel(Location.QUERY)); } public static InsertObjectRequest of(String collectionName, T properties) { @@ -72,4 +78,5 @@ public InsertObjectRequest build() { return new InsertObjectRequest<>(this); } } + } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java index 0530d23e9..72b673bba 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java @@ -4,6 +4,8 @@ import java.util.Collections; import java.util.List; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; @@ -11,24 +13,28 @@ public record ReferenceAddManyRequest(List references) { public static final Endpoint endpoint( - List references) { - return new SimpleEndpoint<>( - request -> "POST", - request -> "/batch/references", - request -> Collections.emptyMap(), - request -> JSON.serialize(request.references), - (statusCode, response) -> { - var result = JSON.deserialize(response, ReferenceAddManyResponse.class); - var errors = new ArrayList(); + List references, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + new SimpleEndpoint<>( + request -> "POST", + request -> "/batch/references", + request -> Collections.emptyMap(), + request -> JSON.serialize(request.references), + (statusCode, response) -> { + var result = JSON.deserialize(response, ReferenceAddManyResponse.class); + var errors = new ArrayList(); - for (var err : result.errors()) { - errors.add(new ReferenceAddManyResponse.BatchError( - err.message(), - references.get(err.referenceIndex()), - err.referenceIndex())); - } - return new ReferenceAddManyResponse(errors); - }); + for (var err : result.errors()) { + errors.add(new ReferenceAddManyResponse.BatchError( + err.message(), + references.get(err.referenceIndex()), + err.referenceIndex())); + } + return new ReferenceAddManyResponse(errors); + }), + add -> add + .consistencyLevel(Location.QUERY)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddRequest.java index 5da29e0fe..5fea35c24 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddRequest.java @@ -2,6 +2,8 @@ import java.util.Collections; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.Endpoint; @@ -10,11 +12,16 @@ public record ReferenceAddRequest(String fromUuid, String fromProperty, Reference reference) { public static final Endpoint endpoint( - CollectionDescriptor descriptor) { - return SimpleEndpoint.sideEffect( - request -> "POST", - request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, - request -> Collections.emptyMap(), - request -> JSON.serialize(request.reference)); + CollectionDescriptor descriptor, + CollectionHandleDefaults defautls) { + return defautls.endpoint( + SimpleEndpoint.sideEffect( + request -> "POST", + request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, + request -> Collections.emptyMap(), + request -> JSON.serialize(request.reference)), + add -> add + .consistencyLevel(Location.QUERY)); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceDeleteRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceDeleteRequest.java index f7f037e23..5038e0812 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceDeleteRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceDeleteRequest.java @@ -2,6 +2,8 @@ import java.util.Collections; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.Endpoint; @@ -10,11 +12,15 @@ public record ReferenceDeleteRequest(String fromUuid, String fromProperty, Reference reference) { public static final Endpoint endpoint( - CollectionDescriptor descriptor) { - return SimpleEndpoint.sideEffect( - request -> "DELETE", - request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, - request -> Collections.emptyMap(), - request -> JSON.serialize(request.reference)); + CollectionDescriptor descriptor, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + SimpleEndpoint.sideEffect( + request -> "DELETE", + request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, + request -> Collections.emptyMap(), + request -> JSON.serialize(request.reference)), + add -> add + .consistencyLevel(Location.QUERY)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceReplaceRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceReplaceRequest.java index 746fe6966..8d8aaf1e2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceReplaceRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceReplaceRequest.java @@ -3,6 +3,8 @@ import java.util.Collections; import java.util.List; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.Endpoint; @@ -11,11 +13,15 @@ public record ReferenceReplaceRequest(String fromUuid, String fromProperty, Reference reference) { public static final Endpoint endpoint( - CollectionDescriptor descriptor) { - return SimpleEndpoint.sideEffect( - request -> "PUT", - request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, - request -> Collections.emptyMap(), - request -> JSON.serialize(List.of(request.reference))); + CollectionDescriptor descriptor, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + SimpleEndpoint.sideEffect( + request -> "PUT", + request -> "/objects/" + descriptor.name() + "/" + request.fromUuid + "/references/" + request.fromProperty, + request -> Collections.emptyMap(), + request -> JSON.serialize(List.of(request.reference))), + add -> add + .consistencyLevel(Location.QUERY)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java index 09704839b..1ecb5455c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java @@ -5,6 +5,8 @@ import com.google.gson.reflect.TypeToken; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; @@ -16,13 +18,17 @@ public record ReplaceObjectRequest(WeaviateObject object) { - static final Endpoint, Void> endpoint(CollectionDescriptor collectionDescriptor) { - return SimpleEndpoint.sideEffect( - request -> "PUT", - request -> "/objects/" + collectionDescriptor.name() + "/" + request.object.metadata().uuid(), - request -> Collections.emptyMap(), - request -> JSON.serialize(request.object, TypeToken.getParameterized( - WeaviateObject.class, collectionDescriptor.typeToken().getType(), Reference.class, ObjectMetadata.class))); + static final Endpoint, Void> endpoint(CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + SimpleEndpoint.sideEffect( + request -> "PUT", + request -> "/objects/" + collection.name() + "/" + request.object.metadata().uuid(), + request -> Collections.emptyMap(), + request -> JSON.serialize(request.object, TypeToken.getParameterized( + WeaviateObject.class, collection.typeToken().getType(), Reference.class, ObjectMetadata.class))), + add -> add + .consistencyLevel(Location.QUERY)); } public static ReplaceObjectRequest of(String collectionName, String uuid, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java index f1f64022d..28423c752 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java @@ -5,6 +5,8 @@ import com.google.gson.reflect.TypeToken; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; @@ -16,13 +18,17 @@ public record UpdateObjectRequest(WeaviateObject object) { - static final Endpoint, Void> endpoint(CollectionDescriptor collectionDescriptor) { - return SimpleEndpoint.sideEffect( - request -> "PATCH", - request -> "/objects/" + collectionDescriptor.name() + "/" + request.object.metadata().uuid(), - request -> Collections.emptyMap(), - request -> JSON.serialize(request.object, TypeToken.getParameterized( - WeaviateObject.class, collectionDescriptor.typeToken().getType(), Reference.class, ObjectMetadata.class))); + static final Endpoint, Void> endpoint(CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.endpoint( + SimpleEndpoint.sideEffect( + request -> "PATCH", + request -> "/objects/" + collection.name() + "/" + request.object.metadata().uuid(), + request -> Collections.emptyMap(), + request -> JSON.serialize(request.object, TypeToken.getParameterized( + WeaviateObject.class, collection.typeToken().getType(), Reference.class, ObjectMetadata.class))), + add -> add + .consistencyLevel(Location.QUERY)); } public static UpdateObjectRequest of(String collectionName, String uuid, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java index cb0771e0f..6cda6883c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java @@ -5,6 +5,7 @@ import java.util.List; import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; @@ -15,96 +16,112 @@ import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; -public class WeaviateDataClient { +public class WeaviateDataClient { private final RestTransport restTransport; private final GrpcTransport grpcTransport; - private final CollectionDescriptor collectionDescriptor; + private final CollectionDescriptor collection; - private final WeaviateQueryClient query; + private final WeaviateQueryClient query; + private final CollectionHandleDefaults defaults; - public WeaviateDataClient(CollectionDescriptor collectionDescriptor, RestTransport restTransport, - GrpcTransport grpcTransport) { + public WeaviateDataClient( + CollectionDescriptor collectionDescriptor, + RestTransport restTransport, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { this.restTransport = restTransport; this.grpcTransport = grpcTransport; - this.collectionDescriptor = collectionDescriptor; - this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); + this.collection = collectionDescriptor; + this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport, defaults); + this.defaults = defaults; + } + /** Copy constructor that updates the {@link #query} to use new defaults. */ + public WeaviateDataClient(WeaviateDataClient c, CollectionHandleDefaults defaults) { + this.restTransport = c.restTransport; + this.grpcTransport = c.grpcTransport; + this.collection = c.collection; + this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); + this.defaults = defaults; } - public WeaviateObject insert(T properties) throws IOException { - return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties)); + public WeaviateObject insert(PropertiesT properties) throws IOException { + return insert(InsertObjectRequest.of(collection.name(), properties)); } - public WeaviateObject insert(T properties, - Function, ObjectBuilder>> fn) + public WeaviateObject insert(PropertiesT properties, + Function, ObjectBuilder>> fn) throws IOException { - return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties, fn)); + return insert(InsertObjectRequest.of(collection.name(), properties, fn)); } @SafeVarargs - public final InsertManyResponse insertMany(T... objects) { + public final InsertManyResponse insertMany(PropertiesT... objects) { return insertMany(InsertManyRequest.of(objects)); } - public InsertManyResponse insertMany(List> objects) { + public InsertManyResponse insertMany(List> objects) { return insertMany(new InsertManyRequest<>(objects)); } - public InsertManyResponse insertMany(InsertManyRequest request) { - return this.grpcTransport.performRequest(request, InsertManyRequest.rpc(request.objects(), collectionDescriptor)); + public InsertManyResponse insertMany(InsertManyRequest request) { + return this.grpcTransport.performRequest(request, + InsertManyRequest.rpc(request.objects(), collection, defaults)); } - public WeaviateObject insert(InsertObjectRequest request) throws IOException { - return this.restTransport.performRequest(request, InsertObjectRequest.endpoint(collectionDescriptor)); + public WeaviateObject insert(InsertObjectRequest request) + throws IOException { + return this.restTransport.performRequest(request, InsertObjectRequest.endpoint(collection, defaults)); } - public boolean exists(String uuid) throws IOException { + public boolean exists(String uuid) { return this.query.byId(uuid).isPresent(); } - public void update(String uuid, Function, ObjectBuilder>> fn) + public void update(String uuid, + Function, ObjectBuilder>> fn) throws IOException { - this.restTransport.performRequest(UpdateObjectRequest.of(collectionDescriptor.name(), uuid, fn), - UpdateObjectRequest.endpoint(collectionDescriptor)); + this.restTransport.performRequest(UpdateObjectRequest.of(collection.name(), uuid, fn), + UpdateObjectRequest.endpoint(collection, defaults)); } - public void replace(String uuid, Function, ObjectBuilder>> fn) + public void replace(String uuid, + Function, ObjectBuilder>> fn) throws IOException { - this.restTransport.performRequest(ReplaceObjectRequest.of(collectionDescriptor.name(), uuid, fn), - ReplaceObjectRequest.endpoint(collectionDescriptor)); + this.restTransport.performRequest(ReplaceObjectRequest.of(collection.name(), uuid, fn), + ReplaceObjectRequest.endpoint(collection, defaults)); } public void delete(String uuid) throws IOException { - this.restTransport.performRequest(new DeleteObjectRequest(collectionDescriptor.name(), uuid), - DeleteObjectRequest._ENDPOINT); + this.restTransport.performRequest(new DeleteObjectRequest(uuid), + DeleteObjectRequest.endpoint(collection, defaults)); } - public DeleteManyResponse deleteMany(String... uuids) throws IOException { + public DeleteManyResponse deleteMany(String... uuids) { var either = Arrays.stream(uuids) .map(uuid -> (WhereOperand) Where.uuid().eq(uuid)) .toList(); return deleteMany(DeleteManyRequest.of(Where.or(either))); } - public DeleteManyResponse deleteMany(Where where) throws IOException { + public DeleteManyResponse deleteMany(Where where) { return deleteMany(DeleteManyRequest.of(where)); } public DeleteManyResponse deleteMany(Where where, - Function> fn) - throws IOException { + Function> fn) { return deleteMany(DeleteManyRequest.of(where, fn)); } - public DeleteManyResponse deleteMany(DeleteManyRequest request) throws IOException { - return this.grpcTransport.performRequest(request, DeleteManyRequest.rpc(collectionDescriptor)); + public DeleteManyResponse deleteMany(DeleteManyRequest request) { + return this.grpcTransport.performRequest(request, DeleteManyRequest.rpc(collection, defaults)); } public void referenceAdd(String fromUuid, String fromProperty, Reference reference) throws IOException { for (var uuid : reference.uuids()) { var singleRef = new Reference(reference.collection(), uuid); this.restTransport.performRequest(new ReferenceAddRequest(fromUuid, fromProperty, singleRef), - ReferenceAddRequest.endpoint(collectionDescriptor)); + ReferenceAddRequest.endpoint(collection, defaults)); } } @@ -114,14 +131,14 @@ public ReferenceAddManyResponse referenceAddMany(BatchReference... references) t public ReferenceAddManyResponse referenceAddMany(List references) throws IOException { return this.restTransport.performRequest(new ReferenceAddManyRequest(references), - ReferenceAddManyRequest.endpoint(references)); + ReferenceAddManyRequest.endpoint(references, defaults)); } public void referenceDelete(String fromUuid, String fromProperty, Reference reference) throws IOException { for (var uuid : reference.uuids()) { var singleRef = new Reference(reference.collection(), uuid); this.restTransport.performRequest(new ReferenceDeleteRequest(fromUuid, fromProperty, singleRef), - ReferenceDeleteRequest.endpoint(collectionDescriptor)); + ReferenceDeleteRequest.endpoint(collection, defaults)); } } @@ -129,7 +146,7 @@ public void referenceReplace(String fromUuid, String fromProperty, Reference ref for (var uuid : reference.uuids()) { var singleRef = new Reference(reference.collection(), uuid); this.restTransport.performRequest(new ReferenceReplaceRequest(fromUuid, fromProperty, singleRef), - ReferenceReplaceRequest.endpoint(collectionDescriptor)); + ReferenceReplaceRequest.endpoint(collection, defaults)); } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java index 506020b4b..f85696a5f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java @@ -1,6 +1,5 @@ package io.weaviate.client6.v1.api.collections.data; -import java.io.IOException; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -8,6 +7,7 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; @@ -18,48 +18,62 @@ import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; -public class WeaviateDataClientAsync { +public class WeaviateDataClientAsync { private final RestTransport restTransport; private final GrpcTransport grpcTransport; - private final CollectionDescriptor collectionDescriptor; + private final CollectionDescriptor collection; - private final WeaviateQueryClientAsync query; + private final WeaviateQueryClientAsync query; + private final CollectionHandleDefaults defaults; - public WeaviateDataClientAsync(CollectionDescriptor collectionDescriptor, RestTransport restTransport, - GrpcTransport grpcTransport) { + public WeaviateDataClientAsync( + CollectionDescriptor collectionDescriptor, + RestTransport restTransport, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { this.restTransport = restTransport; this.grpcTransport = grpcTransport; - this.collectionDescriptor = collectionDescriptor; - this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport); + this.collection = collectionDescriptor; + this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport, defaults); + this.defaults = defaults; } - public CompletableFuture> insert(T properties) throws IOException { - return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties)); + /** Copy constructor that updates the {@link #query} to use new defaults. */ + public WeaviateDataClientAsync(WeaviateDataClientAsync c, CollectionHandleDefaults defaults) { + this.restTransport = c.restTransport; + this.grpcTransport = c.grpcTransport; + this.collection = c.collection; + this.query = new WeaviateQueryClientAsync<>(collection, grpcTransport, defaults); + this.defaults = defaults; } - public CompletableFuture> insert(T properties, - Function, ObjectBuilder>> fn) - throws IOException { - return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties, fn)); + public CompletableFuture> insert(PropertiesT properties) { + return insert(InsertObjectRequest.of(collection.name(), properties)); } - public CompletableFuture> insert(InsertObjectRequest request) - throws IOException { - return this.restTransport.performRequestAsync(request, InsertObjectRequest.endpoint(collectionDescriptor)); + public CompletableFuture> insert(PropertiesT properties, + Function, ObjectBuilder>> fn) { + return insert(InsertObjectRequest.of(collection.name(), properties, fn)); + } + + public CompletableFuture> insert( + InsertObjectRequest request) { + return this.restTransport.performRequestAsync(request, InsertObjectRequest.endpoint(collection, defaults)); } @SafeVarargs - public final CompletableFuture insertMany(T... objects) { + public final CompletableFuture insertMany(PropertiesT... objects) { return insertMany(InsertManyRequest.of(objects)); } - public CompletableFuture insertMany(List> objects) { + public CompletableFuture insertMany( + List> objects) { return insertMany(new InsertManyRequest<>(objects)); } - public CompletableFuture insertMany(InsertManyRequest request) { + public CompletableFuture insertMany(InsertManyRequest request) { return this.grpcTransport.performRequestAsync(request, - InsertManyRequest.rpc(request.objects(), collectionDescriptor)); + InsertManyRequest.rpc(request.objects(), collection, defaults)); } public CompletableFuture exists(String uuid) { @@ -67,68 +81,64 @@ public CompletableFuture exists(String uuid) { } public CompletableFuture update(String uuid, - Function, ObjectBuilder>> fn) - throws IOException { - return this.restTransport.performRequestAsync(UpdateObjectRequest.of(collectionDescriptor.name(), uuid, fn), - UpdateObjectRequest.endpoint(collectionDescriptor)); + Function, ObjectBuilder>> fn) { + return this.restTransport.performRequestAsync(UpdateObjectRequest.of(collection.name(), uuid, fn), + UpdateObjectRequest.endpoint(collection, defaults)); } public CompletableFuture replace(String uuid, - Function, ObjectBuilder>> fn) - throws IOException { - return this.restTransport.performRequestAsync(ReplaceObjectRequest.of(collectionDescriptor.name(), uuid, fn), - ReplaceObjectRequest.endpoint(collectionDescriptor)); + Function, ObjectBuilder>> fn) { + return this.restTransport.performRequestAsync(ReplaceObjectRequest.of(collection.name(), uuid, fn), + ReplaceObjectRequest.endpoint(collection, defaults)); } public CompletableFuture delete(String uuid) { - return this.restTransport.performRequestAsync(new DeleteObjectRequest(collectionDescriptor.name(), uuid), - DeleteObjectRequest._ENDPOINT); + return this.restTransport.performRequestAsync(new DeleteObjectRequest(uuid), + DeleteObjectRequest.endpoint(collection, defaults)); } - public CompletableFuture deleteMany(String... uuids) throws IOException { + public CompletableFuture deleteMany(String... uuids) { var either = Arrays.stream(uuids) .map(uuid -> (WhereOperand) Where.uuid().eq(uuid)) .toList(); return deleteMany(DeleteManyRequest.of(Where.or(either))); } - public CompletableFuture deleteMany(Where where) throws IOException { + public CompletableFuture deleteMany(Where where) { return deleteMany(DeleteManyRequest.of(where)); } public CompletableFuture deleteMany(Where where, - Function> fn) - throws IOException { + Function> fn) { return deleteMany(DeleteManyRequest.of(where, fn)); } - public CompletableFuture deleteMany(DeleteManyRequest request) throws IOException { - return this.grpcTransport.performRequestAsync(request, DeleteManyRequest.rpc(collectionDescriptor)); + public CompletableFuture deleteMany(DeleteManyRequest request) { + return this.grpcTransport.performRequestAsync(request, DeleteManyRequest.rpc(collection, defaults)); } public CompletableFuture referenceAdd(String fromUuid, String fromProperty, Reference reference) { return forEachAsync(reference.uuids(), uuid -> { var singleRef = new Reference(reference.collection(), (String) uuid); return this.restTransport.performRequestAsync(new ReferenceAddRequest(fromUuid, fromProperty, singleRef), - ReferenceAddRequest.endpoint(collectionDescriptor)); + ReferenceAddRequest.endpoint(collection, defaults)); }); } - public CompletableFuture referenceAddMany(BatchReference... references) throws IOException { + public CompletableFuture referenceAddMany(BatchReference... references) { return referenceAddMany(Arrays.asList(references)); } - public CompletableFuture referenceAddMany(List references) - throws IOException { + public CompletableFuture referenceAddMany(List references) { return this.restTransport.performRequestAsync(new ReferenceAddManyRequest(references), - ReferenceAddManyRequest.endpoint(references)); + ReferenceAddManyRequest.endpoint(references, defaults)); } public CompletableFuture referenceDelete(String fromUuid, String fromProperty, Reference reference) { return forEachAsync(reference.uuids(), uuid -> { var singleRef = new Reference(reference.collection(), (String) uuid); return this.restTransport.performRequestAsync(new ReferenceDeleteRequest(fromUuid, fromProperty, singleRef), - ReferenceDeleteRequest.endpoint(collectionDescriptor)); + ReferenceDeleteRequest.endpoint(collection, defaults)); }); } @@ -136,7 +146,7 @@ public CompletableFuture referenceReplace(String fromUuid, String fromProp return forEachAsync(reference.uuids(), uuid -> { var singleRef = new Reference(reference.collection(), (String) uuid); return this.restTransport.performRequestAsync(new ReferenceReplaceRequest(fromUuid, fromProperty, singleRef), - ReferenceReplaceRequest.endpoint(collectionDescriptor)); + ReferenceReplaceRequest.endpoint(collection, defaults)); }); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index 0db66bd6b..e24fbadcb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -4,6 +4,8 @@ import java.util.Optional; import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -11,10 +13,19 @@ abstract class AbstractQueryClient { protected final CollectionDescriptor collection; protected final GrpcTransport grpcTransport; + protected final CollectionHandleDefaults defaults; - AbstractQueryClient(CollectionDescriptor collection, GrpcTransport grpcTransport) { + AbstractQueryClient(CollectionDescriptor collection, GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { this.collection = collection; this.grpcTransport = grpcTransport; + this.defaults = defaults; + } + + /** Copy constructor that sets new defaults. */ + AbstractQueryClient(AbstractQueryClient qc, + CollectionHandleDefaults defaults) { + this(qc.collection, qc.grpcTransport, defaults); } protected abstract SingleT byId(ById byId); @@ -30,11 +41,22 @@ public SingleT byId(String uuid) { } public SingleT byId(String uuid, Function> fn) { + // Collection handle defaults (consistencyLevel / tenant) are irrelevant for + // by-ID lookup. Do not `applyDefaults` to `fn`. return byId(ById.of(uuid, fn)); } - protected final Optional optionalFirst(List objects) { - return objects.isEmpty() ? Optional.empty() : Optional.ofNullable(objects.get(0)); + /** + * Retrieve the first result from query response if any. + * + * @param response Query response. + * @return An object from the list or empty {@link Optional}. + */ + protected final Optional> optionalFirst(QueryResponse response) { + return response == null || response.objects().isEmpty() + ? Optional.empty() + : Optional.ofNullable(response.objects().get(0)); + } // Object queries ----------------------------------------------------------- diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java index 27cd30221..5ed88258f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java @@ -1,20 +1,37 @@ package io.weaviate.client6.v1.api.collections.query; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatchDelete; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; public enum ConsistencyLevel { - ONE(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE), - QUORUM(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE), - ALL(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE); + ONE(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE, "ONE"), + QUORUM(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE, "QUORUM"), + ALL(WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE, "ALL"); private final WeaviateProtoBase.ConsistencyLevel consistencyLevel; + private final String queryParameter; - ConsistencyLevel(WeaviateProtoBase.ConsistencyLevel consistencyLevel) { + ConsistencyLevel(WeaviateProtoBase.ConsistencyLevel consistencyLevel, String queryParameter) { this.consistencyLevel = consistencyLevel; + this.queryParameter = queryParameter; } - final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) { + public final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) { req.setConsistencyLevel(consistencyLevel); } + + public final void appendTo(WeaviateProtoBatchDelete.BatchDeleteRequest.Builder req) { + req.setConsistencyLevel(consistencyLevel); + } + + public final void appendTo(WeaviateProtoBatch.BatchObjectsRequest.Builder req) { + req.setConsistencyLevel(consistencyLevel); + } + + @Override + public String toString() { + return queryParameter; + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java index d7a3e4bb2..a3844b4eb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java @@ -3,5 +3,6 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; interface QueryOperator { + /** Append QueryOperator to the request message. */ void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 22ebbfa7c..9d1911286 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -9,6 +9,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; @@ -24,8 +25,9 @@ public record QueryRequest(QueryOperator operator, GroupBy groupBy) { static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( - CollectionDescriptor collection) { - return Rpc.of( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return defaults.rpc(Rpc.of( request -> { var message = WeaviateProtoSearchGet.SearchRequest.newBuilder(); message.setUses127Api(true); @@ -48,12 +50,13 @@ static Rpc(objects); }, () -> WeaviateBlockingStub::search, - () -> WeaviateFutureStub::search); + () -> WeaviateFutureStub::search)); } static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( - CollectionDescriptor collection) { - var rpc = rpc(collection); + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var rpc = rpc(collection, defaults); return Rpc.of( request -> rpc.marshal(request), reply -> { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java index 54801ca12..b80640b4a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java @@ -2,6 +2,7 @@ import java.util.Optional; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -10,28 +11,36 @@ public class WeaviateQueryClient extends AbstractQueryClient>, QueryResponse, QueryResponseGrouped> { - public WeaviateQueryClient(CollectionDescriptor collection, GrpcTransport grpcTransport) { - super(collection, grpcTransport); + public WeaviateQueryClient( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateQueryClient(WeaviateQueryClient qc, CollectionHandleDefaults defaults) { + super(qc, defaults); } @Override protected Optional> byId(ById byId) { var request = new QueryRequest(byId, null); - var result = this.grpcTransport.performRequest(request, QueryRequest.rpc(collection)); - return optionalFirst(result.objects()); + var result = this.grpcTransport.performRequest(request, QueryRequest.rpc(collection, defaults)); + return optionalFirst(result); } @Override protected final QueryResponse performRequest(QueryOperator operator) { var request = new QueryRequest(operator, null); - return this.grpcTransport.performRequest(request, QueryRequest.rpc(collection)); + return this.grpcTransport.performRequest(request, QueryRequest.rpc(collection, defaults)); } @Override protected final QueryResponseGrouped performRequest(QueryOperator operator, GroupBy groupBy) { var request = new QueryRequest(operator, groupBy); - return this.grpcTransport.performRequest(request, QueryRequest.grouped(collection)); + return this.grpcTransport.performRequest(request, QueryRequest.grouped(collection, defaults)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java index e8415314f..0c195e80d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java @@ -3,6 +3,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -11,28 +12,36 @@ public class WeaviateQueryClientAsync extends AbstractQueryClient>>, CompletableFuture>, CompletableFuture>> { - public WeaviateQueryClientAsync(CollectionDescriptor collection, GrpcTransport grpcTransport) { - super(collection, grpcTransport); + public WeaviateQueryClientAsync( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateQueryClientAsync(WeaviateQueryClientAsync qc, CollectionHandleDefaults defaults) { + super(qc, defaults); } @Override protected CompletableFuture>> byId( ById byId) { var request = new QueryRequest(byId, null); - var result = this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection)); - return result.thenApply(r -> optionalFirst(r.objects())); + var result = this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection, defaults)); + return result.thenApply(this::optionalFirst); } @Override protected final CompletableFuture> performRequest(QueryOperator operator) { var request = new QueryRequest(operator, null); - return this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection)); + return this.grpcTransport.performRequestAsync(request, QueryRequest.rpc(collection, defaults)); } @Override protected final CompletableFuture> performRequest(QueryOperator operator, GroupBy groupBy) { var request = new QueryRequest(operator, groupBy); - return this.grpcTransport.performRequestAsync(request, QueryRequest.grouped(collection)); + return this.grpcTransport.performRequestAsync(request, QueryRequest.grouped(collection, defaults)); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/BooleanEndpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/BooleanEndpoint.java index b0b20665e..a4e29b20f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/BooleanEndpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/BooleanEndpoint.java @@ -8,7 +8,7 @@ public class BooleanEndpoint extends EndpointBase { public BooleanEndpoint( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, Function body) { super(method, requestUrl, queryParameters, body); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index bcd5a07ab..99d5bcb42 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -75,6 +75,12 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { this.httpClientAsync.start(); } + private String uri(Endpoint ep, RequestT req) { + return transportOptions.baseUrl() + + ep.requestUrl(req) + + UrlEncoder.encodeQuery(ep.queryParameters(req)); + } + @Override public ResponseT performRequest(RequestT request, Endpoint endpoint) @@ -86,7 +92,7 @@ public ResponseT performRequest(RequestT reque private ClassicHttpRequest prepareClassicRequest(RequestT request, Endpoint endpoint) { var method = endpoint.method(request); - var uri = transportOptions.baseUrl() + endpoint.requestUrl(request); + var uri = uri(endpoint, request); // TODO: apply options; var req = ClassicRequestBuilder.create(method).setUri(uri); @@ -138,8 +144,7 @@ public void cancelled() { private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endpoint endpoint) { var method = endpoint.method(request); - var uri = transportOptions.baseUrl() + endpoint.requestUrl(request); - // TODO: apply options; + var uri = uri(endpoint, request); var body = endpoint.body(request); var req = SimpleHttpRequest.create(method, uri); @@ -166,19 +171,7 @@ private ResponseT _handleResponse(Endpoint endpoint, S var message = endpoint.deserializeError(statusCode, body); throw WeaviateApiException.http(method, url, statusCode, message); } - - if (endpoint instanceof JsonEndpoint json) { - @SuppressWarnings("unchecked") - ResponseT response = (ResponseT) json.deserializeResponse(statusCode, body); - return response; - } else if (endpoint instanceof BooleanEndpoint bool) { - @SuppressWarnings("unchecked") - ResponseT response = (ResponseT) ((Boolean) bool.getResult(statusCode)); - return response; - } - - // TODO: make it a WeaviateTransportException - throw new RuntimeException("Unhandled endpoint type " + endpoint.getClass().getSimpleName()); + return EndpointBase.deserializeResponse(endpoint, statusCode, body); } @Override diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java index 52cc37c3b..6e1e33760 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java @@ -10,7 +10,7 @@ public interface Endpoint { String body(RequestT request); - Map queryParameters(RequestT request); + Map queryParameters(RequestT request); /** Should this status code be considered an error? */ boolean isError(int statusCode); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java b/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java index 2ebe61d6d..d38622915 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java @@ -14,7 +14,7 @@ public abstract class EndpointBase implements Endpoint method; protected final Function requestUrl; protected final Function body; - protected final Function> queryParameters; + protected final Function> queryParameters; @SuppressWarnings("unchecked") protected static Function nullBody() { @@ -24,7 +24,7 @@ protected static Function nullBody() { public EndpointBase( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, Function body) { this.method = method; this.requestUrl = requestUrl; @@ -43,7 +43,7 @@ public String requestUrl(RequestT request) { } @Override - public Map queryParameters(RequestT request) { + public Map queryParameters(RequestT request) { return queryParameters.apply(request); } @@ -67,6 +67,19 @@ public String deserializeError(int statusCode, String responseBody) { return response.errors.get(0).text(); } + @SuppressWarnings("unchecked") + public static ResponseT deserializeResponse(Endpoint endpoint, int statusCode, + String responseBody) { + if (endpoint instanceof JsonEndpoint json) { + return (ResponseT) json.deserializeResponse(statusCode, responseBody); + } else if (endpoint instanceof BooleanEndpoint bool) { + return (ResponseT) ((Boolean) bool.getResult(statusCode)); + } + + // TODO: make it a WeaviateTransportException + throw new RuntimeException("Unhandled endpoint type " + endpoint.getClass().getSimpleName()); + } + static record ErrorResponse(@SerializedName("error") List errors) { private static record ErrorMessage(@SerializedName("message") String text) { } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/OptionalEndpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/OptionalEndpoint.java index 0b6052573..c3863bf97 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/OptionalEndpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/OptionalEndpoint.java @@ -10,7 +10,7 @@ public class OptionalEndpoint extends SimpleEndpoint OptionalEndpoint noBodyOptional( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, BiFunction deserializeResponse) { return new OptionalEndpoint<>(method, requestUrl, queryParameters, nullBody(), deserializeResponse); } @@ -18,7 +18,7 @@ public static OptionalEndpoint noBody public OptionalEndpoint( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, Function body, BiFunction deserializeResponse) { super(method, requestUrl, queryParameters, body, optional(deserializeResponse)); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java index 9f5c6fa9c..963cf4e3a 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java @@ -17,7 +17,7 @@ protected static BiFunction nullResponse() { public static SimpleEndpoint noBody( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, BiFunction deserializeResponse) { return new SimpleEndpoint<>(method, requestUrl, queryParameters, nullBody(), deserializeResponse); } @@ -25,7 +25,7 @@ public static SimpleEndpoint noBody( public static SimpleEndpoint sideEffect( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, Function body) { return new SimpleEndpoint<>(method, requestUrl, queryParameters, body, nullResponse()); } @@ -33,14 +33,14 @@ public static SimpleEndpoint sideEffect( public static SimpleEndpoint sideEffect( Function method, Function requestUrl, - Function> queryParameters) { + Function> queryParameters) { return new SimpleEndpoint<>(method, requestUrl, queryParameters, nullBody(), nullResponse()); } public SimpleEndpoint( Function method, Function requestUrl, - Function> queryParameters, + Function> queryParameters, Function body, BiFunction deserializeResponse) { super(method, requestUrl, queryParameters, body); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/UrlEncoder.java b/src/main/java/io/weaviate/client6/v1/internal/rest/UrlEncoder.java new file mode 100644 index 000000000..9c0f5f6ee --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/UrlEncoder.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.stream.Collectors; + +public final class UrlEncoder { + + private static String encodeValue(Object value) { + try { + return URLEncoder.encode(value.toString(), StandardCharsets.UTF_8.toString()); + } catch (UnsupportedEncodingException e) { + throw new AssertionError(e); // should never happen with a standard encoding + } + } + + public static String encodeQuery(Map queryParams) { + if (queryParams == null || queryParams.isEmpty()) { + return ""; + } + return queryParams.entrySet().stream() + .map(qp -> qp.getKey() + "=" + encodeValue(qp.getValue())) + .collect(Collectors.joining("&", "?", "")); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java new file mode 100644 index 000000000..1c0d7aff2 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java @@ -0,0 +1,51 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Map; + +import org.assertj.core.api.Assertions; +import org.junit.Test; + +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class CollectionHandleDefaultsTest { + private static final CollectionDescriptor> DESCRIPTOR = CollectionDescriptor.ofMap("Things"); + private static final CollectionHandleDefaults NONE_DEFAULTS = CollectionHandleDefaults.of(ObjectBuilder.identity()); + + /** CollectionHandle with no defaults. */ + private static final CollectionHandle> HANDLE_NONE = new CollectionHandle<>( + null, null, DESCRIPTOR, NONE_DEFAULTS); + + /** CollectionHandleAsync with no defaults. */ + private static final CollectionHandleAsync> HANDLE_NONE_ASYNC = new CollectionHandleAsync<>( + null, null, DESCRIPTOR, NONE_DEFAULTS); + + /** All defaults are {@code null} if none were set. */ + @Test + public void test_defaults() { + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + } + + /** + * {@link CollectionHandle#withConsistencyLevel} should create a copy with + * different defaults but not modify the original. + */ + @Test + public void test_withConsistencyLevel() { + var handle = HANDLE_NONE.withConsistencyLevel(ConsistencyLevel.QUORUM); + Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + } + + /** + * {@link CollectionHandleAsync#withConsistencyLevel} should create a copy with + * different defaults but not modify the original. + */ + @Test + public void test_withConsistencyLevel_async() { + var handle = HANDLE_NONE_ASYNC.withConsistencyLevel(ConsistencyLevel.QUORUM); + Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull(); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientTest.java new file mode 100644 index 000000000..aecffae65 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientTest.java @@ -0,0 +1,155 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.Map; + +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import com.google.gson.JsonParser; +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; +import com.jparams.junit4.description.Name; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults.Location; +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.testutil.transport.MockGrpcTransport; +import io.weaviate.testutil.transport.MockRestTransport; + +@RunWith(JParamsTestRunner.class) +public class WeaviateDataClientTest { + private static MockRestTransport rest; + private static MockGrpcTransport grpc; + + @BeforeClass + public static void setUp() { + rest = new MockRestTransport(); + grpc = new MockGrpcTransport(); + } + + @AfterClass + public static void tearDown() throws Exception { + rest.close(); + grpc.close(); + } + + @FunctionalInterface + interface Act { + void apply(WeaviateDataClient> client) throws Exception; + } + + public static Object[][] restTestCases() { + return new Object[][] { + { + "insert single object", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.insert(Map.of()), + }, + { + "replace single object", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.replace("test-uuid", ObjectBuilder.identity()), + }, + { + "update single object", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.update("test-uuid", ObjectBuilder.identity()), + }, + { + "delete by id", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.delete("test-uuid"), + }, + { + "add reference", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.referenceAdd("from-uuid", "from_property", Reference.uuids("to-uuid")), + }, + { + "add reference many", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.referenceAddMany(), + }, + { + "replace reference", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.referenceReplace("from-uuid", "from_property", Reference.uuids("to-uuid")), + }, + { + "delete reference", + ConsistencyLevel.ONE, Location.QUERY, + (Act) client -> client.referenceDelete("from-uuid", "from_property", Reference.uuids("to-uuid")), + }, + }; + } + + @Name("0") + @DataMethod(source = WeaviateDataClientTest.class, method = "restTestCases") + @Test + public void test_collectionHandleDefaults_rest(String __, ConsistencyLevel cl, Location clLoc, Act act) + throws Exception { + // Arrange + var collection = CollectionDescriptor.ofMap("Things"); + var defaults = new CollectionHandleDefaults(cl); + var client = new WeaviateDataClient>( + collection, rest, null, defaults); + + // Act + act.apply(client); + + // Assert + rest.assertNext((method, requestUrl, body, query) -> { + switch (clLoc) { + case QUERY: + Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel()); + break; + case BODY: + assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel()); + } + }); + } + + private void assertJsonHasValue(String json, String key, T value) { + var gotJson = JsonParser.parseString(json).getAsJsonObject(); + Assertions.assertThat(gotJson.has(key)) + .describedAs("missing key \"%s\" in %s", key, json) + .isTrue(); + + var wantValue = JsonParser.parseString(JSON.serialize(value)); + Assertions.assertThat(gotJson.get(key)).isEqualTo(wantValue); + } + + public static Object[][] grpcTestCases() { + return new Object[][] { + { "object exists", (Act) client -> client.exists("test-uuid") }, + { "insert many", (Act) client -> client.insertMany() }, + { "delete many", (Act) client -> client.deleteMany() }, + }; + } + + @Name("0") + @DataMethod(source = WeaviateDataClientTest.class, method = "grpcTestCases") + @Test + public void test_collectionHandleDefaults_grpc(String __, Act act) + throws Exception { + // Arrange + var collection = CollectionDescriptor.ofMap("Things"); + var defaults = new CollectionHandleDefaults(ConsistencyLevel.ONE); + var client = new WeaviateDataClient>( + collection, null, grpc, defaults); + + // Act + act.apply(client); + + // Assert + grpc.assertNext(json -> assertJsonHasValue(json, "consistencyLevel", + WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE.toString())); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientTest.java new file mode 100644 index 000000000..aef374385 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientTest.java @@ -0,0 +1,87 @@ +package io.weaviate.client6.v1.api.collections.query; + +import java.util.Map; + +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; + +import com.google.gson.JsonParser; +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; +import com.jparams.junit4.description.Name; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.testutil.transport.MockGrpcTransport; + +@RunWith(JParamsTestRunner.class) +public class WeaviateQueryClientTest { + private static MockGrpcTransport grpc; + + @BeforeClass + public static void setUp() { + grpc = new MockGrpcTransport(); + } + + @AfterClass + public static void tearDown() throws Exception { + grpc.close(); + } + + @FunctionalInterface + interface Act { + void apply(WeaviateQueryClient> client) throws Exception; + } + + private void assertJsonHasValue(String json, String key, T value) { + var gotJson = JsonParser.parseString(json).getAsJsonObject(); + Assertions.assertThat(gotJson.has(key)) + .describedAs("missing key \"%s\" in %s", key, json) + .isTrue(); + + var wantValue = JsonParser.parseString(JSON.serialize(value)); + Assertions.assertThat(gotJson.get(key)).isEqualTo(wantValue); + } + + public static Object[][] grpcTestCases() { + return new Object[][] { + { "get by id", (Act) client -> client.byId("test-uuid") }, + { "fetch objects", (Act) client -> client.fetchObjects(ObjectBuilder.identity()) }, + { "bm25", (Act) client -> client.bm25("red ballon") }, + { "hybrid", (Act) client -> client.hybrid("red ballon") }, + { "nearVector", (Act) client -> client.nearVector(new float[] {}) }, + { "nearText", (Act) client -> client.nearText("weather in Arizona") }, + { "nearObject", (Act) client -> client.nearObject("test-uuid") }, + { "nearImage", (Act) client -> client.nearImage("img.jpeg") }, + { "nearAudio", (Act) client -> client.nearAudio("song.mp3") }, + { "nearVideo", (Act) client -> client.nearVideo("clip.mp4") }, + { "nearDepth", (Act) client -> client.nearDepth("20.000 leagues") }, + { "nearThermal", (Act) client -> client.nearThermal("Fahrenheit 451") }, + { "nearImu", (Act) client -> client.nearImu("6 m/s") }, + }; + } + + @Name("0") + @DataMethod(source = WeaviateQueryClientTest.class, method = "grpcTestCases") + @Test + public void test_collectionHandleDefaults_grpc(String __, Act act) + throws Exception { + // Arrange + var collection = CollectionDescriptor.ofMap("Things"); + var defaults = new CollectionHandleDefaults(ConsistencyLevel.ONE); + var client = new WeaviateQueryClient>(collection, grpc, defaults); + + // Act + act.apply(client); + + // Assert + grpc.assertNext(json -> assertJsonHasValue(json, "consistencyLevel", + WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE.toString())); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 56e189f2e..45607c4f4 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -345,6 +345,7 @@ public void test_serialize(Object cls, Object in, String want) { } + @FunctionalInterface private interface CustomAssert extends BiConsumer { } diff --git a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java new file mode 100644 index 000000000..cb504d01b --- /dev/null +++ b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java @@ -0,0 +1,56 @@ +package io.weaviate.testutil.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.util.JsonFormat; + +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.grpc.Rpc; + +public class MockGrpcTransport implements GrpcTransport { + + @FunctionalInterface + public interface AssertFunction { + void apply(String json); + } + + private List requests = new ArrayList<>(); + + public void assertNext(AssertFunction... assertions) { + var assertN = Math.min(assertions.length, requests.size()); + for (var i = 0; i < assertN; i++) { + var req = requests.get(i); + String json; + try { + json = JsonFormat.printer().print(req); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + assertions[i].apply(json); + } + requests.clear(); + } + + @Override + public ResponseT performRequest(RequestT request, + Rpc rpc) { + requests.add((MessageOrBuilder) rpc.marshal(request)); + return null; + } + + @Override + public CompletableFuture performRequestAsync(RequestT request, + Rpc rpc) { + requests.add((MessageOrBuilder) rpc.marshal(request)); + return null; + } + + @Override + public void close() throws IOException { + } +} diff --git a/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java new file mode 100644 index 000000000..c3cdf1234 --- /dev/null +++ b/src/test/java/io/weaviate/testutil/transport/MockRestTransport.java @@ -0,0 +1,55 @@ +package io.weaviate.testutil.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import io.weaviate.client6.v1.internal.rest.Endpoint; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class MockRestTransport implements RestTransport { + + private record Request(String method, String requestUrl, String body, + Map queryParameters) { + + Request(RequestT req, Endpoint ep) { + this(ep.method(req), ep.requestUrl(req), ep.body(req), ep.queryParameters(req)); + } + } + + @FunctionalInterface + public interface AssertFunction { + void apply(String method, String requestUrl, String body, Map queryParameters); + } + + private List> requests = new ArrayList<>(); + + public void assertNext(AssertFunction... assertions) { + var assertN = Math.min(assertions.length, requests.size()); + for (var i = 0; i < assertN; i++) { + var req = requests.get(i); + assertions[i].apply(req.method, req.requestUrl, req.body, req.queryParameters); + } + requests.clear(); + } + + @Override + public ResponseT performRequest(RequestT request, + Endpoint endpoint) throws IOException { + requests.add(new Request<>(request, endpoint)); + return null; + } + + @Override + public CompletableFuture performRequestAsync(RequestT request, + Endpoint endpoint) { + requests.add(new Request<>(request, endpoint)); + return null; + } + + @Override + public void close() throws IOException { + } +}