From 45f6e85f50873bab4fa3fb596e7a2a5abdbf4947 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 4 Sep 2024 15:34:23 -0400 Subject: [PATCH] GH-43956: [C++][Format] Add initial Decimal32/Decimal64 implementations --- cpp/src/arrow/acero/tpch_node.cc | 18 +- cpp/src/arrow/array/array_base.cc | 8 + cpp/src/arrow/array/array_decimal.cc | 28 ++ cpp/src/arrow/array/array_decimal.h | 32 ++ cpp/src/arrow/array/array_test.cc | 120 ++++- cpp/src/arrow/array/array_view_test.cc | 38 +- cpp/src/arrow/array/builder_base.cc | 2 + cpp/src/arrow/array/builder_decimal.cc | 71 +++ cpp/src/arrow/array/builder_decimal.h | 62 +++ cpp/src/arrow/array/builder_dict.h | 18 +- cpp/src/arrow/array/concatenate.cc | 2 +- cpp/src/arrow/array/diff.cc | 6 +- cpp/src/arrow/array/diff_test.cc | 2 + cpp/src/arrow/array/util.cc | 36 ++ cpp/src/arrow/array/validate.cc | 10 + cpp/src/arrow/builder.cc | 2 + cpp/src/arrow/c/bridge.cc | 4 + cpp/src/arrow/c/bridge_test.cc | 44 +- cpp/src/arrow/compare.cc | 24 + cpp/src/arrow/compute/kernel_test.cc | 10 +- .../arrow/compute/kernels/aggregate_basic.cc | 2 +- .../compute/kernels/aggregate_internal.h | 10 + .../arrow/compute/kernels/aggregate_mode.cc | 2 +- .../compute/kernels/aggregate_tdigest.cc | 2 + .../compute/kernels/aggregate_var_std.cc | 4 +- .../arrow/compute/kernels/codegen_internal.h | 62 ++- .../arrow/compute/kernels/hash_aggregate.cc | 28 +- .../arrow/compute/kernels/vector_hash_test.cc | 6 +- .../compute/kernels/vector_pairwise_test.cc | 6 +- cpp/src/arrow/csv/converter_test.cc | 28 +- .../engine/substrait/expression_internal.cc | 32 ++ cpp/src/arrow/integration/json_internal.cc | 54 ++- cpp/src/arrow/ipc/json_simple.cc | 12 +- cpp/src/arrow/ipc/json_simple_test.cc | 14 +- cpp/src/arrow/ipc/metadata_internal.cc | 37 +- cpp/src/arrow/json/converter.cc | 4 +- cpp/src/arrow/pretty_print_test.cc | 4 +- cpp/src/arrow/scalar.cc | 22 + cpp/src/arrow/scalar.h | 8 + cpp/src/arrow/testing/gtest_util.h | 3 +- cpp/src/arrow/testing/random.cc | 123 ++++- cpp/src/arrow/testing/random.h | 31 ++ cpp/src/arrow/testing/random_test.cc | 22 +- cpp/src/arrow/type.cc | 100 +++- cpp/src/arrow/type.h | 71 +++ cpp/src/arrow/type_fwd.h | 24 + cpp/src/arrow/type_test.cc | 16 +- cpp/src/arrow/type_traits.cc | 2 + cpp/src/arrow/type_traits.h | 40 ++ cpp/src/arrow/util/align_util_test.cc | 17 +- cpp/src/arrow/util/basic_decimal.cc | 407 ++++++++++++++++ cpp/src/arrow/util/basic_decimal.h | 308 +++++++++++++ cpp/src/arrow/util/decimal.cc | 433 ++++++++++++++++++ cpp/src/arrow/util/decimal.h | 244 ++++++++++ cpp/src/arrow/util/decimal_internal.h | 104 +++++ cpp/src/arrow/util/decimal_test.cc | 4 +- cpp/src/arrow/util/formatting.h | 12 + cpp/src/arrow/util/formatting_util_test.cc | 14 +- cpp/src/arrow/visitor.cc | 6 + cpp/src/arrow/visitor.h | 6 + cpp/src/arrow/visitor_generate.h | 2 + .../parquet/arrow/arrow_reader_writer_test.cc | 30 +- cpp/src/parquet/arrow/arrow_schema_test.cc | 22 +- cpp/src/parquet/arrow/test_util.h | 20 +- format/Schema.fbs | 9 +- 65 files changed, 2802 insertions(+), 142 deletions(-) diff --git a/cpp/src/arrow/acero/tpch_node.cc b/cpp/src/arrow/acero/tpch_node.cc index 137b62ad38a95..abc742f9fa10b 100644 --- a/cpp/src/arrow/acero/tpch_node.cc +++ b/cpp/src/arrow/acero/tpch_node.cc @@ -838,12 +838,12 @@ class PartAndPartSupplierGenerator { const std::vector> kPartTypes = { int32(), utf8(), fixed_size_binary(25), fixed_size_binary(10), - utf8(), int32(), fixed_size_binary(10), decimal(12, 2), + utf8(), int32(), fixed_size_binary(10), decimal128(12, 2), utf8(), }; const std::vector> kPartsuppTypes = { - int32(), int32(), int32(), decimal(12, 2), utf8(), + int32(), int32(), int32(), decimal128(12, 2), utf8(), }; Status AllocatePartBatch(size_t thread_index, int column) { @@ -1527,7 +1527,7 @@ class OrdersAndLineItemGenerator { const std::vector> kOrdersTypes = {int32(), int32(), fixed_size_binary(1), - decimal(12, 2), + decimal128(12, 2), date32(), fixed_size_binary(15), fixed_size_binary(15), @@ -1539,10 +1539,10 @@ class OrdersAndLineItemGenerator { int32(), int32(), int32(), - decimal(12, 2), - decimal(12, 2), - decimal(12, 2), - decimal(12, 2), + decimal128(12, 2), + decimal128(12, 2), + decimal128(12, 2), + decimal128(12, 2), fixed_size_binary(1), fixed_size_binary(1), date32(), @@ -2489,7 +2489,7 @@ class SupplierGenerator : public TpchTableGenerator { std::vector> kTypes = { int32(), fixed_size_binary(25), utf8(), - int32(), fixed_size_binary(15), decimal(12, 2), + int32(), fixed_size_binary(15), decimal128(12, 2), utf8(), }; @@ -2872,7 +2872,7 @@ class CustomerGenerator : public TpchTableGenerator { utf8(), int32(), fixed_size_binary(15), - decimal(12, 2), + decimal128(12, 2), fixed_size_binary(10), utf8(), }; diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index 6927f51283eb7..fe4dc5e2223e9 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -74,6 +74,14 @@ struct ScalarFromArraySlotImpl { return Finish(a.Value(index_)); } + Status Visit(const Decimal32Array& a) { + return Finish(Decimal32(a.GetValue(index_))); + } + + Status Visit(const Decimal64Array& a) { + return Finish(Decimal64(a.GetValue(index_))); + } + Status Visit(const Decimal128Array& a) { return Finish(Decimal128(a.GetValue(index_))); } diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index d65f6ee53564f..a2c9cae3451a1 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -32,6 +32,34 @@ namespace arrow { using internal::checked_cast; +// ---------------------------------------------------------------------- +// Decimal32 + +Decimal32Array::Decimal32Array(const std::shared_ptr& data) + : FixedSizeBinaryArray(data) { + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL32); +} + +std::string Decimal32Array::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const Decimal32 value(GetValue(i)); + return value.ToString(type_.scale()); +} + +// ---------------------------------------------------------------------- +// Decimal64 + +Decimal64Array::Decimal64Array(const std::shared_ptr& data) + : FixedSizeBinaryArray(data) { + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL64); +} + +std::string Decimal64Array::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const Decimal64 value(GetValue(i)); + return value.ToString(type_.scale()); +} + // ---------------------------------------------------------------------- // Decimal128 diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index f14812549089a..2f10bb8429996 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -32,6 +32,38 @@ namespace arrow { /// /// @{ +// ---------------------------------------------------------------------- +// Decimal32Array + +/// Concrete Array class for 32-bit decimal data +class ARROW_EXPORT Decimal32Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal32Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal32Array from ArrayData instance + explicit Decimal32Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +// ---------------------------------------------------------------------- +// Decimal64Array + +/// Concrete Array class for 64-bit decimal data +class ARROW_EXPORT Decimal64Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal64Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal64Array from ArrayData instance + explicit Decimal64Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + // ---------------------------------------------------------------------- // Decimal128Array diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 32806d9d2edb3..b5e89e4ad9e48 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -667,7 +667,9 @@ static ScalarVector GetScalars() { std::make_shared(hello), std::make_shared( hello, fixed_size_binary(static_cast(hello->size()))), - std::make_shared(Decimal128(10), decimal(16, 4)), + std::make_shared(Decimal32(10), decimal(7, 4)), + std::make_shared(Decimal64(10), decimal(12, 4)), + std::make_shared(Decimal128(10), decimal(20, 4)), std::make_shared(Decimal256(10), decimal(76, 38)), std::make_shared(hello), std::make_shared(hello), @@ -3092,6 +3094,98 @@ class DecimalTest : public ::testing::TestWithParam { } }; +using Decimal32Test = DecimalTest; + +TEST_P(Decimal32Test, NoNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal32(1), Decimal32(-2), Decimal32(2389), + Decimal32(4), Decimal32(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal32Test, WithNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal32(1), Decimal32(2), Decimal32(-1), Decimal32(4), + Decimal32(-1), Decimal32(1), Decimal32(2)}; + Decimal32 big; + ASSERT_OK_AND_ASSIGN(big, Decimal32::FromString("23034.234234")); + draw.push_back(big); + + Decimal32 big_negative; + ASSERT_OK_AND_ASSIGN(big_negative, Decimal32::FromString("-23049.235234")); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal32Test, ValidateFull) { + int32_t precision = GetParam(); + std::vector draw; + Decimal32 val = Decimal32::GetMaxValue(precision) + 1; + + draw = {Decimal32(), val}; + auto arr = this->TestCreate(precision, draw, {true, false}, 0); + ASSERT_OK(arr->ValidateFull()); + + draw = {val, Decimal32()}; + arr = this->TestCreate(precision, draw, {true, false}, 0); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("does not fit in precision of"), arr->ValidateFull()); +} + +INSTANTIATE_TEST_SUITE_P(Decimal32Test, Decimal32Test, ::testing::Range(1, 9)); + +using Decimal64Test = DecimalTest; + +TEST_P(Decimal64Test, NoNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal64(1), Decimal64(-2), Decimal64(2389), + Decimal64(4), Decimal64(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal64Test, WithNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal64(1), Decimal64(2), Decimal64(-1), Decimal64(4), + Decimal64(-1), Decimal64(1), Decimal64(2)}; + Decimal64 big; + ASSERT_OK_AND_ASSIGN(big, Decimal64::FromString("23034.234234")); + draw.push_back(big); + + Decimal64 big_negative; + ASSERT_OK_AND_ASSIGN(big_negative, Decimal64::FromString("-23049.235234")); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal64Test, ValidateFull) { + int32_t precision = GetParam(); + std::vector draw; + Decimal64 val = Decimal64::GetMaxValue(precision) + 1; + + draw = {Decimal64(), val}; + auto arr = this->TestCreate(precision, draw, {true, false}, 0); + ASSERT_OK(arr->ValidateFull()); + + draw = {val, Decimal64()}; + arr = this->TestCreate(precision, draw, {true, false}, 0); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("does not fit in precision of"), arr->ValidateFull()); +} + +INSTANTIATE_TEST_SUITE_P(Decimal64Test, Decimal64Test, ::testing::Range(1, 9)); + using Decimal128Test = DecimalTest; TEST_P(Decimal128Test, NoNulls) { @@ -3315,6 +3409,28 @@ TEST(TestSwapEndianArrayData, PrimitiveType) { expected_data = ArrayData::Make(uint64(), 1, {null_buffer, data_int64_buffer}, 0); AssertArrayDataEqualsWithSwapEndian(data, expected_data); + auto data_4byte_buffer = Buffer::FromString( + "\x01" + "12\x01"); + data = ArrayData::Make(decimal32(9, 8), 1, {null_buffer, data_4byte_buffer}); + auto data_decimal32_buffer = Buffer::FromString( + "\x01" + "21\x01"); + expected_data = + ArrayData::Make(decimal32(9, 8), 1, {null_buffer, data_decimal32_buffer}, 0); + AssertArrayDataEqualsWithSwapEndian(data, expected_data); + + auto data_8byte_buffer = Buffer::FromString( + "\x01" + "123456\x01"); + data = ArrayData::Make(decimal64(18, 8), 1, {null_buffer, data_8byte_buffer}); + auto data_decimal64_buffer = Buffer::FromString( + "\x01" + "654321\x01"); + expected_data = + ArrayData::Make(decimal64(18, 8), 1, {null_buffer, data_decimal64_buffer}, 0); + AssertArrayDataEqualsWithSwapEndian(data, expected_data); + auto data_16byte_buffer = Buffer::FromString( "\x01" "123456789abcde\x01"); @@ -3647,6 +3763,8 @@ DataTypeVector SwappableTypes() { uint16(), uint32(), uint64(), + decimal32(8, 1), + decimal64(16, 2), decimal128(19, 4), decimal256(37, 8), timestamp(TimeUnit::MICRO, ""), diff --git a/cpp/src/arrow/array/array_view_test.cc b/cpp/src/arrow/array/array_view_test.cc index 97110ea97f3fc..280690d5b51fe 100644 --- a/cpp/src/arrow/array/array_view_test.cc +++ b/cpp/src/arrow/array/array_view_test.cc @@ -385,10 +385,34 @@ TEST(TestArrayView, SparseUnionAsStruct) { CheckView(expected, arr); } -TEST(TestArrayView, DecimalRoundTrip) { +TEST(TestArrayView, Decimal32RoundTrip) { + auto ty1 = decimal(9, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + + auto ty2 = fixed_size_binary(4); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + +TEST(TestArrayView, Decimal64RoundTrip) { auto ty1 = decimal(10, 4); auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + auto ty2 = fixed_size_binary(8); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + +TEST(TestArrayView, Decimal128RoundTrip) { + auto ty1 = decimal(20, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + auto ty2 = fixed_size_binary(16); ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); ASSERT_OK(v->ValidateFull()); @@ -397,6 +421,18 @@ TEST(TestArrayView, DecimalRoundTrip) { AssertArraysEqual(*arr, *w); } +TEST(TestArrayView, Decimal256RoundTrip) { + auto ty1 = decimal256(10, 4); + auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])"); + + auto ty2 = fixed_size_binary(32); + ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2)); + ASSERT_OK(v->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1)); + ASSERT_OK(w->ValidateFull()); + AssertArraysEqual(*arr, *w); +} + TEST(TestArrayView, Dictionaries) { // ARROW-6049 auto ty1 = dictionary(int8(), float32()); diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 40e705aa3e440..2e6e1bfd13032 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -119,6 +119,8 @@ struct AppendScalarImpl { } Status Visit(const FixedSizeBinaryType& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal32Type& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal64Type& t) { return HandleFixedWidth(t); } Status Visit(const Decimal128Type& t) { return HandleFixedWidth(t); } Status Visit(const Decimal256Type& t) { return HandleFixedWidth(t); } diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index 3b1262819df7f..e4d67d5f1769f 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -32,6 +32,77 @@ namespace arrow { class Buffer; class MemoryPool; +// ---------------------------------------------------------------------- +// Decimal32Builder + +Decimal32Builder::Decimal32Builder(const std::shared_ptr& type, + MemoryPool* pool, int64_t alignment) + : FixedSizeBinaryBuilder(type, pool, alignment), + decimal_type_(internal::checked_pointer_cast(type)) {} + +Status Decimal32Builder::Append(Decimal32 value) { + RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); + UnsafeAppend(value); + return Status::OK(); +} + +void Decimal32Builder::UnsafeAppend(Decimal32 value) { + value.ToBytes(GetMutableValue(length())); + byte_builder_.UnsafeAdvance(4); + UnsafeAppendToBitmap(true); +} + +void Decimal32Builder::UnsafeAppend(std::string_view value) { + FixedSizeBinaryBuilder::UnsafeAppend(value); +} + +Status Decimal32Builder::FinishInternal(std::shared_ptr* out) { + std::shared_ptr data; + RETURN_NOT_OK(byte_builder_.Finish(&data)); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); +} + + +// ---------------------------------------------------------------------- +// Decimal64Builder + +Decimal64Builder::Decimal64Builder(const std::shared_ptr& type, + MemoryPool* pool, int64_t alignment) + : FixedSizeBinaryBuilder(type, pool, alignment), + decimal_type_(internal::checked_pointer_cast(type)) {} + +Status Decimal64Builder::Append(Decimal64 value) { + RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); + UnsafeAppend(value); + return Status::OK(); +} + +void Decimal64Builder::UnsafeAppend(Decimal64 value) { + value.ToBytes(GetMutableValue(length())); + byte_builder_.UnsafeAdvance(8); + UnsafeAppendToBitmap(true); +} + +void Decimal64Builder::UnsafeAppend(std::string_view value) { + FixedSizeBinaryBuilder::UnsafeAppend(value); +} + +Status Decimal64Builder::FinishInternal(std::shared_ptr* out) { + std::shared_ptr data; + RETURN_NOT_OK(byte_builder_.Finish(&data)); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); +} + // ---------------------------------------------------------------------- // Decimal128Builder diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 8094250aef8d4..c506a78b05276 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -33,6 +33,68 @@ namespace arrow { /// /// @{ +class ARROW_EXPORT Decimal32Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal32Type; + using ValueType = Decimal32; + + explicit Decimal32Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal32 val); + void UnsafeAppend(Decimal32 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +class ARROW_EXPORT Decimal64Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal64Type; + using ValueType = Decimal64; + + explicit Decimal64Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal64 val); + void UnsafeAppend(Decimal64 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { public: using TypeClass = Decimal128Type; diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 3f0d711dc5bb5..2b6df0cabeab9 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -298,6 +298,22 @@ class DictionaryBuilderBase : public ArrayBuilder { return Append(std::string_view(value, length)); } + /// \brief Append a decimal (only for Decimal32Type) + template + enable_if_decimal32 Append(const Decimal32& value) { + uint8_t data[4]; + value.ToBytes(data); + return Append(data, 4); + } + + /// \brief Append a decimal (only for Decimal64Type) + template + enable_if_decimal64 Append(const Decimal64& value) { + uint8_t data[8]; + value.ToBytes(data); + return Append(data, 8); + } + /// \brief Append a decimal (only for Decimal128Type) template enable_if_decimal128 Append(const Decimal128& value) { @@ -306,7 +322,7 @@ class DictionaryBuilderBase : public ArrayBuilder { return Append(data, 16); } - /// \brief Append a decimal (only for Decimal128Type) + /// \brief Append a decimal (only for Decimal256Type) template enable_if_decimal256 Append(const Decimal256& value) { uint8_t data[32]; diff --git a/cpp/src/arrow/array/concatenate.cc b/cpp/src/arrow/array/concatenate.cc index b4638dd6593d8..d8a69868d1543 100644 --- a/cpp/src/arrow/array/concatenate.cc +++ b/cpp/src/arrow/array/concatenate.cc @@ -377,7 +377,7 @@ class ConcatenateImpl { } Status Visit(const FixedWidthType& fixed) { - // Handles numbers, decimal128, decimal256, fixed_size_binary + // Handles numbers, decimal32, decimal64, decimal128, decimal256, fixed_size_binary ARROW_ASSIGN_OR_RAISE(auto buffers, Buffers(1, fixed)); return ConcatenateBuffers(buffers, pool_).Value(&out_->buffers[1]); } diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc index f9714eda34c61..26fc854af89d3 100644 --- a/cpp/src/arrow/array/diff.cc +++ b/cpp/src/arrow/array/diff.cc @@ -707,7 +707,11 @@ class MakeFormatterImpl { template enable_if_decimal Visit(const T&) { impl_ = [](const Array& array, int64_t index, std::ostream* os) { - if constexpr (T::type_id == Type::DECIMAL128) { + if constexpr (T::type_id == Type::DECIMAL32) { + *os << checked_cast(array).FormatValue(index); + } else if constexpr (T::type_id == Type::DECIMAL64) { + *os << checked_cast(array).FormatValue(index); + } else if constexpr (T::type_id == Type::DECIMAL128) { *os << checked_cast(array).FormatValue(index); } else { *os << checked_cast(array).FormatValue(index); diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 145978a91ad54..02bcf5bbb4c5b 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -707,6 +707,8 @@ TEST_F(DiffTest, UnifiedDiffFormatter) { } for (const auto& type : { + decimal32(8, 4), + decimal64(10, 4), decimal128(10, 4), decimal256(10, 4), }) { diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index b56ea25f9e421..d468a8db5fe0c 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -152,6 +152,42 @@ class ArrayDataEndianSwapper { return Status::OK(); } + Status Visit(const Decimal32Type& type) { + auto data = reinterpret_cast(data_->buffers[1]->data()); + ARROW_ASSIGN_OR_RAISE(auto new_buffer, + AllocateBuffer(data_->buffers[1]->size(), pool_)); + auto new_data = reinterpret_cast(new_buffer->mutable_data()); + // NOTE: data_->length not trusted (see warning above) + const int64_t length = data_->buffers[1]->size() / Decimal32Type::kByteWidth; + for (int64_t i = 0; i < length; i++) { +#if ARROW_LITTLE_ENDIAN + new_data[i] = bit_util::FromBigEndian(data[i]); +#else + new_data[i] = bit_util::FromLittleEndian(data[i]); +#endif + } + out_->buffers[1] = std::move(new_buffer); + return Status::OK(); + } + + Status Visit(const Decimal64Type& type) { + auto data = reinterpret_cast(data_->buffers[1]->data()); + ARROW_ASSIGN_OR_RAISE(auto new_buffer, + AllocateBuffer(data_->buffers[1]->size(), pool_)); + auto new_data = reinterpret_cast(new_buffer->mutable_data()); + // NOTE: data_->length not trusted (see warning above) + const int64_t length = data_->buffers[1]->size() / Decimal64Type::kByteWidth; + for (int64_t i = 0; i < length; i++) { +#if ARROW_LITTLE_ENDIAN + new_data[i] = bit_util::FromBigEndian(data[i]); +#else + new_data[i] = bit_util::FromLittleEndian(data[i]); +#endif + } + out_->buffers[1] = std::move(new_buffer); + return Status::OK(); + } + Status Visit(const Decimal128Type& type) { auto data = reinterpret_cast(data_->buffers[1]->data()); ARROW_ASSIGN_OR_RAISE(auto new_buffer, diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 0d940d3bc869e..4cce2b76c2fc9 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -144,6 +144,16 @@ struct ValidateArrayImpl { Status Visit(const FixedWidthType&) { return ValidateFixedWidthBuffers(); } + Status Visit(const Decimal32Type& type) { + RETURN_NOT_OK(ValidateFixedWidthBuffers()); + return ValidateDecimals(type); + } + + Status Visit(const Decimal64Type& type) { + RETURN_NOT_OK(ValidateFixedWidthBuffers()); + return ValidateDecimals(type); + } + Status Visit(const Decimal128Type& type) { RETURN_NOT_OK(ValidateFixedWidthBuffers()); return ValidateDecimals(type); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 7042d9818c691..46969e73e22ae 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -151,6 +151,8 @@ struct DictionaryBuilderCase { Status Visit(const BinaryViewType&) { return CreateFor(); } Status Visit(const StringViewType&) { return CreateFor(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } + Status Visit(const Decimal32Type&) { return CreateFor(); } + Status Visit(const Decimal64Type&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index eba575f4cf39c..9deac40bcb0d8 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1253,6 +1253,10 @@ struct SchemaImporter { type_ = decimal128(prec_scale[0], prec_scale[1]); } else if (prec_scale[2] == 256) { type_ = decimal256(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 32) { + type_ = decimal32(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 64) { + type_ = decimal64(prec_scale[0], prec_scale[1]); } else { return f_parser_.Invalid(); } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 09bb524adbdf0..d0a79104c14a2 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -363,13 +363,19 @@ TEST_F(TestSchemaExport, Primitive) { TestPrimitive(binary_view(), "vz"); TestPrimitive(utf8_view(), "vu"); - TestPrimitive(decimal(16, 4), "d:16,4"); + TestPrimitive(decimal(8, 4), "d:8,4,32"); + TestPrimitive(decimal(16, 4), "d:16,4,64"); + TestPrimitive(decimal128(16, 4), "d:16,4"); TestPrimitive(decimal256(16, 4), "d:16,4,256"); - TestPrimitive(decimal(15, 0), "d:15,0"); + TestPrimitive(decimal(8, 0), "d:8,0,32"); + TestPrimitive(decimal(15, 0), "d:15,0,64"); + TestPrimitive(decimal128(15, 0), "d:15,0"); TestPrimitive(decimal256(15, 0), "d:15,0,256"); - TestPrimitive(decimal(15, -4), "d:15,-4"); + TestPrimitive(decimal(8, -4), "d:8,-4,32"); + TestPrimitive(decimal(15, -4), "d:15,-4,64"); + TestPrimitive(decimal128(15, -4), "d:15,-4"); TestPrimitive(decimal256(15, -4), "d:15,-4,256"); } @@ -1951,6 +1957,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, 4))); FillPrimitive("d:16,4,256"); CheckImport(field("", decimal256(16, 4))); + FillPrimitive("d:4,4,32"); + CheckImport(field("", decimal32(4, 4))); + FillPrimitive("d:16,4,64"); + CheckImport(field("", decimal64(16, 4))); FillPrimitive("d:16,0"); CheckImport(field("", decimal128(16, 0))); @@ -1958,6 +1968,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, 0))); FillPrimitive("d:16,0,256"); CheckImport(field("", decimal256(16, 0))); + FillPrimitive("d:4,0,32"); + CheckImport(field("", decimal32(4, 0))); + FillPrimitive("d:16,0,64"); + CheckImport(field("", decimal64(16, 0))); FillPrimitive("d:16,-4"); CheckImport(field("", decimal128(16, -4))); @@ -1965,6 +1979,10 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", decimal128(16, -4))); FillPrimitive("d:16,-4,256"); CheckImport(field("", decimal256(16, -4))); + FillPrimitive("d:4,-4,32"); + CheckImport(field("", decimal32(4, -4))); + FillPrimitive("d:16,-4,64"); + CheckImport(field("", decimal64(16, -4))); } TEST_F(TestSchemaImport, Temporal) { @@ -2034,7 +2052,7 @@ TEST_F(TestSchemaImport, String) { FillPrimitive("w:3"); CheckImport(fixed_size_binary(3)); FillPrimitive("d:15,4"); - CheckImport(decimal(15, 4)); + CheckImport(decimal128(15, 4)); } TEST_F(TestSchemaImport, List) { @@ -2950,26 +2968,26 @@ TEST_F(TestArrayImport, FixedSizeBinary) { FillPrimitive(2, 0, 0, primitive_buffers_no_nulls2); CheckImport(ArrayFromJSON(fixed_size_binary(3), R"(["abc", "def"])")); FillPrimitive(2, 0, 0, primitive_buffers_no_nulls3); - CheckImport(ArrayFromJSON(decimal(15, 4), R"(["12345.6789", "98765.4321"])")); + CheckImport(ArrayFromJSON(decimal128(15, 4), R"(["12345.6789", "98765.4321"])")); // Empty array with null data pointers FillPrimitive(0, 0, 0, all_buffers_omitted); CheckImport(ArrayFromJSON(fixed_size_binary(3), "[]")); FillPrimitive(0, 0, 0, all_buffers_omitted); - CheckImport(ArrayFromJSON(decimal(15, 4), "[]")); + CheckImport(ArrayFromJSON(decimal128(15, 4), "[]")); } TEST_F(TestArrayImport, FixedSizeBinaryWithOffset) { FillPrimitive(1, 0, 1, primitive_buffers_no_nulls2); CheckImport(ArrayFromJSON(fixed_size_binary(3), R"(["def"])")); FillPrimitive(1, 0, 1, primitive_buffers_no_nulls3); - CheckImport(ArrayFromJSON(decimal(15, 4), R"(["98765.4321"])")); + CheckImport(ArrayFromJSON(decimal128(15, 4), R"(["98765.4321"])")); // Empty array with null data pointers FillPrimitive(0, 0, 1, all_buffers_omitted); CheckImport(ArrayFromJSON(fixed_size_binary(3), "[]")); FillPrimitive(0, 0, 1, all_buffers_omitted); - CheckImport(ArrayFromJSON(decimal(15, 4), "[]")); + CheckImport(ArrayFromJSON(decimal128(15, 4), "[]")); } TEST_F(TestArrayImport, List) { @@ -3624,10 +3642,16 @@ TEST_F(TestSchemaRoundtrip, Primitive) { TestWithTypeFactory(boolean); TestWithTypeFactory(float16); + TestWithTypeFactory([] { return decimal32(8, 4); }); + TestWithTypeFactory([] { return decimal64(16, 4); }); TestWithTypeFactory([] { return decimal128(19, 4); }); TestWithTypeFactory([] { return decimal256(19, 4); }); + TestWithTypeFactory([] { return decimal32(8, 0); }); + TestWithTypeFactory([] { return decimal64(16, 0); }); TestWithTypeFactory([] { return decimal128(19, 0); }); TestWithTypeFactory([] { return decimal256(19, 0); }); + TestWithTypeFactory([] { return decimal32(8, -5); }); + TestWithTypeFactory([] { return decimal64(16, -5); }); TestWithTypeFactory([] { return decimal128(19, -5); }); TestWithTypeFactory([] { return decimal256(19, -5); }); TestWithTypeFactory([] { return fixed_size_binary(3); }); @@ -3901,6 +3925,8 @@ TEST_F(TestArrayRoundtrip, Primitive) { TestWithJSON(int32(), "[]"); TestWithJSON(int32(), "[4, 5, null]"); + TestWithJSON(decimal32(8, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSON(decimal64(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSON(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSON(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); @@ -3908,6 +3934,8 @@ TEST_F(TestArrayRoundtrip, Primitive) { TestWithJSONSliced(int32(), "[4, 5]"); TestWithJSONSliced(int32(), "[4, 5, 6, null]"); + TestWithJSONSliced(decimal32(8, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSONSliced(decimal64(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); TestWithJSONSliced(month_day_nano_interval(), diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index e983b47e39dc4..4feda085e56f2 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -750,6 +750,18 @@ class TypeEqualsVisitor { return Status::OK(); } + Status Visit(const Decimal32Type& left) { + const auto& right = checked_cast(right_); + result_ = left.precision() == right.precision() && left.scale() == right.scale(); + return Status::OK(); + } + + Status Visit(const Decimal64Type& left) { + const auto& right = checked_cast(right_); + result_ = left.precision() == right.precision() && left.scale() == right.scale(); + return Status::OK(); + } + Status Visit(const Decimal128Type& left) { const auto& right = checked_cast(right_); result_ = left.precision() == right.precision() && left.scale() == right.scale(); @@ -900,6 +912,18 @@ class ScalarEqualsVisitor { return Status::OK(); } + Status Visit(const Decimal32Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + + Status Visit(const Decimal64Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + Status Visit(const Decimal128Scalar& left) { const auto& right = checked_cast(right_); result_ = left.value == right.value; diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index 5daf7d2991d2a..21b2338e59329 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -36,7 +36,7 @@ namespace compute { TEST(TypeMatcher, SameTypeId) { std::shared_ptr matcher = match::SameTypeId(Type::DECIMAL); - ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); + ASSERT_TRUE(matcher->Matches(*decimal(20, 2))); ASSERT_FALSE(matcher->Matches(*int8())); ASSERT_EQ("Type::DECIMAL128", matcher->ToString()); @@ -120,7 +120,7 @@ TEST(InputType, Constructors) { InputType ty2(Type::DECIMAL); ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); - ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); + ASSERT_TRUE(ty2.type_matcher().Matches(*decimal128(12, 2))); ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); // Implicit construction in a vector @@ -203,7 +203,7 @@ TEST(InputType, Matches) { ASSERT_TRUE(input1.Matches(*int8())); ASSERT_FALSE(input1.Matches(*int16())); - InputType input2(Type::DECIMAL); + InputType input2(Type::DECIMAL64); ASSERT_TRUE(input2.Matches(*decimal(12, 2))); auto ty2 = decimal(12, 2); @@ -312,7 +312,7 @@ TEST(OutputType, Resolve) { TEST(KernelSignature, Basics) { // (int8, decimal) -> utf8 - std::vector in_types({int8(), InputType(Type::DECIMAL)}); + std::vector in_types({int8(), InputType(Type::DECIMAL64)}); OutputType out_type(utf8()); KernelSignature sig(in_types, out_type); @@ -381,7 +381,7 @@ TEST(KernelSignature, MatchesInputs) { ASSERT_FALSE(sig2.MatchesInputs({})); ASSERT_FALSE(sig2.MatchesInputs({int8()})); - ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); + ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal128(12, 2)})); // (int8, int32) -> boolean KernelSignature sig3({int8(), int32()}, boolean()); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 1fbcd6a249093..ae6069e190f8c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -1042,7 +1042,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { func = std::make_shared("sum", Arity::Unary(), sum_doc, &default_scalar_aggregate_options); - AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get()); + AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get()); AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), SumInit, func.get(), SimdLevel::NONE); AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), SumInit, func.get(), diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h index 168f063c770f3..9dab049821d5c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -52,6 +52,16 @@ struct FindAccumulatorType> { using Type = DoubleType; }; +template +struct FindAccumulatorType> { + using Type = Decimal32Type; +}; + +template +struct FindAccumulatorType> { + using Type = Decimal64Type; +}; + template struct FindAccumulatorType> { using Type = Decimal128Type; diff --git a/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/cpp/src/arrow/compute/kernels/aggregate_mode.cc index 3f84c0a5ee4c4..86b20674b85ec 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_mode.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_mode.cc @@ -454,7 +454,7 @@ VectorKernel NewModeKernel(const std::shared_ptr& in_type, ArrayKernel kernel.init = ModeState::Init; kernel.can_execute_chunkwise = false; kernel.output_chunked = false; - switch (in_type->id()) { + switch (in_type->id()) { case Type::DECIMAL128: case Type::DECIMAL256: kernel.signature = diff --git a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc index 1dab92632ef2d..83d01091b3c8d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc @@ -51,6 +51,8 @@ struct TDigestImpl : public ScalarAggregator { double ToDouble(T value) const { return static_cast(value); } + double ToDouble(const Decimal32& value) const { return value.ToDouble(decimal_scale); } + double ToDouble(const Decimal64& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal128& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal256& value) const { return value.ToDouble(decimal_scale); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index c2fab48dbe208..28a98ee960d78 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -46,6 +46,8 @@ struct VarStdState { double ToDouble(T value) const { return static_cast(value); } + double ToDouble(const Decimal32& value) const { return value.ToDouble(decimal_scale); } + double ToDouble(const Decimal64& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal128& value) const { return value.ToDouble(decimal_scale); } double ToDouble(const Decimal256& value) const { return value.ToDouble(decimal_scale); } @@ -53,7 +55,7 @@ struct VarStdState { // algorithm` // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm template - enable_if_t::value || (sizeof(CType) > 4)> Consume( + enable_if_t::value || (sizeof(CType) > 4) || (!is_integer_type::value && sizeof(CType) == 4)> Consume( const ArraySpan& array) { this->all_valid = array.GetNullCount() == 0; int64_t count = array.length - array.GetNullCount(); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 9e46a21887f8c..7bfdb92611cd4 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -140,6 +140,30 @@ struct GetViewType::value || static T LogicalValue(PhysicalType value) { return value; } }; +template <> +struct GetViewType { + using T = Decimal32; + using PhysicalType = std::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal32(reinterpret_cast(value.data())); + } + + static T LogicalValue(T value) { return value; } +}; + +template <> +struct GetViewType { + using T = Decimal64; + using PhysicalType = std::string_view; + + static T LogicalValue(PhysicalType value) { + return Decimal64(reinterpret_cast(value.data())); + } + + static T LogicalValue(T value) { return value; } +}; + template <> struct GetViewType { using T = Decimal128; @@ -177,6 +201,16 @@ struct GetOutputType::value>> { using T = std::string; }; +template <> +struct GetOutputType { + using T = Decimal32; +}; + +template <> +struct GetOutputType { + using T = Decimal64; +}; + template <> struct GetOutputType { using T = Decimal128; @@ -224,7 +258,9 @@ using enable_if_not_floating_value = enable_if_t::val template using enable_if_decimal_value = - enable_if_t::value || std::is_same::value, + enable_if_t::value || std::is_same::value || + std::is_same::value || + std::is_same::value, R>; // ---------------------------------------------------------------------- @@ -353,6 +389,22 @@ struct UnboxScalar> { } }; +template <> +struct UnboxScalar { + using T = Decimal32; + static const T& Unbox(const Scalar& val) { + return checked_cast(val).value; + } +}; + +template <> +struct UnboxScalar { + using T = Decimal64; + static const T& Unbox(const Scalar& val) { + return checked_cast(val).value; + } +}; + template <> struct UnboxScalar { using T = Decimal128; @@ -1116,6 +1168,10 @@ ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) { template