Skip to content

Commit

Permalink
feat(expr): support array_position and array_replace for 1d scena…
Browse files Browse the repository at this point in the history
…rio (#10166)
  • Loading branch information
xiangjinwu authored Jun 6, 2023
1 parent ac2085d commit 572780b
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 23 deletions.
2 changes: 2 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ message ExprNode {
ARRAY_POSITIONS = 539;
TRIM_ARRAY = 540;
STRING_TO_ARRAY = 541;
ARRAY_POSITION = 542;
ARRAY_REPLACE = 543;

// Int256 functions
HEX_TO_INT256 = 560;
Expand Down
115 changes: 115 additions & 0 deletions src/expr/src/vector_op/array_positions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,121 @@ use risingwave_expr_macro::function;
use crate::error::ExprError;
use crate::Result;

/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
/// it's not present.
///
/// Examples:
///
/// ```slt
/// query I
/// select array_position(array[1, null, 2, null], null);
/// ----
/// 2
///
/// query I
/// select array_position(array[3, 4, 5], 2);
/// ----
/// NULL
///
/// query I
/// select array_position(null, 4);
/// ----
/// NULL
///
/// query I
/// select array_position(null, null);
/// ----
/// NULL
///
/// query I
/// select array_position('{yes}', true);
/// ----
/// 1
///
/// # Like in PostgreSQL, searching `int` in multidimensional array is disallowed.
/// statement error
/// select array_position(array[array[1, 2], array[3, 4]], 1);
///
/// # Unlike in PostgreSQL, it is okay to search `int[]` inside `int[][]`.
/// query I
/// select array_position(array[array[1, 2], array[3, 4]], array[3, 4]);
/// ----
/// 2
///
/// statement error
/// select array_position(array[3, 4], true);
///
/// query I
/// select array_position(array[3, 4], 4.0);
/// ----
/// 2
/// ```
#[function("array_position(list, *) -> int32")]
fn array_position<'a, T: ScalarRef<'a>>(
array: Option<ListRef<'_>>,
element: Option<T>,
) -> Result<Option<i32>> {
array_position_common(array, element, 0)
}

/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
/// it's not present. The search begins at the third argument.
///
/// Examples:
///
/// ```slt
/// statement error
/// select array_position(array[1, null, 2, null], null, false);
///
/// statement error
/// select array_position(array[1, null, 2, null], null, null::int);
///
/// query II
/// select v, array_position(array[1, null, 2, null], null, v) from generate_series(-1, 5) as t(v);
/// ----
/// -1 2
/// 0 2
/// 1 2
/// 2 2
/// 3 4
/// 4 4
/// 5 NULL
/// ```
#[function("array_position(list, *, int32) -> int32")]
fn array_position_start<'a, T: ScalarRef<'a>>(
array: Option<ListRef<'_>>,
element: Option<T>,
start: Option<i32>,
) -> Result<Option<i32>> {
let start = match start {
None => {
return Err(ExprError::InvalidParam {
name: "start",
reason: "initial position must not be null".into(),
})
}
Some(start) => (start.max(1) - 1) as usize,
};
array_position_common(array, element, start)
}

fn array_position_common<'a, T: ScalarRef<'a>>(
array: Option<ListRef<'_>>,
element: Option<T>,
skip: usize,
) -> Result<Option<i32>> {
let Some(left) = array else { return Ok(None) };
if i32::try_from(left.len()).is_err() {
return Err(ExprError::CastOutOfRange("invalid array length"));
}

Ok(left
.iter()
.skip(skip)
.position(|item| item == element.map(Into::into))
.map(|idx| (idx + 1 + skip) as _))
}

/// Returns an array of the subscripts of all occurrences of the second argument in the array
/// given as first argument. Note the behavior is slightly different from PG.
///
Expand Down
90 changes: 90 additions & 0 deletions src/expr/src/vector_op/array_replace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::{ListRef, ListValue};
use risingwave_common::types::{ScalarRef, ToOwnedDatum};
use risingwave_expr_macro::function;

/// Replaces each array element equal to the second argument with the third argument.
///
/// Examples:
///
/// ```slt
/// query T
/// select array_replace(array[7, null, 8, null], null, 0.5);
/// ----
/// {7,0.5,8,0.5}
///
/// query T
/// select array_replace(null, 1, 5);
/// ----
/// NULL
///
/// query T
/// select array_replace(null, null, null);
/// ----
/// NULL
///
/// statement error
/// select array_replace(array[3, null, 4], true, false);
///
/// # Replacing `int` in multidimensional array is not supported yet. (OK in PostgreSQL)
/// statement error
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], 3, 9);
///
/// # Unlike PostgreSQL, it is okay to replace `int[][]` inside `int[][][]`.
/// query T
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[array[4, 5], array[6, 7]], array[array[2, 3], array[4, 5]]);
/// ----
/// {{{0,1},{2,3}},{{2,3},{4,5}}}
///
/// # Replacing `int[]` inside `int[][][]` is not supported by either PostgreSQL or RisingWave.
/// # This may or may not be supported later, whichever makes the `int` support above simpler.
/// statement error
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[4, 5], array[8, 9]);
/// ```
#[function("array_replace(list, boolean, boolean) -> list")]
#[function("array_replace(list, int16, int16) -> list")]
#[function("array_replace(list, int32, int32) -> list")]
#[function("array_replace(list, int64, int64) -> list")]
#[function("array_replace(list, decimal, decimal) -> list")]
#[function("array_replace(list, float32, float32) -> list")]
#[function("array_replace(list, float64, float64) -> list")]
#[function("array_replace(list, varchar, varchar) -> list")]
#[function("array_replace(list, bytea, bytea) -> list")]
#[function("array_replace(list, time, time) -> list")]
#[function("array_replace(list, interval, interval) -> list")]
#[function("array_replace(list, date, date) -> list")]
#[function("array_replace(list, timestamp, timestamp) -> list")]
#[function("array_replace(list, timestamptz, timestamptz) -> list")]
#[function("array_replace(list, list, list) -> list")]
#[function("array_replace(list, struct, struct) -> list")]
#[function("array_replace(list, jsonb, jsonb) -> list")]
#[function("array_replace(list, int256, int256) -> list")]
fn array_replace<'a, T: ScalarRef<'a>>(
arr: Option<ListRef<'_>>,
elem_from: Option<T>,
elem_to: Option<T>,
) -> Option<ListValue> {
arr.map(|arr| {
ListValue::new(
arr.iter()
.map(|x| match x == elem_from.map(Into::into) {
true => elem_to.map(Into::into).to_owned_datum(),
false => x.to_owned_datum(),
})
.collect(),
)
})
}
1 change: 1 addition & 0 deletions src/expr/src/vector_op/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod array_length;
pub mod array_positions;
pub mod array_range_access;
pub mod array_remove;
pub mod array_replace;
pub mod ascii;
pub mod bitwise_op;
pub mod cardinality;
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ impl Binder {
("array_length", raw_call(ExprType::ArrayLength)),
("cardinality", raw_call(ExprType::Cardinality)),
("array_remove", raw_call(ExprType::ArrayRemove)),
("array_replace", raw_call(ExprType::ArrayReplace)),
("array_position", raw_call(ExprType::ArrayPosition)),
("array_positions", raw_call(ExprType::ArrayPositions)),
("trim_array", raw_call(ExprType::TrimArray)),
// int256
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ impl ExprVisitor<bool> for ImpureAnalyzer {
| expr_node::Type::Cardinality
| expr_node::Type::TrimArray
| expr_node::Type::ArrayRemove
| expr_node::Type::ArrayReplace
| expr_node::Type::ArrayPosition
| expr_node::Type::HexToInt256
| expr_node::Type::JsonbAccessInner
| expr_node::Type::JsonbAccessStr
Expand Down
31 changes: 31 additions & 0 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,37 @@ fn infer_type_for_special(
.into()),
}
}
ExprType::ArrayReplace => {
ensure_arity!("array_replace", | inputs | == 3);
let common_type = align_array_and_element(0, &[1, 2], inputs);
match common_type {
Ok(casted) => Ok(Some(casted)),
Err(_) => Err(ErrorCode::BindError(format!(
"Cannot replace {} with {} in {}",
inputs[1].return_type(),
inputs[2].return_type(),
inputs[0].return_type(),
))
.into()),
}
}
ExprType::ArrayPosition => {
ensure_arity!("array_position", 2 <= | inputs | <= 3);
if let Some(start) = inputs.get_mut(2) {
let owned = std::mem::replace(start, ExprImpl::literal_bool(false));
*start = owned.cast_implicit(DataType::Int32)?;
}
let common_type = align_array_and_element(0, &[1], inputs);
match common_type {
Ok(_) => Ok(Some(DataType::Int32)),
Err(_) => Err(ErrorCode::BindError(format!(
"Cannot get position of {} in {}",
inputs[1].return_type(),
inputs[0].return_type()
))
.into()),
}
}
ExprType::ArrayPositions => {
ensure_arity!("array_positions", | inputs | == 2);
let common_type = align_array_and_element(0, &[1], inputs);
Expand Down
46 changes: 23 additions & 23 deletions src/tests/regress/data/sql/arrays.sql
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,14 @@ SELECT array_cat(ARRAY[1,2], ARRAY[3,4]) AS "{1,2,3,4}";
SELECT array_cat(ARRAY[1,2], ARRAY[[3,4],[5,6]]) AS "{{1,2},{3,4},{5,6}}";
SELECT array_cat(ARRAY[[3,4],[5,6]], ARRAY[1,2]) AS "{{3,4},{5,6},{1,2}}";

--@ SELECT array_position(ARRAY[1,2,3,4,5], 4);
--@ SELECT array_position(ARRAY[5,3,4,2,1], 4);
--@ SELECT array_position(ARRAY[[1,2],[3,4]], 3);
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon');
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat');
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL);
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL);
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat');
SELECT array_position(ARRAY[1,2,3,4,5], 4);
SELECT array_position(ARRAY[5,3,4,2,1], 4);
SELECT array_position(ARRAY[[1,2],[3,4]], 3);
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon');
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat');
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL);
SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL);
SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat');

