Skip to content

Commit

Permalink
feat: enable "substring" as a UDF in addition to "substr" (#11277)
Browse files Browse the repository at this point in the history
* feat: enable "substring" as a UDF in addition to "substr"

Substrait uses the name "substring", and it already exists in DF SQL

The setup here is a bit weird; I'd have added substring as an alias for substr, but then we have here this "substring" version being created as udf already and exported through the export_functions, with slightly different args than substr
(even though in reality the underlying function for both is the same substr impl).

I think this PR should work, but if you have suggestions on how to make the situation here cleaner, I'd be happy to!

* okay redo everything: add an alias instead, and add renaming in the substrait producer

* add alias into scalar_functions.md
  • Loading branch information
Blizzara authored Jul 6, 2024
1 parent 682fc05 commit 6f86bfa
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 43 deletions.
6 changes: 6 additions & 0 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type};
#[derive(Debug)]
pub struct SubstrFunc {
signature: Signature,
aliases: Vec<String>,
}

impl Default for SubstrFunc {
Expand All @@ -53,6 +54,7 @@ impl SubstrFunc {
],
Volatility::Immutable,
),
aliases: vec![String::from("substring")],
}
}
}
Expand Down Expand Up @@ -81,6 +83,10 @@ impl ScalarUDFImpl for SubstrFunc {
other => exec_err!("Unsupported data type {other:?} for function substr"),
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
Expand Down
60 changes: 30 additions & 30 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,36 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

pub fn name_to_op(name: &str) -> Result<Operator> {
pub fn name_to_op(name: &str) -> Option<Operator> {
match name {
"equal" => Ok(Operator::Eq),
"not_equal" => Ok(Operator::NotEq),
"lt" => Ok(Operator::Lt),
"lte" => Ok(Operator::LtEq),
"gt" => Ok(Operator::Gt),
"gte" => Ok(Operator::GtEq),
"add" => Ok(Operator::Plus),
"subtract" => Ok(Operator::Minus),
"multiply" => Ok(Operator::Multiply),
"divide" => Ok(Operator::Divide),
"mod" => Ok(Operator::Modulo),
"and" => Ok(Operator::And),
"or" => Ok(Operator::Or),
"is_distinct_from" => Ok(Operator::IsDistinctFrom),
"is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom),
"regex_match" => Ok(Operator::RegexMatch),
"regex_imatch" => Ok(Operator::RegexIMatch),
"regex_not_match" => Ok(Operator::RegexNotMatch),
"regex_not_imatch" => Ok(Operator::RegexNotIMatch),
"bitwise_and" => Ok(Operator::BitwiseAnd),
"bitwise_or" => Ok(Operator::BitwiseOr),
"str_concat" => Ok(Operator::StringConcat),
"at_arrow" => Ok(Operator::AtArrow),
"arrow_at" => Ok(Operator::ArrowAt),
"bitwise_xor" => Ok(Operator::BitwiseXor),
"bitwise_shift_right" => Ok(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft),
_ => not_impl_err!("Unsupported function name: {name:?}"),
"equal" => Some(Operator::Eq),
"not_equal" => Some(Operator::NotEq),
"lt" => Some(Operator::Lt),
"lte" => Some(Operator::LtEq),
"gt" => Some(Operator::Gt),
"gte" => Some(Operator::GtEq),
"add" => Some(Operator::Plus),
"subtract" => Some(Operator::Minus),
"multiply" => Some(Operator::Multiply),
"divide" => Some(Operator::Divide),
"mod" => Some(Operator::Modulo),
"and" => Some(Operator::And),
"or" => Some(Operator::Or),
"is_distinct_from" => Some(Operator::IsDistinctFrom),
"is_not_distinct_from" => Some(Operator::IsNotDistinctFrom),
"regex_match" => Some(Operator::RegexMatch),
"regex_imatch" => Some(Operator::RegexIMatch),
"regex_not_match" => Some(Operator::RegexNotMatch),
"regex_not_imatch" => Some(Operator::RegexNotIMatch),
"bitwise_and" => Some(Operator::BitwiseAnd),
"bitwise_or" => Some(Operator::BitwiseOr),
"str_concat" => Some(Operator::StringConcat),
"at_arrow" => Some(Operator::AtArrow),
"arrow_at" => Some(Operator::ArrowAt),
"bitwise_xor" => Some(Operator::BitwiseXor),
"bitwise_shift_right" => Some(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Some(Operator::BitwiseShiftLeft),
_ => None,
}
}

Expand Down Expand Up @@ -1124,7 +1124,7 @@ pub async fn from_substrait_rex(
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
} else if let Some(op) = name_to_op(fn_name) {
if f.arguments.len() < 2 {
return not_impl_err!(
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
Expand Down
32 changes: 20 additions & 12 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.to_string(), extension_info);
let function_anchor = register_function(fun.to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -849,7 +849,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.name().to_string(), extension_info);
let function_anchor = register_function(fun.name().to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -917,7 +917,7 @@ fn to_substrait_sort_field(
}
}

fn _register_function(
fn register_function(
function_name: String,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
Expand All @@ -926,6 +926,14 @@ fn _register_function(
) -> u32 {
let (function_extensions, function_set) = extension_info;
let function_name = function_name.to_lowercase();

// Some functions are named differently in Substrait default extensions than in DF
// Rename those to match the Substrait extensions for interoperability
let function_name = match function_name.as_str() {
"substr" => "substring".to_string(),
_ => function_name,
};

// To prevent ambiguous references between ScalarFunctions and AggregateFunctions,
// a plan-relative identifier starting from 0 is used as the function_anchor.
// The consumer is responsible for correctly registering <function_anchor,function_name>
Expand Down Expand Up @@ -969,7 +977,7 @@ pub fn make_binary_op_scalar_func(
),
) -> Expression {
let function_anchor =
_register_function(operator_to_name(op).to_string(), extension_info);
register_function(operator_to_name(op).to_string(), extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -1044,7 +1052,7 @@ pub fn to_substrait_rex(

if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);
register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -1076,7 +1084,7 @@ pub fn to_substrait_rex(
}

let function_anchor =
_register_function(fun.name().to_string(), extension_info);
register_function(fun.name().to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -1252,7 +1260,7 @@ pub fn to_substrait_rex(
null_treatment: _,
}) => {
// function reference
let function_anchor = _register_function(fun.to_string(), extension_info);
let function_anchor = register_function(fun.to_string(), extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
Expand Down Expand Up @@ -1330,7 +1338,7 @@ pub fn to_substrait_rex(
};
if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);
register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -1727,9 +1735,9 @@ fn make_substrait_like_expr(
),
) -> Result<Expression> {
let function_anchor = if ignore_case {
_register_function("ilike".to_string(), extension_info)
register_function("ilike".to_string(), extension_info)
} else {
_register_function("like".to_string(), extension_info)
register_function("like".to_string(), extension_info)
};
let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?;
let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?;
Expand Down Expand Up @@ -1759,7 +1767,7 @@ fn make_substrait_like_expr(
};

if negated {
let function_anchor = _register_function("not".to_string(), extension_info);
let function_anchor = register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -2128,7 +2136,7 @@ fn to_substrait_unary_scalar_fn(
HashMap<String, u32>,
),
) -> Result<Expression> {
let function_anchor = _register_function(fn_name.to_string(), extension_info);
let function_anchor = register_function(fn_name.to_string(), extension_info);
let substrait_expr =
to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async fn simple_scalar_function_pow() -> Result<()> {

#[tokio::test]
async fn simple_scalar_function_substr() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await
}

#[tokio::test]
Expand Down
8 changes: 8 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,14 @@ substr(str, start_pos[, length])
- **length**: Number of characters to extract.
If not specified, returns the rest of the string after the start position.

#### Aliases

- substring

### `substring`

_Alias of [substr](#substr)._

### `translate`

Translates characters in a string to specified translation characters.
Expand Down

0 comments on commit 6f86bfa

Please sign in to comment.