From 3222e2a252bb421c1c0b870227f58a939cef12b3 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Thu, 16 Jan 2025 00:37:06 +0800 Subject: [PATCH] GH-44393: [C++][Compute] Vector selection functions `inverse_permutation` and `scatter` (#44394) ### Rationale for this change For background please see #44393. When implementing the "scatter" function requested in #44393, I found it also useful to make it a public vector API. After a painful thinking, I decided to name it "permute". And when implementing permute, I found it fairly easy to implement it by first computing the "reverse indices" of the positions, and then invoking the existing "take", where I think "reverse_indices" itself can also be a useful public vector API. Thus the PR categorized them as "placement functions". ### What changes are included in this PR? Implement vector selection API `inverse_permutation` and `scatter`, where `scatter(values, indices)` is implemented as `take(values, inverse_permutation(indices))`. ### Are these changes tested? UT included. ### Are there any user-facing changes? Yes, new public APIs added. Documents updated. * GitHub Issue: #44393 Lead-authored-by: Ruoxi Sun Co-authored-by: Rossi Sun Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 3 +- cpp/src/arrow/compute/api_vector.cc | 33 + cpp/src/arrow/compute/api_vector.h | 87 ++ cpp/src/arrow/compute/function_test.cc | 4 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 6 + .../arrow/compute/kernels/codegen_internal.h | 5 +- .../arrow/compute/kernels/vector_swizzle.cc | 421 ++++++++++ .../compute/kernels/vector_swizzle_test.cc | 756 ++++++++++++++++++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + docs/source/cpp/compute.rst | 42 +- 11 files changed, 1343 insertions(+), 16 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/vector_swizzle.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_swizzle_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6e2294371e7a6..eb9860b240f16 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -771,13 +771,14 @@ if(ARROW_COMPUTE) compute/kernels/scalar_validity.cc compute/kernels/vector_array_sort.cc compute/kernels/vector_cumulative_ops.cc - compute/kernels/vector_pairwise.cc compute/kernels/vector_nested.cc + compute/kernels/vector_pairwise.cc compute/kernels/vector_rank.cc compute/kernels/vector_replace.cc compute/kernels/vector_run_end_encode.cc compute/kernels/vector_select_k.cc compute/kernels/vector_sort.cc + compute/kernels/vector_swizzle.cc compute/key_hash_internal.cc compute/key_map_internal.cc compute/light_array_internal.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index f0d5c0fcc3d72..22ecf1cc87844 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); static auto kListFlattenOptionsType = GetFunctionOptionsType( DataMember("recursive", &ListFlattenOptions::recursive)); +static auto kInversePermutationOptionsType = + GetFunctionOptionsType( + DataMember("max_index", &InversePermutationOptions::max_index), + DataMember("output_type", &InversePermutationOptions::output_type)); +static auto kScatterOptionsType = GetFunctionOptionsType( + DataMember("max_index", &ScatterOptions::max_index)); } // namespace } // namespace internal @@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive) : FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {} constexpr char ListFlattenOptions::kTypeName[]; +InversePermutationOptions::InversePermutationOptions( + int64_t max_index, std::shared_ptr output_type) + : FunctionOptions(internal::kInversePermutationOptionsType), + max_index(max_index), + output_type(std::move(output_type)) {} +constexpr char InversePermutationOptions::kTypeName[]; + +ScatterOptions::ScatterOptions(int64_t max_index) + : FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {} +constexpr char ScatterOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType)); } } // namespace internal @@ -429,5 +448,19 @@ Result CumulativeMean(const Datum& values, const CumulativeOptions& optio return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx); } +// ---------------------------------------------------------------------- +// Swizzle functions + +Result InversePermutation(const Datum& indices, + const InversePermutationOptions& options, + ExecContext* ctx) { + return CallFunction("inverse_permutation", {indices}, &options, ctx); +} + +Result Scatter(const Datum& values, const Datum& indices, + const ScatterOptions& options, ExecContext* ctx) { + return CallFunction("scatter", {values, indices}, &options, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index e5bcc37329661..ada1665b3ec7c 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -257,6 +257,40 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions { bool recursive = false; }; +/// \brief Options for inverse_permutation function +class ARROW_EXPORT InversePermutationOptions : public FunctionOptions { + public: + explicit InversePermutationOptions(int64_t max_index = -1, + std::shared_ptr output_type = NULLPTR); + static constexpr char const kTypeName[] = "InversePermutationOptions"; + static InversePermutationOptions Defaults() { return InversePermutationOptions(); } + + /// \brief The max value in the input indices to allow. The length of the function's + /// output will be this value plus 1. If negative, this value will be set to the length + /// of the input indices minus 1 and the length of the function's output will be the + /// length of the input indices. + int64_t max_index = -1; + /// \brief The type of the output inverse permutation. If null, the output will be of + /// the same type as the input indices, otherwise must be signed integer type. An + /// invalid error will be reported if this type is not able to store the length of the + /// input indices. + std::shared_ptr output_type = NULLPTR; +}; + +/// \brief Options for scatter function +class ARROW_EXPORT ScatterOptions : public FunctionOptions { + public: + explicit ScatterOptions(int64_t max_index = -1); + static constexpr char const kTypeName[] = "ScatterOptions"; + static ScatterOptions Defaults() { return ScatterOptions(); } + + /// \brief The max value in the input indices to allow. The length of the function's + /// output will be this value plus 1. If negative, this value will be set to the length + /// of the input indices minus 1 and the length of the function's output will be the + /// length of the input indices. + int64_t max_index = -1; +}; + /// @} /// \brief Filter with a boolean selection filter @@ -705,5 +739,58 @@ Result> PairwiseDiff(const Array& array, bool check_overflow = false, ExecContext* ctx = NULLPTR); +/// \brief Return the inverse permutation of the given indices. +/// +/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x +/// does not appear in the input indices. Indices must be in the range of [0, max_index], +/// or null, which will be ignored. If multiple indices point to the same value, the last +/// one is used. +/// +/// For example, with +/// indices = [null, 0, null, 2, 4, 1, 1] +/// the inverse permutation is +/// [1, 6, 3, null, 4, null, null] +/// if max_index = 6. +/// +/// \param[in] indices array-like indices +/// \param[in] options configures the max index and the output type +/// \param[in] ctx the function execution context, optional +/// \return the resulting inverse permutation +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result InversePermutation( + const Datum& indices, + const InversePermutationOptions& options = InversePermutationOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Scatter the values into specified positions according to the indices. +/// +/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear +/// in the input indices. Indices must be in the range of [0, max_index], or null, in +/// which case the corresponding value will be ignored. If multiple indices point to the +/// same value, the last one is used. +/// +/// For example, with +/// values = [a, b, c, d, e, f, g] +/// indices = [null, 0, null, 2, 4, 1, 1] +/// the output is +/// [b, g, d, null, e, null, null] +/// if max_index = 6. +/// +/// \param[in] values datum to scatter +/// \param[in] indices array-like indices +/// \param[in] options configures the max index of to scatter +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Scatter(const Datum& values, const Datum& indices, + const ScatterOptions& options = ScatterOptions::Defaults(), + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index c269de0763217..b7d017d482013 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) { options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}})); options.emplace_back(new Utf8NormalizeOptions()); options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD)); + options.emplace_back( + new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32())); + options.emplace_back(new ScatterOptions()); + options.emplace_back(new ScatterOptions(/*max_index=*/42)); for (size_t i = 0; i < options.size(); i++) { const size_t prev_i = i == 0 ? options.size() - 1 : i - 1; diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 7c7b9c8b68d45..84b508f5d9be4 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test EXTRA_LINK_LIBS arrow_compute_kernels_testing) +add_arrow_compute_test(vector_swizzle_test + SOURCES + vector_swizzle_test.cc + EXTRA_LINK_LIBS + arrow_compute_kernels_testing) + add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 594bd1fce0b84..2a492f581f53b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) { // Generate a kernel given a templated functor for integer types // // See "Numeric" above for description of the generator functor -template