Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,50 @@
import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient;
import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient;
import io.weaviate.client6.v1.api.collections.pagination.Paginator;
import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel;
import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
import io.weaviate.client6.v1.internal.orm.CollectionDescriptor;
import io.weaviate.client6.v1.internal.rest.RestTransport;

public class CollectionHandle<T> {
public class CollectionHandle<PropertiesT> {
public final WeaviateConfigClient config;
public final WeaviateDataClient<T> data;
public final WeaviateQueryClient<T> query;
public final WeaviateDataClient<PropertiesT> data;
public final WeaviateQueryClient<PropertiesT> query;
public final WeaviateAggregateClient aggregate;

private final CollectionHandleDefaults defaults;

public CollectionHandle(
RestTransport restTransport,
GrpcTransport grpcTransport,
CollectionDescriptor<T> collectionDescriptor) {

CollectionDescriptor<PropertiesT> collectionDescriptor,
CollectionHandleDefaults defaults) {
this.config = new WeaviateConfigClient(collectionDescriptor, restTransport, grpcTransport);
this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, grpcTransport);
this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport);
this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport);
this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport, defaults);
this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport, defaults);
this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport, grpcTransport, defaults);

this.defaults = defaults;
}

/** Copy constructor that sets new defaults. */
private CollectionHandle(CollectionHandle<PropertiesT> c, CollectionHandleDefaults defaults) {
this.config = c.config;
this.aggregate = c.aggregate;
this.query = new WeaviateQueryClient<>(c.query, defaults);
this.data = new WeaviateDataClient<>(c.data, defaults);

this.defaults = defaults;
}

