Skip to content

Commit

Permalink
add udf
Browse files Browse the repository at this point in the history
  • Loading branch information
YinZheng-Sun committed Jan 15, 2025
1 parent 64a80f1 commit 3f011eb
Show file tree
Hide file tree
Showing 26 changed files with 650 additions and 1 deletion.
2 changes: 2 additions & 0 deletions metadata/INFORMATION_SCHEMA.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ATTACH DATABASE INFORMATION_SCHEMA
ENGINE = Memory
1 change: 1 addition & 0 deletions metadata/default
2 changes: 2 additions & 0 deletions metadata/default.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ATTACH DATABASE _ UUID '5abc80be-7165-462e-9e28-94d4467dd408'
ENGINE = Atomic
2 changes: 2 additions & 0 deletions metadata/information_schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ATTACH DATABASE information_schema
ENGINE = Memory
1 change: 1 addition & 0 deletions metadata/system
2 changes: 2 additions & 0 deletions metadata/system.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ATTACH DATABASE _ UUID '99a3fa29-6efd-4622-b822-ba63a202d9f4'
ENGINE = Atomic
44 changes: 44 additions & 0 deletions preprocessed_configs/config.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<!-- This file was generated automatically.
Do not edit it: it is likely to be discarded and generated again before it's read next time.
Files used to generate this file:
config.xml -->

<!-- Config that is used when server is run without config file. -->
<clickhouse>
<logger>
<level>trace</level>
<console>true</console>
</logger>

<http_port>8123</http_port>
<tcp_port>9000</tcp_port>
<mysql_port>9004</mysql_port>

<path>./</path>

<mlock_executable>true</mlock_executable>

<users>
<default>
<password/>

<networks>
<ip>::/0</ip>
</networks>

<profile>default</profile>
<quota>default</quota>

<access_management>1</access_management>
<named_collection_control>1</named_collection_control>
</default>
</users>

<profiles>
<default/>
</profiles>

<quotas>
<default/>
</quotas>
</clickhouse>
2 changes: 1 addition & 1 deletion programs/server/config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
[1]: https://github.com/pocoproject/poco/blob/poco-1.9.4-release/Foundation/include/Poco/Logger.h#L105-L114
-->
<level>trace</level>
<level>debug</level>
<log>/var/log/clickhouse-server/clickhouse-server.log</log>
<errorlog>/var/log/clickhouse-server/clickhouse-server.err.log</errorlog>
<!-- Rotation policy
Expand Down
20 changes: 20 additions & 0 deletions programs/server/udf.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
DROP FUNCTION IF EXISTS ai_estimate_rows;

CREATE FUNCTION imdb.ai_estimate_rows
RETURNS UInt64
LANGUAGE PYTHON AS
$code$
from iudf import IUDF
from overload import overload
from BayesCard.Evaluation.cardinality_estimation import test_inference_result

class ai_estimate_rows(IUDF):
@overload
def process(sql):
res = test_inference_result(sql)
return res
$code$;

DROP FUNCTION IF EXISTS ai_estimate_rows;

select imdb.ai_estimate_rows('SELECT COUNT(*) FROM movie_info_idx WHERE id >= 713168 AND movie_id >= 2075341');
12 changes: 12 additions & 0 deletions programs/server/user_defined/function_ai_estimate_rows.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE FUCTION imdb.ai_estimate_rows TYPE SCALAR RETURNS UInt64 LANGUAGE PYTHON AS
$code$
from iudf import IUDF
from overload import overload
from BayesCard.Evaluation.cardinality_estimation import test_inference_result

class ai_estimate_rows(IUDF):
@overload
def process(sql):
res = test_inference_result(sql)
return res
$code$
15 changes: 15 additions & 0 deletions src/Common/BitHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,18 @@ constexpr bool isPowerOf2(T number)
{
return number > 0 && (number & (number - 1)) == 0;
}

