Skip to content

Commit

Permalink
GH-44393: [C++][Compute] Vector selection functions `inverse_permutat…
Browse files Browse the repository at this point in the history
…ion` and `scatter` (#44394)

### Rationale for this change

For background please see #44393.

When implementing the "scatter" function requested in #44393, I found it also useful to make it a public vector API. After a painful thinking, I decided to name it "permute". And when implementing permute, I found it fairly easy to implement it by first computing the "reverse indices" of the positions, and then invoking the existing "take", where I think "reverse_indices"  itself can also be a useful public vector API. Thus the PR categorized them as "placement functions".

### What changes are included in this PR?

Implement vector selection API `inverse_permutation` and `scatter`, where `scatter(values, indices)` is implemented as `take(values, inverse_permutation(indices))`.

### Are these changes tested?

UT included.

### Are there any user-facing changes?

Yes, new public APIs added. Documents updated.

* GitHub Issue: #44393

Lead-authored-by: Ruoxi Sun <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
zanmato1984 and pitrou authored Jan 15, 2025
1 parent d7dc586 commit 3222e2a
Show file tree
Hide file tree
Showing 11 changed files with 1,343 additions and 16 deletions.
3 changes: 2 additions & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -771,13 +771,14 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_validity.cc
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
compute/kernels/vector_run_end_encode.cc
compute/kernels/vector_select_k.cc
compute/kernels/vector_sort.cc
compute/kernels/vector_swizzle.cc
compute/key_hash_internal.cc
compute/key_map_internal.cc
compute/light_array_internal.cc
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kInversePermutationOptionsType =
GetFunctionOptionsType<InversePermutationOptions>(
DataMember("max_index", &InversePermutationOptions::max_index),
DataMember("output_type", &InversePermutationOptions::output_type));
static auto kScatterOptionsType = GetFunctionOptionsType<ScatterOptions>(
DataMember("max_index", &ScatterOptions::max_index));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

InversePermutationOptions::InversePermutationOptions(
int64_t max_index, std::shared_ptr<DataType> output_type)
: FunctionOptions(internal::kInversePermutationOptionsType),
max_index(max_index),
output_type(std::move(output_type)) {}
constexpr char InversePermutationOptions::kTypeName[];

ScatterOptions::ScatterOptions(int64_t max_index)
: FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {}
constexpr char ScatterOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +448,19 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Swizzle functions

Result<Datum> InversePermutation(const Datum& indices,
const InversePermutationOptions& options,
ExecContext* ctx) {
return CallFunction("inverse_permutation", {indices}, &options, ctx);
}

Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options, ExecContext* ctx) {
return CallFunction("scatter", {values, indices}, &options, ctx);
}

} // namespace compute
} // namespace arrow
87 changes: 87 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for inverse_permutation function
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
public:
explicit InversePermutationOptions(int64_t max_index = -1,
std::shared_ptr<DataType> output_type = NULLPTR);
static constexpr char const kTypeName[] = "InversePermutationOptions";
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
/// \brief The type of the output inverse permutation. If null, the output will be of
/// the same type as the input indices, otherwise must be signed integer type. An
/// invalid error will be reported if this type is not able to store the length of the
/// input indices.
std::shared_ptr<DataType> output_type = NULLPTR;
};

/// \brief Options for scatter function
class ARROW_EXPORT ScatterOptions : public FunctionOptions {
public:
explicit ScatterOptions(int64_t max_index = -1);
static constexpr char const kTypeName[] = "ScatterOptions";
static ScatterOptions Defaults() { return ScatterOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +739,58 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

/// \brief Return the inverse permutation of the given indices.
///
/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x
/// does not appear in the input indices. Indices must be in the range of [0, max_index],
/// or null, which will be ignored. If multiple indices point to the same value, the last
/// one is used.
///
/// For example, with
/// indices = [null, 0, null, 2, 4, 1, 1]
/// the inverse permutation is
/// [1, 6, 3, null, 4, null, null]
/// if max_index = 6.
///
/// \param[in] indices array-like indices
/// \param[in] options configures the max index and the output type
/// \param[in] ctx the function execution context, optional
/// \return the resulting inverse permutation
///
/// \since 20.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> InversePermutation(
const Datum& indices,
const InversePermutationOptions& options = InversePermutationOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Scatter the values into specified positions according to the indices.
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. Indices must be in the range of [0, max_index], or null, in
/// which case the corresponding value will be ignored. If multiple indices point to the
/// same value, the last one is used.
///
/// For example, with
/// values = [a, b, c, d, e, f, g]
/// indices = [null, 0, null, 2, 4, 1, 1]
/// the output is
/// [b, g, d, null, e, null, null]
/// if max_index = 6.
///
/// \param[in] values datum to scatter
/// \param[in] indices array-like indices
/// \param[in] options configures the max index of to scatter
/// \param[in] ctx the function execution context, optional
/// \return the resulting datum
///
/// \since 20.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options = ScatterOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
options.emplace_back(new Utf8NormalizeOptions());
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
options.emplace_back(
new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32()));
options.emplace_back(new ScatterOptions());
options.emplace_back(new ScatterOptions(/*max_index=*/42));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_swizzle_test
SOURCES
vector_swizzle_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
Loading

0 comments on commit 3222e2a

Please sign in to comment.