Skip to content

Commit

Permalink
apacheGH-32662: [C#] Make dictionaries in file and memory implementat…
Browse files Browse the repository at this point in the history
…ions work correctly and support integration tests (apache#39146)

### Rationale for this change

While dictionary support was implemented for C# in apache#6870 for streams, support did not extend to files or memory buffers. This change rectifies that.

### What changes are included in this PR?

Changes to the memory and file implementations to support reading and writing of dictionaries, including nested dictionaries.
Changes to the integration tests so that they work with dictionaries.
Enabling the dictionary tests in CI.

### Are these changes tested?

Yes, both directly and indirectly via the integration tests.

### Are there any user-facing changes?

No.

* Closes: apache#32662

Authored-by: Curt Hagenlocher <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
  • Loading branch information
CurtHagenlocher authored and dgreiss committed Feb 17, 2024
1 parent cf7f2f9 commit 27eb4e3
Show file tree
Hide file tree
Showing 16 changed files with 352 additions and 169 deletions.
34 changes: 34 additions & 0 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ internal sealed class ArrowFileReaderImplementation : ArrowStreamReaderImplement

private ArrowFooter _footer;

private bool HasReadDictionaries => HasReadSchema && DictionaryMemo.LoadedDictionaryCount >= _footer.DictionaryCount;

public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, ICompressionCodecFactory compressionCodecFactory, bool leaveOpen)
: base(stream, allocator, compressionCodecFactory, leaveOpen)
{
Expand Down Expand Up @@ -143,6 +145,7 @@ private void ReadSchema(Memory<byte> buffer)
public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (index >= _footer.RecordBatchCount)
{
Expand All @@ -159,6 +162,7 @@ public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, Cancellation
public RecordBatch ReadRecordBatch(int index)
{
ReadSchema();
ReadDictionaries();

if (index >= _footer.RecordBatchCount)
{
Expand All @@ -175,6 +179,7 @@ public RecordBatch ReadRecordBatch(int index)
public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (_recordBatchIndex >= _footer.RecordBatchCount)
{
Expand All @@ -190,6 +195,7 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();
ReadDictionaries();

if (_recordBatchIndex >= _footer.RecordBatchCount)
{
Expand All @@ -202,6 +208,34 @@ public override RecordBatch ReadNextRecordBatch()
return result;
}

private async ValueTask ReadDictionariesAsync(CancellationToken cancellationToken = default)
{
if (HasReadDictionaries)
{
return;
}

foreach (Block block in _footer.Dictionaries)
{
BaseStream.Position = block.Offset;
await ReadMessageAsync(cancellationToken);
}
}

private void ReadDictionaries()
{
if (HasReadDictionaries)
{
return;
}

foreach (Block block in _footer.Dictionaries)
{
BaseStream.Position = block.Offset;
ReadMessage();
}
}

/// <summary>
/// Check if file format is valid. If it's valid don't run the validation again.
/// </summary>
Expand Down
54 changes: 49 additions & 5 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@

namespace Apache.Arrow.Ipc
{
public class ArrowFileWriter: ArrowStreamWriter
public class ArrowFileWriter : ArrowStreamWriter
{
private long _currentRecordBatchOffset = -1;
private long _currentDictionaryOffset = -1;

private List<Block> DictionaryBlocks { get; set; }
private List<Block> RecordBatchBlocks { get; }

public ArrowFileWriter(Stream stream, Schema schema)
Expand Down Expand Up @@ -105,6 +107,34 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long
_currentRecordBatchOffset = -1;
}

private protected override void StartingWritingDictionary()
{
if (DictionaryBlocks == null) { DictionaryBlocks = new List<Block>(); }
_currentDictionaryOffset = BaseStream.Position;
}

private protected override void FinishedWritingDictionary(long bodyLength, long metadataLength)
{
// Dictionaries only appear after a Schema is written, so the dictionary offsets must
// always be greater than 0.
Debug.Assert(_currentDictionaryOffset > 0, "_currentDictionaryOffset must be positive.");

int metadataLengthInt = checked((int)metadataLength);

Debug.Assert(BitUtility.IsMultipleOf8(_currentDictionaryOffset));
Debug.Assert(BitUtility.IsMultipleOf8(metadataLengthInt));
Debug.Assert(BitUtility.IsMultipleOf8(bodyLength));

var block = new Block(
offset: _currentDictionaryOffset,
length: bodyLength,
metadataLength: metadataLengthInt);

DictionaryBlocks.Add(block);

_currentDictionaryOffset = -1;
}

private protected override void WriteEndInternal()
{
base.WriteEndInternal();
Expand Down Expand Up @@ -161,9 +191,16 @@ private void WriteFooter(Schema schema)
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector();

// Serialize all dictionaries
// NOTE: Currently unsupported.

Flatbuf.Footer.StartDictionariesVector(Builder, 0);
int dictionaryCount = DictionaryBlocks?.Count ?? 0;
Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);

for (int i = dictionaryCount - 1; i >= 0; i--)
{
Block dictionary = DictionaryBlocks[i];
Flatbuf.Block.CreateBlock(
Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength);
}

Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector();

Expand Down Expand Up @@ -221,9 +258,16 @@ private async Task WriteFooterAsync(Schema schema, CancellationToken cancellatio
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector();

// Serialize all dictionaries
// NOTE: Currently unsupported.

Flatbuf.Footer.StartDictionariesVector(Builder, 0);
int dictionaryCount = DictionaryBlocks?.Count ?? 0;
Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);

for (int i = dictionaryCount - 1; i >= 0; i--)
{
Block dictionary = DictionaryBlocks[i];
Flatbuf.Block.CreateBlock(
Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength);
}

Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector();

Expand Down
4 changes: 2 additions & 2 deletions csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ internal class ArrowFooter
private readonly List<Block> _dictionaries;
private readonly List<Block> _recordBatches;

public IEnumerable<Block> Dictionaries => _dictionaries;
public IEnumerable<Block> RecordBatches => _recordBatches;
public IReadOnlyList<Block> Dictionaries => _dictionaries;
public IReadOnlyList<Block> RecordBatches => _recordBatches;

public Block GetRecordBatchBlock(int i) => _recordBatches[i];

Expand Down
60 changes: 33 additions & 27 deletions csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,48 +43,54 @@ public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();

if (_buffer.Length <= _bufferPosition + sizeof(int))
RecordBatch batch = null;
while (batch == null)
{
// reached the end
return null;
}

// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
// reached the end
return null;
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;
int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;

int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;
batch = CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
}

return CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
return batch;
}

private void ReadSchema()
Expand Down
Loading

0 comments on commit 27eb4e3

Please sign in to comment.