From 3c34f84dd52e708bfd5bbfb00f5bc485ba6c9f45 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 9 Jan 2025 01:16:05 -0500 Subject: [PATCH] GH-494: [Flight] Improve handling of unreachable locations in JDBC - Expose gRPC for the client builder - Cache failed locations and try them last - Allow configuring the connect timeout Fixes #494. --- .../org/apache/arrow/flight/FlightClient.java | 153 ++---------- .../apache/arrow/flight/FlightGrpcUtils.java | 14 ++ .../arrow/flight/grpc/NettyClientBuilder.java | 232 ++++++++++++++++++ flight/flight-sql-jdbc-core/pom.xml | 6 + .../driver/jdbc/ArrowFlightConnection.java | 3 + .../client/ArrowFlightSqlClientHandler.java | 130 ++++++---- .../jdbc/client/utils/FlightClientCache.java | 55 +++++ .../client/utils/FlightLocationQueue.java | 60 +++++ .../ArrowFlightConnectionConfigImpl.java | 21 +- .../arrow/driver/jdbc/ResultSetTest.java | 132 ++++++++++ ...rrowFlightSqlClientHandlerBuilderTest.java | 2 + .../client/utils/FlightClientCacheTest.java | 51 ++++ .../client/utils/FlightLocationQueueTest.java | 63 +++++ .../ArrowFlightConnectionConfigImplTest.java | 45 +++- .../jdbc/utils/FallbackFlightSqlProducer.java | 10 + 15 files changed, 787 insertions(+), 190 deletions(-) create mode 100644 flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java create mode 100644 flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java create mode 100644 flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java create mode 100644 flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java create mode 100644 flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index a15c3049a..9e6aba407 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -23,29 +23,21 @@ import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import io.grpc.StatusRuntimeException; -import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.ServerChannel; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.InvocationTargetException; import java.net.URISyntaxException; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.BooleanSupplier; -import javax.net.ssl.SSLException; import org.apache.arrow.flight.FlightProducer.StreamListener; import org.apache.arrow.flight.auth.BasicClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthHandler; @@ -57,6 +49,7 @@ import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; import org.apache.arrow.flight.grpc.ClientInterceptorAdapter; import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.grpc.NettyClientBuilder; import org.apache.arrow.flight.grpc.StatusUtils; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.Empty; @@ -72,11 +65,6 @@ /** Client for Flight services. */ public class FlightClient implements AutoCloseable { private static final int PENDING_REQUESTS = 5; - /** - * The maximum number of trace events to keep on the gRPC Channel. This value disables channel - * tracing. - */ - private static final int MAX_CHANNEL_TRACE_EVENTS = 0; private final BufferAllocator allocator; private final ManagedChannel channel; @@ -771,176 +759,71 @@ public static Builder builder(BufferAllocator allocator, Location location) { /** A builder for Flight clients. */ public static final class Builder { - private BufferAllocator allocator; - private Location location; - private boolean forceTls = false; - private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE; - private InputStream trustedCertificates = null; - private InputStream clientCertificate = null; - private InputStream clientKey = null; - private String overrideHostname = null; - private List middleware = new ArrayList<>(); - private boolean verifyServer = true; - - private Builder() {} + private final NettyClientBuilder builder; + + private Builder() { + this.builder = new NettyClientBuilder(); + } private Builder(BufferAllocator allocator, Location location) { - this.allocator = Preconditions.checkNotNull(allocator); - this.location = Preconditions.checkNotNull(location); + this.builder = new NettyClientBuilder(allocator, location); } /** Force the client to connect over TLS. */ public Builder useTls() { - this.forceTls = true; + builder.useTls(); return this; } /** Override the hostname checked for TLS. Use with caution in production. */ public Builder overrideHostname(final String hostname) { - this.overrideHostname = hostname; + builder.overrideHostname(hostname); return this; } /** Set the maximum inbound message size. */ public Builder maxInboundMessageSize(int maxSize) { - Preconditions.checkArgument(maxSize > 0); - this.maxInboundMessageSize = maxSize; + builder.maxInboundMessageSize(maxSize); return this; } /** Set the trusted TLS certificates. */ public Builder trustedCertificates(final InputStream stream) { - this.trustedCertificates = Preconditions.checkNotNull(stream); + builder.trustedCertificates(stream); return this; } /** Set the trusted TLS certificates. */ public Builder clientCertificate( final InputStream clientCertificate, final InputStream clientKey) { - Preconditions.checkNotNull(clientKey); - this.clientCertificate = Preconditions.checkNotNull(clientCertificate); - this.clientKey = Preconditions.checkNotNull(clientKey); + builder.clientCertificate(clientCertificate, clientKey); return this; } public Builder allocator(BufferAllocator allocator) { - this.allocator = Preconditions.checkNotNull(allocator); + builder.allocator(allocator); return this; } public Builder location(Location location) { - this.location = Preconditions.checkNotNull(location); + builder.location(location); return this; } public Builder intercept(FlightClientMiddleware.Factory factory) { - middleware.add(factory); + builder.intercept(factory); return this; } public Builder verifyServer(boolean verifyServer) { - this.verifyServer = verifyServer; + builder.verifyServer(verifyServer); return this; } /** Create the client from this builder. */ public FlightClient build() { - final NettyChannelBuilder builder; - - switch (location.getUri().getScheme()) { - case LocationSchemes.GRPC: - case LocationSchemes.GRPC_INSECURE: - case LocationSchemes.GRPC_TLS: - { - builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); - break; - } - case LocationSchemes.GRPC_DOMAIN_SOCKET: - { - // The implementation is platform-specific, so we have to find the classes at runtime - builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); - try { - try { - // Linux - builder.channelType( - Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel") - .asSubclass(ServerChannel.class)); - final EventLoopGroup elg = - Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") - .asSubclass(EventLoopGroup.class) - .getDeclaredConstructor() - .newInstance(); - builder.eventLoopGroup(elg); - } catch (ClassNotFoundException e) { - // BSD - builder.channelType( - Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel") - .asSubclass(ServerChannel.class)); - final EventLoopGroup elg = - Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") - .asSubclass(EventLoopGroup.class) - .getDeclaredConstructor() - .newInstance(); - builder.eventLoopGroup(elg); - } - } catch (ClassNotFoundException - | InstantiationException - | IllegalAccessException - | NoSuchMethodException - | InvocationTargetException e) { - throw new UnsupportedOperationException( - "Could not find suitable Netty native transport implementation for domain socket address."); - } - break; - } - default: - throw new IllegalArgumentException( - "Scheme is not supported: " + location.getUri().getScheme()); - } - - if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { - builder.useTransportSecurity(); - - final boolean hasTrustedCerts = this.trustedCertificates != null; - final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null; - if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) { - throw new IllegalArgumentException( - "FlightClient has been configured to disable server verification, " - + "but certificate options have been specified."); - } - - final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - - if (!this.verifyServer) { - sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); - } else if (this.trustedCertificates != null - || this.clientCertificate != null - || this.clientKey != null) { - if (this.trustedCertificates != null) { - sslContextBuilder.trustManager(this.trustedCertificates); - } - if (this.clientCertificate != null && this.clientKey != null) { - sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); - } - } - try { - builder.sslContext(sslContextBuilder.build()); - } catch (SSLException e) { - throw new RuntimeException(e); - } - - if (this.overrideHostname != null) { - builder.overrideAuthority(this.overrideHostname); - } - } else { - builder.usePlaintext(); - } - - builder - .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) - .maxInboundMessageSize(maxInboundMessageSize) - .maxInboundMetadataSize(maxInboundMessageSize); - return new FlightClient(allocator, builder.build(), middleware); + final NettyChannelBuilder channelBuilder = builder.build(); + return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware()); } } diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java index 13e4f2f21..df5e29741 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java @@ -23,6 +23,7 @@ import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import java.util.Collections; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.auth.ServerAuthHandler; @@ -151,6 +152,19 @@ public static FlightClient createFlightClient( return new FlightClient(incomingAllocator, channel, Collections.emptyList()); } + /** + * Creates a Flight client. + * + * @param incomingAllocator Memory allocator + * @param channel provides a connection to a gRPC server. + */ + public static FlightClient createFlightClient( + BufferAllocator incomingAllocator, + ManagedChannel channel, + List middleware) { + return new FlightClient(incomingAllocator, channel, middleware); + } + /** * Creates a Flight client. * diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java new file mode 100644 index 000000000..42cdaac01 --- /dev/null +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight.grpc; + +import io.grpc.ManagedChannel; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import javax.net.ssl.SSLException; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.LocationSchemes; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; + +/** + * A wrapper around gRPC's Netty builder. + * + *