SELECT array_positions(NULL, 10);
SELECT array_positions(NULL, NULL::int);
Expand Down Expand Up @@ -296,15 +296,15 @@ SELECT array_positions(ARRAY[1,2,3,NULL,5,6,1,2,3,NULL,5,6], NULL);

--@ SELECT array_position('[2:4]={1,2,3}'::int[], 1);
--@ SELECT array_positions('[2:4]={1,2,3}'::int[], 1);
--@
--@ SELECT
--@ array_position(ids, (1, 1)),
--@ array_positions(ids, (1, 1))
--@ FROM
--@ (VALUES
--@ (ARRAY[(0, 0), (1, 1)]),
--@ (ARRAY[(1, 1)])
--@ ) AS f (ids);

SELECT
array_position(ids, (1, 1)),
array_positions(ids, (1, 1))
FROM
(VALUES
(ARRAY[(0, 0), (1, 1)]),
(ARRAY[(1, 1)])
) AS f (ids);

-- operators
--@ SELECT a FROM arrtest WHERE b = ARRAY[[[113,142],[1,147]]];
Expand Down Expand Up @@ -625,12 +625,12 @@ select array_remove(array['A','CC','D','C','RR'], 'RR');
select array_remove(array[1.0, 2.1, 3.3], 1);
select array_remove('{{1,2,2},{1,4,3}}', 2); -- not allowed
select array_remove(array['X','X','X'], 'X') = '{}';
--@ select array_replace(array[1,2,5,4],5,3);
--@ select array_replace(array[1,2,5,4],5,NULL);
--@ select array_replace(array[1,2,NULL,4,NULL],NULL,5);
--@ select array_replace(array['A','B','DD','B'],'B','CC');
--@ select array_replace(array[1,NULL,3],NULL,NULL);
--@ select array_replace(array['AB',NULL,'CDE'],NULL,'12');
select array_replace(array[1,2,5,4],5,3);
select array_replace(array[1,2,5,4],5,NULL);
select array_replace(array[1,2,NULL,4,NULL],NULL,5);
select array_replace(array['A','B','DD','B'],'B','CC');
select array_replace(array[1,NULL,3],NULL,NULL);
select array_replace(array['AB',NULL,'CDE'],NULL,'12');

-- array(select array-value ...)
--@ select array(select array[i,i/2] from generate_series(1,5) i);
Expand Down

0 comments on commit 572780b

Please sign in to comment.