Skip to content

Commit

Permalink
Finished making dictionaries work in File and Memory implementations …
Browse files Browse the repository at this point in the history
…and updated tests.
  • Loading branch information
CurtHagenlocher committed Dec 8, 2023
1 parent 536a867 commit 3709500
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 116 deletions.
74 changes: 58 additions & 16 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, I
{
}

public async ValueTask<int> RecordBatchCountAsync()
public async ValueTask<int> RecordBatchCountAsync(CancellationToken cancellationToken = default)
{
if (!HasReadSchema)
{
Expand Down Expand Up @@ -145,7 +145,7 @@ private void ReadSchema(Memory<byte> buffer)
public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (index >= _footer.RecordBatchCount)
{
Expand Down Expand Up @@ -179,7 +179,7 @@ public RecordBatch ReadRecordBatch(int index)
public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadDictionariesAsync().ConfigureAwait(false);
await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);

if (_recordBatchIndex >= _footer.RecordBatchCount)
{
Expand Down Expand Up @@ -208,22 +208,43 @@ public override RecordBatch ReadNextRecordBatch()
return result;
}

private async ValueTask ReadDictionariesAsync()
private async ValueTask ReadDictionariesAsync(CancellationToken cancellationToken = default)
{
if (HasReadDictionaries)
{
return;
}

// We don't know in what order the dictionaries have been serialized, so we deserialize
// just their indices and then construct them in X order
foreach (Block block in _footer.Dictionaries)
int index = 0;
while (index < _footer.DictionaryCount)
{
BaseStream.Position = block.Offset;
await ReadRecordBatchAsync(deferDictionaryLoad: true).ConfigureAwait(false);
index = await ReadNextDictionaryAsync(index, cancellationToken).ConfigureAwait(false);
}
}

private async ValueTask<int> ReadNextDictionaryAsync(int index, CancellationToken cancellationToken)
{
Block block = _footer.Dictionaries[index++];
BaseStream.Position = block.Offset;
await ReadMessageAsync(async (message, cancellationToken) =>
{
if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch)
{
return null;
}
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;

DictionaryMemo.FinishLoad();
long position = BaseStream.Position;
while (!DictionaryMemo.CanLoad(dictionaryBatch.Id))
{
// recursive load
index = await ReadNextDictionaryAsync(index, cancellationToken);
}
BaseStream.Position = position;
return await CreateArrowObjectAsync(message, cancellationToken);
}, cancellationToken).ConfigureAwait(false);

return index;
}

private void ReadDictionaries()
Expand All @@ -233,15 +254,36 @@ private void ReadDictionaries()
return;
}

// We don't know in what order the dictionaries have been serialized, so we deserialize
// just their indices and then construct them in X order
foreach (Block block in _footer.Dictionaries)
int index = 0;
while (index < _footer.DictionaryCount)
{
BaseStream.Position = block.Offset;
ReadRecordBatch(deferDictionaryLoad: true);
index = ReadNextDictionary(index);
}
}

private int ReadNextDictionary(int index)
{
Block block = _footer.Dictionaries[index++];
BaseStream.Position = block.Offset;
ReadMessage(message =>
{
if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch)
{
return null;
}
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;

long position = BaseStream.Position;
while (!DictionaryMemo.CanLoad(dictionaryBatch.Id))
{
// recursive load
index = ReadNextDictionary(index);
}
BaseStream.Position = position;
return CreateArrowObject(message);
});

DictionaryMemo.FinishLoad();
return index;
}

/// <summary>
Expand Down
60 changes: 33 additions & 27 deletions csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,48 +43,54 @@ public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();

if (_buffer.Length <= _bufferPosition + sizeof(int))
RecordBatch batch = null;
while (batch == null)
{
// reached the end
return null;
}

// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
// reached the end
return null;
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
// Get Length of record batch for message header.
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}

messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);

if (messageLength == 0)
{
//reached the end
return null;
}
}

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;

Message message = Message.GetRootAsMessage(
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;
int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;

int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;
batch = CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
}

return CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null);
return batch;
}

