diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4c6a79..5b6f8ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: Cargo Build & Test +name: Cargo Build & Test the Crate on: push: diff --git a/.github/workflows/ci_python.yml b/.github/workflows/ci_python.yml new file mode 100644 index 0000000..38a8da5 --- /dev/null +++ b/.github/workflows/ci_python.yml @@ -0,0 +1,231 @@ +name: Cargo Build & Test the Python Bindings + +defaults: + run: + working-directory: python + +on: + push: + branches: + - main + tags: + - "*" + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + format: + name: Check Python format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v5 + - name: Install dependencies + run: pip install ruff black + - name: Ruff + run: ruff check . + - name: Black + run: black --check --diff . + + rustfmt: + name: Check Rust format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - run: rustup update stable && rustup default stable + - run: rustup component add rustfmt + - run: cargo fmt --all --check + + test: + name: Run tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "pypy3.10"] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install locally + run: pip install -e ".[test]" + - name: Install additional dependencies + run: pip install pytest-md pytest-emoji + - uses: pavelzw/pytest-action@v2 + with: + emoji: false + verbose: true + job-summary: true + - name: Test building wheels + uses: PyO3/maturin-action@v1 + with: + sccache: true + manylinux: auto + + linux: + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, aarch64, armv7] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10 + sccache: true + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux-${{ matrix.target }} + path: dist + - name: pytest + if: ${{ startsWith(matrix.target, 'x86_64') }} + shell: bash + run: | + set -e + pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall + pytest --import-mode=importlib + - name: pytest + if: ${{ !startsWith(matrix.target, 'x86') && matrix.target != 'ppc64' }} + uses: uraimo/run-on-arch-action@v2.7.1 + with: + arch: ${{ matrix.target }} + distro: ubuntu22.04 + githubToken: ${{ github.token }} + install: | + apt-get update + apt-get install -y --no-install-recommends python3 python3-pip + pip3 install -U pip + run: | + set -e + pip3 install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall + pytest --import-mode=importlib + + windows: + runs-on: windows-latest + strategy: + matrix: + target: [x64] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10 + sccache: true + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows-${{ matrix.target }} + path: dist + - name: pytest + if: ${{ !startsWith(matrix.target, 'aarch64') }} + shell: bash + run: | + set -e + pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall + pytest --import-mode=importlib + + macos: + runs-on: macos-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10 + sccache: true + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-${{ matrix.target }} + path: dist + - name: pytest + if: ${{ !startsWith(matrix.target, 'aarch64') }} + shell: bash + run: | + set -e + pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall + pytest --import-mode=importlib + + sdist: + needs: [test] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: wheels-sdist + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: "startsWith(github.ref, 'refs/tags/')" + needs: [test, format, rustfmt, linux, windows, macos, sdist] + permissions: + # Used to upload release artifacts + contents: write + steps: + - uses: actions/download-artifact@v4 + with: + pattern: wheels-* + merge-multiple: true + - name: Publish to PyPI + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing * + - name: Upload to GitHub Release + uses: softprops/action-gh-release@v2 + with: + files: | + *.whl + *.tar.gz + prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') }} diff --git a/Cargo.toml b/Cargo.toml index fa00f84..6a448a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ description = "Token healing implementation" repository = "https://github.com/ModelTC/mtc-token-healing" homepage = "https://github.com/ModelTC/mtc-token-healing" documentation = "https://docs.rs/mtc-token-healing" +authors = ["Chielo Newctle "] [package] name = "mtc-token-healing" @@ -19,16 +20,16 @@ description.workspace = true repository.workspace = true homepage.workspace = true documentation.workspace = true +authors.workspace = true readme = "README.md" -authors = ["Chielo Newctle "] -exclude = ["release-plz.toml", ".github"] +exclude = ["release-plz.toml", ".github", "python"] [dependencies] derive_more = "0.99.17" general-sam = { version = "1.0.0", features = ["trie"] } pyo3 = { version = "0.21.2", optional = true } smallvec = "1.13.2" -thiserror = "1.0.59" +thiserror = "1.0.60" [features] pyo3 = ["dep:pyo3"] @@ -38,7 +39,7 @@ clap = { version = "4.5.4", features = ["derive", "env"] } color-eyre = "0.6.3" rand = "0.8.5" regex = "1.10.4" -serde_json = "1.0.116" +serde_json = "1.0.117" tokenizers = { version = "0.19.1", features = ["hf-hub", "http"] } tokio = { version = "1.37.0", features = ["rt-multi-thread"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 52d1c2d..f47f7b1 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -7,6 +7,7 @@ description.workspace = true repository.workspace = true homepage.workspace = true documentation.workspace = true +authors.workspace = true [lib] name = "mtc_token_healing" diff --git a/python/mtc_token_healing/__init__.py b/python/mtc_token_healing/__init__.py index 0fb207b..bfad032 100644 --- a/python/mtc_token_healing/__init__.py +++ b/python/mtc_token_healing/__init__.py @@ -1,5 +1,21 @@ -from .mtc_token_healing import CountInfo +from .mtc_token_healing import ( + BestChoice, + CountInfo, + InferRequest, + InferResponse, + Prediction, + VocabPrefixAutomaton, + ReorderedTokenId, + SearchTree, +) __all__ = [ + "BestChoice", "CountInfo", + "InferRequest", + "InferResponse", + "Prediction", + "VocabPrefixAutomaton", + "ReorderedTokenId", + "SearchTree", ] diff --git a/python/mtc_token_healing/mtc_token_healing.pyi b/python/mtc_token_healing/mtc_token_healing.pyi index aea03d5..58dfcae 100644 --- a/python/mtc_token_healing/mtc_token_healing.pyi +++ b/python/mtc_token_healing/mtc_token_healing.pyi @@ -1 +1,10 @@ +TokenId = int + +class BestChoice: ... class CountInfo: ... +class InferRequest: ... +class InferResponse: ... +class Prediction: ... +class VocabPrefixAutomaton: ... +class ReorderedTokenId: ... +class SearchTree: ... diff --git a/python/pyproject.toml b/python/pyproject.toml index be94c4f..acedf13 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -12,5 +12,7 @@ classifiers = [ ] dynamic = ["version"] +[tool.maturin] + [project.optional-dependencies] test = ["pytest"] diff --git a/python/src/lib.rs b/python/src/lib.rs index 578049b..3476516 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,8 +1,18 @@ -use ::mtc_token_healing::CountInfo; +use ::mtc_token_healing::{ + vocab::PyVocabPrefixAutomaton, BestChoice, CountInfo, InferRequest, InferResponse, Prediction, + ReorderedTokenId, SearchTree, +}; use pyo3::prelude::*; #[pymodule] fn mtc_token_healing(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/choice.rs b/src/choice.rs index 7476c53..a243734 100644 --- a/src/choice.rs +++ b/src/choice.rs @@ -1,6 +1,7 @@ use crate::TokenId; #[derive(Clone, Debug)] +#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, frozen))] pub struct BestChoice { pub extra_token_ids: Vec, pub accum_log_prob: f64, diff --git a/src/search_tree.rs b/src/search_tree.rs index eb3e4bd..eedc03d 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -28,12 +28,14 @@ pub enum SearchTreeError { pub type SearchTreeResult = Result; #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, set_all))] pub struct Prediction { pub token_id: ReorderedTokenId, pub log_prob: f64, } #[derive(Clone, Debug)] +#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, frozen))] pub struct InferRequest { pub backtrace: usize, pub feed: Option, @@ -42,6 +44,7 @@ pub struct InferRequest { } #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, set_all))] pub struct InferResponse { pub sampled: Option, pub sparse_choices: Vec, @@ -56,6 +59,7 @@ struct SearchState { } #[derive(Debug)] +#[cfg_attr(feature = "pyo3", pyo3::pyclass)] pub struct SearchTree { automaton: Arc, @@ -302,3 +306,104 @@ impl SearchTree { self.max_num_tokens } } + +#[cfg(feature = "pyo3")] +mod _pyo3 { + use std::collections::BTreeMap; + + use pyo3::{ + exceptions::PyValueError, pymethods, types::PyType, Bound, PyErr, PyObject, PyRefMut, + PyResult, Python, + }; + + use crate::{ + vocab::PyVocabPrefixAutomaton, BestChoice, InferRequest, InferResponse, Prediction, + ReorderedTokenId, SearchTree, SearchTreeError, TokenId, + }; + + use super::SearchTreeResult; + + impl From for PyErr { + fn from(value: SearchTreeError) -> Self { + PyValueError::new_err(value.to_string()) + } + } + + #[pymethods] + impl Prediction { + #[new] + pub fn new_py(token_id: ReorderedTokenId, log_prob: f64) -> Self { + Self { token_id, log_prob } + } + } + + #[pymethods] + impl InferResponse { + #[new] + pub fn new_py( + sampled: Option, + sparse_choices: Option>, + ) -> Self { + Self { + sampled, + sparse_choices: sparse_choices.unwrap_or_default(), + } + } + } + + #[pymethods] + impl SearchTree { + #[pyo3(name = "get_prefilled_token_ids")] + pub fn prefilled_token_ids_py(&self) -> Vec { + self.prefilled_token_ids.clone() + } + + #[pyo3(name = "get_best_choice")] + pub fn get_best_choice_py(&self) -> SearchTreeResult { + if self.best_choice.valid() { + Ok(self.best_choice.clone()) + } else { + Err(SearchTreeError::NoSampledResult) + } + } + + #[classmethod] + #[pyo3(name = "new")] + pub fn new_py( + _cls: &Bound<'_, PyType>, + py: Python<'_>, + automaton: &'_ PyVocabPrefixAutomaton, + tokenize_for_multiple_ending_positions: PyObject, + text: &str, + start_from: usize, + ) -> PyResult> { + let pos_to_cnt_info = + BTreeMap::from_iter(automaton.as_ref().parse_chars(text, start_from)); + + let end_pos = Vec::from_iter(pos_to_cnt_info.keys().copied()); + + let encoded: Vec<(usize, Vec)> = tokenize_for_multiple_ending_positions + .call1(py, (end_pos,))? + .extract(py)?; + + Ok(Self::from_encoded( + automaton.0.clone(), + pos_to_cnt_info, + encoded, + )) + } + + #[pyo3(name = "feed")] + pub fn feed_py( + mut self_: PyRefMut<'_, Self>, + res: InferResponse, + ) -> SearchTreeResult> { + self_.feed(res) + } + + #[getter("max_num_tokens")] + pub fn max_num_tokens_py(&self) -> usize { + self.max_num_tokens + } + } +} diff --git a/src/utils.rs b/src/utils.rs index f4e270b..5d0ee3a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -4,8 +4,6 @@ use general_sam::{ BTreeTransTable, BoxBisectTable, GeneralSam, TransitionTable, Trie, TrieNodeAlike, SAM_ROOT_NODE_ID, }; -#[cfg(feature = "pyo3")] -use pyo3::pyclass; use smallvec::SmallVec; pub type TokenId = u32; @@ -22,10 +20,40 @@ pub type TokenId = u32; PartialOrd, Ord, )] +#[cfg_attr(feature = "pyo3", pyo3::pyclass)] pub struct ReorderedTokenId(pub u32); +#[cfg(feature = "pyo3")] +mod _pyo3 { + use pyo3::pymethods; + + use crate::ReorderedTokenId; + + #[pymethods] + impl ReorderedTokenId { + #[new] + fn new(value: u32) -> Self { + Self(value) + } + + pub fn __int__(&self) -> u32 { + self.0 + } + + #[getter] + pub fn get_value(&self) -> u32 { + self.0 + } + + #[setter] + pub fn set_value(&mut self, value: u32) { + self.0 = value; + } + } +} + #[derive(Clone, Debug, Default, PartialEq, Eq)] -#[cfg_attr(feature = "pyo3", pyclass(get_all, set_all))] +#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, set_all))] pub struct CountInfo { pub cnt: usize, pub tot_cnt_lower: usize, diff --git a/src/vocab.rs b/src/vocab.rs index 15e187f..58d1497 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -1,6 +1,4 @@ use general_sam::{BoxBisectTable, GeneralSam}; -#[cfg(feature = "pyo3")] -use pyo3::pyclass; use crate::{ utils::{build_sam_of_reversed_tokens, gen_sam_cnt_info, sort_vocab_with_trie, TokenBytes}, @@ -8,7 +6,6 @@ use crate::{ }; #[derive(Clone, Debug)] -#[cfg_attr(feature = "pyo3", pyclass)] pub struct VocabPrefixAutomaton { vocab: Vec, order: Vec, @@ -83,17 +80,23 @@ impl VocabPrefixAutomaton { #[cfg(feature = "pyo3")] mod _pyo3 { - use pyo3::pymethods; + use std::sync::Arc; + + use pyo3::{pyclass, pymethods}; - use crate::utils::CountInfo; + use crate::{utils::CountInfo, ReorderedTokenId, TokenId}; use super::VocabPrefixAutomaton; + #[derive(Clone, Debug, derive_more::Deref)] + #[pyclass(name = "VocabPrefixAutomaton", frozen)] + pub struct PyVocabPrefixAutomaton(pub Arc); + #[pymethods] - impl VocabPrefixAutomaton { + impl PyVocabPrefixAutomaton { #[new] - fn py_new(vocab: Vec>) -> Self { - Self::new(vocab) + fn py_new(vocab: Vec) -> Self { + Self(Arc::new(VocabPrefixAutomaton::new(vocab))) } #[pyo3(name = "vocab_size")] @@ -102,21 +105,21 @@ mod _pyo3 { } #[pyo3(name = "get_order")] - fn get_order_py(&self) -> Vec { + fn get_order_py(&self) -> Vec { self.order.clone() } #[pyo3(name = "get_rank")] - fn get_rank_py(&self) -> Vec { - self.rank.iter().map(|x| x.0).collect() + fn get_rank_py(&self) -> Vec { + self.rank.clone() } #[pyo3(name = "parse_chars")] fn parse_chars_py(&self, text: &str, start_from: usize) -> Vec<(usize, CountInfo)> { self.parse_chars(text, start_from) - .into_iter() - .map(|(pos, cnt_info)| (pos, cnt_info.clone())) - .collect() } } } + +#[cfg(feature = "pyo3")] +pub use self::_pyo3::PyVocabPrefixAutomaton;