Skip to content

Commit 4eb8305

Browse files
committed
Implement batch async
1 parent 03a2a13 commit 4eb8305

File tree

11 files changed

+1097
-95
lines changed

11 files changed

+1097
-95
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.weaviate.client.base.grpc;
2+
3+
import com.google.common.util.concurrent.ListenableFuture;
4+
import io.grpc.ManagedChannel;
5+
import io.grpc.Metadata;
6+
import io.grpc.stub.MetadataUtils;
7+
import io.weaviate.client.Config;
8+
import io.weaviate.client.base.grpc.base.BaseGrpcClient;
9+
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
10+
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
11+
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
12+
import lombok.AccessLevel;
13+
import lombok.experimental.FieldDefaults;
14+
15+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
16+
public class AsyncGrpcClient extends BaseGrpcClient {
17+
WeaviateGrpc.WeaviateFutureStub client;
18+
ManagedChannel channel;
19+
20+
private AsyncGrpcClient(WeaviateGrpc.WeaviateFutureStub client, ManagedChannel channel) {
21+
this.client = client;
22+
this.channel = channel;
23+
}
24+
25+
public ListenableFuture<WeaviateProtoBatch.BatchObjectsReply> batchObjects(WeaviateProtoBatch.BatchObjectsRequest request) {
26+
return this.client.batchObjects(request);
27+
}
28+
29+
public void shutdown() {
30+
this.channel.shutdown();
31+
}
32+
33+
public static AsyncGrpcClient create(Config config, AccessTokenProvider tokenProvider) {
34+
Metadata headers = getHeaders(config, tokenProvider);
35+
ManagedChannel channel = buildChannel(config);
36+
WeaviateGrpc.WeaviateFutureStub stub = WeaviateGrpc.newFutureStub(channel);
37+
WeaviateGrpc.WeaviateFutureStub client = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
38+
return new AsyncGrpcClient(client, channel);
39+
}
40+
}
Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
package io.weaviate.client.base.grpc;
22

33
import io.grpc.ManagedChannel;
4-
import io.grpc.ManagedChannelBuilder;
54
import io.grpc.Metadata;
65
import io.grpc.stub.MetadataUtils;
76
import io.weaviate.client.Config;
7+
import io.weaviate.client.base.grpc.base.BaseGrpcClient;
88
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
99
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
1010
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
11-
import java.util.Map;
1211
import lombok.AccessLevel;
1312
import lombok.experimental.FieldDefaults;
1413

1514
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
16-
public class GrpcClient {
15+
public class GrpcClient extends BaseGrpcClient {
1716
WeaviateGrpc.WeaviateBlockingStub client;
1817
ManagedChannel channel;
1918

@@ -31,38 +30,10 @@ public void shutdown() {
3130
}
3231

3332
public static GrpcClient create(Config config, AccessTokenProvider tokenProvider) {
34-
Metadata headers = new Metadata();
35-
if (config.getHeaders() != null) {
36-
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
37-
headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue());
38-
}
39-
}
40-
if (tokenProvider != null) {
41-
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken()));
42-
}
43-
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config));
44-
if (config.isGRPCSecured()) {
45-
channelBuilder = channelBuilder.useTransportSecurity();
46-
} else {
47-
channelBuilder.usePlaintext();
48-
}
49-
ManagedChannel channel = channelBuilder.build();
33+
Metadata headers = getHeaders(config, tokenProvider);
34+
ManagedChannel channel = buildChannel(config);
5035
WeaviateGrpc.WeaviateBlockingStub blockingStub = WeaviateGrpc.newBlockingStub(channel);
5136
WeaviateGrpc.WeaviateBlockingStub client = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
5237
return new GrpcClient(client, channel);
5338
}
54-
55-
private static String getAddress(Config config) {
56-
if (config.getGRPCHost() != null) {
57-
String host = config.getGRPCHost();
58-
if (host.contains(":")) {
59-
return host;
60-
}
61-
if (config.isGRPCSecured()) {
62-
return String.format("%s:443", host);
63-
}
64-
return String.format("%s:80", host);
65-
}
66-
return "";
67-
}
6839
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package io.weaviate.client.base.grpc.base;
2+
3+
import io.grpc.ManagedChannel;
4+
import io.grpc.ManagedChannelBuilder;
5+
import io.grpc.Metadata;
6+
import io.weaviate.client.Config;
7+
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
8+
import java.util.Map;
9+
10+
public class BaseGrpcClient {
11+
12+
protected static Metadata getHeaders(Config config, AccessTokenProvider tokenProvider) {
13+
Metadata headers = new Metadata();
14+
if (config.getHeaders() != null) {
15+
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
16+
headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue());
17+
}
18+
}
19+
if (tokenProvider != null) {
20+
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken()));
21+
}
22+
return headers;
23+
}
24+
25+
protected static ManagedChannel buildChannel(Config config) {
26+
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config));
27+
if (config.isGRPCSecured()) {
28+
channelBuilder = channelBuilder.useTransportSecurity();
29+
} else {
30+
channelBuilder.usePlaintext();
31+
}
32+
return channelBuilder.build();
33+
}
34+
35+
private static String getAddress(Config config) {
36+
if (config.getGRPCHost() != null) {
37+
String host = config.getGRPCHost();
38+
if (host.contains(":")) {
39+
return host;
40+
}
41+
if (config.isGRPCSecured()) {
42+
return String.format("%s:443", host);
43+
}
44+
return String.format("%s:80", host);
45+
}
46+
return "";
47+
}
48+
}