private void ReadSchema()
Expand Down
49 changes: 19 additions & 30 deletions csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuff
/// Null when the message type is not RecordBatch.
/// </returns>
protected RecordBatch CreateArrowObjectFromMessage(
Flatbuf.Message message,
ByteBuffer bodyByteBuffer,
IMemoryOwner<byte> memoryOwner,
bool deferDictionaryLoad = false)
Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner)
{
switch (message.HeaderType)
{
Expand All @@ -119,11 +116,11 @@ protected RecordBatch CreateArrowObjectFromMessage(
break;
case Flatbuf.MessageHeader.DictionaryBatch:
Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value;
ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner, deferDictionaryLoad);
ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner);
break;
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb, deferDictionaryLoad);
List<IArrowArray> arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
default:
// NOTE: Skip unsupported message type
Expand All @@ -143,8 +140,7 @@ private void ReadDictionaryBatch(
MetadataVersion version,
Flatbuf.DictionaryBatch dictionaryBatch,
ByteBuffer bodyByteBuffer,
IMemoryOwner<byte> memoryOwner,
bool deferDictionaryLoad)
IMemoryOwner<byte> memoryOwner)
{
long id = dictionaryBatch.Id;
IArrowType valueType = DictionaryMemo.GetDictionaryType(id);
Expand All @@ -157,18 +153,14 @@ private void ReadDictionaryBatch(

Field valueField = new Field("dummy", valueType, true);
var schema = new Schema(new[] { valueField }, default);
IList<IArrowArray> arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value, deferDictionaryLoad);
IList<IArrowArray> arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value);

if (arrays.Count != 1)
{
throw new InvalidDataException("Dictionary record batch must contain only one field");
}

if (deferDictionaryLoad)
{
DictionaryMemo.AddDictionaryValues(id, arrays[0].Data);
}
else if (dictionaryBatch.IsDelta)
if (dictionaryBatch.IsDelta)
{
DictionaryMemo.AddDeltaDictionary(id, arrays[0], _allocator);
}
Expand All @@ -182,8 +174,7 @@ private List<IArrowArray> BuildArrays(
MetadataVersion version,
Schema schema,
ByteBuffer messageBuffer,
Flatbuf.RecordBatch recordBatchMessage,
bool deferDictionaryLoad)
Flatbuf.RecordBatch recordBatchMessage)
{
var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength);

Expand All @@ -201,8 +192,8 @@ private List<IArrowArray> BuildArrays(
Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode;

ArrayData arrayData = field.DataType.IsFixedPrimitive()
? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator, deferDictionaryLoad)
: LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator, deferDictionaryLoad);
? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator)
: LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator);

arrays.Add(ArrowArrayFactory.BuildArray(arrayData));
} while (recordBatchEnumerator.MoveNextNode());
Expand Down Expand Up @@ -244,8 +235,7 @@ private ArrayData LoadPrimitiveField(
Field field,
in Flatbuf.FieldNode fieldNode,
ByteBuffer bodyData,
IBufferCreator bufferCreator,
bool deferDictionaryLoad)
IBufferCreator bufferCreator)
{

int fieldLength = (int)fieldNode.Length;
Expand Down Expand Up @@ -298,10 +288,10 @@ private ArrayData LoadPrimitiveField(
recordBatchEnumerator.MoveNextBuffer();
}

ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator, deferDictionaryLoad);
ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator);

IArrowArray dictionary = null;
if (field.DataType.TypeId == ArrowTypeId.Dictionary && !deferDictionaryLoad)
if (field.DataType.TypeId == ArrowTypeId.Dictionary)
{
long id = DictionaryMemo.GetId(field);
dictionary = DictionaryMemo.GetDictionary(id);
Expand All @@ -316,9 +306,9 @@ private ArrayData LoadVariableField(
Field field,
in Flatbuf.FieldNode fieldNode,
ByteBuffer bodyData,
IBufferCreator bufferCreator,
bool deferDictionaryLoad)
IBufferCreator bufferCreator)
{

ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator);
if (!recordBatchEnumerator.MoveNextBuffer())
{
Expand Down Expand Up @@ -346,10 +336,10 @@ private ArrayData LoadVariableField(
}

ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer };
ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator, deferDictionaryLoad);
ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator);

IArrowArray dictionary = null;
if (field.DataType.TypeId == ArrowTypeId.Dictionary && !deferDictionaryLoad)
if (field.DataType.TypeId == ArrowTypeId.Dictionary)
{
long id = DictionaryMemo.GetId(field);
dictionary = DictionaryMemo.GetDictionary(id);
Expand All @@ -363,8 +353,7 @@ private ArrayData[] GetChildren(
ref RecordBatchEnumerator recordBatchEnumerator,
Field field,
ByteBuffer bodyData,
IBufferCreator bufferCreator,
bool deferDictionaryLoad)
IBufferCreator bufferCreator)
{
if (!(field.DataType is NestedType type)) return null;

Expand All @@ -377,8 +366,8 @@ private ArrayData[] GetChildren(

Field childField = type.Fields[index];
ArrayData child = childField.DataType.IsFixedPrimitive()
? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator, deferDictionaryLoad)
: LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator, deferDictionaryLoad);
? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator)
: LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator);

children[index] = child;
}
Expand Down
Loading

0 comments on commit 3709500

Please sign in to comment.