template <typename E>
constexpr auto to_underlying(E e) noexcept
{
return static_cast<std::underlying_type_t<E>>(e);
}

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-macro-identifier"
#define __bf_shf(x) (__builtin_ffsll(x) - 1)
#pragma clang diagnostic pop

#define GENMASK(h, l) (((~0UL) - (1UL << (l)) + 1) & (~0UL >> (sizeof(long)*8 - 1 - (h))))
#define FIELD_GET(_mask, _reg) (decltype(_mask))(((_reg) & (_mask)) >> __bf_shf(_mask))
#define FIELD_PREP(_mask, _val) (((decltype(_mask))(_val) << __bf_shf(_mask)) & (_mask))
76 changes: 76 additions & 0 deletions src/Functions/UDFFlags.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#pragma once
#include <cstdint>
#include <limits>
#include <string>
#include <string.h>
#include <Common/BitHelpers.h>
#include <IO/ReadBufferFromMemory.h>



namespace DB::UDF
{
constexpr auto SQL = "SQL";
constexpr auto PYTHON = "PYTHON";
constexpr auto SCALAR = "SCALAR";
constexpr auto AGGREGATE = "AGGREGATE";
constexpr auto ML = "ML";
using UDFFlags = uint32_t;
enum class UDFLanguage
{
Sql,
Python
};
enum class UDFFunctionType
{
Scalar,
Aggregate,
ML
};
static const char * Languages[2] = {SQL, PYTHON};
static const char * FunctionTypes[3] = {SCALAR, AGGREGATE, ML};
#define UDF_FIELD_TYPE GENMASK(28, 25)
#define UDF_FIELD_LANG GENMASK(31, 29)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wshorten-64-to-32"
static inline int getEnumValFromString(const char* enum_name, const char * const * arr, size_t n) {
for (size_t i = 0; i < n; i++) {
if (strcmp(arr[i], enum_name) == 0) {
return i;
}
}
return -1;
}
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
static inline enum UDFLanguage getLanguage(UDFFlags flags) {
return static_cast<UDFLanguage>(FIELD_GET(UDF_FIELD_LANG, flags));
}
static inline enum UDFFunctionType getFunctionType(UDFFlags flags)
{
return static_cast<UDFFunctionType>(FIELD_GET(UDF_FIELD_TYPE, flags));
}
static inline bool setFlag(UDFFlags &flags, const char *value, const char** enums, size_t n, uint64_t mask) {
int z = getEnumValFromString(value, enums, n);
if (z == -1) {
return false;
}
flags |= FIELD_PREP(mask, z);
return true;
}
#pragma clang diagnostic pop
static inline const char * getLanguageStr(UDFLanguage language) {
return Languages[to_underlying(language)];
}
static inline const char * getFunctionTypeStr(UDFFunctionType func_type)
{
return FunctionTypes[to_underlying(func_type)];
}
static inline bool setFunctionType(UDFFlags &flags, const char *type) {
return setFlag(flags, type, FunctionTypes, sizeof(FunctionTypes), UDF_FIELD_TYPE);
}
static inline bool setLangauge(UDFFlags &flags, const char *language) {
return setFlag(flags, language, Languages, sizeof(Languages), UDF_FIELD_LANG);
}
}
53 changes: 53 additions & 0 deletions src/Functions/UserDefined/UserDefinedSQLFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/FunctionNameNormalizer.h>
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Parsers/ASTCreateUserDefinedFunctionQuery.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Common/quoteString.h>
Expand Down Expand Up @@ -150,6 +151,34 @@ bool UserDefinedSQLFunctionFactory::registerFunction(const ContextMutablePtr & c
return true;
}

bool UserDefinedSQLFunctionFactory::registerUserDefinedFunction(const ContextMutablePtr & context, const String & function_name, ASTPtr create_function_query)
{
// checkCanBeRegistered(context, function_name, *create_function_query);
// create_function_query = normalizeCreateFunctionQuery(*create_function_query);

try
{
auto & loader = context->getUserDefinedSQLObjectsStorage();
bool stored = loader.storeObject(
context,
UserDefinedSQLObjectType::Function,
function_name,
create_function_query,
false,
true,
context->getSettingsRef());
if (!stored)
return false;
}
catch (Exception & exception)
{
exception.addMessage(fmt::format("while storing user defined function {}", backQuote(function_name)));
throw;
}

return true;
}

