Skip to content

Commit

Permalink
feat: add more tokenizers (#28)
Browse files Browse the repository at this point in the history
- refer to #10 

## features

This added two new tokenizers:
- UNICODE: trained on users' data
- TOCKEN: pre-trained on wiki-103-raw

## TBD

- can not get the `vocab_len` on UNICODE dynamically, related to #25

---------

Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Jan 8, 2025
1 parent 52cbdfb commit 1463c7a
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 29 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ tantivy-stemmers = { version = "0.4.0", features = [
thiserror = "2"
tokenizers = { version = "0.20", default-features = false, features = ["onig"] }

tocken = "0.1.0"
unicode-segmentation = "1.12.0"

[dev-dependencies]
rand = "0.8"

Expand Down
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# VectorChord-BM25

A postgresql extension for bm25 ranking algorithm. We implemented the Block-WeakAnd Algorithms for BM25 ranking inside PostgreSQL. This extension is currently in **alpha** stage and not recommended for production use. We're still iterating on the API and performance. The interface may change in the future.
A PostgreSQL extension for bm25 ranking algorithm. We implemented the Block-WeakAnd Algorithms for BM25 ranking inside PostgreSQL. This extension is currently in **alpha** stage and not recommended for production use. We're still iterating on the API and performance. The interface may change in the future.

## Example

Expand Down Expand Up @@ -64,10 +64,9 @@ You can follow the docs about [`pgvecto.rs`](https://docs.pgvecto.rs/developers/
cargo pgrx install --sudo --release
```

3. Configure your PostgreSQL by modifying the `shared_preload_libraries` and `search_path` to include the extension.
3. Configure your PostgreSQL by modifying `search_path` to include the extension.

```sh
psql -U postgres -c 'ALTER SYSTEM SET shared_preload_libraries = "vchord_bm25.so"'
psql -U postgres -c 'ALTER SYSTEM SET search_path TO "$user", public, bm25_catalog'
# You need restart the PostgreSQL cluster to take effects.
sudo systemctl restart postgresql.service # for vchord_bm25.rs running with systemd
Expand Down Expand Up @@ -96,12 +95,23 @@ CREATE EXTENSION vchord_bm25;
- `tokenize(text) RETURNS bm25vector`: Tokenize the input text into a BM25 vector.
- `to_bm25query(index_name regclass, query text) RETURNS bm25query`: Convert the input text into a BM25 query.
- `bm25vector <&> bm25query RETURNS float4`: Calculate the **negative** BM25 score between the BM25 vector and query.
- `unicode_tokenizer_trigger(text_column text, vec_column text, stored_token_table text) RETURNS TRIGGER`: A trigger function to tokenize the `text_column`, store the vector in `vec_column`, and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.
- `document_unicode_tokenize(content text, stored_token_table text) RETURNS bm25vector`: tokenize the `content` and store the new tokens in the `bm25_catalog.stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.
- `bm25_query_unicode_tokenize(index_name regclass, query text, stored_token_table text) RETURNS bm25query`: Tokenize the `query` into a BM25 query vector according to the tokens stored in `stored_token_table`. For more information, check the [tokenizer](./tokenizer.md) document.

### GUCs

- `bm25_catalog.bm25_limit (integer)`: The maximum number of documents to return in a search. Default is 1, minimum is 1, and maximum is 65535.
- `bm25_catalog.enable_index (boolean)`: Whether to enable the bm25 index. Default is false.
- `bm25_catalog.segment_growing_max_page_size (integer)`: The maximum page count of the growing segment. When the size of the growing segment exceeds this value, the segment will be sealed into a read-only segment. Default is 1, minimum is 1, and maximum is 1,000,000.
- `bm25_catalog.tokenizer (text)`: Tokenizer chosen from:
- `BERT`: default uncased BERT tokenizer.
- `TOCKEN`: a Unicode tokenizer pre-trained on wiki-103-raw.
- `UNICODE`: a Unicode tokenizer that will be trained on your data. (need to work with the trigger function `unicode_tokenizer_trigger`)

## Contribution

- For new tokenizer, check the [tokenizer](./tokenizer.md#contribution) document.

## License

Expand Down
120 changes: 99 additions & 21 deletions src/datatype/functions.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
use std::{collections::BTreeMap, num::NonZero};
use std::{collections::HashMap, num::NonZero};

use pgrx::{pg_sys::panic::ErrorReportable, IntoDatum};

use crate::{
page::{page_read, METAPAGE_BLKNO},
segment::meta::MetaPageData,
segment::term_stat::TermStatReader,
segment::{meta::MetaPageData, term_stat::TermStatReader},
token::unicode_tokenize,
weight::bm25_score_batch,
};

use super::{
memory_bm25vector::{Bm25VectorInput, Bm25VectorOutput},
Bm25VectorBorrowed,
};
use super::memory_bm25vector::{Bm25VectorInput, Bm25VectorOutput};

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
pub fn tokenize(text: &str) -> Bm25VectorOutput {
let term_ids = crate::token::tokenize(text);
let mut map: BTreeMap<u32, u32> = BTreeMap::new();
for term_id in term_ids {
*map.entry(term_id).or_insert(0) += 1;
}
let mut doc_len: u32 = 0;
let mut indexes = Vec::with_capacity(map.len());
let mut values = Vec::with_capacity(map.len());
for (index, value) in map {
indexes.push(index);
values.push(value);
doc_len = doc_len.checked_add(value).expect("overflow");
}
let vector = unsafe { Bm25VectorBorrowed::new_unchecked(doc_len, &indexes, &values) };
Bm25VectorOutput::new(vector)
Bm25VectorOutput::from_ids(&term_ids)
}

#[pgrx::pg_extern(stable, strict, parallel_safe)]
Expand Down Expand Up @@ -66,3 +52,95 @@ pub fn search_bm25query(

scores * -1.0
}

#[pgrx::pg_extern()]
pub fn document_unicode_tokenize(text: &str, token_table: &str) -> Bm25VectorOutput {
let tokens = unicode_tokenize(text);
let args = Some(vec![(
pgrx::PgBuiltInOids::TEXTARRAYOID.oid(),
tokens.clone().into_datum(),
)]);

let mut token_ids = HashMap::new();
pgrx::Spi::connect(|mut client| {
let query = format!(
r#"
WITH new_tokens AS (SELECT unnest($1::text[]) AS token),
to_insert AS (
SELECT token FROM new_tokens
WHERE NOT EXISTS (
SELECT 1 FROM bm25_catalog.{} WHERE token = new_tokens.token
)
),
ins AS (
INSERT INTO bm25_catalog.{} (token)
SELECT token FROM to_insert
ON CONFLICT (token) DO NOTHING
RETURNING id, token
)
SELECT id, token FROM ins
UNION ALL
SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1);
"#,
token_table, token_table, token_table
);
let table = client.update(&query, None, args).unwrap_or_report();
for row in table {
let id: i32 = row
.get_by_name("id")
.expect("no id column")
.expect("no id value");
let token: String = row
.get_by_name("token")
.expect("no token column")
.expect("no token value");
token_ids.insert(token, id as u32);
}
});

let ids = tokens
.iter()
.map(|t| *token_ids.get(t).expect("unknown token"))
.collect::<Vec<_>>();
Bm25VectorOutput::from_ids(&ids)
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
pub fn query_unicode_tokenize(query: &str, token_table: &str) -> Bm25VectorOutput {
let tokens = unicode_tokenize(query);
let args = Some(vec![(
pgrx::PgBuiltInOids::TEXTARRAYOID.oid(),
tokens.clone().into_datum(),
)]);
let mut token_ids = HashMap::new();
pgrx::Spi::connect(|client| {
let table = client
.select(
&format!(
"SELECT id, token FROM bm25_catalog.{} WHERE token = ANY($1);",
token_table
),
None,
args,
)
.unwrap_or_report();
for row in table {
let id: i32 = row
.get_by_name("id")
.expect("no id column")
.expect("no id value");
let token: String = row
.get_by_name("token")
.expect("no token column")
.expect("no token value");
token_ids.insert(token, id as u32);
}
});

let ids = tokens
.iter()
.filter(|&t| token_ids.contains_key(t))
.map(|t| *token_ids.get(t).unwrap())
.collect::<Vec<_>>();
Bm25VectorOutput::from_ids(&ids)
}
18 changes: 18 additions & 0 deletions src/datatype/memory_bm25vector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::{alloc::Layout, ops::Deref, ptr::NonNull};

use pgrx::{
Expand Down Expand Up @@ -106,6 +107,23 @@ impl Bm25VectorOutput {
}
}

pub fn from_ids(ids: &[u32]) -> Self {
let mut map: BTreeMap<u32, u32> = BTreeMap::new();
for term_id in ids {
*map.entry(*term_id).or_insert(0) += 1;
}
let mut doc_len: u32 = 0;
let mut indexes = Vec::with_capacity(map.len());
let mut values = Vec::with_capacity(map.len());
for (index, value) in map {
indexes.push(index);
values.push(value);
doc_len = doc_len.checked_add(value).expect("overflow");
}
let vector = unsafe { Bm25VectorBorrowed::new_unchecked(doc_len, &indexes, &values) };
Self::new(vector)
}

pub fn into_raw(self) -> *mut Bm25VectorHeader {
let result = self.0.as_ptr();
std::mem::forget(self);
Expand Down
12 changes: 12 additions & 0 deletions src/guc.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::ffi::CStr;

use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};

pub static BM25_LIMIT: GucSetting<i32> = GucSetting::<i32>::new(100);
pub static ENABLE_INDEX: GucSetting<bool> = GucSetting::<bool>::new(true);
pub static SEGMENT_GROWING_MAX_PAGE_SIZE: GucSetting<i32> = GucSetting::<i32>::new(1000);
pub static TOKENIZER_NAME: GucSetting<Option<&CStr>> =
GucSetting::<Option<&CStr>>::new(Some(c"BERT"));

pub unsafe fn init() {
GucRegistry::define_int_guc(
Expand Down Expand Up @@ -33,4 +37,12 @@ pub unsafe fn init() {
GucContext::Userset,
GucFlags::default(),
);
GucRegistry::define_string_guc(
"bm25_catalog.tokenizer",
"tokenizer name",
"tokenizer name",
&TOKENIZER_NAME,
GucContext::Userset,
GucFlags::default(),
);
}
37 changes: 37 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,40 @@ CREATE OPERATOR FAMILY bm25_ops USING bm25;

CREATE OPERATOR CLASS bm25_ops FOR TYPE bm25vector USING bm25 FAMILY bm25_ops AS
OPERATOR 1 pg_catalog.<&>(bm25vector, bm25query) FOR ORDER BY float_ops;

-- CREATE TABLE IF NOT EXISTS bm25_catalog.unicode_tokenizer (
-- id INT BY DEFAULT AS IDENTITY PRIMARY KEY,
-- source_table TEXT NOT NULL,
-- source_column TEXT NOT NULL,
-- target_column TEXT NOT NULL,
-- token_table TEXT NOT NULL
-- );

CREATE OR REPLACE FUNCTION unicode_tokenizer_trigger()
RETURNS TRIGGER AS $$
DECLARE
source_column TEXT := TG_ARGV[0];
target_column TEXT := TG_ARGV[1];
token_table TEXT := TG_ARGV[2];
schema_name TEXT := 'bm25_catalog';
result bm25vector;
BEGIN
IF NOT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = schema_name AND table_name = token_table
) THEN
EXECUTE format('CREATE TABLE %I.%I (id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, token TEXT UNIQUE)', schema_name, token_table);
-- EXECUTE format('INSERT INTO %I.unicode_tokenizer (source_table, source_column, target_column, token_table) VALUES (%L, %L, %L, %L)', schema_name, TG_TABLE_NAME, source_column, target_column, token_table);
END IF;

EXECUTE format('select document_unicode_tokenize($1.%I, ''%I'')', source_column, token_table) INTO result USING NEW;
EXECUTE format('UPDATE %I SET %I = %L WHERE id = $1.id', TG_TABLE_NAME, target_column, result) USING NEW;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION bm25_query_unicode_tokenize(index_oid regclass, query text, token_table text) RETURNS bm25query
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE sql AS $$
SELECT index_oid, query_unicode_tokenize(query, token_table);
$$;
Loading

0 comments on commit 1463c7a

Please sign in to comment.