Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44393: [C++][Compute] Vector selection functions inverse_permutation and scatter #44394

Merged
merged 61 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8de08ae
WIP
zanmato1984 Sep 29, 2024
d36330e
WIP
zanmato1984 Sep 30, 2024
e68a6d4
Add permute function options
zanmato1984 Oct 2, 2024
41bc1cd
WIP
zanmato1984 Oct 4, 2024
49f835f
Implementation done and basic tests
zanmato1984 Oct 6, 2024
b928c07
Implement permute
zanmato1984 Oct 7, 2024
157bc1d
Fix API and doc
zanmato1984 Oct 10, 2024
fef6355
Fix API and doc
zanmato1984 Oct 10, 2024
9fe37a4
Init docs
zanmato1984 Oct 10, 2024
604fa69
Reorg reverse_index
zanmato1984 Oct 10, 2024
02914f4
Refine
zanmato1984 Oct 10, 2024
5aab878
Update docs
zanmato1984 Oct 10, 2024
a21b237
Refine doc
zanmato1984 Oct 11, 2024
bdfcf55
Add comments for the implementation
zanmato1984 Oct 11, 2024
797bd94
Refine docs
zanmato1984 Oct 11, 2024
8946cda
Fix uint64 overflow check
zanmato1984 Oct 11, 2024
ee671f2
Reverse indices tests
zanmato1984 Oct 11, 2024
b89c4bb
Forbit non-array-like argument
zanmato1984 Oct 11, 2024
4a43367
Fix permute option default
zanmato1984 Oct 11, 2024
1cc1355
Refine
zanmato1984 Oct 11, 2024
04832af
WIP permute tests
zanmato1984 Oct 11, 2024
774ccfb
Refine tests
zanmato1984 Oct 12, 2024
afd2787
More permute tests
zanmato1984 Oct 12, 2024
ea0b0da
Add if-else tests using permute
zanmato1984 Oct 13, 2024
338687b
Update some comments
zanmato1984 Oct 13, 2024
5519cd1
Fix lint
zanmato1984 Oct 14, 2024
bd5ec35
Update comment
zanmato1984 Oct 14, 2024
776fe89
Fix typo
zanmato1984 Oct 14, 2024
d7a2f61
Typo
zanmato1984 Oct 14, 2024
a892bab
Refine
zanmato1984 Oct 17, 2024
3df345e
Update cpp/src/arrow/compute/kernels/vector_placement_test.cc
zanmato1984 Oct 31, 2024
7108c12
Rename function category to swizzle
zanmato1984 Nov 4, 2024
cab0d4c
reverse_indices -> inverse_permutation
zanmato1984 Nov 4, 2024
2e6cb70
output_length -> max_index
zanmato1984 Nov 4, 2024
e35af38
Permute -> Scatter
zanmato1984 Nov 4, 2024
5c22d7c
Fixing some renamings
zanmato1984 Nov 6, 2024
ab010f0
Update docs/source/cpp/compute.rst
zanmato1984 Dec 11, 2024
84049ac
Update cpp/src/arrow/compute/api_vector.h
zanmato1984 Dec 11, 2024
85ed9dc
Update cpp/src/arrow/compute/api_vector.h
zanmato1984 Dec 11, 2024
fd03ad7
Update cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
zanmato1984 Dec 11, 2024
d9b10cc
Update cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
zanmato1984 Dec 11, 2024
ffd2f36
Limit input/output type to signed integers
zanmato1984 Dec 11, 2024
acb4fce
Make visit method public and remove friend
zanmato1984 Dec 11, 2024
fa1d9f2
Show no mercy to index out of bounds
zanmato1984 Dec 12, 2024
4527c47
Use type error instead of invalid
zanmato1984 Dec 12, 2024
5419a85
Remove errornous predict false
zanmato1984 Dec 12, 2024
a3bd7c3
Avoid uninitialized data buf
zanmato1984 Dec 12, 2024
962749b
Coding convention of instantce variables
zanmato1984 Dec 12, 2024
ff71d77
Optimize buffer initializing
zanmato1984 Dec 12, 2024
6805784
Reduce typed tests
zanmato1984 Dec 12, 2024
e30f33c
Naming
zanmato1984 Dec 12, 2024
28bfddc
Remove repetition of test cases
zanmato1984 Dec 12, 2024
75d96d6
Doc about output length
zanmato1984 Dec 12, 2024
0d44639
Fix ci error
zanmato1984 Dec 12, 2024
953d3b1
Move new functions into selection category in doc
zanmato1984 Jan 15, 2025
b82ad94
Allocate uninitialized buffer and fill the capacity bytes
zanmato1984 Jan 15, 2025
eaa9a3c
type -> value_type
zanmato1984 Jan 15, 2025
b1f1208
Simplify chunked cases
zanmato1984 Jan 15, 2025
4342033
Use common type lists and test for more numeric types
zanmato1984 Jan 15, 2025
8e4ecab
Bump `since` version
zanmato1984 Jan 15, 2025
6ac321c
Refine function docs
zanmato1984 Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -771,13 +771,14 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_validity.cc
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
compute/kernels/vector_run_end_encode.cc
compute/kernels/vector_select_k.cc
compute/kernels/vector_sort.cc
compute/kernels/vector_swizzle.cc
compute/key_hash_internal.cc
compute/key_map_internal.cc
compute/light_array_internal.cc
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kInversePermutationOptionsType =
GetFunctionOptionsType<InversePermutationOptions>(
DataMember("max_index", &InversePermutationOptions::max_index),
DataMember("output_type", &InversePermutationOptions::output_type));
static auto kScatterOptionsType = GetFunctionOptionsType<ScatterOptions>(
DataMember("max_index", &ScatterOptions::max_index));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

