diff --git a/csharp/examples/Examples.sln b/csharp/examples/Examples.sln index c0a4199ca5605..46858102b3209 100644 --- a/csharp/examples/Examples.sln +++ b/csharp/examples/Examples.sln @@ -7,6 +7,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FluentBuilderExample", "Flu EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow", "..\src\Apache.Arrow\Apache.Arrow.csproj", "{1FE1DE95-FF6E-4895-82E7-909713C53524}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FlightAspServerExample", "FlightAspServerExample\FlightAspServerExample.csproj", "{51701AC8-5C3C-47EA-B481-56F46B8C5673}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FlightClientExample", "FlightClientExample\FlightClientExample.csproj", "{9F54DCD2-68C2-47A9-ABE2-816068176328}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -21,6 +25,14 @@ Global {1FE1DE95-FF6E-4895-82E7-909713C53524}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.Build.0 = Release|Any CPU + {51701AC8-5C3C-47EA-B481-56F46B8C5673}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {51701AC8-5C3C-47EA-B481-56F46B8C5673}.Debug|Any CPU.Build.0 = Debug|Any CPU + {51701AC8-5C3C-47EA-B481-56F46B8C5673}.Release|Any CPU.ActiveCfg = Release|Any CPU + {51701AC8-5C3C-47EA-B481-56F46B8C5673}.Release|Any CPU.Build.0 = Release|Any CPU + {9F54DCD2-68C2-47A9-ABE2-816068176328}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9F54DCD2-68C2-47A9-ABE2-816068176328}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9F54DCD2-68C2-47A9-ABE2-816068176328}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9F54DCD2-68C2-47A9-ABE2-816068176328}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs index 588127537dfeb..7b04243a55a1c 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs @@ -38,11 +38,11 @@ public abstract class FlightRecordBatchStreamReader : IAsyncStreamReader flightDataStream) { - _arrowReaderImplementation = new RecordBatcReaderImplementation(flightDataStream); + _arrowReaderImplementation = new RecordBatchReaderImplementation(flightDataStream); } public ValueTask Schema => _arrowReaderImplementation.ReadSchema(); diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs index 72c1551be2917..54824a6c0d40d 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs @@ -36,12 +36,14 @@ internal class FlightDataStream : ArrowStreamWriter private readonly FlightDescriptor _flightDescriptor; private readonly IAsyncStreamWriter _clientStreamWriter; private Protocol.FlightData _currentFlightData; + private ByteString _currentAppMetadata; public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema) : base(new MemoryStream(), schema) { _clientStreamWriter = clientStreamWriter; _flightDescriptor = flightDescriptor; + AlwaysWriteDictionaries = true; } private async Task SendSchema() @@ -66,29 +68,42 @@ private void ResetStream() this.BaseStream.SetLength(0); } - public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata) + private void ResetFlightData() { - if (!HasWrittenSchema) + _currentFlightData = new Protocol.FlightData(); + } + + private void AddMetadata() + { + if (_currentAppMetadata != null) { - await SendSchema().ConfigureAwait(false); + _currentFlightData.AppMetadata = _currentAppMetadata; } - ResetStream(); + } - _currentFlightData = new Protocol.FlightData(); + private async Task SetFlightDataBodyFromBaseStreamAsync() + { + BaseStream.Position = 0; + var body = await ByteString.FromStreamAsync(BaseStream).ConfigureAwait(false); + _currentFlightData.DataBody = body; + } + + private async Task WriteFlightDataAsync() + { + await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); + } - if(applicationMetadata != null) + public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata) + { + _currentAppMetadata = applicationMetadata; + if (!HasWrittenSchema) { - _currentFlightData.AppMetadata = applicationMetadata; + await SendSchema().ConfigureAwait(false); } + ResetStream(); + ResetFlightData(); await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false); - - //Reset stream position - this.BaseStream.Position = 0; - var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false); - - _currentFlightData.DataBody = bodyData; - await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); } private protected override ValueTask WriteMessageAsync(MessageHeader headerType, Offset headerOffset, int bodyLength, CancellationToken cancellationToken) @@ -105,5 +120,23 @@ private protected override ValueTask WriteMessageAsync(MessageHeader he return new ValueTask(0); } + + private protected override async Task PostRecordBatchAsync() + { + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync().ConfigureAwait(false); + AddMetadata(); + await WriteFlightDataAsync().ConfigureAwait(false); + } + + private protected override async Task PostDictionaryAsync() + { + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync().ConfigureAwait(false); + await WriteFlightDataAsync().ConfigureAwait(false); + // Reset the stream for the next dictionary or record batch + ResetStream(); + ResetFlightData(); + } } } diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 9df28b5033c06..7797692f0f03f 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -50,10 +50,8 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) return schema; } - internal static Schema DecodeSchema(ByteBuffer schemaBuffer) + internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo) { - //DictionaryBatch not supported for now - DictionaryMemo dictionaryMemo = null; var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatcReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs similarity index 79% rename from csharp/src/Apache.Arrow.Flight/Internal/RecordBatcReaderImplementation.cs rename to csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index 10d4d731eb9f7..44a025364b601 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatcReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -25,13 +25,13 @@ namespace Apache.Arrow.Flight.Internal { - internal class RecordBatcReaderImplementation : ArrowReaderImplementation + internal class RecordBatchReaderImplementation : ArrowReaderImplementation { private readonly IAsyncStreamReader _flightDataStream; private FlightDescriptor _flightDescriptor; private readonly List _applicationMetadatas; - public RecordBatcReaderImplementation(IAsyncStreamReader streamReader) + public RecordBatchReaderImplementation(IAsyncStreamReader streamReader) { _flightDataStream = streamReader; _applicationMetadatas = new List(); @@ -87,7 +87,7 @@ public async ValueTask ReadSchema() switch (message.HeaderType) { case MessageHeader.Schema: - Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); @@ -103,8 +103,10 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati { await ReadSchema().ConfigureAwait(false); } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); - if (moveNextResult) + + // Keep reading dictionary batches until we get a record batch + var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + while (keepGoing) { //AppMetadata will never be null, but length 0 if empty //Those are skipped @@ -121,8 +123,17 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati case MessageHeader.RecordBatch: var body = _flightDataStream.Current.DataBody.Memory; return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null); + case MessageHeader.DictionaryBatch: + var dictionaryBody = _flightDataStream.Current.DataBody.Memory; + CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null); + keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + if (!keepGoing) + { + throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries"); + } + break; default: - throw new NotImplementedException(); + throw new NotImplementedException($"Message type {message.HeaderType} is not implemented."); } } return null; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index dcb8852bc1f65..b20235b5a42f5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -228,6 +228,8 @@ public void Visit(IArrowArray array) private bool HasWrittenDictionaryBatch { get; set; } + protected bool AlwaysWriteDictionaries { get; set; } + private bool HasWrittenStart { get; set; } private bool HasWrittenEnd { get; set; } @@ -316,7 +318,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) HasWrittenSchema = true; } - if (!HasWrittenDictionaryBatch) + if (!HasWrittenDictionaryBatch || AlwaysWriteDictionaries) { DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); WriteDictionaries(recordBatch); @@ -342,6 +344,8 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); FinishedWritingRecordBatch(bufferLength, metadataLength); + + PostRecordBatch(); } private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, @@ -355,7 +359,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat HasWrittenSchema = true; } - if (!HasWrittenDictionaryBatch) + if (!HasWrittenDictionaryBatch || AlwaysWriteDictionaries) { DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); @@ -382,6 +386,8 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); FinishedWritingRecordBatch(bufferLength, metadataLength); + + await PostRecordBatchAsync().ConfigureAwait(false); } private long WriteBufferData(IReadOnlyList buffers) @@ -492,7 +498,6 @@ private Tuple PreparingWritingR return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); } - private protected void WriteDictionaries(RecordBatch recordBatch) { foreach (Field field in recordBatch.Schema.FieldsList) @@ -522,6 +527,8 @@ private protected void WriteDictionary(Field field) dictionaryBatchOffset, recordBatchBuilder.TotalLength); WriteBufferData(recordBatchBuilder.Buffers); + + PostDictionary(); } private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken) @@ -553,6 +560,8 @@ await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); + + await PostDictionaryAsync().ConfigureAwait(false); } private Tuple> CreateDictionaryBatchOffset(Field field) @@ -616,6 +625,24 @@ private protected virtual void FinishedWritingRecordBatch(long bodyLength, long { } + private protected virtual void PostRecordBatch() + { + } + + private protected virtual void PostDictionary() + { + } + + private protected virtual async Task PostRecordBatchAsync() + { + await Task.CompletedTask; + } + + private protected virtual async Task PostDictionaryAsync() + { + await Task.CompletedTask; + } + public virtual void WriteRecordBatch(RecordBatch recordBatch) { WriteRecordBatchInternal(recordBatch); diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 24f25a142966c..c35b59f391ed0 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -24,13 +24,13 @@ class DictionaryMemo { private readonly Dictionary _idToDictionary; private readonly Dictionary _idToValueType; - private readonly Dictionary _fieldToId; + private readonly Dictionary _fieldToId; public DictionaryMemo() { _idToDictionary = new Dictionary(); _idToValueType = new Dictionary(); - _fieldToId = new Dictionary(); + _fieldToId = new Dictionary(); } public IArrowType GetDictionaryType(long id) @@ -53,7 +53,7 @@ public IArrowArray GetDictionary(long id) public void AddField(long id, Field field) { - if (_fieldToId.ContainsKey(field)) + if (_fieldToId.ContainsKey(field.Name)) { throw new ArgumentException($"Field {field.Name} is already in Memo"); } @@ -73,13 +73,13 @@ public void AddField(long id, Field field) } } - _fieldToId.Add(field, id); + _fieldToId.Add(field.Name, id); _idToValueType.Add(id, valueType); } public long GetId(Field field) { - if (!_fieldToId.TryGetValue(field, out long id)) + if (!_fieldToId.TryGetValue(field.Name, out long id)) { throw new ArgumentException($"Field with name {field.Name} not found"); } @@ -88,7 +88,7 @@ public long GetId(Field field) public long GetOrAssignId(Field field) { - if (!_fieldToId.TryGetValue(field, out long id)) + if (!_fieldToId.TryGetValue(field.Name, out long id)) { id = _fieldToId.Count + 1; AddField(id, field); diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 267fe4e4b606d..3fa7f480963cc 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -20,6 +20,7 @@ using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Tests; +using Apache.Arrow.Types; using Google.Protobuf; using Grpc.Core.Utils; using Xunit; @@ -52,6 +53,10 @@ private RecordBatch CreateTestBatch(int startValue, int length) builder.Append(startValue + i); } batchBuilder.Append("test", true, builder.Build()); + var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build(); + var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build(); + var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary); + batchBuilder.Append("dict", true, dictArray); return batchBuilder.Build(); } @@ -187,8 +192,8 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); - List actualMetadata = new List(); - while(await getStream.ResponseStream.MoveNext(default)) + List actualMetadata = new List(); + while (await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); }