src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.weaviate.client.base.http.async.AsyncHttpClient;
66
import io.weaviate.client.base.util.DbVersionProvider;
77
import io.weaviate.client.base.util.DbVersionSupport;
8+
import io.weaviate.client.v1.async.batch.Batch;
89
import io.weaviate.client.v1.async.classifications.Classifications;
910
import io.weaviate.client.v1.async.cluster.Cluster;
1011
import io.weaviate.client.v1.async.data.Data;
@@ -42,6 +43,10 @@ public Data data() {
4243
return new Data(client, config, dbVersionSupport);
4344
}
4445

46+
public Batch batch() {
47+
return new Batch(client, config);
48+
}
49+
4550
public Cluster cluster() {
4651
return new Cluster(client, config);
4752
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package io.weaviate.client.v1.async.batch;
2+
3+
import io.weaviate.client.Config;
4+
import io.weaviate.client.v1.async.batch.api.ObjectsBatchDeleter;
5+
import io.weaviate.client.v1.async.batch.api.ObjectsBatcher;
6+
import io.weaviate.client.v1.batch.util.ObjectsPath;
7+
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
8+
9+
public class Batch {
10+
private final CloseableHttpAsyncClient client;
11+
private final Config config;
12+
private final ObjectsPath objectsPath;
13+
14+
public Batch(CloseableHttpAsyncClient client, Config config) {
15+
this.client = client;
16+
this.config = config;
17+
this.objectsPath = new ObjectsPath();
18+
}
19+
20+
public ObjectsBatcher objectsBatcher() {
21+
return objectsBatcher(ObjectsBatcher.BatchRetriesConfig.defaultConfig().build());
22+
}
23+
24+
public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
25+
return ObjectsBatcher.create(client, config, null, objectsPath, null, null, batchRetriesConfig);
26+
// return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig);
27+
}
28+
29+
public ObjectsBatchDeleter objectsBatchDeleter() {
30+
return new ObjectsBatchDeleter(client, config, objectsPath);
31+
}
32+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package io.weaviate.client.v1.async.batch.api;
2+
3+
import io.weaviate.client.Config;
4+
import io.weaviate.client.base.AsyncBaseClient;
5+
import io.weaviate.client.base.AsyncClientResult;
6+
import io.weaviate.client.base.Result;
7+
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
8+
import io.weaviate.client.v1.batch.util.ObjectsPath;
9+
import io.weaviate.client.v1.filters.WhereFilter;
10+
import java.util.concurrent.Future;
11+
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
12+
import org.apache.hc.core5.concurrent.FutureCallback;
13+
14+
public class ObjectsBatchDeleter extends AsyncBaseClient<BatchDeleteResponse> implements AsyncClientResult<BatchDeleteResponse> {
15+
private final ObjectsPath objectsPath;
16+
private String className;
17+
private String consistencyLevel;
18+
private String tenant;
19+
private WhereFilter where;
20+
private String output;
21+
private Boolean dryRun;
22+
23+
public ObjectsBatchDeleter(CloseableHttpAsyncClient client, Config config, ObjectsPath objectsPath) {
24+
super(client, config);
25+
this.objectsPath = objectsPath;
26+
}
27+
28+
public ObjectsBatchDeleter withClassName(String className) {
29+
this.className = className;
30+
return this;
31+
}
32+
33+
public ObjectsBatchDeleter withConsistencyLevel(String consistencyLevel) {
34+
this.consistencyLevel = consistencyLevel;
35+
return this;
36+
}
37+
38+
public ObjectsBatchDeleter withTenant(String tenant) {
39+
this.tenant = tenant;
40+
return this;
41+
}
42+
43+
public ObjectsBatchDeleter withWhere(WhereFilter where) {
44+
this.where = where;
45+
return this;
46+
}
47+
48+
public ObjectsBatchDeleter withOutput(String output) {
49+
this.output = output;
50+
return this;
51+
}
52+
53+
public ObjectsBatchDeleter withDryRun(Boolean dryRun) {
54+
this.dryRun = dryRun;
55+
return this;
56+
}
57+
58+
@Override
59+
public Future<Result<BatchDeleteResponse>> run() {
60+
return run(null);
61+
}
62+
63+
@Override
64+
public Future<Result<BatchDeleteResponse>> run(FutureCallback<Result<BatchDeleteResponse>> callback) {
65+
io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDeleteMatch match = io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDeleteMatch.builder()
66+
.className(className)
67+
.whereFilter(where)
68+
.build();
69+
io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDelete batchDelete = io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDelete.builder()
70+
.dryRun(dryRun)
71+
.output(output)
72+
.match(match)
73+
.build();
74+
String path = objectsPath.buildDelete(ObjectsPath.Params.builder()
75+
.consistencyLevel(consistencyLevel)
76+
.tenant(tenant)
77+
.build());
78+
return sendDeleteRequest(path, batchDelete, BatchDeleteResponse.class, callback);
79+
}
80+
}

0 commit comments

Comments
 (0)