InversePermutationOptions::InversePermutationOptions(
int64_t max_index, std::shared_ptr<DataType> output_type)
: FunctionOptions(internal::kInversePermutationOptionsType),
max_index(max_index),
output_type(std::move(output_type)) {}
constexpr char InversePermutationOptions::kTypeName[];

ScatterOptions::ScatterOptions(int64_t max_index)
: FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {}
constexpr char ScatterOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +448,19 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Swizzle functions

Result<Datum> InversePermutation(const Datum& indices,
const InversePermutationOptions& options,
ExecContext* ctx) {
return CallFunction("inverse_permutation", {indices}, &options, ctx);
}

Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options, ExecContext* ctx) {
return CallFunction("scatter", {values, indices}, &options, ctx);
}

} // namespace compute
} // namespace arrow
87 changes: 87 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for inverse_permutation function
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
public:
explicit InversePermutationOptions(int64_t max_index = -1,
std::shared_ptr<DataType> output_type = NULLPTR);
static constexpr char const kTypeName[] = "InversePermutationOptions";
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
/// \brief The type of the output inverse permutation. If null, the output will be of
/// the same type as the input indices, otherwise must be signed integer type. An
/// invalid error will be reported if this type is not able to store the length of the
/// input indices.
std::shared_ptr<DataType> output_type = NULLPTR;
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
};

/// \brief Options for scatter function
class ARROW_EXPORT ScatterOptions : public FunctionOptions {
public:
explicit ScatterOptions(int64_t max_index = -1);
static constexpr char const kTypeName[] = "ScatterOptions";
static ScatterOptions Defaults() { return ScatterOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
pitrou marked this conversation as resolved.
Show resolved Hide resolved
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +739,58 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

/// \brief Return the inverse permutation of the given indices.
///
/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x
/// does not appear in the input indices. Indices must be in the range of [0, max_index],
/// or null, which will be ignored. If multiple indices point to the same value, the last
/// one is used.
///
/// For example, with
/// indices = [null, 0, null, 2, 4, 1, 1]
/// the inverse permutation is
/// [1, 6, 3, null, 4, null, null]
/// if max_index = 6.
///
/// \param[in] indices array-like indices
/// \param[in] options configures the max index and the output type
/// \param[in] ctx the function execution context, optional
/// \return the resulting inverse permutation
///
/// \since 20.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> InversePermutation(
const Datum& indices,
const InversePermutationOptions& options = InversePermutationOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Scatter the values into specified positions according to the indices.
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. Indices must be in the range of [0, max_index], or null, in
/// which case the corresponding value will be ignored. If multiple indices point to the
/// same value, the last one is used.
///
/// For example, with
/// values = [a, b, c, d, e, f, g]
/// indices = [null, 0, null, 2, 4, 1, 1]
/// the output is
/// [b, g, d, null, e, null, null]
/// if max_index = 6.
///
/// \param[in] values datum to scatter
/// \param[in] indices array-like indices
/// \param[in] options configures the max index of to scatter
/// \param[in] ctx the function execution context, optional
/// \return the resulting datum
///
/// \since 20.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options = ScatterOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
options.emplace_back(new Utf8NormalizeOptions());
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
options.emplace_back(
new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32()));
options.emplace_back(new ScatterOptions());
options.emplace_back(new ScatterOptions(/*max_index=*/42));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_swizzle_test
SOURCES
vector_swizzle_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
Loading
Loading