From abe6e886c2a342144679c4614bbe2efbf585941f Mon Sep 17 00:00:00 2001 From: Mingzhuo Yin Date: Wed, 15 Jan 2025 13:27:04 +0800 Subject: [PATCH] refactor: api about tokenizer Signed-off-by: Mingzhuo Yin --- Cargo.toml | 3 + README.md | 25 +- src/datatype/functions.rs | 103 +------ src/guc.rs | 12 - src/sql/finalize.sql | 43 +-- src/sql/tokenizer.sql | 38 +++ src/token.rs | 366 ++++++++++++++++++++--- tests/sqllogictest/delete.slt | 10 +- tests/sqllogictest/empty.slt | 8 +- tests/sqllogictest/index.slt | 10 +- tests/sqllogictest/temp.slt | 10 +- tests/sqllogictest/tokenizer.slt | 25 ++ tests/sqllogictest/unicode_tokenizer.slt | 88 ++++++ tests/sqllogictest/unlogged.slt | 8 +- tokenizer.md | 59 ++-- 15 files changed, 557 insertions(+), 251 deletions(-) create mode 100644 src/sql/tokenizer.sql create mode 100644 tests/sqllogictest/tokenizer.slt create mode 100644 tests/sqllogictest/unicode_tokenizer.slt diff --git a/Cargo.toml b/Cargo.toml index 4814c51..344f8e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,8 +34,11 @@ tantivy-stemmers = { version = "0.4.0", features = [ thiserror = "2" tokenizers = { version = "0.20", default-features = false, features = ["onig"] } +serde = { version = "1.0.217", features = ["derive"] } tocken = "0.1.0" +toml = "0.8.19" unicode-segmentation = "1.12.0" +validator = { version = "0.19.0", features = ["derive"] } [dev-dependencies] rand = "0.8" diff --git a/README.md b/README.md index 93d51bb..1fac1d8 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ A PostgreSQL extension for bm25 ranking algorithm. We implemented the Block-Weak ```sql CREATE TABLE documents ( id SERIAL PRIMARY KEY, - passage TEXT + passage TEXT, + embedding bm25vector ); INSERT INTO documents (passage) VALUES @@ -22,13 +23,11 @@ INSERT INTO documents (passage) VALUES ('Relational databases such as PostgreSQL can handle both structured and unstructured data.'), ('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); -ALTER TABLE documents ADD COLUMN embedding bm25vector; - -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'PostgreSQL') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'PostgreSQL', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -92,22 +91,20 @@ CREATE EXTENSION vchord_bm25; ### Functions -- `tokenize(text) RETURNS bm25vector`: Tokenize the input text into a BM25 vector. -- `to_bm25query(index_name regclass, query text) RETURNS bm25query`: Convert the input text into a BM25 query. +- `create_tokenizer(tokenizer_name text, config text)`: Create a tokenizer with the given name and configuration. +- `create_unicode_tokenizer_and_trigger(tokenizer_name text, table_name text, source_column text, target_column text)`: Create a Unicode tokenizer and trigger function for the given table and columns. It will automatically build the tokenizer according to source_column and store the result in target_column. +- `drop_tokenizer(tokenizer_name text)`: Drop the tokenizer with the given name. +- `tokenize(content text, tokenizer_name text) RETURNS bm25vector`: Tokenize the content text into a BM25 vector. +- `to_bm25query(index_name regclass, query text, tokenizer_name text) RETURNS bm25query`: Convert the input text into a BM25 query. - `bm25vector <&> bm25query RETURNS float4`: Calculate the **negative** BM25 score between the BM25 vector and query. -- `unicode_tokenizer_trigger(text_column text, vec_column text, stored_token_table text) RETURNS TRIGGER`: A trigger function to tokenize the `text_column`, store the vector in `vec_column`, and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document. -- `document_unicode_tokenize(content text, stored_token_table text) RETURNS bm25vector`: tokenize the `content` and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document. -- `bm25_query_unicode_tokenize(index_name regclass, query text, stored_token_table text) RETURNS bm25query`: Tokenize the `query` into a BM25 query vector according to the tokens stored in `stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document. + +For more information about tokenizer, check the [tokenizer](./tokenizer.md) document. ### GUCs - `bm25_catalog.bm25_limit (integer)`: The maximum number of documents to return in a search. Default is 1, minimum is 1, and maximum is 65535. - `bm25_catalog.enable_index (boolean)`: Whether to enable the bm25 index. Default is false. - `bm25_catalog.segment_growing_max_page_size (integer)`: The maximum page count of the growing segment. When the size of the growing segment exceeds this value, the segment will be sealed into a read-only segment. Default is 1, minimum is 1, and maximum is 1,000,000. -- `bm25_catalog.tokenizer (text)`: Tokenizer chosen from: - - `BERT`: default uncased BERT tokenizer. - - `TOCKEN`: a Unicode tokenizer pre-trained on wiki-103-raw. - - `UNICODE`: a Unicode tokenizer that will be trained on your data. (need to work with the trigger function `unicode_tokenizer_trigger`) ## Contribution diff --git a/src/datatype/functions.rs b/src/datatype/functions.rs index 892db15..e5f2c4e 100644 --- a/src/datatype/functions.rs +++ b/src/datatype/functions.rs @@ -1,22 +1,13 @@ -use std::{collections::HashMap, num::NonZero}; - -use pgrx::{pg_sys::panic::ErrorReportable, IntoDatum}; +use std::num::NonZero; use crate::{ page::{page_read, METAPAGE_BLKNO}, segment::{meta::MetaPageData, term_stat::TermStatReader}, - token::unicode_tokenize, weight::bm25_score_batch, }; use super::memory_bm25vector::{Bm25VectorInput, Bm25VectorOutput}; -#[pgrx::pg_extern(immutable, strict, parallel_safe)] -pub fn tokenize(text: &str) -> Bm25VectorOutput { - let term_ids = crate::token::tokenize(text); - Bm25VectorOutput::from_ids(&term_ids) -} - #[pgrx::pg_extern(stable, strict, parallel_safe)] pub fn search_bm25query( target_vector: Bm25VectorInput, @@ -52,95 +43,3 @@ pub fn search_bm25query( scores * -1.0 } - -#[pgrx::pg_extern()] -pub fn document_unicode_tokenize(text: &str, token_table: &str) -> Bm25VectorOutput { - let tokens = unicode_tokenize(text); - let args = Some(vec![( - pgrx::PgBuiltInOids::TEXTARRAYOID.oid(), - tokens.clone().into_datum(), - )]); - - let mut token_ids = HashMap::new(); - pgrx::Spi::connect(|mut client| { - let query = format!( - r#" - WITH new_tokens AS (SELECT unnest($1::text[]) AS token), - to_insert AS ( - SELECT token FROM new_tokens - WHERE NOT EXISTS ( - SELECT 1 FROM bm25_catalog.{} WHERE token = new_tokens.token - ) - ), - ins AS ( - INSERT INTO bm25_catalog.{} (token) - SELECT token FROM to_insert - ON CONFLICT (token) DO NOTHING - RETURNING id, token - ) - SELECT id, token FROM ins - UNION ALL - SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1); - "#, - token_table, token_table, token_table - ); - let table = client.update(&query, None, args).unwrap_or_report(); - for row in table { - let id: i32 = row - .get_by_name("id") - .expect("no id column") - .expect("no id value"); - let token: String = row - .get_by_name("token") - .expect("no token column") - .expect("no token value"); - token_ids.insert(token, id as u32); - } - }); - - let ids = tokens - .iter() - .map(|t| *token_ids.get(t).expect("unknown token")) - .collect::>(); - Bm25VectorOutput::from_ids(&ids) -} - -#[pgrx::pg_extern(immutable, strict, parallel_safe)] -pub fn query_unicode_tokenize(query: &str, token_table: &str) -> Bm25VectorOutput { - let tokens = unicode_tokenize(query); - let args = Some(vec![( - pgrx::PgBuiltInOids::TEXTARRAYOID.oid(), - tokens.clone().into_datum(), - )]); - let mut token_ids = HashMap::new(); - pgrx::Spi::connect(|client| { - let table = client - .select( - &format!( - "SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1);", - token_table - ), - None, - args, - ) - .unwrap_or_report(); - for row in table { - let id: i32 = row - .get_by_name("id") - .expect("no id column") - .expect("no id value"); - let token: String = row - .get_by_name("token") - .expect("no token column") - .expect("no token value"); - token_ids.insert(token, id as u32); - } - }); - - let ids = tokens - .iter() - .filter(|&t| token_ids.contains_key(t)) - .map(|t| *token_ids.get(t).unwrap()) - .collect::>(); - Bm25VectorOutput::from_ids(&ids) -} diff --git a/src/guc.rs b/src/guc.rs index be75e4f..1718684 100644 --- a/src/guc.rs +++ b/src/guc.rs @@ -1,12 +1,8 @@ -use std::ffi::CStr; - use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; pub static BM25_LIMIT: GucSetting = GucSetting::::new(100); pub static ENABLE_INDEX: GucSetting = GucSetting::::new(true); pub static SEGMENT_GROWING_MAX_PAGE_SIZE: GucSetting = GucSetting::::new(1000); -pub static TOKENIZER_NAME: GucSetting> = - GucSetting::>::new(Some(c"BERT")); pub unsafe fn init() { GucRegistry::define_int_guc( @@ -37,12 +33,4 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_string_guc( - "bm25_catalog.tokenizer", - "tokenizer name", - "tokenizer name", - &TOKENIZER_NAME, - GucContext::Userset, - GucFlags::default(), - ); } diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index f9e389a..1c0a760 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -13,9 +13,9 @@ CREATE TYPE bm25query AS ( query_vector bm25vector ); -CREATE FUNCTION to_bm25query(index_oid regclass, query_str text) RETURNS bm25query - IMMUTABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$ - SELECT index_oid, tokenize(query_str); +CREATE FUNCTION to_bm25query(index_oid regclass, query_str text, tokenizer_name text) RETURNS bm25query + STABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$ + SELECT index_oid, tokenize(query_str, tokenizer_name); $$; CREATE ACCESS METHOD bm25 TYPE INDEX HANDLER _bm25_amhandler; @@ -31,40 +31,3 @@ CREATE OPERATOR FAMILY bm25_ops USING bm25; CREATE OPERATOR CLASS bm25_ops FOR TYPE bm25vector USING bm25 FAMILY bm25_ops AS OPERATOR 1 pg_catalog.<&>(bm25vector, bm25query) FOR ORDER BY float_ops; - --- CREATE TABLE IF NOT EXISTS bm25_catalog.unicode_tokenizer ( --- id INT BY DEFAULT AS IDENTITY PRIMARY KEY, --- source_table TEXT NOT NULL, --- source_column TEXT NOT NULL, --- target_column TEXT NOT NULL, --- token_table TEXT NOT NULL --- ); - -CREATE OR REPLACE FUNCTION unicode_tokenizer_trigger() -RETURNS TRIGGER AS $$ -DECLARE - source_column TEXT := TG_ARGV[0]; - target_column TEXT := TG_ARGV[1]; - token_table TEXT := TG_ARGV[2]; - schema_name TEXT := 'bm25_catalog'; - result bm25vector; -BEGIN - IF NOT EXISTS ( - SELECT 1 - FROM information_schema.tables - WHERE table_schema = schema_name AND table_name = token_table - ) THEN - EXECUTE format('CREATE TABLE %I.%I (id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, token TEXT UNIQUE)', schema_name, token_table); - -- EXECUTE format('INSERT INTO %I.unicode_tokenizer (source_table, source_column, target_column, token_table) VALUES (%L, %L, %L, %L)', schema_name, TG_TABLE_NAME, source_column, target_column, token_table); - END IF; - - EXECUTE format('select document_unicode_tokenize($1.%I, ''%I'')', source_column, token_table) INTO result USING NEW; - EXECUTE format('UPDATE %I SET %I = %L WHERE id = $1.id', TG_TABLE_NAME, target_column, result) USING NEW; - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -CREATE OR REPLACE FUNCTION bm25_query_unicode_tokenize(index_oid regclass, query text, token_table text) RETURNS bm25query - IMMUTABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$ - SELECT index_oid, query_unicode_tokenize(query, token_table); - $$; diff --git a/src/sql/tokenizer.sql b/src/sql/tokenizer.sql new file mode 100644 index 0000000..cbbc753 --- /dev/null +++ b/src/sql/tokenizer.sql @@ -0,0 +1,38 @@ +CREATE TABLE bm25_catalog.tokenizers ( + name TEXT NOT NULL UNIQUE PRIMARY KEY, + config TEXT NOT NULL +); + +CREATE FUNCTION unicode_tokenizer_insert_trigger() +RETURNS TRIGGER AS $$ +DECLARE + tokenizer_name TEXT := TG_ARGV[0]; + target_column TEXT := TG_ARGV[1]; +BEGIN + EXECUTE format(' + WITH new_tokens AS ( + SELECT unnest(unicode_tokenizer_split($1.%I)) AS token + ), + to_insert AS ( + SELECT token FROM new_tokens + WHERE NOT EXISTS ( + SELECT 1 FROM bm25_catalog.%I WHERE token = new_tokens.token + ) + ) + INSERT INTO bm25_catalog.%I (token) SELECT token FROM to_insert ON CONFLICT (token) DO NOTHING', target_column, tokenizer_name, tokenizer_name) USING NEW; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION create_unicode_tokenizer_and_trigger(tokenizer_name TEXT, table_name TEXT, source_column TEXT, target_column TEXT) +RETURNS VOID AS $body$ +BEGIN + EXECUTE format('SELECT create_tokenizer(%L, $$ + tokenizer = ''Unicode'' + table = %L + column = %L + $$)', tokenizer_name, table_name, source_column); + EXECUTE format('UPDATE %I SET %I = tokenize(%I, %L)', table_name, target_column, source_column, tokenizer_name); + EXECUTE format('CREATE TRIGGER "%s_trigger_insert" BEFORE INSERT OR UPDATE OF %I ON %I FOR EACH ROW EXECUTE FUNCTION unicode_tokenizer_set_target_column_trigger(%L, %I, %I)', tokenizer_name, source_column, table_name, tokenizer_name, source_column, target_column); +END; +$body$ LANGUAGE plpgsql; diff --git a/src/token.rs b/src/token.rs index c968ad7..165ae94 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,9 +1,15 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; +use pgrx::{ + extension_sql_file, pg_sys::panic::ErrorReportable, pg_trigger, prelude::PgHeapTuple, + spi::SpiClient, IntoDatum, WhoAllocated, +}; +use serde::{Deserialize, Serialize}; use tocken::tokenizer::Tokenizer as Tockenizer; use unicode_segmentation::UnicodeSegmentation; +use validator::{Validate, ValidationError}; -use crate::guc::TOKENIZER_NAME; +use crate::datatype::Bm25VectorOutput; static BERT_BASE_UNCASED_BYTES: &[u8] = include_bytes!("../tokenizer/bert_base_uncased.json"); static TOCKEN: &[u8] = include_bytes!("../tokenizer/wiki_tocken.json"); @@ -24,25 +30,50 @@ lazy_static::lazy_static! { words.into_iter().collect() }; - static ref BERT_TOKENIZER: BertWithStemmerAndSplit = Default::default(); - static ref TOCKENIZER: Tocken = Tocken(Tockenizer::loads(std::str::from_utf8(TOCKEN).expect("str"))); + static ref BERT_TOKENIZER: BertWithStemmerAndSplit = BertWithStemmerAndSplit::new(); + static ref TOCKENIZER: Tocken = Tocken::new(); } -pub fn tokenize(text: &str) -> Vec { - match TOKENIZER_NAME - .get() - .expect("set guc") - .to_str() - .expect("str") - { - "BERT" => BERT_TOKENIZER.encode(text), - "TOCKEN" => TOCKENIZER.encode(text), - "UNICODE" => panic!("only support the trigger"), - _ => panic!("Unknown tokenizer"), +struct BertWithStemmerAndSplit(tokenizers::Tokenizer); + +impl BertWithStemmerAndSplit { + fn new() -> Self { + Self(tokenizers::Tokenizer::from_bytes(BERT_BASE_UNCASED_BYTES).unwrap()) + } + + fn encode(&self, text: &str) -> Vec { + let mut results = Vec::new(); + let lower_text = text.to_lowercase(); + let split = TOKEN_PATTERN_RE.find_iter(&lower_text); + for token in split { + if STOP_WORDS_NLTK.contains(token.as_str()) { + continue; + } + let stemmed_token = + tantivy_stemmers::algorithms::english_porter_2(token.as_str()).to_string(); + let encoding = self.0.encode_fast(stemmed_token, false).unwrap(); + results.extend_from_slice(encoding.get_ids()); + } + results + } +} + +struct Tocken(Tockenizer); + +impl Tocken { + fn new() -> Self { + Self(tocken::tokenizer::Tokenizer::loads( + std::str::from_utf8(TOCKEN).unwrap(), + )) + } + + fn encode(&self, text: &str) -> Vec { + self.0.tokenize(text) } } -pub fn unicode_tokenize(text: &str) -> Vec { +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +pub fn unicode_tokenizer_split(text: &str) -> Vec { let mut tokens = Vec::new(); for word in text.unicode_words() { // trim `'s` for English @@ -68,40 +99,295 @@ pub fn unicode_tokenize(text: &str) -> Vec { tokens } -trait Tokenizer { - fn encode(&self, text: &str) -> Vec; +#[derive(Clone, Copy, Serialize, Deserialize)] +#[repr(i32)] +enum TokenizerKind { + Bert, + Tocken, + Unicode, } -struct BertWithStemmerAndSplit(tokenizers::Tokenizer); +#[derive(Clone, Serialize, Deserialize, Validate)] +#[validate(schema(function = "TokenizerConfig::validate_unicode"))] +#[serde(deny_unknown_fields)] +struct TokenizerConfig { + tokenizer: TokenizerKind, + #[serde(default)] + table: Option, + #[serde(default)] + column: Option, +} -impl Default for BertWithStemmerAndSplit { - fn default() -> Self { - Self(tokenizers::Tokenizer::from_bytes(BERT_BASE_UNCASED_BYTES).unwrap()) +impl TokenizerConfig { + fn validate_unicode(&self) -> Result<(), ValidationError> { + if !matches!(self.tokenizer, TokenizerKind::Unicode) { + return Ok(()); + } + if self.table.is_none() { + return Err(ValidationError::new( + "table is required for unicode tokenizer", + )); + } + if self.column.is_none() { + return Err(ValidationError::new( + "column is required for unicode tokenizer", + )); + } + Ok(()) } } -impl Tokenizer for BertWithStemmerAndSplit { - fn encode(&self, text: &str) -> Vec { - let mut results = Vec::new(); - let lower_text = text.to_lowercase(); - let split = TOKEN_PATTERN_RE.find_iter(&lower_text); - for token in split { - if STOP_WORDS_NLTK.contains(token.as_str()) { - continue; - } - let stemmed_token = - tantivy_stemmers::algorithms::english_porter_2(token.as_str()).to_string(); - let encoding = self.0.encode_fast(stemmed_token, false).unwrap(); - results.extend_from_slice(encoding.get_ids()); +extension_sql_file!( + "sql/tokenizer.sql", + name = "tokenizer_table", + requires = [unicode_tokenizer_split] +); + +#[pgrx::pg_extern(requires = ["tokenizer_table"])] +pub fn create_tokenizer(tokenizer_name: &str, config_str: &str) { + if let Err(e) = validate_tokenizer_name(tokenizer_name) { + panic!("Invalid tokenizer name: {}, Details: {}", tokenizer_name, e); + } + + let config: TokenizerConfig = toml::from_str(config_str).unwrap_or_report(); + if let Err(e) = config.validate() { + panic!("Invalid tokenizer config, Details: {}", e); + } + + pgrx::Spi::connect(|mut client| { + let query = "INSERT INTO bm25_catalog.tokenizers (name, config) VALUES ($1, $2)"; + let args = Some(vec![ + ( + pgrx::PgBuiltInOids::TEXTOID.oid(), + tokenizer_name.into_datum(), + ), + (pgrx::PgBuiltInOids::TEXTOID.oid(), config_str.into_datum()), + ]); + client.update(query, None, args).unwrap_or_report(); + if matches!(config.tokenizer, TokenizerKind::Unicode) { + create_unicode_tokenizer_table(&mut client, tokenizer_name, &config); } - results + }); +} + +#[pgrx::pg_extern(requires = ["tokenizer_table"])] +fn drop_tokenizer(tokenizer_name: &str) { + if let Err(e) = validate_tokenizer_name(tokenizer_name) { + panic!("Invalid tokenizer name: {}, Details: {}", tokenizer_name, e); } + + pgrx::Spi::connect(|mut client| { + let query = "SELECT config FROM bm25_catalog.tokenizers WHERE name = $1"; + let args = Some(vec![( + pgrx::PgBuiltInOids::TEXTOID.oid(), + tokenizer_name.into_datum(), + )]); + let mut rows = client.select(query, None, args).unwrap_or_report(); + if rows.len() != 1 { + panic!("Tokenizer not found"); + } + + let config: &str = rows + .next() + .unwrap() + .get(1) + .expect("no config value") + .expect("no config value"); + let config: TokenizerConfig = toml::from_str(config).unwrap_or_report(); + if matches!(config.tokenizer, TokenizerKind::Unicode) { + let table_name = format!("bm25_catalog.\"{}\"", tokenizer_name); + let drop_table = format!("DROP TABLE IF EXISTS {}", table_name); + client.update(&drop_table, None, None).unwrap_or_report(); + let drop_trigger = format!( + "DROP TRIGGER IF EXISTS \"{}_trigger\" ON {}", + tokenizer_name, + config.table.unwrap() + ); + client.update(&drop_trigger, None, None).unwrap_or_report(); + } + + let query = "DELETE FROM bm25_catalog.tokenizers WHERE name = $1"; + let args = Some(vec![( + pgrx::PgBuiltInOids::TEXTOID.oid(), + tokenizer_name.into_datum(), + )]); + client.update(query, None, args).unwrap_or_report(); + }); } -struct Tocken(Tockenizer); +const TOKENIZER_RESERVED_NAMES: [&[u8]; 3] = [b"Bert", b"Tocken", b"tokenizers"]; -impl Tokenizer for Tocken { - fn encode(&self, text: &str) -> Vec { - self.0.tokenize(text) +// 1. It only contains ascii letters, numbers, and underscores. +// 2. It starts with a letter. +// 3. Its length cannot exceed NAMEDATALEN - 1 +// 4. It is not a reserved name. +fn validate_tokenizer_name(name: &str) -> Result<(), String> { + let name = name.as_bytes(); + for &b in name { + if !b.is_ascii_alphanumeric() && b != b'_' { + return Err(format!("Invalid character: {}", b as char)); + } + } + if !(1..=pgrx::pg_sys::NAMEDATALEN as usize - 1).contains(&name.len()) { + return Err(format!( + "Name length must be between 1 and {}", + pgrx::pg_sys::NAMEDATALEN - 1 + )); + } + if !name[0].is_ascii_alphabetic() { + return Err("Name must start with a letter".to_string()); + } + if TOKENIZER_RESERVED_NAMES.contains(&name) { + return Err("The name is reserved, please choose another name".to_string()); + } + + Ok(()) +} + +// 1. create word table +// 2. scan the text and split it into words and insert them into the word table +// 3. create a trigger to insert new words into the word table +fn create_unicode_tokenizer_table( + client: &mut SpiClient<'_>, + name: &str, + config: &TokenizerConfig, +) { + let table_name = format!("bm25_catalog.\"{}\"", name); + let target_table = config.table.as_ref().unwrap(); + let column = config.column.as_ref().unwrap(); + + let create_table = format!( + r#" + CREATE TABLE {} ( + id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + token TEXT NOT NULL UNIQUE + ); + "#, + table_name + ); + client.update(&create_table, None, None).unwrap_or_report(); + + let select_text = format!("SELECT {} FROM {}", column, target_table); + let rows = client.select(&select_text, None, None).unwrap_or_report(); + let mut tokens = HashSet::new(); + for row in rows { + let text: &str = row.get(1).unwrap_or_report().expect("no text value"); + let words = unicode_tokenizer_split(text); + tokens.extend(words); } + + let insert_text = format!( + r#" + INSERT INTO {} (token) VALUES ($1) + "#, + table_name + ); + for token in tokens { + let args = Some(vec![( + pgrx::PgBuiltInOids::TEXTOID.oid(), + token.into_datum(), + )]); + client.update(&insert_text, None, args).unwrap_or_report(); + } + + let trigger = format!( + r#" + CREATE TRIGGER "{}_trigger" + BEFORE INSERT OR UPDATE OF {} + ON {} + FOR EACH ROW + EXECUTE FUNCTION unicode_tokenizer_insert_trigger('{}', '{}'); + "#, + name, column, target_table, name, column + ); + client.update(&trigger, None, None).unwrap_or_report(); +} + +fn unicode_tokenize(client: &SpiClient<'_>, text: &str, tokenizer_name: &str) -> Vec { + let tokens = unicode_tokenizer_split(text); + let query = format!( + "SELECT id, token FROM bm25_catalog.\"{}\" WHERE token = ANY($1)", + tokenizer_name + ); + let args = Some(vec![( + pgrx::PgBuiltInOids::TEXTARRAYOID.oid(), + tokens.clone().into_datum(), + )]); + let rows = client.select(&query, None, args).unwrap_or_report(); + + let mut token_map = HashMap::new(); + for row in rows { + let id: i32 = row.get(1).unwrap_or_report().expect("no id value"); + let id = u32::try_from(id).expect("id is not a valid u32"); + let token: String = row.get(2).unwrap_or_report().expect("no token value"); + token_map.insert(token, id); + } + + tokens + .into_iter() + .filter_map(|token| token_map.get(&token).copied()) + .collect() +} + +#[pgrx::pg_extern(stable, strict, parallel_safe, requires = ["tokenizer_table"])] +pub fn tokenize(content: &str, tokenizer_name: &str) -> Bm25VectorOutput { + let term_ids = match tokenizer_name { + "Bert" => BERT_TOKENIZER.encode(content), + "Tocken" => TOCKENIZER.encode(content), + _ => custom_tokenize(content, tokenizer_name), + }; + Bm25VectorOutput::from_ids(&term_ids) +} + +fn custom_tokenize(text: &str, tokenizer_name: &str) -> Vec { + pgrx::Spi::connect(|client| { + let query = "SELECT config FROM bm25_catalog.tokenizers WHERE name = $1"; + let args = Some(vec![( + pgrx::PgBuiltInOids::TEXTOID.oid(), + tokenizer_name.into_datum(), + )]); + let mut rows = client.select(query, None, args).unwrap_or_report(); + if rows.len() != 1 { + panic!("Tokenizer not found"); + } + + let config: &str = rows + .next() + .unwrap() + .get(1) + .expect("no config value") + .expect("no config value"); + let config: TokenizerConfig = toml::from_str(config).unwrap_or_report(); + match config.tokenizer { + TokenizerKind::Bert => BERT_TOKENIZER.encode(text), + TokenizerKind::Tocken => TOCKENIZER.encode(text), + TokenizerKind::Unicode => unicode_tokenize(&client, text, tokenizer_name), + } + }) +} + +#[pg_trigger] +fn unicode_tokenizer_set_target_column_trigger<'a>( + trigger: &'a pgrx::PgTrigger<'a>, +) -> Result>, ()> { + let mut new = trigger.new().expect("new tuple is missing").into_owned(); + let tg_argv = trigger.extra_args().expect("trigger arguments are missing"); + if tg_argv.len() != 3 { + panic!("Invalid trigger arguments"); + } + let tokenizer_name = &tg_argv[0]; + let source_column = &tg_argv[1]; + let target_column = &tg_argv[2]; + + let source = new + .get_by_name::<&str>(source_column) + .expect("source column is missing"); + let Some(source) = source else { + return Ok(Some(new)); + }; + + let target = tokenize(source, tokenizer_name); + new.set_by_name(target_column, target) + .expect("set target column failed"); + Ok(Some(new)) } diff --git a/tests/sqllogictest/delete.slt b/tests/sqllogictest/delete.slt index 382e555..6683cfe 100644 --- a/tests/sqllogictest/delete.slt +++ b/tests/sqllogictest/delete.slt @@ -21,13 +21,13 @@ statement ok ALTER TABLE documents ADD COLUMN embedding bm25vector; statement ok -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); statement ok CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -49,13 +49,13 @@ INSERT INTO documents (passage) VALUES ('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); statement ok -UPDATE documents SET embedding = tokenize(passage) WHERE embedding IS NULL; +UPDATE documents SET embedding = tokenize(passage, 'Bert') WHERE embedding IS NULL; statement ok VACUUM documents; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -64,7 +64,7 @@ statement ok SET enable_seqscan=off; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; diff --git a/tests/sqllogictest/empty.slt b/tests/sqllogictest/empty.slt index f37ae51..74a4468 100644 --- a/tests/sqllogictest/empty.slt +++ b/tests/sqllogictest/empty.slt @@ -22,7 +22,7 @@ INSERT INTO documents (passage) VALUES ('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -31,16 +31,16 @@ statement ok SET enable_seqscan=off; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; statement ok -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; diff --git a/tests/sqllogictest/index.slt b/tests/sqllogictest/index.slt index c6f25f3..657e29f 100644 --- a/tests/sqllogictest/index.slt +++ b/tests/sqllogictest/index.slt @@ -21,13 +21,13 @@ statement ok ALTER TABLE documents ADD COLUMN embedding bm25vector; statement ok -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); statement ok CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -36,7 +36,7 @@ statement ok SET enable_seqscan=off; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -46,10 +46,10 @@ INSERT INTO documents (passage) VALUES ('vchord_bm25 is a postgresql extension for bm25 ranking algorithm.'); statement ok -UPDATE documents SET embedding = tokenize(passage) WHERE embedding IS NULL; +UPDATE documents SET embedding = tokenize(passage, 'Bert') WHERE embedding IS NULL; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; diff --git a/tests/sqllogictest/temp.slt b/tests/sqllogictest/temp.slt index d4955ba..222a11e 100644 --- a/tests/sqllogictest/temp.slt +++ b/tests/sqllogictest/temp.slt @@ -24,13 +24,13 @@ statement ok ALTER TABLE documents ADD COLUMN embedding bm25vector; statement ok -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); statement ok CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -39,7 +39,7 @@ statement ok SET enable_seqscan=off; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -49,10 +49,10 @@ INSERT INTO documents (passage) VALUES ('vchord_bm25 is a postgresql extension for bm25 ranking algorithm.'); statement ok -UPDATE documents SET embedding = tokenize(passage) WHERE embedding IS NULL; +UPDATE documents SET embedding = tokenize(passage, 'Bert') WHERE embedding IS NULL; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; diff --git a/tests/sqllogictest/tokenizer.slt b/tests/sqllogictest/tokenizer.slt new file mode 100644 index 0000000..90314f4 --- /dev/null +++ b/tests/sqllogictest/tokenizer.slt @@ -0,0 +1,25 @@ +statement error +SELECT tokenize('PostgreSQL'); + +query I +SELECT tokenize('PostgreSQL', 'Bert'); +---- +{2015:1, 2140:1, 2695:1, 4160:1, 17603:1} + +query I +SELECT tokenize('PostgreSQL', 'Tocken'); +---- +{45687:1} + +statement ok +SELECT create_tokenizer('test_bert', $$ +tokenizer = "Bert" +$$); + +query I +SELECT tokenize('PostgreSQL', 'test_bert'); +---- +{2015:1, 2140:1, 2695:1, 4160:1, 17603:1} + +statement ok +SELECT drop_tokenizer('test_bert'); diff --git a/tests/sqllogictest/unicode_tokenizer.slt b/tests/sqllogictest/unicode_tokenizer.slt new file mode 100644 index 0000000..a27fd1b --- /dev/null +++ b/tests/sqllogictest/unicode_tokenizer.slt @@ -0,0 +1,88 @@ +# manual + +statement ok +CREATE TABLE documents ( + id SERIAL PRIMARY KEY, + passage TEXT +); + +statement ok +INSERT INTO documents (passage) VALUES +('PostgreSQL is a powerful, open-source object-relational database system. It has over 15 years of active development.'), +('Full-text search is a technique for searching in plain-text documents or textual database fields. PostgreSQL supports this with tsvector.'), +('BM25 is a ranking function used by search engines to estimate the relevance of documents to a given search query.'), +('PostgreSQL provides many advanced features like full-text search, window functions, and more.'), +('Search and ranking in databases are important in building effective information retrieval systems.'), +('The BM25 ranking algorithm is derived from the probabilistic retrieval framework.'), +('Full-text search indexes documents to allow fast text queries. PostgreSQL supports this through its GIN and GiST indexes.'), +('The PostgreSQL community is active and regularly improves the database system.'), +('Relational databases such as PostgreSQL can handle both structured and unstructured data.'), +('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); + +statement ok +SELECT create_tokenizer('documents_tokenizer', $$ +tokenizer = 'Unicode' +table = 'documents' +column = 'passage' +$$); + +statement ok +ALTER TABLE documents ADD COLUMN embedding bm25vector; + +statement ok +UPDATE documents SET embedding = tokenize(passage, 'documents_tokenizer'); + +statement ok +CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); + +statement ok +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Postgresql', 'documents_tokenizer') AS rank +FROM documents +ORDER BY rank +LIMIT 10; + +statement ok +SELECT drop_tokenizer('documents_tokenizer'); + +statement ok +DROP TABLE documents; + +# trigger + +statement ok +CREATE TABLE documents ( + id SERIAL PRIMARY KEY, + passage TEXT, + embedding bm25vector +); + +statement ok +SELECT create_unicode_tokenizer_and_trigger('Documents_tokenizer', 'documents', 'passage', 'embedding'); + +statement ok +INSERT INTO documents (passage) VALUES +('PostgreSQL is a powerful, open-source object-relational database system. It has over 15 years of active development.'), +('Full-text search is a technique for searching in plain-text documents or textual database fields. PostgreSQL supports this with tsvector.'), +('BM25 is a ranking function used by search engines to estimate the relevance of documents to a given search query.'), +('PostgreSQL provides many advanced features like full-text search, window functions, and more.'), +('Search and ranking in databases are important in building effective information retrieval systems.'), +('The BM25 ranking algorithm is derived from the probabilistic retrieval framework.'), +('Full-text search indexes documents to allow fast text queries. PostgreSQL supports this through its GIN and GiST indexes.'), +('The PostgreSQL community is active and regularly improves the database system.'), +('Relational databases such as PostgreSQL can handle both structured and unstructured data.'), +('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); + +statement ok +CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops); + +statement ok +SELECT id, passage, embedding <&> to_bm25query('Documents_embedding_bm25', 'Postgresql', 'Documents_tokenizer') AS rank +FROM documents +ORDER BY rank +LIMIT 10; + +statement ok +DROP TABLE documents; + +statement ok +SELECT drop_tokenizer('Documents_tokenizer'); diff --git a/tests/sqllogictest/unlogged.slt b/tests/sqllogictest/unlogged.slt index 28e09ea..fbc9d51 100644 --- a/tests/sqllogictest/unlogged.slt +++ b/tests/sqllogictest/unlogged.slt @@ -22,7 +22,7 @@ INSERT INTO documents (passage) VALUES ('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.'); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; @@ -31,16 +31,16 @@ statement ok SET enable_seqscan=off; statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; statement ok -UPDATE documents SET embedding = tokenize(passage); +UPDATE documents SET embedding = tokenize(passage, 'Bert'); statement ok -SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post') AS rank +SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'Post', 'Bert') AS rank FROM documents ORDER BY rank LIMIT 10; diff --git a/tokenizer.md b/tokenizer.md index fdba70b..d403384 100644 --- a/tokenizer.md +++ b/tokenizer.md @@ -2,58 +2,77 @@ Currently, we support the following tokenizers: -- `BERT`: default uncased BERT tokenizer. -- `TOCKEN`: a Unicode tokenizer pre-trained on wiki-103-raw with `min_freq=10`. -- `UNICODE`: a Unicode tokenizer that will be trained on your data. +- `Bert`: default uncased BERT tokenizer. +- `Tocken`: a Unicode tokenizer pre-trained on wiki-103-raw with `min_freq=10`. +- `Unicode`: a Unicode tokenizer that will be trained on your data. ## Usage ### Pre-trained Tokenizer -`BERT` and `TOCKEN` are pre-trained tokenizers. You can use them directly by calling the `tokenize` function. +`Bert` and `Tocken` are pre-trained tokenizers. You can use them directly by calling the `tokenize` function. ```sql -SET bm25_catalog.tokenizer = 'BERT'; -- or 'TOCKEN' -SELECT tokenize('A quick brown fox jumps over the lazy dog.'); +SELECT tokenize('A quick brown fox jumps over the lazy dog.', 'Bert'); -- or 'Tocken' -- {2058:1, 2474:1, 2829:1, 3899:1, 4248:1, 4419:1, 5376:1, 5831:1} ``` ### Train on Your Data -`UNICODE` will be trained on your data during the document tokenization. You can use this function with/without the trigger: +`Unicode` will be trained on your data during the document tokenization. You can use this function with/without the trigger: - with trigger (convenient but slower) ```sql -CREATE TABLE corpus (id TEXT, text TEXT, embedding bm25vector); -CREATE TRIGGER test_trigger AFTER INSERT ON corpus FOR each row execute FUNCTION unicode_tokenizer_trigger('text', 'embedding', 'test_token'); --- insert text to the table +CREATE TABLE corpus (id SERIAL, text TEXT, embedding bm25vector); +SELECT create_unicode_tokenizer_and_trigger('test_token', 'corpus', 'text', 'embedding'); +INSERT INTO corpus (text) VALUES ('PostgreSQL is a powerful, open-source object-relational database system.'); -- insert text to the table CREATE INDEX corpus_embedding_bm25 ON corpus USING bm25 (embedding bm25_ops); -SELECT id, text, embedding <&> bm25_query_unicode_tokenize('documents_embedding_bm25', 'PostgreSQL', 'test_token') AS rank +SELECT id, text, embedding <&> to_bm25query('corpus_embedding_bm25', 'PostgreSQL', 'test_token') AS rank FROM corpus ORDER BY rank LIMIT 10; ``` -- without trigger (faster but need to call the `document_unicode_tokenize` function manually) +- without trigger (faster but need to call the `tokenize` function manually) ```sql -CREATE TABLE corpus (id TEXT, text TEXT); --- insert text to the table -CREATE TABLE bm25_catalog.test_token (id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, token TEXT UNIQUE); -ALTER TABLE corpus ADD COLUMN embedding bm25vector; -UPDATE corpus SET embedding = document_unicode_tokenize(text, 'test_token'); +CREATE TABLE corpus (id SERIAL, text TEXT, embedding bm25vector); +INSERT INTO corpus (text) VALUES ('PostgreSQL is a powerful, open-source object-relational database system.'); -- insert text to the table +SELECT create_tokenizer('test_token', $$ +tokenizer = 'Unicode' +table = 'corpus' +column = 'text' +$$); +UPDATE corpus SET embedding = tokenize(text, 'test_token'); CREATE INDEX corpus_embedding_bm25 ON corpus USING bm25 (embedding bm25_ops); -SELECT id, text, embedding <&> bm25_query_unicode_tokenize('documents_embedding_bm25', 'PostgreSQL', 'test_token') AS rank +SELECT id, text, embedding <&> to_bm25query('corpus_embedding_bm25', 'PostgreSQL', 'test_token') AS rank FROM corpus ORDER BY rank LIMIT 10; ``` +## Configuration + +We utilize [`TOML`](https://toml.io/en/) to configure the tokenizer. You can specify the tokenizer type and the table/column to train on. + +Here is what each field means: + +| Field | Type | Description | +| --------- | ------ | ---------------------------------------------------- | +| tokenizer | String | The tokenizer type (`Bert`, `Tocken`, or `Unicode`). | +| table | String | The table name to train on for Unicode tokenizer. | +| column | String | The column name to train on for Unicode tokenizer. | + +## Note + +- `tokenizer_name` is case-sensitive. Make sure to use the exact name when calling the `tokenize` function. +- `tokenizer_name` can only contain alphanumeric characters and underscores, and it must start with an alphabet. +- `tokenizer_name` is unique. You cannot create two tokenizers with the same name. + ## Contribution To create another tokenizer that is pre-trained on your data, you can follow the steps below: -1. `impl Tokenizer` trait for your tokenizer. +1. update `TOKENIZER_RESERVED_NAMES`, `create_tokenizer`, `drop_tokenizer`, and `tokenize` functions in the [`token.rs`](src/token.rs). 2. (optional) pre-trained data can be stored under the [tokenizer](./tokenizer/) directory. -3. Add your tokenizer to the `GUC TOKENIZER_NAME` match branch in the [`token.rs`](./src/token.rs).