Skip to content

Commit

Permalink
apacheGH-494: [Flight] Improve handling of unreachable locations in JDBC
Browse files Browse the repository at this point in the history
- Expose gRPC for the client builder
- Cache failed locations and try them last
- Allow configuring the connect timeout

Fixes apache#494.
  • Loading branch information
lidavidm committed Jan 9, 2025
1 parent 3440633 commit 3c34f84
Show file tree
Hide file tree
Showing 15 changed files with 787 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,21 @@
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import javax.net.ssl.SSLException;
import org.apache.arrow.flight.FlightProducer.StreamListener;
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
Expand All @@ -57,6 +49,7 @@
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.Empty;
Expand All @@ -72,11 +65,6 @@
/** Client for Flight services. */
public class FlightClient implements AutoCloseable {
private static final int PENDING_REQUESTS = 5;
/**
* The maximum number of trace events to keep on the gRPC Channel. This value disables channel
* tracing.
*/
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;

private final BufferAllocator allocator;
private final ManagedChannel channel;
Expand Down Expand Up @@ -771,176 +759,71 @@ public static Builder builder(BufferAllocator allocator, Location location) {

/** A builder for Flight clients. */
public static final class Builder {
private BufferAllocator allocator;
private Location location;
private boolean forceTls = false;
private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
private InputStream trustedCertificates = null;
private InputStream clientCertificate = null;
private InputStream clientKey = null;
private String overrideHostname = null;
private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
private boolean verifyServer = true;

private Builder() {}
private final NettyClientBuilder builder;

private Builder() {
this.builder = new NettyClientBuilder();
}

private Builder(BufferAllocator allocator, Location location) {
this.allocator = Preconditions.checkNotNull(allocator);
this.location = Preconditions.checkNotNull(location);
this.builder = new NettyClientBuilder(allocator, location);
}

/** Force the client to connect over TLS. */
public Builder useTls() {
this.forceTls = true;
builder.useTls();
return this;
}

/** Override the hostname checked for TLS. Use with caution in production. */
public Builder overrideHostname(final String hostname) {
this.overrideHostname = hostname;
builder.overrideHostname(hostname);
return this;
}

/** Set the maximum inbound message size. */
public Builder maxInboundMessageSize(int maxSize) {
Preconditions.checkArgument(maxSize > 0);
this.maxInboundMessageSize = maxSize;
builder.maxInboundMessageSize(maxSize);
return this;
}

/** Set the trusted TLS certificates. */
public Builder trustedCertificates(final InputStream stream) {
this.trustedCertificates = Preconditions.checkNotNull(stream);
builder.trustedCertificates(stream);
return this;
}

/** Set the trusted TLS certificates. */
public Builder clientCertificate(
final InputStream clientCertificate, final InputStream clientKey) {
Preconditions.checkNotNull(clientKey);
this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
this.clientKey = Preconditions.checkNotNull(clientKey);
builder.clientCertificate(clientCertificate, clientKey);
return this;
}

public Builder allocator(BufferAllocator allocator) {
this.allocator = Preconditions.checkNotNull(allocator);
builder.allocator(allocator);
return this;
}

public Builder location(Location location) {
this.location = Preconditions.checkNotNull(location);
builder.location(location);
return this;
}

public Builder intercept(FlightClientMiddleware.Factory factory) {
middleware.add(factory);
builder.intercept(factory);
return this;
}

public Builder verifyServer(boolean verifyServer) {
this.verifyServer = verifyServer;
builder.verifyServer(verifyServer);
return this;
}

/** Create the client from this builder. */
public FlightClient build() {
final NettyChannelBuilder builder;

switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE:
case LocationSchemes.GRPC_TLS:
{
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
break;
}
case LocationSchemes.GRPC_DOMAIN_SOCKET:
{
// The implementation is platform-specific, so we have to find the classes at runtime
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
try {
try {
// Linux
builder.channelType(
Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
} catch (ClassNotFoundException e) {
// BSD
builder.channelType(
Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
}
} catch (ClassNotFoundException
| InstantiationException
| IllegalAccessException
| NoSuchMethodException
| InvocationTargetException e) {
throw new UnsupportedOperationException(
"Could not find suitable Netty native transport implementation for domain socket address.");
}
break;
}
default:
throw new IllegalArgumentException(
"Scheme is not supported: " + location.getUri().getScheme());
}

if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
builder.useTransportSecurity();

final boolean hasTrustedCerts = this.trustedCertificates != null;
final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
throw new IllegalArgumentException(
"FlightClient has been configured to disable server verification, "
+ "but certificate options have been specified.");
}

final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();

if (!this.verifyServer) {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
} else if (this.trustedCertificates != null
|| this.clientCertificate != null
|| this.clientKey != null) {
if (this.trustedCertificates != null) {
sslContextBuilder.trustManager(this.trustedCertificates);
}
if (this.clientCertificate != null && this.clientKey != null) {
sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
}
}
try {
builder.sslContext(sslContextBuilder.build());
} catch (SSLException e) {
throw new RuntimeException(e);
}

if (this.overrideHostname != null) {
builder.overrideAuthority(this.overrideHostname);
}
} else {
builder.usePlaintext();
}

builder
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
.maxInboundMessageSize(maxInboundMessageSize)
.maxInboundMetadataSize(maxInboundMessageSize);
return new FlightClient(allocator, builder.build(), middleware);
final NettyChannelBuilder channelBuilder = builder.build();
return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
Expand Down Expand Up @@ -151,6 +152,19 @@ public static FlightClient createFlightClient(
return new FlightClient(incomingAllocator, channel, Collections.emptyList());
}

/**
* Creates a Flight client.
*
* @param incomingAllocator Memory allocator
* @param channel provides a connection to a gRPC server.
*/
public static FlightClient createFlightClient(
BufferAllocator incomingAllocator,
ManagedChannel channel,
List<FlightClientMiddleware.Factory> middleware) {
return new FlightClient(incomingAllocator, channel, middleware);
}

/**
* Creates a Flight client.
*
Expand Down
Loading

0 comments on commit 3c34f84

Please sign in to comment.