diff options
Diffstat (limited to 'src/arrow/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs')
-rw-r--r-- | src/arrow/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/src/arrow/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/src/arrow/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs new file mode 100644 index 000000000..35199477b --- /dev/null +++ b/src/arrow/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlatBuffers; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Types; +using Apache.Arrow.Memory; + +namespace Apache.Arrow.Ipc +{ + internal abstract class ArrowReaderImplementation : IDisposable + { + public Schema Schema { get; protected set; } + protected bool HasReadSchema => Schema != null; + + private protected DictionaryMemo _dictionaryMemo; + private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); + private protected readonly MemoryAllocator _allocator; + + private protected ArrowReaderImplementation() : this(null) + { } + + private protected ArrowReaderImplementation(MemoryAllocator allocator) + { + _allocator = allocator ?? MemoryAllocator.Default.Value; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + } + + public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken); + public abstract RecordBatch ReadNextRecordBatch(); + + internal static T ReadMessage<T>(ByteBuffer bb) + where T : struct, IFlatbufferObject + { + Type returnType = typeof(T); + Flatbuf.Message msg = Flatbuf.Message.GetRootAsMessage(bb); + + if (MatchEnum(msg.HeaderType, returnType)) + { + return msg.Header<T>().Value; + } + else + { + throw new Exception($"Requested type '{returnType.Name}' " + + $"did not match type found at offset => '{msg.HeaderType}'"); + } + } + + private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuffType) + { + switch (messageHeader) + { + case Flatbuf.MessageHeader.RecordBatch: + return flatBuffType == typeof(Flatbuf.RecordBatch); + case Flatbuf.MessageHeader.DictionaryBatch: + return flatBuffType == typeof(Flatbuf.DictionaryBatch); + case Flatbuf.MessageHeader.Schema: + return flatBuffType == typeof(Flatbuf.Schema); + case Flatbuf.MessageHeader.Tensor: + return flatBuffType == typeof(Flatbuf.Tensor); + case Flatbuf.MessageHeader.NONE: + throw new ArgumentException("MessageHeader NONE has no matching flatbuf types", nameof(messageHeader)); + default: + throw new ArgumentException($"Unexpected MessageHeader value", nameof(messageHeader)); + } + } + + /// <summary> + /// Create a record batch or dictionary batch from Flatbuf.Message. + /// </summary> + /// <remarks> + /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. + /// </remarks>> + /// <returns> + /// The record batch when the message type is RecordBatch. + /// Null when the message type is not RecordBatch. + /// </returns> + protected RecordBatch CreateArrowObjectFromMessage( + Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner) + { + switch (message.HeaderType) + { + case Flatbuf.MessageHeader.Schema: + // TODO: Read schema and verify equality? + break; + case Flatbuf.MessageHeader.DictionaryBatch: + Flatbuf.DictionaryBatch dictionaryBatch = message.Header<Flatbuf.DictionaryBatch>().Value; + ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); + break; + case Flatbuf.MessageHeader.RecordBatch: + Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value; + List<IArrowArray> arrays = BuildArrays(Schema, bodyByteBuffer, rb); + return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); + default: + // NOTE: Skip unsupported message type + Debug.WriteLine($"Skipping unsupported message type '{message.HeaderType}'"); + break; + } + + return null; + } + + internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer) + { + return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); + } + + private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner<byte> memoryOwner) + { + long id = dictionaryBatch.Id; + IArrowType valueType = DictionaryMemo.GetDictionaryType(id); + Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data; + + if (!recordBatch.HasValue) + { + throw new InvalidDataException("Dictionary must contain RecordBatch"); + } + + Field valueField = new Field("dummy", valueType, true); + var schema = new Schema(new[] { valueField }, default); + IList<IArrowArray> arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + + if (arrays.Count != 1) + { + throw new InvalidDataException("Dictionary record batch must contain only one field"); + } + + if (dictionaryBatch.IsDelta) + { + DictionaryMemo.AddDeltaDictionary(id, arrays[0], _allocator); + } + else + { + DictionaryMemo.AddOrReplaceDictionary(id, arrays[0]); + } + } + + private List<IArrowArray> BuildArrays( + Schema schema, + ByteBuffer messageBuffer, + Flatbuf.RecordBatch recordBatchMessage) + { + var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength); + + if (recordBatchMessage.NodesLength == 0) + { + return arrays; + } + + var recordBatchEnumerator = new RecordBatchEnumerator(in recordBatchMessage); + int schemaFieldIndex = 0; + do + { + Field field = schema.GetFieldByIndex(schemaFieldIndex++); + Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode; + + ArrayData arrayData = field.DataType.IsFixedPrimitive() + ? LoadPrimitiveField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer) + : LoadVariableField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer); + + arrays.Add(ArrowArrayFactory.BuildArray(arrayData)); + } while (recordBatchEnumerator.MoveNextNode()); + + return arrays; + } + + private ArrayData LoadPrimitiveField( + ref RecordBatchEnumerator recordBatchEnumerator, + Field field, + in Flatbuf.FieldNode fieldNode, + ByteBuffer bodyData) + { + + ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer); + if (!recordBatchEnumerator.MoveNextBuffer()) + { + throw new Exception("Unable to move to the next buffer."); + } + + int fieldLength = (int)fieldNode.Length; + int fieldNullCount = (int)fieldNode.NullCount; + + if (fieldLength < 0) + { + throw new InvalidDataException("Field length must be >= 0"); // TODO:Localize exception message + } + + if (fieldNullCount < 0) + { + throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message + } + + ArrowBuffer[] arrowBuff; + if (field.DataType.TypeId == ArrowTypeId.Struct) + { + arrowBuff = new[] { nullArrowBuffer }; + } + else + { + ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer); + recordBatchEnumerator.MoveNextBuffer(); + + arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer }; + } + + ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); + + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); + } + + private ArrayData LoadVariableField( + ref RecordBatchEnumerator recordBatchEnumerator, + Field field, + in Flatbuf.FieldNode fieldNode, + ByteBuffer bodyData) + { + + ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer); + if (!recordBatchEnumerator.MoveNextBuffer()) + { + throw new Exception("Unable to move to the next buffer."); + } + ArrowBuffer offsetArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer); + if (!recordBatchEnumerator.MoveNextBuffer()) + { + throw new Exception("Unable to move to the next buffer."); + } + ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer); + recordBatchEnumerator.MoveNextBuffer(); + + int fieldLength = (int)fieldNode.Length; + int fieldNullCount = (int)fieldNode.NullCount; + + if (fieldLength < 0) + { + throw new InvalidDataException("Field length must be >= 0"); // TODO: Localize exception message + } + + if (fieldNullCount < 0) + { + throw new InvalidDataException("Null count length must be >= 0"); //TODO: Localize exception message + } + + ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; + ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); + + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); + } + + private ArrayData[] GetChildren( + ref RecordBatchEnumerator recordBatchEnumerator, + Field field, + ByteBuffer bodyData) + { + if (!(field.DataType is NestedType type)) return null; + + int childrenCount = type.Fields.Count; + var children = new ArrayData[childrenCount]; + for (int index = 0; index < childrenCount; index++) + { + recordBatchEnumerator.MoveNextNode(); + Flatbuf.FieldNode childFieldNode = recordBatchEnumerator.CurrentNode; + + Field childField = type.Fields[index]; + ArrayData child = childField.DataType.IsFixedPrimitive() + ? LoadPrimitiveField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData) + : LoadVariableField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData); + + children[index] = child; + } + return children; + } + + private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer) + { + if (buffer.Length <= 0) + { + return ArrowBuffer.Empty; + } + + int offset = (int)buffer.Offset; + int length = (int)buffer.Length; + + var data = bodyData.ToReadOnlyMemory(offset, length); + return new ArrowBuffer(data); + } + } + + internal struct RecordBatchEnumerator + { + private Flatbuf.RecordBatch RecordBatch { get; } + internal int CurrentBufferIndex { get; private set; } + internal int CurrentNodeIndex { get; private set; } + + internal Flatbuf.Buffer CurrentBuffer => RecordBatch.Buffers(CurrentBufferIndex).GetValueOrDefault(); + + internal Flatbuf.FieldNode CurrentNode => RecordBatch.Nodes(CurrentNodeIndex).GetValueOrDefault(); + + internal bool MoveNextBuffer() + { + return ++CurrentBufferIndex < RecordBatch.BuffersLength; + } + + internal bool MoveNextNode() + { + return ++CurrentNodeIndex < RecordBatch.NodesLength; + } + + internal RecordBatchEnumerator(in Flatbuf.RecordBatch recordBatch) + { + RecordBatch = recordBatch; + CurrentBufferIndex = 0; + CurrentNodeIndex = 0; + } + } +} |