Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-32662: [C#] Make dictionaries in file and memory implementations work correctly and support integration tests #39146

Merged
merged 8 commits into from
Dec 10, 2023
86 changes: 85 additions & 1 deletion csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ internal sealed class ArrowFileReaderImplementation : ArrowStreamReaderImplement

private ArrowFooter _footer;

private bool HasReadDictionaries => HasReadSchema && DictionaryMemo.LoadedDictionaryCount >= _footer.DictionaryCount;

public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, ICompressionCodecFactory compressionCodecFactory, bool leaveOpen)
: base(stream, allocator, compressionCodecFactory, leaveOpen)
{
}

public async ValueTask<int> RecordBatchCountAsync()
public async ValueTask<int> RecordBatchCountAsync(CancellationToken cancellationToken = default)
CurtHagenlocher marked this conversation as resolved.
Show resolved Hide resolved
{
if (!HasReadSchema)
{
Expand Down Expand Up @@ -143,6 +145,7 @@ private void ReadSchema(Memory<byte> buffer)
public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (index >= _footer.RecordBatchCount)
{
Expand All @@ -159,6 +162,7 @@ public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, Cancellation
public RecordBatch ReadRecordBatch(int index)
{
ReadSchema();
ReadDictionaries();

if (index >= _footer.RecordBatchCount)
{
Expand All @@ -175,6 +179,7 @@ public RecordBatch ReadRecordBatch(int index)
public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (_recordBatchIndex >= _footer.RecordBatchCount)
{
Expand All @@ -190,6 +195,7 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();
ReadDictionaries();

if (_recordBatchIndex >= _footer.RecordBatchCount)
{
Expand All @@ -202,6 +208,84 @@ public override RecordBatch ReadNextRecordBatch()
return result;
}

private async ValueTask ReadDictionariesAsync(CancellationToken cancellationToken = default)
{
if (HasReadDictionaries)
{
return;
}

int index = 0;
while (index < _footer.DictionaryCount)
{
index = await ReadNextDictionaryAsync(index, cancellationToken).ConfigureAwait(false);
}
}

private async ValueTask<int> ReadNextDictionaryAsync(int index, CancellationToken cancellationToken)
{
Block block = _footer.Dictionaries[index++];
BaseStream.Position = block.Offset;
await ReadMessageAsync(async (message, cancellationToken) =>
{
if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch)
{
return null;
}
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;

long position = BaseStream.Position;
while (!DictionaryMemo.CanLoad(dictionaryBatch.Id))
{
// recursive load
index = await ReadNextDictionaryAsync(index, cancellationToken);
}
BaseStream.Position = position;
return await CreateArrowObjectAsync(message, cancellationToken);
}, cancellationToken).ConfigureAwait(false);

return index;
}

private void ReadDictionaries()
{
if (HasReadDictionaries)
{
return;
}

int index = 0;
while (index < _footer.DictionaryCount)
{
index = ReadNextDictionary(index);
}
}

private int ReadNextDictionary(int index)
{
Block block = _footer.Dictionaries[index++];
BaseStream.Position = block.Offset;
ReadMessage(message =>
{
if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch)
{
return null;
}
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;

long position = BaseStream.Position;
while (!DictionaryMemo.CanLoad(dictionaryBatch.Id))
{
// recursive load
index = ReadNextDictionary(index);
}
BaseStream.Position = position;
return CreateArrowObject(message);
});

return index;
}

/// <summary>
/// Check if file format is valid. If it's valid don't run the validation again.
/// </summary>
Expand Down
54 changes: 49 additions & 5 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@

namespace Apache.Arrow.Ipc
{
public class ArrowFileWriter: ArrowStreamWriter
public class ArrowFileWriter : ArrowStreamWriter
{
private long _currentRecordBatchOffset = -1;
private long _currentDictionaryOffset = -1;

private List<Block> DictionaryBlocks { get; set; }
private List<Block> RecordBatchBlocks { get; }

public ArrowFileWriter(Stream stream, Schema schema)
Expand Down Expand Up @@ -105,6 +107,34 @@ private protected override void FinishedWritingRecordBatch(long bodyLength, long
_currentRecordBatchOffset = -1;
}

private protected override void StartingWritingDictionary()
{
if (DictionaryBlocks == null) { DictionaryBlocks = new List<Block>(); }
_currentDictionaryOffset = BaseStream.Position;
}

private protected override void FinishedWritingDictionary(long bodyLength, long metadataLength)
{
// Dictionaries only appear after a Schema is written, so the dictionary offsets must
// always be greater than 0.
Debug.Assert(_currentDictionaryOffset > 0, "_currentDictionaryOffset must be positive.");

int metadataLengthInt = checked((int)metadataLength);

Debug.Assert(BitUtility.IsMultipleOf8(_currentDictionaryOffset));
Debug.Assert(BitUtility.IsMultipleOf8(metadataLengthInt));
Debug.Assert(BitUtility.IsMultipleOf8(bodyLength));

var block = new Block(
offset: _currentDictionaryOffset,
length: bodyLength,
metadataLength: metadataLengthInt);

DictionaryBlocks.Add(block);

_currentDictionaryOffset = -1;
}

private protected override void WriteEndInternal()
{
base.WriteEndInternal();
Expand Down Expand Up @@ -161,9 +191,16 @@ private void WriteFooter(Schema schema)
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector();

// Serialize all dictionaries
// NOTE: Currently unsupported.

Flatbuf.Footer.StartDictionariesVector(Builder, 0);
int dictionaryCount = DictionaryBlocks?.Count ?? 0;
Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);

for (int i = 0; i < dictionaryCount; i++)
{
Block dictionary = DictionaryBlocks[i];
Flatbuf.Block.CreateBlock(
Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength);
}

Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector();

Expand Down Expand Up @@ -221,9 +258,16 @@ private async Task WriteFooterAsync(Schema schema, CancellationToken cancellatio
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset = Builder.EndVector();

// Serialize all dictionaries
// NOTE: Currently unsupported.

Flatbuf.Footer.StartDictionariesVector(Builder, 0);
int dictionaryCount = DictionaryBlocks?.Count ?? 0;
Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);

for (int i = 0; i < dictionaryCount; i++)
{
Block dictionary = DictionaryBlocks[i];
Flatbuf.Block.CreateBlock(
Builder, dictionary.Offset, dictionary.MetadataLength, dictionary.BodyLength);
}

Google.FlatBuffers.VectorOffset dictionaryBatchesOffset = Builder.EndVector();

Expand Down
4 changes: 2 additions & 2 deletions csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ internal class ArrowFooter
private readonly List<Block> _dictionaries;
private readonly List<Block> _recordBatches;

public IEnumerable<Block> Dictionaries => _dictionaries;
public IEnumerable<Block> RecordBatches => _recordBatches;
public IReadOnlyList<Block> Dictionaries => _dictionaries;
public IReadOnlyList<Block> RecordBatches => _recordBatches;

public Block GetRecordBatchBlock(int i) => _recordBatches[i];

Expand Down
60 changes: 33 additions & 27 deletions csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,48 +43,54 @@ public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();

if (_buffer.Length <= _bufferPosition + sizeof(int))
RecordBatch batch = null;
while (batch == null)
{
// reached the end
return null;
}

// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
// reached the end
return null;
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;
int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;

int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;
batch = CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
}

return CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
return batch;
}

private void ReadSchema()
Expand Down
Loading
Loading