diff --git a/pom.xml b/pom.xml index 8576d7044..d67979e38 100644 --- a/pom.xml +++ b/pom.xml @@ -90,7 +90,6 @@ io.grpc grpc-netty-shaded ${grpc-netty-shaded.version} - runtime io.grpc diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index f1da7a136..a9a6c9dda 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -5,6 +5,8 @@ import java.util.Map; import java.util.function.Function; +import javax.net.ssl.TrustManagerFactory; + import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; @@ -17,7 +19,8 @@ public record Config( String grpcHost, int grpcPort, Map headers, - TokenProvider tokenProvider) { + TokenProvider tokenProvider, + TrustManagerFactory trustManagerFactory) { public static Config of(Function> fn) { return fn.apply(new Custom()).build(); @@ -31,15 +34,16 @@ private Config(Builder builder) { builder.grpcHost, builder.grpcPort, builder.headers, - builder.tokenProvider); + builder.tokenProvider, + builder.trustManagerFactory); } - public RestTransportOptions restTransportOptions() { - return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider); + RestTransportOptions restTransportOptions() { + return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory); } - public GrpcChannelOptions grpcTransportOptions() { - return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider); + GrpcChannelOptions grpcTransportOptions() { + return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory); } private abstract static class Builder> implements ObjectBuilder { @@ -50,20 +54,33 @@ private abstract static class Builder> implements Obj protected String grpcHost; protected int grpcPort; protected TokenProvider tokenProvider; + protected TrustManagerFactory trustManagerFactory; protected Map headers = new HashMap<>(); + /** + * Set URL scheme. Subclasses may increase the visibility of this method to + * {@code public} if using a different scheme is allowed. + */ @SuppressWarnings("unchecked") protected SELF scheme(String scheme) { this.scheme = scheme; return (SELF) this; } + /** + * Set port for REST requests. Subclasses may increase the visibility of this + * method to {@code public} if using a different port is allowed. + */ @SuppressWarnings("unchecked") protected SELF httpHost(String httpHost) { this.httpHost = trimScheme(httpHost); return (SELF) this; } + /** + * Set port for gRPC requests. Subclasses may increase the visibility of this + * method to {@code public} if using a different port is allowed. + */ @SuppressWarnings("unchecked") protected SELF grpcHost(String grpcHost) { this.grpcHost = trimScheme(grpcHost); @@ -75,18 +92,41 @@ private String trimScheme(String url) { return url.replaceFirst("^https?\\/\\/", ""); } + /** + * Provide a {@link TrustManagerFactory}. Subclasses which support + * secure connection should expose this method. + */ + @SuppressWarnings("unchecked") + protected SELF trustManagerFactory(TrustManagerFactory tmf) { + this.trustManagerFactory = tmf; + return (SELF) this; + } + + /** + * Set a single request header. The client does not support header lists, + * so there is no equivalent {@code addHeader} to append to existing header. + * This will be applied both to REST and gRPC requests. + */ @SuppressWarnings("unchecked") public SELF setHeader(String key, String value) { this.headers.put(key, value); return (SELF) this; } + /** + * Set multiple request headers. + * This will be applied both to REST and gRPC requests. + */ @SuppressWarnings("unchecked") public SELF setHeaders(Map headers) { - this.headers = Map.copyOf(headers); + this.headers.putAll(Map.copyOf(headers)); return (SELF) this; } + /** + * Weaviate will use the URL in this header to call Weaviate Embeddings + * Service if an appropriate vectorizer is configured for collection. + */ private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL"; /** @@ -102,6 +142,8 @@ private static boolean isWeaviateDomain(String host) { @Override public Config build() { + // For clusters hosted on Weaviate Cloud, Weaviate Embedding Service + // will be available under the same domain. if (isWeaviateDomain(httpHost) && tokenProvider != null) { setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort); } @@ -109,6 +151,18 @@ public Config build() { } } + /** + * Configuration for Weaviate instances deployed locally. + * + *

+ * Has sane defaults that match standard Weaviate deployment configuration: + *

    + *
  • {@code scheme: http}
  • + *
  • {@code host: localhost}
  • + *
  • {@code httpPort: 8080}
  • + *
  • {@code grpcPort: 50051}
  • + *
+ */ public static class Local extends Builder { public Local() { scheme("http"); @@ -117,23 +171,37 @@ public Local() { grpcPort(50051); } + /** + * Set a different hostname. + * This changes both {@code httpHost} and {@code grpcHost}. + */ public Local host(String host) { httpHost(host); grpcHost(host); return this; } + /** Override default HTTP port. */ public Local httpPort(int port) { this.httpPort = port; return this; } + /** Override default gRPC port. */ public Local grpcPort(int port) { this.grpcPort = port; return this; } } + /** + * Configuration for instances hosted on Weaviate Cloud. + * {@link WeaviateCloud} will create a secure client + * with {@code schema: https} and {@code http-/grpcPort: 443}. + * + * Custom SSL certificates are suppored via + * {@link #trustManagerFactory}. + */ public static class WeaviateCloud extends Builder { public WeaviateCloud(String httpHost, TokenProvider tokenProvider) { this(URI.create(httpHost), tokenProvider); @@ -144,18 +212,46 @@ public WeaviateCloud(URI clusterUri, TokenProvider tokenProvider) { super.httpHost(clusterUri.getHost() != null ? clusterUri.getHost() // https://[example.com]/about : clusterUri.getPath().split("/")[0]); // [example.com]/about - this.httpPort = 443; super.grpcHost("grpc-" + this.httpHost); + this.httpPort = 443; this.grpcPort = 443; this.tokenProvider = tokenProvider; } + + /** + * Configure a custom TrustStore to validate third-party SSL certificates. + * + *

+ * Usage: + * + *

{@code
+     * // Create a TrustManagerFactory to validate custom certificates.
+     * TrustManagerFactory tmf;
+     * try (var keys = new FileInputStream("/path/to/custom/truststore.p12")) {
+     *   KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
+     *   trustStore.load(myKeys, "secret-password".toCharArra());
+     *
+     *   tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+     *   tmf.init(trustStore);
+     * }
+     *
+     * // Pass it to wcd -> wcd.trustManagerFactory(tmf)
+     * }
+ */ + public WeaviateCloud trustManagerFactory(TrustManagerFactory tmf) { + return super.trustManagerFactory(tmf); + } } + /** Configuration for custom Weaviate deployements. */ public static class Custom extends Builder { /** * Scheme controls which protocol will be used for the database connection. * REST and gRPC ports will be automatically inferred from it: * 443 for HTTPS connection and 80 for HTTP. + * + * These can be overriden with {@link #httpPort(int)} and + * {@link #grpcPort(int)}. */ public Custom scheme(String scheme) { httpPort("https".equals(scheme) ? 443 : 80); @@ -163,29 +259,61 @@ public Custom scheme(String scheme) { return super.scheme(scheme); } + /** Set HTTP hostname. */ public Custom httpHost(String httpHost) { super.httpHost(httpHost); return this; } + /** Set HTTP port. */ public Custom httpPort(int port) { this.httpPort = port; return this; } + /** Set gRPC hostname. */ public Custom grpcHost(String grpcHost) { super.grpcHost(grpcHost); return this; } + /** Set gRPC port. */ public Custom grpcPort(int port) { this.grpcPort = port; return this; } + /** + * Set authorization method. Setting this to {@code null} or omitting + * will not use any authorization mechanism. + */ public Custom authorization(TokenProvider tokenProvider) { this.tokenProvider = tokenProvider; return this; } + + /** + * Configure a custom TrustStore to validate third-party SSL certificates. + * + *

+ * Usage: + * + *

{@code
+     * // Create a TrustManagerFactory to validate custom certificates.
+     * TrustManagerFactory tmf;
+     * try (var keys = new FileInputStream("/path/to/custom/truststore.p12")) {
+     *   KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
+     *   trustStore.load(myKeys, "secret-password".toCharArra());
+     *
+     *   tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+     *   tmf.init(trustStore);
+     * }
+     *
+     * // Pass it to custom -> custom.trustManagerFactory(tmf)
+     * }
+ */ + public Custom trustManagerFactory(TrustManagerFactory tmf) { + return super.trustManagerFactory(tmf); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java index 4fd1728f7..1e2127e22 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -28,32 +28,81 @@ public WeaviateClient(Config config) { this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport); } + /** + * Create {@link WeaviateClientAsync} with identical configurations. + * It is a shorthand for: + * + *
{@code
+   * var config = new Config(...);
+   * var client = new WeaviateClient(config);
+   * var async = new WeaviateClientAsync(config);
+   * }
+ * + * and as such, this does not manage or reuse resources (transport, gRPC + * channel, etc) used by the original client. Keep that in mind and make + * sure to close the original and async clients individually. + * + *

+ * Example: + * + *

{@code
+   * var client = WeaviateClient.local();
+   *
+   * // Need to make the next request non-blocking
+   * try (final var async = client.async()) {
+   *   async.collections.create("Things");
+   * }
+   * // At this point only `async` resource has been auto-closed.
+   *
+   * client.close();
+   * }
+ * + * + * If you only intend to use {@link WeaviateClientAsync}, prefer creating it + * directly via one of its static factories: + *
    + *
  • {@link WeaviateClientAsync#local} + *
  • {@link WeaviateClientAsync#wcd} + *
  • {@link WeaviateClientAsync#custom} + *
+ * + * Otherwise the client wastes time initializing resources it will never use. + */ public WeaviateClientAsync async() { return new WeaviateClientAsync(config); } + /** Connect to a local Weaviate instance. */ public static WeaviateClient local() { return local(ObjectBuilder.identity()); } + /** Connect to a local Weaviate instance. */ public static WeaviateClient local(Function> fn) { return new WeaviateClient(fn.apply(new Config.Local()).build()); } + /** Connect to a Weaviate Cloud instance. */ public static WeaviateClient wcd(String httpHost, String apiKey) { return wcd(httpHost, apiKey, ObjectBuilder.identity()); } + /** Connect to a Weaviate Cloud instance. */ public static WeaviateClient wcd(String httpHost, String apiKey, Function> fn) { var config = new Config.WeaviateCloud(httpHost, Authorization.apiKey(apiKey)); return new WeaviateClient(fn.apply(config).build()); } + /** Connect to a Weaviate instance with custom configuration. */ public static WeaviateClient custom(Function> fn) { return new WeaviateClient(fn.apply(new Config.Custom()).build()); } + /** + * Close {@link #restTransport} and {@link #grpcTransport} + * and release associated resources. + */ @Override public void close() throws IOException { this.restTransport.close(); diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java index 12ac88a3f..3f144c5e9 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java @@ -24,28 +24,37 @@ public WeaviateClientAsync(Config config) { this.collections = new WeaviateCollectionsClientAsync(restTransport, grpcTransport); } + /** Connect to a local Weaviate instance. */ public static WeaviateClientAsync local() { return local(ObjectBuilder.identity()); } + /** Connect to a local Weaviate instance. */ public static WeaviateClientAsync local(Function> fn) { return new WeaviateClientAsync(fn.apply(new Config.Local()).build()); } + /** Connect to a Weaviate Cloud instance. */ public static WeaviateClientAsync wcd(String httpHost, String apiKey) { return wcd(httpHost, apiKey, ObjectBuilder.identity()); } + /** Connect to a Weaviate Cloud instance. */ public static WeaviateClientAsync wcd(String httpHost, String apiKey, Function> fn) { var config = new Config.WeaviateCloud(httpHost, Authorization.apiKey(apiKey)); return new WeaviateClientAsync(fn.apply(config).build()); } + /** Connect to a Weaviate instance with custom configuration. */ public static WeaviateClientAsync custom(Function> fn) { return new WeaviateClientAsync(Config.of(fn)); } + /** + * Close {@link #restTransport} and {@link #grpcTransport} + * and release associated resources. + */ @Override public void close() throws IOException { this.restTransport.close(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java index 03ee045b7..8700bee47 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -1,22 +1,28 @@ package io.weaviate.client6.v1.internal; +import javax.annotation.Nullable; +import javax.net.ssl.TrustManagerFactory; + public abstract class TransportOptions { private final String scheme; private final String host; private final int port; private final TokenProvider tokenProvider; private final H headers; + private final TrustManagerFactory trustManagerFactory; - protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider) { + protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider, + TrustManagerFactory tmf) { this.scheme = scheme; this.host = host; this.port = port; this.tokenProvider = tokenProvider; this.headers = headers; + this.trustManagerFactory = tmf; } public boolean isSecure() { - return scheme == "https"; + return scheme.equals("https"); } public String scheme() { @@ -31,6 +37,7 @@ public int port() { return this.port; } + @Nullable public TokenProvider tokenProvider() { return this.tokenProvider; } @@ -38,4 +45,9 @@ public TokenProvider tokenProvider() { public H headers() { return this.headers; } + + @Nullable + public TrustManagerFactory trustManagerFactory() { + return this.trustManagerFactory; + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index 82aea9598..75893b06b 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -3,12 +3,16 @@ import java.io.IOException; import java.util.concurrent.CompletableFuture; +import javax.net.ssl.SSLException; + import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.stub.MetadataUtils; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; @@ -80,7 +84,7 @@ public void onFailure(Throwable t) { } private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) { - var channel = ManagedChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); + var channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); if (transportOptions.isSecure()) { channel.useTransportSecurity(); @@ -88,7 +92,21 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) channel.usePlaintext(); } + if (transportOptions.trustManagerFactory() != null) { + SslContext sslCtx; + try { + sslCtx = GrpcSslContexts.forClient() + .trustManager(transportOptions.trustManagerFactory()) + .build(); + } catch (SSLException e) { + // todo: rethrow as WeaviateConnectionException + throw new RuntimeException("create grpc transport", e); + } + channel.sslContext(sslCtx); + } + channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); + return channel.build(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index da67cb0c2..6e32d9738 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -2,14 +2,16 @@ import java.util.Map; +import javax.net.ssl.TrustManagerFactory; + import io.grpc.Metadata; import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.TransportOptions; public class GrpcChannelOptions extends TransportOptions { public GrpcChannelOptions(String scheme, String host, int port, Map headers, - TokenProvider tokenProvider) { - super(scheme, host, port, buildMetadata(headers), tokenProvider); + TokenProvider tokenProvider, TrustManagerFactory tmf) { + super(scheme, host, port, buildMetadata(headers), tokenProvider, tmf); } private static final Metadata buildMetadata(Map headers) { diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index 99206aac4..396eb3b1f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -1,14 +1,23 @@ package io.weaviate.client6.v1.internal.rest; import java.io.IOException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; import java.util.concurrent.CompletableFuture; +import javax.net.ssl.SSLContext; + import org.apache.hc.client5.http.async.methods.SimpleHttpRequest; import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; import org.apache.hc.client5.http.impl.async.HttpAsyncClients; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.ssl.DefaultClientTlsStrategy; import org.apache.hc.core5.concurrent.FutureCallback; import org.apache.hc.core5.http.ClassicHttpRequest; import org.apache.hc.core5.http.ContentType; @@ -36,6 +45,27 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { var httpClientAsync = HttpAsyncClients.custom() .setDefaultHeaders(transportOptions.headers()); + // Apply custom SSL context + if (transportOptions.trustManagerFactory() != null) { + DefaultClientTlsStrategy tlsStrategy; + try { + var sslCtx = SSLContext.getInstance("TLS"); + sslCtx.init(null, transportOptions.trustManagerFactory().getTrustManagers(), null); + tlsStrategy = new DefaultClientTlsStrategy(sslCtx); + } catch (NoSuchAlgorithmException | KeyManagementException e) { + // todo: throw WeaviateConnectionException + throw new RuntimeException("connect to Weaviate", e); + } + + PoolingHttpClientConnectionManager syncManager = PoolingHttpClientConnectionManagerBuilder.create() + .setTlsSocketStrategy(tlsStrategy).build(); + PoolingAsyncClientConnectionManager asyncManager = PoolingAsyncClientConnectionManagerBuilder.create() + .setTlsStrategy(tlsStrategy).build(); + + httpClient.setConnectionManager(syncManager); + httpClientAsync.setConnectionManager(asyncManager); + } + if (transportOptions.tokenProvider() != null) { var interceptor = new AuthorizationInterceptor(transportOptions.tokenProvider()); httpClient.addRequestInterceptorFirst(interceptor); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java index 795695e72..80f3299af 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java @@ -4,6 +4,8 @@ import java.util.HashSet; import java.util.Map; +import javax.net.ssl.TrustManagerFactory; + import org.apache.hc.core5.http.message.BasicHeader; import io.weaviate.client6.v1.internal.TokenProvider; @@ -13,8 +15,8 @@ public final class RestTransportOptions extends TransportOptions headers, - TokenProvider tokenProvider) { - super(scheme, host, port, buildHeaders(headers), tokenProvider); + TokenProvider tokenProvider, TrustManagerFactory trust) { + super(scheme, host, port, buildHeaders(headers), tokenProvider, trust); } private static final Collection buildHeaders(Map headers) { diff --git a/src/it/java/io/weaviate/integration/AuthorizationITest.java b/src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java similarity index 90% rename from src/it/java/io/weaviate/integration/AuthorizationITest.java rename to src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java index 7c847e6bd..fcd74fe44 100644 --- a/src/it/java/io/weaviate/integration/AuthorizationITest.java +++ b/src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java @@ -1,4 +1,4 @@ -package io.weaviate.integration; +package io.weaviate.client6.v1.api; import java.io.IOException; import java.util.Collections; @@ -9,13 +9,11 @@ import org.mockserver.integration.ClientAndServer; import org.mockserver.model.HttpRequest; -import io.weaviate.ConcurrentTest; -import io.weaviate.client6.v1.api.Authorization; import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.RestTransportOptions; -public class AuthorizationITest extends ConcurrentTest { +public class AuthorizationTest { private ClientAndServer mockServer; @Before @@ -35,7 +33,7 @@ public void startMockServer() throws IOException { public void testAuthorization_apiKey() throws IOException { var transportOptions = new RestTransportOptions( "http", "localhost", mockServer.getLocalPort(), - Collections.emptyMap(), Authorization.apiKey("my-api-key")); + Collections.emptyMap(), Authorization.apiKey("my-api-key"), null); try (final var restClient = new DefaultRestTransport(transportOptions)) { restClient.performRequest(null, Endpoint.of( diff --git a/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java b/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java new file mode 100644 index 000000000..ab58ac5be --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java @@ -0,0 +1,90 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +import javax.net.ssl.TrustManagerFactory; + +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; + +import io.weaviate.testutil.truststore.SingleTrustManagerFactory; +import io.weaviate.testutil.truststore.SpyTrustManager; + +public class DefaultRestTransportTest { + private ClientAndServer mockServer; + private DefaultRestTransport transport; + private TrustManagerFactory tmf; + + @Before + public void setUp() throws IOException { + // MockServer does not verify exclusive ownership of the port + // and using any well-known port like 8080 will produce flaky + // test results with fairly confusing errors, like: + // + // path /mockserver/verifySequence was not found + // + // if another webserver is listening to that port. + // We use 0 to let the underlying system find an available port. + mockServer = ClientAndServer.startClientAndServer(0); + mockServer.withSecure(true); + + tmf = SingleTrustManagerFactory.create(new SpyTrustManager()); + transport = new DefaultRestTransport(new RestTransportOptions( + "https", "localhost", mockServer.getLocalPort(), + Collections.emptyMap(), null, tmf)); + } + + @Test + public void testCustomTrustStore_sync() throws IOException { + transport.performRequest(null, Endpoint.of( + request -> "GET", + request -> "/", + (gson, request) -> null, + request -> null, + code -> code != 200, + (gson, response) -> null)); + + mockServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/")); + + var spy = SpyTrustManager.getSpy(tmf); + Assertions.assertThat(spy).get() + .as("HttpClient uses custom TrustManager") + .returns(true, SpyTrustManager::wasUsed); + } + + @Test + public void testCustomTrustStore_async() throws IOException, ExecutionException, InterruptedException { + transport.performRequestAsync(null, Endpoint.of( + request -> "GET", + request -> "/", + (gson, request) -> null, + request -> null, + code -> code != 200, + (gson, response) -> null)).get(); + + mockServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/")); + + var spy = SpyTrustManager.getSpy(tmf); + Assertions.assertThat(spy).get() + .as("HttpClient uses custom TrustManager") + .returns(true, SpyTrustManager::wasUsed); + } + + @After + public void tearDown() throws IOException { + mockServer.stop(); + transport.close(); + } +} diff --git a/src/test/java/io/weaviate/testutil/truststore/SingleTrustManagerFactory.java b/src/test/java/io/weaviate/testutil/truststore/SingleTrustManagerFactory.java new file mode 100644 index 000000000..19f6456f0 --- /dev/null +++ b/src/test/java/io/weaviate/testutil/truststore/SingleTrustManagerFactory.java @@ -0,0 +1,50 @@ +package io.weaviate.testutil.truststore; + +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; + +/** TrustManagerFactory which always returns the same {@code TrustManager}. */ +public final class SingleTrustManagerFactory extends TrustManagerFactory { + + /** Create a factory that will return {@code TrustManager tm}. */ + public static TrustManagerFactory create(TrustManager tm) { + return new SingleTrustManagerFactory(tm); + } + + protected SingleTrustManagerFactory(TrustManager tm) { + super(new SingleTrustManagerFactorySpi(tm), null, TrustManagerFactory.getDefaultAlgorithm()); + } + + /** + * Naive {@code TrustManagerFactorySpi} implementation + * which always returns the same {@code TrustManager}. + */ + private static final class SingleTrustManagerFactorySpi extends TrustManagerFactorySpi { + private final TrustManager[] trustManagers; + + private SingleTrustManagerFactorySpi(TrustManager tm) { + this.trustManagers = new TrustManager[] { tm }; + } + + @Override + protected void engineInit(KeyStore ks) throws KeyStoreException { + return; + } + + @Override + protected void engineInit(ManagerFactoryParameters spec) throws InvalidAlgorithmParameterException { + return; + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return trustManagers; + } + } +} diff --git a/src/test/java/io/weaviate/testutil/truststore/SpyTrustManager.java b/src/test/java/io/weaviate/testutil/truststore/SpyTrustManager.java new file mode 100644 index 000000000..a4a881d0b --- /dev/null +++ b/src/test/java/io/weaviate/testutil/truststore/SpyTrustManager.java @@ -0,0 +1,45 @@ +package io.weaviate.testutil.truststore; + +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Optional; + +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * Test fixture that records when this TrustManager has been used. + * Combine with {@link SingleTrustManagerFactory#create} to mock + * a custom TrustStore. + */ +public class SpyTrustManager implements X509TrustManager { + private boolean used = false; + + public boolean wasUsed() { + return this.used; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + this.used = true; + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + this.used = true; + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + this.used = true; + return new X509Certificate[0]; + } + + public static Optional getSpy(TrustManagerFactory tmf) { + var managers = tmf.getTrustManagers(); + if (managers.length == 0) { + return Optional.empty(); + } + return Optional.of((SpyTrustManager) managers[0]); + } +}