Skip to content

Commit

Permalink
refactor: api about tokenizer
Browse files Browse the repository at this point in the history
Signed-off-by: Mingzhuo Yin <[email protected]>
  • Loading branch information
silver-ymz committed Jan 15, 2025
1 parent 0e64b6f commit abe6e88
Show file tree
Hide file tree
Showing 15 changed files with 557 additions and 251 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ tantivy-stemmers = { version = "0.4.0", features = [
thiserror = "2"
tokenizers = { version = "0.20", default-features = false, features = ["onig"] }

serde = { version = "1.0.217", features = ["derive"] }
tocken = "0.1.0"
toml = "0.8.19"
unicode-segmentation = "1.12.0"
validator = { version = "0.19.0", features = ["derive"] }

[dev-dependencies]
rand = "0.8"
Expand Down
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ A PostgreSQL extension for bm25 ranking algorithm. We implemented the Block-Weak
```sql
CREATE TABLE documents (
id SERIAL PRIMARY KEY,
passage TEXT
passage TEXT,
embedding bm25vector
);

INSERT INTO documents (passage) VALUES
Expand All @@ -22,13 +23,11 @@ INSERT INTO documents (passage) VALUES
('Relational databases such as PostgreSQL can handle both structured and unstructured data.'),
('Effective search ranking algorithms, such as BM25, improve search results by understanding relevance.');

ALTER TABLE documents ADD COLUMN embedding bm25vector;

UPDATE documents SET embedding = tokenize(passage);
UPDATE documents SET embedding = tokenize(passage, 'Bert');

CREATE INDEX documents_embedding_bm25 ON documents USING bm25 (embedding bm25_ops);

SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'PostgreSQL') AS rank
SELECT id, passage, embedding <&> to_bm25query('documents_embedding_bm25', 'PostgreSQL', 'Bert') AS rank
FROM documents
ORDER BY rank
LIMIT 10;
Expand Down Expand Up @@ -92,22 +91,20 @@ CREATE EXTENSION vchord_bm25;

### Functions

- `tokenize(text) RETURNS bm25vector`: Tokenize the input text into a BM25 vector.
- `to_bm25query(index_name regclass, query text) RETURNS bm25query`: Convert the input text into a BM25 query.
- `create_tokenizer(tokenizer_name text, config text)`: Create a tokenizer with the given name and configuration.
- `create_unicode_tokenizer_and_trigger(tokenizer_name text, table_name text, source_column text, target_column text)`: Create a Unicode tokenizer and trigger function for the given table and columns. It will automatically build the tokenizer according to source_column and store the result in target_column.
- `drop_tokenizer(tokenizer_name text)`: Drop the tokenizer with the given name.
- `tokenize(content text, tokenizer_name text) RETURNS bm25vector`: Tokenize the content text into a BM25 vector.
- `to_bm25query(index_name regclass, query text, tokenizer_name text) RETURNS bm25query`: Convert the input text into a BM25 query.
- `bm25vector <&> bm25query RETURNS float4`: Calculate the **negative** BM25 score between the BM25 vector and query.
- `unicode_tokenizer_trigger(text_column text, vec_column text, stored_token_table text) RETURNS TRIGGER`: A trigger function to tokenize the `text_column`, store the vector in `vec_column`, and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.
- `document_unicode_tokenize(content text, stored_token_table text) RETURNS bm25vector`: tokenize the `content` and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.
- `bm25_query_unicode_tokenize(index_name regclass, query text, stored_token_table text) RETURNS bm25query`: Tokenize the `query` into a BM25 query vector according to the tokens stored in `stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.

For more information about tokenizer, check the [tokenizer](./tokenizer.md) document.

### GUCs

- `bm25_catalog.bm25_limit (integer)`: The maximum number of documents to return in a search. Default is 1, minimum is 1, and maximum is 65535.
- `bm25_catalog.enable_index (boolean)`: Whether to enable the bm25 index. Default is false.
- `bm25_catalog.segment_growing_max_page_size (integer)`: The maximum page count of the growing segment. When the size of the growing segment exceeds this value, the segment will be sealed into a read-only segment. Default is 1, minimum is 1, and maximum is 1,000,000.
- `bm25_catalog.tokenizer (text)`: Tokenizer chosen from:
- `BERT`: default uncased BERT tokenizer.
- `TOCKEN`: a Unicode tokenizer pre-trained on wiki-103-raw.
- `UNICODE`: a Unicode tokenizer that will be trained on your data. (need to work with the trigger function `unicode_tokenizer_trigger`)

## Contribution

Expand Down
103 changes: 1 addition & 102 deletions src/datatype/functions.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
use std::{collections::HashMap, num::NonZero};

use pgrx::{pg_sys::panic::ErrorReportable, IntoDatum};
use std::num::NonZero;

use crate::{
page::{page_read, METAPAGE_BLKNO},
segment::{meta::MetaPageData, term_stat::TermStatReader},
token::unicode_tokenize,
weight::bm25_score_batch,
};

use super::memory_bm25vector::{Bm25VectorInput, Bm25VectorOutput};

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
pub fn tokenize(text: &str) -> Bm25VectorOutput {
let term_ids = crate::token::tokenize(text);
Bm25VectorOutput::from_ids(&term_ids)
}

#[pgrx::pg_extern(stable, strict, parallel_safe)]
pub fn search_bm25query(
target_vector: Bm25VectorInput,
Expand Down Expand Up @@ -52,95 +43,3 @@ pub fn search_bm25query(

scores * -1.0
}

#[pgrx::pg_extern()]
pub fn document_unicode_tokenize(text: &str, token_table: &str) -> Bm25VectorOutput {
let tokens = unicode_tokenize(text);
let args = Some(vec![(
pgrx::PgBuiltInOids::TEXTARRAYOID.oid(),
tokens.clone().into_datum(),
)]);

let mut token_ids = HashMap::new();
pgrx::Spi::connect(|mut client| {
let query = format!(
r#"
WITH new_tokens AS (SELECT unnest($1::text[]) AS token),
to_insert AS (
SELECT token FROM new_tokens
WHERE NOT EXISTS (
SELECT 1 FROM bm25_catalog.{} WHERE token = new_tokens.token
)
),
ins AS (
INSERT INTO bm25_catalog.{} (token)
SELECT token FROM to_insert
ON CONFLICT (token) DO NOTHING
RETURNING id, token
)
SELECT id, token FROM ins
UNION ALL
SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1);
"#,
token_table, token_table, token_table
);
let table = client.update(&query, None, args).unwrap_or_report();
for row in table {
let id: i32 = row
.get_by_name("id")
.expect("no id column")
.expect("no id value");
let token: String = row
.get_by_name("token")
.expect("no token column")
.expect("no token value");
token_ids.insert(token, id as u32);
}
});

let ids = tokens
.iter()
.map(|t| *token_ids.get(t).expect("unknown token"))
.collect::<Vec<_>>();
Bm25VectorOutput::from_ids(&ids)
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
pub fn query_unicode_tokenize(query: &str, token_table: &str) -> Bm25VectorOutput {
let tokens = unicode_tokenize(query);
let args = Some(vec![(
pgrx::PgBuiltInOids::TEXTARRAYOID.oid(),
tokens.clone().into_datum(),
)]);
let mut token_ids = HashMap::new();
pgrx::Spi::connect(|client| {
let table = client
.select(
&format!(
"SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1);",
token_table
),
None,
args,
)
.unwrap_or_report();
for row in table {
let id: i32 = row
.get_by_name("id")
.expect("no id column")
.expect("no id value");
let token: String = row
.get_by_name("token")
.expect("no token column")
.expect("no token value");
token_ids.insert(token, id as u32);
}
});

let ids = tokens
.iter()
.filter(|&t| token_ids.contains_key(t))
.map(|t| *token_ids.get(t).unwrap())
.collect::<Vec<_>>();
Bm25VectorOutput::from_ids(&ids)
}
12 changes: 0 additions & 12 deletions src/guc.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use std::ffi::CStr;

use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};

pub static BM25_LIMIT: GucSetting<i32> = GucSetting::<i32>::new(100);
pub static ENABLE_INDEX: GucSetting<bool> = GucSetting::<bool>::new(true);
pub static SEGMENT_GROWING_MAX_PAGE_SIZE: GucSetting<i32> = GucSetting::<i32>::new(1000);
pub static TOKENIZER_NAME: GucSetting<Option<&CStr>> =
GucSetting::<Option<&CStr>>::new(Some(c"BERT"));

pub unsafe fn init() {
GucRegistry::define_int_guc(
Expand Down Expand Up @@ -37,12 +33,4 @@ pub unsafe fn init() {
GucContext::Userset,
GucFlags::default(),
);
GucRegistry::define_string_guc(
"bm25_catalog.tokenizer",
"tokenizer name",
"tokenizer name",
&TOKENIZER_NAME,
GucContext::Userset,
GucFlags::default(),
);
}
43 changes: 3 additions & 40 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ CREATE TYPE bm25query AS (
query_vector bm25vector
);

CREATE FUNCTION to_bm25query(index_oid regclass, query_str text) RETURNS bm25query
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$
SELECT index_oid, tokenize(query_str);
CREATE FUNCTION to_bm25query(index_oid regclass, query_str text, tokenizer_name text) RETURNS bm25query
STABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$
SELECT index_oid, tokenize(query_str, tokenizer_name);
$$;

CREATE ACCESS METHOD bm25 TYPE INDEX HANDLER _bm25_amhandler;
Expand All @@ -31,40 +31,3 @@ CREATE OPERATOR FAMILY bm25_ops USING bm25;

CREATE OPERATOR CLASS bm25_ops FOR TYPE bm25vector USING bm25 FAMILY bm25_ops AS
OPERATOR 1 pg_catalog.<&>(bm25vector, bm25query) FOR ORDER BY float_ops;

-- CREATE TABLE IF NOT EXISTS bm25_catalog.unicode_tokenizer (
-- id INT BY DEFAULT AS IDENTITY PRIMARY KEY,
-- source_table TEXT NOT NULL,
-- source_column TEXT NOT NULL,
-- target_column TEXT NOT NULL,
-- token_table TEXT NOT NULL
-- );

CREATE OR REPLACE FUNCTION unicode_tokenizer_trigger()
RETURNS TRIGGER AS $$
DECLARE
source_column TEXT := TG_ARGV[0];
target_column TEXT := TG_ARGV[1];
token_table TEXT := TG_ARGV[2];
schema_name TEXT := 'bm25_catalog';
result bm25vector;
BEGIN
IF NOT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = schema_name AND table_name = token_table
) THEN
EXECUTE format('CREATE TABLE %I.%I (id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, token TEXT UNIQUE)', schema_name, token_table);
-- EXECUTE format('INSERT INTO %I.unicode_tokenizer (source_table, source_column, target_column, token_table) VALUES (%L, %L, %L, %L)', schema_name, TG_TABLE_NAME, source_column, target_column, token_table);
END IF;

EXECUTE format('select document_unicode_tokenize($1.%I, ''%I'')', source_column, token_table) INTO result USING NEW;
EXECUTE format('UPDATE %I SET %I = %L WHERE id = $1.id', TG_TABLE_NAME, target_column, result) USING NEW;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION bm25_query_unicode_tokenize(index_oid regclass, query text, token_table text) RETURNS bm25query
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$
SELECT index_oid, query_unicode_tokenize(query, token_table);
$$;
38 changes: 38 additions & 0 deletions src/sql/tokenizer.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
CREATE TABLE bm25_catalog.tokenizers (
name TEXT NOT NULL UNIQUE PRIMARY KEY,
config TEXT NOT NULL
);

CREATE FUNCTION unicode_tokenizer_insert_trigger()
RETURNS TRIGGER AS $$
DECLARE
tokenizer_name TEXT := TG_ARGV[0];
target_column TEXT := TG_ARGV[1];
BEGIN
EXECUTE format('
WITH new_tokens AS (
SELECT unnest(unicode_tokenizer_split($1.%I)) AS token
),
to_insert AS (
SELECT token FROM new_tokens
WHERE NOT EXISTS (
SELECT 1 FROM bm25_catalog.%I WHERE token = new_tokens.token
)
)
INSERT INTO bm25_catalog.%I (token) SELECT token FROM to_insert ON CONFLICT (token) DO NOTHING', target_column, tokenizer_name, tokenizer_name) USING NEW;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION create_unicode_tokenizer_and_trigger(tokenizer_name TEXT, table_name TEXT, source_column TEXT, target_column TEXT)
RETURNS VOID AS $body$
BEGIN
EXECUTE format('SELECT create_tokenizer(%L, $$
tokenizer = ''Unicode''
table = %L
column = %L
$$)', tokenizer_name, table_name, source_column);
EXECUTE format('UPDATE %I SET %I = tokenize(%I, %L)', table_name, target_column, source_column, tokenizer_name);
EXECUTE format('CREATE TRIGGER "%s_trigger_insert" BEFORE INSERT OR UPDATE OF %I ON %I FOR EACH ROW EXECUTE FUNCTION unicode_tokenizer_set_target_column_trigger(%L, %I, %I)', tokenizer_name, source_column, table_name, tokenizer_name, source_column, target_column);
END;
$body$ LANGUAGE plpgsql;
Loading

0 comments on commit abe6e88

Please sign in to comment.