It is recommended to use the Netty channel builder directly with {@link + * org.apache.arrow.flight.FlightGrpcUtils#createFlightClient(BufferAllocator, ManagedChannel)}. + * However, this class provides an adapter that implements the existing Flight-specific builder + * interface but allows usage of the Netty builder as well. + */ +public class NettyClientBuilder { + /** + * The maximum number of trace events to keep on the gRPC Channel. This value disables channel + * tracing. + */ + private static final int MAX_CHANNEL_TRACE_EVENTS = 0; + + protected BufferAllocator allocator; + protected Location location; + protected boolean forceTls = false; + protected int maxInboundMessageSize = Integer.MAX_VALUE; + protected InputStream trustedCertificates = null; + protected InputStream clientCertificate = null; + protected InputStream clientKey = null; + protected String overrideHostname = null; + protected List middleware = new ArrayList<>(); + protected boolean verifyServer = true; + + public NettyClientBuilder() {} + + public NettyClientBuilder(BufferAllocator allocator, Location location) { + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); + } + + /** Force the client to connect over TLS. */ + public NettyClientBuilder useTls() { + this.forceTls = true; + return this; + } + + /** Override the hostname checked for TLS. Use with caution in production. */ + public NettyClientBuilder overrideHostname(final String hostname) { + this.overrideHostname = hostname; + return this; + } + + /** Set the maximum inbound message size. */ + public NettyClientBuilder maxInboundMessageSize(int maxSize) { + Preconditions.checkArgument(maxSize > 0); + this.maxInboundMessageSize = maxSize; + return this; + } + + /** Set the trusted TLS certificates. */ + public NettyClientBuilder trustedCertificates(final InputStream stream) { + this.trustedCertificates = Preconditions.checkNotNull(stream); + return this; + } + + /** Set the trusted TLS certificates. */ + public NettyClientBuilder clientCertificate( + final InputStream clientCertificate, final InputStream clientKey) { + Preconditions.checkNotNull(clientKey); + this.clientCertificate = Preconditions.checkNotNull(clientCertificate); + this.clientKey = Preconditions.checkNotNull(clientKey); + return this; + } + + public BufferAllocator allocator() { + return allocator; + } + + public NettyClientBuilder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public NettyClientBuilder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + + public List middleware() { + return Collections.unmodifiableList(middleware); + } + + public NettyClientBuilder intercept(FlightClientMiddleware.Factory factory) { + middleware.add(factory); + return this; + } + + public NettyClientBuilder verifyServer(boolean verifyServer) { + this.verifyServer = verifyServer; + return this; + } + + /** Create the client from this builder. */ + public NettyChannelBuilder build() { + final NettyChannelBuilder builder; + + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: + case LocationSchemes.GRPC_TLS: + { + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + break; + } + case LocationSchemes.GRPC_DOMAIN_SOCKET: + { + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = + Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .asSubclass(EventLoopGroup.class) + .getDeclaredConstructor() + .newInstance(); + builder.eventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = + Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .asSubclass(EventLoopGroup.class) + .getDeclaredConstructor() + .newInstance(); + builder.eventLoopGroup(elg); + } + } catch (ClassNotFoundException + | InstantiationException + | IllegalAccessException + | NoSuchMethodException + | InvocationTargetException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } + break; + } + default: + throw new IllegalArgumentException( + "Scheme is not supported: " + location.getUri().getScheme()); + } + + if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { + builder.useTransportSecurity(); + + final boolean hasTrustedCerts = this.trustedCertificates != null; + final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null; + if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) { + throw new IllegalArgumentException( + "FlightClient has been configured to disable server verification, " + + "but certificate options have been specified."); + } + + final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + + if (!this.verifyServer) { + sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } else if (this.trustedCertificates != null + || this.clientCertificate != null + || this.clientKey != null) { + if (this.trustedCertificates != null) { + sslContextBuilder.trustManager(this.trustedCertificates); + } + if (this.clientCertificate != null && this.clientKey != null) { + sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); + } + } + try { + builder.sslContext(sslContextBuilder.build()); + } catch (SSLException e) { + throw new RuntimeException(e); + } + + if (this.overrideHostname != null) { + builder.overrideAuthority(this.overrideHostname); + } + } else { + builder.usePlaintext(); + } + + builder + .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) + .maxInboundMessageSize(maxInboundMessageSize) + .maxInboundMetadataSize(maxInboundMessageSize); + return builder; + } +} diff --git a/flight/flight-sql-jdbc-core/pom.xml b/flight/flight-sql-jdbc-core/pom.xml index 204e17efe..7c97fc026 100644 --- a/flight/flight-sql-jdbc-core/pom.xml +++ b/flight/flight-sql-jdbc-core/pom.xml @@ -128,6 +128,12 @@ under the License. org.checkerframework checker-qual + + + com.github.ben-manes.caffeine + caffeine + 3.1.8 + diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index c1b1c8f8e..747287ed1 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -24,6 +24,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache; import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.memory.BufferAllocator; @@ -113,6 +114,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler( .withRetainCookies(config.retainCookies()) .withRetainAuth(config.retainAuth()) .withCatalog(config.getCatalog()) + .withClientCache(config.useClientCache() ? new FlightClientCache() : null) + .withConnectTimeout(config.getConnectTimeout()) .build(); } catch (final SQLException e) { try { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 0e9c79a09..06b8f6701 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -17,39 +17,24 @@ package org.apache.arrow.driver.jdbc.client; import com.google.common.collect.ImmutableMap; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.ChannelOption; import java.io.IOException; import java.net.URI; import java.security.GeneralSecurityException; import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; +import java.time.Duration; +import java.util.*; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; -import org.apache.arrow.flight.CallOption; -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.CloseSessionRequest; -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightClientMiddleware; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightRuntimeException; -import org.apache.arrow.flight.FlightStatusCode; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.LocationSchemes; -import org.apache.arrow.flight.SessionOptionValue; -import org.apache.arrow.flight.SessionOptionValueFactory; -import org.apache.arrow.flight.SetSessionOptionsRequest; -import org.apache.arrow.flight.SetSessionOptionsResult; +import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache; +import org.apache.arrow.driver.jdbc.client.utils.FlightLocationQueue; +import org.apache.arrow.flight.*; import org.apache.arrow.flight.auth2.BearerCredentialWriter; import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; import org.apache.arrow.flight.client.ClientCookieMiddleware; import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.grpc.NettyClientBuilder; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; import org.apache.arrow.flight.sql.util.TableRef; @@ -70,21 +55,27 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { // JDBC connection string query parameter private static final String CATALOG = "catalog"; + private final String cacheKey; private final FlightSqlClient sqlClient; private final Set options = new HashSet<>(); private final Builder builder; private final Optional catalog; + private final @Nullable FlightClientCache flightClientCache; ArrowFlightSqlClientHandler( + final String cacheKey, final FlightSqlClient sqlClient, final Builder builder, final Collection credentialOptions, - final Optional catalog) { + final Optional catalog, + final @Nullable FlightClientCache flightClientCache) { this.options.addAll(builder.options); this.options.addAll(credentialOptions); + this.cacheKey = Preconditions.checkNotNull(cacheKey); this.sqlClient = Preconditions.checkNotNull(sqlClient); this.builder = builder; this.catalog = catalog; + this.flightClientCache = flightClientCache; } /** @@ -96,12 +87,15 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { * @return a new {@link ArrowFlightSqlClientHandler}. */ static ArrowFlightSqlClientHandler createNewHandler( + final String cacheKey, final FlightClient client, final Builder builder, final Collection options, - final Optional catalog) { + final Optional catalog, + final @Nullable FlightClientCache flightClientCache) { final ArrowFlightSqlClientHandler handler = - new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog); + new ArrowFlightSqlClientHandler( + cacheKey, new FlightSqlClient(client), builder, options, catalog, flightClientCache); handler.setSetCatalogInSessionIfPresent(); return handler; } @@ -138,15 +132,19 @@ public List getStreams(final FlightInfo flightInfo) // Clone the builder and then set the new endpoint on it. // GH-38574: Currently a new FlightClient will be made for each partition that returns a - // non-empty Location - // then disposed of. It may be better to cache clients because a server may report the - // same Locations. - // It would also be good to identify when the reported location is the same as the - // original connection's - // Location and skip creating a FlightClient in that scenario. + // non-empty Location then disposed of. It may be better to cache clients because a server + // may report the same Locations. It would also be good to identify when the reported + // location + // is the same as the original connection's Location and skip creating a FlightClient in + // that scenario. + // Also copy the cache to the client so we can share a cache. Cache needs to cache + // negative attempts too. List exceptions = new ArrayList<>(); CloseableEndpointStreamPair stream = null; - for (Location location : endpoint.getLocations()) { + FlightLocationQueue locations = + new FlightLocationQueue(flightClientCache, endpoint.getLocations()); + while (locations.hasNext()) { + Location location = locations.next(); final URI endpointUri = location.getUri(); if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) { stream = @@ -158,7 +156,9 @@ public List getStreams(final FlightInfo flightInfo) new Builder(ArrowFlightSqlClientHandler.this.builder) .withHost(endpointUri.getHost()) .withPort(endpointUri.getPort()) - .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS)); + .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS)) + .withClientCache(flightClientCache) + .withConnectTimeout(builder.connectTimeout); ArrowFlightSqlClientHandler endpointHandler = null; try { @@ -172,11 +172,29 @@ public List getStreams(final FlightInfo flightInfo) stream.getStream().getSchema(); } catch (Exception ex) { if (endpointHandler != null) { + // If the exception is related to connectivity, mark the client as a dud. + if (flightClientCache != null) { + if (ex instanceof FlightRuntimeException + && ((FlightRuntimeException) ex).status().code() + == FlightStatusCode.UNAVAILABLE + && + // IOException covers SocketException and Netty's (private) + // AnnotatedSocketException + // We are looking for things like "Network is unreachable" + ex.getCause() instanceof IOException) { + flightClientCache.markLocationAsDud(location.toString()); + } + } + AutoCloseables.close(endpointHandler); } exceptions.add(ex); continue; } + + if (flightClientCache != null) { + flightClientCache.markLocationAsReachable(location.toString()); + } break; } if (stream != null) { @@ -543,6 +561,10 @@ public static final class Builder { @VisibleForTesting Optional catalog = Optional.empty(); + @VisibleForTesting @Nullable FlightClientCache flightClientCache; + + @VisibleForTesting @Nullable Duration connectTimeout; + // These two middleware are for internal use within build() and should not be exposed by builder // APIs. // Note that these middleware may not necessarily be registered. @@ -825,6 +847,27 @@ public Builder withCatalog(@Nullable final String catalog) { return this; } + public Builder withClientCache(FlightClientCache flightClientCache) { + this.flightClientCache = flightClientCache; + return this; + } + + public Builder withConnectTimeout(Duration connectTimeout) { + this.connectTimeout = connectTimeout; + return this; + } + + public String getCacheKey() { + return getLocation().toString(); + } + + public Location getLocation() { + if (useEncryption) { + return Location.forGrpcTls(host, port); + } + return Location.forGrpcInsecure(host, port); + } + /** * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. * @@ -845,17 +888,15 @@ public ArrowFlightSqlClientHandler build() throws SQLException { if (isUsingUserPasswordAuth) { buildTimeMiddlewareFactories.add(authFactory); } - final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); + final NettyClientBuilder clientBuilder = new NettyClientBuilder(); + clientBuilder.allocator(allocator); buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory()); buildTimeMiddlewareFactories.forEach(clientBuilder::intercept); - Location location; if (useEncryption) { - location = Location.forGrpcTls(host, port); clientBuilder.useTls(); - } else { - location = Location.forGrpcInsecure(host, port); } + Location location = getLocation(); clientBuilder.location(location); if (useEncryption) { @@ -883,7 +924,14 @@ public ArrowFlightSqlClientHandler build() throws SQLException { } } - client = clientBuilder.build(); + NettyChannelBuilder channelBuilder = clientBuilder.build(); + if (connectTimeout != null) { + channelBuilder.withOption( + ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis()); + } + client = + FlightGrpcUtils.createFlightClient( + allocator, channelBuilder.build(), clientBuilder.middleware()); final ArrayList credentialOptions = new ArrayList<>(); if (isUsingUserPasswordAuth) { // If the authFactory has already been used for a handshake, use the existing token. @@ -905,7 +953,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException { options.toArray(new CallOption[0]))); } return ArrowFlightSqlClientHandler.createNewHandler( - client, this, credentialOptions, catalog); + getCacheKey(), client, this, credentialOptions, catalog, flightClientCache); } catch (final IllegalArgumentException | GeneralSecurityException diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java new file mode 100644 index 000000000..36e8441ba --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.client.utils; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import java.time.Duration; +import org.apache.arrow.util.VisibleForTesting; + +/** + * A cache for Flight clients. + * + *

The intent is to avoid constantly recreating clients to the same locations. gRPC can multiplex + * multiple requests over a single TCP connection, and a cache would let us take advantage of that. + * + *

At the time being it only tracks whether a location is reachable or not. To actually cache + * clients, we would need a way to incorporate other connection parameters (authentication, etc.) + * into the cache key. + */ +public final class FlightClientCache { + @VisibleForTesting Cache clientCache; + + public FlightClientCache() { + this.clientCache = Caffeine.newBuilder().expireAfterWrite(Duration.ofSeconds(600)).build(); + } + + public boolean isDud(String key) { + return clientCache.getIfPresent(key) != null; + } + + public void markLocationAsDud(String key) { + clientCache.put(key, new ClientCacheEntry()); + } + + public void markLocationAsReachable(String key) { + clientCache.invalidate(key); + } + + /** A cache entry (empty because we only track reachability, see outer class docstring). */ + public static final class ClientCacheEntry {} +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java new file mode 100644 index 000000000..37050d6e1 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.client.utils; + +import java.util.*; +import org.apache.arrow.flight.Location; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A queue of Flight locations to connect to for an endpoint. + * + *

This helper class is intended to encapsulate the retry logic in a testable manner. + */ +public final class FlightLocationQueue implements Iterator { + private final Deque locations; + private final Deque badLocations; + + public FlightLocationQueue( + @Nullable FlightClientCache flightClientCache, List locations) { + this.locations = new ArrayDeque<>(); + this.badLocations = new ArrayDeque<>(); + + for (Location location : locations) { + if (flightClientCache != null && flightClientCache.isDud(location.toString())) { + this.badLocations.add(location); + } else { + this.locations.add(location); + } + } + } + + @Override + public boolean hasNext() { + return !locations.isEmpty() || !badLocations.isEmpty(); + } + + @Override + public Location next() { + if (!locations.isEmpty()) { + return locations.pop(); + } else if (!badLocations.isEmpty()) { + return badLocations.pop(); + } + throw new NoSuchElementException(); + } +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index e8bae2a20..76ba964a5 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -16,6 +16,7 @@ */ package org.apache.arrow.driver.jdbc.utils; +import java.time.Duration; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -163,6 +164,21 @@ public String getCatalog() { return ArrowFlightConnectionProperty.CATALOG.getString(properties); } + /** The initial connect timeout. */ + public Duration getConnectTimeout() { + Integer timeout = ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.getInteger(properties); + if (timeout == null) { + return Duration.ofMillis( + (int) ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.defaultValue()); + } + return Duration.ofMillis(timeout); + } + + /** Whether to enable the client cache. */ + public boolean useClientCache() { + return ArrowFlightConnectionProperty.USE_CLIENT_CACHE.getBoolean(properties); + } + /** * Gets the {@link CallOption}s from this {@link ConnectionConfig}. * @@ -213,7 +229,10 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty { TOKEN("token", null, Type.STRING, false), RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false), RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false), - CATALOG("catalog", null, Type.STRING, false); + CATALOG("catalog", null, Type.STRING, false), + CONNECT_TIMEOUT_MILLIS("connectTimeoutMs", 10000, Type.NUMBER, false), + USE_CLIENT_CACHE("useClientCache", true, Type.BOOLEAN, false), + ; private final String camelName; private final Object defaultValue; diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index a8d04dfc8..6e4e4a7fe 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -645,6 +645,138 @@ public void testFallbackSecondFlightServer() throws Exception { } } + @Test + public void testFallbackUnresolvableFlightServer() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = + new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = + FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = + DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) { + // This first attempt should take a measurable amount of time. + long start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + long attempt1 = System.nanoTime(); + double elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + + // This second attempt should take less time, since the failure from before should be + // cached. + start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + attempt1 = System.nanoTime(); + elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs < 5000., + String.format("Expected second attempt to be faster, but %f ms elapsed", elapsedMs)); + } + } + } + + @Test + public void testFallbackUnresolvableFlightServerDisableCache() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = + new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = + FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = + DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false&useClientCache=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) { + // This first attempt should take a measurable amount of time. + long start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + long attempt1 = System.nanoTime(); + double elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + + // This second attempt should take a long time still, since we disabled the cache. + start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + attempt1 = System.nanoTime(); + elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected second attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + } + } + } + @Test public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception { try (Statement statement = connection.createStatement(); diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java index 6beaba823..6524eaf39 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java @@ -147,6 +147,8 @@ public void testDefaults() { assertNull(builder.clientCertificatePath); assertNull(builder.clientKeyPath); assertEquals(Optional.empty(), builder.catalog); + assertNull(builder.flightClientCache); + assertNull(builder.connectTimeout); } @Test diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java new file mode 100644 index 000000000..8e818967a --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.client.utils; + +import static org.junit.jupiter.api.Assertions.*; + +import org.apache.arrow.flight.Location; +import org.junit.jupiter.api.Test; + +class FlightClientCacheTest { + @Test + void basicOperation() { + FlightClientCache cache = new FlightClientCache(); + + Location location1 = Location.forGrpcInsecure("localhost", 8080); + Location location2 = Location.forGrpcInsecure("localhost", 8081); + + assertFalse(cache.isDud(location1.toString())); + assertFalse(cache.isDud(location2.toString())); + + cache.markLocationAsReachable(location1.toString()); + assertFalse(cache.isDud(location1.toString())); + assertFalse(cache.isDud(location2.toString())); + + cache.markLocationAsDud(location1.toString()); + assertTrue(cache.isDud(location1.toString())); + assertFalse(cache.isDud(location2.toString())); + + cache.markLocationAsDud(location2.toString()); + assertTrue(cache.isDud(location1.toString())); + assertTrue(cache.isDud(location2.toString())); + + cache.markLocationAsReachable(location1.toString()); + assertFalse(cache.isDud(location1.toString())); + assertTrue(cache.isDud(location2.toString())); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java new file mode 100644 index 000000000..0603f86e5 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.client.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import org.apache.arrow.flight.Location; +import org.junit.jupiter.api.Test; + +class FlightLocationQueueTest { + @Test + void basicOperation() { + Location location1 = Location.forGrpcInsecure("localhost", 8080); + Location location2 = Location.forGrpcInsecure("localhost", 8081); + FlightLocationQueue queue = new FlightLocationQueue(null, List.of(location1, location2)); + assertTrue(queue.hasNext()); + assertEquals(location1, queue.next()); + assertTrue(queue.hasNext()); + assertEquals(location2, queue.next()); + assertFalse(queue.hasNext()); + } + + @Test + void badAfterGood() { + Location location1 = Location.forGrpcInsecure("localhost", 8080); + Location location2 = Location.forGrpcInsecure("localhost", 8081); + FlightClientCache cache = new FlightClientCache(); + cache.markLocationAsDud(location1.toString()); + FlightLocationQueue queue = new FlightLocationQueue(cache, List.of(location1, location2)); + assertTrue(queue.hasNext()); + assertEquals(location2, queue.next()); + assertTrue(queue.hasNext()); + assertEquals(location1, queue.next()); + assertFalse(queue.hasNext()); + } + + @Test + void iteratorInvariants() { + FlightLocationQueue empty = new FlightLocationQueue(null, Collections.emptyList()); + assertFalse(empty.hasNext()); + assertThrows(NoSuchElementException.class, empty::next); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java index 4a46b5f5b..f2cdfa43f 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -17,16 +17,11 @@ package org.apache.arrow.driver.jdbc.utils; import static java.lang.Runtime.getRuntime; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.THREAD_POOL_SIZE; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER; -import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USE_ENCRYPTION; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.*; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import java.time.Duration; import java.util.Properties; import java.util.Random; import java.util.function.Function; @@ -59,49 +54,73 @@ public void setUp() { public void testGetProperty( ArrowFlightConnectionProperty property, Object value, + Object expected, Function configFunction) { properties.put(property.camelName(), value); arrowFlightConnectionConfigFunction = configFunction; - assertThat(configFunction.apply(arrowFlightConnectionConfig), is(value)); - assertThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(value)); + assertThat(configFunction.apply(arrowFlightConnectionConfig), is(expected)); + assertThat( + arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(expected)); } public static Stream provideParameters() { + int port = RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)); + boolean useEncryption = RANDOM.nextBoolean(); + int threadPoolSize = RANDOM.nextInt(getRuntime().availableProcessors()); return Stream.of( Arguments.of( HOST, "host", + "host", (Function) ArrowFlightConnectionConfigImpl::getHost), Arguments.of( PORT, - RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)), + port, + port, (Function) ArrowFlightConnectionConfigImpl::getPort), Arguments.of( USER, "user", + "user", (Function) ArrowFlightConnectionConfigImpl::getUser), Arguments.of( PASSWORD, "password", + "password", (Function) ArrowFlightConnectionConfigImpl::getPassword), Arguments.of( USE_ENCRYPTION, - RANDOM.nextBoolean(), + useEncryption, + useEncryption, (Function) ArrowFlightConnectionConfigImpl::useEncryption), Arguments.of( THREAD_POOL_SIZE, - RANDOM.nextInt(getRuntime().availableProcessors()), + threadPoolSize, + threadPoolSize, (Function) ArrowFlightConnectionConfigImpl::threadPoolSize), Arguments.of( CATALOG, "catalog", + "catalog", + (Function) + ArrowFlightConnectionConfigImpl::getCatalog), + Arguments.of( + CONNECT_TIMEOUT_MILLIS, + 5000, + Duration.ofMillis(5000), + (Function) + ArrowFlightConnectionConfigImpl::getConnectTimeout), + Arguments.of( + USE_CLIENT_CACHE, + false, + false, (Function) - ArrowFlightConnectionConfigImpl::getCatalog)); + ArrowFlightConnectionConfigImpl::useClientCache)); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java index 9aa257172..670b9e3be 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java @@ -109,6 +109,16 @@ private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) { Location.forGrpcInsecure("localhost", 9999), Location.reuseConnection()) .build()); + } else if (query.equals("fallback with unresolvable")) { + endpoints = + Collections.singletonList( + FlightEndpoint.builder( + ticket, + // Inaccessible IP + // https://stackoverflow.com/questions/10456044/what-is-a-good-invalid-ip-address-to-use-for-unit-tests + Location.forGrpcInsecure("203.0.113.0", 9999), + Location.reuseConnection()) + .build()); } else { throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException(); }