middleware) {
+ return new FlightClient(incomingAllocator, channel, middleware);
+ }
+
/**
* Creates a Flight client.
*
diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java
new file mode 100644
index 000000000..42cdaac01
--- /dev/null
+++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.flight.grpc;
+
+import io.grpc.ManagedChannel;
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
+import java.io.InputStream;
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import javax.net.ssl.SSLException;
+import org.apache.arrow.flight.FlightClientMiddleware;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.LocationSchemes;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.Preconditions;
+
+/**
+ * A wrapper around gRPC's Netty builder.
+ *
+ * It is recommended to use the Netty channel builder directly with {@link
+ * org.apache.arrow.flight.FlightGrpcUtils#createFlightClient(BufferAllocator, ManagedChannel)}.
+ * However, this class provides an adapter that implements the existing Flight-specific builder
+ * interface but allows usage of the Netty builder as well.
+ */
+public class NettyClientBuilder {
+ /**
+ * The maximum number of trace events to keep on the gRPC Channel. This value disables channel
+ * tracing.
+ */
+ private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
+
+ protected BufferAllocator allocator;
+ protected Location location;
+ protected boolean forceTls = false;
+ protected int maxInboundMessageSize = Integer.MAX_VALUE;
+ protected InputStream trustedCertificates = null;
+ protected InputStream clientCertificate = null;
+ protected InputStream clientKey = null;
+ protected String overrideHostname = null;
+ protected List middleware = new ArrayList<>();
+ protected boolean verifyServer = true;
+
+ public NettyClientBuilder() {}
+
+ public NettyClientBuilder(BufferAllocator allocator, Location location) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ this.location = Preconditions.checkNotNull(location);
+ }
+
+ /** Force the client to connect over TLS. */
+ public NettyClientBuilder useTls() {
+ this.forceTls = true;
+ return this;
+ }
+
+ /** Override the hostname checked for TLS. Use with caution in production. */
+ public NettyClientBuilder overrideHostname(final String hostname) {
+ this.overrideHostname = hostname;
+ return this;
+ }
+
+ /** Set the maximum inbound message size. */
+ public NettyClientBuilder maxInboundMessageSize(int maxSize) {
+ Preconditions.checkArgument(maxSize > 0);
+ this.maxInboundMessageSize = maxSize;
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public NettyClientBuilder trustedCertificates(final InputStream stream) {
+ this.trustedCertificates = Preconditions.checkNotNull(stream);
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public NettyClientBuilder clientCertificate(
+ final InputStream clientCertificate, final InputStream clientKey) {
+ Preconditions.checkNotNull(clientKey);
+ this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
+ this.clientKey = Preconditions.checkNotNull(clientKey);
+ return this;
+ }
+
+ public BufferAllocator allocator() {
+ return allocator;
+ }
+
+ public NettyClientBuilder allocator(BufferAllocator allocator) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ return this;
+ }
+
+ public NettyClientBuilder location(Location location) {
+ this.location = Preconditions.checkNotNull(location);
+ return this;
+ }
+
+ public List middleware() {
+ return Collections.unmodifiableList(middleware);
+ }
+
+ public NettyClientBuilder intercept(FlightClientMiddleware.Factory factory) {
+ middleware.add(factory);
+ return this;
+ }
+
+ public NettyClientBuilder verifyServer(boolean verifyServer) {
+ this.verifyServer = verifyServer;
+ return this;
+ }
+
+ /** Create the client from this builder. */
+ public NettyChannelBuilder build() {
+ final NettyChannelBuilder builder;
+
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE:
+ case LocationSchemes.GRPC_TLS:
+ {
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ break;
+ }
+ case LocationSchemes.GRPC_DOMAIN_SOCKET:
+ {
+ // The implementation is platform-specific, so we have to find the classes at runtime
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ try {
+ try {
+ // Linux
+ builder.channelType(
+ Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
+ .asSubclass(ServerChannel.class));
+ final EventLoopGroup elg =
+ Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
+ .asSubclass(EventLoopGroup.class)
+ .getDeclaredConstructor()
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ } catch (ClassNotFoundException e) {
+ // BSD
+ builder.channelType(
+ Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
+ .asSubclass(ServerChannel.class));
+ final EventLoopGroup elg =
+ Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
+ .asSubclass(EventLoopGroup.class)
+ .getDeclaredConstructor()
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ }
+ } catch (ClassNotFoundException
+ | InstantiationException
+ | IllegalAccessException
+ | NoSuchMethodException
+ | InvocationTargetException e) {
+ throw new UnsupportedOperationException(
+ "Could not find suitable Netty native transport implementation for domain socket address.");
+ }
+ break;
+ }
+ default:
+ throw new IllegalArgumentException(
+ "Scheme is not supported: " + location.getUri().getScheme());
+ }
+
+ if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
+ builder.useTransportSecurity();
+
+ final boolean hasTrustedCerts = this.trustedCertificates != null;
+ final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
+ if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
+ throw new IllegalArgumentException(
+ "FlightClient has been configured to disable server verification, "
+ + "but certificate options have been specified.");
+ }
+
+ final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
+
+ if (!this.verifyServer) {
+ sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
+ } else if (this.trustedCertificates != null
+ || this.clientCertificate != null
+ || this.clientKey != null) {
+ if (this.trustedCertificates != null) {
+ sslContextBuilder.trustManager(this.trustedCertificates);
+ }
+ if (this.clientCertificate != null && this.clientKey != null) {
+ sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
+ }
+ }
+ try {
+ builder.sslContext(sslContextBuilder.build());
+ } catch (SSLException e) {
+ throw new RuntimeException(e);
+ }
+
+ if (this.overrideHostname != null) {
+ builder.overrideAuthority(this.overrideHostname);
+ }
+ } else {
+ builder.usePlaintext();
+ }
+
+ builder
+ .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
+ .maxInboundMessageSize(maxInboundMessageSize)
+ .maxInboundMetadataSize(maxInboundMessageSize);
+ return builder;
+ }
+}
diff --git a/flight/flight-sql-jdbc-core/pom.xml b/flight/flight-sql-jdbc-core/pom.xml
index 204e17efe..7c97fc026 100644
--- a/flight/flight-sql-jdbc-core/pom.xml
+++ b/flight/flight-sql-jdbc-core/pom.xml
@@ -128,6 +128,12 @@ under the License.
org.checkerframework
checker-qual
+
+
+ com.github.ben-manes.caffeine
+ caffeine
+ 3.1.8
+
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
index c1b1c8f8e..747287ed1 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
@@ -24,6 +24,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
+import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.memory.BufferAllocator;
@@ -113,6 +114,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.withCatalog(config.getCatalog())
+ .withClientCache(config.useClientCache() ? new FlightClientCache() : null)
+ .withConnectTimeout(config.getConnectTimeout())
.build();
} catch (final SQLException e) {
try {
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
index 0e9c79a09..06b8f6701 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
@@ -17,39 +17,24 @@
package org.apache.arrow.driver.jdbc.client;
import com.google.common.collect.ImmutableMap;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.channel.ChannelOption;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
+import java.time.Duration;
+import java.util.*;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
-import org.apache.arrow.flight.CallOption;
-import org.apache.arrow.flight.CallStatus;
-import org.apache.arrow.flight.CloseSessionRequest;
-import org.apache.arrow.flight.FlightClient;
-import org.apache.arrow.flight.FlightClientMiddleware;
-import org.apache.arrow.flight.FlightEndpoint;
-import org.apache.arrow.flight.FlightInfo;
-import org.apache.arrow.flight.FlightRuntimeException;
-import org.apache.arrow.flight.FlightStatusCode;
-import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.LocationSchemes;
-import org.apache.arrow.flight.SessionOptionValue;
-import org.apache.arrow.flight.SessionOptionValueFactory;
-import org.apache.arrow.flight.SetSessionOptionsRequest;
-import org.apache.arrow.flight.SetSessionOptionsResult;
+import org.apache.arrow.driver.jdbc.client.utils.FlightClientCache;
+import org.apache.arrow.driver.jdbc.client.utils.FlightLocationQueue;
+import org.apache.arrow.flight.*;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
+import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
@@ -70,21 +55,27 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable {
// JDBC connection string query parameter
private static final String CATALOG = "catalog";
+ private final String cacheKey;
private final FlightSqlClient sqlClient;
private final Set options = new HashSet<>();
private final Builder builder;
private final Optional catalog;
+ private final @Nullable FlightClientCache flightClientCache;
ArrowFlightSqlClientHandler(
+ final String cacheKey,
final FlightSqlClient sqlClient,
final Builder builder,
final Collection credentialOptions,
- final Optional catalog) {
+ final Optional catalog,
+ final @Nullable FlightClientCache flightClientCache) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
+ this.cacheKey = Preconditions.checkNotNull(cacheKey);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
this.catalog = catalog;
+ this.flightClientCache = flightClientCache;
}
/**
@@ -96,12 +87,15 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable {
* @return a new {@link ArrowFlightSqlClientHandler}.
*/
static ArrowFlightSqlClientHandler createNewHandler(
+ final String cacheKey,
final FlightClient client,
final Builder builder,
final Collection options,
- final Optional catalog) {
+ final Optional catalog,
+ final @Nullable FlightClientCache flightClientCache) {
final ArrowFlightSqlClientHandler handler =
- new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog);
+ new ArrowFlightSqlClientHandler(
+ cacheKey, new FlightSqlClient(client), builder, options, catalog, flightClientCache);
handler.setSetCatalogInSessionIfPresent();
return handler;
}
@@ -138,15 +132,19 @@ public List getStreams(final FlightInfo flightInfo)
// Clone the builder and then set the new endpoint on it.
// GH-38574: Currently a new FlightClient will be made for each partition that returns a
- // non-empty Location
- // then disposed of. It may be better to cache clients because a server may report the
- // same Locations.
- // It would also be good to identify when the reported location is the same as the
- // original connection's
- // Location and skip creating a FlightClient in that scenario.
+ // non-empty Location then disposed of. It may be better to cache clients because a server
+ // may report the same Locations. It would also be good to identify when the reported
+ // location
+ // is the same as the original connection's Location and skip creating a FlightClient in
+ // that scenario.
+ // Also copy the cache to the client so we can share a cache. Cache needs to cache
+ // negative attempts too.
List exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
- for (Location location : endpoint.getLocations()) {
+ FlightLocationQueue locations =
+ new FlightLocationQueue(flightClientCache, endpoint.getLocations());
+ while (locations.hasNext()) {
+ Location location = locations.next();
final URI endpointUri = location.getUri();
if (endpointUri.getScheme().equals(LocationSchemes.REUSE_CONNECTION)) {
stream =
@@ -158,7 +156,9 @@ public List getStreams(final FlightInfo flightInfo)
new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
- .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));
+ .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS))
+ .withClientCache(flightClientCache)
+ .withConnectTimeout(builder.connectTimeout);
ArrowFlightSqlClientHandler endpointHandler = null;
try {
@@ -172,11 +172,29 @@ public List getStreams(final FlightInfo flightInfo)
stream.getStream().getSchema();
} catch (Exception ex) {
if (endpointHandler != null) {
+ // If the exception is related to connectivity, mark the client as a dud.
+ if (flightClientCache != null) {
+ if (ex instanceof FlightRuntimeException
+ && ((FlightRuntimeException) ex).status().code()
+ == FlightStatusCode.UNAVAILABLE
+ &&
+ // IOException covers SocketException and Netty's (private)
+ // AnnotatedSocketException
+ // We are looking for things like "Network is unreachable"
+ ex.getCause() instanceof IOException) {
+ flightClientCache.markLocationAsDud(location.toString());
+ }
+ }
+
AutoCloseables.close(endpointHandler);
}
exceptions.add(ex);
continue;
}
+
+ if (flightClientCache != null) {
+ flightClientCache.markLocationAsReachable(location.toString());
+ }
break;
}
if (stream != null) {
@@ -543,6 +561,10 @@ public static final class Builder {
@VisibleForTesting Optional catalog = Optional.empty();
+ @VisibleForTesting @Nullable FlightClientCache flightClientCache;
+
+ @VisibleForTesting @Nullable Duration connectTimeout;
+
// These two middleware are for internal use within build() and should not be exposed by builder
// APIs.
// Note that these middleware may not necessarily be registered.
@@ -825,6 +847,27 @@ public Builder withCatalog(@Nullable final String catalog) {
return this;
}
+ public Builder withClientCache(FlightClientCache flightClientCache) {
+ this.flightClientCache = flightClientCache;
+ return this;
+ }
+
+ public Builder withConnectTimeout(Duration connectTimeout) {
+ this.connectTimeout = connectTimeout;
+ return this;
+ }
+
+ public String getCacheKey() {
+ return getLocation().toString();
+ }
+
+ public Location getLocation() {
+ if (useEncryption) {
+ return Location.forGrpcTls(host, port);
+ }
+ return Location.forGrpcInsecure(host, port);
+ }
+
/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
@@ -845,17 +888,15 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
- final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
+ final NettyClientBuilder clientBuilder = new NettyClientBuilder();
+ clientBuilder.allocator(allocator);
buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
- Location location;
if (useEncryption) {
- location = Location.forGrpcTls(host, port);
clientBuilder.useTls();
- } else {
- location = Location.forGrpcInsecure(host, port);
}
+ Location location = getLocation();
clientBuilder.location(location);
if (useEncryption) {
@@ -883,7 +924,14 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
}
}
- client = clientBuilder.build();
+ NettyChannelBuilder channelBuilder = clientBuilder.build();
+ if (connectTimeout != null) {
+ channelBuilder.withOption(
+ ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
+ }
+ client =
+ FlightGrpcUtils.createFlightClient(
+ allocator, channelBuilder.build(), clientBuilder.middleware());
final ArrayList credentialOptions = new ArrayList<>();
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
@@ -905,7 +953,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
options.toArray(new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(
- client, this, credentialOptions, catalog);
+ getCacheKey(), client, this, credentialOptions, catalog, flightClientCache);
} catch (final IllegalArgumentException
| GeneralSecurityException
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java
new file mode 100644
index 000000000..36e8441ba
--- /dev/null
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCache.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.driver.jdbc.client.utils;
+
+import com.github.benmanes.caffeine.cache.Cache;
+import com.github.benmanes.caffeine.cache.Caffeine;
+import java.time.Duration;
+import org.apache.arrow.util.VisibleForTesting;
+
+/**
+ * A cache for Flight clients.
+ *
+ * The intent is to avoid constantly recreating clients to the same locations. gRPC can multiplex
+ * multiple requests over a single TCP connection, and a cache would let us take advantage of that.
+ *
+ *
At the time being it only tracks whether a location is reachable or not. To actually cache
+ * clients, we would need a way to incorporate other connection parameters (authentication, etc.)
+ * into the cache key.
+ */
+public final class FlightClientCache {
+ @VisibleForTesting Cache clientCache;
+
+ public FlightClientCache() {
+ this.clientCache = Caffeine.newBuilder().expireAfterWrite(Duration.ofSeconds(600)).build();
+ }
+
+ public boolean isDud(String key) {
+ return clientCache.getIfPresent(key) != null;
+ }
+
+ public void markLocationAsDud(String key) {
+ clientCache.put(key, new ClientCacheEntry());
+ }
+
+ public void markLocationAsReachable(String key) {
+ clientCache.invalidate(key);
+ }
+
+ /** A cache entry (empty because we only track reachability, see outer class docstring). */
+ public static final class ClientCacheEntry {}
+}
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java
new file mode 100644
index 000000000..37050d6e1
--- /dev/null
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueue.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.driver.jdbc.client.utils;
+
+import java.util.*;
+import org.apache.arrow.flight.Location;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * A queue of Flight locations to connect to for an endpoint.
+ *
+ * This helper class is intended to encapsulate the retry logic in a testable manner.
+ */
+public final class FlightLocationQueue implements Iterator {
+ private final Deque locations;
+ private final Deque badLocations;
+
+ public FlightLocationQueue(
+ @Nullable FlightClientCache flightClientCache, List locations) {
+ this.locations = new ArrayDeque<>();
+ this.badLocations = new ArrayDeque<>();
+
+ for (Location location : locations) {
+ if (flightClientCache != null && flightClientCache.isDud(location.toString())) {
+ this.badLocations.add(location);
+ } else {
+ this.locations.add(location);
+ }
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !locations.isEmpty() || !badLocations.isEmpty();
+ }
+
+ @Override
+ public Location next() {
+ if (!locations.isEmpty()) {
+ return locations.pop();
+ } else if (!badLocations.isEmpty()) {
+ return badLocations.pop();
+ }
+ throw new NoSuchElementException();
+ }
+}
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
index e8bae2a20..76ba964a5 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
@@ -16,6 +16,7 @@
*/
package org.apache.arrow.driver.jdbc.utils;
+import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -163,6 +164,21 @@ public String getCatalog() {
return ArrowFlightConnectionProperty.CATALOG.getString(properties);
}
+ /** The initial connect timeout. */
+ public Duration getConnectTimeout() {
+ Integer timeout = ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.getInteger(properties);
+ if (timeout == null) {
+ return Duration.ofMillis(
+ (int) ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.defaultValue());
+ }
+ return Duration.ofMillis(timeout);
+ }
+
+ /** Whether to enable the client cache. */
+ public boolean useClientCache() {
+ return ArrowFlightConnectionProperty.USE_CLIENT_CACHE.getBoolean(properties);
+ }
+
/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
@@ -213,7 +229,10 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false),
- CATALOG("catalog", null, Type.STRING, false);
+ CATALOG("catalog", null, Type.STRING, false),
+ CONNECT_TIMEOUT_MILLIS("connectTimeoutMs", 10000, Type.NUMBER, false),
+ USE_CLIENT_CACHE("useClientCache", true, Type.BOOLEAN, false),
+ ;
private final String camelName;
private final Object defaultValue;
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
index a8d04dfc8..6e4e4a7fe 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
@@ -645,6 +645,138 @@ public void testFallbackSecondFlightServer() throws Exception {
}
}
+ @Test
+ public void testFallbackUnresolvableFlightServer() throws Exception {
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer =
+ new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer =
+ FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection =
+ DriverManager.getConnection(
+ String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
+ // This first attempt should take a measurable amount of time.
+ long start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ long attempt1 = System.nanoTime();
+ double elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+
+ // This second attempt should take less time, since the failure from before should be
+ // cached.
+ start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ attempt1 = System.nanoTime();
+ elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs < 5000.,
+ String.format("Expected second attempt to be faster, but %f ms elapsed", elapsedMs));
+ }
+ }
+ }
+
+ @Test
+ public void testFallbackUnresolvableFlightServerDisableCache() throws Exception {
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer =
+ new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer =
+ FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection =
+ DriverManager.getConnection(
+ String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false&useClientCache=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
+ // This first attempt should take a measurable amount of time.
+ long start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ long attempt1 = System.nanoTime();
+ double elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+
+ // This second attempt should take a long time still, since we disabled the cache.
+ start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ attempt1 = System.nanoTime();
+ elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected second attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+ }
+ }
+ }
+
@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
index 6beaba823..6524eaf39 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
@@ -147,6 +147,8 @@ public void testDefaults() {
assertNull(builder.clientCertificatePath);
assertNull(builder.clientKeyPath);
assertEquals(Optional.empty(), builder.catalog);
+ assertNull(builder.flightClientCache);
+ assertNull(builder.connectTimeout);
}
@Test
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java
new file mode 100644
index 000000000..8e818967a
--- /dev/null
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightClientCacheTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.driver.jdbc.client.utils;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.apache.arrow.flight.Location;
+import org.junit.jupiter.api.Test;
+
+class FlightClientCacheTest {
+ @Test
+ void basicOperation() {
+ FlightClientCache cache = new FlightClientCache();
+
+ Location location1 = Location.forGrpcInsecure("localhost", 8080);
+ Location location2 = Location.forGrpcInsecure("localhost", 8081);
+
+ assertFalse(cache.isDud(location1.toString()));
+ assertFalse(cache.isDud(location2.toString()));
+
+ cache.markLocationAsReachable(location1.toString());
+ assertFalse(cache.isDud(location1.toString()));
+ assertFalse(cache.isDud(location2.toString()));
+
+ cache.markLocationAsDud(location1.toString());
+ assertTrue(cache.isDud(location1.toString()));
+ assertFalse(cache.isDud(location2.toString()));
+
+ cache.markLocationAsDud(location2.toString());
+ assertTrue(cache.isDud(location1.toString()));
+ assertTrue(cache.isDud(location2.toString()));
+
+ cache.markLocationAsReachable(location1.toString());
+ assertFalse(cache.isDud(location1.toString()));
+ assertTrue(cache.isDud(location2.toString()));
+ }
+}
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java
new file mode 100644
index 000000000..0603f86e5
--- /dev/null
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/utils/FlightLocationQueueTest.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.driver.jdbc.client.utils;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.NoSuchElementException;
+import org.apache.arrow.flight.Location;
+import org.junit.jupiter.api.Test;
+
+class FlightLocationQueueTest {
+ @Test
+ void basicOperation() {
+ Location location1 = Location.forGrpcInsecure("localhost", 8080);
+ Location location2 = Location.forGrpcInsecure("localhost", 8081);
+ FlightLocationQueue queue = new FlightLocationQueue(null, List.of(location1, location2));
+ assertTrue(queue.hasNext());
+ assertEquals(location1, queue.next());
+ assertTrue(queue.hasNext());
+ assertEquals(location2, queue.next());
+ assertFalse(queue.hasNext());
+ }
+
+ @Test
+ void badAfterGood() {
+ Location location1 = Location.forGrpcInsecure("localhost", 8080);
+ Location location2 = Location.forGrpcInsecure("localhost", 8081);
+ FlightClientCache cache = new FlightClientCache();
+ cache.markLocationAsDud(location1.toString());
+ FlightLocationQueue queue = new FlightLocationQueue(cache, List.of(location1, location2));
+ assertTrue(queue.hasNext());
+ assertEquals(location2, queue.next());
+ assertTrue(queue.hasNext());
+ assertEquals(location1, queue.next());
+ assertFalse(queue.hasNext());
+ }
+
+ @Test
+ void iteratorInvariants() {
+ FlightLocationQueue empty = new FlightLocationQueue(null, Collections.emptyList());
+ assertFalse(empty.hasNext());
+ assertThrows(NoSuchElementException.class, empty::next);
+ }
+}
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
index 4a46b5f5b..f2cdfa43f 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
@@ -17,16 +17,11 @@
package org.apache.arrow.driver.jdbc.utils;
import static java.lang.Runtime.getRuntime;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.THREAD_POOL_SIZE;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER;
-import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USE_ENCRYPTION;
+import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.*;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
+import java.time.Duration;
import java.util.Properties;
import java.util.Random;
import java.util.function.Function;
@@ -59,49 +54,73 @@ public void setUp() {
public void testGetProperty(
ArrowFlightConnectionProperty property,
Object value,
+ Object expected,
Function configFunction) {
properties.put(property.camelName(), value);
arrowFlightConnectionConfigFunction = configFunction;
- assertThat(configFunction.apply(arrowFlightConnectionConfig), is(value));
- assertThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(value));
+ assertThat(configFunction.apply(arrowFlightConnectionConfig), is(expected));
+ assertThat(
+ arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(expected));
}
public static Stream provideParameters() {
+ int port = RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE));
+ boolean useEncryption = RANDOM.nextBoolean();
+ int threadPoolSize = RANDOM.nextInt(getRuntime().availableProcessors());
return Stream.of(
Arguments.of(
HOST,
"host",
+ "host",
(Function)
ArrowFlightConnectionConfigImpl::getHost),
Arguments.of(
PORT,
- RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)),
+ port,
+ port,
(Function)
ArrowFlightConnectionConfigImpl::getPort),
Arguments.of(
USER,
"user",
+ "user",
(Function)
ArrowFlightConnectionConfigImpl::getUser),
Arguments.of(
PASSWORD,
"password",
+ "password",
(Function)
ArrowFlightConnectionConfigImpl::getPassword),
Arguments.of(
USE_ENCRYPTION,
- RANDOM.nextBoolean(),
+ useEncryption,
+ useEncryption,
(Function)
ArrowFlightConnectionConfigImpl::useEncryption),
Arguments.of(
THREAD_POOL_SIZE,
- RANDOM.nextInt(getRuntime().availableProcessors()),
+ threadPoolSize,
+ threadPoolSize,
(Function)
ArrowFlightConnectionConfigImpl::threadPoolSize),
Arguments.of(
CATALOG,
"catalog",
+ "catalog",
+ (Function)
+ ArrowFlightConnectionConfigImpl::getCatalog),
+ Arguments.of(
+ CONNECT_TIMEOUT_MILLIS,
+ 5000,
+ Duration.ofMillis(5000),
+ (Function)
+ ArrowFlightConnectionConfigImpl::getConnectTimeout),
+ Arguments.of(
+ USE_CLIENT_CACHE,
+ false,
+ false,
(Function)
- ArrowFlightConnectionConfigImpl::getCatalog));
+ ArrowFlightConnectionConfigImpl::useClientCache));
}
}
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
index 9aa257172..670b9e3be 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
@@ -109,6 +109,16 @@ private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) {
Location.forGrpcInsecure("localhost", 9999),
Location.reuseConnection())
.build());
+ } else if (query.equals("fallback with unresolvable")) {
+ endpoints =
+ Collections.singletonList(
+ FlightEndpoint.builder(
+ ticket,
+ // Inaccessible IP
+ // https://stackoverflow.com/questions/10456044/what-is-a-good-invalid-ip-address-to-use-for-unit-tests
+ Location.forGrpcInsecure("203.0.113.0", 9999),
+ Location.reuseConnection())
+ .build());
} else {
throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException();
}