Skip to content

Commit

Permalink
feat(udf): support CREATE FUNCTION/AGGREGATE IF NOT EXISTS (#20079)
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Jan 9, 2025
1 parent 934db16 commit 09e59e4
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 94 deletions.
115 changes: 115 additions & 0 deletions e2e_test/udf/create_and_drop.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# https://github.com/risingwavelabs/risingwave/issues/17263

statement ok
create table t (a int, b int);

statement ok
create function add(a int, b int) returns int language python as $$
def add(a, b):
return a+b
$$;

statement error function with name add\(integer,integer\) exists
create function add(int, int) returns int language sql as $$select $1 + $2$$;

statement ok
create function if not exists add(int, int) returns int language sql as $$select $1 + $2$$;

statement ok
create function add_v2(int, int) returns int language sql as $$select $1 + $2$$;

statement ok
create aggregate mysum(value int) returns int language python as $$
def create_state():
return 0
def accumulate(state, value):
return state + value
def finish(state):
return state
$$;

statement error function with name mysum\(integer\) exists
create aggregate mysum(value int) returns int language python as $$
def create_state():
return 0
def accumulate(state, value):
return state + value
def finish(state):
return state
$$;

statement ok
create aggregate if not exists mysum(value int) returns int language python as $$
def create_state():
return 0
def accumulate(state, value):
return state + value
def finish(state):
return state
$$;

statement ok
create materialized view mv as select add(a, b) + add_v2(a, b) as c from t;

statement ok
create materialized view mv2 as select mysum(a) as s from t;

statement error function used by 1 other objects
drop function add;

statement error function used by 1 other objects
drop function if exists add;

statement error function used by 1 other objects
drop function add_v2;

statement error function used by 1 other objects
drop aggregate mysum;

statement ok
drop materialized view mv;

statement ok
drop materialized view mv2;

statement ok
drop function add;

statement error function not found
drop function add;

statement ok
drop function if exists add;

statement ok
drop function add_v2;

statement ok
drop function if exists add_v2;

statement ok
drop aggregate mysum;

statement ok
drop aggregate if exists mysum;

statement ok
create function add(a int, b int) returns int language python as $$
def add(a, b):
return a+b
$$;

statement ok
create sink s as select add(a, b) as c from t with (connector = 'blackhole');

statement error function used by 1 other objects
drop function add;

statement ok
drop sink s;

statement ok
drop function add;

statement ok
drop table t;
52 changes: 0 additions & 52 deletions e2e_test/udf/drop_function.slt

This file was deleted.

23 changes: 11 additions & 12 deletions src/frontend/src/handler/create_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
// limitations under the License.

use anyhow::Context;
use either::Either;
use risingwave_common::catalog::FunctionId;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{AggregateFunction, Kind};
use risingwave_pb::catalog::Function;
use risingwave_sqlparser::ast::DataType as AstDataType;

use super::*;
use crate::catalog::CatalogError;
use crate::{bind_data_type, Binder};

pub async fn handle_create_aggregate(
handler_args: HandlerArgs,
or_replace: bool,
if_not_exists: bool,
name: ObjectName,
args: Vec<OperateFunctionArg>,
returns: AstDataType,
Expand Down Expand Up @@ -74,20 +75,18 @@ pub async fn handle_create_aggregate(
// resolve database and schema id
let session = &handler_args.session;
let db_name = &session.database();
let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
let (schema_name, function_name) =
Binder::resolve_schema_qualified_name(db_name, name.clone())?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;

// check if the function exists in the catalog
if (session.env().catalog_reader().read_guard())
.get_schema_by_id(&database_id, &schema_id)?
.get_function_by_name_args(&function_name, &arg_types)
.is_some()
{
let name = format!(
"{function_name}({})",
arg_types.iter().map(|t| t.to_string()).join(",")
);
return Err(CatalogError::Duplicated("function", name).into());
if let Either::Right(resp) = session.check_function_name_duplicated(
StatementType::CREATE_FUNCTION,
name,
&arg_types,
if_not_exists,
)? {
return Ok(resp);
}

let link = match &params.using {
Expand Down
23 changes: 11 additions & 12 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
// limitations under the License.

use anyhow::Context;
use either::Either;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::StructType;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
use risingwave_pb::catalog::Function;

use super::*;
use crate::catalog::CatalogError;
use crate::{bind_data_type, Binder};

pub async fn handle_create_function(
handler_args: HandlerArgs,
or_replace: bool,
temporary: bool,
if_not_exists: bool,
name: ObjectName,
args: Option<Vec<OperateFunctionArg>>,
returns: Option<CreateFunctionReturns>,
Expand Down Expand Up @@ -107,20 +108,18 @@ pub async fn handle_create_function(
// resolve database and schema id
let session = &handler_args.session;
let db_name = &session.database();
let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
let (schema_name, function_name) =
Binder::resolve_schema_qualified_name(db_name, name.clone())?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;

// check if the function exists in the catalog
if (session.env().catalog_reader().read_guard())
.get_schema_by_id(&database_id, &schema_id)?
.get_function_by_name_args(&function_name, &arg_types)
.is_some()
{
let name = format!(
"{function_name}({})",
arg_types.iter().map(|t| t.to_string()).join(",")
);
return Err(CatalogError::Duplicated("function", name).into());
if let Either::Right(resp) = session.check_function_name_duplicated(
StatementType::CREATE_FUNCTION,
name,
&arg_types,
if_not_exists,
)? {
return Ok(resp);
}

let link = match &params.using {
Expand Down
23 changes: 11 additions & 12 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::collections::HashMap;

use either::Either;
use fancy_regex::Regex;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::{DataType, StructType};
Expand All @@ -23,7 +24,6 @@ use risingwave_sqlparser::parser::{Parser, ParserError};

use super::*;
use crate::binder::UdfContext;
use crate::catalog::CatalogError;
use crate::expr::{Expr, ExprImpl, Literal};
use crate::{bind_data_type, Binder};

Expand Down Expand Up @@ -122,6 +122,7 @@ pub async fn handle_create_sql_function(
handler_args: HandlerArgs,
or_replace: bool,
temporary: bool,
if_not_exists: bool,
name: ObjectName,
args: Option<Vec<OperateFunctionArg>>,
returns: Option<CreateFunctionReturns>,
Expand Down Expand Up @@ -214,20 +215,18 @@ pub async fn handle_create_sql_function(
// resolve database and schema id
let session = &handler_args.session;
let db_name = &session.database();
let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?;
let (schema_name, function_name) =
Binder::resolve_schema_qualified_name(db_name, name.clone())?;
let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?;

// check if function exists
if (session.env().catalog_reader().read_guard())
.get_schema_by_id(&database_id, &schema_id)?
.get_function_by_name_args(&function_name, &arg_types)
.is_some()
{
let name = format!(
"{function_name}({})",
arg_types.iter().map(|t| t.to_string()).join(",")
);
return Err(CatalogError::Duplicated("function", name).into());
if let Either::Right(resp) = session.check_function_name_duplicated(
StatementType::CREATE_FUNCTION,
name,
&arg_types,
if_not_exists,
)? {
return Ok(resp);
}

// Parse function body here
Expand Down
5 changes: 5 additions & 0 deletions src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ pub async fn handle(
Statement::CreateFunction {
or_replace,
temporary,
if_not_exists,
name,
args,
returns,
Expand All @@ -298,6 +299,7 @@ pub async fn handle(
handler_args,
or_replace,
temporary,
if_not_exists,
name,
args,
returns,
Expand All @@ -310,6 +312,7 @@ pub async fn handle(
handler_args,
or_replace,
temporary,
if_not_exists,
name,
args,
returns,
Expand All @@ -320,6 +323,7 @@ pub async fn handle(
}
Statement::CreateAggregate {
or_replace,
if_not_exists,
name,
args,
returns,
Expand All @@ -329,6 +333,7 @@ pub async fn handle(
create_aggregate::handle_create_aggregate(
handler_args,
or_replace,
if_not_exists,
name,
args,
returns,
Expand Down
Loading

0 comments on commit 09e59e4

Please sign in to comment.