From 8581bdfdd0f28e6fdcfd759ea5a5ccf2bec69464 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 3 Dec 2019 13:14:08 -0600 Subject: [PATCH] Add RecordArray (and Record) to Numba. (#26) RecordArray and Record objects can be used in Numba and `FillableArray.beginrecord`, `field`, `endrecord` has been extended to Numba. * Start adding RecordArray (and Record) to Numba. * Stub files for RecordArray in Numba. * Access a RecordArray's 'lookup' dict in Python. * Adding accessors to C++ that support Numba. * Record's first attribute should be named 'array', not 'recordarray'. * Record* made a round-trip to Numba. * Compute 'length' once in iteration. * Stubs for StringLiteral in slices. * RecordArray.getitem_range. * Finished 'getitem_str' for all array types. * Record.getitem_str. * Bring signatures up to date. * Finished Record getitems. * [skip ci] Started implementing getitem_tuple; broke everything. * Fixed it: RecordArray.getitem_next(anything but field) should work. * Finished RecordArray.getitem_next(*), need to do *.getitem_next(string). * Finished *.getitem_next(str). Moving on to FillableArray. * FillableArray.begintuple/index/endtuple works. * FillableArray.beginrecord/field/endrecord works. Probably done with this PR. --- README.md | 4 +- VERSION_INFO | 2 +- awkward1/_numba/__init__.py | 1 + awkward1/_numba/array/emptyarray.py | 3 + awkward1/_numba/array/listarray.py | 33 ++ awkward1/_numba/array/listoffsetarray.py | 32 ++ awkward1/_numba/array/numpyarray.py | 6 + awkward1/_numba/array/recordarray.py | 379 ++++++++++++++++++ awkward1/_numba/array/regulararray.py | 32 ++ awkward1/_numba/content.py | 12 +- awkward1/_numba/fillable.py | 104 +++++ awkward1/_numba/iterator.py | 20 +- awkward1/_numba/libawkward.py | 49 +++ awkward1/_numba/util.py | 28 +- awkward1/signatures/NumpyArray_8cpp.xml | 99 +---- awkward1/signatures/RecordArray_8cpp.xml | 54 +++ awkward1/signatures/Record_8cpp.xml | 48 +++ .../signatures/cpu-kernels_2util_8cpp.xml | 4 +- awkward1/signatures/getitem_8cpp.xml | 76 ++-- awkward1/signatures/identity_8cpp.xml | 249 +++++++++++- awkward1/signatures/libawkward_2util_8cpp.xml | 167 ++++++-- awkward1/util.py | 23 ++ include/awkward/array/Record.h | 21 +- include/awkward/array/RecordArray.h | 2 +- include/awkward/fillable/FillableArray.h | 8 +- src/libawkward/array/NumpyArray.cpp | 24 +- src/libawkward/array/Record.cpp | 56 +-- src/libawkward/array/RecordArray.cpp | 2 +- src/libawkward/fillable/FillableArray.cpp | 81 +++- src/pyawkward.cpp | 55 ++- tests/test_PR025_record_array.py | 9 +- tests/test_PR026_recordarray_in_numba.py | 225 +++++++++++ 32 files changed, 1672 insertions(+), 236 deletions(-) create mode 100644 awkward1/_numba/array/recordarray.py create mode 100644 awkward1/signatures/RecordArray_8cpp.xml create mode 100644 awkward1/signatures/Record_8cpp.xml create mode 100644 tests/test_PR026_recordarray_in_numba.py diff --git a/README.md b/README.md index 886b9358d5..020b28e32b 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ Completed items are ☑check-marked. See [closed PRs](https://github.com/scikit- * [X] Test all (tested in mock [studies/fillable.py](tree/master/studies/fillable.py)). * [X] JSON → Awkward via header-only [RapidJSON](https://rapidjson.org) and `awkward.fromiter`. * [ ] Explicit broadcasting functions for jagged and non-jagged arrays and scalars. - * [ ] Structure-preserving ufunc-like operation on the C++ side that applies a lambda function to inner data. The Python `__array_ufunc__` implementation will _call_ this to preserve structure. + * [ ] ~~Structure-preserving ufunc-like operation on the C++ side that applies a lambda function to inner data. The Python `__array_ufunc__` implementation will _call_ this to preserve structure.~~ * [ ] Extend `__getitem__` to take jagged arrays of integers and booleans (same behavior as old). * [ ] Full suite of array types: * [X] `EmptyArray`: 1-dimensional array with length 0 and unknown type (result of `UnknownFillable`, compatible with all types of arrays). @@ -75,7 +75,7 @@ Completed items are ☑check-marked. See [closed PRs](https://github.com/scikit- * [X] `ListOffsetArray`: the `JaggedArray` case with no unreachable data between reachable data (gaps). * [X] `RegularArray`: for building rectilinear, N-dimensional arrays of arbitrary contents, e.g. putting jagged dimensions inside fixed dimensions. * [X] `RecordArray`: the new `Table` _without_ lazy-slicing. - * [ ] Implement it in Numba as well. + * [X] Implement it in Numba as well. * [ ] `MaskedArray`, `BitMaskedArray`, `IndexedMaskedArray`: same as the old versions. * [ ] `UnionArray`: same as the old version; `SparseUnionArray`: the additional case found in Apache Arrow. * [ ] `IndexedArray`: same as the old version. diff --git a/VERSION_INFO b/VERSION_INFO index 0e7400f186..7db267292f 100644 --- a/VERSION_INFO +++ b/VERSION_INFO @@ -1 +1 @@ -0.1.25 +0.1.26 diff --git a/awkward1/_numba/__init__.py b/awkward1/_numba/__init__.py index e4037026dc..9e287e845a 100644 --- a/awkward1/_numba/__init__.py +++ b/awkward1/_numba/__init__.py @@ -18,3 +18,4 @@ import awkward1._numba.array.listoffsetarray import awkward1._numba.array.emptyarray import awkward1._numba.array.regulararray + import awkward1._numba.array.recordarray diff --git a/awkward1/_numba/array/emptyarray.py b/awkward1/_numba/array/emptyarray.py index 6c528e8fdc..2a5258451e 100644 --- a/awkward1/_numba/array/emptyarray.py +++ b/awkward1/_numba/array/emptyarray.py @@ -27,6 +27,9 @@ def getitem_int(self): def getitem_range(self): return self + def getitem_str(self): + raise IndexError("cannot slice EmptyArray with str (Record field name)") + def getitem_tuple(self, wheretpe): if len(wheretpe.types) == 0: return self diff --git a/awkward1/_numba/array/listarray.py b/awkward1/_numba/array/listarray.py index dcd8686ac2..8abe1a5759 100644 --- a/awkward1/_numba/array/listarray.py +++ b/awkward1/_numba/array/listarray.py @@ -46,6 +46,9 @@ def getitem_int(self): def getitem_range(self): return self + def getitem_str(self, key): + return ListArrayType(self.startstpe, self.stopstpe, self.contenttpe.getitem_str(key), self.idtpe) + def getitem_tuple(self, wheretpe): nexttpe = ListArrayType(util.index64tpe, util.index64tpe, self, numba.none) outtpe = nexttpe.getitem_next(wheretpe, False) @@ -65,6 +68,9 @@ def getitem_next(self, wheretpe, isadvanced): contenttpe = self.contenttpe.carry().getitem_next(tailtpe, isadvanced) return awkward1._numba.array.listoffsetarray.ListOffsetArrayType(util.indextpe(self.indexname), contenttpe, self.idtpe) + elif isinstance(headtpe, numba.types.StringLiteral): + return self.getitem_str(headtpe.literal_value).getitem_next(tailtpe, isadvanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("ellipsis") @@ -102,6 +108,10 @@ def lower_getitem_int(self): def lower_getitem_range(self): return lower_getitem_range + @property + def lower_getitem_str(self): + return lower_getitem_str + @property def lower_getitem_next(self): return lower_getitem_next @@ -224,6 +234,24 @@ def lower_getitem_range(context, builder, sig, args): context.nrt.incref(builder, rettpe, out) return out +@numba.extending.lower_builtin(operator.getitem, ListArrayType, numba.types.StringLiteral) +def lower_getitem_str(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + proxyout = numba.cgutils.create_struct_proxy(rettpe)(context, builder) + proxyout.starts = proxyin.starts + proxyout.stops = proxyin.stops + proxyout.content = tpe.contenttpe.lower_getitem_str(context, builder, rettpe.contenttpe(tpe.contenttpe, wheretpe), (proxyin.content, whereval)) + if tpe.idtpe != numba.none: + proxyout.id = proxyin.id + + out = proxyout._getvalue() + if context.enable_nrt: + context.nrt.incref(builder, rettpe, out) + return out + @numba.extending.lower_builtin(operator.getitem, ListArrayType, numba.types.BaseTuple) def lower_getitem_tuple(context, builder, sig, args): return content.lower_getitem_tuple(context, builder, sig, args) @@ -364,6 +392,11 @@ def lower_getitem_next(context, builder, arraytpe, wheretpe, arrayval, whereval, proxyout.id = proxyin.id return proxyout._getvalue() + elif isinstance(headtpe, numba.types.StringLiteral): + nexttpe = arraytpe.getitem_str(headtpe.literal_value) + nextval = lower_getitem_str(context, builder, nexttpe(arraytpe, headtpe), (arrayval, headval)) + return lower_getitem_next(context, builder, nexttpe, tailtpe, nextval, tailval, advanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("ListArray.getitem_next(ellipsis)") diff --git a/awkward1/_numba/array/listoffsetarray.py b/awkward1/_numba/array/listoffsetarray.py index 48388e32f4..b34d359619 100644 --- a/awkward1/_numba/array/listoffsetarray.py +++ b/awkward1/_numba/array/listoffsetarray.py @@ -40,6 +40,9 @@ def getitem_int(self): def getitem_range(self): return self + def getitem_str(self, key): + return ListOffsetArrayType(self.offsetstpe, self.contenttpe.getitem_str(key), self.idtpe) + def getitem_tuple(self, wheretpe): import awkward1._numba.array.listarray nexttpe = awkward1._numba.array.listarray.ListArrayType(util.index64tpe, util.index64tpe, self, numba.none) @@ -60,6 +63,9 @@ def getitem_next(self, wheretpe, isadvanced): contenttpe = self.contenttpe.carry().getitem_next(tailtpe, isadvanced) return ListOffsetArrayType(util.indextpe(self.indexname), contenttpe, self.idtpe) + elif isinstance(headtpe, numba.types.StringLiteral): + return self.getitem_str(headtpe.literal_value).getitem_next(tailtpe, isadvanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("ellipsis") @@ -98,6 +104,10 @@ def lower_getitem_int(self): def lower_getitem_range(self): return lower_getitem_range + @property + def lower_getitem_str(self): + return lower_getitem_str + @property def lower_getitem_next(self): return lower_getitem_next @@ -222,6 +232,23 @@ def lower_getitem_range(context, builder, sig, args): context.nrt.incref(builder, rettpe, out) return out +@numba.extending.lower_builtin(operator.getitem, ListOffsetArrayType, numba.types.StringLiteral) +def lower_getitem_str(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + proxyout = numba.cgutils.create_struct_proxy(rettpe)(context, builder) + proxyout.offsets = proxyin.offsets + proxyout.content = tpe.contenttpe.lower_getitem_str(context, builder, rettpe.contenttpe(tpe.contenttpe, wheretpe), (proxyin.content, whereval)) + if tpe.idtpe != numba.none: + proxyout.id = proxyin.id + + out = proxyout._getvalue() + if context.enable_nrt: + context.nrt.incref(builder, rettpe, out) + return out + @numba.extending.lower_builtin(operator.getitem, ListOffsetArrayType, numba.types.BaseTuple) def lower_getitem_tuple(context, builder, sig, args): return content.lower_getitem_tuple(context, builder, sig, args) @@ -377,6 +404,11 @@ def lower_getitem_next(context, builder, arraytpe, wheretpe, arrayval, whereval, proxyout.id = proxyin.id return proxyout._getvalue() + elif isinstance(headtpe, numba.types.StringLiteral): + nexttpe = arraytpe.getitem_str(headtpe.literal_value) + nextval = lower_getitem_str(context, builder, nexttpe(arraytpe, headtpe), (arrayval, headval)) + return lower_getitem_next(context, builder, nexttpe, tailtpe, nextval, tailval, advanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("ListOffsetArray.getitem_next(ellipsis)") diff --git a/awkward1/_numba/array/numpyarray.py b/awkward1/_numba/array/numpyarray.py index aabf275519..f479345e95 100644 --- a/awkward1/_numba/array/numpyarray.py +++ b/awkward1/_numba/array/numpyarray.py @@ -30,6 +30,9 @@ def getitem_int(self): def getitem_range(self): return self.getitem_tuple(numba.types.slice2_type) + def getitem_str(self): + raise IndexError("cannot slice NumpyArray with str (Record field name)") + def getitem_tuple(self, wheretpe): outtpe = numba.typing.arraydecl.get_array_index_type(self.arraytpe, wheretpe).result if isinstance(outtpe, numba.types.Array): @@ -40,6 +43,9 @@ def getitem_tuple(self, wheretpe): def getitem_next(self, wheretpe, isadvanced): if len(wheretpe.types) > self.arraytpe.ndim: raise IndexError("too many dimensions in slice") + if any(isinstance(x, numba.types.StringLiteral) for x in wheretpe): + raise IndexError("cannot slice NumpyArray with str (Record field name)") + if isadvanced: numreduce = sum(1 if isinstance(x, (numba.types.Integer, numba.types.Array)) else 0 for x in wheretpe.types) else: diff --git a/awkward1/_numba/array/recordarray.py b/awkward1/_numba/array/recordarray.py new file mode 100644 index 0000000000..172bf491ce --- /dev/null +++ b/awkward1/_numba/array/recordarray.py @@ -0,0 +1,379 @@ +# BSD 3-Clause License; see https://github.com/jpivarski/awkward-1.0/blob/master/LICENSE + +import operator + +import numpy +import numba +import numba.typing.arraydecl + +import awkward1.layout +from ..._numba import cpu, util, content + +@numba.extending.typeof_impl.register(awkward1.layout.RecordArray) +def typeof(val, c): + return RecordArrayType([numba.typeof(x) for x in val.values()], val.lookup, val.reverselookup, numba.typeof(val.id)) + +@numba.extending.typeof_impl.register(awkward1.layout.Record) +def typeof(val, c): + return RecordType(numba.typeof(val.array)) + +class RecordArrayType(content.ContentType): + def __init__(self, contenttpes, lookup, reverselookup, idtpe): + super(RecordArrayType, self).__init__(name="RecordArrayType([{}], {}, {}, id={})".format(", ".join(x.name for x in contenttpes), lookup, reverselookup, idtpe.name)) + self.contenttpes = contenttpes + self.lookup = lookup + self.reverselookup = reverselookup + self.idtpe = idtpe + + @property + def istuple(self): + return self.lookup is None + + @property + def numfields(self): + return len(self.contenttpes) + + @property + def ndim(self): + return 1 + + def getitem_int(self): + return RecordType(self) + + def getitem_range(self): + return self + + def getitem_str(self, key): + return self.contenttpes[awkward1.util.field2index(self.lookup, self.numfields, key)] + + def getitem_tuple(self, wheretpe): + import awkward1._numba.array.regulararray + nexttpe = awkward1._numba.array.regulararray.RegularArrayType(self, numba.none) + out = nexttpe.getitem_next(wheretpe, False) + return out.getitem_int() + + def getitem_next(self, wheretpe, isadvanced): + if len(wheretpe.types) == 0: + return self + headtpe = wheretpe.types[0] + tailtpe = numba.types.Tuple(wheretpe.types[1:]) + + if isinstance(headtpe, numba.types.StringLiteral): + index = awkward1.util.field2index(self.lookup, self.numfields, headtpe.literal_value) + nexttpe = self.contenttpes[index] + + else: + contenttpes = [] + for t in self.contenttpes: + contenttpes.append(t.getitem_next(numba.types.Tuple((headtpe,)), isadvanced)) + nexttpe = RecordArrayType(contenttpes, self.lookup, self.reverselookup, numba.none) + + return nexttpe.getitem_next(tailtpe, isadvanced) + + def carry(self): + return RecordArrayType([x.carry() for x in self.contenttpes], self.lookup, self.reverselookup, self.idtpe) + + @property + def lower_len(self): + return lower_len + + @property + def lower_getitem_nothing(self): + return content.lower_getitem_nothing + + @property + def lower_getitem_int(self): + return lower_getitem_int + + @property + def lower_getitem_range(self): + return lower_getitem_range + + @property + def lower_getitem_str(self): + return lower_getitem_str + + @property + def lower_getitem_next(self): + return lower_getitem_next + + @property + def lower_carry(self): + return lower_carry + +class RecordType(numba.types.Type): + def __init__(self, arraytpe): + self.arraytpe = arraytpe + super(RecordType, self).__init__("Record({})".format(self.arraytpe.name)) + assert isinstance(arraytpe, RecordArrayType) + + @property + def istuple(self): + return self.arraytpe.istuple + + def getitem_str(self, key): + outtpe = self.arraytpe.getitem_str(key) + return outtpe.getitem_int() + + def getitem_tuple(self, wheretpe): + nextwheretpe = numba.types.Tuple((numba.int64,) + wheretpe.types) + return self.arraytpe.getitem_tuple(nextwheretpe) + +@numba.typing.templates.infer_global(operator.getitem) +class type_getitem_record(numba.typing.templates.AbstractTemplate): + def generic(self, args, kwargs): + if len(args) == 2 and len(kwargs) == 0: + tpe, wheretpe = args + + if isinstance(tpe, RecordType): + original_wheretpe = wheretpe + if isinstance(wheretpe, numba.types.Integer): + raise TypeError("Record[int]") + if isinstance(wheretpe, numba.types.SliceType): + raise TypeError("Record[slice]") + if isinstance(wheretpe, numba.types.StringLiteral): + return numba.typing.templates.signature(tpe.getitem_str(wheretpe.literal_value), tpe, original_wheretpe) + + if not isinstance(wheretpe, numba.types.BaseTuple): + wheretpe = numba.types.Tuple((wheretpe,)) + + wheretpe = util.typing_regularize_slice(wheretpe) + content.type_getitem.check_slice_types(wheretpe) + + return numba.typing.templates.signature(tpe.getitem_tuple(wheretpe), tpe, original_wheretpe) + +def field(i): + return "f" + str(i) + +@numba.extending.register_model(RecordArrayType) +class RecordArrayModel(numba.datamodel.models.StructModel): + def __init__(self, dmm, fe_type): + members = [("length", numba.int64)] + for i, x in enumerate(fe_type.contenttpes): + members.append((field(i), x)) + if fe_type.idtpe != numba.none: + members.append(("id", fe_type.idtpe)) + super(RecordArrayModel, self).__init__(dmm, fe_type, members) + +@numba.datamodel.registry.register_default(RecordType) +class RecordModel(numba.datamodel.models.StructModel): + def __init__(self, dmm, fe_type): + members = [("array", fe_type.arraytpe), + ("at", numba.int64)] + super(RecordModel, self).__init__(dmm, fe_type, members) + +@numba.extending.unbox(RecordArrayType) +def unbox(tpe, obj, c): + len_obj = c.pyapi.unserialize(c.pyapi.serialize_object(len)) + length_obj = c.pyapi.call_function_objargs(len_obj, (obj,)) + proxyout = numba.cgutils.create_struct_proxy(tpe)(c.context, c.builder) + proxyout.length = c.pyapi.to_native_value(numba.int64, length_obj).value + c.pyapi.decref(len_obj) + c.pyapi.decref(length_obj) + field_obj = c.pyapi.object_getattr_string(obj, "field") + for i, t in enumerate(tpe.contenttpes): + i_obj = c.pyapi.long_from_longlong(c.context.get_constant(numba.int64, i)) + x_obj = c.pyapi.call_function_objargs(field_obj, (i_obj,)) + setattr(proxyout, field(i), c.pyapi.to_native_value(t, x_obj).value) + c.pyapi.decref(i_obj) + c.pyapi.decref(x_obj) + c.pyapi.decref(field_obj) + if tpe.idtpe != numba.none: + id_obj = c.pyapi.object_getattr_string(obj, "id") + proxyout.id = c.pyapi.to_native_value(tpe.idtpe, id_obj).value + c.pyapi.decref(id_obj) + is_error = numba.cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) + return numba.extending.NativeValue(proxyout._getvalue(), is_error) + +@numba.extending.unbox(RecordType) +def unbox_record(tpe, obj, c): + array_obj = c.pyapi.object_getattr_string(obj, "array") + at_obj = c.pyapi.object_getattr_string(obj, "at") + proxyout = numba.cgutils.create_struct_proxy(tpe)(c.context, c.builder) + proxyout.array = c.pyapi.to_native_value(tpe.arraytpe, array_obj).value + proxyout.at = c.pyapi.to_native_value(numba.int64, at_obj).value + if c.context.enable_nrt: + c.context.nrt.incref(c.builder, tpe.arraytpe, proxyout.array) + c.pyapi.decref(array_obj) + c.pyapi.decref(at_obj) + is_error = numba.cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) + return numba.extending.NativeValue(proxyout._getvalue(), is_error) + +@numba.extending.box(RecordArrayType) +def box(tpe, val, c): + RecordArray_obj = c.pyapi.unserialize(c.pyapi.serialize_object(awkward1.layout.RecordArray)) + istuple_obj = c.pyapi.unserialize(c.pyapi.serialize_object(tpe.istuple)) + proxyin = numba.cgutils.create_struct_proxy(tpe)(c.context, c.builder, value=val) + length_obj = c.pyapi.long_from_longlong(proxyin.length) + if tpe.idtpe != numba.none: + id_obj = c.pyapi.from_native_value(tpe.idtpe, proxyin.id, c.env_manager) + out = c.pyapi.call_function_objargs(RecordArray_obj, (length_obj, istuple_obj, id_obj)) + c.pyapi.decref(id_obj) + else: + out = c.pyapi.call_function_objargs(RecordArray_obj, (length_obj, istuple_obj)) + append_obj = c.pyapi.object_getattr_string(out, "append") + for i, t in enumerate(tpe.contenttpes): + x_obj = c.pyapi.from_native_value(t, getattr(proxyin, field(i)), c.env_manager) + if tpe.reverselookup is None or len(tpe.reverselookup) <= i: + c.pyapi.call_function_objargs(append_obj, (x_obj,)) + else: + key_obj = c.pyapi.unserialize(c.pyapi.serialize_object(tpe.reverselookup[i])) + c.pyapi.call_function_objargs(append_obj, (x_obj, key_obj)) + c.pyapi.decref(key_obj) + c.pyapi.decref(x_obj) + c.pyapi.decref(RecordArray_obj) + c.pyapi.decref(istuple_obj) + c.pyapi.decref(length_obj) + c.pyapi.decref(append_obj) + return out + +@numba.extending.box(RecordType) +def box(tpe, val, c): + Record_obj = c.pyapi.unserialize(c.pyapi.serialize_object(awkward1.layout.Record)) + proxyin = numba.cgutils.create_struct_proxy(tpe)(c.context, c.builder, value=val) + array_obj = c.pyapi.from_native_value(tpe.arraytpe, proxyin.array, c.env_manager) + at_obj = c.pyapi.from_native_value(numba.int64, proxyin.at, c.env_manager) + out = c.pyapi.call_function_objargs(Record_obj, (array_obj, at_obj)) + c.pyapi.decref(Record_obj) + c.pyapi.decref(array_obj) + c.pyapi.decref(at_obj) + return out + +@numba.extending.lower_builtin(len, RecordArrayType) +def lower_len(context, builder, sig, args): + rettpe, (tpe,) = sig.return_type, sig.args + val, = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + return util.cast(context, builder, numba.int64, numba.intp, proxyin.length) + +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.Integer) +def lower_getitem_int(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + proxyout = numba.cgutils.create_struct_proxy(rettpe)(context, builder) + proxyout.array = val + proxyout.at = util.cast(context, builder, wheretpe, numba.int64, whereval) + if context.enable_nrt: + context.nrt.incref(builder, tpe, val) + return numba.targets.imputils.impl_ret_new_ref(context, builder, rettpe, proxyout._getvalue()) + +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.slice2_type) +def lower_getitem_range(context, builder, sig, args): + import awkward1._numba.identity + + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + + proxyslicein = context.make_helper(builder, wheretpe, value=whereval) + numba.targets.slicing.guard_invalid_slice(context, builder, wheretpe, proxyslicein) + numba.targets.slicing.fix_slice(builder, proxyslicein, util.cast(context, builder, numba.int64, numba.intp, proxyin.length)) + proxysliceout = numba.cgutils.create_struct_proxy(numba.types.slice2_type)(context, builder) + proxysliceout.start = proxyslicein.start + proxysliceout.stop = proxyslicein.stop + proxysliceout.step = proxyslicein.step + sliceout = proxysliceout._getvalue() + + proxyout = numba.cgutils.create_struct_proxy(tpe)(context, builder) + proxyout.length = util.cast(context, builder, numba.intp, numba.int64, builder.sub(proxyslicein.stop, proxyslicein.start)) + for i, t in enumerate(tpe.contenttpes): + setattr(proxyout, field(i), t.lower_getitem_range(context, builder, t.getitem_range()(t, numba.types.slice2_type), (getattr(proxyin, field(i)), sliceout))) + if tpe.idtpe != numba.none: + proxyout.id = awkward1._numba.identity.lower_getitem_any(context, builder, tpe.idtpe, wheretpe, proxyin.id, whereval) + + out = proxyout._getvalue() + if context.enable_nrt: + context.nrt.incref(builder, rettpe, out) + return out + +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.StringLiteral) +def lower_getitem_str(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + index = awkward1.util.field2index(tpe.lookup, tpe.numfields, wheretpe.literal_value) + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + + out = getattr(proxyin, field(index)) + if context.enable_nrt: + context.nrt.incref(builder, rettpe, out) + return out + +@numba.extending.lower_builtin(operator.getitem, RecordType, numba.types.StringLiteral) +def lower_getitem_str_record(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + + outtpe = tpe.arraytpe.getitem_str(wheretpe.literal_value) + outval = lower_getitem_str(context, builder, outtpe(tpe.arraytpe, wheretpe), (proxyin.array, whereval)) + return outtpe.lower_getitem_int(context, builder, rettpe(outtpe, numba.int64), (outval, proxyin.at)) + +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.BaseTuple) +def lower_getitem_tuple(context, builder, sig, args): + return content.lower_getitem_tuple(context, builder, sig, args) + +@numba.extending.lower_builtin(operator.getitem, RecordType, numba.types.BaseTuple) +def lower_getitem_tuple_record(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + + nextwheretpe = numba.types.Tuple((numba.int64,) + wheretpe.types) + nextwhereval = context.make_tuple(builder, nextwheretpe, (proxyin.at,) + numba.cgutils.unpack_tuple(builder, whereval)) + + return lower_getitem_tuple(context, builder, rettpe(tpe.arraytpe, nextwheretpe), (proxyin.array, nextwhereval)) + +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.Array) +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.List) +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.ArrayCompatible) +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, numba.types.EllipsisType) +@numba.extending.lower_builtin(operator.getitem, RecordArrayType, type(numba.typeof(numpy.newaxis))) +def lower_getitem_other(context, builder, sig, args): + return content.lower_getitem_other(context, builder, sig, args) + +def lower_getitem_next(context, builder, arraytpe, wheretpe, arrayval, whereval, advanced): + if len(wheretpe.types) == 0: + return arrayval + + headtpe = wheretpe.types[0] + tailtpe = numba.types.Tuple(wheretpe.types[1:]) + headval = numba.cgutils.unpack_tuple(builder, whereval)[0] + tailval = context.make_tuple(builder, tailtpe, numba.cgutils.unpack_tuple(builder, whereval)[1:]) + + proxyin = numba.cgutils.create_struct_proxy(arraytpe)(context, builder, value=arrayval) + + if isinstance(headtpe, numba.types.StringLiteral): + index = awkward1.util.field2index(arraytpe.lookup, arraytpe.numfields, headtpe.literal_value) + nexttpe = arraytpe.contenttpes[index] + nextval = getattr(proxyin, field(index)) + + else: + nexttpe = RecordArrayType([t.getitem_next(numba.types.Tuple((headtpe,)), advanced is not None) for t in arraytpe.contenttpes], arraytpe.lookup, arraytpe.reverselookup, numba.none) + proxyout = numba.cgutils.create_struct_proxy(nexttpe)(context, builder) + proxyout.length = proxyin.length + wrappedheadtpe = numba.types.Tuple((headtpe,)) + wrappedheadval = context.make_tuple(builder, wrappedheadtpe, (headval,)) + + for i, t in enumerate(arraytpe.contenttpes): + setattr(proxyout, field(i), t.lower_getitem_next(context, builder, t, wrappedheadtpe, getattr(proxyin, field(i)), wrappedheadval, advanced)) + nextval = proxyout._getvalue() + + rettpe = nexttpe.getitem_next(tailtpe, advanced is not None) + return rettpe.lower_getitem_next(context, builder, nexttpe, tailtpe, nextval, tailval, advanced) + +def lower_carry(context, builder, arraytpe, carrytpe, arrayval, carryval): + import awkward1._numba.identity + rettpe = arraytpe.carry() + proxyin = numba.cgutils.create_struct_proxy(arraytpe)(context, builder, value=arrayval) + proxyout = numba.cgutils.create_struct_proxy(rettpe)(context, builder) + proxyout.length = util.arraylen(context, builder, carrytpe, carryval, totpe=numba.int64) + for i, t in enumerate(arraytpe.contenttpes): + setattr(proxyout, field(i), t.lower_carry(context, builder, t, carrytpe, getattr(proxyin, field(i)), carryval)) + if rettpe.idtpe != numba.none: + proxyout.id = awkward1._numba.identity.lower_getitem_any(context, builder, rettpe.idtpe, carrytpe, proxyin.id, carryval) + return proxyout._getvalue() diff --git a/awkward1/_numba/array/regulararray.py b/awkward1/_numba/array/regulararray.py index 03ab92422b..1a096deb71 100644 --- a/awkward1/_numba/array/regulararray.py +++ b/awkward1/_numba/array/regulararray.py @@ -29,6 +29,9 @@ def getitem_int(self): def getitem_range(self): return self + def getitem_str(self, key): + return RegularArrayType(self.contenttpe.getitem_str(key), self.idtpe) + def getitem_tuple(self, wheretpe): nexttpe = RegularArrayType(self, numba.none) out = nexttpe.getitem_next(wheretpe, False) @@ -47,6 +50,9 @@ def getitem_next(self, wheretpe, isadvanced): contenttpe = self.contenttpe.carry().getitem_next(tailtpe, isadvanced) return RegularArrayType(contenttpe, self.idtpe) + elif isinstance(headtpe, numba.types.StringLiteral): + return self.getitem_str(headtpe.literal_value).getitem_next(tailtpe, isadvanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("ellipsis") @@ -84,6 +90,10 @@ def lower_getitem_int(self): def lower_getitem_range(self): return lower_getitem_range + @property + def lower_getitem_str(self): + return lower_getitem_str + @property def lower_getitem_next(self): return lower_getitem_next @@ -190,6 +200,23 @@ def lower_getitem_range(context, builder, sig, args): context.nrt.incref(builder, rettpe, out) return out +@numba.extending.lower_builtin(operator.getitem, RegularArrayType, numba.types.StringLiteral) +def lower_getitem_str(context, builder, sig, args): + rettpe, (tpe, wheretpe) = sig.return_type, sig.args + val, whereval = args + + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + proxyout = numba.cgutils.create_struct_proxy(rettpe)(context, builder) + proxyout.size = proxyin.size + proxyout.content = tpe.contenttpe.lower_getitem_str(context, builder, rettpe.contenttpe(tpe.contenttpe, wheretpe), (proxyin.content, whereval)) + if tpe.idtpe != numba.none: + proxyout.id = proxyin.id + + out = proxyout._getvalue() + if context.enable_nrt: + context.nrt.incref(builder, rettpe, out) + return out + @numba.extending.lower_builtin(operator.getitem, RegularArrayType, numba.types.BaseTuple) def lower_getitem_tuple(context, builder, sig, args): return content.lower_getitem_tuple(context, builder, sig, args) @@ -270,6 +297,11 @@ def lower_getitem_next(context, builder, arraytpe, wheretpe, arrayval, whereval, proxyout.id = proxyin.id return proxyout._getvalue() + elif isinstance(headtpe, numba.types.StringLiteral): + nexttpe = arraytpe.getitem_str(headtpe.literal_value) + nextval = lower_getitem_str(context, builder, nexttpe(arraytpe, headtpe), (arrayval, headval)) + return lower_getitem_next(context, builder, nexttpe, tailtpe, nextval, tailval, advanced) + elif isinstance(headtpe, numba.types.EllipsisType): raise NotImplementedError("RegularArray.getitem_next(ellipsis)") diff --git a/awkward1/_numba/content.py b/awkward1/_numba/content.py index bc93005fb3..008c07c2e6 100644 --- a/awkward1/_numba/content.py +++ b/awkward1/_numba/content.py @@ -32,17 +32,23 @@ def generic(self, args, kwargs): return numba.typing.templates.signature(arraytpe.getitem_int(), arraytpe, original_wheretpe) if isinstance(wheretpe, numba.types.SliceType) and not wheretpe.has_step: return numba.typing.templates.signature(arraytpe.getitem_range(), arraytpe, original_wheretpe) + if isinstance(wheretpe, numba.types.StringLiteral): + return numba.typing.templates.signature(arraytpe.getitem_str(wheretpe.literal_value), arraytpe, original_wheretpe) if not isinstance(wheretpe, numba.types.BaseTuple): wheretpe = numba.types.Tuple((wheretpe,)) wheretpe = util.typing_regularize_slice(wheretpe) - - if any(not isinstance(t, (numba.types.Integer, numba.types.SliceType, numba.types.EllipsisType, type(numba.typeof(numpy.newaxis)), numba.types.Array)) for t in wheretpe.types): - raise TypeError("only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`), and integer or boolean arrays (possibly jagged) are valid indices") + self.check_slice_types(wheretpe) return numba.typing.templates.signature(arraytpe.getitem_tuple(wheretpe), arraytpe, original_wheretpe) + @staticmethod + def check_slice_types(wheretpe): + if any(not isinstance(t, (numba.types.Integer, numba.types.SliceType, numba.types.EllipsisType, type(numba.typeof(numpy.newaxis)), numba.types.StringLiteral)) and not (isinstance(t, numba.types.Array) and isinstance(t.dtype, (numba.types.Boolean, numba.types.Integer))) for t in wheretpe.types): + raise TypeError("only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`), integer or boolean arrays (possibly jagged), and constant strings (known at compile-time) are valid indices") + + def lower_getitem_nothing(context, builder, tpe, val): proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) proxyslice = numba.cgutils.create_struct_proxy(numba.types.slice2_type)(context, builder) diff --git a/awkward1/_numba/fillable.py b/awkward1/_numba/fillable.py index 89e5d965fc..9f1c16a09c 100644 --- a/awkward1/_numba/fillable.py +++ b/awkward1/_numba/fillable.py @@ -114,6 +114,50 @@ def resolve_endlist(self, arraytpe, args, kwargs): else: raise TypeError("wrong number of arguments for FillableArray.endlist") + @numba.typing.templates.bound_function("begintuple") + def resolve_begintuple(self, arraytpe, args, kwargs): + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], numba.types.Integer): + return numba.types.none(args[0]) + else: + raise TypeError("wrong number of arguments for FillableArray.begintuple") + + @numba.typing.templates.bound_function("index") + def resolve_index(self, arraytpe, args, kwargs): + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], numba.types.Integer): + return numba.types.none(args[0]) + else: + raise TypeError("wrong number of arguments for FillableArray.index") + + @numba.typing.templates.bound_function("endtuple") + def resolve_endtuple(self, arraytpe, args, kwargs): + if len(args) == 0 and len(kwargs) == 0: + return numba.types.none() + else: + raise TypeError("wrong number of arguments for FillableArray.endtuple") + + @numba.typing.templates.bound_function("beginrecord") + def resolve_beginrecord(self, arraytpe, args, kwargs): + if len(args) == 0 and len(kwargs) == 0: + return numba.types.none() + elif len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], numba.types.StringLiteral): + return numba.types.none(args[0]) + else: + raise TypeError("wrong number of arguments for FillableArray.beginrecord") + + @numba.typing.templates.bound_function("field") + def resolve_field(self, arraytpe, args, kwargs): + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], numba.types.StringLiteral): + return numba.types.none(args[0]) + else: + raise TypeError("wrong number of arguments for FillableArray.field") + + @numba.typing.templates.bound_function("endrecord") + def resolve_endrecord(self, arraytpe, args, kwargs): + if len(args) == 0 and len(kwargs) == 0: + return numba.types.none() + else: + raise TypeError("wrong number of arguments for FillableArray.endrecord") + @numba.extending.lower_builtin("clear", FillableArrayType) def lower_clear(context, builder, sig, args): tpe, = sig.args @@ -182,3 +226,63 @@ def lower_endlist(context, builder, sig, args): proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) call(context, builder, libawkward.FillableArray_endlist, (proxyin.rawptr,)) return context.get_dummy_value() + +@numba.extending.lower_builtin("begintuple", FillableArrayType, numba.types.Integer) +def lower_begintuple(context, builder, sig, args): + tpe, numfieldstpe = sig.args + val, numfieldsval = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + numfields = util.cast(context, builder, numfieldstpe, numba.int64, numfieldsval) + call(context, builder, libawkward.FillableArray_begintuple, (proxyin.rawptr, numfields)) + return context.get_dummy_value() + +@numba.extending.lower_builtin("index", FillableArrayType, numba.types.Integer) +def lower_index(context, builder, sig, args): + tpe, indextpe = sig.args + val, indexval = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + index = util.cast(context, builder, indextpe, numba.int64, indexval) + call(context, builder, libawkward.FillableArray_index, (proxyin.rawptr, index)) + return context.get_dummy_value() + +@numba.extending.lower_builtin("endtuple", FillableArrayType) +def lower_endtuple(context, builder, sig, args): + tpe, = sig.args + val, = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + call(context, builder, libawkward.FillableArray_endtuple, (proxyin.rawptr,)) + return context.get_dummy_value() + +@numba.extending.lower_builtin("beginrecord", FillableArrayType) +def lower_beginrecord(context, builder, sig, args): + tpe, = sig.args + val, = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + call(context, builder, libawkward.FillableArray_beginrecord, (proxyin.rawptr, context.get_constant(numba.int64, 0))) + return context.get_dummy_value() + +@numba.extending.lower_builtin("beginrecord", FillableArrayType, numba.types.StringLiteral) +def lower_beginrecord(context, builder, sig, args): + tpe, nametpe = sig.args + val, nameval = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + name = util.globalstring(context, builder, nametpe.literal_value, inttype=numba.int64) + call(context, builder, libawkward.FillableArray_beginrecord, (proxyin.rawptr, name)) + return context.get_dummy_value() + +@numba.extending.lower_builtin("field", FillableArrayType, numba.types.StringLiteral) +def lower_field(context, builder, sig, args): + tpe, keytpe = sig.args + val, keyval = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + key = util.globalstring(context, builder, keytpe.literal_value) + call(context, builder, libawkward.FillableArray_field_fast, (proxyin.rawptr, key)) + return context.get_dummy_value() + +@numba.extending.lower_builtin("endrecord", FillableArrayType) +def lower_endrecord(context, builder, sig, args): + tpe, = sig.args + val, = args + proxyin = numba.cgutils.create_struct_proxy(tpe)(context, builder, value=val) + call(context, builder, libawkward.FillableArray_endrecord, (proxyin.rawptr,)) + return context.get_dummy_value() diff --git a/awkward1/_numba/iterator.py b/awkward1/_numba/iterator.py index b9a4b03f1a..2dbcc7bbf6 100644 --- a/awkward1/_numba/iterator.py +++ b/awkward1/_numba/iterator.py @@ -8,7 +8,7 @@ class IteratorType(numba.types.common.SimpleIteratorType): def __init__(self, arraytpe): self.arraytpe = arraytpe - super(IteratorType, self).__init__("iter({0})".format(self.arraytpe.name), self.arraytpe.getitem_int()) + super(IteratorType, self).__init__("iter({})".format(self.arraytpe.name), self.arraytpe.getitem_int()) @numba.typing.templates.infer class ContentType_type_getiter(numba.typing.templates.AbstractTemplate): @@ -24,7 +24,8 @@ def generic(self, args, kwargs): class IteratorModel(numba.datamodel.models.StructModel): def __init__(self, dmm, fe_type): members = [("array", fe_type.arraytpe), - ("where", numba.types.EphemeralPointer(numba.int64))] + ("length", numba.int64), + ("at", numba.types.EphemeralPointer(numba.int64))] super(IteratorModel, self).__init__(dmm, fe_type, members) @numba.extending.lower_builtin("getiter", content.ContentType) @@ -33,7 +34,8 @@ def lower_getiter(context, builder, sig, args): val, = args proxyout = context.make_helper(builder, rettpe) proxyout.array = val - proxyout.where = numba.cgutils.alloca_once_value(builder, context.get_constant(numba.int64, 0)) + proxyout.length = util.cast(context, builder, numba.intp, numba.int64, tpe.lower_len(context, builder, numba.intp(tpe), (val,))) + proxyout.at = numba.cgutils.alloca_once_value(builder, context.get_constant(numba.int64, 0)) if context.enable_nrt: context.nrt.incref(builder, tpe, val) return numba.targets.imputils.impl_ret_new_ref(context, builder, rettpe, proxyout._getvalue()) @@ -45,14 +47,12 @@ def lower_iternext(context, builder, sig, args, result): val, = args proxyin = context.make_helper(builder, tpe, value=val) - where = builder.load(proxyin.where) - length = tpe.arraytpe.lower_len(context, builder, numba.intp(tpe.arraytpe), (proxyin.array,)) - length = util.cast(context, builder, numba.intp, numba.int64, length) + at = builder.load(proxyin.at) - is_valid = builder.icmp_signed("<", where, length) + is_valid = builder.icmp_signed("<", at, proxyin.length) result.set_valid(is_valid) with builder.if_then(is_valid, likely=True): - result.yield_(tpe.arraytpe.lower_getitem_int(context, builder, tpe.yield_type(tpe.arraytpe, numba.int64), (proxyin.array, where))) - nextwhere = numba.cgutils.increment_index(builder, where) - builder.store(nextwhere, proxyin.where) + result.yield_(tpe.arraytpe.lower_getitem_int(context, builder, tpe.yield_type(tpe.arraytpe, numba.int64), (proxyin.array, at))) + nextat = numba.cgutils.increment_index(builder, at) + builder.store(nextat, proxyin.at) diff --git a/awkward1/_numba/libawkward.py b/awkward1/_numba/libawkward.py index 2d3723f35d..73b4bb0b5c 100644 --- a/awkward1/_numba/libawkward.py +++ b/awkward1/_numba/libawkward.py @@ -73,3 +73,52 @@ FillableArray_endlist.argtypes = [ctypes.c_voidp] FillableArray_endlist.restype = ctypes.c_uint8 FillableArray_endlist.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_endlist) + +# uint8_t awkward_FillableArray_begintuple(void* fillablearray, int64_t numfields); +FillableArray_begintuple = lib.awkward_FillableArray_begintuple +FillableArray_begintuple.name = "FillableArray.begintuple" +FillableArray_begintuple.argtypes = [ctypes.c_voidp, ctypes.c_int64] +FillableArray_begintuple.restype = ctypes.c_uint8 +FillableArray_begintuple.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_begintuple) + +# uint8_t awkward_FillableArray_index(void* fillablearray, int64_t index); +FillableArray_index = lib.awkward_FillableArray_index +FillableArray_index.name = "FillableArray.index" +FillableArray_index.argtypes = [ctypes.c_voidp, ctypes.c_int64] +FillableArray_index.restype = ctypes.c_uint8 +FillableArray_index.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_index) + +# uint8_t awkward_FillableArray_endtuple(void* fillablearray); +FillableArray_endtuple = lib.awkward_FillableArray_endtuple +FillableArray_endtuple.name = "FillableArray.endtuple" +FillableArray_endtuple.argtypes = [ctypes.c_voidp] +FillableArray_endtuple.restype = ctypes.c_uint8 +FillableArray_endtuple.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_endtuple) + +# uint8_t awkward_FillableArray_beginrecord(void* fillablearray, int64_t disambiguator); +FillableArray_beginrecord = lib.awkward_FillableArray_beginrecord +FillableArray_beginrecord.name = "FillableArray.beginrecord" +FillableArray_beginrecord.argtypes = [ctypes.c_voidp, ctypes.c_int64] +FillableArray_beginrecord.restype = ctypes.c_uint8 +FillableArray_beginrecord.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_beginrecord) + +# uint8_t awkward_FillableArray_field_fast(void* fillablearray, const char* key); +FillableArray_field_fast = lib.awkward_FillableArray_field_fast +FillableArray_field_fast.name = "FillableArray.field_fast" +FillableArray_field_fast.argtypes = [ctypes.c_voidp, ctypes.c_voidp] +FillableArray_field_fast.restype = ctypes.c_uint8 +FillableArray_field_fast.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_field_fast) + +# uint8_t awkward_FillableArray_field_check(void* fillablearray, const char* key); +FillableArray_field_check = lib.awkward_FillableArray_field_check +FillableArray_field_check.name = "FillableArray.field_check" +FillableArray_field_check.argtypes = [ctypes.c_voidp, ctypes.c_voidp] +FillableArray_field_check.restype = ctypes.c_uint8 +FillableArray_field_check.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_field_check) + +# uint8_t awkward_FillableArray_endrecord(void* fillablearray); +FillableArray_endrecord = lib.awkward_FillableArray_endrecord +FillableArray_endrecord.name = "FillableArray.endrecord" +FillableArray_endrecord.argtypes = [ctypes.c_voidp] +FillableArray_endrecord.restype = ctypes.c_uint8 +FillableArray_endrecord.numbatpe = numba.typing.ctypes_utils.make_function_type(FillableArray_endrecord) diff --git a/awkward1/_numba/util.py b/awkward1/_numba/util.py index 0a02a92853..cb26219f59 100644 --- a/awkward1/_numba/util.py +++ b/awkward1/_numba/util.py @@ -5,11 +5,31 @@ import numpy import numba import llvmlite.ir.types +import llvmlite.llvmpy.core from .._numba import cpu py27 = (sys.version_info[0] < 3) +if not py27: + exec(""" +def debug(context, builder, *args): + assert len(args) % 2 == 0 + tpes, vals = args[0::2], args[1::2] + context.get_function(print, numba.none(*tpes))(builder, tuple(vals)) +""", globals()) + +dynamic_addrs = {} +def globalstring(context, builder, pyvalue, inttype=None): + if pyvalue not in dynamic_addrs: + buf = dynamic_addrs[pyvalue] = numpy.array(pyvalue.encode("utf-8") + b"\x00") + context.add_dynamic_addr(builder, buf.ctypes.data, info="str({})".format(repr(pyvalue))) + if inttype is None: + ptr = context.get_constant(numba.types.uintp, dynamic_addrs[pyvalue].ctypes.data) + return builder.inttoptr(ptr, llvmlite.llvmpy.core.Type.pointer(llvmlite.llvmpy.core.Type.int(8))) + else: + return context.get_constant(inttype, dynamic_addrs[pyvalue].ctypes.data) + RefType = numba.int64 index8tpe = numba.types.Array(numba.int8, 1, "C") @@ -31,14 +51,6 @@ def indextpe(indexname): else: raise AssertionError("unrecognized index type: {}".format(indexname)) -if not py27: - exec(""" -def debug(context, builder, *args): - assert len(args) % 2 == 0 - tpes, vals = args[0::2], args[1::2] - context.get_function(print, numba.none(*tpes))(builder, tuple(vals)) -""", globals()) - def cast(context, builder, fromtpe, totpe, val): if isinstance(fromtpe, llvmlite.ir.types.IntType): if fromtpe.width == 8: diff --git a/awkward1/signatures/NumpyArray_8cpp.xml b/awkward1/signatures/NumpyArray_8cpp.xml index b3aa0e8a62..4c565b6504 100644 --- a/awkward1/signatures/NumpyArray_8cpp.xml +++ b/awkward1/signatures/NumpyArray_8cpp.xml @@ -9,6 +9,7 @@ awkward/cpu-kernels/getitem.h awkward/type/PrimitiveType.h awkward/type/RegularType.h + awkward/util.h awkward/array/NumpyArray.h @@ -33,6 +34,11 @@ + + + + + @@ -43,7 +49,7 @@ - + @@ -86,92 +92,7 @@ - - - - void - void awkward::tojson_boolean - (ToJson &builder, bool *array, int64_t length) - tojson_boolean - - ToJson & - builder - - - bool * - array - - - int64_t - length - - - - - - - - - - - - - typename T - - - void - void awkward::tojson_integer - (ToJson &builder, T *array, int64_t length) - tojson_integer - - ToJson & - builder - - - T * - array - - - int64_t - length - - - - - - - - - - - - - typename T - - - void - void awkward::tojson_real - (ToJson &builder, T *array, int64_t length) - tojson_real - - ToJson & - builder - - - T * - array - - - int64_t - length - - - - - - - - + const std::vector< ssize_t > @@ -188,7 +109,7 @@ - + const std::vector< ssize_t > @@ -205,7 +126,7 @@ - + diff --git a/awkward1/signatures/RecordArray_8cpp.xml b/awkward1/signatures/RecordArray_8cpp.xml new file mode 100644 index 0000000000..def527b089 --- /dev/null +++ b/awkward1/signatures/RecordArray_8cpp.xml @@ -0,0 +1,54 @@ + + + + RecordArray.cpp + sstream + awkward/cpu-kernels/identity.h + awkward/cpu-kernels/getitem.h + awkward/type/RecordType.h + awkward/array/Record.h + awkward/array/RecordArray.h + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + awkward + + + + + + + diff --git a/awkward1/signatures/Record_8cpp.xml b/awkward1/signatures/Record_8cpp.xml new file mode 100644 index 0000000000..425f902636 --- /dev/null +++ b/awkward1/signatures/Record_8cpp.xml @@ -0,0 +1,48 @@ + + + + Record.cpp + sstream + awkward/cpu-kernels/identity.h + awkward/cpu-kernels/getitem.h + awkward/type/RecordType.h + awkward/array/Record.h + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + awkward + + + + + + + diff --git a/awkward1/signatures/cpu-kernels_2util_8cpp.xml b/awkward1/signatures/cpu-kernels_2util_8cpp.xml index 24651eabda..cea9d261cc 100644 --- a/awkward1/signatures/cpu-kernels_2util_8cpp.xml +++ b/awkward1/signatures/cpu-kernels_2util_8cpp.xml @@ -32,7 +32,7 @@ - + struct Error @@ -57,7 +57,7 @@ - + diff --git a/awkward1/signatures/getitem_8cpp.xml b/awkward1/signatures/getitem_8cpp.xml index 6c286382ad..327574056e 100644 --- a/awkward1/signatures/getitem_8cpp.xml +++ b/awkward1/signatures/getitem_8cpp.xml @@ -1505,7 +1505,7 @@ - + ERROR @@ -1558,7 +1558,7 @@ - + ERROR @@ -1611,7 +1611,7 @@ - + ERROR @@ -1664,7 +1664,7 @@ - + @@ -1697,7 +1697,7 @@ - + ERROR @@ -1722,7 +1722,7 @@ - + ERROR @@ -1747,7 +1747,7 @@ - + ERROR @@ -1772,7 +1772,7 @@ - + @@ -1809,7 +1809,7 @@ - + ERROR @@ -1838,7 +1838,7 @@ - + ERROR @@ -1867,7 +1867,7 @@ - + ERROR @@ -1896,7 +1896,7 @@ - + @@ -1957,7 +1957,7 @@ - + ERROR @@ -2010,7 +2010,7 @@ - + ERROR @@ -2063,7 +2063,7 @@ - + ERROR @@ -2116,7 +2116,7 @@ - + @@ -2181,7 +2181,7 @@ - + ERROR @@ -2238,7 +2238,7 @@ - + ERROR @@ -2295,7 +2295,7 @@ - + ERROR @@ -2352,7 +2352,7 @@ - + @@ -2409,7 +2409,7 @@ - + ERROR @@ -2458,7 +2458,7 @@ - + ERROR @@ -2507,7 +2507,7 @@ - + ERROR @@ -2556,7 +2556,7 @@ - + @@ -2590,7 +2590,7 @@ - + ERROR @@ -2619,7 +2619,7 @@ - + @@ -2661,7 +2661,7 @@ - + ERROR @@ -2698,7 +2698,7 @@ - + @@ -2732,7 +2732,7 @@ - + ERROR @@ -2761,7 +2761,7 @@ - + @@ -2795,7 +2795,7 @@ - + ERROR @@ -2824,7 +2824,7 @@ - + @@ -2866,7 +2866,7 @@ - + ERROR @@ -2903,7 +2903,7 @@ - + @@ -2949,7 +2949,7 @@ - + ERROR @@ -2990,7 +2990,7 @@ - + @@ -3024,7 +3024,7 @@ - + ERROR @@ -3053,7 +3053,7 @@ - + diff --git a/awkward1/signatures/identity_8cpp.xml b/awkward1/signatures/identity_8cpp.xml index 147c54ec7c..d7567ec410 100644 --- a/awkward1/signatures/identity_8cpp.xml +++ b/awkward1/signatures/identity_8cpp.xml @@ -112,6 +112,239 @@ + + + + typename ID + + + typename T + + + ERROR + ERROR awkward_identity_from_listoffsetarray + (ID *toptr, const ID *fromptr, const T *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity_from_listoffsetarray + + ID * + toptr + + + const ID * + fromptr + + + const T * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + + + ERROR + ERROR awkward_identity32_from_listoffsetarray32 + (int32_t *toptr, const int32_t *fromptr, const int32_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity32_from_listoffsetarray32 + + int32_t * + toptr + + + const int32_t * + fromptr + + + const int32_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + + + ERROR + ERROR awkward_identity64_from_listoffsetarray32 + (int64_t *toptr, const int64_t *fromptr, const int32_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity64_from_listoffsetarray32 + + int64_t * + toptr + + + const int64_t * + fromptr + + + const int32_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + + + ERROR + ERROR awkward_identity64_from_listoffsetarrayU32 + (int64_t *toptr, const int64_t *fromptr, const uint32_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity64_from_listoffsetarrayU32 + + int64_t * + toptr + + + const int64_t * + fromptr + + + const uint32_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + + + ERROR + ERROR awkward_identity64_from_listoffsetarray64 + (int64_t *toptr, const int64_t *fromptr, const int64_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity64_from_listoffsetarray64 + + int64_t * + toptr + + + const int64_t * + fromptr + + + const int64_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + @@ -171,7 +404,7 @@ - + ERROR @@ -224,7 +457,7 @@ - + ERROR @@ -277,7 +510,7 @@ - + ERROR @@ -330,7 +563,7 @@ - + ERROR @@ -383,7 +616,7 @@ - + @@ -429,7 +662,7 @@ - + ERROR @@ -470,7 +703,7 @@ - + ERROR @@ -511,7 +744,7 @@ - + diff --git a/awkward1/signatures/libawkward_2util_8cpp.xml b/awkward1/signatures/libawkward_2util_8cpp.xml index 2bc02bd37b..a3616dab9c 100644 --- a/awkward1/signatures/libawkward_2util_8cpp.xml +++ b/awkward1/signatures/libawkward_2util_8cpp.xml @@ -72,6 +72,121 @@ + + std::string + std::string awkward::util::quote + (std::string x, bool doublequote) + quote + + std::string + x + + + bool + doublequote + + + + + + + + + + + + + Error + Error awkward::util::awkward_identity64_from_listoffsetarray< uint32_t > + (int64_t *toptr, const int64_t *fromptr, const uint32_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity64_from_listoffsetarray< uint32_t > + + int64_t * + toptr + + + const int64_t * + fromptr + + + const uint32_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + + + + + Error + Error awkward::util::awkward_identity64_from_listoffsetarray< int64_t > + (int64_t *toptr, const int64_t *fromptr, const int64_t *fromoffsets, int64_t fromptroffset, int64_t offsetsoffset, int64_t tolength, int64_t fromlength, int64_t fromwidth) + awkward_identity64_from_listoffsetarray< int64_t > + + int64_t * + toptr + + + const int64_t * + fromptr + + + const int64_t * + fromoffsets + + + int64_t + fromptroffset + + + int64_t + offsetsoffset + + + int64_t + tolength + + + int64_t + fromlength + + + int64_t + fromwidth + + + + + + + + + @@ -125,7 +240,7 @@ - + @@ -180,7 +295,7 @@ - + @@ -223,7 +338,7 @@ - + @@ -266,7 +381,7 @@ - + @@ -309,7 +424,7 @@ - + @@ -360,7 +475,7 @@ - + @@ -411,7 +526,7 @@ - + @@ -462,7 +577,7 @@ - + @@ -517,7 +632,7 @@ - + @@ -572,7 +687,7 @@ - + @@ -627,7 +742,7 @@ - + @@ -654,7 +769,7 @@ - + @@ -681,7 +796,7 @@ - + @@ -708,7 +823,7 @@ - + @@ -739,7 +854,7 @@ - + @@ -770,7 +885,7 @@ - + @@ -801,7 +916,7 @@ - + @@ -856,7 +971,7 @@ - + @@ -911,7 +1026,7 @@ - + @@ -966,7 +1081,7 @@ - + @@ -1025,7 +1140,7 @@ - + @@ -1084,7 +1199,7 @@ - + @@ -1143,7 +1258,7 @@ - + @@ -1194,7 +1309,7 @@ - + @@ -1245,7 +1360,7 @@ - + @@ -1296,7 +1411,7 @@ - + diff --git a/awkward1/util.py b/awkward1/util.py index 76ca65bb4e..7664ee93b0 100644 --- a/awkward1/util.py +++ b/awkward1/util.py @@ -1 +1,24 @@ # BSD 3-Clause License; see https://github.com/jpivarski/awkward-1.0/blob/master/LICENSE + +import numbers +import re + +import numpy + +def field2index(lookup, numfields, key): + if isinstance(key, (int, numbers.Integral, numpy.integer)): + attempt = key + else: + attempt = None if lookup is None else lookup.get(key) + + if attempt is None: + m = field2index._pattern.match(key) + if m is not None: + attempt = m.group(0) + + if attempt is None or attempt >= numfields: + raise ValueError("key {0} not found in Record".format(repr(key))) + else: + return attempt + +field2index._pattern = re.compile(r"^[1-9][0-9]*$") diff --git a/include/awkward/array/Record.h b/include/awkward/array/Record.h index fc8020c6f6..360427a911 100644 --- a/include/awkward/array/Record.h +++ b/include/awkward/array/Record.h @@ -8,13 +8,22 @@ namespace awkward { class Record: public Content { public: - Record(const RecordArray& recordarray, int64_t at) - : recordarray_(recordarray) + Record(const RecordArray& array, int64_t at) + : array_(array) , at_(at) { } - const std::shared_ptr recordarray() const { return recordarray_.shallow_copy(); } + const std::shared_ptr array() const { return array_.shallow_copy(); } int64_t at() const { return at_; } - bool istuple() const { return recordarray_.istuple(); } + const std::vector> contents() const { + std::vector> out; + for (auto item : array_.contents()) { + out.push_back(item.get()->getitem_at_nowrap(at_)); + } + return out; + } + const std::shared_ptr lookup() const { return array_.lookup(); } + const std::shared_ptr reverselookup() const { return array_.reverselookup(); } + bool istuple() const { return lookup().get() == nullptr; } virtual bool isscalar() const; virtual const std::string classname() const; @@ -48,7 +57,7 @@ namespace awkward { const std::vector keys() const; const std::vector> values() const; const std::vector>> items() const; - const Record withoutkeys() const; + const Record astuple() const; protected: virtual const std::shared_ptr getitem_next(const SliceAt& at, const Slice& tail, const Index64& advanced) const; @@ -58,7 +67,7 @@ namespace awkward { virtual const std::shared_ptr getitem_next(const SliceFields& fields, const Slice& tail, const Index64& advanced) const; private: - const RecordArray recordarray_; + const RecordArray array_; int64_t at_; }; } diff --git a/include/awkward/array/RecordArray.h b/include/awkward/array/RecordArray.h index 858b67f62f..011859299a 100644 --- a/include/awkward/array/RecordArray.h +++ b/include/awkward/array/RecordArray.h @@ -77,7 +77,7 @@ namespace awkward { const std::vector keys() const; const std::vector> values() const; const std::vector>> items() const; - const RecordArray withoutkeys() const; + const RecordArray astuple() const; void append(const std::shared_ptr& content, const std::string& key); void append(const std::shared_ptr& content); diff --git a/include/awkward/fillable/FillableArray.h b/include/awkward/fillable/FillableArray.h index 7b0e0cb1f8..060b49230a 100644 --- a/include/awkward/fillable/FillableArray.h +++ b/include/awkward/fillable/FillableArray.h @@ -26,7 +26,6 @@ namespace awkward { const std::shared_ptr getitem_fields(const std::vector& keys) const; const std::shared_ptr getitem(const Slice& where) const; - bool active() const; void null(); void boolean(bool x); void integer(int64_t x); @@ -70,6 +69,13 @@ extern "C" { uint8_t awkward_FillableArray_real(void* fillablearray, double x); uint8_t awkward_FillableArray_beginlist(void* fillablearray); uint8_t awkward_FillableArray_endlist(void* fillablearray); + uint8_t awkward_FillableArray_begintuple(void* fillablearray, int64_t numfields); + uint8_t awkward_FillableArray_index(void* fillablearray, int64_t index); + uint8_t awkward_FillableArray_endtuple(void* fillablearray); + uint8_t awkward_FillableArray_beginrecord(void* fillablearray, int64_t disambiguator); + uint8_t awkward_FillableArray_field_fast(void* fillablearray, const char* key); + uint8_t awkward_FillableArray_field_check(void* fillablearray, const char* key); + uint8_t awkward_FillableArray_endrecord(void* fillablearray); } #endif // AWKWARD_FILLABLE_H_ diff --git a/src/libawkward/array/NumpyArray.cpp b/src/libawkward/array/NumpyArray.cpp index 55949e9700..d013da16db 100644 --- a/src/libawkward/array/NumpyArray.cpp +++ b/src/libawkward/array/NumpyArray.cpp @@ -87,7 +87,12 @@ namespace awkward { if (i != 0) { out << " "; } - out << ptr[i]; + if (std::is_same::value) { + out << (ptr[i] ? "true" : "false"); + } + else { + out << ptr[i]; + } } } else { @@ -95,14 +100,24 @@ namespace awkward { if (i != 0) { out << " "; } - out << ptr[i]; + if (std::is_same::value) { + out << (ptr[i] ? "true" : "false"); + } + else { + out << ptr[i]; + } } out << " ... "; for (int64_t i = length - 5; i < length; i++) { if (i != length - 5) { out << " "; } - out << ptr[i]; + if (std::is_same::value) { + out << (ptr[i] ? "true" : "false"); + } + else { + out << ptr[i]; + } } } } @@ -149,6 +164,9 @@ namespace awkward { else if (ndim() == 1 && format_.compare("d") == 0) { tostring_as(out, reinterpret_cast(byteptr()), length()); } + else if (ndim() == 1 && format_.compare("?") == 0) { + tostring_as(out, reinterpret_cast(byteptr()), length()); + } else { ssize_t len = bytelength(); if (len <= 32) { diff --git a/src/libawkward/array/Record.cpp b/src/libawkward/array/Record.cpp index 9f1a4dc501..f727fa94f4 100644 --- a/src/libawkward/array/Record.cpp +++ b/src/libawkward/array/Record.cpp @@ -18,7 +18,7 @@ namespace awkward { } const std::shared_ptr Record::id() const { - std::shared_ptr recid = recordarray_.id(); + std::shared_ptr recid = array_.id(); if (recid.get() == nullptr) { return recid; } @@ -38,21 +38,21 @@ namespace awkward { const std::string Record::tostring_part(const std::string indent, const std::string pre, const std::string post) const { std::stringstream out; out << indent << pre << "<" << classname() << " at=\"" << at_ << "\">\n"; - out << recordarray_.tostring_part(indent + std::string(" "), "", "\n"); + out << array_.tostring_part(indent + std::string(" "), "", "\n"); out << indent << "" << post; return out.str(); } void Record::tojson_part(ToJson& builder) const { size_t cols = (size_t)numfields(); - std::shared_ptr keys = recordarray_.reverselookup(); + std::shared_ptr keys = array_.reverselookup(); if (istuple()) { keys = std::shared_ptr(new RecordArray::ReverseLookup); for (size_t j = 0; j < cols; j++) { keys.get()->push_back(std::to_string(j)); } } - std::vector> contents = recordarray_.contents(); + std::vector> contents = array_.contents(); builder.beginrec(); for (size_t j = 0; j < cols; j++) { builder.fieldkey(keys.get()->at(j).c_str()); @@ -62,7 +62,7 @@ namespace awkward { } const std::shared_ptr Record::type_part() const { - return recordarray_.type_part(); + return array_.type_part(); } int64_t Record::length() const { @@ -70,12 +70,12 @@ namespace awkward { } const std::shared_ptr Record::shallow_copy() const { - return std::shared_ptr(new Record(recordarray_, at_)); + return std::shared_ptr(new Record(array_, at_)); } void Record::check_for_iteration() const { - if (recordarray_.id().get() != nullptr && recordarray_.id().get()->length() != 1) { - util::handle_error(failure("len(id) != 1 for scalar Record", kSliceNone, kSliceNone), recordarray_.id().get()->classname(), nullptr); + if (array_.id().get() != nullptr && array_.id().get()->length() != 1) { + util::handle_error(failure("len(id) != 1 for scalar Record", kSliceNone, kSliceNone), array_.id().get()->classname(), nullptr); } } @@ -100,19 +100,19 @@ namespace awkward { } const std::shared_ptr Record::getitem_field(const std::string& key) const { - return recordarray_.field(key).get()->getitem_at_nowrap(at_); + return array_.field(key).get()->getitem_at_nowrap(at_); } const std::shared_ptr Record::getitem_fields(const std::vector& keys) const { - RecordArray out(recordarray_.id(), length(), istuple()); + RecordArray out(array_.id(), length(), istuple()); if (istuple()) { for (auto key : keys) { - out.append(recordarray_.field(key)); + out.append(array_.field(key)); } } else { for (auto key : keys) { - out.append(recordarray_.field(key), key); + out.append(array_.field(key), key); } } return out.getitem_at_nowrap(at_); @@ -123,74 +123,74 @@ namespace awkward { } const std::pair Record::minmax_depth() const { - return recordarray_.minmax_depth(); + return array_.minmax_depth(); } int64_t Record::numfields() const { - return recordarray_.numfields(); + return array_.numfields(); } int64_t Record::index(const std::string& key) const { - return recordarray_.index(key); + return array_.index(key); } const std::string Record::key(int64_t index) const { - return recordarray_.key(index); + return array_.key(index); } bool Record::has(const std::string& key) const { - return recordarray_.has(key); + return array_.has(key); } const std::vector Record::aliases(int64_t index) const { - return recordarray_.aliases(index); + return array_.aliases(index); } const std::vector Record::aliases(const std::string& key) const { - return recordarray_.aliases(key); + return array_.aliases(key); } const std::shared_ptr Record::field(int64_t index) const { - return recordarray_.field(index).get()->getitem_at_nowrap(at_); + return array_.field(index).get()->getitem_at_nowrap(at_); } const std::shared_ptr Record::field(const std::string& key) const { - return recordarray_.field(key).get()->getitem_at_nowrap(at_); + return array_.field(key).get()->getitem_at_nowrap(at_); } const std::vector Record::keys() const { - return recordarray_.keys(); + return array_.keys(); } const std::vector> Record::values() const { std::vector> out; int64_t cols = numfields(); for (int64_t j = 0; j < cols; j++) { - out.push_back(recordarray_.field(j).get()->getitem_at_nowrap(at_)); + out.push_back(array_.field(j).get()->getitem_at_nowrap(at_)); } return out; } const std::vector>> Record::items() const { std::vector>> out; - std::shared_ptr keys = recordarray_.reverselookup(); + std::shared_ptr keys = array_.reverselookup(); if (istuple()) { int64_t cols = numfields(); for (int64_t j = 0; j < cols; j++) { - out.push_back(std::pair>(std::to_string(j), recordarray_.field(j).get()->getitem_at_nowrap(at_))); + out.push_back(std::pair>(std::to_string(j), array_.field(j).get()->getitem_at_nowrap(at_))); } } else { int64_t cols = numfields(); for (int64_t j = 0; j < cols; j++) { - out.push_back(std::pair>(keys.get()->at((size_t)j), recordarray_.field(j).get()->getitem_at_nowrap(at_))); + out.push_back(std::pair>(keys.get()->at((size_t)j), array_.field(j).get()->getitem_at_nowrap(at_))); } } return out; } - const Record Record::withoutkeys() const { - return Record(recordarray_.withoutkeys(), at_); + const Record Record::astuple() const { + return Record(array_.astuple(), at_); } const std::shared_ptr Record::getitem_next(const SliceAt& at, const Slice& tail, const Index64& advanced) const { diff --git a/src/libawkward/array/RecordArray.cpp b/src/libawkward/array/RecordArray.cpp index 04ab51e871..2e0cb3e91e 100644 --- a/src/libawkward/array/RecordArray.cpp +++ b/src/libawkward/array/RecordArray.cpp @@ -381,7 +381,7 @@ namespace awkward { return out; } - const RecordArray RecordArray::withoutkeys() const { + const RecordArray RecordArray::astuple() const { return RecordArray(id_, contents_); } diff --git a/src/libawkward/fillable/FillableArray.cpp b/src/libawkward/fillable/FillableArray.cpp index 5551a3c492..c91eebeade 100644 --- a/src/libawkward/fillable/FillableArray.cpp +++ b/src/libawkward/fillable/FillableArray.cpp @@ -49,10 +49,6 @@ namespace awkward { return snapshot().get()->getitem(where); } - bool FillableArray::active() const { - return fillable_.get()->active(); - } - void FillableArray::null() { maybeupdate(fillable_.get()->null()); } @@ -207,3 +203,80 @@ uint8_t awkward_FillableArray_endlist(void* fillablearray) { } return 0; } + +uint8_t awkward_FillableArray_begintuple(void* fillablearray, int64_t numfields) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->begintuple(numfields); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_index(void* fillablearray, int64_t index) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->index(index); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_endtuple(void* fillablearray) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->endtuple(); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_beginrecord(void* fillablearray, int64_t disambiguator) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->beginrecord(disambiguator); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_field_fast(void* fillablearray, const char* key) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->field_fast(key); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_field_check(void* fillablearray, const char* key) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->field_check(key); + } + catch (...) { + return 1; + } + return 0; +} + +uint8_t awkward_FillableArray_endrecord(void* fillablearray) { + awkward::FillableArray* obj = reinterpret_cast(fillablearray); + try { + obj->endrecord(); + } + catch (...) { + return 1; + } + return 0; +} diff --git a/src/pyawkward.cpp b/src/pyawkward.cpp index 4caf4592ce..e0cb26f8ae 100644 --- a/src/pyawkward.cpp +++ b/src/pyawkward.cpp @@ -1090,6 +1090,40 @@ py::class_ make_RegularArray(py::handle m, std::s /////////////////////////////////////////////////////////////// RecordArray +template +py::object lookup(const T& self) { + std::shared_ptr lookup = self.lookup(); + if (lookup.get() == nullptr) { + return py::none(); + } + else { + py::dict out; + for (auto pair : *lookup.get()) { + std::string cppkey = pair.first; + py::str pykey(PyUnicode_DecodeUTF8(cppkey.data(), cppkey.length(), "surrogateescape")); + out[pykey] = py::cast(pair.second); + } + return out; + } +} + +template +py::object reverselookup(const T& self) { + std::shared_ptr reverselookup = self.reverselookup(); + if (reverselookup.get() == nullptr) { + return py::none(); + } + else { + py::list out; + for (auto item : *reverselookup.get()) { + std::string cppkey = item; + py::str pykey(PyUnicode_DecodeUTF8(cppkey.data(), cppkey.length(), "surrogateescape")); + out.append(pykey); + } + return out; + } +} + py::class_ make_RecordArray(py::handle m, std::string name) { return content(py::class_(m, name.c_str()) .def(py::init([](py::dict contents, py::object id) -> ak::RecordArray { @@ -1102,6 +1136,9 @@ py::class_ make_RecordArray(py::handle m, std::str reverselookup.get()->push_back(key); out.push_back(unbox_content(x.second)); } + if (out.size() == 0) { + throw std::invalid_argument("construct RecordArrays without fields using RecordArray(length) where length is an integer"); + } return ak::RecordArray(unbox_id(id), out, lookup, reverselookup); }), py::arg("contents"), py::arg("id") = py::none()) .def(py::init([](py::iterable contents, py::object id) -> ak::RecordArray { @@ -1109,6 +1146,9 @@ py::class_ make_RecordArray(py::handle m, std::str for (auto x : contents) { out.push_back(unbox_content(x)); } + if (out.size() == 0) { + throw std::invalid_argument("construct RecordArrays without fields using RecordArray(length) where length is an integer"); + } return ak::RecordArray(unbox_id(id), out, std::shared_ptr(nullptr), std::shared_ptr(nullptr)); }), py::arg("contents"), py::arg("id") = py::none()) .def(py::init([](int64_t length, bool istuple, py::object id) -> ak::RecordArray { @@ -1152,8 +1192,10 @@ py::class_ make_RecordArray(py::handle m, std::str } return out; }) - .def_property_readonly("withoutkeys", [](ak::RecordArray& self) -> py::object { - return box(self.withoutkeys().shallow_copy()); + .def_property_readonly("lookup", &lookup) + .def_property_readonly("reverselookup", &reverselookup) + .def_property_readonly("astuple", [](ak::RecordArray& self) -> py::object { + return box(self.astuple().shallow_copy()); }) .def("append", [](ak::RecordArray& self, py::object content, py::object key) -> void { @@ -1171,6 +1213,7 @@ py::class_ make_RecordArray(py::handle m, std::str py::class_ make_Record(py::handle m, std::string name) { return py::class_(m, name.c_str()) + .def(py::init()) .def("__repr__", &repr) .def_property_readonly("id", [](ak::Record& self) -> py::object { return box(self.id()); }) .def("__getitem__", &getitem) @@ -1178,6 +1221,8 @@ py::class_ make_Record(py::handle m, std::string name) { .def("tojson", &tojson_file, py::arg("destination"), py::arg("pretty") = false, py::arg("maxdecimals") = py::none(), py::arg("buffersize") = 65536) .def_property_readonly("type", &ak::Content::type) + .def_property_readonly("array", [](ak::Record& self) -> py::object { return box(self.array()); }) + .def_property_readonly("at", &ak::Record::at) .def_property_readonly("istuple", &ak::Record::istuple) .def_property_readonly("numfields", &ak::Record::numfields) .def("index", &ak::Record::index) @@ -1215,8 +1260,10 @@ py::class_ make_Record(py::handle m, std::string name) { } return out; }) - .def_property_readonly("withoutkeys", [](ak::RecordArray& self) -> py::object { - return box(self.withoutkeys().shallow_copy()); + .def_property_readonly("lookup", &lookup) + .def_property_readonly("reverselookup", &reverselookup) + .def_property_readonly("astuple", [](ak::Record& self) -> py::object { + return box(self.astuple().shallow_copy()); }) .def_property_readonly("location", &location) diff --git a/tests/test_PR025_record_array.py b/tests/test_PR025_record_array.py index 0bbc7b9b04..1d5d88c45b 100644 --- a/tests/test_PR025_record_array.py +++ b/tests/test_PR025_record_array.py @@ -61,7 +61,10 @@ def test_basic(): assert awkward1.tolist(pairs[1][1]) == [[1.1, 2.2, 3.3], [], [4.4, 5.5], [6.6], [7.7, 8.8, 9.9]] assert awkward1.tolist(pairs[2][1]) == [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9] - assert awkward1.tojson(recordarray.withoutkeys) == '[{"0":1,"1":[1.1,2.2,3.3],"2":1.1},{"0":2,"1":[],"2":2.2},{"0":3,"1":[4.4,5.5],"2":3.3},{"0":4,"1":[6.6],"2":4.4},{"0":5,"1":[7.7,8.8,9.9],"2":5.5}]' + assert awkward1.tojson(recordarray.astuple) == '[{"0":1,"1":[1.1,2.2,3.3],"2":1.1},{"0":2,"1":[],"2":2.2},{"0":3,"1":[4.4,5.5],"2":3.3},{"0":4,"1":[6.6],"2":4.4},{"0":5,"1":[7.7,8.8,9.9],"2":5.5}]' + + assert recordarray.lookup == {"one": 0, "two": 1, "wonky": 0} + assert recordarray.astuple.lookup is None def test_scalar_record(): content1 = awkward1.layout.NumpyArray(numpy.array([1, 2, 3, 4, 5])) @@ -85,6 +88,8 @@ def test_scalar_record(): assert awkward1.tolist(pairs[1][1]) == [4.4, 5.5] assert awkward1.tolist(recordarray[2]) == {"one": 3, "two": [4.4, 5.5]} + assert awkward1.tolist(awkward1.layout.Record(recordarray, 2)) == {"one": 3, "two": [4.4, 5.5]} + def test_type(): content1 = awkward1.layout.NumpyArray(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)) content2 = awkward1.layout.NumpyArray(numpy.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], dtype=numpy.float64)) @@ -94,6 +99,7 @@ def test_type(): recordarray.append(content1) recordarray.append(listoffsetarray) assert str(awkward1.typeof(recordarray)) == '5 * (int64, var * float64)' + assert recordarray.lookup is None assert awkward1.typeof(recordarray) == awkward1.layout.ArrayType(awkward1.layout.RecordType( awkward1.layout.PrimitiveType("int64"), @@ -111,6 +117,7 @@ def test_type(): recordarray.setkey(0, "one") recordarray.setkey(1, "two") assert str(awkward1.typeof(recordarray)) in ('5 * {"one": int64, "two": var * float64}', '5 * {"two": var * float64, "one": int64}') + assert recordarray.lookup == {"one": 0, "two": 1} assert str(awkward1.layout.RecordType( awkward1.layout.PrimitiveType("int32"), diff --git a/tests/test_PR026_recordarray_in_numba.py b/tests/test_PR026_recordarray_in_numba.py new file mode 100644 index 0000000000..42f1b1cedd --- /dev/null +++ b/tests/test_PR026_recordarray_in_numba.py @@ -0,0 +1,225 @@ +# BSD 3-Clause License; see https://github.com/jpivarski/awkward-1.0/blob/master/LICENSE + +import sys +import itertools + +import pytest +import numpy + +numba = pytest.importorskip("numba") + +import awkward1 + +content1 = awkward1.layout.NumpyArray(numpy.array([1, 2, 3, 4, 5])) +content2 = awkward1.layout.NumpyArray(numpy.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])) +offsets = awkward1.layout.Index64(numpy.array([0, 3, 3, 5, 6, 9])) +listoffsetarray = awkward1.layout.ListOffsetArray64(offsets, content2) +recordarray = awkward1.layout.RecordArray({"one": content1, "two": listoffsetarray}) + +def test_boxing(): + @numba.njit + def f1(q): + return 3.14 + + assert f1(recordarray) == 3.14 + assert f1(recordarray[2]) == 3.14 + + @numba.njit + def f2(q): + return q + + assert awkward1.tolist(f2(recordarray)) == [{"one": 1, "two": [1.1, 2.2, 3.3]}, {"one": 2, "two": []}, {"one": 3, "two": [4.4, 5.5]}, {"one": 4, "two": [6.6]}, {"one": 5, "two": [7.7, 8.8, 9.9]}] + assert awkward1.tolist(f2(recordarray[2])) == {"one": 3, "two": [4.4, 5.5]} + +def test_len(): + @numba.njit + def f1(q): + return len(q) + + assert f1(recordarray) == 5 + with pytest.raises(numba.TypingError): + f1(recordarray[2]) + +def test_getitem_int(): + @numba.njit + def f1(q): + return q[2] + + assert awkward1.tolist(f1(recordarray)) == {"one": 3, "two": [4.4, 5.5]} + with pytest.raises(numba.TypingError): + f1(recordarray[2]) + +def test_getitem_iter(): + @numba.njit + def f1(q): + out = 0 + for x in q: + out += 1 + return out + + assert f1(recordarray) == 5 + with pytest.raises(numba.TypingError): + f1(recordarray[2]) + +def test_getitem_range(): + @numba.njit + def f1(q): + return q[1:4] + + assert awkward1.tolist(f1(recordarray)) == [{"one": 2, "two": []}, {"one": 3, "two": [4.4, 5.5]}, {"one": 4, "two": [6.6]}] + with pytest.raises(numba.TypingError): + f1(recordarray[2]) + +def test_getitem_str(): + outer_starts = numpy.array([0, 3, 3], dtype=numpy.int64) + outer_stops = numpy.array([3, 3, 5], dtype=numpy.int64) + outer_offsets = numpy.array([0, 3, 3, 5], dtype=numpy.int64) + outer_listarray = awkward1.layout.ListArray64(awkward1.layout.Index64(outer_starts), awkward1.layout.Index64(outer_stops), recordarray) + outer_listoffsetarray = awkward1.layout.ListOffsetArray64(awkward1.layout.Index64(outer_offsets), recordarray) + outer_regulararray = awkward1.layout.RegularArray(recordarray, 2) + + @numba.njit + def f1(q): + return q["one"] + + assert awkward1.tolist(f1(recordarray)) == [1, 2, 3, 4, 5] + + assert sys.getrefcount(outer_starts), sys.getrefcount(outer_stops) == (3, 3) + assert awkward1.tolist(f1(outer_listarray)) == [[1, 2, 3], [], [4, 5]] + assert sys.getrefcount(outer_starts), sys.getrefcount(outer_stops) == (3, 3) + + assert sys.getrefcount(outer_offsets) == 3 + assert awkward1.tolist(f1(outer_listoffsetarray)) == [[1, 2, 3], [], [4, 5]] + assert sys.getrefcount(outer_offsets) == 3 + + assert awkward1.tolist(f1(outer_regulararray)) == [[1, 2], [3, 4]] + + @numba.njit + def f2(q): + return q["two"] + + assert awkward1.tolist(f1(recordarray[2])) == 3 + assert awkward1.tolist(f2(recordarray[2])) == [4.4, 5.5] + +def test_getitem_tuple(): + content3 = awkward1.layout.NumpyArray(numpy.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) + regulararray = awkward1.layout.RegularArray(content3, 2) + offsets2 = awkward1.layout.Index64(numpy.array([0, 3, 4, 5, 8, 9])) + listoffsetarray2 = awkward1.layout.ListOffsetArray64(offsets2, content2) + recordarray2 = awkward1.layout.RecordArray({"one": regulararray, "two": listoffsetarray2}) + + assert awkward1.tolist(recordarray2) == [{"one": [1, 2], "two": [1.1, 2.2, 3.3]}, {"one": [3, 4], "two": [4.4]}, {"one": [5, 6], "two": [5.5]}, {"one": [7, 8], "two": [6.6, 7.7, 8.8]}, {"one": [9, 10], "two": [9.9]}] + + @numba.njit + def f1(q): + return q[:, -1] + + assert awkward1.tolist(f1(recordarray2)) == [{"one": 2, "two": 3.3}, {"one": 4, "two": 4.4}, {"one": 6, "two": 5.5}, {"one": 8, "two": 8.8}, {"one": 10, "two": 9.9}] + + @numba.njit + def f2(q): + return q[:, -2:] + + assert awkward1.tolist(f2(recordarray2)) == [{"one": [1, 2], "two": [2.2, 3.3]}, {"one": [3, 4], "two": [4.4]}, {"one": [5, 6], "two": [5.5]}, {"one": [7, 8], "two": [7.7, 8.8]}, {"one": [9, 10], "two": [9.9]}] + + @numba.njit + def f3(q): + return q[2:, "two"] + + assert awkward1.tolist(f3(recordarray2)) == [[5.5], [6.6, 7.7, 8.8], [9.9]] + + @numba.njit + def f4(q): + return q["two", 1:-1] + + assert awkward1.tolist(f4(recordarray2)) == [[4.4], [5.5], [6.6, 7.7, 8.8]] + +def test_fillablearray_tuple(): + fillablearray = awkward1.layout.FillableArray() + + @numba.njit + def f1(q): + q.begintuple(3) + q.index(0); q.boolean(True) + q.index(1); q.integer(1) + q.index(2); q.real(1.1) + q.endtuple() + + q.begintuple(3) + q.index(0); q.boolean(False) + q.index(1); q.integer(1) + q.index(2); q.real(1.1) + q.endtuple() + + q.begintuple(3) + q.index(0); q.boolean(True) + q.index(1); q.integer(1) + q.index(2); q.real(1.1) + q.endtuple() + + return q + + fillablearray2 = f1(fillablearray) + + assert awkward1.tolist(fillablearray.snapshot()) == [(True, 1, 1.1), (False, 1, 1.1), (True, 1, 1.1)] + assert awkward1.tolist(fillablearray2.snapshot()) == [(True, 1, 1.1), (False, 1, 1.1), (True, 1, 1.1)] + +def test_fillablearray_record_1(): + fillablearray = awkward1.layout.FillableArray() + + @numba.njit + def f1(q): + q.beginrecord() + q.field("one"); q.boolean(True) + q.field("two"); q.integer(1) + q.field("three"); q.real(1.1) + q.endrecord() + + q.beginrecord() + q.field("one"); q.boolean(False) + q.field("two"); q.integer(2) + q.field("three"); q.real(2.2) + q.endrecord() + + q.beginrecord() + q.field("one"); q.boolean(True) + q.field("two"); q.integer(3) + q.field("three"); q.real(3.3) + q.endrecord() + + return q + + fillablearray2 = f1(fillablearray) + + assert awkward1.tolist(fillablearray.snapshot()) == [{'one': True, 'two': 1, 'three': 1.1}, {'one': False, 'two': 2, 'three': 2.2}, {'one': True, 'two': 3, 'three': 3.3}] + assert awkward1.tolist(fillablearray2.snapshot()) == [{'one': True, 'two': 1, 'three': 1.1}, {'one': False, 'two': 2, 'three': 2.2}, {'one': True, 'two': 3, 'three': 3.3}] + +def test_fillablearray_record_2(): + fillablearray = awkward1.layout.FillableArray() + + @numba.njit + def f1(q): + q.beginrecord("wowie") + q.field("one"); q.boolean(True) + q.field("two"); q.integer(1) + q.field("three"); q.real(1.1) + q.endrecord() + + q.beginrecord("wowie") + q.field("one"); q.boolean(False) + q.field("two"); q.integer(2) + q.field("three"); q.real(2.2) + q.endrecord() + + q.beginrecord("wowie") + q.field("one"); q.boolean(True) + q.field("two"); q.integer(3) + q.field("three"); q.real(3.3) + q.endrecord() + + return q + + fillablearray2 = f1(fillablearray) + + assert awkward1.tolist(fillablearray.snapshot()) == [{'one': True, 'two': 1, 'three': 1.1}, {'one': False, 'two': 2, 'three': 2.2}, {'one': True, 'two': 3, 'three': 3.3}] + assert awkward1.tolist(fillablearray2.snapshot()) == [{'one': True, 'two': 1, 'three': 1.1}, {'one': False, 'two': 2, 'three': 2.2}, {'one': True, 'two': 3, 'three': 3.3}]