diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 422ed01c394d1..91e3b4d052f39 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -429,6 +429,16 @@ public ClientStreamListener getWriter() { return writer; } + /** + * Make sure stream is drained. You must call this to be notified of any errors that may have + * happened after the exchange is complete. This should be called after `getWriter().completed()` + * and instead of `getWriter().getResult()`. + */ + public void getResult() { + // After exchange is complete, make sure stream is drained to propagate errors through reader + while (reader.next()) { }; + } + /** Shut down the streams in this call. */ @Override public void close() throws Exception { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index 03ce13c9780e3..ad4ffcbebdec1 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -207,6 +207,8 @@ public void close() throws Exception { } else { AutoCloseables.close(closeables); } + // Remove any metadata after closing to prevent negative refcnt + applicationMetadata = null; } finally { // The value of this CompletableFuture is meaningless, only whether it's completed (or has an exception) // No-op if already complete diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java index c2f8e75596904..f9db9bfd23a88 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -55,6 +55,7 @@ public class TestDoExchange { static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8); static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8); static byte[] EXCHANGE_CANCEL = "cancel".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_ERROR = "error".getBytes(StandardCharsets.UTF_8); private BufferAllocator allocator; private FlightServer server; @@ -365,6 +366,37 @@ public void testClientCancel() throws Exception { } } + /** Test a DoExchange error handling. */ + @Test + public void testDoExchangeError() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(EXCHANGE_ERROR)); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightStream reader = stream.getReader(); + + // Write data and check that it gets echoed back. + IntVector iv = (IntVector) root.getVector("a"); + iv.allocateNew(); + stream.getWriter().start(root); + for (int i = 0; i < 10; i++) { + iv.setSafe(0, i); + root.setRowCount(1); + stream.getWriter().putNext(); + + assertTrue(reader.next()); + assertEquals(root.getSchema(), reader.getSchema()); + assertEquals(i, ((IntVector) reader.getRoot().getVector("a")).get(0)); + } + + // Complete the stream so that the server knows not to expect any more messages from us. + stream.getWriter().completed(); + + // Must call reader.next() to get any errors after exchange, will return false if no error + final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, stream::getResult); + assertEquals("error completing exchange", fre.status().description()); + } + } + /** Have the client close the stream without reading; ensure memory is not leaked. */ @Test public void testClientClose() throws Exception { @@ -381,6 +413,38 @@ public void testClientClose() throws Exception { client = null; } + /** Test closing with Metadata can't lead to error. */ + @Test + public void testCloseWithMetadata() throws Exception { + // Send a particular descriptor to the server and check for a particular response pattern. + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_METADATA_ONLY))) { + final FlightStream reader = stream.getReader(); + + // Server starts by sending a message without data (hence no VectorSchemaRoot should be present) + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(42, reader.getLatestMetadata().getInt(0)); + + // Write a metadata message to the server (without sending any data) + ArrowBuf buf = allocator.buffer(4); + buf.writeInt(84); + stream.getWriter().putMetadata(buf); + + // Check that the server echoed the metadata back to us + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(84, reader.getLatestMetadata().getInt(0)); + + // Close our write channel and ensure the server also closes theirs + stream.getWriter().completed(); + stream.getResult(); + + // Not necessary to close reader here, but check closing twice doesn't lead to negative refcnt from metadata + stream.getReader().close(); + } + } + static class Producer extends NoOpFlightProducer { static final Schema SCHEMA = new Schema( Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); @@ -404,6 +468,8 @@ public void doExchange(CallContext context, FlightStream reader, ServerStreamLis transform(context, reader, writer); } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_CANCEL)) { cancel(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ERROR)) { + error(context, reader, writer); } else { writer.error(CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException()); } @@ -534,5 +600,30 @@ private void transform(CallContext context, FlightStream reader, ServerStreamLis private void cancel(CallContext context, FlightStream reader, ServerStreamListener writer) { writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException()); } + + private void error(CallContext context, FlightStream reader, ServerStreamListener writer) { + VectorSchemaRoot root = null; + VectorLoader loader = null; + while (reader.next()) { + + if (root == null) { + root = VectorSchemaRoot.create(reader.getSchema(), allocator); + loader = new VectorLoader(root); + writer.start(root); + } + VectorUnloader unloader = new VectorUnloader(reader.getRoot()); + try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { + loader.load(arb); + } + + writer.putNext(); + } + if (root != null) { + root.close(); + } + + // An error occurs before completing the writer + writer.error(CallStatus.INTERNAL.withDescription("error completing exchange").toRuntimeException()); + } } }