public Paginator<T> paginate() {
public Paginator<PropertiesT> paginate() {
return Paginator.of(this.query);
}

public Paginator<T> paginate(Function<Paginator.Builder<T>, ObjectBuilder<Paginator<T>>> fn) {
public Paginator<PropertiesT> paginate(
Function<Paginator.Builder<PropertiesT>, ObjectBuilder<Paginator<PropertiesT>>> fn) {
return Paginator.of(this.query, fn);
}

Expand All @@ -57,4 +73,12 @@ public Paginator<T> paginate(Function<Paginator.Builder<T>, ObjectBuilder<Pagina
public long size() {
return this.aggregate.overAll(all -> all.includeTotalCount(true)).totalCount();
}

public ConsistencyLevel consistencyLevel() {
return defaults.consistencyLevel();
}

public CollectionHandle<PropertiesT> withConsistencyLevel(ConsistencyLevel consistencyLevel) {
return new CollectionHandle<>(this, CollectionHandleDefaults.of(def -> def.consistencyLevel(consistencyLevel)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClientAsync;
import io.weaviate.client6.v1.api.collections.data.WeaviateDataClientAsync;
import io.weaviate.client6.v1.api.collections.pagination.AsyncPaginator;
import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel;
import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
Expand All @@ -21,15 +22,30 @@ public class CollectionHandleAsync<PropertiesT> {
public final WeaviateQueryClientAsync<PropertiesT> query;
public final WeaviateAggregateClientAsync aggregate;

private final CollectionHandleDefaults defaults;

public CollectionHandleAsync(
RestTransport restTransport,
GrpcTransport grpcTransport,
CollectionDescriptor<PropertiesT> collectionDescriptor) {
CollectionDescriptor<PropertiesT> collectionDescriptor,
CollectionHandleDefaults defaults) {

this.config = new WeaviateConfigClientAsync(collectionDescriptor, restTransport, grpcTransport);
this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, grpcTransport);
this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport);
this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport);
this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport, defaults);
this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport, defaults);
this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport, grpcTransport, defaults);

this.defaults = defaults;
}

/** Copy constructor that sets new defaults. */
private CollectionHandleAsync(CollectionHandleAsync<PropertiesT> c, CollectionHandleDefaults defaults) {
this.config = c.config;
this.aggregate = c.aggregate;
this.query = new WeaviateQueryClientAsync<>(c.query, defaults);
this.data = new WeaviateDataClientAsync<>(c.data, defaults);

this.defaults = defaults;
}

public AsyncPaginator<PropertiesT> paginate() {
Expand Down Expand Up @@ -64,4 +80,13 @@ public CompletableFuture<Long> size() {
return this.aggregate.overAll(all -> all.includeTotalCount(true))
.thenApply(AggregateResponse::totalCount);
}

public ConsistencyLevel consistencyLevel() {
return defaults.consistencyLevel();
}

public CollectionHandleAsync<PropertiesT> withConsistencyLevel(ConsistencyLevel consistencyLevel) {
return new CollectionHandleAsync<>(this, CollectionHandleDefaults.of(
def -> def.consistencyLevel(consistencyLevel)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package io.weaviate.client6.v1.api.collections;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import com.google.common.util.concurrent.ListenableFuture;

import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.Rpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatchDelete;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
import io.weaviate.client6.v1.internal.rest.Endpoint;
import io.weaviate.client6.v1.internal.rest.EndpointBase;
import io.weaviate.client6.v1.internal.rest.JsonEndpoint;

public record CollectionHandleDefaults(ConsistencyLevel consistencyLevel) {
private static final String CONSISTENCY_LEVEL = "consistency_level";

/**
* Set default values for query / aggregation requests.
*
* @return CollectionHandleDefaults derived from applying {@code fn} to
* {@link Builder}.
*/
public static CollectionHandleDefaults of(Function<Builder, ObjectBuilder<CollectionHandleDefaults>> fn) {
return fn.apply(new Builder()).build();
}

/**
* Empty collection defaults.
*
* @return An tucked builder that does not leaves all defaults unset.
*/
public static Function<Builder, ObjectBuilder<CollectionHandleDefaults>> none() {
return ObjectBuilder.identity();
}

public CollectionHandleDefaults(Builder builder) {
this(builder.consistencyLevel);
}

public static final class Builder implements ObjectBuilder<CollectionHandleDefaults> {
private ConsistencyLevel consistencyLevel;

/** Set default consistency level for this collection handle. */
public Builder consistencyLevel(ConsistencyLevel consistencyLevel) {
this.consistencyLevel = consistencyLevel;
return this;
}

@Override
public CollectionHandleDefaults build() {
return new CollectionHandleDefaults(this);
}
}

public <RequestT, ResponseT> Endpoint<RequestT, ResponseT> endpoint(Endpoint<RequestT, ResponseT> ep,
Function<EndpointBuilder<RequestT, ResponseT>, ObjectBuilder<Endpoint<RequestT, ResponseT>>> fn) {
return fn.apply(new EndpointBuilder<>(ep)).build();
}

public <RequestT, RequestM, ResponseT, ReplyM> Rpc<RequestT, RequestM, ResponseT, ReplyM> rpc(
Rpc<RequestT, RequestM, ResponseT, ReplyM> rpc) {
return new ContextRpc<>(rpc);
}

/** Which part of the request a parameter should be added to. */
public static enum Location {
/** Query string. */
QUERY,
/**
* Request body. {@code RequestT} must implement {@link WithDefaults} for the
* changes to be applied.
*/
BODY;
}

public static interface WithDefaults<SelfT extends WithDefaults<SelfT>> {
ConsistencyLevel consistencyLevel();

SelfT withConsistencyLevel(ConsistencyLevel consistencyLevel);
}

private class ContextEndpoint<RequestT, ResponseT> extends EndpointBase<RequestT, ResponseT>
implements JsonEndpoint<RequestT, ResponseT> {

private final Location consistencyLevelLoc;
private final Endpoint<RequestT, ResponseT> endpoint;

ContextEndpoint(EndpointBuilder<RequestT, ResponseT> builder) {
super(builder.endpoint::method,
builder.endpoint::requestUrl,
builder.endpoint::queryParameters,
builder.endpoint::body);
this.consistencyLevelLoc = builder.consistencyLevelLoc;
this.endpoint = builder.endpoint;
}

/** Return consistencyLevel of the enclosing CollectionHandleDefaults object. */
private ConsistencyLevel consistencyLevel() {
return CollectionHandleDefaults.this.consistencyLevel;
}

@Override
public Map<String, Object> queryParameters(RequestT request) {
// Copy the map, as it's most likely unmodifiable.
var query = new HashMap<>(super.queryParameters(request));
if (consistencyLevel() != null && consistencyLevelLoc == Location.QUERY) {
query.putIfAbsent(CONSISTENCY_LEVEL, consistencyLevel());
}
return query;
}

@SuppressWarnings("unchecked")
@Override
public String body(RequestT request) {
if (request instanceof WithDefaults wd) {
if (wd.consistencyLevel() == null) {
wd = wd.withConsistencyLevel(consistencyLevel());
}
// This cast is safe as long as `wd` returns its own type,
// which it does as per the interface contract.
request = (RequestT) wd;
}
return super.body(request);
}

@Override
public ResponseT deserializeResponse(int statusCode, String responseBody) {
return EndpointBase.deserializeResponse(endpoint, statusCode, responseBody);
}
}

/**
* EndpointBuilder configures how CollectionHandleDefautls
* are added to a REST request.
*/
public class EndpointBuilder<RequestT, ResponseT> implements ObjectBuilder<Endpoint<RequestT, ResponseT>> {
private final Endpoint<RequestT, ResponseT> endpoint;

private Location consistencyLevelLoc;

EndpointBuilder(Endpoint<RequestT, ResponseT> ep) {
this.endpoint = ep;
}

/** Control which part of the request to add default consistency level to. */
public EndpointBuilder<RequestT, ResponseT> consistencyLevel(Location loc) {
this.consistencyLevelLoc = loc;
return this;
}

@Override
public Endpoint<RequestT, ResponseT> build() {
return new ContextEndpoint<>(this);
}
}

private class ContextRpc<RequestT, RequestM, ResponseT, ReplyM>
implements Rpc<RequestT, RequestM, ResponseT, ReplyM> {

private final Rpc<RequestT, RequestM, ResponseT, ReplyM> rpc;

ContextRpc(Rpc<RequestT, RequestM, ResponseT, ReplyM> rpc) {
this.rpc = rpc;
}

/** Return consistencyLevel of the enclosing CollectionHandleDefaults object. */
private ConsistencyLevel consistencyLevel() {
return CollectionHandleDefaults.this.consistencyLevel;
}

@SuppressWarnings("unchecked")
@Override
public RequestM marshal(RequestT request) {
var message = rpc.marshal(request);
if (message instanceof WeaviateProtoBatch.BatchObjectsRequest msg) {
var b = msg.toBuilder();
if (!msg.hasConsistencyLevel() && consistencyLevel() != null) {
consistencyLevel().appendTo(b);
return (RequestM) b.build();
}
} else if (message instanceof WeaviateProtoBatchDelete.BatchDeleteRequest msg) {
var b = msg.toBuilder();
if (!msg.hasConsistencyLevel() && consistencyLevel() != null) {
consistencyLevel().appendTo(b);
return (RequestM) b.build();
}
} else if (message instanceof WeaviateProtoSearchGet.SearchRequest msg) {
var b = msg.toBuilder();
if (!msg.hasConsistencyLevel() && consistencyLevel() != null) {
consistencyLevel().appendTo(b);
return (RequestM) b.build();
}
}

return message;
}

@Override
public ResponseT unmarshal(ReplyM reply) {
return rpc.unmarshal(reply);
}

@Override
public BiFunction<WeaviateBlockingStub, RequestM, ReplyM> method() {
return rpc.method();
}

@Override
public BiFunction<WeaviateFutureStub, RequestM, ListenableFuture<ReplyM>> methodAsync() {
return rpc.methodAsync();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import java.io.IOException;
import java.util.EnumMap;
import java.util.Map;
import java.util.function.Function;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
Expand All @@ -20,7 +19,6 @@
import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer;
import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer;
import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.json.JsonEnum;

public interface Vectorizer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,24 @@ public WeaviateCollectionsClient(RestTransport restTransport, GrpcTransport grpc
* properties.
*/
public CollectionHandle<Map<String, Object>> use(String collectionName) {
return new CollectionHandle<>(restTransport, grpcTransport, CollectionDescriptor.ofMap(collectionName));
return use(collectionName, CollectionHandleDefaults.none());
}

/**
* Obtain a handle to send requests to a particular collection.
* The returned object is thread-safe.
*
* @return a handle for a collection with {@code Map<String, Object>}
* properties.
*/
public CollectionHandle<Map<String, Object>> use(
String collectionName,
Function<CollectionHandleDefaults.Builder, ObjectBuilder<CollectionHandleDefaults>> fn) {
return new CollectionHandle<>(
restTransport,
grpcTransport,
CollectionDescriptor.ofMap(collectionName),
CollectionHandleDefaults.of(fn));
}

/**
Expand Down
Loading