From 09e59e43dd7630f02664af8f8fb6672554351530 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Thu, 9 Jan 2025 16:57:02 +0800 Subject: [PATCH] feat(udf): support `CREATE FUNCTION/AGGREGATE IF NOT EXISTS` (#20079) Signed-off-by: Richard Chien --- e2e_test/udf/create_and_drop.slt | 115 ++++++++++++++++++ e2e_test/udf/drop_function.slt | 52 -------- src/frontend/src/handler/create_aggregate.rs | 23 ++-- src/frontend/src/handler/create_function.rs | 23 ++-- .../src/handler/create_sql_function.rs | 23 ++-- src/frontend/src/handler/mod.rs | 5 + src/frontend/src/session.rs | 39 ++++++ src/sqlparser/src/ast/mod.rs | 16 ++- src/sqlparser/src/parser.rs | 6 + src/sqlparser/tests/sqlparser_postgres.rs | 57 ++++++++- 10 files changed, 265 insertions(+), 94 deletions(-) create mode 100644 e2e_test/udf/create_and_drop.slt delete mode 100644 e2e_test/udf/drop_function.slt diff --git a/e2e_test/udf/create_and_drop.slt b/e2e_test/udf/create_and_drop.slt new file mode 100644 index 0000000000000..7b31dba16fdbd --- /dev/null +++ b/e2e_test/udf/create_and_drop.slt @@ -0,0 +1,115 @@ +# https://github.com/risingwavelabs/risingwave/issues/17263 + +statement ok +create table t (a int, b int); + +statement ok +create function add(a int, b int) returns int language python as $$ +def add(a, b): + return a+b +$$; + +statement error function with name add\(integer,integer\) exists +create function add(int, int) returns int language sql as $$select $1 + $2$$; + +statement ok +create function if not exists add(int, int) returns int language sql as $$select $1 + $2$$; + +statement ok +create function add_v2(int, int) returns int language sql as $$select $1 + $2$$; + +statement ok +create aggregate mysum(value int) returns int language python as $$ +def create_state(): + return 0 +def accumulate(state, value): + return state + value +def finish(state): + return state +$$; + +statement error function with name mysum\(integer\) exists +create aggregate mysum(value int) returns int language python as $$ +def create_state(): + return 0 +def accumulate(state, value): + return state + value +def finish(state): + return state +$$; + +statement ok +create aggregate if not exists mysum(value int) returns int language python as $$ +def create_state(): + return 0 +def accumulate(state, value): + return state + value +def finish(state): + return state +$$; + +statement ok +create materialized view mv as select add(a, b) + add_v2(a, b) as c from t; + +statement ok +create materialized view mv2 as select mysum(a) as s from t; + +statement error function used by 1 other objects +drop function add; + +statement error function used by 1 other objects +drop function if exists add; + +statement error function used by 1 other objects +drop function add_v2; + +statement error function used by 1 other objects +drop aggregate mysum; + +statement ok +drop materialized view mv; + +statement ok +drop materialized view mv2; + +statement ok +drop function add; + +statement error function not found +drop function add; + +statement ok +drop function if exists add; + +statement ok +drop function add_v2; + +statement ok +drop function if exists add_v2; + +statement ok +drop aggregate mysum; + +statement ok +drop aggregate if exists mysum; + +statement ok +create function add(a int, b int) returns int language python as $$ +def add(a, b): + return a+b +$$; + +statement ok +create sink s as select add(a, b) as c from t with (connector = 'blackhole'); + +statement error function used by 1 other objects +drop function add; + +statement ok +drop sink s; + +statement ok +drop function add; + +statement ok +drop table t; diff --git a/e2e_test/udf/drop_function.slt b/e2e_test/udf/drop_function.slt deleted file mode 100644 index ffe4e0eea481f..0000000000000 --- a/e2e_test/udf/drop_function.slt +++ /dev/null @@ -1,52 +0,0 @@ -# https://github.com/risingwavelabs/risingwave/issues/17263 - -statement ok -create table t (a int, b int); - -statement ok -create function add(a int, b int) returns int language python as $$ -def add(a, b): - return a+b -$$; - -statement ok -create function add_v2(INT, INT) returns int language sql as $$select $1 + $2$$; - -statement ok -create materialized view mv as select add(a, b) + add_v2(a, b) as c from t; - -statement error function used by 1 other objects -drop function add; - -statement error function used by 1 other objects -drop function add_v2; - -statement ok -drop materialized view mv; - -statement ok -drop function add; - -statement ok -drop function add_v2; - -statement ok -create function add(a int, b int) returns int language python as $$ -def add(a, b): - return a+b -$$; - -statement ok -create sink s as select add(a, b) as c from t with (connector = 'blackhole'); - -statement error function used by 1 other objects -drop function add; - -statement ok -drop sink s; - -statement ok -drop function add; - -statement ok -drop table t; diff --git a/src/frontend/src/handler/create_aggregate.rs b/src/frontend/src/handler/create_aggregate.rs index 32f326db9b1d9..85ba343fef408 100644 --- a/src/frontend/src/handler/create_aggregate.rs +++ b/src/frontend/src/handler/create_aggregate.rs @@ -13,6 +13,7 @@ // limitations under the License. use anyhow::Context; +use either::Either; use risingwave_common::catalog::FunctionId; use risingwave_expr::sig::{CreateFunctionOptions, UdfKind}; use risingwave_pb::catalog::function::{AggregateFunction, Kind}; @@ -20,12 +21,12 @@ use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::DataType as AstDataType; use super::*; -use crate::catalog::CatalogError; use crate::{bind_data_type, Binder}; pub async fn handle_create_aggregate( handler_args: HandlerArgs, or_replace: bool, + if_not_exists: bool, name: ObjectName, args: Vec, returns: AstDataType, @@ -74,20 +75,18 @@ pub async fn handle_create_aggregate( // resolve database and schema id let session = &handler_args.session; let db_name = &session.database(); - let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (schema_name, function_name) = + Binder::resolve_schema_qualified_name(db_name, name.clone())?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; // check if the function exists in the catalog - if (session.env().catalog_reader().read_guard()) - .get_schema_by_id(&database_id, &schema_id)? - .get_function_by_name_args(&function_name, &arg_types) - .is_some() - { - let name = format!( - "{function_name}({})", - arg_types.iter().map(|t| t.to_string()).join(",") - ); - return Err(CatalogError::Duplicated("function", name).into()); + if let Either::Right(resp) = session.check_function_name_duplicated( + StatementType::CREATE_FUNCTION, + name, + &arg_types, + if_not_exists, + )? { + return Ok(resp); } let link = match ¶ms.using { diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index b87d3c90a3488..c212eaebb56f4 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -13,6 +13,7 @@ // limitations under the License. use anyhow::Context; +use either::Either; use risingwave_common::catalog::FunctionId; use risingwave_common::types::StructType; use risingwave_expr::sig::{CreateFunctionOptions, UdfKind}; @@ -20,13 +21,13 @@ use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use super::*; -use crate::catalog::CatalogError; use crate::{bind_data_type, Binder}; pub async fn handle_create_function( handler_args: HandlerArgs, or_replace: bool, temporary: bool, + if_not_exists: bool, name: ObjectName, args: Option>, returns: Option, @@ -107,20 +108,18 @@ pub async fn handle_create_function( // resolve database and schema id let session = &handler_args.session; let db_name = &session.database(); - let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (schema_name, function_name) = + Binder::resolve_schema_qualified_name(db_name, name.clone())?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; // check if the function exists in the catalog - if (session.env().catalog_reader().read_guard()) - .get_schema_by_id(&database_id, &schema_id)? - .get_function_by_name_args(&function_name, &arg_types) - .is_some() - { - let name = format!( - "{function_name}({})", - arg_types.iter().map(|t| t.to_string()).join(",") - ); - return Err(CatalogError::Duplicated("function", name).into()); + if let Either::Right(resp) = session.check_function_name_duplicated( + StatementType::CREATE_FUNCTION, + name, + &arg_types, + if_not_exists, + )? { + return Ok(resp); } let link = match ¶ms.using { diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 4725b37ab6511..71a31ce5173cc 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; +use either::Either; use fancy_regex::Regex; use risingwave_common::catalog::FunctionId; use risingwave_common::types::{DataType, StructType}; @@ -23,7 +24,6 @@ use risingwave_sqlparser::parser::{Parser, ParserError}; use super::*; use crate::binder::UdfContext; -use crate::catalog::CatalogError; use crate::expr::{Expr, ExprImpl, Literal}; use crate::{bind_data_type, Binder}; @@ -122,6 +122,7 @@ pub async fn handle_create_sql_function( handler_args: HandlerArgs, or_replace: bool, temporary: bool, + if_not_exists: bool, name: ObjectName, args: Option>, returns: Option, @@ -214,20 +215,18 @@ pub async fn handle_create_sql_function( // resolve database and schema id let session = &handler_args.session; let db_name = &session.database(); - let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (schema_name, function_name) = + Binder::resolve_schema_qualified_name(db_name, name.clone())?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; // check if function exists - if (session.env().catalog_reader().read_guard()) - .get_schema_by_id(&database_id, &schema_id)? - .get_function_by_name_args(&function_name, &arg_types) - .is_some() - { - let name = format!( - "{function_name}({})", - arg_types.iter().map(|t| t.to_string()).join(",") - ); - return Err(CatalogError::Duplicated("function", name).into()); + if let Either::Right(resp) = session.check_function_name_duplicated( + StatementType::CREATE_FUNCTION, + name, + &arg_types, + if_not_exists, + )? { + return Ok(resp); } // Parse function body here diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index bd559b9245fc7..4245e66c3034a 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -278,6 +278,7 @@ pub async fn handle( Statement::CreateFunction { or_replace, temporary, + if_not_exists, name, args, returns, @@ -298,6 +299,7 @@ pub async fn handle( handler_args, or_replace, temporary, + if_not_exists, name, args, returns, @@ -310,6 +312,7 @@ pub async fn handle( handler_args, or_replace, temporary, + if_not_exists, name, args, returns, @@ -320,6 +323,7 @@ pub async fn handle( } Statement::CreateAggregate { or_replace, + if_not_exists, name, args, returns, @@ -329,6 +333,7 @@ pub async fn handle( create_aggregate::handle_create_aggregate( handler_args, or_replace, + if_not_exists, name, args, returns, diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index fe03af7c0c786..7f99d91d6e0d5 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -22,6 +22,7 @@ use std::time::{Duration, Instant}; use anyhow::anyhow; use bytes::Bytes; use either::Either; +use itertools::Itertools; use parking_lot::{Mutex, RwLock, RwLockReadGuard}; use pgwire::error::{PsqlError, PsqlResult}; use pgwire::net::{Address, AddressRef}; @@ -961,6 +962,44 @@ impl SessionImpl { .map_err(RwError::from) } + pub fn check_function_name_duplicated( + &self, + stmt_type: StatementType, + name: ObjectName, + arg_types: &[DataType], + if_not_exists: bool, + ) -> Result> { + let db_name = &self.database(); + let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (database_id, schema_id) = self.get_database_and_schema_id_for_create(schema_name)?; + + let catalog_reader = self.env().catalog_reader().read_guard(); + if catalog_reader + .get_schema_by_id(&database_id, &schema_id)? + .get_function_by_name_args(&function_name, arg_types) + .is_some() + { + let full_name = format!( + "{function_name}({})", + arg_types.iter().map(|t| t.to_string()).join(",") + ); + if if_not_exists { + Ok(Either::Right( + PgResponse::builder(stmt_type) + .notice(format!( + "function \"{}\" already exists, skipping", + full_name + )) + .into(), + )) + } else { + Err(CatalogError::Duplicated("function", full_name).into()) + } + } else { + Ok(Either::Left(())) + } + } + /// Also check if the user has the privilege to create in the schema. pub fn get_database_and_schema_id_for_create( &self, diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 16496b71c97eb..98f8599ccf89a 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -1349,6 +1349,7 @@ pub enum Statement { CreateFunction { or_replace: bool, temporary: bool, + if_not_exists: bool, name: ObjectName, args: Option>, returns: Option, @@ -1361,6 +1362,7 @@ pub enum Statement { /// Postgres: CreateAggregate { or_replace: bool, + if_not_exists: bool, name: ObjectName, args: Vec, returns: DataType, @@ -1768,6 +1770,7 @@ impl fmt::Display for Statement { Statement::CreateFunction { or_replace, temporary, + if_not_exists, name, args, returns, @@ -1776,9 +1779,10 @@ impl fmt::Display for Statement { } => { write!( f, - "CREATE {or_replace}{temp}FUNCTION {name}", + "CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}", temp = if *temporary { "TEMPORARY " } else { "" }, or_replace = if *or_replace { "OR REPLACE " } else { "" }, + if_not_exists = if *if_not_exists { "IF NOT EXISTS " } else { "" }, )?; if let Some(args) = args { write!(f, "({})", display_comma_separated(args))?; @@ -1792,6 +1796,7 @@ impl fmt::Display for Statement { } Statement::CreateAggregate { or_replace, + if_not_exists, name, args, returns, @@ -1800,8 +1805,9 @@ impl fmt::Display for Statement { } => { write!( f, - "CREATE {or_replace}AGGREGATE {name}", + "CREATE {or_replace}AGGREGATE {if_not_exists}{name}", or_replace = if *or_replace { "OR REPLACE " } else { "" }, + if_not_exists = if *if_not_exists { "IF NOT EXISTS " } else { "" }, )?; write!(f, "({})", display_comma_separated(args))?; write!(f, " RETURNS {}", returns)?; @@ -3551,8 +3557,9 @@ mod tests { #[test] fn test_create_function_display() { let create_function = Statement::CreateFunction { - temporary: false, or_replace: false, + temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("foo")]), args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]), returns: Some(CreateFunctionReturns::Value(DataType::Int)), @@ -3573,8 +3580,9 @@ mod tests { format!("{}", create_function) ); let create_function = Statement::CreateFunction { - temporary: false, or_replace: false, + temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("foo")]), args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]), returns: Some(CreateFunctionReturns::Value(DataType::Int)), diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 9eb3d9e439967..2df0183cf5c5a 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2210,6 +2210,8 @@ impl Parser<'_> { or_replace: bool, temporary: bool, ) -> PResult { + impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], self); + let name = self.parse_object_name()?; self.expect_token(&Token::LParen)?; let args = if self.peek_token().token == Token::RParen { @@ -2248,6 +2250,7 @@ impl Parser<'_> { Ok(Statement::CreateFunction { or_replace, temporary, + if_not_exists, name, args, returns: return_type, @@ -2257,6 +2260,8 @@ impl Parser<'_> { } fn parse_create_aggregate(&mut self, or_replace: bool) -> PResult { + impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], self); + let name = self.parse_object_name()?; self.expect_token(&Token::LParen)?; let args = self.parse_comma_separated(Parser::parse_function_arg)?; @@ -2270,6 +2275,7 @@ impl Parser<'_> { Ok(Statement::CreateAggregate { or_replace, + if_not_exists, name, args, returns, diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 7acf6d29b4444..1466a9024a6d5 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -753,6 +753,7 @@ fn parse_create_function() { Statement::CreateFunction { or_replace: false, temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("add")]), args: Some(vec![ OperateFunctionArg::unnamed(DataType::Int), @@ -777,6 +778,7 @@ fn parse_create_function() { Statement::CreateFunction { or_replace: false, temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("sub")]), args: Some(vec![ OperateFunctionArg::unnamed(DataType::Int), @@ -801,6 +803,7 @@ fn parse_create_function() { Statement::CreateFunction { or_replace: false, temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("return_test")]), args: Some(vec![ OperateFunctionArg::unnamed(DataType::Int), @@ -826,6 +829,7 @@ fn parse_create_function() { Statement::CreateFunction { or_replace: true, temporary: false, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("add")]), args: Some(vec![ OperateFunctionArg::with_name("a", DataType::Int), @@ -851,12 +855,14 @@ fn parse_create_function() { } ); - let sql = "CREATE FUNCTION unnest(a INT[]) RETURNS TABLE (x INT) LANGUAGE SQL RETURN a"; + let sql = + "CREATE TEMPORARY FUNCTION unnest(a INT[]) RETURNS TABLE (x INT) LANGUAGE SQL RETURN a"; assert_eq!( verified_stmt(sql), Statement::CreateFunction { or_replace: false, - temporary: false, + temporary: true, + if_not_exists: false, name: ObjectName(vec![Ident::new_unchecked("unnest")]), args: Some(vec![OperateFunctionArg::with_name( "a", @@ -874,6 +880,32 @@ fn parse_create_function() { with_options: Default::default(), } ); + + let sql = + "CREATE FUNCTION IF NOT EXISTS add(INT, INT) RETURNS INT LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + if_not_exists: true, + name: ObjectName(vec![Ident::new_unchecked("add")]), + args: Some(vec![ + OperateFunctionArg::unnamed(DataType::Int), + OperateFunctionArg::unnamed(DataType::Int), + ]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some("SQL".into()), + behavior: Some(FunctionBehavior::Immutable), + as_: Some(FunctionDefinition::SingleQuotedDef( + "select $1 + $2;".into() + )), + ..Default::default() + }, + with_options: Default::default(), + } + ); } #[test] @@ -884,6 +916,27 @@ fn parse_create_aggregate() { verified_stmt(sql), Statement::CreateAggregate { or_replace: true, + if_not_exists: false, + name: ObjectName(vec![Ident::new_unchecked("sum")]), + args: vec![OperateFunctionArg::unnamed(DataType::Int)], + returns: DataType::BigInt, + append_only: true, + params: CreateFunctionBody { + language: Some("python".into()), + as_: Some(FunctionDefinition::SingleQuotedDef("sum".into())), + using: Some(CreateFunctionUsing::Link("xxx".into())), + ..Default::default() + }, + } + ); + + let sql = + "CREATE AGGREGATE IF NOT EXISTS sum(INT) RETURNS BIGINT APPEND ONLY LANGUAGE python AS 'sum' USING LINK 'xxx'"; + assert_eq!( + verified_stmt(sql), + Statement::CreateAggregate { + or_replace: false, + if_not_exists: true, name: ObjectName(vec![Ident::new_unchecked("sum")]), args: vec![OperateFunctionArg::unnamed(DataType::Int)], returns: DataType::BigInt,