diff --git a/Cargo.lock b/Cargo.lock index ae544b3f7b27..3f3d8d876481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2603,7 +2603,6 @@ dependencies = [ "criterion", "databend-common-base", "databend-common-exception", - "databend-common-functions", "databend-common-io", "databend-common-meta-app", "derive-visitor", diff --git a/src/query/ast/Cargo.toml b/src/query/ast/Cargo.toml index ca7e660b6739..446fe7c706b7 100644 --- a/src/query/ast/Cargo.toml +++ b/src/query/ast/Cargo.toml @@ -13,7 +13,6 @@ doctest = false # Workspace dependencies databend-common-base = { path = "../../common/base" } databend-common-exception = { path = "../../common/exception" } -databend-common-functions = { path = "../functions" } databend-common-io = { path = "../../common/io" } databend-common-meta-app = { path = "../../meta/app" } diff --git a/src/query/ast/src/ast/expr.rs b/src/query/ast/src/ast/expr.rs index 5a695c412157..5c324957eb6c 100644 --- a/src/query/ast/src/ast/expr.rs +++ b/src/query/ast/src/ast/expr.rs @@ -18,7 +18,6 @@ use std::fmt::Formatter; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_exception::Span; -use databend_common_functions::aggregates::AggregateFunctionFactory; use databend_common_io::display_decimal_256; use databend_common_io::escape_string_with_quote; use derive_visitor::Drive; @@ -1386,90 +1385,3 @@ pub fn split_equivalent_predicate_expr(expr: &Expr) -> Option<(Expr, Expr)> { _ => None, } } - -// If contain agg function in Expr -pub fn contain_agg_func(expr: &Expr) -> bool { - match expr { - Expr::ColumnRef { .. } => false, - Expr::IsNull { expr, .. } => contain_agg_func(expr), - Expr::IsDistinctFrom { left, right, .. } => { - contain_agg_func(left) || contain_agg_func(right) - } - Expr::InList { expr, list, .. } => { - contain_agg_func(expr) || list.iter().any(contain_agg_func) - } - Expr::InSubquery { expr, .. } => contain_agg_func(expr), - Expr::Between { - expr, low, high, .. - } => contain_agg_func(expr) || contain_agg_func(low) || contain_agg_func(high), - Expr::BinaryOp { left, right, .. } => contain_agg_func(left) || contain_agg_func(right), - Expr::JsonOp { left, right, .. } => contain_agg_func(left) || contain_agg_func(right), - Expr::UnaryOp { expr, .. } => contain_agg_func(expr), - Expr::Cast { expr, .. } => contain_agg_func(expr), - Expr::TryCast { expr, .. } => contain_agg_func(expr), - Expr::Extract { expr, .. } => contain_agg_func(expr), - Expr::DatePart { expr, .. } => contain_agg_func(expr), - Expr::Position { - substr_expr, - str_expr, - .. - } => contain_agg_func(substr_expr) || contain_agg_func(str_expr), - Expr::Substring { - expr, - substring_for, - substring_from, - .. - } => { - if let Some(substring_for) = substring_for { - contain_agg_func(expr) || contain_agg_func(substring_for) - } else { - contain_agg_func(expr) || contain_agg_func(substring_from) - } - } - Expr::Trim { expr, .. } => contain_agg_func(expr), - Expr::Literal { .. } => false, - Expr::CountAll { .. } => false, - Expr::Tuple { exprs, .. } => exprs.iter().any(contain_agg_func), - Expr::FunctionCall { func, .. } => { - AggregateFunctionFactory::instance().contains(func.name.to_string()) - } - Expr::Case { - operand, - conditions, - results, - else_result, - .. - } => { - if let Some(operand) = operand { - if contain_agg_func(operand) { - return true; - } - } - if conditions.iter().any(contain_agg_func) { - return true; - } - if results.iter().any(contain_agg_func) { - return true; - } - if let Some(else_result) = else_result { - if contain_agg_func(else_result) { - return true; - } - } - false - } - Expr::Exists { .. } => false, - Expr::Subquery { .. } => false, - Expr::MapAccess { expr, .. } => contain_agg_func(expr), - Expr::Array { exprs, .. } => exprs.iter().any(contain_agg_func), - Expr::Map { kvs, .. } => kvs.iter().any(|(_, v)| contain_agg_func(v)), - Expr::Interval { expr, .. } => contain_agg_func(expr), - Expr::DateAdd { interval, date, .. } => { - contain_agg_func(interval) || contain_agg_func(date) - } - Expr::DateSub { interval, date, .. } => { - contain_agg_func(interval) || contain_agg_func(date) - } - Expr::DateTrunc { date, .. } => contain_agg_func(date), - } -} diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index c3bf1bfec370..eeea971ae526 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -17,7 +17,6 @@ use std::collections::VecDeque; use std::sync::Arc; use std::vec; -use databend_common_ast::ast::contain_agg_func; use databend_common_ast::ast::BinaryOperator; use databend_common_ast::ast::ColumnID; use databend_common_ast::ast::ColumnRef; @@ -80,6 +79,8 @@ use databend_common_meta_app::principal::UDFDefinition; use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; use databend_common_users::UserApiProvider; +use derive_visitor::Drive; +use derive_visitor::Visitor; use indexmap::IndexMap; use itertools::Itertools; use jsonb::keypath::KeyPath; @@ -2362,7 +2363,21 @@ impl<'a> TypeChecker<'a> { let select = &select_stmt.select_list[0]; if let SelectTarget::AliasedExpr { expr, .. } = select { // Check if contain aggregation function - contain_agg = Some(contain_agg_func(expr)); + #[derive(Visitor)] + #[visitor(ASTFunctionCall(enter))] + struct AggFuncVisitor { + contain_agg: bool, + } + impl AggFuncVisitor { + fn enter_ast_function_call(&mut self, func: &ASTFunctionCall) { + self.contain_agg = self.contain_agg + || AggregateFunctionFactory::instance() + .contains(func.name.to_string()); + } + } + let mut visitor = AggFuncVisitor { contain_agg: false }; + expr.drive(&mut visitor); + contain_agg = Some(visitor.contain_agg); } } }