bool UserDefinedSQLFunctionFactory::unregisterFunction(const ContextMutablePtr & context, const String & function_name, bool throw_if_not_exists)
{
checkCanBeUnregistered(context, function_name);
Expand All @@ -174,6 +203,30 @@ bool UserDefinedSQLFunctionFactory::unregisterFunction(const ContextMutablePtr &
return true;
}

bool UserDefinedSQLFunctionFactory::unregisterUserDefinedFunction(const ContextMutablePtr & context, const String & function_name)
{
// checkCanBeUnregistered(context, function_name);

try
{
auto & storage = context->getUserDefinedSQLObjectsStorage();
bool removed = storage.removeObject(
context,
UserDefinedSQLObjectType::Function,
function_name,
false);
if (!removed)
return false;
}
catch (Exception & exception)
{
exception.addMessage(fmt::format("while removing user defined function {}", backQuote(function_name)));
throw;
}

return true;
}

ASTPtr UserDefinedSQLFunctionFactory::get(const String & function_name) const
{
return global_context->getUserDefinedSQLObjectsStorage().get(function_name);
Expand Down
3 changes: 3 additions & 0 deletions src/Functions/UserDefined/UserDefinedSQLFunctionFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ class UserDefinedSQLFunctionFactory : public IHints<>
/// Register function for function_name in factory for specified create_function_query.
bool registerFunction(const ContextMutablePtr & context, const String & function_name, ASTPtr create_function_query, bool throw_if_exists, bool replace_if_exists);

bool registerUserDefinedFunction(const ContextMutablePtr & context, const String & function_name, ASTPtr create_function_query);
/// Unregister function for function_name.
bool unregisterFunction(const ContextMutablePtr & context, const String & function_name, bool throw_if_not_exists);

bool unregisterUserDefinedFunction(const ContextMutablePtr & context, const String & function_name);

/// Get function create query for function_name. If no function registered with function_name throws exception.
ASTPtr get(const String & function_name) const;

Expand Down
29 changes: 29 additions & 0 deletions src/IO/ReadHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,35 @@ template void readJSONStringInto<NullOutput>(NullOutput & s, ReadBuffer & buf);
template void readJSONStringInto<String>(String & s, ReadBuffer & buf);
template bool readJSONStringInto<String, bool>(String & s, ReadBuffer & buf);

bool readDollarQuotedStringInto(String & s, ReadBuffer & buf, String & tag, size_t size) {
if (*buf.position() != '$') {
return false;
}

const char *tag_end = find_first_symbols<'$'>(buf.position()+1, buf.position() + size);

if (tag_end == nullptr) {
return false;
}

appendToStringOrVector(tag, buf, tag_end+1);
size_t tag_length = tag.length();

while(tag_length--) {
buf.position()++;
}

const char *output = strstr(buf.position(), &tag[0]);

if (output == nullptr) {
return false;
}

appendToStringOrVector(s, buf, output-1);

return true;
}

template <typename Vector, typename ReturnType, char opening_bracket, char closing_bracket>
ReturnType readJSONObjectOrArrayPossiblyInvalid(Vector & s, ReadBuffer & buf)
{
Expand Down
2 changes: 2 additions & 0 deletions src/IO/ReadHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ bool tryReadJSONStringInto(Vector & s, ReadBuffer & buf)
return readJSONStringInto<Vector, bool>(s, buf);
}

bool readDollarQuotedStringInto(String & s, ReadBuffer & buf, String & tag, size_t size);

template <bool enable_sql_style_quoting, typename Vector>
bool tryReadQuotedStringInto(Vector & s, ReadBuffer & buf);

Expand Down
Loading

0 comments on commit 3f011eb

Please sign in to comment.