From d5cb41aa3f95576660fea6f604180c986c883a96 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Sun, 10 Dec 2023 06:51:07 -0800 Subject: [PATCH] GH-32662: [C#] Make dictionaries in file and memory implementations work correctly and support integration tests (#39146) ### Rationale for this change While dictionary support was implemented for C# in #6870 for streams, support did not extend to files or memory buffers. This change rectifies that. ### What changes are included in this PR? Changes to the memory and file implementations to support reading and writing of dictionaries, including nested dictionaries. Changes to the integration tests so that they work with dictionaries. Enabling the dictionary tests in CI. ### Are these changes tested? Yes, both directly and indirectly via the integration tests. ### Are there any user-facing changes? No. * Closes: #32662 Authored-by: Curt Hagenlocher Signed-off-by: Curt Hagenlocher --- .../Ipc/ArrowFileReaderImplementation.cs | 34 +++++ .../src/Apache.Arrow/Ipc/ArrowFileWriter.cs | 54 +++++++- csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs | 4 +- .../Ipc/ArrowMemoryReaderImplementation.cs | 60 ++++---- .../Ipc/ArrowStreamReaderImplementation.cs | 128 +++++++++++------- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 78 +++++------ csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 10 +- .../src/Apache.Arrow/Types/DictionaryType.cs | 3 +- .../ArrowReaderBenchmark.cs | 2 +- .../IntegrationCommand.cs | 23 ++-- .../Apache.Arrow.IntegrationTest/JsonFile.cs | 106 +++++++++++++-- .../ArrowFileReaderTests.cs | 6 +- csharp/test/Apache.Arrow.Tests/TestData.cs | 2 +- dev/archery/archery/integration/datagen.py | 5 +- .../archery/integration/tester_csharp.py | 4 +- docs/source/status.rst | 2 +- 16 files changed, 352 insertions(+), 169 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index d88665e496dc9..3ae475885f16a 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -35,6 +35,8 @@ internal sealed class ArrowFileReaderImplementation : ArrowStreamReaderImplement private ArrowFooter _footer; + private bool HasReadDictionaries => HasReadSchema && DictionaryMemo.LoadedDictionaryCount >= _footer.DictionaryCount; + public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, ICompressionCodecFactory compressionCodecFactory, bool leaveOpen) : base(stream, allocator, compressionCodecFactory, leaveOpen) { @@ -143,6 +145,7 @@ private void ReadSchema(Memory buffer) public async ValueTask ReadRecordBatchAsync(int index, CancellationToken cancellationToken) { await ReadSchemaAsync().ConfigureAwait(false); + await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false); if (index >= _footer.RecordBatchCount) { @@ -159,6 +162,7 @@ public async ValueTask ReadRecordBatchAsync(int index, Cancellation public RecordBatch ReadRecordBatch(int index) { ReadSchema(); + ReadDictionaries(); if (index >= _footer.RecordBatchCount) { @@ -175,6 +179,7 @@ public RecordBatch ReadRecordBatch(int index) public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { await ReadSchemaAsync().ConfigureAwait(false); + await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false); if (_recordBatchIndex >= _footer.RecordBatchCount) { @@ -190,6 +195,7 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati public override RecordBatch ReadNextRecordBatch() { ReadSchema(); + ReadDictionaries(); if (_recordBatchIndex >= _footer.RecordBatchCount) { @@ -202,6 +208,34 @@ public override RecordBatch ReadNextRecordBatch() return result; } + private async ValueTask ReadDictionariesAsync(CancellationToken cancellationToken = default) + { + if (HasReadDictionaries) + { + return; + } + + foreach (Block block in _footer.Dictionaries) + { + BaseStream.Position = block.Offset; + await ReadMessageAsync(cancellationToken); + } + } + + private void ReadDictionaries() + { + if (HasReadDictionaries) + { + return; + } + + foreach (Block block in _footer.Dictionaries) + { + BaseStream.Position = block.Offset; + ReadMessage(); + } + } + /// /// Check if file format is valid. If it's valid don't run the validation again. /// diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 4fefb121cb669..95b9f60fffe0f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -23,10 +23,12 @@ namespace Apache.Arrow.Ipc { - public class ArrowFileWriter: ArrowStreamWriter + public class ArrowFileWriter : ArrowStreamWriter { private long _currentRecordBatchOffset = -1; + private long _currentDictionaryOffset = -1; + private List DictionaryBlocks { get; set; } private List RecordBatchBlocks { get; } public ArrowFileWriter(Stream stream, Schema schema) @@ -105,6 +107,34 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long _currentRecordBatchOffset = -1; } + private protected override void StartingWritingDictionary() + { + if (DictionaryBlocks == null) { DictionaryBlocks = new List(); } + _currentDictionaryOffset = BaseStream.Position; + } + + private protected override void FinishedWritingDictionary(long bodyLength, long metadataLength) + { + // Dictionaries only appear after a Schema is written, so the dictionary offsets must + // always be greater than 0. + Debug.Assert(_currentDictionaryOffset > 0, "_currentDictionaryOffset must be positive."); + + int metadataLengthInt = checked((int)metadataLength); + + Debug.Assert(BitUtility.IsMultipleOf8(_currentDictionaryOffset)); + Debug.Assert(BitUtility.IsMultipleOf8(metadataLengthInt)); + Debug.Assert(BitUtility.IsMultipleOf8(bodyLength)); + + var block = new Block( + offset: _currentDictionaryOffset, + length: bodyLength, + metadataLength: metadataLengthInt); + + DictionaryBlocks.Add(block); + + _currentDictionaryOffset = -1; + } + private protected override void WriteEndInternal() { base.WriteEndInternal(); @@ -161,9 +191,16 @@ private void WriteFooter(Schema schema) Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector(); // Serialize all dictionaries - // NOTE: Currently unsupported. - Flatbuf.Footer.StartDictionariesVector(Builder, 0); + int dictionaryCount = DictionaryBlocks?.Count ?? 0; + Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount); + + for (int i = dictionaryCount - 1; i >= 0; i--) + { + Block dictionary = DictionaryBlocks[i]; + Flatbuf.Block.CreateBlock( + Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength); + } Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector(); @@ -221,9 +258,16 @@ private async Task WriteFooterAsync(Schema schema, CancellationToken cancellatio Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector(); // Serialize all dictionaries - // NOTE: Currently unsupported. - Flatbuf.Footer.StartDictionariesVector(Builder, 0); + int dictionaryCount = DictionaryBlocks?.Count ?? 0; + Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount); + + for (int i = dictionaryCount - 1; i >= 0; i--) + { + Block dictionary = DictionaryBlocks[i]; + Flatbuf.Block.CreateBlock( + Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength); + } Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector(); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs index db269ae019b51..600624ef9ef12 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs @@ -25,8 +25,8 @@ internal class ArrowFooter private readonly List _dictionaries; private readonly List _recordBatches; - public IEnumerable Dictionaries => _dictionaries; - public IEnumerable RecordBatches => _recordBatches; + public IReadOnlyList Dictionaries => _dictionaries; + public IReadOnlyList RecordBatches => _recordBatches; public Block GetRecordBatchBlock(int i) => _recordBatches[i]; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index af4f963ee520f..6e2336a591bf1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -43,30 +43,17 @@ public override RecordBatch ReadNextRecordBatch() { ReadSchema(); - if (_buffer.Length <= _bufferPosition + sizeof(int)) + RecordBatch batch = null; + while (batch == null) { - // reached the end - return null; - } - - // Get Length of record batch for message header. - int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); - _bufferPosition += sizeof(int); - - if (messageLength == 0) - { - //reached the end - return null; - } - else if (messageLength == MessageSerializer.IpcContinuationToken) - { - // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length if (_buffer.Length <= _bufferPosition + sizeof(int)) { - throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + // reached the end + return null; } - messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + // Get Length of record batch for message header. + int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); _bufferPosition += sizeof(int); if (messageLength == 0) @@ -74,17 +61,36 @@ public override RecordBatch ReadNextRecordBatch() //reached the end return null; } - } + else if (messageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + + if (messageLength == 0) + { + //reached the end + return null; + } + } + + Message message = Message.GetRootAsMessage( + CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); + _bufferPosition += messageLength; - Message message = Message.GetRootAsMessage( - CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); - _bufferPosition += messageLength; + int bodyLength = (int)message.BodyLength; + ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength)); + _bufferPosition += bodyLength; - int bodyLength = (int)message.BodyLength; - ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength)); - _bufferPosition += bodyLength; + batch = CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null); + } - return CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null); + return batch; } private void ReadSchema() diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index df80ffe1e0fa5..184e0348e5e07 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -57,79 +57,93 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - RecordBatch result = null; - - while (result == null) + ReadResult result = default; + do { - int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) - .ConfigureAwait(false); + result = await ReadMessageAsync(cancellationToken).ConfigureAwait(false); + } while (result.Batch == null && result.MessageLength > 0); - if (messageLength == 0) - { - // reached end - return null; - } + return result.Batch; + } - await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => - { - int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(messageBuff, bytesRead); + protected async ValueTask ReadMessageAsync(CancellationToken cancellationToken) + { + int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) + .ConfigureAwait(false); - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + if (messageLength == 0) + { + // reached end + return default; + } - int bodyLength = checked((int)message.BodyLength); + RecordBatch result = null; + await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => + { + int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(messageBuff, bytesRead); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(bodyBuff, bytesRead); + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }).ConfigureAwait(false); - } + int bodyLength = checked((int)message.BodyLength); + + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(bodyBuff, bytesRead); + + Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }).ConfigureAwait(false); - return result; + return new ReadResult(messageLength, result); } protected RecordBatch ReadRecordBatch() { ReadSchema(); - RecordBatch result = null; - - while (result == null) + ReadResult result = default; + do { - int messageLength = ReadMessageLength(throwOnFullRead: false); + result = ReadMessage(); + } while (result.Batch == null && result.MessageLength > 0); - if (messageLength == 0) - { - // reached end - return null; - } + return result.Batch; + } - ArrayPool.Shared.RentReturn(messageLength, messageBuff => - { - int bytesRead = BaseStream.ReadFullBuffer(messageBuff); - EnsureFullRead(messageBuff, bytesRead); + protected ReadResult ReadMessage() + { + int messageLength = ReadMessageLength(throwOnFullRead: false); - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + if (messageLength == 0) + { + // reached end + return default; + } - int bodyLength = checked((int)message.BodyLength); + RecordBatch result = null; + ArrayPool.Shared.RentReturn(messageLength, messageBuff => + { + int bytesRead = BaseStream.ReadFullBuffer(messageBuff); + EnsureFullRead(messageBuff, bytesRead); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = BaseStream.ReadFullBuffer(bodyBuff); - EnsureFullRead(bodyBuff, bytesRead); + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }); - } + int bodyLength = checked((int)message.BodyLength); + + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = BaseStream.ReadFullBuffer(bodyBuff); + EnsureFullRead(bodyBuff, bytesRead); - return result; + Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }); + + return new ReadResult(messageLength, result); } protected virtual async ValueTask ReadSchemaAsync() @@ -264,5 +278,17 @@ internal static void EnsureFullRead(Memory buffer, int bytesRead) throw new InvalidOperationException("Unexpectedly reached the end of the stream before a full buffer was read."); } } + + internal struct ReadResult + { + public readonly int MessageLength; + public readonly RecordBatch Batch; + + public ReadResult(int messageLength, RecordBatch batch) + { + MessageLength = messageLength; + Batch = batch; + } + } } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index dcb8852bc1f65..d4e8bb48df4e1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -270,7 +270,6 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp _options = options ?? IpcOptions.Default; } - private void CreateSelfAndChildrenFieldNodes(ArrayData data) { if (data.DataType is NestedType) @@ -319,7 +318,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) if (!HasWrittenDictionaryBatch) { DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); - WriteDictionaries(recordBatch); + WriteDictionaries(_dictionaryMemo); HasWrittenDictionaryBatch = true; } @@ -358,7 +357,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat if (!HasWrittenDictionaryBatch) { DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); - await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); + await WriteDictionariesAsync(_dictionaryMemo, cancellationToken).ConfigureAwait(false); HasWrittenDictionaryBatch = true; } @@ -492,74 +491,65 @@ private Tuple PreparingWritingR return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); } + private protected virtual void StartingWritingDictionary() + { + } - private protected void WriteDictionaries(RecordBatch recordBatch) + private protected virtual void FinishedWritingDictionary(long bodyLength, long metadataLength) { - foreach (Field field in recordBatch.Schema.FieldsList) - { - WriteDictionary(field); - } } - private protected void WriteDictionary(Field field) + private protected void WriteDictionaries(DictionaryMemo dictionaryMemo) { - if (field.DataType.TypeId != ArrowTypeId.Dictionary) + int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; + for (int i = 0; i < fieldCount; i++) { - if (field.DataType is NestedType nestedType) - { - foreach (Field child in nestedType.Fields) - { - WriteDictionary(child); - } - } - return; + WriteDictionary(i, dictionaryMemo.GetDictionaryType(i), dictionaryMemo.GetDictionary(i)); } + } + + private protected void WriteDictionary(long id, IArrowType valueType, IArrowArray dictionary) + { + StartingWritingDictionary(); (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = - CreateDictionaryBatchOffset(field); + CreateDictionaryBatchOffset(id, valueType, dictionary); - WriteMessage(Flatbuf.MessageHeader.DictionaryBatch, + long metadataLength = WriteMessage(Flatbuf.MessageHeader.DictionaryBatch, dictionaryBatchOffset, recordBatchBuilder.TotalLength); - WriteBufferData(recordBatchBuilder.Buffers); + long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); + + FinishedWritingDictionary(bufferLength, metadataLength); } - private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken) + private protected async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) { - foreach (Field field in recordBatch.Schema.FieldsList) + int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; + for (int i = 0; i < fieldCount; i++) { - await WriteDictionaryAsync(field, cancellationToken).ConfigureAwait(false); + await WriteDictionaryAsync(i, dictionaryMemo.GetDictionaryType(i), dictionaryMemo.GetDictionary(i), cancellationToken).ConfigureAwait(false); } } - private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken) + private protected async Task WriteDictionaryAsync(long id, IArrowType valueType, IArrowArray dictionary, CancellationToken cancellationToken) { - if (field.DataType.TypeId != ArrowTypeId.Dictionary) - { - if (field.DataType is NestedType nestedType) - { - foreach (Field child in nestedType.Fields) - { - await WriteDictionaryAsync(child, cancellationToken).ConfigureAwait(false); - } - } - return; - } + StartingWritingDictionary(); (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = - CreateDictionaryBatchOffset(field); + CreateDictionaryBatchOffset(id, valueType, dictionary); - await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, + long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); - await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); + long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); + + FinishedWritingDictionary(bufferLength, metadataLength); } - private Tuple> CreateDictionaryBatchOffset(Field field) + private Tuple> CreateDictionaryBatchOffset(long id, IArrowType valueType, IArrowArray dictionary) { - Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); - long id = DictionaryMemo.GetId(field); - IArrowArray dictionary = DictionaryMemo.GetDictionary(id); + Field dictionaryField = new Field("dummy", valueType, false); var fields = new Field[] { dictionaryField }; @@ -987,12 +977,12 @@ private static void CollectDictionary(Field field, ArrayData arrayData, ref Dict arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId); IArrowArray dictionary = ArrowArrayFactory.BuildArray(arrayData.Dictionary); + WalkChildren(dictionary.Data, ref dictionaryMemo); dictionaryMemo ??= new DictionaryMemo(); long id = dictionaryMemo.GetOrAssignId(field); dictionaryMemo.AddOrReplaceDictionary(id, dictionary); - WalkChildren(dictionary.Data, ref dictionaryMemo); } else { diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 24f25a142966c..b107cc65bfac5 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -33,6 +33,9 @@ public DictionaryMemo() _fieldToId = new Dictionary(); } + public int DictionaryCount => _fieldToId.Count; + public int LoadedDictionaryCount => _idToDictionary.Count; + public IArrowType GetDictionaryType(long id) { if (!_idToValueType.TryGetValue(id, out IArrowType type)) @@ -72,9 +75,12 @@ public void AddField(long id, Field field) throw new ArgumentException($"Field type {field.DataType.Name} does not match the existing type {valueTypeInDic})"); } } + else + { + _idToValueType.Add(id, valueType); + } _fieldToId.Add(field, id); - _idToValueType.Add(id, valueType); } public long GetId(Field field) @@ -90,7 +96,7 @@ public long GetOrAssignId(Field field) { if (!_fieldToId.TryGetValue(field, out long id)) { - id = _fieldToId.Count + 1; + id = _fieldToId.Count; AddField(id, field); } return id; diff --git a/csharp/src/Apache.Arrow/Types/DictionaryType.cs b/csharp/src/Apache.Arrow/Types/DictionaryType.cs index 5c1dd4095eb16..6316578aa6a5d 100644 --- a/csharp/src/Apache.Arrow/Types/DictionaryType.cs +++ b/csharp/src/Apache.Arrow/Types/DictionaryType.cs @@ -20,6 +20,7 @@ namespace Apache.Arrow.Types { public sealed class DictionaryType : FixedWidthType { + [Obsolete] public static readonly DictionaryType Default = new DictionaryType(Int64Type.Default, Int64Type.Default, false); public DictionaryType(IArrowType indexType, IArrowType valueType, bool ordered) @@ -36,7 +37,7 @@ public DictionaryType(IArrowType indexType, IArrowType valueType, bool ordered) public override ArrowTypeId TypeId => ArrowTypeId.Dictionary; public override string Name => "dictionary"; - public override int BitWidth => 64; + public override int BitWidth => ((IntegerType)IndexType).BitWidth; public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); public IArrowType IndexType { get; private set; } diff --git a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs index 4e491a2a6b128..cd8198d434cc7 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs +++ b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs @@ -38,7 +38,7 @@ public class ArrowReaderBenchmark [GlobalSetup] public async Task GlobalSetup() { - RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count); + RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count, createDictionaryArray: false); _memoryStream = new MemoryStream(); ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, batch.Schema); diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs index d19d19f1ce7c1..6a1e91240989b 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs @@ -14,14 +14,8 @@ // limitations under the License. using System; -using System.Collections.Generic; -using System.Globalization; using System.IO; -using System.Numerics; -using System.Text; -using System.Text.Json; using System.Threading.Tasks; -using Apache.Arrow.Arrays; using Apache.Arrow.Ipc; using Apache.Arrow.Tests; using Apache.Arrow.Types; @@ -49,6 +43,7 @@ public async Task Execute() "json-to-arrow" => JsonToArrow, "stream-to-file" => StreamToFile, "file-to-stream" => FileToStream, + "round-trip-json-arrow" => RoundTripJsonArrow, _ => () => { Console.WriteLine($"Mode '{Mode}' is not supported."); @@ -58,6 +53,14 @@ public async Task Execute() return await commandDelegate(); } + private async Task RoundTripJsonArrow() + { + int status = await JsonToArrow(); + if (status != 0) { return status; } + + return await Validate(); + } + private async Task Validate() { JsonFile jsonFile = await ParseJsonFile(); @@ -72,7 +75,7 @@ private async Task Validate() return -1; } - Schema jsonFileSchema = jsonFile.Schema.ToArrow(); + Schema jsonFileSchema = jsonFile.GetSchemaAndDictionaries(out Func dictionaries); Schema arrowFileSchema = reader.Schema; SchemaComparer.Compare(jsonFileSchema, arrowFileSchema); @@ -80,7 +83,7 @@ private async Task Validate() for (int i = 0; i < batchCount; i++) { RecordBatch arrowFileRecordBatch = reader.ReadNextRecordBatch(); - RecordBatch jsonFileRecordBatch = jsonFile.Batches[i].ToArrow(jsonFileSchema); + RecordBatch jsonFileRecordBatch = jsonFile.Batches[i].ToArrow(jsonFileSchema, dictionaries); ArrowReaderVerifier.CompareBatches(jsonFileRecordBatch, arrowFileRecordBatch, strictCompare: false); } @@ -98,7 +101,7 @@ private async Task Validate() private async Task JsonToArrow() { JsonFile jsonFile = await ParseJsonFile(); - Schema schema = jsonFile.Schema.ToArrow(); + Schema schema = jsonFile.GetSchemaAndDictionaries(out Func dictionaries); using (FileStream fs = ArrowFileInfo.Create()) { @@ -107,7 +110,7 @@ private async Task JsonToArrow() foreach (var jsonRecordBatch in jsonFile.Batches) { - RecordBatch batch = jsonRecordBatch.ToArrow(schema); + RecordBatch batch = jsonRecordBatch.ToArrow(schema, dictionaries); await writer.WriteRecordBatchAsync(batch); } await writer.WriteEndAsync(); diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs index 987a236a10191..bdb9e2682bb01 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs @@ -15,9 +15,9 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Globalization; using System.IO; +using System.Linq; using System.Numerics; using System.Text; using System.Text.Json; @@ -31,8 +31,10 @@ namespace Apache.Arrow.IntegrationTest public class JsonFile { public JsonSchema Schema { get; set; } + + public List Dictionaries { get; set; } + public List Batches { get; set; } - //public List Dictionaries {get;set;} public static async ValueTask ParseAsync(FileInfo fileInfo) { @@ -48,6 +50,33 @@ public static JsonFile Parse(FileInfo fileInfo) return JsonSerializer.Deserialize(fileStream, options); } + public Schema GetSchemaAndDictionaries(out Func dictionaries) + { + Schema schema = Schema.ToArrow(out Dictionary dictionaryIndexes); + + Func lookup = null; + lookup = type => Dictionaries.Single(d => d.Id == dictionaryIndexes[type]).Data.ToArrow(type.ValueType, lookup); + dictionaries = lookup; + + return schema; + } + + /// + /// Return both the schema and a specific batch number. + /// This method is used by C Data Interface integration testing. + /// + public Schema ToArrow(int batchNumber, out RecordBatch batch) + { + Schema schema = Schema.ToArrow(out Dictionary dictionaryIndexes); + + Func lookup = null; + lookup = type => Dictionaries.Single(d => d.Id == dictionaryIndexes[type]).Data.ToArrow(type.ValueType, lookup); + + batch = Batches[batchNumber].ToArrow(schema, lookup); + + return schema; + } + private static JsonSerializerOptions GetJsonOptions() { JsonSerializerOptions options = new JsonSerializerOptions() @@ -67,22 +96,33 @@ public class JsonSchema /// /// Decode this JSON schema as a Schema instance. /// + public Schema ToArrow(out Dictionary dictionaryIndexes) + { + dictionaryIndexes = new Dictionary(); + return CreateSchema(this, dictionaryIndexes); + } + + /// + /// Decode this JSON schema as a Schema instance without computing dictionaries. + /// This method is used by C Data Interface integration testing. + /// public Schema ToArrow() { - return CreateSchema(this); + Dictionary dictionaryIndexes = new Dictionary(); + return CreateSchema(this, dictionaryIndexes); } - private static Schema CreateSchema(JsonSchema jsonSchema) + private static Schema CreateSchema(JsonSchema jsonSchema, Dictionary dictionaryIndexes) { Schema.Builder builder = new Schema.Builder(); for (int i = 0; i < jsonSchema.Fields.Count; i++) { - builder.Field(f => CreateField(f, jsonSchema.Fields[i])); + builder.Field(f => CreateField(f, jsonSchema.Fields[i], dictionaryIndexes)); } return builder.Build(); } - private static void CreateField(Field.Builder builder, JsonField jsonField) + private static void CreateField(Field.Builder builder, JsonField jsonField, Dictionary dictionaryIndexes) { Field[] children = null; if (jsonField.Children?.Count > 0) @@ -91,13 +131,26 @@ private static void CreateField(Field.Builder builder, JsonField jsonField) for (int i = 0; i < jsonField.Children.Count; i++) { Field.Builder field = new Field.Builder(); - CreateField(field, jsonField.Children[i]); + CreateField(field, jsonField.Children[i], dictionaryIndexes); children[i] = field.Build(); } } + IArrowType type = ToArrowType(jsonField.Type, children); + + if (jsonField.Dictionary != null) + { + DictionaryType dictType = new DictionaryType( + ToArrowType(jsonField.Dictionary.IndexType, new Field[0]), + type, + jsonField.Dictionary.IsOrdered); + + dictionaryIndexes[dictType] = jsonField.Dictionary.Id; + type = dictType; + } + builder.Name(jsonField.Name) - .DataType(ToArrowType(jsonField.Type, children)) + .DataType(type) .Nullable(jsonField.Nullable); if (jsonField.Metadata != null) @@ -300,10 +353,18 @@ public class JsonArrowType public class JsonDictionaryIndex { public int Id { get; set; } - public JsonArrowType Type { get; set; } + public JsonArrowType IndexType { get; set; } public bool IsOrdered { get; set; } } + public class JsonDictionary + { + public int Id { get; set; } + + [JsonPropertyName("data")] + public JsonRecordBatch Data { get; set; } + } + public class JsonMetadata : List> { } @@ -316,12 +377,19 @@ public class JsonRecordBatch /// /// Decode this JSON record batch as a RecordBatch instance. /// - public RecordBatch ToArrow(Schema schema) + public RecordBatch ToArrow(Schema schema, Func dictionaries) + { + return CreateRecordBatch(schema, dictionaries, this); + } + + public IArrowArray ToArrow(IArrowType arrowType, Func dictionaries) { - return CreateRecordBatch(schema, this); + ArrayCreator creator = new ArrayCreator(this.Columns[0], dictionaries); + arrowType.Accept(creator); + return creator.Array; } - private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch jsonRecordBatch) + private RecordBatch CreateRecordBatch(Schema schema, Func dictionaries, JsonRecordBatch jsonRecordBatch) { if (schema.FieldsList.Count != jsonRecordBatch.Columns.Count) { @@ -333,7 +401,7 @@ private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch jsonRecordB { JsonFieldData data = jsonRecordBatch.Columns[i]; Field field = schema.FieldsList[i]; - ArrayCreator creator = new ArrayCreator(data); + ArrayCreator creator = new ArrayCreator(data, dictionaries); field.DataType.Accept(creator); arrays.Add(creator.Array); } @@ -369,14 +437,18 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor { private JsonFieldData JsonFieldData { get; set; } public IArrowArray Array { get; private set; } - public ArrayCreator(JsonFieldData jsonFieldData) + private readonly Func dictionaries; + + public ArrayCreator(JsonFieldData jsonFieldData, Func dictionaries) { JsonFieldData = jsonFieldData; + this.dictionaries = dictionaries; } public void Visit(BooleanType type) @@ -656,6 +728,12 @@ public void Visit(MapType type) Array = new MapArray(arrayData); } + public void Visit(DictionaryType type) + { + type.IndexType.Accept(this); + Array = new DictionaryArray(type, Array, this.dictionaries(type)); + } + private ArrayData[] GetChildren(NestedType type) { ArrayData[] children = new ArrayData[type.Fields.Count]; diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs index 2f2229ded4c46..585b1acc27f17 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs @@ -66,7 +66,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) ArrowFileReader reader = new ArrowFileReader(stream, memoryPool, leaveOpen: shouldLeaveOpen); reader.ReadNextRecordBatch(); - Assert.Equal(1, memoryPool.Statistics.Allocations); + Assert.Equal(2, memoryPool.Statistics.Allocations); Assert.True(memoryPool.Statistics.BytesAllocated > 0); reader.Dispose(); @@ -132,8 +132,8 @@ private static async Task TestReadRecordBatchHelper( [Fact] public async Task TestReadMultipleRecordBatchAsync() { - RecordBatch originalBatch1 = TestData.CreateSampleRecordBatch(length: 100); - RecordBatch originalBatch2 = TestData.CreateSampleRecordBatch(length: 50); + RecordBatch originalBatch1 = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + RecordBatch originalBatch2 = TestData.CreateSampleRecordBatch(length: 50, createDictionaryArray: false); using (MemoryStream stream = new MemoryStream()) { diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 3af6efb97b437..79e886f0deabb 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -23,7 +23,7 @@ namespace Apache.Arrow.Tests { public static class TestData { - public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = false) + public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = true) { return CreateSampleRecordBatch(length, columnSetCount: 1, createDictionaryArray); } diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 80cc1c1e76425..341b48117ab80 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1836,15 +1836,12 @@ def _temp_path(): .skip_tester('C#') .skip_tester('JS'), - generate_dictionary_case() - .skip_tester('C#'), + generate_dictionary_case(), generate_dictionary_unsigned_case() - .skip_tester('C#') .skip_tester('Java'), # TODO(ARROW-9377) generate_nested_dictionary_case() - .skip_tester('C#') .skip_tester('Java'), # TODO(ARROW-7779) generate_run_end_encoded_case() diff --git a/dev/archery/archery/integration/tester_csharp.py b/dev/archery/archery/integration/tester_csharp.py index 4f7765641130d..9aab5b0b28ef9 100644 --- a/dev/archery/archery/integration/tester_csharp.py +++ b/dev/archery/archery/integration/tester_csharp.py @@ -78,9 +78,7 @@ def _pointer_to_int(self, c_ptr): def _read_batch_from_json(self, json_path, num_batch): from Apache.Arrow.IntegrationTest import CDataInterface - jf = CDataInterface.ParseJsonFile(json_path) - schema = jf.Schema.ToArrow() - return schema, jf.Batches[num_batch].ToArrow(schema) + return CDataInterface.ParseJsonFile(json_path).ToArrow(num_batch) def _run_gc(self): from Apache.Arrow.IntegrationTest import CDataInterface diff --git a/docs/source/status.rst b/docs/source/status.rst index 6167d3037ba77..140e15f44cbca 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -100,7 +100,7 @@ Data Types | Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | | (special) | | | | | | | | | +===================+=======+=======+=======+============+=======+=======+=======+=======+ -| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ (3) | ✓ (3) | ✓ | | +| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ | ✓ (3) | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Extension | ✓ | ✓ | ✓ | | | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+