From 4bff3e4c82e47a95ade06ef98c1fab71555e8437 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 26 Sep 2023 19:49:20 -0700 Subject: [PATCH 001/301] make gvmprog into gvm_abi lib --- gvm_abi/.cargo/config.toml | 2 + gvm_abi/Cargo.lock | 7 +++ gvm_abi/Cargo.toml | 9 +++ gvm_abi/build.sh | 10 ++++ gvm_abi/src/lib.rs | 120 +++++++++++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+) create mode 100644 gvm_abi/.cargo/config.toml create mode 100644 gvm_abi/Cargo.lock create mode 100644 gvm_abi/Cargo.toml create mode 100755 gvm_abi/build.sh create mode 100644 gvm_abi/src/lib.rs diff --git a/gvm_abi/.cargo/config.toml b/gvm_abi/.cargo/config.toml new file mode 100644 index 00000000..f4e8c002 --- /dev/null +++ b/gvm_abi/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +target = "wasm32-unknown-unknown" diff --git a/gvm_abi/Cargo.lock b/gvm_abi/Cargo.lock new file mode 100644 index 00000000..d95b4cac --- /dev/null +++ b/gvm_abi/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "gvm_abi" +version = "0.1.0" diff --git a/gvm_abi/Cargo.toml b/gvm_abi/Cargo.toml new file mode 100644 index 00000000..ed3df9ec --- /dev/null +++ b/gvm_abi/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "gvm_abi" +version = "0.1.0" +edition = "2021" + +[lib] +name = "gvm_abi" + +[dependencies] diff --git a/gvm_abi/build.sh b/gvm_abi/build.sh new file mode 100755 index 00000000..2cfcc76a --- /dev/null +++ b/gvm_abi/build.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +set -x +set -e +cargo build --release +wasm-opt -Oz target/wasm32-unknown-unknown/release/gvmprog.wasm -o target/opt.wasm +wasm-strip target/opt.wasm +p=`pwd` +cd ../gvmrt +cargo run -- --module $p/target/opt.wasm diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs new file mode 100644 index 00000000..6050a78a --- /dev/null +++ b/gvm_abi/src/lib.rs @@ -0,0 +1,120 @@ +/// Expose method as extern "C", usage: +/// expose!(Foo::set_count(n: i32) -> i32); +/// Generates "C" function: +/// set_count(Foo *, i32) -> i32 +#[macro_export] +macro_rules! expose { + ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$method_name($($arg),*) + } + } + }; + ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$field.$method_name($($arg),*) + } + } + }; +} + +#[derive(Clone)] +pub struct GuidanceVmHelper { + tokens: Vec, + prompt_length: usize, + logit_biases: Vec, +} + +// gvm_* are exposed to C in both GuidanceVm and GuidanceVmHelper +impl GuidanceVmHelper { + pub fn new() -> Self { + GuidanceVmHelper { + tokens: Vec::new(), + prompt_length: 0, + logit_biases: Vec::new(), + } + } + pub fn gvm_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { + self.logit_biases.resize(size as usize, 0.0); + self.logit_biases.as_mut_ptr() + } + pub fn gvm_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { + self.prompt_length = size as usize; + self.tokens.resize(self.prompt_length, 0); + self.tokens.as_mut_ptr() + } +} + +pub trait GuidanceVm { + /// Create a new instance of VM + fn gvm_create() -> Self; + /// Create a new instance of VM, based on existing instance, for example when doing beam-search. + fn gvm_clone(&mut self) -> Self; + /// The prompt is in self.helper.tokens. + /// On return, self.helper.logit_biases are supposed to be updated. + fn gvm_process_prompt(&mut self); + /// On return, self.helper.logit_biases are supposed to be updated. + fn gvm_append_token(&mut self, token: u32); +} + +struct MyGvm { + helper: GuidanceVmHelper, +} + +impl GuidanceVm for MyGvm { + fn gvm_create() -> Self { + MyGvm { + helper: GuidanceVmHelper::new(), + } + } + + fn gvm_process_prompt(&mut self) {} + + fn gvm_append_token(&mut self, token: u32) { + let toks = &mut self.helper.tokens; + toks.push(token); + // finish generation at 10 tokens + if toks.len() - self.helper.prompt_length >= 3 { + self.helper.logit_biases[50256] = 100.0 + } else { + self.helper.logit_biases[50256] = -100.0 + } + } + + fn gvm_clone(&mut self) -> Self { + MyGvm { + helper: self.helper.clone(), + } + } +} + +#[macro_export] +macro_rules! gvm_expose_all { + ($struct_name:ident ) => { + expose!($struct_name::gvm_process_prompt() -> ()); + expose!($struct_name::gvm_append_token(token: u32) -> ()); + expose!($struct_name::helper::gvm_get_logit_bias_buffer(size: u32) -> *mut f32); + expose!($struct_name::helper::gvm_get_prompt_buffer(size: u32) -> *mut u32); + + #[no_mangle] + pub extern "C" fn gvm_create() -> *mut $struct_name { + let b = Box::new($struct_name::gvm_create()); + Box::into_raw(b) + } + + #[no_mangle] + pub extern "C" fn gvm_clone(self_: *mut $struct_name) -> *mut $struct_name { + let b = unsafe { (&mut *self_).gvm_clone() }; + Box::into_raw(Box::new(b)) + } + + #[no_mangle] + pub extern "C" fn gvm_free(self_: *mut $struct_name) { + let _drop = unsafe { Box::from_raw(self_) }; + } + } +} From 7273e98474e1bd7d14fd68b575c66c3736d2fe63 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 26 Sep 2023 19:55:57 -0700 Subject: [PATCH 002/301] use the new lib --- gvm_abi/build.sh | 10 ---------- gvm_abi/src/lib.rs | 37 +++---------------------------------- 2 files changed, 3 insertions(+), 44 deletions(-) delete mode 100755 gvm_abi/build.sh diff --git a/gvm_abi/build.sh b/gvm_abi/build.sh deleted file mode 100755 index 2cfcc76a..00000000 --- a/gvm_abi/build.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/sh - -set -x -set -e -cargo build --release -wasm-opt -Oz target/wasm32-unknown-unknown/release/gvmprog.wasm -o target/opt.wasm -wasm-strip target/opt.wasm -p=`pwd` -cd ../gvmrt -cargo run -- --module $p/target/opt.wasm diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 6050a78a..2748138d 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -24,9 +24,9 @@ macro_rules! expose { #[derive(Clone)] pub struct GuidanceVmHelper { - tokens: Vec, - prompt_length: usize, - logit_biases: Vec, + pub tokens: Vec, + pub prompt_length: usize, + pub logit_biases: Vec, } // gvm_* are exposed to C in both GuidanceVm and GuidanceVmHelper @@ -61,37 +61,6 @@ pub trait GuidanceVm { fn gvm_append_token(&mut self, token: u32); } -struct MyGvm { - helper: GuidanceVmHelper, -} - -impl GuidanceVm for MyGvm { - fn gvm_create() -> Self { - MyGvm { - helper: GuidanceVmHelper::new(), - } - } - - fn gvm_process_prompt(&mut self) {} - - fn gvm_append_token(&mut self, token: u32) { - let toks = &mut self.helper.tokens; - toks.push(token); - // finish generation at 10 tokens - if toks.len() - self.helper.prompt_length >= 3 { - self.helper.logit_biases[50256] = 100.0 - } else { - self.helper.logit_biases[50256] = -100.0 - } - } - - fn gvm_clone(&mut self) -> Self { - MyGvm { - helper: self.helper.clone(), - } - } -} - #[macro_export] macro_rules! gvm_expose_all { ($struct_name:ident ) => { From 364a8be5862c78c198d13d0f92cd758a1639e8b0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Sep 2023 08:26:26 -0700 Subject: [PATCH 003/301] macros for regexes --- gvm_abi/src/lib.rs | 48 +++++++++++++++++---- gvm_abi/src/rx.rs | 103 ++++++++++++++++++++++++++++++++++++++++++++ gvm_abi/src/rxvm.rs | 55 +++++++++++++++++++++++ 3 files changed, 198 insertions(+), 8 deletions(-) create mode 100644 gvm_abi/src/rx.rs create mode 100644 gvm_abi/src/rxvm.rs diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 2748138d..e1da33e8 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,3 +1,6 @@ +pub mod rx; +pub mod rxvm; + /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); /// Generates "C" function: @@ -50,8 +53,6 @@ impl GuidanceVmHelper { } pub trait GuidanceVm { - /// Create a new instance of VM - fn gvm_create() -> Self; /// Create a new instance of VM, based on existing instance, for example when doing beam-search. fn gvm_clone(&mut self) -> Self; /// The prompt is in self.helper.tokens. @@ -63,15 +64,15 @@ pub trait GuidanceVm { #[macro_export] macro_rules! gvm_expose_all { - ($struct_name:ident ) => { - expose!($struct_name::gvm_process_prompt() -> ()); - expose!($struct_name::gvm_append_token(token: u32) -> ()); - expose!($struct_name::helper::gvm_get_logit_bias_buffer(size: u32) -> *mut f32); - expose!($struct_name::helper::gvm_get_prompt_buffer(size: u32) -> *mut u32); + ($struct_name:ident, $new:expr) => { + $crate::expose!($struct_name::gvm_process_prompt() -> ()); + $crate::expose!($struct_name::gvm_append_token(token: u32) -> ()); + $crate::expose!($struct_name::helper::gvm_get_logit_bias_buffer(size: u32) -> *mut f32); + $crate::expose!($struct_name::helper::gvm_get_prompt_buffer(size: u32) -> *mut u32); #[no_mangle] pub extern "C" fn gvm_create() -> *mut $struct_name { - let b = Box::new($struct_name::gvm_create()); + let b = Box::new($new); Box::into_raw(b) } @@ -87,3 +88,34 @@ macro_rules! gvm_expose_all { } } } + +#[macro_export] +macro_rules! include_bytes_as { + ($align_ty:ty, $path:literal) => {{ + #[repr(C)] // guarantee 'bytes' comes after '_align' + pub struct AlignedAs { + pub _align: [Align; 0], + pub bytes: Bytes, + } + + // this assignment is made possible by CoerceUnsized + static ALIGNED: &AlignedAs<$align_ty, [u8]> = &AlignedAs { + _align: [], + bytes: *include_bytes!($path), + }; + + let slice = &ALIGNED.bytes; + let ptr = slice.as_ptr() as *const $align_ty; + unsafe { std::slice::from_raw_parts(ptr, slice.len() / std::mem::size_of::<$align_ty>()) } + }}; +} + +#[macro_export] +macro_rules! TokenCompiled_from_bin { + () => { + $crate::rx::TokenCompiled { + token_data: $crate::include_bytes_as!(u16, "token_data.bin"), + state_data: $crate::include_bytes_as!(u32, "state_data.bin"), + } + }; +} diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs new file mode 100644 index 00000000..13d0ff9a --- /dev/null +++ b/gvm_abi/src/rx.rs @@ -0,0 +1,103 @@ +const NO_TOKEN: TokenId = 0; + +pub type TokenId = u16; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct StateId { + off: u32, +} + +impl StateId { + const NONE: StateId = StateId { off: 0 }; + const DEAD: StateId = StateId { off: 1 }; + pub const START: StateId = StateId { off: 4 }; +} + +#[derive(Clone)] +pub struct TokenCompiled { + pub token_data: &'static [u16], + pub state_data: &'static [u32], +} + +impl TokenCompiled { + fn token_in_token_set(&self, token: TokenId, set: u32) -> bool { + assert!(token != NO_TOKEN); + let mut idx = set as usize; + loop { + let v = self.token_data[idx]; + if v == token { + return true; + } + if v == NO_TOKEN { + return false; + } + idx = idx + 1; + } + } + + fn state_bias(state: StateId) -> f32 { + if state == StateId::DEAD { + -100.0 + } else { + 0.0 + } + } + + pub fn compute_logit_bias(&self, state_offset: StateId, bias: &mut [f32]) { + let mut p = state_offset.off as usize; + let default_state = StateId { + off: self.state_data[p], + }; + p += 1; + + let init_val = Self::state_bias(default_state); + for idx in 0..bias.len() { + bias[idx] = init_val; + } + + loop { + let state = StateId { + off: self.state_data[p], + }; + if state == StateId::NONE { + break; + } + p += 1; + let toks = self.state_data[p]; + p += 1; + let val = Self::state_bias(state); + + let mut idx = toks as usize; + loop { + let tok = self.token_data[idx]; + if tok == NO_TOKEN { + break; + } + bias[tok as usize] = val; + idx = idx + 1; + } + } + } + + pub fn advance(&self, state_offset: StateId, token: TokenId) -> StateId { + let mut p = state_offset.off as usize; + let default_state = StateId { + off: self.state_data[p], + }; + p += 1; + loop { + let state = StateId { + off: self.state_data[p], + }; + if state == StateId::NONE { + return default_state; + } + p += 1; + let toks = self.state_data[p]; + p += 1; + if self.token_in_token_set(token, toks) { + return state; + } + } + } +} diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs new file mode 100644 index 00000000..228f251f --- /dev/null +++ b/gvm_abi/src/rxvm.rs @@ -0,0 +1,55 @@ +use crate::rx::{StateId, TokenCompiled}; +use crate::{GuidanceVm, GuidanceVmHelper}; + +pub struct RxGvm { + pub helper: GuidanceVmHelper, + pub compiled: TokenCompiled, + pub state: StateId, +} + +impl RxGvm { + pub fn from_token_compiled(compiled: TokenCompiled) -> Self { + RxGvm { + helper: GuidanceVmHelper::new(), + compiled, + state: StateId::START, + } + } +} + +impl GuidanceVm for RxGvm { + fn gvm_process_prompt(&mut self) { + // the regex doesn't care about the prompt + } + + fn gvm_append_token(&mut self, token: u32) { + self.state = self.compiled.advance(self.state, token as u16); + + // save the token, just in case + let toks = &mut self.helper.tokens; + toks.push(token); + + // compute biases + self.compiled + .compute_logit_bias(self.state, &mut self.helper.logit_biases); + } + + // implement by hand for now - we may need some special processing here + fn gvm_clone(&mut self) -> Self { + RxGvm { + helper: self.helper.clone(), + compiled: self.compiled.clone(), + state: self.state.clone(), + } + } +} + +#[macro_export] +macro_rules! RxGvm_from_bin { + () => { + $crate::gvm_expose_all!( + RxGvm, + RxGvm::from_token_compiled($crate::TokenCompiled_from_bin!()) + ); + }; +} From 8673a2a0166963eb692f2ebf0e8523ca57d3b383 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Sep 2023 12:49:55 -0700 Subject: [PATCH 004/301] now deserializes --- gvm_abi/src/lib.rs | 16 +----- gvm_abi/src/rx.rs | 120 +++++++++++++++++++++++++++++++++++++------- gvm_abi/src/rxvm.rs | 20 ++------ 3 files changed, 109 insertions(+), 47 deletions(-) diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index e1da33e8..34e34c21 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -90,7 +90,7 @@ macro_rules! gvm_expose_all { } #[macro_export] -macro_rules! include_bytes_as { +macro_rules! include_bytes_aligned { ($align_ty:ty, $path:literal) => {{ #[repr(C)] // guarantee 'bytes' comes after '_align' pub struct AlignedAs { @@ -104,18 +104,6 @@ macro_rules! include_bytes_as { bytes: *include_bytes!($path), }; - let slice = &ALIGNED.bytes; - let ptr = slice.as_ptr() as *const $align_ty; - unsafe { std::slice::from_raw_parts(ptr, slice.len() / std::mem::size_of::<$align_ty>()) } + &ALIGNED.bytes }}; } - -#[macro_export] -macro_rules! TokenCompiled_from_bin { - () => { - $crate::rx::TokenCompiled { - token_data: $crate::include_bytes_as!(u16, "token_data.bin"), - state_data: $crate::include_bytes_as!(u32, "state_data.bin"), - } - }; -} diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs index 13d0ff9a..3b8fd85d 100644 --- a/gvm_abi/src/rx.rs +++ b/gvm_abi/src/rx.rs @@ -1,25 +1,109 @@ +use std::{mem::size_of, slice::from_raw_parts}; + const NO_TOKEN: TokenId = 0; pub type TokenId = u16; #[derive(Clone, Copy, PartialEq, Eq)] -pub struct StateId { - off: u32, +pub struct StateOffset { + pub off: u32, +} + +impl StateOffset { + pub const NONE: StateOffset = StateOffset { off: 0 }; + pub const DEAD: StateOffset = StateOffset { off: 1 }; + pub const START: StateOffset = StateOffset { off: 4 }; } -impl StateId { - const NONE: StateId = StateId { off: 0 }; - const DEAD: StateId = StateId { off: 1 }; - pub const START: StateId = StateId { off: 4 }; +#[repr(C)] +struct TokRxHeader { + magic: u32, + hd_size: u32, + state_bytes: u32, + token_bytes: u32, + info: TokRxInfo, + align: [u32; 0], } +#[repr(C)] #[derive(Clone)] -pub struct TokenCompiled { +pub struct TokRxInfo { + pub tok_eos: TokenId, +} + +fn clone_vec_as_bytes(input: &Vec) -> Vec { + unsafe { + let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); + byte_slice.to_vec() + } +} + +fn clone_as_bytes(input: &T) -> Vec { + unsafe { + let byte_slice = from_raw_parts(input as *const T as *const u8, size_of::()); + byte_slice.to_vec() + } +} + +impl TokRxHeader { + pub const MAGIC: u32 = 0x6623f10b; + pub const SIZE: u32 = size_of::() as u32; +} + +#[derive(Clone)] +pub struct TokRx { + pub info: &'static TokRxInfo, pub token_data: &'static [u16], pub state_data: &'static [u32], } -impl TokenCompiled { +impl TokRx { + pub fn deserialize(bytes: &'static [u8]) -> TokRx { + unsafe { + assert!(bytes.len() > TokRxHeader::SIZE as usize); + let hd = (bytes.as_ptr() as *const TokRxHeader).as_ref().unwrap(); + assert!(hd.magic == TokRxHeader::MAGIC); + assert!(hd.hd_size == TokRxHeader::SIZE); + let state_data = from_raw_parts( + bytes.as_ptr().add(TokRxHeader::SIZE as usize) as *const u32, + hd.state_bytes as usize / size_of::(), + ); + let token_data = from_raw_parts( + bytes + .as_ptr() + .add((TokRxHeader::SIZE + hd.state_bytes) as usize) + as *const u16, + hd.token_bytes as usize / size_of::(), + ); + TokRx { + info: &hd.info, + state_data, + token_data, + } + } + } + + pub fn serialize( + info: &TokRxInfo, + token_data: &Vec, + state_data: &Vec, + ) -> Vec { + let mut token_bytes = clone_vec_as_bytes(&token_data); + let mut state_bytes = clone_vec_as_bytes(&state_data); + let hd = TokRxHeader { + magic: TokRxHeader::MAGIC, + hd_size: TokRxHeader::SIZE, + info: info.clone(), + state_bytes: state_bytes.len() as u32, + token_bytes: token_bytes.len() as u32, + align: [], + }; + let mut bytes = clone_as_bytes(&hd); + bytes.append(&mut state_bytes); + bytes.append(&mut token_bytes); + bytes + } + fn token_in_token_set(&self, token: TokenId, set: u32) -> bool { assert!(token != NO_TOKEN); let mut idx = set as usize; @@ -35,17 +119,17 @@ impl TokenCompiled { } } - fn state_bias(state: StateId) -> f32 { - if state == StateId::DEAD { + fn state_bias(state: StateOffset) -> f32 { + if state == StateOffset::DEAD { -100.0 } else { 0.0 } } - pub fn compute_logit_bias(&self, state_offset: StateId, bias: &mut [f32]) { + pub fn compute_logit_bias(&self, state_offset: StateOffset, bias: &mut [f32]) { let mut p = state_offset.off as usize; - let default_state = StateId { + let default_state = StateOffset { off: self.state_data[p], }; p += 1; @@ -56,10 +140,10 @@ impl TokenCompiled { } loop { - let state = StateId { + let state = StateOffset { off: self.state_data[p], }; - if state == StateId::NONE { + if state == StateOffset::NONE { break; } p += 1; @@ -79,17 +163,17 @@ impl TokenCompiled { } } - pub fn advance(&self, state_offset: StateId, token: TokenId) -> StateId { + pub fn advance(&self, state_offset: StateOffset, token: TokenId) -> StateOffset { let mut p = state_offset.off as usize; - let default_state = StateId { + let default_state = StateOffset { off: self.state_data[p], }; p += 1; loop { - let state = StateId { + let state = StateOffset { off: self.state_data[p], }; - if state == StateId::NONE { + if state == StateOffset::NONE { return default_state; } p += 1; diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index 228f251f..d65a894b 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -1,18 +1,18 @@ -use crate::rx::{StateId, TokenCompiled}; +use crate::rx::{StateOffset, TokRx}; use crate::{GuidanceVm, GuidanceVmHelper}; pub struct RxGvm { pub helper: GuidanceVmHelper, - pub compiled: TokenCompiled, - pub state: StateId, + pub compiled: TokRx, + pub state: StateOffset, } impl RxGvm { - pub fn from_token_compiled(compiled: TokenCompiled) -> Self { + pub fn from_token_compiled(compiled: TokRx) -> Self { RxGvm { helper: GuidanceVmHelper::new(), compiled, - state: StateId::START, + state: StateOffset::START, } } } @@ -43,13 +43,3 @@ impl GuidanceVm for RxGvm { } } } - -#[macro_export] -macro_rules! RxGvm_from_bin { - () => { - $crate::gvm_expose_all!( - RxGvm, - RxGvm::from_token_compiled($crate::TokenCompiled_from_bin!()) - ); - }; -} From 7edda880344d0321904a7427ec1aec0da35fc4d9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Sep 2023 13:24:15 -0700 Subject: [PATCH 005/301] prefix coding --- gvm_abi/src/rx.rs | 105 ++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 59 deletions(-) diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs index 3b8fd85d..c9488020 100644 --- a/gvm_abi/src/rx.rs +++ b/gvm_abi/src/rx.rs @@ -1,8 +1,17 @@ use std::{mem::size_of, slice::from_raw_parts}; -const NO_TOKEN: TokenId = 0; - pub type TokenId = u16; +pub type Transition = (StateOffset, TokenSetOffset); + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct TokenSetOffset { + pub off: u32, +} + +pub struct StateDesc { + default_transition: StateOffset, + transitions: &'static [Transition], +} #[derive(Clone, Copy, PartialEq, Eq)] pub struct StateOffset { @@ -10,9 +19,8 @@ pub struct StateOffset { } impl StateOffset { - pub const NONE: StateOffset = StateOffset { off: 0 }; pub const DEAD: StateOffset = StateOffset { off: 1 }; - pub const START: StateOffset = StateOffset { off: 4 }; + pub const START: StateOffset = StateOffset { off: 3 }; } #[repr(C)] @@ -104,18 +112,26 @@ impl TokRx { bytes } - fn token_in_token_set(&self, token: TokenId, set: u32) -> bool { - assert!(token != NO_TOKEN); - let mut idx = set as usize; - loop { - let v = self.token_data[idx]; - if v == token { - return true; - } - if v == NO_TOKEN { - return false; - } - idx = idx + 1; + fn token_set(&self, set: TokenSetOffset) -> &'static [TokenId] { + let idx = set.off as usize; + let sz = self.token_data[idx] as usize; + unsafe { from_raw_parts(self.token_data.as_ptr().add(idx + 1), sz) } + } + + fn state_desc(&self, state: StateOffset) -> StateDesc { + let idx = state.off as usize; + let default_transition = StateOffset { + off: self.state_data[idx], + }; + let sz = self.state_data[idx + 1] as usize; + StateDesc { + default_transition, + transitions: unsafe { + from_raw_parts( + self.state_data.as_ptr().add(idx + 2) as *const Transition, + sz, + ) + }, } } @@ -128,60 +144,31 @@ impl TokRx { } pub fn compute_logit_bias(&self, state_offset: StateOffset, bias: &mut [f32]) { - let mut p = state_offset.off as usize; - let default_state = StateOffset { - off: self.state_data[p], - }; - p += 1; + let state = self.state_desc(state_offset); - let init_val = Self::state_bias(default_state); + let init_val = Self::state_bias(state.default_transition); for idx in 0..bias.len() { bias[idx] = init_val; } - loop { - let state = StateOffset { - off: self.state_data[p], - }; - if state == StateOffset::NONE { - break; - } - p += 1; - let toks = self.state_data[p]; - p += 1; - let val = Self::state_bias(state); - - let mut idx = toks as usize; - loop { - let tok = self.token_data[idx]; - if tok == NO_TOKEN { - break; - } - bias[tok as usize] = val; - idx = idx + 1; + for (st, ts) in state.transitions { + let val = Self::state_bias(*st); + let toks = self.token_set(*ts); + for tok in toks { + bias[*tok as usize] = val; } } } pub fn advance(&self, state_offset: StateOffset, token: TokenId) -> StateOffset { - let mut p = state_offset.off as usize; - let default_state = StateOffset { - off: self.state_data[p], - }; - p += 1; - loop { - let state = StateOffset { - off: self.state_data[p], - }; - if state == StateOffset::NONE { - return default_state; - } - p += 1; - let toks = self.state_data[p]; - p += 1; - if self.token_in_token_set(token, toks) { - return state; + let state = self.state_desc(state_offset); + + for (st, ts) in state.transitions { + if self.token_set(*ts).contains(&token) { + return *st; } } + + state.default_transition } } From 71ebcb385fd7919c6c7b6797d5fb96e3ecca0ee9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Sep 2023 15:29:35 -0700 Subject: [PATCH 006/301] printing, serialization validation --- gvm_abi/src/lib.rs | 19 +++++++++++++++ gvm_abi/src/printing.rs | 51 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 gvm_abi/src/printing.rs diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 34e34c21..97f78862 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,3 +1,4 @@ +pub mod printing; pub mod rx; pub mod rxvm; @@ -107,3 +108,21 @@ macro_rules! include_bytes_aligned { &ALIGNED.bytes }}; } + +#[macro_export] +macro_rules! println { + () => { + $crate::printing::_print("\n") + }; + ($($arg:tt)*) => {{ + $crate::printing::_print(&format!($($arg)*)); + $crate::printing::_print("\n"); + }}; +} + +#[macro_export] +macro_rules! print { + ($($arg:tt)*) => {{ + $crate::printing::_print(&format!($($arg)*)); + }}; +} diff --git a/gvm_abi/src/printing.rs b/gvm_abi/src/printing.rs new file mode 100644 index 00000000..3683cbf5 --- /dev/null +++ b/gvm_abi/src/printing.rs @@ -0,0 +1,51 @@ +use std::{io, panic}; + +extern "C" { + fn gvm_host_print(ptr: *const u8, len: u32); +} + +pub struct Printer {} + +impl io::Write for Printer { + fn write(&mut self, buf: &[u8]) -> io::Result { + unsafe { gvm_host_print(buf.as_ptr(), buf.len() as u32) }; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub fn init() { + panic::set_hook(Box::new(|info| { + let file = info.location().unwrap().file(); + let line = info.location().unwrap().line(); + let col = info.location().unwrap().column(); + + let msg = match info.payload().downcast_ref::<&'static str>() { + Some(s) => *s, + None => match info.payload().downcast_ref::() { + Some(s) => &s[..], + None => "Box", + }, + }; + + let err_info = format!("Panicked at '{}', {}:{}:{}\n", msg, file, line, col); + _print(&err_info); + })) +} + +pub fn stdout() -> Printer { + Printer {} +} + +pub fn _print(msg: &str) { + let vec: Vec = msg.into(); + unsafe { gvm_host_print(vec.as_ptr(), vec.len() as u32) }; +} + +#[no_mangle] +pub extern "C" fn gvm_init() { + init(); +} From 193a486ac90e0a122e23614335e7fe2e85bc3fca Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 28 Sep 2023 00:33:06 +0000 Subject: [PATCH 007/301] fixes --- gvm_abi/src/rxvm.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index d65a894b..368916fe 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -1,5 +1,5 @@ use crate::rx::{StateOffset, TokRx}; -use crate::{GuidanceVm, GuidanceVmHelper}; +use crate::{println, GuidanceVm, GuidanceVmHelper}; pub struct RxGvm { pub helper: GuidanceVmHelper, @@ -20,9 +20,13 @@ impl RxGvm { impl GuidanceVm for RxGvm { fn gvm_process_prompt(&mut self) { // the regex doesn't care about the prompt + self.state = StateOffset::START; + self.compiled + .compute_logit_bias(self.state, &mut self.helper.logit_biases); } fn gvm_append_token(&mut self, token: u32) { + println!("xapp {} {}", token, self.state.off); self.state = self.compiled.advance(self.state, token as u16); // save the token, just in case From 5c616d08f8c8cd29dabd10f90836a276b2bfba9a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 28 Sep 2023 17:56:40 +0000 Subject: [PATCH 008/301] bugfixes; setup logging --- gvm_abi/src/rxvm.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index 368916fe..5d993f48 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -19,6 +19,7 @@ impl RxGvm { impl GuidanceVm for RxGvm { fn gvm_process_prompt(&mut self) { + println!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt self.state = StateOffset::START; self.compiled @@ -26,7 +27,7 @@ impl GuidanceVm for RxGvm { } fn gvm_append_token(&mut self, token: u32) { - println!("xapp {} {}", token, self.state.off); + // println!("xapp {:?} {} {}", self as *const _, token, self.state.off); self.state = self.compiled.advance(self.state, token as u16); // save the token, just in case @@ -40,10 +41,12 @@ impl GuidanceVm for RxGvm { // implement by hand for now - we may need some special processing here fn gvm_clone(&mut self) -> Self { - RxGvm { + let r = RxGvm { helper: self.helper.clone(), compiled: self.compiled.clone(), state: self.state.clone(), - } + }; + println!("{} -> {}", self.state.off, r.state.off); + r } } From 3fe0c31ff526fb097ba6684f7481331c908ee37c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 28 Sep 2023 16:02:38 -0700 Subject: [PATCH 009/301] document WASM iface --- gvm_abi/src/gvm_iface.h | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 gvm_abi/src/gvm_iface.h diff --git a/gvm_abi/src/gvm_iface.h b/gvm_abi/src/gvm_iface.h new file mode 100644 index 00000000..4a51f597 --- /dev/null +++ b/gvm_abi/src/gvm_iface.h @@ -0,0 +1,53 @@ +// +// This interface needs to be implemented by the WASM binary +// + +// Tokens are assumed to be at most 32 bit. +typedef uint32_t token_t; + +// Called first, after instantiating WASM module. +void gvm_init(void); + +// Called once per module, to get a GVM for a specific query +Gvm *gvm_create(void); + +// If a query is split into several (eg., during beam-search, or when returning several results) +// this is called to get GVM for the sub-query. +Gvm *gvm_clone(Gvm *parent); + +// These two are called after gvm_create() and gvm_clone() on the fresh GVM. +// They should return the buffers that the WASM code has to allocated and keep around +// until relevant gvm_free(). + +// Return buffer where the prompt will be written. `size` is number of tokens in the prompt. +token_t *gvm_get_prompt_buffer(Gvm *gvm, uint32_t size); + +// Return the buffer where the WASM code will write logit biases after +// gvm_process_prompt() and gvm_append_token(). +// Size of number of biases (which equals size of the vocabulary). +float *gvm_get_logit_bias_buffer(Gvm *gvm, uint32_t size); + +// This called once, when the GVM should process the prompt in its buffer. +// It should set the values in logit bias buffer. +void gvm_process_prompt(Gvm *gvm); + +// This is called after a token is sampled. +// It should set the values in logit bias buffer. +void gvm_append_token(Gvm *gvm, token_t tok); + +// This is called for GVMs that no longer needed (eg. because generation completed, +// or beam-search branch was cut). +void gvm_free(Gvm *gvm); + +// +// This interface is available to the WASM binary +// + +// Log a string. +void gvm_host_print(const uint8_t *ptr, uint32_t size); + +// Provisional, not implemented yet: + +// Get bytes corresponding to given token. `size` is `sizeof(dst)`. +// The length of token is returned (even if its bigger than `size`). +uint32_t gvm_host_token_to_bytes(token_t token, uint8_t dst[], uint32_t size); From 17828d03ebbb3e4639a75a3f266d8c582f61fee3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 28 Sep 2023 16:05:34 -0700 Subject: [PATCH 010/301] more info --- gvm_abi/src/gvm_iface.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gvm_abi/src/gvm_iface.h b/gvm_abi/src/gvm_iface.h index 4a51f597..bcb706b5 100644 --- a/gvm_abi/src/gvm_iface.h +++ b/gvm_abi/src/gvm_iface.h @@ -30,10 +30,14 @@ float *gvm_get_logit_bias_buffer(Gvm *gvm, uint32_t size); // This called once, when the GVM should process the prompt in its buffer. // It should set the values in logit bias buffer. void gvm_process_prompt(Gvm *gvm); +// The logical type (if WASM would allow such things) of this function is: +// float[vocab_size] gvm_process_prompt(Gvm *gvm, token_t[] prompt); // This is called after a token is sampled. // It should set the values in logit bias buffer. void gvm_append_token(Gvm *gvm, token_t tok); +// The logical type (if WASM would allow such things) of this function is: +// float[vocab_size] gvm_append_token(Gvm *gvm, token_t tok); // This is called for GVMs that no longer needed (eg. because generation completed, // or beam-search branch was cut). From 6ff778123fc845934f6300871b77cc7a73d2ea48 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 2 Oct 2023 16:02:01 -0700 Subject: [PATCH 011/301] switch tokens u16->u32 --- gvm_abi/src/gvm_iface.h | 1 + gvm_abi/src/rx.rs | 8 ++++---- gvm_abi/src/rxvm.rs | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/gvm_iface.h b/gvm_abi/src/gvm_iface.h index bcb706b5..2c028808 100644 --- a/gvm_abi/src/gvm_iface.h +++ b/gvm_abi/src/gvm_iface.h @@ -3,6 +3,7 @@ // // Tokens are assumed to be at most 32 bit. +// Typical models range 30k (LLAMA) to 100k (GPT4) tokens. typedef uint32_t token_t; // Called first, after instantiating WASM module. diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs index c9488020..4164dcfd 100644 --- a/gvm_abi/src/rx.rs +++ b/gvm_abi/src/rx.rs @@ -1,6 +1,6 @@ use std::{mem::size_of, slice::from_raw_parts}; -pub type TokenId = u16; +pub type TokenId = u32; pub type Transition = (StateOffset, TokenSetOffset); #[derive(Clone, Copy, PartialEq, Eq)] @@ -61,7 +61,7 @@ impl TokRxHeader { #[derive(Clone)] pub struct TokRx { pub info: &'static TokRxInfo, - pub token_data: &'static [u16], + pub token_data: &'static [TokenId], pub state_data: &'static [u32], } @@ -80,8 +80,8 @@ impl TokRx { bytes .as_ptr() .add((TokRxHeader::SIZE + hd.state_bytes) as usize) - as *const u16, - hd.token_bytes as usize / size_of::(), + as *const TokenId, + hd.token_bytes as usize / size_of::(), ); TokRx { info: &hd.info, diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index 5d993f48..e3cf2468 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -28,7 +28,7 @@ impl GuidanceVm for RxGvm { fn gvm_append_token(&mut self, token: u32) { // println!("xapp {:?} {} {}", self as *const _, token, self.state.off); - self.state = self.compiled.advance(self.state, token as u16); + self.state = self.compiled.advance(self.state, token); // save the token, just in case let toks = &mut self.helper.tokens; From a1895745d17a657a4e22db64a08bee0c502a75cb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 08:39:55 -0700 Subject: [PATCH 012/301] move tokenizers to gvm_tokenizers --- gvm_abi/src/lib.rs | 1 + gvm_abi/src/toktree.rs | 89 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 gvm_abi/src/toktree.rs diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 97f78862..2f1ef2b8 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,6 +1,7 @@ pub mod printing; pub mod rx; pub mod rxvm; +pub mod toktree; /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs new file mode 100644 index 00000000..9ccaa6ee --- /dev/null +++ b/gvm_abi/src/toktree.rs @@ -0,0 +1,89 @@ +// use 8:24 encoding - num_ch:tok_id (ch_idx:ch_off)* - 8 bytes per token +// special case num_ch=0xff -> num_ch=0x100 + +use crate::rx::TokenId; + +pub struct TokNode { + pub byte: u8, + off: usize, + data: &'static [u32], +} + +impl TokNode { + const NO_TOKEN: u32 = 0xffffff; + + pub fn token_id(&self) -> Option { + let r = self.data[self.off] >> 8; + if r == Self::NO_TOKEN { + None + } else { + Some(r) + } + } + + pub fn num_children(&self) -> usize { + let num_ch = self.data[self.off] & 0xff; + if num_ch == 0xff { + 0x100 + } else { + num_ch as usize + } + } + + pub fn child_at_idx(&self, idx: usize) -> TokNode { + assert!(idx < self.num_children()); + let off = self.off + 1 + idx; + let ch_off = self.data[off] >> 8; + TokNode { + byte: (self.data[off] & 0xff) as u8, + off: ch_off as usize, + data: self.data, + } + } + + pub fn child_at_byte(&self, byte: u8) -> Option { + let num_ch = self.num_children(); + for idx in 0..num_ch { + let off = self.off + 1 + idx; + if (self.data[off] & 0xff) as u8 == byte { + return Some(self.child_at_idx(idx)); + } + } + None + } + + pub fn children(&self) -> TokNodeChildrenIter { + TokNodeChildrenIter { + parent: self, + idx: 0, + max_idx: self.num_children(), + } + } +} + +pub struct TokNodeChildrenIter<'a> { + parent: &'a TokNode, + idx: usize, + max_idx: usize, +} + +impl<'a> Iterator for TokNodeChildrenIter<'a> { + type Item = TokNode; + + fn next(&mut self) -> Option { + if self.idx < self.max_idx { + let child = self.parent.child_at_idx(self.idx); + self.idx += 1; + Some(child) + } else { + None + } + } +} + +#[repr(C)] +pub struct TokenizerBin { + magic: u32, + tokens_bytes: u32, + tree_bytes: u32, +} From a5ef95c307f1e52fa064b4b376023ff70e60d512 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 09:34:01 -0700 Subject: [PATCH 013/301] fighting lifetimes --- gvm_abi/src/toktree.rs | 65 ++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 9ccaa6ee..26da4d8e 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -3,17 +3,37 @@ use crate::rx::TokenId; -pub struct TokNode { +pub struct TokTrie { + data: Vec, +} + +pub struct TrieNode<'a> { + trie: &'a TokTrie, pub byte: u8, off: usize, - data: &'static [u32], + parent: Option<&'a TrieNode<'a>> } -impl TokNode { +impl TokTrie { + pub fn new() -> TokTrie { + TokTrie { data: Vec::new() } + } + + pub fn root<'a>(&'a self) -> TrieNode<'a> { + TrieNode { + trie: &self, + byte: 0, + off: 0, + parent: None, + } + } +} + +impl<'a> TrieNode<'a> { const NO_TOKEN: u32 = 0xffffff; pub fn token_id(&self) -> Option { - let r = self.data[self.off] >> 8; + let r = self.trie.data[self.off] >> 8; if r == Self::NO_TOKEN { None } else { @@ -22,7 +42,7 @@ impl TokNode { } pub fn num_children(&self) -> usize { - let num_ch = self.data[self.off] & 0xff; + let num_ch = self.trie.data[self.off] & 0xff; if num_ch == 0xff { 0x100 } else { @@ -30,30 +50,31 @@ impl TokNode { } } - pub fn child_at_idx(&self, idx: usize) -> TokNode { + pub fn child_at_idx(&'a self, idx: usize) -> TrieNode<'a> { assert!(idx < self.num_children()); let off = self.off + 1 + idx; - let ch_off = self.data[off] >> 8; - TokNode { - byte: (self.data[off] & 0xff) as u8, + let ch_off = self.trie.data[off] >> 8; + TrieNode { + trie: self.trie, + byte: (self.trie.data[off] & 0xff) as u8, off: ch_off as usize, - data: self.data, + parent: Some(self), } } - pub fn child_at_byte(&self, byte: u8) -> Option { + pub fn child_at_byte(&'a self, byte: u8) -> Option> { let num_ch = self.num_children(); for idx in 0..num_ch { let off = self.off + 1 + idx; - if (self.data[off] & 0xff) as u8 == byte { + if (self.trie.data[off] & 0xff) as u8 == byte { return Some(self.child_at_idx(idx)); } } None } - pub fn children(&self) -> TokNodeChildrenIter { - TokNodeChildrenIter { + pub fn children(&self) -> TrieNodeChildrenIter { + TrieNodeChildrenIter { parent: self, idx: 0, max_idx: self.num_children(), @@ -61,14 +82,14 @@ impl TokNode { } } -pub struct TokNodeChildrenIter<'a> { - parent: &'a TokNode, +pub struct TrieNodeChildrenIter<'a> { + parent: &'a TrieNode<'a>, idx: usize, max_idx: usize, } -impl<'a> Iterator for TokNodeChildrenIter<'a> { - type Item = TokNode; +impl<'a> Iterator for TrieNodeChildrenIter<'a> { + type Item = TrieNode<'a>; fn next(&mut self) -> Option { if self.idx < self.max_idx { @@ -87,3 +108,11 @@ pub struct TokenizerBin { tokens_bytes: u32, tree_bytes: u32, } + +pub fn iter(word: &[u8]) { + let trie = TokTrie::new(); + let mut n = trie.root(); + for &ch in word { + n = n.child_at_byte(ch).unwrap(); + } +} From 697d70de41e6649665aaf23c8e2700344dba0b06 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 10:36:58 -0700 Subject: [PATCH 014/301] building trie --- gvm_abi/src/toktree.rs | 169 ++++++++++++++++++++++++++++++----------- 1 file changed, 123 insertions(+), 46 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 26da4d8e..0a11001f 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,48 +1,59 @@ -// use 8:24 encoding - num_ch:tok_id (ch_idx:ch_off)* - 8 bytes per token +// use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 +use std::collections::HashMap; + use crate::rx::TokenId; pub struct TokTrie { data: Vec, } -pub struct TrieNode<'a> { - trie: &'a TokTrie, - pub byte: u8, - off: usize, - parent: Option<&'a TrieNode<'a>> +#[derive(Clone, Copy, Debug)] +pub struct TrieNode { + // ch_byte:ch_off + bits: u32, } +const NO_TOKEN: u32 = 0xffffff; + impl TokTrie { - pub fn new() -> TokTrie { - TokTrie { data: Vec::new() } + pub fn from(words: &Vec>) -> TokTrie { + let mut trie = TrieHash::new(); + for (idx, word) in words.iter().enumerate() { + if word.len() > 0 { + trie.insert(word, idx as u32) + } + } + let mut data = Vec::new(); + trie.serialize(&mut data); + TokTrie { data } } - pub fn root<'a>(&'a self) -> TrieNode<'a> { - TrieNode { - trie: &self, - byte: 0, - off: 0, - parent: None, - } + pub fn root(&self) -> TrieNode { + TrieNode { bits: 0 } } -} -impl<'a> TrieNode<'a> { - const NO_TOKEN: u32 = 0xffffff; + #[inline(always)] + fn at(&self, n: TrieNode, off: usize) -> u32 { + self.data[(n.bits >> 8) as usize + off] + } - pub fn token_id(&self) -> Option { - let r = self.trie.data[self.off] >> 8; - if r == Self::NO_TOKEN { + pub fn child_byte(&self, n: TrieNode) -> u8 { + (n.bits & 0xff) as u8 + } + + pub fn token_id(&self, n: TrieNode) -> Option { + let r = self.at(n, 0) >> 8; + if r == NO_TOKEN { None } else { Some(r) } } - pub fn num_children(&self) -> usize { - let num_ch = self.trie.data[self.off] & 0xff; + pub fn num_children(&self, n: TrieNode) -> usize { + let num_ch = self.at(n, 0) & 0xff; if num_ch == 0xff { 0x100 } else { @@ -50,50 +61,58 @@ impl<'a> TrieNode<'a> { } } - pub fn child_at_idx(&'a self, idx: usize) -> TrieNode<'a> { - assert!(idx < self.num_children()); - let off = self.off + 1 + idx; - let ch_off = self.trie.data[off] >> 8; + pub fn child_at_idx(&self, n: TrieNode, idx: usize) -> TrieNode { + assert!(idx < self.num_children(n)); TrieNode { - trie: self.trie, - byte: (self.trie.data[off] & 0xff) as u8, - off: ch_off as usize, - parent: Some(self), + bits: self.at(n, 1 + idx), } } - pub fn child_at_byte(&'a self, byte: u8) -> Option> { - let num_ch = self.num_children(); + pub fn child_at_byte(&self, n: TrieNode, byte: u8) -> Option { + let num_ch = self.num_children(n); for idx in 0..num_ch { - let off = self.off + 1 + idx; - if (self.trie.data[off] & 0xff) as u8 == byte { - return Some(self.child_at_idx(idx)); + // let byte2 = self.child_byte(self.child_at_idx(n, idx)); + let byte2 = (self.at(n, 1 + idx) & 0xff) as u8; + if byte2 == byte { + return Some(self.child_at_idx(n, idx)); } } None } - pub fn children(&self) -> TrieNodeChildrenIter { + pub fn child_at_bytes(&self, mut n: TrieNode, bytes: &[u8]) -> Option { + for &byte in bytes { + n = match self.child_at_byte(n, byte) { + Some(n) => n, + None => return None, + } + } + Some(n) + } + + pub fn children(&self, n: TrieNode) -> TrieNodeChildrenIter { TrieNodeChildrenIter { parent: self, + node: n, idx: 0, - max_idx: self.num_children(), + max_idx: self.num_children(n), } } } pub struct TrieNodeChildrenIter<'a> { - parent: &'a TrieNode<'a>, + parent: &'a TokTrie, + node: TrieNode, idx: usize, max_idx: usize, } impl<'a> Iterator for TrieNodeChildrenIter<'a> { - type Item = TrieNode<'a>; + type Item = TrieNode; fn next(&mut self) -> Option { if self.idx < self.max_idx { - let child = self.parent.child_at_idx(self.idx); + let child = self.parent.child_at_idx(self.node, self.idx); self.idx += 1; Some(child) } else { @@ -109,10 +128,68 @@ pub struct TokenizerBin { tree_bytes: u32, } -pub fn iter(word: &[u8]) { - let trie = TokTrie::new(); - let mut n = trie.root(); - for &ch in word { - n = n.child_at_byte(ch).unwrap(); +struct TrieHash { + token_id: u32, + children: HashMap, +} + +impl TrieHash { + fn new() -> TrieHash { + TrieHash { + token_id: NO_TOKEN, + children: HashMap::new(), + } + } + fn insert(&mut self, word: &[u8], token_id: u32) { + if word.len() == 0 { + assert!(self.token_id == NO_TOKEN); + self.token_id = token_id; + } else { + let ch = word[0]; + let child = self.children.entry(ch).or_insert_with(Self::new); + child.insert(&word[1..], token_id); + } + } + fn serialize(&self, data: &mut Vec) { + let mut child_ids = self.children.keys().collect::>(); + child_ids.sort(); + let mut len = child_ids.len(); + if len == 0x100 { + len = 0xff; + } else { + assert!(len < 0xf0); + } + let idx = data.len(); + data.push((self.token_id << 8) | len as u32); + data.resize(idx + 1 + child_ids.len(), 0); + for ch_idx in 0..child_ids.len() { + let ptr = data.len() as u32; + let ch_byte = child_ids[ch_idx]; + assert!((ptr << 8) >> 8 == ptr); + data[idx + 1 + ch_idx] = (ptr << 8) | (*ch_byte as u32); + self.children.get(ch_byte).unwrap().serialize(data); + } + } +} + +pub fn test_trie() { + let mut words0 = vec!["a", "b", "abc"]; + let words = words0 + .iter() + .map(|s| s.as_bytes().to_vec()) + .collect::>(); + let trie = TokTrie::from(&words); + let root = trie.root(); + words0.push("ab"); + words0.push("foo"); + for w in words0 { + match trie.child_at_bytes(root, &w.as_bytes().to_vec()) { + Some(n) => { + println!("{} -> {:?}", w, trie.token_id(n)); + } + None => { + println!("{} -> not found", w); + } + } } } From 4cfbfb45a85ddcfb08e1683226bade6893d4ce94 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 11:16:47 -0700 Subject: [PATCH 015/301] faster --- gvm_abi/src/toktree.rs | 93 +++++++++++++++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 20 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 0a11001f..2f737781 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -128,16 +128,22 @@ pub struct TokenizerBin { tree_bytes: u32, } +enum TrieChildren { + None, + One { k: u8, v: Box }, + Many { hash: HashMap }, +} + struct TrieHash { token_id: u32, - children: HashMap, + children: TrieChildren, } impl TrieHash { fn new() -> TrieHash { TrieHash { token_id: NO_TOKEN, - children: HashMap::new(), + children: TrieChildren::None, } } fn insert(&mut self, word: &[u8], token_id: u32) { @@ -146,29 +152,76 @@ impl TrieHash { self.token_id = token_id; } else { let ch = word[0]; - let child = self.children.entry(ch).or_insert_with(Self::new); - child.insert(&word[1..], token_id); + let rest = &word[1..]; + let children = std::mem::replace(&mut self.children, TrieChildren::None); + match children { + TrieChildren::Many { mut hash } => { + let child = hash.entry(ch).or_insert_with(Self::new); + child.insert(rest, token_id); + self.children = TrieChildren::Many { hash }; + } + TrieChildren::One { k, mut v } => { + if k == ch { + v.insert(rest, token_id); + self.children = TrieChildren::One { k, v }; + } else { + let mut child = Self::new(); + child.insert(rest, token_id); + let mut hash = HashMap::new(); + hash.insert(k, *v); + hash.insert(ch, child); + self.children = TrieChildren::Many { hash }; + } + } + TrieChildren::None => { + let mut child = Self::new(); + child.insert(rest, token_id); + self.children = TrieChildren::One { + k: ch, + v: Box::new(child), + }; + } + } } } + fn serialize_val(&self, len: usize) -> u32 { + (self.token_id << 8) | len as u32 + } + fn serialize(&self, data: &mut Vec) { - let mut child_ids = self.children.keys().collect::>(); - child_ids.sort(); - let mut len = child_ids.len(); - if len == 0x100 { - len = 0xff; - } else { - assert!(len < 0xf0); - } - let idx = data.len(); - data.push((self.token_id << 8) | len as u32); - data.resize(idx + 1 + child_ids.len(), 0); - for ch_idx in 0..child_ids.len() { - let ptr = data.len() as u32; - let ch_byte = child_ids[ch_idx]; + fn serialize_ch(off: usize, ch: u8) -> u32 { + let ptr = off as u32; + assert!(ptr as usize == off); assert!((ptr << 8) >> 8 == ptr); - data[idx + 1 + ch_idx] = (ptr << 8) | (*ch_byte as u32); - self.children.get(ch_byte).unwrap().serialize(data); + (ptr << 8) | (ch as u32) } + let idx = data.len(); + match &self.children { + TrieChildren::None => { + data.push(self.serialize_val(0)); + } + TrieChildren::One { k, v } => { + data.push(self.serialize_val(1)); + data.push(serialize_ch(idx + 2, *k)); + v.serialize(data); + } + TrieChildren::Many { hash } => { + let mut child_ids = hash.keys().map(|v| *v).collect::>(); + child_ids.sort(); + let mut len = child_ids.len(); + if len == 0x100 { + len = 0xff; + } else { + assert!(len < 0xf0); + } + data.push(self.serialize_val(len)); + data.resize(idx + 1 + child_ids.len(), 0); + for (ch_idx, ch_byte) in child_ids.iter().enumerate() { + data[idx + 1 + ch_idx] = serialize_ch(data.len(), *ch_byte); + hash.get(ch_byte).unwrap().serialize(data); + } + } + }; } } From 95e1f3a574f5e95bd02e36620bbb76e40663b604 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 12:36:27 -0700 Subject: [PATCH 016/301] faster yet --- gvm_abi/src/toktree.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 2f737781..c4971cd9 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,7 +1,7 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use std::collections::HashMap; +use std::collections::BTreeMap; use crate::rx::TokenId; @@ -131,7 +131,7 @@ pub struct TokenizerBin { enum TrieChildren { None, One { k: u8, v: Box }, - Many { hash: HashMap }, + Many { hash: BTreeMap }, } struct TrieHash { @@ -167,7 +167,7 @@ impl TrieHash { } else { let mut child = Self::new(); child.insert(rest, token_id); - let mut hash = HashMap::new(); + let mut hash = BTreeMap::new(); hash.insert(k, *v); hash.insert(ch, child); self.children = TrieChildren::Many { hash }; @@ -206,8 +206,8 @@ impl TrieHash { v.serialize(data); } TrieChildren::Many { hash } => { - let mut child_ids = hash.keys().map(|v| *v).collect::>(); - child_ids.sort(); + let child_ids = hash.keys().map(|v| *v).collect::>(); + // child_ids.sort(); let mut len = child_ids.len(); if len == 0x100 { len = 0xff; From ddeac53f1bfbe0babf5d8bcc47680ff57c447eb0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 12:47:14 -0700 Subject: [PATCH 017/301] simpler, faster --- gvm_abi/src/toktree.rs | 117 ++++++++++------------------------------- 1 file changed, 28 insertions(+), 89 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index c4971cd9..cf3a1083 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,8 +1,6 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use std::collections::BTreeMap; - use crate::rx::TokenId; pub struct TokTrie { @@ -19,7 +17,7 @@ const NO_TOKEN: u32 = 0xffffff; impl TokTrie { pub fn from(words: &Vec>) -> TokTrie { - let mut trie = TrieHash::new(); + let mut trie = TrieHash::new(0xff); for (idx, word) in words.iter().enumerate() { if word.len() > 0 { trie.insert(word, idx as u32) @@ -128,22 +126,18 @@ pub struct TokenizerBin { tree_bytes: u32, } -enum TrieChildren { - None, - One { k: u8, v: Box }, - Many { hash: BTreeMap }, -} - struct TrieHash { token_id: u32, - children: TrieChildren, + byte: u8, + children: Vec, } impl TrieHash { - fn new() -> TrieHash { + fn new(byte: u8) -> TrieHash { TrieHash { token_id: NO_TOKEN, - children: TrieChildren::None, + byte, + children: Vec::new(), } } fn insert(&mut self, word: &[u8], token_id: u32) { @@ -151,44 +145,22 @@ impl TrieHash { assert!(self.token_id == NO_TOKEN); self.token_id = token_id; } else { - let ch = word[0]; - let rest = &word[1..]; - let children = std::mem::replace(&mut self.children, TrieChildren::None); - match children { - TrieChildren::Many { mut hash } => { - let child = hash.entry(ch).or_insert_with(Self::new); - child.insert(rest, token_id); - self.children = TrieChildren::Many { hash }; - } - TrieChildren::One { k, mut v } => { - if k == ch { - v.insert(rest, token_id); - self.children = TrieChildren::One { k, v }; - } else { - let mut child = Self::new(); - child.insert(rest, token_id); - let mut hash = BTreeMap::new(); - hash.insert(k, *v); - hash.insert(ch, child); - self.children = TrieChildren::Many { hash }; - } - } - TrieChildren::None => { - let mut child = Self::new(); - child.insert(rest, token_id); - self.children = TrieChildren::One { - k: ch, - v: Box::new(child), - }; + for idx in 0..self.children.len() { + if self.children[idx].byte == word[0] { + self.children[idx].insert(&word[1..], token_id); + return; } } + let mut ch = TrieHash::new(word[0]); + ch.insert(&word[1..], token_id); + self.children.push(ch); } } fn serialize_val(&self, len: usize) -> u32 { (self.token_id << 8) | len as u32 } - fn serialize(&self, data: &mut Vec) { + fn serialize(&mut self, data: &mut Vec) { fn serialize_ch(off: usize, ch: u8) -> u32 { let ptr = off as u32; assert!(ptr as usize == off); @@ -196,53 +168,20 @@ impl TrieHash { (ptr << 8) | (ch as u32) } let idx = data.len(); - match &self.children { - TrieChildren::None => { - data.push(self.serialize_val(0)); - } - TrieChildren::One { k, v } => { - data.push(self.serialize_val(1)); - data.push(serialize_ch(idx + 2, *k)); - v.serialize(data); - } - TrieChildren::Many { hash } => { - let child_ids = hash.keys().map(|v| *v).collect::>(); - // child_ids.sort(); - let mut len = child_ids.len(); - if len == 0x100 { - len = 0xff; - } else { - assert!(len < 0xf0); - } - data.push(self.serialize_val(len)); - data.resize(idx + 1 + child_ids.len(), 0); - for (ch_idx, ch_byte) in child_ids.iter().enumerate() { - data[idx + 1 + ch_idx] = serialize_ch(data.len(), *ch_byte); - hash.get(ch_byte).unwrap().serialize(data); - } - } - }; - } -} - -pub fn test_trie() { - let mut words0 = vec!["a", "b", "abc"]; - let words = words0 - .iter() - .map(|s| s.as_bytes().to_vec()) - .collect::>(); - let trie = TokTrie::from(&words); - let root = trie.root(); - words0.push("ab"); - words0.push("foo"); - for w in words0 { - match trie.child_at_bytes(root, &w.as_bytes().to_vec()) { - Some(n) => { - println!("{} -> {:?}", w, trie.token_id(n)); - } - None => { - println!("{} -> not found", w); - } + let mut len = self.children.len(); + if len == 0x100 { + len = 0xff; + } else { + assert!(len < 0xf0); + } + data.push(self.serialize_val(len)); + data.resize(idx + 1 + self.children.len(), 0); + self.children.sort_by_key(|e| e.byte); + let mut ch_idx = idx + 1; + for entry in &mut self.children { + data[ch_idx] = serialize_ch(data.len(), entry.byte); + ch_idx += 1; + entry.serialize(data); } } } From faf0a8097308a31d2003dda44a5b99ababca7c1a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 13:02:18 -0700 Subject: [PATCH 018/301] optimize lookup --- gvm_abi/src/toktree.rs | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index cf3a1083..674c94ad 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -4,7 +4,7 @@ use crate::rx::TokenId; pub struct TokTrie { - data: Vec, + pub data: Vec, } #[derive(Clone, Copy, Debug)] @@ -68,6 +68,9 @@ impl TokTrie { pub fn child_at_byte(&self, n: TrieNode, byte: u8) -> Option { let num_ch = self.num_children(n); + if num_ch == 0x100 { + return Some(self.child_at_idx(n, byte as usize)); + } for idx in 0..num_ch { // let byte2 = self.child_byte(self.child_at_idx(n, idx)); let byte2 = (self.at(n, 1 + idx) & 0xff) as u8; @@ -145,15 +148,32 @@ impl TrieHash { assert!(self.token_id == NO_TOKEN); self.token_id = token_id; } else { - for idx in 0..self.children.len() { - if self.children[idx].byte == word[0] { - self.children[idx].insert(&word[1..], token_id); + if self.children.len() == 0x100 { + // assert!(self.children[word[0] as usize].byte == word[0]); + self.children[word[0] as usize].insert(&word[1..], token_id); + return; + } + + for ch in &mut self.children { + if ch.byte == word[0] { + ch.insert(&word[1..], token_id); return; } } + let mut ch = TrieHash::new(word[0]); ch.insert(&word[1..], token_id); self.children.push(ch); + + // if it's getting dense, make it full + if self.children.len() > 50 { + let mut v2 = (0..=255).map(TrieHash::new).collect::>(); + for ch in self.children.drain(..) { + let idx = ch.byte as usize; + v2[idx] = ch; + } + self.children = v2; + } } } fn serialize_val(&self, len: usize) -> u32 { From ab5b96f85c5f3321bf91602f924155a5450c8aa6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 14:57:33 -0700 Subject: [PATCH 019/301] add Recognizer iface --- gvm_abi/src/lib.rs | 1 + gvm_abi/src/recognizer.rs | 104 ++++++++++++++++++++++++++++++++++++++ gvm_abi/src/toktree.rs | 4 +- 3 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 gvm_abi/src/recognizer.rs diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 2f1ef2b8..45c625e1 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -2,6 +2,7 @@ pub mod printing; pub mod rx; pub mod rxvm; pub mod toktree; +pub mod recognizer; /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs new file mode 100644 index 00000000..8042ff30 --- /dev/null +++ b/gvm_abi/src/recognizer.rs @@ -0,0 +1,104 @@ +use crate::toktree::{TokTrie, TrieNode}; + +pub trait Recognizer { + fn append(&self, bytes: &[u8]) -> Self + where + Self: Sized, + { + let mut rec = self.append1(bytes[0]); + for b in &bytes[1..] { + rec = rec.append1(*b); + } + rec + } + fn append1(&self, byte: u8) -> Self; + fn allowed(&self) -> Vec>; +} + +fn append_bias( + trie: &TokTrie, + rec: &impl Recognizer, + logits: &mut [f32], + mut n: Option, + v: &Vec, +) { + for b in v { + match n { + Some(n2) => { + n = trie.child_at_byte(n2, *b); + match n { + Some(n3) => { + if let Some(tok) = trie.token_id(n3) { + logits[tok as usize] = 0.0; + } + } + None => break, + } + } + None => break, + } + } + + if n.is_some() { + let rec = rec.append(v); + for v in rec.allowed() { + append_bias(trie, &rec, logits, n, &v); + } + } +} + +pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { + logits.iter_mut().for_each(|x| *x = -100.0); + for v in rec.allowed() { + append_bias(trie, rec, logits, Some(trie.root()), &v); + } +} + +pub struct Uppercase { + len: usize, +} + +impl Uppercase { + pub fn new() -> Self { + Uppercase { len: 0 } + } +} + +impl Recognizer for Uppercase { + fn append1(&self, _byte: u8) -> Self { + Uppercase { len: self.len + 1 } + } + + fn allowed(&self) -> Vec> { + if self.len > 1 { + ('a'..'z').map(|c| vec![c as u8]).collect() + } else { + ('A'..'Z').map(|c| vec![c as u8]).collect() + } + } +} + +// pub struct PrefixEnum { +// prefix_ch: u8, +// depth: u32, +// allowed: Vec>, +// } + +// impl Recognizer for PrefixEnum { +// fn append1(&self, byte: u8) -> Self { +// let mut depth = self.depth; +// for b in bytes { +// if depth > 0 { +// depth += 1; +// } +// if depth == 0 && *b == self.prefix_ch { +// depth = 1 +// } +// } +// todo!() +// } + +// fn allowed(&self) -> Vec> { +// self.allowed.clone() +// } +// } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 674c94ad..8ba39e97 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -166,7 +166,9 @@ impl TrieHash { self.children.push(ch); // if it's getting dense, make it full - if self.children.len() > 50 { + // for cl100k threshold 60->15 nodes, 50->22, 40->45 30->94 + // for llama (32k) 50->5, 40->15 + if self.children.len() > 40 { let mut v2 = (0..=255).map(TrieHash::new).collect::>(); for ch in self.children.drain(..) { let idx = ch.byte as usize; From 9cc484d2d0ca847483ab4c78caab554d34862fe7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 15:13:37 -0700 Subject: [PATCH 020/301] mask recognizer --- gvm_abi/src/recognizer.rs | 69 +++++++++++++++------------------------ 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 8042ff30..d0db3f50 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,57 +1,35 @@ use crate::toktree::{TokTrie, TrieNode}; pub trait Recognizer { - fn append(&self, bytes: &[u8]) -> Self - where - Self: Sized, - { - let mut rec = self.append1(bytes[0]); - for b in &bytes[1..] { - rec = rec.append1(*b); - } - rec - } - fn append1(&self, byte: u8) -> Self; - fn allowed(&self) -> Vec>; + fn append(&self, byte: u8) -> Self; + fn allowed(&self, mask: &mut [u8]); } fn append_bias( trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32], - mut n: Option, - v: &Vec, + mask: &mut [u8], + n: TrieNode, ) { - for b in v { - match n { - Some(n2) => { - n = trie.child_at_byte(n2, *b); - match n { - Some(n3) => { - if let Some(tok) = trie.token_id(n3) { - logits[tok as usize] = 0.0; - } - } - None => break, + rec.allowed(mask); + for idx in 0..=255 { + if mask[idx] != 0 { + if let Some(ch) = trie.child_at_byte(n, idx as u8) { + if let Some(tok) = trie.token_id(ch) { + logits[tok as usize] = 0.0; } + append_bias(trie, &rec.append(idx as u8), logits, mask, ch) } - None => break, - } - } - - if n.is_some() { - let rec = rec.append(v); - for v in rec.allowed() { - append_bias(trie, &rec, logits, n, &v); } } } pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); - for v in rec.allowed() { - append_bias(trie, rec, logits, Some(trie.root()), &v); - } + let mut mask = Vec::new(); + mask.resize(256, 0); + append_bias(trie, rec, logits, &mut mask, trie.root()); } pub struct Uppercase { @@ -61,19 +39,26 @@ pub struct Uppercase { impl Uppercase { pub fn new() -> Self { Uppercase { len: 0 } - } + } } impl Recognizer for Uppercase { - fn append1(&self, _byte: u8) -> Self { + fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } - fn allowed(&self) -> Vec> { - if self.len > 1 { - ('a'..'z').map(|c| vec![c as u8]).collect() + fn allowed(&self, mask: &mut [u8]) { + for idx in 0..255 { + mask[idx] = 0; + } + if self.len < 2 { + for ch in 'A'..'Z' { + mask[ch as usize] = 1; + } } else { - ('A'..'Z').map(|c| vec![c as u8]).collect() + for ch in 'a'..'z' { + mask[ch as usize] = 1; + } } } } From ebc3934078cd2dc111cd6f30fed6c95cdf62302a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 15:19:23 -0700 Subject: [PATCH 021/301] faster, more correct --- gvm_abi/src/recognizer.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index d0db3f50..0ae8e945 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -13,15 +13,17 @@ fn append_bias( n: TrieNode, ) { rec.allowed(mask); - for idx in 0..=255 { - if mask[idx] != 0 { - if let Some(ch) = trie.child_at_byte(n, idx as u8) { - if let Some(tok) = trie.token_id(ch) { - logits[tok as usize] = 0.0; - } - append_bias(trie, &rec.append(idx as u8), logits, mask, ch) - } + + let sel = trie + .children(n) + .filter(|c| mask[trie.child_byte(*c) as usize] != 0) + .collect::>(); + + for ch in sel { + if let Some(tok) = trie.token_id(ch) { + logits[tok as usize] = 0.0; } + append_bias(trie, &rec.append(trie.child_byte(n)), logits, mask, ch) } } From 05a8fc0b419e5c9338cbf8cfa25ce1dc7e693752 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 16:18:49 -0700 Subject: [PATCH 022/301] AllowedResult --- gvm_abi/src/recognizer.rs | 63 +++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 0ae8e945..4ee7cd35 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -2,28 +2,31 @@ use crate::toktree::{TokTrie, TrieNode}; pub trait Recognizer { fn append(&self, byte: u8) -> Self; - fn allowed(&self, mask: &mut [u8]); + fn allowed<'a>(&self, mask: &'a mut [u8]) -> AllowedResult<'a>; } fn append_bias( trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32], - mask: &mut [u8], + maskbuf: &mut [u8], n: TrieNode, ) { - rec.allowed(mask); - - let sel = trie - .children(n) - .filter(|c| mask[trie.child_byte(*c) as usize] != 0) - .collect::>(); + let sel = trie.children(n); + let sel: Vec = match rec.allowed(maskbuf) { + AllowedResult::All => sel.collect(), + AllowedResult::None => return, + AllowedResult::Mask(mask) => sel + .filter(|c| mask[trie.child_byte(*c) as usize] != 0) + .collect(), + AllowedResult::List(lst) => sel.filter(|c| lst.contains(&trie.child_byte(*c))).collect(), + }; for ch in sel { if let Some(tok) = trie.token_id(ch) { logits[tok as usize] = 0.0; } - append_bias(trie, &rec.append(trie.child_byte(n)), logits, mask, ch) + append_bias(trie, &rec.append(trie.child_byte(ch)), logits, maskbuf, ch) } } @@ -44,24 +47,40 @@ impl Uppercase { } } +pub enum AllowedResult<'a> { + All, + None, + Mask(&'a [u8]), + List(&'a [u8]), +} + impl Recognizer for Uppercase { fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } - fn allowed(&self, mask: &mut [u8]) { - for idx in 0..255 { - mask[idx] = 0; - } - if self.len < 2 { - for ch in 'A'..'Z' { - mask[ch as usize] = 1; - } - } else { - for ch in 'a'..'z' { - mask[ch as usize] = 1; - } - } + fn allowed<'a>(&self, _mask: &'a mut [u8]) -> AllowedResult<'a> { + AllowedResult::All + + // if self.len < 2 { + // AllowedResult::List(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ") + // } else { + // AllowedResult::List(b"abcdefghijklmnopqrstuvwxyz") + // } + + // let mut idx = 0; + // if self.len < 2 { + // for ch in 'A'..'Z' { + // mask[idx] = ch as u8; + // idx += 1; + // } + // } else { + // for ch in 'a'..'z' { + // mask[idx] = ch as u8; + // idx += 1; + // } + // } + // AllowedResult::List(&mask[0..idx]) } } From 8d7dfb62afb6f63cbebea3c2a17847f1cd08db76 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 16:57:41 -0700 Subject: [PATCH 023/301] 31ms/tok --- gvm_abi/src/recognizer.rs | 45 ++++++--------------------------------- gvm_abi/src/toktree.rs | 42 +++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 4ee7cd35..e74931ce 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -2,7 +2,7 @@ use crate::toktree::{TokTrie, TrieNode}; pub trait Recognizer { fn append(&self, byte: u8) -> Self; - fn allowed<'a>(&self, mask: &'a mut [u8]) -> AllowedResult<'a>; + fn allowed(&self, c: u8) -> bool; } fn append_bias( @@ -12,17 +12,7 @@ fn append_bias( maskbuf: &mut [u8], n: TrieNode, ) { - let sel = trie.children(n); - let sel: Vec = match rec.allowed(maskbuf) { - AllowedResult::All => sel.collect(), - AllowedResult::None => return, - AllowedResult::Mask(mask) => sel - .filter(|c| mask[trie.child_byte(*c) as usize] != 0) - .collect(), - AllowedResult::List(lst) => sel.filter(|c| lst.contains(&trie.child_byte(*c))).collect(), - }; - - for ch in sel { + for ch in trie.masked_children(n, rec) { if let Some(tok) = trie.token_id(ch) { logits[tok as usize] = 0.0; } @@ -47,40 +37,19 @@ impl Uppercase { } } -pub enum AllowedResult<'a> { - All, - None, - Mask(&'a [u8]), - List(&'a [u8]), -} - impl Recognizer for Uppercase { fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } - fn allowed<'a>(&self, _mask: &'a mut [u8]) -> AllowedResult<'a> { - AllowedResult::All - - // if self.len < 2 { - // AllowedResult::List(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ") - // } else { - // AllowedResult::List(b"abcdefghijklmnopqrstuvwxyz") - // } - - // let mut idx = 0; + fn allowed(&self, byte: u8) -> bool { + byte != 0xff + // let ch = _byte as char; // if self.len < 2 { - // for ch in 'A'..'Z' { - // mask[idx] = ch as u8; - // idx += 1; - // } + // 'A' <= ch && ch <= 'Z' // } else { - // for ch in 'a'..'z' { - // mask[idx] = ch as u8; - // idx += 1; - // } + // 'a' <= ch && ch <= 'z' // } - // AllowedResult::List(&mask[0..idx]) } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 8ba39e97..7b30d6f8 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,7 +1,7 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::rx::TokenId; +use crate::{recognizer::Recognizer, rx::TokenId}; pub struct TokTrie { pub data: Vec, @@ -99,6 +99,46 @@ impl TokTrie { max_idx: self.num_children(n), } } + + pub fn masked_children<'a, T: Recognizer>( + &'a self, + n: TrieNode, + rec: &'a T, + ) -> MaskedChildrenIterator<'a, T> { + let len = self.data.len(); + let index = (n.bits >> 8) as usize + 1; + let max_index = index + self.num_children(n); + assert!(max_index <= len); + MaskedChildrenIterator { + recognizer: rec, + ptr: self.data.as_ptr(), + index, + max_index, + } + } +} + +pub struct MaskedChildrenIterator<'a, T: Recognizer> { + recognizer: &'a T, + ptr: *const u32, + index: usize, + max_index: usize, +} + +impl<'a, T: Recognizer> Iterator for MaskedChildrenIterator<'a, T> { + type Item = TrieNode; + + fn next(&mut self) -> Option { + while self.index < self.max_index { + let bits = unsafe { *self.ptr.add(self.index) }; + self.index += 1; + let byte = (bits & 0xff) as u8; + if self.recognizer.allowed(byte) { + return Some(TrieNode { bits }); + } + } + None + } } pub struct TrieNodeChildrenIter<'a> { From 8978a27ce280f8106182b31777bbe7f462b64285 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 17:56:16 -0700 Subject: [PATCH 024/301] inline --- gvm_abi/src/recognizer.rs | 3 ++- gvm_abi/src/toktree.rs | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index e74931ce..020a2390 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -2,7 +2,7 @@ use crate::toktree::{TokTrie, TrieNode}; pub trait Recognizer { fn append(&self, byte: u8) -> Self; - fn allowed(&self, c: u8) -> bool; + fn allowed(&self, byte: u8) -> bool; } fn append_bias( @@ -20,6 +20,7 @@ fn append_bias( } } +#[inline(never)] pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); let mut mask = Vec::new(); diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 7b30d6f8..4d41ddae 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -37,10 +37,12 @@ impl TokTrie { self.data[(n.bits >> 8) as usize + off] } + #[inline(always)] pub fn child_byte(&self, n: TrieNode) -> u8 { (n.bits & 0xff) as u8 } + #[inline(always)] pub fn token_id(&self, n: TrieNode) -> Option { let r = self.at(n, 0) >> 8; if r == NO_TOKEN { @@ -50,6 +52,7 @@ impl TokTrie { } } + #[inline(always)] pub fn num_children(&self, n: TrieNode) -> usize { let num_ch = self.at(n, 0) & 0xff; if num_ch == 0xff { @@ -100,6 +103,7 @@ impl TokTrie { } } + #[inline(always)] pub fn masked_children<'a, T: Recognizer>( &'a self, n: TrieNode, From ce36e386417caeceb8b8e56c36a30f890aa440f9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 21:54:40 -0700 Subject: [PATCH 025/301] simplify and speed up encoding --- gvm_abi/src/recognizer.rs | 24 +---- gvm_abi/src/toktree.rs | 198 ++++++++++++-------------------------- 2 files changed, 69 insertions(+), 153 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 020a2390..add29e4c 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,31 +1,15 @@ -use crate::toktree::{TokTrie, TrieNode}; +use crate::toktree::{append_bias, TokTrie}; pub trait Recognizer { fn append(&self, byte: u8) -> Self; fn allowed(&self, byte: u8) -> bool; } -fn append_bias( - trie: &TokTrie, - rec: &impl Recognizer, - logits: &mut [f32], - maskbuf: &mut [u8], - n: TrieNode, -) { - for ch in trie.masked_children(n, rec) { - if let Some(tok) = trie.token_id(ch) { - logits[tok as usize] = 0.0; - } - append_bias(trie, &rec.append(trie.child_byte(ch)), logits, maskbuf, ch) - } -} - #[inline(never)] pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); - let mut mask = Vec::new(); - mask.resize(256, 0); - append_bias(trie, rec, logits, &mut mask, trie.root()); + let n = trie.root(); + append_bias(rec, logits, n); } pub struct Uppercase { @@ -39,10 +23,12 @@ impl Uppercase { } impl Recognizer for Uppercase { + #[inline(always)] fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } + #[inline(always)] fn allowed(&self, byte: u8) -> bool { byte != 0xff // let ch = _byte as char; diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 4d41ddae..490e2680 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,90 +1,91 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::{recognizer::Recognizer, rx::TokenId}; +use crate::recognizer::Recognizer; pub struct TokTrie { - pub data: Vec, + pub data: Vec, } -#[derive(Clone, Copy, Debug)] pub struct TrieNode { - // ch_byte:ch_off + // byte:token bits: u32, + subtree_size: u32, } const NO_TOKEN: u32 = 0xffffff; -impl TokTrie { - pub fn from(words: &Vec>) -> TokTrie { - let mut trie = TrieHash::new(0xff); - for (idx, word) in words.iter().enumerate() { - if word.len() > 0 { - trie.insert(word, idx as u32) - } +impl TrieNode { + fn new(byte: u8, token_id: u32) -> TrieNode { + TrieNode { + bits: (token_id << 8) | byte as u32, + subtree_size: 0, } - let mut data = Vec::new(); - trie.serialize(&mut data); - TokTrie { data } } - pub fn root(&self) -> TrieNode { - TrieNode { bits: 0 } + #[inline(always)] + unsafe fn next(&self) -> *const TrieNode { + self.ptr().add(self.subtree_size as usize) + } + + #[inline(always)] + unsafe fn ptr(&self) -> *const TrieNode { + self as *const TrieNode } #[inline(always)] - fn at(&self, n: TrieNode, off: usize) -> u32 { - self.data[(n.bits >> 8) as usize + off] + unsafe fn child0(&self) -> *const TrieNode { + self.ptr().add(1) } #[inline(always)] - pub fn child_byte(&self, n: TrieNode) -> u8 { - (n.bits & 0xff) as u8 + pub fn byte(&self) -> u8 { + (self.bits & 0xff) as u8 } #[inline(always)] - pub fn token_id(&self, n: TrieNode) -> Option { - let r = self.at(n, 0) >> 8; + pub fn token_id(&self) -> Option { + let r = self.bits >> 8; if r == NO_TOKEN { None } else { Some(r) } } +} - #[inline(always)] - pub fn num_children(&self, n: TrieNode) -> usize { - let num_ch = self.at(n, 0) & 0xff; - if num_ch == 0xff { - 0x100 - } else { - num_ch as usize +impl TokTrie { + pub fn from(words: &Vec>) -> TokTrie { + let mut trie = TrieHash::new(0xff); + for (idx, word) in words.iter().enumerate() { + if word.len() > 0 { + trie.insert(word, idx as u32) + } } + let mut data = Vec::new(); + trie.serialize(&mut data); + TokTrie { data } } - pub fn child_at_idx(&self, n: TrieNode, idx: usize) -> TrieNode { - assert!(idx < self.num_children(n)); - TrieNode { - bits: self.at(n, 1 + idx), - } + pub fn root(&self) -> &TrieNode { + &self.data[0] } - pub fn child_at_byte(&self, n: TrieNode, byte: u8) -> Option { - let num_ch = self.num_children(n); - if num_ch == 0x100 { - return Some(self.child_at_idx(n, byte as usize)); - } - for idx in 0..num_ch { - // let byte2 = self.child_byte(self.child_at_idx(n, idx)); - let byte2 = (self.at(n, 1 + idx) & 0xff) as u8; - if byte2 == byte { - return Some(self.child_at_idx(n, idx)); + pub fn child_at_byte(&self, n: &TrieNode, byte: u8) -> Option<&TrieNode> { + unsafe { + let mut p = n.child0(); + let endp = n.next(); + while p < endp { + if (*p).byte() == byte { + return Some(&*p); + } + p = (*p).next(); } } None } - pub fn child_at_bytes(&self, mut n: TrieNode, bytes: &[u8]) -> Option { + pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> { for &byte in bytes { n = match self.child_at_byte(n, byte) { Some(n) => n, @@ -93,76 +94,23 @@ impl TokTrie { } Some(n) } - - pub fn children(&self, n: TrieNode) -> TrieNodeChildrenIter { - TrieNodeChildrenIter { - parent: self, - node: n, - idx: 0, - max_idx: self.num_children(n), - } - } - - #[inline(always)] - pub fn masked_children<'a, T: Recognizer>( - &'a self, - n: TrieNode, - rec: &'a T, - ) -> MaskedChildrenIterator<'a, T> { - let len = self.data.len(); - let index = (n.bits >> 8) as usize + 1; - let max_index = index + self.num_children(n); - assert!(max_index <= len); - MaskedChildrenIterator { - recognizer: rec, - ptr: self.data.as_ptr(), - index, - max_index, - } - } -} - -pub struct MaskedChildrenIterator<'a, T: Recognizer> { - recognizer: &'a T, - ptr: *const u32, - index: usize, - max_index: usize, } -impl<'a, T: Recognizer> Iterator for MaskedChildrenIterator<'a, T> { - type Item = TrieNode; - - fn next(&mut self) -> Option { - while self.index < self.max_index { - let bits = unsafe { *self.ptr.add(self.index) }; - self.index += 1; - let byte = (bits & 0xff) as u8; - if self.recognizer.allowed(byte) { - return Some(TrieNode { bits }); +pub fn append_bias(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { + unsafe { + let mut p = n.child0(); + let endp = n.next(); + while p < endp { + let n = &*p; + p = n.next(); + let b = n.byte(); + if rec.allowed(b) { + if let Some(tok) = n.token_id() { + logits[tok as usize] = 0.0; + } + append_bias(&rec.append(b), logits, n); } } - None - } -} - -pub struct TrieNodeChildrenIter<'a> { - parent: &'a TokTrie, - node: TrieNode, - idx: usize, - max_idx: usize, -} - -impl<'a> Iterator for TrieNodeChildrenIter<'a> { - type Item = TrieNode; - - fn next(&mut self) -> Option { - if self.idx < self.max_idx { - let child = self.parent.child_at_idx(self.node, self.idx); - self.idx += 1; - Some(child) - } else { - None - } } } @@ -212,7 +160,8 @@ impl TrieHash { // if it's getting dense, make it full // for cl100k threshold 60->15 nodes, 50->22, 40->45 30->94 // for llama (32k) 50->5, 40->15 - if self.children.len() > 40 { + // TODO remove this? + if self.children.len() > 250 { let mut v2 = (0..=255).map(TrieHash::new).collect::>(); for ch in self.children.drain(..) { let idx = ch.byte as usize; @@ -222,32 +171,13 @@ impl TrieHash { } } } - fn serialize_val(&self, len: usize) -> u32 { - (self.token_id << 8) | len as u32 - } - - fn serialize(&mut self, data: &mut Vec) { - fn serialize_ch(off: usize, ch: u8) -> u32 { - let ptr = off as u32; - assert!(ptr as usize == off); - assert!((ptr << 8) >> 8 == ptr); - (ptr << 8) | (ch as u32) - } + fn serialize(&mut self, data: &mut Vec) { let idx = data.len(); - let mut len = self.children.len(); - if len == 0x100 { - len = 0xff; - } else { - assert!(len < 0xf0); - } - data.push(self.serialize_val(len)); - data.resize(idx + 1 + self.children.len(), 0); + data.push(TrieNode::new(self.byte, self.token_id)); self.children.sort_by_key(|e| e.byte); - let mut ch_idx = idx + 1; for entry in &mut self.children { - data[ch_idx] = serialize_ch(data.len(), entry.byte); - ch_idx += 1; entry.serialize(data); } + data[idx].subtree_size = (data.len() - idx) as u32; } } From d0e611c0e5ec84d69f5b9337c85cda4a46e32971 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 3 Oct 2023 22:05:23 -0700 Subject: [PATCH 026/301] 19ms --- gvm_abi/src/recognizer.rs | 3 +-- gvm_abi/src/toktree.rs | 13 ++++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index add29e4c..94424ce2 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -8,8 +8,7 @@ pub trait Recognizer { #[inline(never)] pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); - let n = trie.root(); - append_bias(rec, logits, n); + append_bias(trie, rec, logits); } pub struct Uppercase { diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 490e2680..539d129d 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -96,10 +96,15 @@ impl TokTrie { } } -pub fn append_bias(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { +pub fn append_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { + let n = trie.root(); + append_bias_core(rec, logits, n); +} + +fn append_bias_core(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { unsafe { - let mut p = n.child0(); let endp = n.next(); + let mut p = n.child0(); while p < endp { let n = &*p; p = n.next(); @@ -108,7 +113,9 @@ pub fn append_bias(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { if let Some(tok) = n.token_id() { logits[tok as usize] = 0.0; } - append_bias(&rec.append(b), logits, n); + if n.subtree_size > 1 { + append_bias_core(&rec.append(b), logits, n); + } } } } From c2323e722b13ee36c7ca1d652fdc68434cb1cd54 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 4 Oct 2023 07:53:00 -0700 Subject: [PATCH 027/301] experiments --- gvm_abi/src/recognizer.rs | 4 ++-- gvm_abi/src/toktree.rs | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 94424ce2..d9c5bb20 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -22,12 +22,12 @@ impl Uppercase { } impl Recognizer for Uppercase { - #[inline(always)] + //#[inline(always)] fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } - #[inline(always)] + //#[inline(always)] fn allowed(&self, byte: u8) -> bool { byte != 0xff // let ch = _byte as char; diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 539d129d..4cf766cb 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -121,6 +121,36 @@ fn append_bias_core(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { } } +fn walk_core(n: &TrieNode) -> u32 { + let mut sum = 0; + let mut stack = Vec::with_capacity(20); + unsafe { + stack.push((n.child0(), n.next())); + loop { + let (mut p, mut endp) = stack.pop().unwrap(); + while p < endp { + let n = &*p; + p = n.next(); + sum += n.subtree_size; + if n.subtree_size > 1 { + stack.push((p, endp)); + endp = p; + p = n.child0(); + } + } + if stack.is_empty() { + break; + } + } + } + sum +} + +pub fn walk(trie: &TokTrie) -> u32 { + let n = trie.root(); + walk_core(n) +} + #[repr(C)] pub struct TokenizerBin { magic: u32, From b8b480e5cb29fa19eee3620ee6c3d2ef2d8fab17 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 4 Oct 2023 20:22:50 +0000 Subject: [PATCH 028/301] tokenizer bin-serialization --- gvm_abi/src/rx.rs | 49 +++++++++++-- gvm_abi/src/toktree.rs | 156 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 193 insertions(+), 12 deletions(-) diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs index 4164dcfd..1a052726 100644 --- a/gvm_abi/src/rx.rs +++ b/gvm_abi/src/rx.rs @@ -1,4 +1,8 @@ -use std::{mem::size_of, slice::from_raw_parts}; +use std::{ + mem::{self, size_of}, + ptr, + slice::from_raw_parts, +}; pub type TokenId = u32; pub type Transition = (StateOffset, TokenSetOffset); @@ -34,25 +38,62 @@ struct TokRxHeader { } #[repr(C)] -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Debug)] pub struct TokRxInfo { + pub vocab_size: u32, pub tok_eos: TokenId, } -fn clone_vec_as_bytes(input: &Vec) -> Vec { +pub fn clone_vec_as_bytes(input: &Vec) -> Vec { unsafe { let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); byte_slice.to_vec() } } -fn clone_as_bytes(input: &T) -> Vec { +pub fn clone_as_bytes(input: &T) -> Vec { unsafe { let byte_slice = from_raw_parts(input as *const T as *const u8, size_of::()); byte_slice.to_vec() } } +pub fn box_from_bytes(bytes: &[u8]) -> Box { + if bytes.len() != mem::size_of::() { + panic!( + "T: got {} bytes, needed {}", + bytes.len(), + mem::size_of::() + ); + } + let mut t: Box = Box::new(unsafe { mem::zeroed() }); + unsafe { + ptr::copy_nonoverlapping( + bytes.as_ptr(), + &mut *t as *mut T as *mut u8, + mem::size_of::(), + ); + } + t +} + +pub fn vec_from_bytes(bytes: &[u8]) -> Vec { + if bytes.len() % mem::size_of::() != 0 { + panic!( + "vecT: got {} bytes, needed mult of {}", + bytes.len(), + mem::size_of::() + ); + } + let num_elements = bytes.len() / mem::size_of::(); + let mut result = Vec::with_capacity(num_elements); + unsafe { + result.set_len(num_elements); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), result.as_mut_ptr() as *mut u8, bytes.len()); + } + result +} + impl TokRxHeader { pub const MAGIC: u32 = 0x6623f10b; pub const SIZE: u32 = size_of::() as u32; diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 4cf766cb..dcdfb93c 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,12 +1,34 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::recognizer::Recognizer; +use crate::{ + recognizer::Recognizer, + rx::{box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId}, +}; pub struct TokTrie { - pub data: Vec, + info: TokRxInfo, + token_offsets: Vec, + token_data: Vec, + nodes: Vec, } +#[repr(C)] +pub struct TokTrieHeader { + magic: u32, + hd_size: u32, + trie_bytes: u32, + token_offset_bytes: u32, + token_data_bytes: u32, + info: TokRxInfo, + align: [u32; 0], +} + +impl TokTrieHeader { + const MAGIC: u32 = 0x558b6fd3; +} + +#[repr(C)] pub struct TrieNode { // byte:token bits: u32, @@ -55,23 +77,141 @@ impl TrieNode { } impl TokTrie { - pub fn from(words: &Vec>) -> TokTrie { + pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { let mut trie = TrieHash::new(0xff); + let mut token_offsets = Vec::new(); + let mut token_data = Vec::new(); + println!("info: {:?} wl={}", info, words.len()); + assert!(info.vocab_size == words.len() as u32); for (idx, word) in words.iter().enumerate() { if word.len() > 0 { trie.insert(word, idx as u32) } + assert!(word.len() < 0xff); + let desc = (word.len() as u32) | ((token_data.len() as u32) << 8); + token_offsets.push(desc); + token_data.extend_from_slice(word); } - let mut data = Vec::new(); - trie.serialize(&mut data); - TokTrie { data } + let mut nodes = Vec::new(); + trie.serialize(&mut nodes); + let r = TokTrie { + info: info.clone(), + token_offsets, + token_data, + nodes, + }; + r.validate(); + r + } + + pub fn info(&self) -> &TokRxInfo { + &self.info + } + + pub fn token(&self, idx: u32) -> &[u8] { + let off = self.token_offsets[idx as usize]; + let len = off & 0xff; + let off = (off >> 8) as usize; + &self.token_data[off..(off + len as usize)] + } + + pub fn from_bytes(bytes: &[u8]) -> Self { + let pref = std::mem::size_of::(); + let hd = *box_from_bytes::(&bytes[0..pref]); + assert!(hd.magic == TokTrieHeader::MAGIC); + assert!(hd.hd_size as usize == pref); + + let trie_end = pref + hd.trie_bytes as usize; + let nodes = vec_from_bytes(&bytes[pref..trie_end]); + let offsets_end = trie_end + hd.token_offset_bytes as usize; + let token_offsets = vec_from_bytes(&bytes[trie_end..offsets_end]); + let token_data = vec_from_bytes(&bytes[offsets_end..]); + + let r = TokTrie { + info: hd.info, + token_offsets, + token_data, + nodes, + }; + r.validate(); + r + } + + fn validate_node(&self, n: &TrieNode, ep: *const TrieNode, used: &mut [bool]) { + if let Some(tok) = n.token_id() { + assert!(tok < self.info.vocab_size); + assert!(!used[tok as usize]); + used[tok as usize] = true; + } + unsafe { + let endp = n.next(); + assert!(endp <= ep); + let mut p = n.child0(); + while p < endp { + let n = &*p; + p = n.next(); + self.validate_node(n, endp, used) + } + } + } + + fn validate(&self) { + self.validate_node( + self.root(), + self.nodes.as_ptr_range().end, + &mut vec![false; self.info.vocab_size as usize], + ); + for idx in 0..self.info.vocab_size { + let _ = self.token(idx); + } + } + + pub fn serialize(&self) -> Vec { + let mut trie_data = clone_vec_as_bytes(&self.nodes); + let mut token_offsets = clone_vec_as_bytes(&self.token_offsets); + let mut token_data = clone_vec_as_bytes(&self.token_data); + + let hd = TokTrieHeader { + magic: TokTrieHeader::MAGIC, + hd_size: std::mem::size_of::() as u32, + trie_bytes: trie_data.len() as u32, + token_offset_bytes: token_offsets.len() as u32, + token_data_bytes: trie_data.len() as u32, + info: self.info.clone(), + align: [], + }; + + let mut bytes = clone_as_bytes(&hd); + bytes.append(&mut trie_data); + bytes.append(&mut token_offsets); + bytes.append(&mut token_data); + bytes } pub fn root(&self) -> &TrieNode { - &self.data[0] + &self.nodes[0] + } + + pub fn check_against(&self, tokens: &Vec>) { + let vocab_size = tokens.len(); + for idx in 0..vocab_size { + let bytes = &tokens[idx]; + let tid = idx as TokenId; + assert!(bytes == self.token(tid)); + let root = self.root(); + if bytes.len() > 0 { + assert!( + self.child_at_bytes(root, &bytes) + .unwrap() + .token_id() + .unwrap() + == tid + ); + } + } } - pub fn child_at_byte(&self, n: &TrieNode, byte: u8) -> Option<&TrieNode> { + pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> { unsafe { let mut p = n.child0(); let endp = n.next(); From 34b8ffdb9cc02a21643b9b5b68c1d949e05498b4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 4 Oct 2023 20:58:57 +0000 Subject: [PATCH 029/301] optimize walk() --- gvm_abi/src/toktree.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index dcdfb93c..a8c3b4af 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -263,22 +263,26 @@ fn append_bias_core(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { fn walk_core(n: &TrieNode) -> u32 { let mut sum = 0; - let mut stack = Vec::with_capacity(20); + let mut stack_buf: [(*const TrieNode, *const TrieNode); 130] = [(0 as _, 0 as _); 130]; + let mut stack_ptr = 0; unsafe { - stack.push((n.child0(), n.next())); + stack_buf[stack_ptr] = (n.child0(), n.next()); + stack_ptr += 1; loop { - let (mut p, mut endp) = stack.pop().unwrap(); + stack_ptr -= 1; + let (mut p, mut endp) = stack_buf[stack_ptr]; while p < endp { let n = &*p; p = n.next(); sum += n.subtree_size; if n.subtree_size > 1 { - stack.push((p, endp)); + stack_buf[stack_ptr] = (p, endp); + stack_ptr += 1; endp = p; p = n.child0(); } } - if stack.is_empty() { + if stack_ptr == 0 { break; } } From 633810f55f23442f5a20c2cd64f182fa3b699842 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 4 Oct 2023 21:20:46 +0000 Subject: [PATCH 030/301] run the actual test --- gvm_abi/src/toktree.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index a8c3b4af..a1b32a4b 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -108,6 +108,10 @@ impl TokTrie { &self.info } + pub fn vocab_size(&self) -> usize { + self.info.vocab_size as usize + } + pub fn token(&self, idx: u32) -> &[u8] { let off = self.token_offsets[idx as usize]; let len = off & 0xff; From 34ee9bb611953659292f11dc568ee5bc33ed0390 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 15:32:58 +0000 Subject: [PATCH 031/301] printing improvements --- gvm_abi/src/lib.rs | 4 ++-- gvm_abi/src/printing.rs | 22 ++++++++++++++++------ gvm_abi/src/rxvm.rs | 8 ++++---- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 45c625e1..d134d345 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -112,7 +112,7 @@ macro_rules! include_bytes_aligned { } #[macro_export] -macro_rules! println { +macro_rules! wprintln { () => { $crate::printing::_print("\n") }; @@ -123,7 +123,7 @@ macro_rules! println { } #[macro_export] -macro_rules! print { +macro_rules! wprint { ($($arg:tt)*) => {{ $crate::printing::_print(&format!($($arg)*)); }}; diff --git a/gvm_abi/src/printing.rs b/gvm_abi/src/printing.rs index 3683cbf5..07133695 100644 --- a/gvm_abi/src/printing.rs +++ b/gvm_abi/src/printing.rs @@ -1,4 +1,4 @@ -use std::{io, panic}; +use std::io; extern "C" { fn gvm_host_print(ptr: *const u8, len: u32); @@ -17,8 +17,9 @@ impl io::Write for Printer { } } -pub fn init() { - panic::set_hook(Box::new(|info| { +pub fn init_panic() { + #[cfg(target_arch = "wasm32")] + std::panic::set_hook(Box::new(|info| { let file = info.location().unwrap().file(); let line = info.location().unwrap().line(); let col = info.location().unwrap().column(); @@ -41,11 +42,20 @@ pub fn stdout() -> Printer { } pub fn _print(msg: &str) { - let vec: Vec = msg.into(); - unsafe { gvm_host_print(vec.as_ptr(), vec.len() as u32) }; + #[cfg(target_arch = "wasm32")] + { + let vec: Vec = msg.into(); + unsafe { gvm_host_print(vec.as_ptr(), vec.len() as u32) }; + } + + #[cfg(not(target_arch = "wasm32"))] + { + use std::io::Write; + std::io::stdout().write_all(msg.as_bytes()).unwrap(); + } } #[no_mangle] pub extern "C" fn gvm_init() { - init(); + init_panic(); } diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index e3cf2468..057bc41d 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -1,5 +1,5 @@ use crate::rx::{StateOffset, TokRx}; -use crate::{println, GuidanceVm, GuidanceVmHelper}; +use crate::{wprintln, GuidanceVm, GuidanceVmHelper}; pub struct RxGvm { pub helper: GuidanceVmHelper, @@ -19,7 +19,7 @@ impl RxGvm { impl GuidanceVm for RxGvm { fn gvm_process_prompt(&mut self) { - println!("prompt, {} tokens", self.helper.prompt_length); + wprintln!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt self.state = StateOffset::START; self.compiled @@ -27,7 +27,7 @@ impl GuidanceVm for RxGvm { } fn gvm_append_token(&mut self, token: u32) { - // println!("xapp {:?} {} {}", self as *const _, token, self.state.off); + // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); self.state = self.compiled.advance(self.state, token); // save the token, just in case @@ -46,7 +46,7 @@ impl GuidanceVm for RxGvm { compiled: self.compiled.clone(), state: self.state.clone(), }; - println!("{} -> {}", self.state.off, r.state.off); + wprintln!("{} -> {}", self.state.off, r.state.off); r } } From efdbb76a31ecd313f5c29906f875ad1a4f65d0a5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 16:00:20 +0000 Subject: [PATCH 032/301] add num_parents field --- gvm_abi/src/toktree.rs | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index a1b32a4b..e07fc257 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -32,22 +32,22 @@ impl TokTrieHeader { pub struct TrieNode { // byte:token bits: u32, - subtree_size: u32, + bits2: u32, } const NO_TOKEN: u32 = 0xffffff; impl TrieNode { - fn new(byte: u8, token_id: u32) -> TrieNode { + fn new(byte: u8, token_id: u32, num_parents: u8) -> TrieNode { TrieNode { bits: (token_id << 8) | byte as u32, - subtree_size: 0, + bits2: num_parents as u32, } } #[inline(always)] unsafe fn next(&self) -> *const TrieNode { - self.ptr().add(self.subtree_size as usize) + self.ptr().add(self.subtree_size()) } #[inline(always)] @@ -65,6 +65,16 @@ impl TrieNode { (self.bits & 0xff) as u8 } + #[inline(always)] + pub fn subtree_size(&self) -> usize { + (self.bits2 >> 8) as usize + } + + #[inline(always)] + pub fn num_parents(&self) -> usize { + (self.bits2 & 0xff) as usize + } + #[inline(always)] pub fn token_id(&self) -> Option { let r = self.bits >> 8; @@ -93,7 +103,7 @@ impl TokTrie { token_data.extend_from_slice(word); } let mut nodes = Vec::new(); - trie.serialize(&mut nodes); + trie.serialize(&mut nodes, 0); let r = TokTrie { info: info.clone(), token_offsets, @@ -257,7 +267,7 @@ fn append_bias_core(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { if let Some(tok) = n.token_id() { logits[tok as usize] = 0.0; } - if n.subtree_size > 1 { + if n.subtree_size() > 1 { append_bias_core(&rec.append(b), logits, n); } } @@ -278,8 +288,8 @@ fn walk_core(n: &TrieNode) -> u32 { while p < endp { let n = &*p; p = n.next(); - sum += n.subtree_size; - if n.subtree_size > 1 { + sum += n.subtree_size() as u32; + if n.subtree_size() > 1 { stack_buf[stack_ptr] = (p, endp); stack_ptr += 1; endp = p; @@ -356,13 +366,19 @@ impl TrieHash { } } } - fn serialize(&mut self, data: &mut Vec) { + fn serialize(&mut self, data: &mut Vec, num_parents: u32) { let idx = data.len(); - data.push(TrieNode::new(self.byte, self.token_id)); + data.push(TrieNode::new(self.byte, self.token_id, num_parents as u8)); self.children.sort_by_key(|e| e.byte); + let mut num_ch = self.children.len(); for entry in &mut self.children { - entry.serialize(data); + num_ch -= 1; + if num_ch == 0 { + entry.serialize(data, num_parents + 1); + } else { + entry.serialize(data, 0); + } } - data[idx].subtree_size = (data.len() - idx) as u32; + data[idx].bits2 |= ((data.len() - idx) as u32) << 8; } } From 643bb9fea5ddd169692c9128801f3d7c7b8b4220 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 16:02:37 +0000 Subject: [PATCH 033/301] more interesting recognizer --- gvm_abi/src/recognizer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index d9c5bb20..480c1bb1 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -29,7 +29,7 @@ impl Recognizer for Uppercase { //#[inline(always)] fn allowed(&self, byte: u8) -> bool { - byte != 0xff + byte != (('z' as usize + self.len) & 0xff) as u8 // let ch = _byte as char; // if self.len < 2 { // 'A' <= ch && ch <= 'Z' From ff2a5ad15b4f0e1e37c9574c9773ad3528ed29bd Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 18:12:12 +0000 Subject: [PATCH 034/301] yet faster tree walking --- gvm_abi/src/recognizer.rs | 7 ++++--- gvm_abi/src/toktree.rs | 42 +++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 480c1bb1..1277747b 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -6,11 +6,12 @@ pub trait Recognizer { } #[inline(never)] -pub fn compute_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { +pub fn compute_bias(trie: &TokTrie, rec: (impl Recognizer + Copy), logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); append_bias(trie, rec, logits); } +#[derive(Copy, Clone)] pub struct Uppercase { len: usize, } @@ -22,12 +23,12 @@ impl Uppercase { } impl Recognizer for Uppercase { - //#[inline(always)] + #[inline(never)] fn append(&self, _byte: u8) -> Self { Uppercase { len: self.len + 1 } } - //#[inline(always)] + #[inline(never)] fn allowed(&self, byte: u8) -> bool { byte != (('z' as usize + self.len) & 0xff) as u8 // let ch = _byte as char; diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index e07fc257..3a812c09 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -103,7 +103,7 @@ impl TokTrie { token_data.extend_from_slice(word); } let mut nodes = Vec::new(); - trie.serialize(&mut nodes, 0); + trie.serialize(&mut nodes, 1); let r = TokTrie { info: info.clone(), token_offsets, @@ -250,28 +250,32 @@ impl TokTrie { } } -pub fn append_bias(trie: &TokTrie, rec: &impl Recognizer, logits: &mut [f32]) { +pub fn append_bias(trie: &TokTrie, rec0: T, logits: &mut [f32]) { let n = trie.root(); - append_bias_core(rec, logits, n); -} - -fn append_bias_core(rec: &impl Recognizer, logits: &mut [f32], n: &TrieNode) { + let mut stack_buf = [rec0; 130]; + let mut stack_ptr = 1; + let defl_tok = trie.vocab_size() as u32; unsafe { - let endp = n.next(); let mut p = n.child0(); + let endp = n.next(); while p < endp { let n = &*p; - p = n.next(); let b = n.byte(); + let rec = &stack_buf[stack_ptr - 1]; if rec.allowed(b) { - if let Some(tok) = n.token_id() { - logits[tok as usize] = 0.0; - } - if n.subtree_size() > 1 { - append_bias_core(&rec.append(b), logits, n); + logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; + stack_buf[stack_ptr] = rec.append(b); + stack_ptr += 1; + if n.subtree_size() == 1 { + stack_ptr -= n.num_parents(); } + p = n.child0(); + } else { + p = n.next(); + stack_ptr -= n.num_parents() - 1; } } + //panic!("st: {}", stack_ptr); } } @@ -366,18 +370,14 @@ impl TrieHash { } } } - fn serialize(&mut self, data: &mut Vec, num_parents: u32) { + fn serialize(&mut self, data: &mut Vec, num_parents: u8) { let idx = data.len(); - data.push(TrieNode::new(self.byte, self.token_id, num_parents as u8)); - self.children.sort_by_key(|e| e.byte); let mut num_ch = self.children.len(); + data.push(TrieNode::new(self.byte, self.token_id, num_parents)); + self.children.sort_by_key(|e| e.byte); for entry in &mut self.children { num_ch -= 1; - if num_ch == 0 { - entry.serialize(data, num_parents + 1); - } else { - entry.serialize(data, 0); - } + entry.serialize(data, if num_ch == 0 { num_parents + 1 } else { 1 }); } data[idx].bits2 |= ((data.len() - idx) as u32) << 8; } From b7f27dc614969061caf6d0f44138d920059556bb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 22:44:30 +0000 Subject: [PATCH 035/301] add gvm_host_read_token_trie() --- gvm_abi/src/gvm_iface.h | 8 +++----- gvm_abi/src/toktree.rs | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/gvm_iface.h b/gvm_abi/src/gvm_iface.h index 2c028808..2b0cb7f1 100644 --- a/gvm_abi/src/gvm_iface.h +++ b/gvm_abi/src/gvm_iface.h @@ -51,8 +51,6 @@ void gvm_free(Gvm *gvm); // Log a string. void gvm_host_print(const uint8_t *ptr, uint32_t size); -// Provisional, not implemented yet: - -// Get bytes corresponding to given token. `size` is `sizeof(dst)`. -// The length of token is returned (even if its bigger than `size`). -uint32_t gvm_host_token_to_bytes(token_t token, uint8_t dst[], uint32_t size); +// Read binary representation of TokTrie. +// Always returns the size of the trie, will write up to `size` bytes to `dst`. +uint32_t gvm_host_read_token_trie(uint8_t *dst, uint32_t size); diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 3a812c09..a169531c 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -86,7 +86,25 @@ impl TrieNode { } } +#[allow(dead_code)] +extern "C" { + fn gvm_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; +} + impl TokTrie { + pub fn from_env() -> Self { + #[cfg(target_arch = "wasm32")] + unsafe { + let size = gvm_host_read_token_trie(0 as _, 0); + let mut buffer = vec![0u8; size as usize]; + gvm_host_read_token_trie(buffer.as_mut_ptr(), size); + Self::from_bytes(&buffer) + } + + #[cfg(not(target_arch = "wasm32"))] + Self::from_bytes(&std::fs::read("tokenizer.bin").unwrap()) + } + pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { let mut trie = TrieHash::new(0xff); let mut token_offsets = Vec::new(); From a5ca43a290d164f3d5bfdabe1e5db7e61545acb8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Oct 2023 22:44:43 +0000 Subject: [PATCH 036/301] clean up unused code warnings --- gvm_abi/src/printing.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/printing.rs b/gvm_abi/src/printing.rs index 07133695..0b66df8b 100644 --- a/gvm_abi/src/printing.rs +++ b/gvm_abi/src/printing.rs @@ -1,11 +1,17 @@ use std::io; +#[allow(dead_code)] extern "C" { fn gvm_host_print(ptr: *const u8, len: u32); } +#[cfg(not(target_arch = "wasm32"))] +pub type Printer = std::io::Stdout; + +#[cfg(target_arch = "wasm32")] pub struct Printer {} +#[cfg(target_arch = "wasm32")] impl io::Write for Printer { fn write(&mut self, buf: &[u8]) -> io::Result { unsafe { gvm_host_print(buf.as_ptr(), buf.len() as u32) }; @@ -38,7 +44,15 @@ pub fn init_panic() { } pub fn stdout() -> Printer { - Printer {} + #[cfg(target_arch = "wasm32")] + { + Printer {} + } + + #[cfg(not(target_arch = "wasm32"))] + { + io::stdout() + } } pub fn _print(msg: &str) { From 6a9dda76bd0ff1e822aa356a2458ead6d36a4e2e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 00:14:40 +0000 Subject: [PATCH 037/301] better recognizer iface --- gvm_abi/src/recognizer.rs | 36 ++++++++++-------------------------- gvm_abi/src/toktree.rs | 21 +++++++++++++-------- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 1277747b..b24893f5 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,42 +1,26 @@ -use crate::toktree::{append_bias, TokTrie}; - -pub trait Recognizer { - fn append(&self, byte: u8) -> Self; - fn allowed(&self, byte: u8) -> bool; -} +use crate::toktree::{append_bias, TokTrie, Recognizer}; #[inline(never)] -pub fn compute_bias(trie: &TokTrie, rec: (impl Recognizer + Copy), logits: &mut [f32]) { +pub fn compute_bias(trie: &TokTrie, rec: &mut impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); append_bias(trie, rec, logits); } -#[derive(Copy, Clone)] -pub struct Uppercase { - len: usize, -} +pub struct LenExcluder {} -impl Uppercase { - pub fn new() -> Self { - Uppercase { len: 0 } +impl Recognizer for LenExcluder { + fn initial(&mut self) -> u32 { + 0 } -} -impl Recognizer for Uppercase { #[inline(never)] - fn append(&self, _byte: u8) -> Self { - Uppercase { len: self.len + 1 } + fn append(&mut self, state: u32, _byte: u8) -> u32 { + state + 1 } #[inline(never)] - fn allowed(&self, byte: u8) -> bool { - byte != (('z' as usize + self.len) & 0xff) as u8 - // let ch = _byte as char; - // if self.len < 2 { - // 'A' <= ch && ch <= 'Z' - // } else { - // 'a' <= ch && ch <= 'z' - // } + fn allowed(&mut self, state: u32, byte: u8) -> bool { + byte != (('z' as u32 + state) & 0xff) as u8 } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index a169531c..52930d15 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,11 +1,16 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::{ - recognizer::Recognizer, - rx::{box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId}, +use crate::rx::{ + box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, }; +pub trait Recognizer { + fn initial(&mut self) -> S; + fn append(&mut self, state: S, byte: u8) -> S; + fn allowed(&mut self, state: S, byte: u8) -> bool; +} + pub struct TokTrie { info: TokRxInfo, token_offsets: Vec, @@ -268,9 +273,9 @@ impl TokTrie { } } -pub fn append_bias(trie: &TokTrie, rec0: T, logits: &mut [f32]) { +pub fn append_bias(trie: &TokTrie, r: &mut impl Recognizer, logits: &mut [f32]) { let n = trie.root(); - let mut stack_buf = [rec0; 130]; + let mut stack_buf = [r.initial(); 130]; let mut stack_ptr = 1; let defl_tok = trie.vocab_size() as u32; unsafe { @@ -279,10 +284,10 @@ pub fn append_bias(trie: &TokTrie, rec0: T, logits: &mut [ while p < endp { let n = &*p; let b = n.byte(); - let rec = &stack_buf[stack_ptr - 1]; - if rec.allowed(b) { + let rec = stack_buf[stack_ptr - 1]; + if r.allowed(rec, b) { logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; - stack_buf[stack_ptr] = rec.append(b); + stack_buf[stack_ptr] = r.append(rec, b); stack_ptr += 1; if n.subtree_size() == 1 { stack_ptr -= n.num_parents(); From 73667a2f556900911f97d3cea6a864886b1be1f4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 00:36:30 +0000 Subject: [PATCH 038/301] start on GvmRecognizer --- gvm_abi/src/recognizer.rs | 97 ++++++++++++++++++++++++++++----------- gvm_abi/src/toktree.rs | 4 +- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index b24893f5..66f5bf6f 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,9 +1,17 @@ -use crate::toktree::{append_bias, TokTrie, Recognizer}; +use crate::{ + toktree::{append_bias, Recognizer, TokTrie}, + wprintln, GuidanceVm, GuidanceVmHelper, +}; #[inline(never)] -pub fn compute_bias(trie: &TokTrie, rec: &mut impl Recognizer, logits: &mut [f32]) { +pub fn compute_bias( + trie: &TokTrie, + rec: &mut impl Recognizer, + state: S, + logits: &mut [f32], +) { logits.iter_mut().for_each(|x| *x = -100.0); - append_bias(trie, rec, logits); + append_bias(trie, rec, state, logits); } pub struct LenExcluder {} @@ -24,27 +32,62 @@ impl Recognizer for LenExcluder { } } -// pub struct PrefixEnum { -// prefix_ch: u8, -// depth: u32, -// allowed: Vec>, -// } - -// impl Recognizer for PrefixEnum { -// fn append1(&self, byte: u8) -> Self { -// let mut depth = self.depth; -// for b in bytes { -// if depth > 0 { -// depth += 1; -// } -// if depth == 0 && *b == self.prefix_ch { -// depth = 1 -// } -// } -// todo!() -// } - -// fn allowed(&self) -> Vec> { -// self.allowed.clone() -// } -// } +pub struct GvmRecognizer<'a, S: Copy, R: Recognizer + Clone> { + pub helper: GuidanceVmHelper, + pub rec: &'a mut R, + pub trie: &'a TokTrie, + pub state: S, +} + +impl<'a, S: Copy, R: Recognizer + Clone> GvmRecognizer<'a, S, R> { + pub fn from_recognizer(trie: &'a TokTrie, rec: &'a mut R) -> Self { + let state = rec.initial(); + GvmRecognizer { + helper: GuidanceVmHelper::new(), + rec, + state, + trie, + } + } + + fn compute(&mut self) { + compute_bias( + self.trie, + self.rec, + self.state, + &mut self.helper.logit_biases, + ); + } +} + +impl<'a, S: Copy, R: Recognizer + Clone> GuidanceVm for GvmRecognizer<'a, S, R> { + fn gvm_clone(&mut self) -> Self { + GvmRecognizer { + helper: self.helper.clone(), + rec: self.rec, + state: self.state, + trie: self.trie, + } + } + + fn gvm_process_prompt(&mut self) { + wprintln!("prompt, {} tokens", self.helper.prompt_length); + // the regex doesn't care about the prompt + self.state = self.rec.initial(); + self.compute(); + } + + fn gvm_append_token(&mut self, token: u32) { + // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); + let bytes = self.trie.token(token); + for b in bytes { + self.state = self.rec.append(self.state, *b); + } + + // save the token, just in case + let toks = &mut self.helper.tokens; + toks.push(token); + + self.compute(); + } +} diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 52930d15..5a7f8f60 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -273,9 +273,9 @@ impl TokTrie { } } -pub fn append_bias(trie: &TokTrie, r: &mut impl Recognizer, logits: &mut [f32]) { +pub fn append_bias(trie: &TokTrie, r: &mut impl Recognizer, state: S, logits: &mut [f32]) { let n = trie.root(); - let mut stack_buf = [r.initial(); 130]; + let mut stack_buf = [state; 130]; let mut stack_ptr = 1; let defl_tok = trie.vocab_size() as u32; unsafe { From 1e7d4b726971b91435889fa5f692866b90d2668c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 00:48:41 +0000 Subject: [PATCH 039/301] switch to boxes --- gvm_abi/src/recognizer.rs | 37 ++++++++++++++++++------------------- gvm_abi/src/toktree.rs | 8 ++++---- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 66f5bf6f..6d4e20c2 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use crate::{ toktree::{append_bias, Recognizer, TokTrie}, wprintln, GuidanceVm, GuidanceVmHelper, @@ -6,7 +8,7 @@ use crate::{ #[inline(never)] pub fn compute_bias( trie: &TokTrie, - rec: &mut impl Recognizer, + rec: &impl Recognizer, state: S, logits: &mut [f32], ) { @@ -17,31 +19,31 @@ pub fn compute_bias( pub struct LenExcluder {} impl Recognizer for LenExcluder { - fn initial(&mut self) -> u32 { + fn initial(&self) -> u32 { 0 } #[inline(never)] - fn append(&mut self, state: u32, _byte: u8) -> u32 { + fn append(&self, state: u32, _byte: u8) -> u32 { state + 1 } #[inline(never)] - fn allowed(&mut self, state: u32, byte: u8) -> bool { + fn allowed(&self, state: u32, byte: u8) -> bool { byte != (('z' as u32 + state) & 0xff) as u8 } } -pub struct GvmRecognizer<'a, S: Copy, R: Recognizer + Clone> { +pub struct GvmRecognizer + Clone> { pub helper: GuidanceVmHelper, - pub rec: &'a mut R, - pub trie: &'a TokTrie, + pub rec: Rc>, + pub trie: Rc>, pub state: S, } -impl<'a, S: Copy, R: Recognizer + Clone> GvmRecognizer<'a, S, R> { - pub fn from_recognizer(trie: &'a TokTrie, rec: &'a mut R) -> Self { - let state = rec.initial(); +impl + Clone> GvmRecognizer { + pub fn from_recognizer(trie: Rc>, rec: Rc>) -> Self { + let state = rec.as_ref().initial(); GvmRecognizer { helper: GuidanceVmHelper::new(), rec, @@ -51,22 +53,19 @@ impl<'a, S: Copy, R: Recognizer + Clone> GvmRecognizer<'a, S, R> { } fn compute(&mut self) { - compute_bias( - self.trie, - self.rec, - self.state, - &mut self.helper.logit_biases, - ); + let trie = (*self.trie).as_ref(); + let rec = (*self.rec).as_ref(); + compute_bias(trie, rec, self.state, &mut self.helper.logit_biases); } } -impl<'a, S: Copy, R: Recognizer + Clone> GuidanceVm for GvmRecognizer<'a, S, R> { +impl + Clone> GuidanceVm for GvmRecognizer { fn gvm_clone(&mut self) -> Self { GvmRecognizer { helper: self.helper.clone(), - rec: self.rec, + rec: self.rec.clone(), state: self.state, - trie: self.trie, + trie: self.trie.clone(), } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 5a7f8f60..aa53dbc4 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -6,9 +6,9 @@ use crate::rx::{ }; pub trait Recognizer { - fn initial(&mut self) -> S; - fn append(&mut self, state: S, byte: u8) -> S; - fn allowed(&mut self, state: S, byte: u8) -> bool; + fn initial(&self) -> S; + fn append(&self, state: S, byte: u8) -> S; + fn allowed(&self, state: S, byte: u8) -> bool; } pub struct TokTrie { @@ -273,7 +273,7 @@ impl TokTrie { } } -pub fn append_bias(trie: &TokTrie, r: &mut impl Recognizer, state: S, logits: &mut [f32]) { +pub fn append_bias(trie: &TokTrie, r: &impl Recognizer, state: S, logits: &mut [f32]) { let n = trie.root(); let mut stack_buf = [state; 130]; let mut stack_ptr = 1; From 92460fd6e62ccd6d7133d156d85f305f4bb9b972 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 01:04:05 +0000 Subject: [PATCH 040/301] hook it all up --- gvm_abi/src/lib.rs | 6 ++++-- gvm_abi/src/recognizer.rs | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index d134d345..14f506ea 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,8 +1,8 @@ pub mod printing; +pub mod recognizer; pub mod rx; pub mod rxvm; pub mod toktree; -pub mod recognizer; /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); @@ -45,7 +45,9 @@ impl GuidanceVmHelper { } } pub fn gvm_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { - self.logit_biases.resize(size as usize, 0.0); + // we keep one more logit at the end as a placeholder to avoid branching in + // the inner loop of append_bias + self.logit_biases.resize((size + 1) as usize, 0.0); self.logit_biases.as_mut_ptr() } pub fn gvm_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 6d4e20c2..850603ec 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -34,14 +34,14 @@ impl Recognizer for LenExcluder { } } -pub struct GvmRecognizer + Clone> { +pub struct GvmRecognizer> { pub helper: GuidanceVmHelper, pub rec: Rc>, pub trie: Rc>, pub state: S, } -impl + Clone> GvmRecognizer { +impl> GvmRecognizer { pub fn from_recognizer(trie: Rc>, rec: Rc>) -> Self { let state = rec.as_ref().initial(); GvmRecognizer { @@ -59,7 +59,7 @@ impl + Clone> GvmRecognizer { } } -impl + Clone> GuidanceVm for GvmRecognizer { +impl> GuidanceVm for GvmRecognizer { fn gvm_clone(&mut self) -> Self { GvmRecognizer { helper: self.helper.clone(), From b24ae03dfda1ad82a75b6649cfde247b8a2b3375 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 15:38:51 +0000 Subject: [PATCH 041/301] limit unsafe code --- gvm_abi/src/recognizer.rs | 4 +- gvm_abi/src/toktree.rs | 122 +++++++++++++------------------------- 2 files changed, 44 insertions(+), 82 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 850603ec..5c401a50 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,7 +1,7 @@ use std::rc::Rc; use crate::{ - toktree::{append_bias, Recognizer, TokTrie}, + toktree::{Recognizer, TokTrie}, wprintln, GuidanceVm, GuidanceVmHelper, }; @@ -13,7 +13,7 @@ pub fn compute_bias( logits: &mut [f32], ) { logits.iter_mut().for_each(|x| *x = -100.0); - append_bias(trie, rec, state, logits); + trie.append_bias(rec, state, logits); } pub struct LenExcluder {} diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index aa53dbc4..51d96d0d 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -50,21 +50,6 @@ impl TrieNode { } } - #[inline(always)] - unsafe fn next(&self) -> *const TrieNode { - self.ptr().add(self.subtree_size()) - } - - #[inline(always)] - unsafe fn ptr(&self) -> *const TrieNode { - self as *const TrieNode - } - - #[inline(always)] - unsafe fn child0(&self) -> *const TrieNode { - self.ptr().add(1) - } - #[inline(always)] pub fn byte(&self) -> u8 { (self.bits & 0xff) as u8 @@ -137,6 +122,22 @@ impl TokTrie { r } + fn node_offset(&self, n: &TrieNode) -> usize { + let off = unsafe { (n as *const TrieNode).offset_from(self.root() as *const TrieNode) }; + assert!(off >= 0); + let off = off as usize; + assert!(off < self.nodes.len()); + off + } + + fn node_child0(&self, n: &TrieNode) -> usize { + return self.node_offset(n) + 1; + } + + fn next_node(&self, n: &TrieNode) -> usize { + return self.node_offset(n) + n.subtree_size(); + } + pub fn info(&self) -> &TokRxInfo { &self.info } @@ -174,28 +175,26 @@ impl TokTrie { r } - fn validate_node(&self, n: &TrieNode, ep: *const TrieNode, used: &mut [bool]) { + fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) { if let Some(tok) = n.token_id() { assert!(tok < self.info.vocab_size); assert!(!used[tok as usize]); used[tok as usize] = true; } - unsafe { - let endp = n.next(); - assert!(endp <= ep); - let mut p = n.child0(); - while p < endp { - let n = &*p; - p = n.next(); - self.validate_node(n, endp, used) - } + let endp = self.next_node(n); + assert!(endp <= ep); + let mut p = self.node_child0(n); + while p < endp { + let n = &self.nodes[p]; + p = self.next_node(n); + self.validate_node(n, endp, used) } } fn validate(&self) { self.validate_node( self.root(), - self.nodes.as_ptr_range().end, + self.next_node(self.root()), &mut vec![false; self.info.vocab_size as usize], ); for idx in 0..self.info.vocab_size { @@ -249,15 +248,14 @@ impl TokTrie { } pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> { - unsafe { - let mut p = n.child0(); - let endp = n.next(); - while p < endp { - if (*p).byte() == byte { - return Some(&*p); - } - p = (*p).next(); + let mut p = self.node_child0(n); + let endp = self.next_node(n); + while p < endp { + let n = &self.nodes[p]; + if n.byte() == byte { + return Some(n); } + p = self.next_node(n); } None } @@ -271,18 +269,16 @@ impl TokTrie { } Some(n) } -} -pub fn append_bias(trie: &TokTrie, r: &impl Recognizer, state: S, logits: &mut [f32]) { - let n = trie.root(); - let mut stack_buf = [state; 130]; - let mut stack_ptr = 1; - let defl_tok = trie.vocab_size() as u32; - unsafe { - let mut p = n.child0(); - let endp = n.next(); + pub fn append_bias(&self, r: &impl Recognizer, state: S, logits: &mut [f32]) { + let n = self.root(); + let mut stack_buf = [state; 130]; + let mut stack_ptr = 1; + let defl_tok = self.vocab_size() as u32; + let mut p = self.node_child0(n); + let endp = self.next_node(n); while p < endp { - let n = &*p; + let n = &self.nodes[p]; let b = n.byte(); let rec = stack_buf[stack_ptr - 1]; if r.allowed(rec, b) { @@ -292,9 +288,9 @@ pub fn append_bias(trie: &TokTrie, r: &impl Recognizer, state: S, lo if n.subtree_size() == 1 { stack_ptr -= n.num_parents(); } - p = n.child0(); + p += 1; } else { - p = n.next(); + p += n.subtree_size(); stack_ptr -= n.num_parents() - 1; } } @@ -302,40 +298,6 @@ pub fn append_bias(trie: &TokTrie, r: &impl Recognizer, state: S, lo } } -fn walk_core(n: &TrieNode) -> u32 { - let mut sum = 0; - let mut stack_buf: [(*const TrieNode, *const TrieNode); 130] = [(0 as _, 0 as _); 130]; - let mut stack_ptr = 0; - unsafe { - stack_buf[stack_ptr] = (n.child0(), n.next()); - stack_ptr += 1; - loop { - stack_ptr -= 1; - let (mut p, mut endp) = stack_buf[stack_ptr]; - while p < endp { - let n = &*p; - p = n.next(); - sum += n.subtree_size() as u32; - if n.subtree_size() > 1 { - stack_buf[stack_ptr] = (p, endp); - stack_ptr += 1; - endp = p; - p = n.child0(); - } - } - if stack_ptr == 0 { - break; - } - } - } - sum -} - -pub fn walk(trie: &TokTrie) -> u32 { - let n = trie.root(); - walk_core(n) -} - #[repr(C)] pub struct TokenizerBin { magic: u32, From 865c6d85491324484a3b35e1b7cb3104d45d91df Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 15:43:50 +0000 Subject: [PATCH 042/301] move byte-twiddling to 'bytes' from 'rx' --- gvm_abi/src/bytes.rs | 53 +++++++++++++++++++++++++++++++++ gvm_abi/src/lib.rs | 1 + gvm_abi/src/rx.rs | 67 +++--------------------------------------- gvm_abi/src/toktree.rs | 2 +- 4 files changed, 59 insertions(+), 64 deletions(-) create mode 100644 gvm_abi/src/bytes.rs diff --git a/gvm_abi/src/bytes.rs b/gvm_abi/src/bytes.rs new file mode 100644 index 00000000..64112cb6 --- /dev/null +++ b/gvm_abi/src/bytes.rs @@ -0,0 +1,53 @@ +use std::{mem::size_of, slice::from_raw_parts}; + +pub type TokenId = u32; + +#[repr(C)] +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct TokRxInfo { + pub vocab_size: u32, + pub tok_eos: TokenId, +} + + +pub fn clone_vec_as_bytes(input: &Vec) -> Vec { + unsafe { + let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); + byte_slice.to_vec() + } +} + +pub fn clone_as_bytes(input: &T) -> Vec { + unsafe { + let byte_slice = from_raw_parts(input as *const T as *const u8, size_of::()); + byte_slice.to_vec() + } +} + +pub fn box_from_bytes(bytes: &[u8]) -> Box { + if bytes.len() != size_of::() { + panic!("T: got {} bytes, needed {}", bytes.len(), size_of::()); + } + let mut t: Box = Box::new(unsafe { std::mem::zeroed() }); + unsafe { + std::ptr::copy_nonoverlapping(bytes.as_ptr(), &mut *t as *mut T as *mut u8, size_of::()); + } + t +} + +pub fn vec_from_bytes(bytes: &[u8]) -> Vec { + if bytes.len() % size_of::() != 0 { + panic!( + "vecT: got {} bytes, needed mult of {}", + bytes.len(), + size_of::() + ); + } + let num_elements = bytes.len() / size_of::(); + let mut result = Vec::with_capacity(num_elements); + unsafe { + result.set_len(num_elements); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), result.as_mut_ptr() as *mut u8, bytes.len()); + } + result +} diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 14f506ea..0727a1e2 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,3 +1,4 @@ +pub mod bytes; pub mod printing; pub mod recognizer; pub mod rx; diff --git a/gvm_abi/src/rx.rs b/gvm_abi/src/rx.rs index 1a052726..ccfd904e 100644 --- a/gvm_abi/src/rx.rs +++ b/gvm_abi/src/rx.rs @@ -1,10 +1,8 @@ -use std::{ - mem::{self, size_of}, - ptr, - slice::from_raw_parts, -}; +use std::{mem::size_of, slice::from_raw_parts}; -pub type TokenId = u32; +use crate::bytes::{clone_as_bytes, clone_vec_as_bytes, TokRxInfo}; + +pub type TokenId = crate::bytes::TokenId; pub type Transition = (StateOffset, TokenSetOffset); #[derive(Clone, Copy, PartialEq, Eq)] @@ -37,63 +35,6 @@ struct TokRxHeader { align: [u32; 0], } -#[repr(C)] -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct TokRxInfo { - pub vocab_size: u32, - pub tok_eos: TokenId, -} - -pub fn clone_vec_as_bytes(input: &Vec) -> Vec { - unsafe { - let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); - byte_slice.to_vec() - } -} - -pub fn clone_as_bytes(input: &T) -> Vec { - unsafe { - let byte_slice = from_raw_parts(input as *const T as *const u8, size_of::()); - byte_slice.to_vec() - } -} - -pub fn box_from_bytes(bytes: &[u8]) -> Box { - if bytes.len() != mem::size_of::() { - panic!( - "T: got {} bytes, needed {}", - bytes.len(), - mem::size_of::() - ); - } - let mut t: Box = Box::new(unsafe { mem::zeroed() }); - unsafe { - ptr::copy_nonoverlapping( - bytes.as_ptr(), - &mut *t as *mut T as *mut u8, - mem::size_of::(), - ); - } - t -} - -pub fn vec_from_bytes(bytes: &[u8]) -> Vec { - if bytes.len() % mem::size_of::() != 0 { - panic!( - "vecT: got {} bytes, needed mult of {}", - bytes.len(), - mem::size_of::() - ); - } - let num_elements = bytes.len() / mem::size_of::(); - let mut result = Vec::with_capacity(num_elements); - unsafe { - result.set_len(num_elements); - std::ptr::copy_nonoverlapping(bytes.as_ptr(), result.as_mut_ptr() as *mut u8, bytes.len()); - } - result -} - impl TokRxHeader { pub const MAGIC: u32 = 0x6623f10b; pub const SIZE: u32 = size_of::() as u32; diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 51d96d0d..97cb4594 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -1,7 +1,7 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::rx::{ +use crate::bytes::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, }; From 3a7d83a34a0bf5870ad848a26bb64b3467b9fa46 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 15:51:48 +0000 Subject: [PATCH 043/301] node_children() iterator --- gvm_abi/src/toktree.rs | 55 ++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 97cb4594..d9e51351 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -130,10 +130,6 @@ impl TokTrie { off } - fn node_child0(&self, n: &TrieNode) -> usize { - return self.node_offset(n) + 1; - } - fn next_node(&self, n: &TrieNode) -> usize { return self.node_offset(n) + n.subtree_size(); } @@ -183,11 +179,8 @@ impl TokTrie { } let endp = self.next_node(n); assert!(endp <= ep); - let mut p = self.node_child0(n); - while p < endp { - let n = &self.nodes[p]; - p = self.next_node(n); - self.validate_node(n, endp, used) + for child in self.node_children(n) { + self.validate_node(child, endp, used); } } @@ -248,18 +241,23 @@ impl TokTrie { } pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> { - let mut p = self.node_child0(n); - let endp = self.next_node(n); - while p < endp { - let n = &self.nodes[p]; - if n.byte() == byte { - return Some(n); + for child in self.node_children(n) { + if child.byte() == byte { + return Some(child); } - p = self.next_node(n); } None } + pub fn node_children(&self, n: &TrieNode) -> NodeChildren { + let off = self.node_offset(n); + NodeChildren { + trie: self, + current_offset: off + 1, + end_offset: off + n.subtree_size(), + } + } + pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> { for &byte in bytes { n = match self.child_at_byte(n, byte) { @@ -275,8 +273,9 @@ impl TokTrie { let mut stack_buf = [state; 130]; let mut stack_ptr = 1; let defl_tok = self.vocab_size() as u32; - let mut p = self.node_child0(n); - let endp = self.next_node(n); + let off = self.node_offset(n); + let mut p = off + 1; + let endp = off + n.subtree_size(); while p < endp { let n = &self.nodes[p]; let b = n.byte(); @@ -298,6 +297,26 @@ impl TokTrie { } } +pub struct NodeChildren<'a> { + trie: &'a TokTrie, + current_offset: usize, + end_offset: usize, +} + +impl<'a> Iterator for NodeChildren<'a> { + type Item = &'a TrieNode; + + fn next(&mut self) -> Option { + if self.current_offset < self.end_offset { + let node = &self.trie.nodes[self.current_offset]; + self.current_offset += node.subtree_size(); + Some(node) + } else { + None + } + } +} + #[repr(C)] pub struct TokenizerBin { magic: u32, From 3087bb03501a66055a76fc56bd09262d59d4e2b7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 17:05:47 +0000 Subject: [PATCH 044/301] re-jig interfaces --- gvm_abi/src/recognizer.rs | 89 ++++++++++++++++++++++++++------------- gvm_abi/src/toktree.rs | 23 +++++----- 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 5c401a50..f0195fb1 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,24 +1,13 @@ -use std::rc::Rc; +use std::{cell::RefCell, rc::Rc}; use crate::{ toktree::{Recognizer, TokTrie}, wprintln, GuidanceVm, GuidanceVmHelper, }; -#[inline(never)] -pub fn compute_bias( - trie: &TokTrie, - rec: &impl Recognizer, - state: S, - logits: &mut [f32], -) { - logits.iter_mut().for_each(|x| *x = -100.0); - trie.append_bias(rec, state, logits); -} - pub struct LenExcluder {} -impl Recognizer for LenExcluder { +impl FunctionalRecognizer for LenExcluder { fn initial(&self) -> u32 { 0 } @@ -34,37 +23,32 @@ impl Recognizer for LenExcluder { } } -pub struct GvmRecognizer> { +pub struct GvmRecognizer { pub helper: GuidanceVmHelper, - pub rec: Rc>, + pub rec: RefCell, pub trie: Rc>, - pub state: S, } -impl> GvmRecognizer { - pub fn from_recognizer(trie: Rc>, rec: Rc>) -> Self { - let state = rec.as_ref().initial(); +impl GvmRecognizer { + pub fn from_recognizer(trie: Rc>, rec: R) -> Self { GvmRecognizer { helper: GuidanceVmHelper::new(), - rec, - state, + rec: RefCell::new(rec), trie, } } fn compute(&mut self) { - let trie = (*self.trie).as_ref(); - let rec = (*self.rec).as_ref(); - compute_bias(trie, rec, self.state, &mut self.helper.logit_biases); + let rec = &mut *self.rec.get_mut(); + self.trie.compute_bias(rec, &mut self.helper.logit_biases); } } -impl> GuidanceVm for GvmRecognizer { +impl GuidanceVm for GvmRecognizer { fn gvm_clone(&mut self) -> Self { GvmRecognizer { helper: self.helper.clone(), - rec: self.rec.clone(), - state: self.state, + rec: RefCell::new((*self.rec.borrow()).clone()), trie: self.trie.clone(), } } @@ -72,15 +56,16 @@ impl> GuidanceVm for GvmRecognizer { fn gvm_process_prompt(&mut self) { wprintln!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt - self.state = self.rec.initial(); self.compute(); } fn gvm_append_token(&mut self, token: u32) { // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); let bytes = self.trie.token(token); + + let rec = &mut *self.rec.get_mut(); for b in bytes { - self.state = self.rec.append(self.state, *b); + rec.push_byte(*b) } // save the token, just in case @@ -90,3 +75,49 @@ impl> GuidanceVm for GvmRecognizer { self.compute(); } } + +pub trait FunctionalRecognizer { + fn initial(&self) -> S; + fn append(&self, state: S, byte: u8) -> S; + fn allowed(&self, state: S, byte: u8) -> bool; +} + +#[derive(Clone)] +pub struct StackRecognizer> { + rec: R, + stack: Vec, + stack_ptr: usize, +} + +impl> StackRecognizer { + pub fn from(rec: R) -> Self { + let stack = vec![rec.initial(); 130]; + StackRecognizer { + rec, + stack, + stack_ptr: 0, + } + } + + pub fn reset(&mut self) { + self.stack_ptr = 0; + self.stack[0] = self.rec.initial(); + } +} + +impl> Recognizer for StackRecognizer { + fn push_byte(&mut self, byte: u8) { + let state = self.stack[self.stack_ptr]; + let state = self.rec.append(state, byte); + self.stack_ptr += 1; + self.stack[self.stack_ptr] = state; + } + + fn pop_bytes(&mut self, num: usize) { + self.stack_ptr -= num; + } + + fn byte_allowed(&mut self, byte: u8) -> bool { + self.rec.allowed(self.stack[self.stack_ptr], byte) + } +} diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index d9e51351..749c5aa2 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -5,10 +5,10 @@ use crate::bytes::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, }; -pub trait Recognizer { - fn initial(&self) -> S; - fn append(&self, state: S, byte: u8) -> S; - fn allowed(&self, state: S, byte: u8) -> bool; +pub trait Recognizer { + fn push_byte(&mut self, byte: u8); + fn pop_bytes(&mut self, num: usize); + fn byte_allowed(&mut self, byte: u8) -> bool; } pub struct TokTrie { @@ -268,10 +268,9 @@ impl TokTrie { Some(n) } - pub fn append_bias(&self, r: &impl Recognizer, state: S, logits: &mut [f32]) { + pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { + logits.iter_mut().for_each(|x| *x = -100.0); let n = self.root(); - let mut stack_buf = [state; 130]; - let mut stack_ptr = 1; let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); let mut p = off + 1; @@ -279,18 +278,16 @@ impl TokTrie { while p < endp { let n = &self.nodes[p]; let b = n.byte(); - let rec = stack_buf[stack_ptr - 1]; - if r.allowed(rec, b) { + if r.byte_allowed(b) { logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; - stack_buf[stack_ptr] = r.append(rec, b); - stack_ptr += 1; + r.push_byte(b); if n.subtree_size() == 1 { - stack_ptr -= n.num_parents(); + r.pop_bytes(n.num_parents()); } p += 1; } else { p += n.subtree_size(); - stack_ptr -= n.num_parents() - 1; + r.pop_bytes(n.num_parents() - 1); } } //panic!("st: {}", stack_ptr); From ce9c7decbd4e4e18e0c7fdcea1fa37ca29b6c52d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 17:14:40 +0000 Subject: [PATCH 045/301] drop refcell --- gvm_abi/src/recognizer.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index f0195fb1..45d2bc89 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::rc::Rc; use crate::{ toktree::{Recognizer, TokTrie}, @@ -25,7 +25,7 @@ impl FunctionalRecognizer for LenExcluder { pub struct GvmRecognizer { pub helper: GuidanceVmHelper, - pub rec: RefCell, + pub rec: R, pub trie: Rc>, } @@ -33,14 +33,14 @@ impl GvmRecognizer { pub fn from_recognizer(trie: Rc>, rec: R) -> Self { GvmRecognizer { helper: GuidanceVmHelper::new(), - rec: RefCell::new(rec), + rec, trie, } } fn compute(&mut self) { - let rec = &mut *self.rec.get_mut(); - self.trie.compute_bias(rec, &mut self.helper.logit_biases); + self.trie + .compute_bias(&mut self.rec, &mut self.helper.logit_biases); } } @@ -48,7 +48,7 @@ impl GuidanceVm for GvmRecognizer { fn gvm_clone(&mut self) -> Self { GvmRecognizer { helper: self.helper.clone(), - rec: RefCell::new((*self.rec.borrow()).clone()), + rec: self.rec.clone(), trie: self.trie.clone(), } } @@ -62,10 +62,8 @@ impl GuidanceVm for GvmRecognizer { fn gvm_append_token(&mut self, token: u32) { // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); let bytes = self.trie.token(token); - - let rec = &mut *self.rec.get_mut(); for b in bytes { - rec.push_byte(*b) + self.rec.push_byte(*b) } // save the token, just in case @@ -106,6 +104,7 @@ impl> StackRecognizer { } impl> Recognizer for StackRecognizer { + #[inline(always)] fn push_byte(&mut self, byte: u8) { let state = self.stack[self.stack_ptr]; let state = self.rec.append(state, byte); @@ -113,10 +112,12 @@ impl> Recognizer for StackRecognizer { self.stack[self.stack_ptr] = state; } + #[inline(always)] fn pop_bytes(&mut self, num: usize) { self.stack_ptr -= num; } + #[inline(always)] fn byte_allowed(&mut self, byte: u8) -> bool { self.rec.allowed(self.stack[self.stack_ptr], byte) } From d9439a2ef66ab22660bdda772efc3ce2a2f7b11f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 17:37:47 +0000 Subject: [PATCH 046/301] perf fix --- gvm_abi/src/toktree.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 749c5aa2..cdcddcca 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -281,9 +281,12 @@ impl TokTrie { if r.byte_allowed(b) { logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; r.push_byte(b); - if n.subtree_size() == 1 { - r.pop_bytes(n.num_parents()); - } + // note that the inner-if is much faster than outer, due to branch misprediction + r.pop_bytes(if n.subtree_size() == 1 { + n.num_parents() + } else { + 0 + }); p += 1; } else { p += n.subtree_size(); From 73156ff0c6fd3436399aa55733f7a963daf042a5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 18:13:49 +0000 Subject: [PATCH 047/301] perf work --- gvm_abi/src/toktree.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index cdcddcca..73875a10 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -280,13 +280,21 @@ impl TokTrie { let b = n.byte(); if r.byte_allowed(b) { logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; + + // This is slower due to branch mis-prediction: + // if n.subtree_size() == 1 { + // r.pop_bytes(n.num_parents() - 1) + // } else { + // r.push_byte(b) + // } + r.push_byte(b); - // note that the inner-if is much faster than outer, due to branch misprediction r.pop_bytes(if n.subtree_size() == 1 { n.num_parents() } else { 0 }); + p += 1; } else { p += n.subtree_size(); From 83d2689e5f8237cf3067ad2863fd50b0b906c9b5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 21:10:55 +0000 Subject: [PATCH 048/301] add testing harness --- gvm_abi/src/lib.rs | 27 +++++++++++++++++++++++++++ gvm_abi/src/recognizer.rs | 6 +++++- gvm_abi/src/rxvm.rs | 4 ++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index 0727a1e2..b35da53b 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -1,3 +1,5 @@ +use bytes::TokenId; + pub mod bytes; pub mod printing; pub mod recognizer; @@ -66,6 +68,8 @@ pub trait GuidanceVm { fn gvm_process_prompt(&mut self); /// On return, self.helper.logit_biases are supposed to be updated. fn gvm_append_token(&mut self, token: u32); + // Used in testing. + fn get_helper(&mut self) -> &mut GuidanceVmHelper; } #[macro_export] @@ -131,3 +135,26 @@ macro_rules! wprint { $crate::printing::_print(&format!($($arg)*)); }}; } + +pub fn gvm_harness(gvm: &mut impl GuidanceVm, vocab_size: usize, prompt: &[TokenId]) { + let logits = unsafe { + std::slice::from_raw_parts_mut( + gvm.get_helper() + .gvm_get_logit_bias_buffer(vocab_size as u32), + vocab_size, + ) + }; + let prompt_buf = unsafe { + std::slice::from_raw_parts_mut( + gvm.get_helper().gvm_get_prompt_buffer(prompt.len() as u32), + prompt.len(), + ) + }; + prompt_buf.copy_from_slice(&prompt); + gvm.gvm_process_prompt(); + let p0 = logits.iter().filter(|x| **x > -50.0).count(); + wprintln!("res0: {}", p0); + gvm.gvm_append_token(13); + let p1 = logits.iter().filter(|x| **x > -50.0).count(); + wprintln!("res1: {}", p1); +} diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 45d2bc89..dcd2dddc 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -60,8 +60,8 @@ impl GuidanceVm for GvmRecognizer { } fn gvm_append_token(&mut self, token: u32) { - // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); let bytes = self.trie.token(token); + wprintln!("xapp {} {:?}", token, bytes); for b in bytes { self.rec.push_byte(*b) } @@ -72,6 +72,10 @@ impl GuidanceVm for GvmRecognizer { self.compute(); } + + fn get_helper(&mut self) -> &mut GuidanceVmHelper { + &mut self.helper + } } pub trait FunctionalRecognizer { diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index 057bc41d..a5dfea11 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -49,4 +49,8 @@ impl GuidanceVm for RxGvm { wprintln!("{} -> {}", self.state.off, r.state.off); r } + + fn get_helper(&mut self) -> &mut GuidanceVmHelper { + &mut self.helper + } } From 5ffd40d093eda2020406f16328b6e66170abf901 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 22:34:52 +0000 Subject: [PATCH 049/301] it constrains! --- gvm_abi/src/recognizer.rs | 18 +++++++++++++++--- gvm_abi/src/toktree.rs | 6 ++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index dcd2dddc..ecacd140 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{fmt::Debug, rc::Rc}; use crate::{ toktree::{Recognizer, TokTrie}, @@ -39,6 +39,7 @@ impl GvmRecognizer { } fn compute(&mut self) { + // wprintln!("compute"); self.trie .compute_bias(&mut self.rec, &mut self.helper.logit_biases); } @@ -61,10 +62,11 @@ impl GuidanceVm for GvmRecognizer { fn gvm_append_token(&mut self, token: u32) { let bytes = self.trie.token(token); - wprintln!("xapp {} {:?}", token, bytes); + // wprintln!("xapp {} {:?}", token, bytes); for b in bytes { self.rec.push_byte(*b) } + self.rec.collapse(); // save the token, just in case let toks = &mut self.helper.tokens; @@ -107,7 +109,7 @@ impl> StackRecognizer { } } -impl> Recognizer for StackRecognizer { +impl> Recognizer for StackRecognizer { #[inline(always)] fn push_byte(&mut self, byte: u8) { let state = self.stack[self.stack_ptr]; @@ -125,4 +127,14 @@ impl> Recognizer for StackRecognizer { fn byte_allowed(&mut self, byte: u8) -> bool { self.rec.allowed(self.stack[self.stack_ptr], byte) } + + fn trie_finished(&mut self) { + // wprintln!("{:?}", &self.stack[0..=self.stack_ptr]); + assert!(self.stack_ptr == 0); + } + + fn collapse(&mut self) { + self.stack[0] = self.stack[self.stack_ptr]; + self.stack_ptr = 0; + } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 73875a10..53bee265 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -9,6 +9,8 @@ pub trait Recognizer { fn push_byte(&mut self, byte: u8); fn pop_bytes(&mut self, num: usize); fn byte_allowed(&mut self, byte: u8) -> bool; + fn trie_finished(&mut self); + fn collapse(&mut self); } pub struct TokTrie { @@ -111,7 +113,7 @@ impl TokTrie { token_data.extend_from_slice(word); } let mut nodes = Vec::new(); - trie.serialize(&mut nodes, 1); + trie.serialize(&mut nodes, 0); let r = TokTrie { info: info.clone(), token_offsets, @@ -301,7 +303,7 @@ impl TokTrie { r.pop_bytes(n.num_parents() - 1); } } - //panic!("st: {}", stack_ptr); + r.trie_finished(); } } From 90f05067af9810c8588b3d1e5e899529a2a449ec Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 22:50:00 +0000 Subject: [PATCH 050/301] deal with special tokens (incl eos) --- gvm_abi/src/recognizer.rs | 21 +++++++++++++++++---- gvm_abi/src/toktree.rs | 24 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index ecacd140..4f20d6e3 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -1,7 +1,7 @@ use std::{fmt::Debug, rc::Rc}; use crate::{ - toktree::{Recognizer, TokTrie}, + toktree::{Recognizer, SpecialToken, TokTrie}, wprintln, GuidanceVm, GuidanceVmHelper, }; @@ -18,9 +18,17 @@ impl FunctionalRecognizer for LenExcluder { } #[inline(never)] - fn allowed(&self, state: u32, byte: u8) -> bool { + fn byte_allowed(&self, state: u32, byte: u8) -> bool { byte != (('z' as u32 + state) & 0xff) as u8 } + + #[inline(never)] + fn special_allowed(&self, state: u32, tok: SpecialToken) -> bool { + match tok { + SpecialToken::EndOfSentence => state < 10, + _ => false, + } + } } pub struct GvmRecognizer { @@ -83,7 +91,8 @@ impl GuidanceVm for GvmRecognizer { pub trait FunctionalRecognizer { fn initial(&self) -> S; fn append(&self, state: S, byte: u8) -> S; - fn allowed(&self, state: S, byte: u8) -> bool; + fn byte_allowed(&self, state: S, byte: u8) -> bool; + fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; } #[derive(Clone)] @@ -125,7 +134,7 @@ impl> Recognizer for StackRecognizer #[inline(always)] fn byte_allowed(&mut self, byte: u8) -> bool { - self.rec.allowed(self.stack[self.stack_ptr], byte) + self.rec.byte_allowed(self.stack[self.stack_ptr], byte) } fn trie_finished(&mut self) { @@ -137,4 +146,8 @@ impl> Recognizer for StackRecognizer self.stack[0] = self.stack[self.stack_ptr]; self.stack_ptr = 0; } + + fn special_allowed(&mut self, tok: SpecialToken) -> bool { + self.rec.special_allowed(self.stack[self.stack_ptr], tok) + } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 53bee265..09ec46d6 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -5,10 +5,20 @@ use crate::bytes::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, }; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SpecialToken { + Unknown, + Padding, + Separator, + BeginningOfSentence, + EndOfSentence, +} + pub trait Recognizer { fn push_byte(&mut self, byte: u8); fn pop_bytes(&mut self, num: usize); fn byte_allowed(&mut self, byte: u8) -> bool; + fn special_allowed(&mut self, tok: SpecialToken) -> bool; fn trie_finished(&mut self); fn collapse(&mut self); } @@ -140,6 +150,13 @@ impl TokTrie { &self.info } + pub fn special_token(&self, tok: SpecialToken) -> TokenId { + match tok { + SpecialToken::EndOfSentence => self.info.tok_eos, + _ => todo!(), + } + } + pub fn vocab_size(&self) -> usize { self.info.vocab_size as usize } @@ -272,6 +289,13 @@ impl TokTrie { pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); + + for tok in vec![SpecialToken::EndOfSentence] { + if r.special_allowed(tok) { + logits[self.special_token(tok) as usize] = 0.0; + } + } + let n = self.root(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); From e666e378dfdf9102437d8e8d47f6c66911c0765c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Oct 2023 22:56:44 +0000 Subject: [PATCH 051/301] add some docs --- gvm_abi/src/recognizer.rs | 4 ++++ gvm_abi/src/toktree.rs | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index 4f20d6e3..ce3861a2 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -89,9 +89,13 @@ impl GuidanceVm for GvmRecognizer { } pub trait FunctionalRecognizer { + /// Initial state fn initial(&self) -> S; + /// Extend the recognizer with given byte. fn append(&self, state: S, byte: u8) -> S; + /// Check if given byte is allowed in given state. fn byte_allowed(&self, state: S, byte: u8) -> bool; + /// Check if given special token is allowed in given state. fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 09ec46d6..379ad354 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -15,12 +15,19 @@ pub enum SpecialToken { } pub trait Recognizer { + /// stack.push(X) where stack.top() trasitions via byte to X fn push_byte(&mut self, byte: u8); + /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); + /// X = stack.top(); stack.empty(); stack.push(X) + fn collapse(&mut self); + /// check if stack.top() transitions via byte to a viable state fn byte_allowed(&mut self, byte: u8) -> bool; + /// check if stack.top() transitions via tok to a viable state fn special_allowed(&mut self, tok: SpecialToken) -> bool; + /// Called when iteration over the trie is finished + /// Stack has exactly one element then. fn trie_finished(&mut self); - fn collapse(&mut self); } pub struct TokTrie { From 3043273ad04deee2fd7fa66d5cc538e1ef3302b3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 10 Oct 2023 21:25:18 +0000 Subject: [PATCH 052/301] add docs --- gvm_abi/src/toktree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 379ad354..4a68bf98 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -15,7 +15,7 @@ pub enum SpecialToken { } pub trait Recognizer { - /// stack.push(X) where stack.top() trasitions via byte to X + /// If `stack.top()` trasitions via `byte` to `X`, execute `stack.push(X)`. fn push_byte(&mut self, byte: u8); /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); From 47e88aa46644adf02bfb8e7872d2e3cbd09e713a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 11 Oct 2023 21:13:54 +0000 Subject: [PATCH 053/301] rename GVM -> AICI --- gvm_abi/Cargo.lock | 2 +- gvm_abi/Cargo.toml | 4 +-- gvm_abi/src/gvm_iface.h | 38 ++++++++++++++-------------- gvm_abi/src/lib.rs | 52 +++++++++++++++++++-------------------- gvm_abi/src/printing.rs | 8 +++--- gvm_abi/src/recognizer.rs | 24 +++++++++--------- gvm_abi/src/rxvm.rs | 24 +++++++++--------- gvm_abi/src/toktree.rs | 6 ++--- 8 files changed, 79 insertions(+), 79 deletions(-) diff --git a/gvm_abi/Cargo.lock b/gvm_abi/Cargo.lock index d95b4cac..3e8bff87 100644 --- a/gvm_abi/Cargo.lock +++ b/gvm_abi/Cargo.lock @@ -3,5 +3,5 @@ version = 3 [[package]] -name = "gvm_abi" +name = "aici_abi" version = "0.1.0" diff --git a/gvm_abi/Cargo.toml b/gvm_abi/Cargo.toml index ed3df9ec..edbc774f 100644 --- a/gvm_abi/Cargo.toml +++ b/gvm_abi/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "gvm_abi" +name = "aici_abi" version = "0.1.0" edition = "2021" [lib] -name = "gvm_abi" +name = "aici_abi" [dependencies] diff --git a/gvm_abi/src/gvm_iface.h b/gvm_abi/src/gvm_iface.h index 2b0cb7f1..500ae06e 100644 --- a/gvm_abi/src/gvm_iface.h +++ b/gvm_abi/src/gvm_iface.h @@ -7,50 +7,50 @@ typedef uint32_t token_t; // Called first, after instantiating WASM module. -void gvm_init(void); +void aici_init(void); -// Called once per module, to get a GVM for a specific query -Gvm *gvm_create(void); +// Called once per module, to get a AICI for a specific query +Aici *aici_create(void); // If a query is split into several (eg., during beam-search, or when returning several results) -// this is called to get GVM for the sub-query. -Gvm *gvm_clone(Gvm *parent); +// this is called to get AICI for the sub-query. +Aici *aici_clone(Aici *parent); -// These two are called after gvm_create() and gvm_clone() on the fresh GVM. +// These two are called after aici_create() and aici_clone() on the fresh AICI. // They should return the buffers that the WASM code has to allocated and keep around -// until relevant gvm_free(). +// until relevant aici_free(). // Return buffer where the prompt will be written. `size` is number of tokens in the prompt. -token_t *gvm_get_prompt_buffer(Gvm *gvm, uint32_t size); +token_t *aici_get_prompt_buffer(Aici *aici, uint32_t size); // Return the buffer where the WASM code will write logit biases after -// gvm_process_prompt() and gvm_append_token(). +// aici_process_prompt() and aici_append_token(). // Size of number of biases (which equals size of the vocabulary). -float *gvm_get_logit_bias_buffer(Gvm *gvm, uint32_t size); +float *aici_get_logit_bias_buffer(Aici *aici, uint32_t size); -// This called once, when the GVM should process the prompt in its buffer. +// This called once, when the AICI should process the prompt in its buffer. // It should set the values in logit bias buffer. -void gvm_process_prompt(Gvm *gvm); +void aici_process_prompt(Aici *aici); // The logical type (if WASM would allow such things) of this function is: -// float[vocab_size] gvm_process_prompt(Gvm *gvm, token_t[] prompt); +// float[vocab_size] aici_process_prompt(Aici *aici, token_t[] prompt); // This is called after a token is sampled. // It should set the values in logit bias buffer. -void gvm_append_token(Gvm *gvm, token_t tok); +void aici_append_token(Aici *aici, token_t tok); // The logical type (if WASM would allow such things) of this function is: -// float[vocab_size] gvm_append_token(Gvm *gvm, token_t tok); +// float[vocab_size] aici_append_token(Aici *aici, token_t tok); -// This is called for GVMs that no longer needed (eg. because generation completed, +// This is called for AICIs that no longer needed (eg. because generation completed, // or beam-search branch was cut). -void gvm_free(Gvm *gvm); +void aici_free(Aici *aici); // // This interface is available to the WASM binary // // Log a string. -void gvm_host_print(const uint8_t *ptr, uint32_t size); +void aici_host_print(const uint8_t *ptr, uint32_t size); // Read binary representation of TokTrie. // Always returns the size of the trie, will write up to `size` bytes to `dst`. -uint32_t gvm_host_read_token_trie(uint8_t *dst, uint32_t size); +uint32_t aici_host_read_token_trie(uint8_t *dst, uint32_t size); diff --git a/gvm_abi/src/lib.rs b/gvm_abi/src/lib.rs index b35da53b..c6052452 100644 --- a/gvm_abi/src/lib.rs +++ b/gvm_abi/src/lib.rs @@ -32,68 +32,68 @@ macro_rules! expose { } #[derive(Clone)] -pub struct GuidanceVmHelper { +pub struct AiciVmHelper { pub tokens: Vec, pub prompt_length: usize, pub logit_biases: Vec, } -// gvm_* are exposed to C in both GuidanceVm and GuidanceVmHelper -impl GuidanceVmHelper { +// aici_* are exposed to C in both AiciVm and AiciVmHelper +impl AiciVmHelper { pub fn new() -> Self { - GuidanceVmHelper { + AiciVmHelper { tokens: Vec::new(), prompt_length: 0, logit_biases: Vec::new(), } } - pub fn gvm_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { + pub fn aici_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { // we keep one more logit at the end as a placeholder to avoid branching in // the inner loop of append_bias self.logit_biases.resize((size + 1) as usize, 0.0); self.logit_biases.as_mut_ptr() } - pub fn gvm_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { + pub fn aici_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { self.prompt_length = size as usize; self.tokens.resize(self.prompt_length, 0); self.tokens.as_mut_ptr() } } -pub trait GuidanceVm { +pub trait AiciVm { /// Create a new instance of VM, based on existing instance, for example when doing beam-search. - fn gvm_clone(&mut self) -> Self; + fn aici_clone(&mut self) -> Self; /// The prompt is in self.helper.tokens. /// On return, self.helper.logit_biases are supposed to be updated. - fn gvm_process_prompt(&mut self); + fn aici_process_prompt(&mut self); /// On return, self.helper.logit_biases are supposed to be updated. - fn gvm_append_token(&mut self, token: u32); + fn aici_append_token(&mut self, token: u32); // Used in testing. - fn get_helper(&mut self) -> &mut GuidanceVmHelper; + fn get_helper(&mut self) -> &mut AiciVmHelper; } #[macro_export] -macro_rules! gvm_expose_all { +macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::gvm_process_prompt() -> ()); - $crate::expose!($struct_name::gvm_append_token(token: u32) -> ()); - $crate::expose!($struct_name::helper::gvm_get_logit_bias_buffer(size: u32) -> *mut f32); - $crate::expose!($struct_name::helper::gvm_get_prompt_buffer(size: u32) -> *mut u32); + $crate::expose!($struct_name::aici_process_prompt() -> ()); + $crate::expose!($struct_name::aici_append_token(token: u32) -> ()); + $crate::expose!($struct_name::helper::aici_get_logit_bias_buffer(size: u32) -> *mut f32); + $crate::expose!($struct_name::helper::aici_get_prompt_buffer(size: u32) -> *mut u32); #[no_mangle] - pub extern "C" fn gvm_create() -> *mut $struct_name { + pub extern "C" fn aici_create() -> *mut $struct_name { let b = Box::new($new); Box::into_raw(b) } #[no_mangle] - pub extern "C" fn gvm_clone(self_: *mut $struct_name) -> *mut $struct_name { - let b = unsafe { (&mut *self_).gvm_clone() }; + pub extern "C" fn aici_clone(self_: *mut $struct_name) -> *mut $struct_name { + let b = unsafe { (&mut *self_).aici_clone() }; Box::into_raw(Box::new(b)) } #[no_mangle] - pub extern "C" fn gvm_free(self_: *mut $struct_name) { + pub extern "C" fn aici_free(self_: *mut $struct_name) { let _drop = unsafe { Box::from_raw(self_) }; } } @@ -136,25 +136,25 @@ macro_rules! wprint { }}; } -pub fn gvm_harness(gvm: &mut impl GuidanceVm, vocab_size: usize, prompt: &[TokenId]) { +pub fn aici_harness(aici: &mut impl AiciVm, vocab_size: usize, prompt: &[TokenId]) { let logits = unsafe { std::slice::from_raw_parts_mut( - gvm.get_helper() - .gvm_get_logit_bias_buffer(vocab_size as u32), + aici.get_helper() + .aici_get_logit_bias_buffer(vocab_size as u32), vocab_size, ) }; let prompt_buf = unsafe { std::slice::from_raw_parts_mut( - gvm.get_helper().gvm_get_prompt_buffer(prompt.len() as u32), + aici.get_helper().aici_get_prompt_buffer(prompt.len() as u32), prompt.len(), ) }; prompt_buf.copy_from_slice(&prompt); - gvm.gvm_process_prompt(); + aici.aici_process_prompt(); let p0 = logits.iter().filter(|x| **x > -50.0).count(); wprintln!("res0: {}", p0); - gvm.gvm_append_token(13); + aici.aici_append_token(13); let p1 = logits.iter().filter(|x| **x > -50.0).count(); wprintln!("res1: {}", p1); } diff --git a/gvm_abi/src/printing.rs b/gvm_abi/src/printing.rs index 0b66df8b..d1f07848 100644 --- a/gvm_abi/src/printing.rs +++ b/gvm_abi/src/printing.rs @@ -2,7 +2,7 @@ use std::io; #[allow(dead_code)] extern "C" { - fn gvm_host_print(ptr: *const u8, len: u32); + fn aici_host_print(ptr: *const u8, len: u32); } #[cfg(not(target_arch = "wasm32"))] @@ -14,7 +14,7 @@ pub struct Printer {} #[cfg(target_arch = "wasm32")] impl io::Write for Printer { fn write(&mut self, buf: &[u8]) -> io::Result { - unsafe { gvm_host_print(buf.as_ptr(), buf.len() as u32) }; + unsafe { aici_host_print(buf.as_ptr(), buf.len() as u32) }; Ok(buf.len()) } @@ -59,7 +59,7 @@ pub fn _print(msg: &str) { #[cfg(target_arch = "wasm32")] { let vec: Vec = msg.into(); - unsafe { gvm_host_print(vec.as_ptr(), vec.len() as u32) }; + unsafe { aici_host_print(vec.as_ptr(), vec.len() as u32) }; } #[cfg(not(target_arch = "wasm32"))] @@ -70,6 +70,6 @@ pub fn _print(msg: &str) { } #[no_mangle] -pub extern "C" fn gvm_init() { +pub extern "C" fn aici_init() { init_panic(); } diff --git a/gvm_abi/src/recognizer.rs b/gvm_abi/src/recognizer.rs index ce3861a2..fd93f70c 100644 --- a/gvm_abi/src/recognizer.rs +++ b/gvm_abi/src/recognizer.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, rc::Rc}; use crate::{ toktree::{Recognizer, SpecialToken, TokTrie}, - wprintln, GuidanceVm, GuidanceVmHelper, + wprintln, AiciVm, AiciVmHelper, }; pub struct LenExcluder {} @@ -31,16 +31,16 @@ impl FunctionalRecognizer for LenExcluder { } } -pub struct GvmRecognizer { - pub helper: GuidanceVmHelper, +pub struct AiciRecognizer { + pub helper: AiciVmHelper, pub rec: R, pub trie: Rc>, } -impl GvmRecognizer { +impl AiciRecognizer { pub fn from_recognizer(trie: Rc>, rec: R) -> Self { - GvmRecognizer { - helper: GuidanceVmHelper::new(), + AiciRecognizer { + helper: AiciVmHelper::new(), rec, trie, } @@ -53,22 +53,22 @@ impl GvmRecognizer { } } -impl GuidanceVm for GvmRecognizer { - fn gvm_clone(&mut self) -> Self { - GvmRecognizer { +impl AiciVm for AiciRecognizer { + fn aici_clone(&mut self) -> Self { + AiciRecognizer { helper: self.helper.clone(), rec: self.rec.clone(), trie: self.trie.clone(), } } - fn gvm_process_prompt(&mut self) { + fn aici_process_prompt(&mut self) { wprintln!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt self.compute(); } - fn gvm_append_token(&mut self, token: u32) { + fn aici_append_token(&mut self, token: u32) { let bytes = self.trie.token(token); // wprintln!("xapp {} {:?}", token, bytes); for b in bytes { @@ -83,7 +83,7 @@ impl GuidanceVm for GvmRecognizer { self.compute(); } - fn get_helper(&mut self) -> &mut GuidanceVmHelper { + fn get_helper(&mut self) -> &mut AiciVmHelper { &mut self.helper } } diff --git a/gvm_abi/src/rxvm.rs b/gvm_abi/src/rxvm.rs index a5dfea11..4942a0e7 100644 --- a/gvm_abi/src/rxvm.rs +++ b/gvm_abi/src/rxvm.rs @@ -1,24 +1,24 @@ use crate::rx::{StateOffset, TokRx}; -use crate::{wprintln, GuidanceVm, GuidanceVmHelper}; +use crate::{wprintln, AiciVm, AiciVmHelper}; -pub struct RxGvm { - pub helper: GuidanceVmHelper, +pub struct RxAici { + pub helper: AiciVmHelper, pub compiled: TokRx, pub state: StateOffset, } -impl RxGvm { +impl RxAici { pub fn from_token_compiled(compiled: TokRx) -> Self { - RxGvm { - helper: GuidanceVmHelper::new(), + RxAici { + helper: AiciVmHelper::new(), compiled, state: StateOffset::START, } } } -impl GuidanceVm for RxGvm { - fn gvm_process_prompt(&mut self) { +impl AiciVm for RxAici { + fn aici_process_prompt(&mut self) { wprintln!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt self.state = StateOffset::START; @@ -26,7 +26,7 @@ impl GuidanceVm for RxGvm { .compute_logit_bias(self.state, &mut self.helper.logit_biases); } - fn gvm_append_token(&mut self, token: u32) { + fn aici_append_token(&mut self, token: u32) { // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); self.state = self.compiled.advance(self.state, token); @@ -40,8 +40,8 @@ impl GuidanceVm for RxGvm { } // implement by hand for now - we may need some special processing here - fn gvm_clone(&mut self) -> Self { - let r = RxGvm { + fn aici_clone(&mut self) -> Self { + let r = RxAici { helper: self.helper.clone(), compiled: self.compiled.clone(), state: self.state.clone(), @@ -50,7 +50,7 @@ impl GuidanceVm for RxGvm { r } - fn get_helper(&mut self) -> &mut GuidanceVmHelper { + fn get_helper(&mut self) -> &mut AiciVmHelper { &mut self.helper } } diff --git a/gvm_abi/src/toktree.rs b/gvm_abi/src/toktree.rs index 4a68bf98..e9306b73 100644 --- a/gvm_abi/src/toktree.rs +++ b/gvm_abi/src/toktree.rs @@ -97,16 +97,16 @@ impl TrieNode { #[allow(dead_code)] extern "C" { - fn gvm_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; + fn aici_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; } impl TokTrie { pub fn from_env() -> Self { #[cfg(target_arch = "wasm32")] unsafe { - let size = gvm_host_read_token_trie(0 as _, 0); + let size = aici_host_read_token_trie(0 as _, 0); let mut buffer = vec![0u8; size as usize]; - gvm_host_read_token_trie(buffer.as_mut_ptr(), size); + aici_host_read_token_trie(buffer.as_mut_ptr(), size); Self::from_bytes(&buffer) } From 0ceea1456b791f0f0b72becf2e1e1bb34a790087 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 11 Oct 2023 21:17:00 +0000 Subject: [PATCH 054/301] rename folders to aici --- {gvm_abi => aici_abi}/.cargo/config.toml | 0 {gvm_abi => aici_abi}/Cargo.lock | 0 {gvm_abi => aici_abi}/Cargo.toml | 0 {gvm_abi => aici_abi}/src/bytes.rs | 0 {gvm_abi => aici_abi}/src/gvm_iface.h | 0 {gvm_abi => aici_abi}/src/lib.rs | 0 {gvm_abi => aici_abi}/src/printing.rs | 0 {gvm_abi => aici_abi}/src/recognizer.rs | 0 {gvm_abi => aici_abi}/src/rx.rs | 0 {gvm_abi => aici_abi}/src/rxvm.rs | 0 {gvm_abi => aici_abi}/src/toktree.rs | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename {gvm_abi => aici_abi}/.cargo/config.toml (100%) rename {gvm_abi => aici_abi}/Cargo.lock (100%) rename {gvm_abi => aici_abi}/Cargo.toml (100%) rename {gvm_abi => aici_abi}/src/bytes.rs (100%) rename {gvm_abi => aici_abi}/src/gvm_iface.h (100%) rename {gvm_abi => aici_abi}/src/lib.rs (100%) rename {gvm_abi => aici_abi}/src/printing.rs (100%) rename {gvm_abi => aici_abi}/src/recognizer.rs (100%) rename {gvm_abi => aici_abi}/src/rx.rs (100%) rename {gvm_abi => aici_abi}/src/rxvm.rs (100%) rename {gvm_abi => aici_abi}/src/toktree.rs (100%) diff --git a/gvm_abi/.cargo/config.toml b/aici_abi/.cargo/config.toml similarity index 100% rename from gvm_abi/.cargo/config.toml rename to aici_abi/.cargo/config.toml diff --git a/gvm_abi/Cargo.lock b/aici_abi/Cargo.lock similarity index 100% rename from gvm_abi/Cargo.lock rename to aici_abi/Cargo.lock diff --git a/gvm_abi/Cargo.toml b/aici_abi/Cargo.toml similarity index 100% rename from gvm_abi/Cargo.toml rename to aici_abi/Cargo.toml diff --git a/gvm_abi/src/bytes.rs b/aici_abi/src/bytes.rs similarity index 100% rename from gvm_abi/src/bytes.rs rename to aici_abi/src/bytes.rs diff --git a/gvm_abi/src/gvm_iface.h b/aici_abi/src/gvm_iface.h similarity index 100% rename from gvm_abi/src/gvm_iface.h rename to aici_abi/src/gvm_iface.h diff --git a/gvm_abi/src/lib.rs b/aici_abi/src/lib.rs similarity index 100% rename from gvm_abi/src/lib.rs rename to aici_abi/src/lib.rs diff --git a/gvm_abi/src/printing.rs b/aici_abi/src/printing.rs similarity index 100% rename from gvm_abi/src/printing.rs rename to aici_abi/src/printing.rs diff --git a/gvm_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs similarity index 100% rename from gvm_abi/src/recognizer.rs rename to aici_abi/src/recognizer.rs diff --git a/gvm_abi/src/rx.rs b/aici_abi/src/rx.rs similarity index 100% rename from gvm_abi/src/rx.rs rename to aici_abi/src/rx.rs diff --git a/gvm_abi/src/rxvm.rs b/aici_abi/src/rxvm.rs similarity index 100% rename from gvm_abi/src/rxvm.rs rename to aici_abi/src/rxvm.rs diff --git a/gvm_abi/src/toktree.rs b/aici_abi/src/toktree.rs similarity index 100% rename from gvm_abi/src/toktree.rs rename to aici_abi/src/toktree.rs From 8fcecd7f99d0ec0dd5454dcb3fc289d7d8863217 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 11 Oct 2023 21:18:42 +0000 Subject: [PATCH 055/301] more renames --- aici_abi/src/{gvm_iface.h => aici_iface.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename aici_abi/src/{gvm_iface.h => aici_iface.h} (100%) diff --git a/aici_abi/src/gvm_iface.h b/aici_abi/src/aici_iface.h similarity index 100% rename from aici_abi/src/gvm_iface.h rename to aici_abi/src/aici_iface.h From b24232ad6668fafc6de27aac5d01d7c950eb03e3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 12 Oct 2023 00:21:32 +0000 Subject: [PATCH 056/301] add AnythingGoes --- aici_abi/src/recognizer.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index fd93f70c..58440748 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -155,3 +155,24 @@ impl> Recognizer for StackRecognizer self.rec.special_allowed(self.stack[self.stack_ptr], tok) } } + +#[derive(Clone)] +pub struct AnythingGoes {} + +impl FunctionalRecognizer<()> for AnythingGoes { + fn initial(&self) -> () { + () + } + + fn append(&self, state: (), _byte: u8) -> () { + state + } + + fn byte_allowed(&self, _state: (), _byte: u8) -> bool { + true + } + + fn special_allowed(&self, _state: (), _tok: SpecialToken) -> bool { + true + } +} From 5c7b36ee16b8d6361a14019f76c503a0419ed8eb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 12 Oct 2023 00:21:52 +0000 Subject: [PATCH 057/301] call Recognizer through vtable (slower) --- aici_abi/src/toktree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index e9306b73..57f3f02a 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -294,7 +294,7 @@ impl TokTrie { Some(n) } - pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { + pub fn compute_bias(&self, r: &mut dyn Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); for tok in vec![SpecialToken::EndOfSentence] { From b6732acc634284dabe4fbb38938830ba8083d23a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 12 Oct 2023 18:43:41 +0000 Subject: [PATCH 058/301] hook up ast_runner --- aici_abi/src/toktree.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 57f3f02a..8f67e600 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -175,6 +175,32 @@ impl TokTrie { &self.token_data[off..(off + len as usize)] } + pub fn token_id(&self, bytes: &[u8]) -> Option { + let (tok, len) = self.prefix_token_id(bytes); + // wprintln!("tok_id {:?} {:?} {:?} ", bytes, tok, len); + if len == bytes.len() { + Some(tok) + } else { + None + } + } + + pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) { + assert!(bytes.len() > 0); + let mut last = (0, 0); + let mut n = self.root(); + for (idx, byte) in bytes.iter().enumerate() { + n = match self.child_at_byte(n, *byte) { + Some(n) => n, + None => break, + }; + if let Some(tok) = n.token_id() { + last = (tok, idx + 1); + } + } + return last; + } + pub fn from_bytes(bytes: &[u8]) -> Self { let pref = std::mem::size_of::(); let hd = *box_from_bytes::(&bytes[0..pref]); @@ -296,7 +322,10 @@ impl TokTrie { pub fn compute_bias(&self, r: &mut dyn Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); + self.add_bias(r, logits) + } + pub fn add_bias(&self, r: &mut dyn Recognizer, logits: &mut [f32]) { for tok in vec![SpecialToken::EndOfSentence] { if r.special_allowed(tok) { logits[self.special_token(tok) as usize] = 0.0; From fcfc1c6b37dd7559475b0fe63cbf067280b638a4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 13 Oct 2023 01:08:30 +0000 Subject: [PATCH 059/301] hook up args --- aici_abi/src/aici_iface.h | 3 +++ aici_abi/src/arg.rs | 17 +++++++++++++++++ aici_abi/src/lib.rs | 1 + 3 files changed, 21 insertions(+) create mode 100644 aici_abi/src/arg.rs diff --git a/aici_abi/src/aici_iface.h b/aici_abi/src/aici_iface.h index 500ae06e..131e098e 100644 --- a/aici_abi/src/aici_iface.h +++ b/aici_abi/src/aici_iface.h @@ -54,3 +54,6 @@ void aici_host_print(const uint8_t *ptr, uint32_t size); // Read binary representation of TokTrie. // Always returns the size of the trie, will write up to `size` bytes to `dst`. uint32_t aici_host_read_token_trie(uint8_t *dst, uint32_t size); + +// Similar, for argument passed by the user (typically JSON). +uint32_t aici_host_read_arg(uint8_t *dst, uint32_t size); diff --git a/aici_abi/src/arg.rs b/aici_abi/src/arg.rs new file mode 100644 index 00000000..01e6bedb --- /dev/null +++ b/aici_abi/src/arg.rs @@ -0,0 +1,17 @@ +#[allow(dead_code)] +extern "C" { + fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; +} + +pub fn arg_bytes() -> Vec { + #[cfg(target_arch = "wasm32")] + unsafe { + let size = aici_host_read_arg(0 as _, 0); + let mut buffer = vec![0u8; size as usize]; + aici_host_read_arg(buffer.as_mut_ptr(), size); + return buffer; + } + + #[cfg(not(target_arch = "wasm32"))] + std::fs::read("arg.json").unwrap() +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index c6052452..0d51fa0d 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -6,6 +6,7 @@ pub mod recognizer; pub mod rx; pub mod rxvm; pub mod toktree; +pub mod arg; /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); From 81bef0da5ff4829747c4240f75ebc5146c9dd292 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 13 Oct 2023 22:54:49 +0000 Subject: [PATCH 060/301] drop aici_clone() --- aici_abi/src/aici_iface.h | 12 ++---------- aici_abi/src/lib.rs | 18 +++--------------- aici_abi/src/recognizer.rs | 8 -------- aici_abi/src/rxvm.rs | 11 ----------- 4 files changed, 5 insertions(+), 44 deletions(-) diff --git a/aici_abi/src/aici_iface.h b/aici_abi/src/aici_iface.h index 131e098e..990678c1 100644 --- a/aici_abi/src/aici_iface.h +++ b/aici_abi/src/aici_iface.h @@ -9,14 +9,10 @@ typedef uint32_t token_t; // Called first, after instantiating WASM module. void aici_init(void); -// Called once per module, to get a AICI for a specific query +// Called once per module, to get an AICI for a specific query Aici *aici_create(void); -// If a query is split into several (eg., during beam-search, or when returning several results) -// this is called to get AICI for the sub-query. -Aici *aici_clone(Aici *parent); - -// These two are called after aici_create() and aici_clone() on the fresh AICI. +// These two are called after aici_create() on the fresh AICI. // They should return the buffers that the WASM code has to allocated and keep around // until relevant aici_free(). @@ -40,10 +36,6 @@ void aici_append_token(Aici *aici, token_t tok); // The logical type (if WASM would allow such things) of this function is: // float[vocab_size] aici_append_token(Aici *aici, token_t tok); -// This is called for AICIs that no longer needed (eg. because generation completed, -// or beam-search branch was cut). -void aici_free(Aici *aici); - // // This interface is available to the WASM binary // diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 0d51fa0d..7723c5b2 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,12 +1,12 @@ use bytes::TokenId; +pub mod arg; pub mod bytes; pub mod printing; pub mod recognizer; pub mod rx; pub mod rxvm; pub mod toktree; -pub mod arg; /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); @@ -62,8 +62,6 @@ impl AiciVmHelper { } pub trait AiciVm { - /// Create a new instance of VM, based on existing instance, for example when doing beam-search. - fn aici_clone(&mut self) -> Self; /// The prompt is in self.helper.tokens. /// On return, self.helper.logit_biases are supposed to be updated. fn aici_process_prompt(&mut self); @@ -86,17 +84,6 @@ macro_rules! aici_expose_all { let b = Box::new($new); Box::into_raw(b) } - - #[no_mangle] - pub extern "C" fn aici_clone(self_: *mut $struct_name) -> *mut $struct_name { - let b = unsafe { (&mut *self_).aici_clone() }; - Box::into_raw(Box::new(b)) - } - - #[no_mangle] - pub extern "C" fn aici_free(self_: *mut $struct_name) { - let _drop = unsafe { Box::from_raw(self_) }; - } } } @@ -147,7 +134,8 @@ pub fn aici_harness(aici: &mut impl AiciVm, vocab_size: usize, prompt: &[TokenId }; let prompt_buf = unsafe { std::slice::from_raw_parts_mut( - aici.get_helper().aici_get_prompt_buffer(prompt.len() as u32), + aici.get_helper() + .aici_get_prompt_buffer(prompt.len() as u32), prompt.len(), ) }; diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 58440748..40ddce33 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -54,14 +54,6 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn aici_clone(&mut self) -> Self { - AiciRecognizer { - helper: self.helper.clone(), - rec: self.rec.clone(), - trie: self.trie.clone(), - } - } - fn aici_process_prompt(&mut self) { wprintln!("prompt, {} tokens", self.helper.prompt_length); // the regex doesn't care about the prompt diff --git a/aici_abi/src/rxvm.rs b/aici_abi/src/rxvm.rs index 4942a0e7..c7ac6096 100644 --- a/aici_abi/src/rxvm.rs +++ b/aici_abi/src/rxvm.rs @@ -39,17 +39,6 @@ impl AiciVm for RxAici { .compute_logit_bias(self.state, &mut self.helper.logit_biases); } - // implement by hand for now - we may need some special processing here - fn aici_clone(&mut self) -> Self { - let r = RxAici { - helper: self.helper.clone(), - compiled: self.compiled.clone(), - state: self.state.clone(), - }; - wprintln!("{} -> {}", self.state.off, r.state.off); - r - } - fn get_helper(&mut self) -> &mut AiciVmHelper { &mut self.helper } From 67c58d5efe14a4cd4e95a1bb018b2c3a4182af7c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 10:51:45 -0700 Subject: [PATCH 061/301] tokenizers in host --- aici_abi/src/aici_iface.h | 3 +++ aici_abi/src/bytes.rs | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/aici_abi/src/aici_iface.h b/aici_abi/src/aici_iface.h index 990678c1..c0d73624 100644 --- a/aici_abi/src/aici_iface.h +++ b/aici_abi/src/aici_iface.h @@ -49,3 +49,6 @@ uint32_t aici_host_read_token_trie(uint8_t *dst, uint32_t size); // Similar, for argument passed by the user (typically JSON). uint32_t aici_host_read_arg(uint8_t *dst, uint32_t size); + +// Tokenize given UTF8 string. `dst_size` is in elements, not bytes. Returns number of generated tokens. +uint32_t aici_host_tokenize(const uint8_t *src, uint32_t src_size, uint32_t *dst, uint32_t dst_size); \ No newline at end of file diff --git a/aici_abi/src/bytes.rs b/aici_abi/src/bytes.rs index 64112cb6..e50c95c0 100644 --- a/aici_abi/src/bytes.rs +++ b/aici_abi/src/bytes.rs @@ -10,7 +10,7 @@ pub struct TokRxInfo { } -pub fn clone_vec_as_bytes(input: &Vec) -> Vec { +pub fn clone_vec_as_bytes(input: &[T]) -> Vec { unsafe { let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); byte_slice.to_vec() From 6a98bcf7a3dd812038822b921aa5f7bb8bc690bb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 10:55:11 -0700 Subject: [PATCH 062/301] move stuff around --- aici_abi/src/arg.rs | 17 ----------------- aici_abi/src/{printing.rs => host.rs} | 18 ++++++++++++++++++ aici_abi/src/lib.rs | 11 +++++------ 3 files changed, 23 insertions(+), 23 deletions(-) delete mode 100644 aici_abi/src/arg.rs rename aici_abi/src/{printing.rs => host.rs} (79%) diff --git a/aici_abi/src/arg.rs b/aici_abi/src/arg.rs deleted file mode 100644 index 01e6bedb..00000000 --- a/aici_abi/src/arg.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[allow(dead_code)] -extern "C" { - fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; -} - -pub fn arg_bytes() -> Vec { - #[cfg(target_arch = "wasm32")] - unsafe { - let size = aici_host_read_arg(0 as _, 0); - let mut buffer = vec![0u8; size as usize]; - aici_host_read_arg(buffer.as_mut_ptr(), size); - return buffer; - } - - #[cfg(not(target_arch = "wasm32"))] - std::fs::read("arg.json").unwrap() -} diff --git a/aici_abi/src/printing.rs b/aici_abi/src/host.rs similarity index 79% rename from aici_abi/src/printing.rs rename to aici_abi/src/host.rs index d1f07848..054e1c50 100644 --- a/aici_abi/src/printing.rs +++ b/aici_abi/src/host.rs @@ -73,3 +73,21 @@ pub fn _print(msg: &str) { pub extern "C" fn aici_init() { init_panic(); } + +#[allow(dead_code)] +extern "C" { + fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; +} + +pub fn arg_bytes() -> Vec { + #[cfg(target_arch = "wasm32")] + unsafe { + let size = aici_host_read_arg(0 as _, 0); + let mut buffer = vec![0u8; size as usize]; + aici_host_read_arg(buffer.as_mut_ptr(), size); + return buffer; + } + + #[cfg(not(target_arch = "wasm32"))] + std::fs::read("arg.json").unwrap() +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 7723c5b2..f1f5447a 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,8 +1,7 @@ use bytes::TokenId; -pub mod arg; +pub mod host; pub mod bytes; -pub mod printing; pub mod recognizer; pub mod rx; pub mod rxvm; @@ -109,18 +108,18 @@ macro_rules! include_bytes_aligned { #[macro_export] macro_rules! wprintln { () => { - $crate::printing::_print("\n") + $crate::host::_print("\n") }; ($($arg:tt)*) => {{ - $crate::printing::_print(&format!($($arg)*)); - $crate::printing::_print("\n"); + $crate::host::_print(&format!($($arg)*)); + $crate::host::_print("\n"); }}; } #[macro_export] macro_rules! wprint { ($($arg:tt)*) => {{ - $crate::printing::_print(&format!($($arg)*)); + $crate::host::_print(&format!($($arg)*)); }}; } From 67c94432fa73a6ca20bedd235392247b42d4cf6d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 10:58:44 -0700 Subject: [PATCH 063/301] cleanup host --- aici_abi/src/host.rs | 20 +++++++++++++++----- aici_abi/src/toktree.rs | 26 ++++++++------------------ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 054e1c50..0b85ba70 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -3,6 +3,8 @@ use std::io; #[allow(dead_code)] extern "C" { fn aici_host_print(ptr: *const u8, len: u32); + fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; + fn aici_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; } #[cfg(not(target_arch = "wasm32"))] @@ -74,11 +76,6 @@ pub extern "C" fn aici_init() { init_panic(); } -#[allow(dead_code)] -extern "C" { - fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; -} - pub fn arg_bytes() -> Vec { #[cfg(target_arch = "wasm32")] unsafe { @@ -91,3 +88,16 @@ pub fn arg_bytes() -> Vec { #[cfg(not(target_arch = "wasm32"))] std::fs::read("arg.json").unwrap() } + +pub fn trie_bytes() -> Vec { + #[cfg(target_arch = "wasm32")] + unsafe { + let size = aici_host_read_token_trie(0 as _, 0); + let mut buffer = vec![0u8; size as usize]; + aici_host_read_token_trie(buffer.as_mut_ptr(), size); + buffer + } + + #[cfg(not(target_arch = "wasm32"))] + std::fs::read("tokenizer.bin").unwrap() +} diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 8f67e600..0bde76cb 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -1,8 +1,11 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 -use crate::bytes::{ - box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, +use crate::{ + bytes::{ + box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, + }, + host::trie_bytes, }; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -95,23 +98,10 @@ impl TrieNode { } } -#[allow(dead_code)] -extern "C" { - fn aici_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; -} - impl TokTrie { - pub fn from_env() -> Self { - #[cfg(target_arch = "wasm32")] - unsafe { - let size = aici_host_read_token_trie(0 as _, 0); - let mut buffer = vec![0u8; size as usize]; - aici_host_read_token_trie(buffer.as_mut_ptr(), size); - Self::from_bytes(&buffer) - } - - #[cfg(not(target_arch = "wasm32"))] - Self::from_bytes(&std::fs::read("tokenizer.bin").unwrap()) + pub fn from_host() -> Self { + let buffer = trie_bytes(); + Self::from_bytes(&buffer) } pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { From a0411d9a9e32d68f43247fa09a841cb6e44769a5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 13:14:32 -0700 Subject: [PATCH 064/301] tokenize fixed steps --- aici_abi/src/bytes.rs | 2 +- aici_abi/src/host.rs | 31 ++++++++++++++++++++++++++++++- aici_abi/src/lib.rs | 23 +++++++++++++++++++++-- aici_abi/src/toktree.rs | 4 ++-- 4 files changed, 54 insertions(+), 6 deletions(-) diff --git a/aici_abi/src/bytes.rs b/aici_abi/src/bytes.rs index e50c95c0..11b45dc2 100644 --- a/aici_abi/src/bytes.rs +++ b/aici_abi/src/bytes.rs @@ -1,6 +1,6 @@ use std::{mem::size_of, slice::from_raw_parts}; -pub type TokenId = u32; +pub(crate) type TokenId = u32; #[repr(C)] #[derive(Clone, PartialEq, Eq, Debug)] diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 0b85ba70..1d23332f 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,10 +1,21 @@ use std::io; +use crate::bytes::TokenId; + #[allow(dead_code)] extern "C" { + // Log a string. fn aici_host_print(ptr: *const u8, len: u32); - fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; + + // Read binary representation of TokTrie. + // Always returns the size of the trie, will write up to `size` bytes to `dst`. fn aici_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; + + // Similar, for argument passed by the user (typically JSON). + fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; + + // Tokenize given UTF8 string. `dst_size` is in elements, not bytes. Returns number of generated tokens. + fn aici_host_tokenize(src: *const u8, src_size: u32, dst: *mut u32, dst_size: u32) -> u32; } #[cfg(not(target_arch = "wasm32"))] @@ -101,3 +112,21 @@ pub fn trie_bytes() -> Vec { #[cfg(not(target_arch = "wasm32"))] std::fs::read("tokenizer.bin").unwrap() } + +pub fn tokenize(s: &str) -> Vec { + // fn aici_host_tokenize(src: *const u8, src_size: u32, dst: *mut u32, dst_size: u32) -> u32; + let slen = s.len() as u32; + let cap = slen / 3 + 10; + let mut res = Vec::with_capacity(cap as usize); + let len = unsafe { aici_host_tokenize(s.as_ptr(), slen, res.as_mut_ptr(), cap) }; + if len > res.len() as u32 { + // unlikely... + res = Vec::with_capacity(len as usize); + unsafe { aici_host_tokenize(s.as_ptr(), slen, res.as_mut_ptr(), len) }; + } + unsafe { + res.set_len(len as usize); + } + // trim size + res.clone() +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index f1f5447a..d17bd46a 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,12 +1,16 @@ -use bytes::TokenId; +use std::rc::Rc; + +use toktree::{SpecialToken, TokTrie}; -pub mod host; pub mod bytes; +pub mod host; pub mod recognizer; pub mod rx; pub mod rxvm; pub mod toktree; +pub type TokenId = bytes::TokenId; + /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); /// Generates "C" function: @@ -36,6 +40,7 @@ pub struct AiciVmHelper { pub tokens: Vec, pub prompt_length: usize, pub logit_biases: Vec, + pub trie: Rc>, } // aici_* are exposed to C in both AiciVm and AiciVmHelper @@ -45,6 +50,7 @@ impl AiciVmHelper { tokens: Vec::new(), prompt_length: 0, logit_biases: Vec::new(), + trie: Rc::new(Box::new(TokTrie::from_host())), } } pub fn aici_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { @@ -58,6 +64,19 @@ impl AiciVmHelper { self.tokens.resize(self.prompt_length, 0); self.tokens.as_mut_ptr() } + + pub fn all_disallowed(&mut self) { + self.logit_biases.iter_mut().for_each(|x| *x = -100.0); + } + + pub fn allow_one(&mut self, tok: TokenId) { + self.all_disallowed(); + self.logit_biases[tok as usize] = 0.0; + } + + pub fn allow_eos(&mut self) { + self.allow_one(self.trie.special_token(SpecialToken::EndOfSentence)); + } } pub trait AiciVm { diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 0bde76cb..7d4e41ea 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -310,12 +310,12 @@ impl TokTrie { Some(n) } - pub fn compute_bias(&self, r: &mut dyn Recognizer, logits: &mut [f32]) { + pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); self.add_bias(r, logits) } - pub fn add_bias(&self, r: &mut dyn Recognizer, logits: &mut [f32]) { + pub fn add_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { for tok in vec![SpecialToken::EndOfSentence] { if r.special_allowed(tok) { logits[self.special_token(tok) as usize] = 0.0; From 4430041f3b8afea28544f024819bdfb6bb3c44e2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 20:32:27 +0000 Subject: [PATCH 065/301] fix tokenizers --- aici_abi/src/host.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 1d23332f..c409f782 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,6 +1,6 @@ use std::io; -use crate::bytes::TokenId; +use crate::{bytes::TokenId, wprintln}; #[allow(dead_code)] extern "C" { @@ -114,7 +114,6 @@ pub fn trie_bytes() -> Vec { } pub fn tokenize(s: &str) -> Vec { - // fn aici_host_tokenize(src: *const u8, src_size: u32, dst: *mut u32, dst_size: u32) -> u32; let slen = s.len() as u32; let cap = slen / 3 + 10; let mut res = Vec::with_capacity(cap as usize); @@ -127,6 +126,7 @@ pub fn tokenize(s: &str) -> Vec { unsafe { res.set_len(len as usize); } + wprintln!("tokenize: '{}' -> {:?}", s, res); // trim size res.clone() } From a2b726e20c9f3429f8b3d255628267c278d83ecb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 16 Oct 2023 23:41:52 +0000 Subject: [PATCH 066/301] rework step handling --- aici_abi/src/lib.rs | 1 - aici_abi/src/recognizer.rs | 4 ++-- aici_abi/src/toktree.rs | 36 ++++++++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index d17bd46a..c73eee7f 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -70,7 +70,6 @@ impl AiciVmHelper { } pub fn allow_one(&mut self, tok: TokenId) { - self.all_disallowed(); self.logit_biases[tok as usize] = 0.0; } diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 40ddce33..a692c234 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -129,7 +129,7 @@ impl> Recognizer for StackRecognizer } #[inline(always)] - fn byte_allowed(&mut self, byte: u8) -> bool { + fn byte_allowed(&self, byte: u8) -> bool { self.rec.byte_allowed(self.stack[self.stack_ptr], byte) } @@ -143,7 +143,7 @@ impl> Recognizer for StackRecognizer self.stack_ptr = 0; } - fn special_allowed(&mut self, tok: SpecialToken) -> bool { + fn special_allowed(&self, tok: SpecialToken) -> bool { self.rec.special_allowed(self.stack[self.stack_ptr], tok) } } diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 7d4e41ea..46d6d8dd 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -25,9 +25,9 @@ pub trait Recognizer { /// X = stack.top(); stack.empty(); stack.push(X) fn collapse(&mut self); /// check if stack.top() transitions via byte to a viable state - fn byte_allowed(&mut self, byte: u8) -> bool; + fn byte_allowed(&self, byte: u8) -> bool; /// check if stack.top() transitions via tok to a viable state - fn special_allowed(&mut self, tok: SpecialToken) -> bool; + fn special_allowed(&self, tok: SpecialToken) -> bool; /// Called when iteration over the trie is finished /// Stack has exactly one element then. fn trie_finished(&mut self); @@ -312,16 +312,40 @@ impl TokTrie { pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { logits.iter_mut().for_each(|x| *x = -100.0); - self.add_bias(r, logits) - } - - pub fn add_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { for tok in vec![SpecialToken::EndOfSentence] { if r.special_allowed(tok) { logits[self.special_token(tok) as usize] = 0.0; } } + self.add_bias(r, logits) + } + pub fn append_token(&self, r: &mut impl Recognizer, t: TokenId) { + let bytes = self.token(t); + for &byte in bytes { + r.push_byte(byte) + } + r.collapse() + } + + pub fn token_allowed(&self, r: &mut impl Recognizer, t: TokenId) -> bool { + let bytes = self.token(t); + let mut num = 0; + let mut ok = true; + for &byte in bytes { + if r.byte_allowed(byte) { + r.push_byte(byte); + num += 1; + } else { + ok = false; + break; + } + } + r.pop_bytes(num); + ok + } + + pub fn add_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { let n = self.root(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); From b06486597cbc94f14140487cf590039c68ea260c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 20 Oct 2023 14:42:33 -0700 Subject: [PATCH 067/301] add try_push_byte() --- aici_abi/src/toktree.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 46d6d8dd..0b4506b5 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -31,6 +31,15 @@ pub trait Recognizer { /// Called when iteration over the trie is finished /// Stack has exactly one element then. fn trie_finished(&mut self); + + fn try_push_byte(&mut self, byte: u8) -> bool { + if self.byte_allowed(byte) { + self.push_byte(byte); + true + } else { + false + } + } } pub struct TokTrie { @@ -333,8 +342,7 @@ impl TokTrie { let mut num = 0; let mut ok = true; for &byte in bytes { - if r.byte_allowed(byte) { - r.push_byte(byte); + if r.try_push_byte(byte) { num += 1; } else { ok = false; @@ -354,17 +362,9 @@ impl TokTrie { while p < endp { let n = &self.nodes[p]; let b = n.byte(); - if r.byte_allowed(b) { + if r.try_push_byte(b) { logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; - // This is slower due to branch mis-prediction: - // if n.subtree_size() == 1 { - // r.pop_bytes(n.num_parents() - 1) - // } else { - // r.push_byte(b) - // } - - r.push_byte(b); r.pop_bytes(if n.subtree_size() == 1 { n.num_parents() } else { From 692b8526c57218acfb45831732e64b6cac0291fc Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 22 Oct 2023 12:16:24 -0700 Subject: [PATCH 068/301] utility functions --- aici_abi/src/toktree.rs | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 0b4506b5..ab72f3b8 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -18,7 +18,7 @@ pub enum SpecialToken { } pub trait Recognizer { - /// If `stack.top()` trasitions via `byte` to `X`, execute `stack.push(X)`. + /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. fn push_byte(&mut self, byte: u8); /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); @@ -167,6 +167,14 @@ impl TokTrie { self.info.vocab_size as usize } + pub fn alloc_logits(&self) -> Vec { + vec![0.0; self.vocab_size() + 1] + } + + pub fn token_str(&self, idx: u32) -> String { + String::from_utf8_lossy(self.token(idx)).to_string() + } + pub fn token(&self, idx: u32) -> &[u8] { let off = self.token_offsets[idx as usize]; let len = off & 0xff; @@ -174,6 +182,37 @@ impl TokTrie { &self.token_data[off..(off + len as usize)] } + pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec { + let mut r = Vec::new(); + if bytes.len() == 0 { + return r; + } + + let mut n = self.root(); + let mut last_tok = None; + let mut last_idx = 0; + let mut idx = 0; + while idx < bytes.len() { + match self.child_at_byte(n, bytes[idx]) { + Some(c) => { + if let Some(tok) = c.token_id() { + last_tok = Some(tok); + last_idx = idx; + } + n = c; + } + None => { + r.push(last_tok.unwrap()); + idx = last_idx; + n = self.root(); + } + } + idx = idx + 1; + } + r.push(last_tok.unwrap()); + r + } + pub fn token_id(&self, bytes: &[u8]) -> Option { let (tok, len) = self.prefix_token_id(bytes); // wprintln!("tok_id {:?} {:?} {:?} ", bytes, tok, len); From 167436a9970015739a5d68eaabc17a96675b043c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 22 Oct 2023 12:16:43 -0700 Subject: [PATCH 069/301] add simple rng --- aici_abi/src/lib.rs | 1 + aici_abi/src/rng.rs | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 aici_abi/src/rng.rs diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index c73eee7f..f0ee0776 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -8,6 +8,7 @@ pub mod recognizer; pub mod rx; pub mod rxvm; pub mod toktree; +pub mod rng; pub type TokenId = bytes::TokenId; diff --git a/aici_abi/src/rng.rs b/aici_abi/src/rng.rs new file mode 100644 index 00000000..add3a1ec --- /dev/null +++ b/aici_abi/src/rng.rs @@ -0,0 +1,34 @@ +pub struct Rng { + state: usize, +} + +impl Rng { + pub fn new(seed: usize) -> Self { + Self { + state: if seed == 0 { 13 } else { seed }, + } + } + + pub fn gen(&mut self) -> usize { + // xor-shift algorithm + let mut x = self.state; + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + self.state = x; + x + } + + pub fn gen_up_to(&mut self, mx: usize) -> usize { + let mut mask = 1; + while mask < mx { + mask = (mask << 1) | 1; + } + loop { + let r = self.gen() & mask; + if r <= mx { + return r; + } + } + } +} From d7010b1e7cc18785fafbe0146b20f05c31e3e4cc Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 22 Oct 2023 12:38:51 -0700 Subject: [PATCH 070/301] add 64-bit version of rng --- aici_abi/src/rng.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/aici_abi/src/rng.rs b/aici_abi/src/rng.rs index add3a1ec..89e8fef7 100644 --- a/aici_abi/src/rng.rs +++ b/aici_abi/src/rng.rs @@ -11,12 +11,24 @@ impl Rng { pub fn gen(&mut self) -> usize { // xor-shift algorithm - let mut x = self.state; - x ^= x << 13; - x ^= x >> 17; - x ^= x << 5; - self.state = x; - x + #[cfg(all(target_pointer_width = "32"))] + { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + self.state = x; + x + } + #[cfg(all(target_pointer_width = "64"))] + { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } } pub fn gen_up_to(&mut self, mx: usize) -> usize { From 85d908699deab2292eaf42cd04f69c6483fafd09 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 22 Oct 2023 16:07:06 -0700 Subject: [PATCH 071/301] use rustc-hash (~2x faster) --- aici_abi/src/rng.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aici_abi/src/rng.rs b/aici_abi/src/rng.rs index 89e8fef7..dbd0010b 100644 --- a/aici_abi/src/rng.rs +++ b/aici_abi/src/rng.rs @@ -11,7 +11,7 @@ impl Rng { pub fn gen(&mut self) -> usize { // xor-shift algorithm - #[cfg(all(target_pointer_width = "32"))] + #[cfg(target_pointer_width = "32")] { let mut x = self.state; x ^= x << 13; @@ -20,7 +20,7 @@ impl Rng { self.state = x; x } - #[cfg(all(target_pointer_width = "64"))] + #[cfg(target_pointer_width = "64")] { let mut x = self.state; x ^= x << 13; From b3503985f5c01c4cfe8deb17a3d7986d8d58b79c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 23 Oct 2023 10:01:28 -0700 Subject: [PATCH 072/301] explicit stack of pstacks --- aici_abi/src/recognizer.rs | 40 ++++++++++++-------------------------- aici_abi/src/toktree.rs | 26 ++++++++++++++----------- 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index a692c234..9b9400f8 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -5,32 +5,6 @@ use crate::{ wprintln, AiciVm, AiciVmHelper, }; -pub struct LenExcluder {} - -impl FunctionalRecognizer for LenExcluder { - fn initial(&self) -> u32 { - 0 - } - - #[inline(never)] - fn append(&self, state: u32, _byte: u8) -> u32 { - state + 1 - } - - #[inline(never)] - fn byte_allowed(&self, state: u32, byte: u8) -> bool { - byte != (('z' as u32 + state) & 0xff) as u8 - } - - #[inline(never)] - fn special_allowed(&self, state: u32, tok: SpecialToken) -> bool { - match tok { - SpecialToken::EndOfSentence => state < 10, - _ => false, - } - } -} - pub struct AiciRecognizer { pub helper: AiciVmHelper, pub rec: R, @@ -129,7 +103,7 @@ impl> Recognizer for StackRecognizer } #[inline(always)] - fn byte_allowed(&self, byte: u8) -> bool { + fn byte_allowed(&mut self, byte: u8) -> bool { self.rec.byte_allowed(self.stack[self.stack_ptr], byte) } @@ -143,9 +117,19 @@ impl> Recognizer for StackRecognizer self.stack_ptr = 0; } - fn special_allowed(&self, tok: SpecialToken) -> bool { + fn special_allowed(&mut self, tok: SpecialToken) -> bool { self.rec.special_allowed(self.stack[self.stack_ptr], tok) } + + #[inline(always)] + fn try_push_byte(&mut self, byte: u8) -> bool { + if self.rec.byte_allowed(self.stack[self.stack_ptr], byte) { + self.push_byte(byte); + true + } else { + false + } + } } #[derive(Clone)] diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index ab72f3b8..55eb1350 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -19,27 +19,31 @@ pub enum SpecialToken { pub trait Recognizer { /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. - fn push_byte(&mut self, byte: u8); + fn push_byte(&mut self, byte: u8) { + if !self.try_push_byte(byte) { + panic!("byte {:?} not allowed", byte as char) + } + } /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); /// X = stack.top(); stack.empty(); stack.push(X) fn collapse(&mut self); /// check if stack.top() transitions via byte to a viable state - fn byte_allowed(&self, byte: u8) -> bool; - /// check if stack.top() transitions via tok to a viable state - fn special_allowed(&self, tok: SpecialToken) -> bool; - /// Called when iteration over the trie is finished - /// Stack has exactly one element then. - fn trie_finished(&mut self); - - fn try_push_byte(&mut self, byte: u8) -> bool { - if self.byte_allowed(byte) { - self.push_byte(byte); + fn byte_allowed(&mut self, byte: u8) -> bool { + if self.try_push_byte(byte) { + self.pop_bytes(1); true } else { false } } + /// check if stack.top() transitions via tok to a viable state + fn special_allowed(&mut self, tok: SpecialToken) -> bool; + /// Called when iteration over the trie is finished + /// Stack has exactly one element then. + fn trie_finished(&mut self); + + fn try_push_byte(&mut self, byte: u8) -> bool; } pub struct TokTrie { From c826d88301ccd12fe5295048424c8aadec987b5d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 23 Oct 2023 22:46:13 +0000 Subject: [PATCH 073/301] add docs for yacc --- aici_abi/src/toktree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 55eb1350..2d2b0312 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -42,7 +42,7 @@ pub trait Recognizer { /// Called when iteration over the trie is finished /// Stack has exactly one element then. fn trie_finished(&mut self); - + /// This combines `push_byte` and `byte_allowed` into one function for performance. fn try_push_byte(&mut self, byte: u8) -> bool; } From 690df142c5fc8367650a44e032bbe26ba5ad11eb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 27 Oct 2023 09:43:52 -0700 Subject: [PATCH 074/301] add AllowToken trait --- aici_abi/src/toktree.rs | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 2d2b0312..5db3088a 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -396,7 +396,7 @@ impl TokTrie { ok } - pub fn add_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { + pub fn add_bias(&self, r: &mut impl Recognizer, mut logits: impl AllowToken) { let n = self.root(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); @@ -406,8 +406,7 @@ impl TokTrie { let n = &self.nodes[p]; let b = n.byte(); if r.try_push_byte(b) { - logits[n.token_id().unwrap_or(defl_tok) as usize] = 0.0; - + logits.allow_token(n.token_id().unwrap_or(defl_tok)); r.pop_bytes(if n.subtree_size() == 1 { n.num_parents() } else { @@ -424,6 +423,24 @@ impl TokTrie { } } +pub trait AllowToken { + fn allow_token(&mut self, tok: TokenId); +} + +impl AllowToken for &mut [f32] { + #[inline(always)] + fn allow_token(&mut self, tok: TokenId) { + self[tok as usize] = 0.0; + } +} + +impl AllowToken for &mut Vec { + #[inline(always)] + fn allow_token(&mut self, tok: TokenId) { + self[tok as usize] = 0.0; + } +} + pub struct NodeChildren<'a> { trie: &'a TokTrie, current_offset: usize, From 1dbefdf0eae6debd3c78a60afcb8103cee76ff74 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 27 Oct 2023 09:46:39 -0700 Subject: [PATCH 075/301] add AllowToken for [u8] --- aici_abi/src/toktree.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 5db3088a..f4fa6cdd 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -441,6 +441,13 @@ impl AllowToken for &mut Vec { } } +impl AllowToken for &mut [u8] { + #[inline(always)] + fn allow_token(&mut self, tok: TokenId) { + self[tok as usize] = 1; + } +} + pub struct NodeChildren<'a> { trie: &'a TokTrie, current_offset: usize, From 164b93eea29e6d5ee9ee5740ddd21a1c49170870 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 27 Oct 2023 17:28:12 +0000 Subject: [PATCH 076/301] simplify allowed-tokens --- aici_abi/src/lib.rs | 16 +++++++++-- aici_abi/src/recognizer.rs | 3 +- aici_abi/src/svob.rs | 59 ++++++++++++++++++++++++++++++++++++++ aici_abi/src/toktree.rs | 36 ++++------------------- 4 files changed, 80 insertions(+), 34 deletions(-) create mode 100644 aici_abi/src/svob.rs diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index f0ee0776..7033d8f3 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,14 +1,16 @@ use std::rc::Rc; +use svob::SimpleVob; use toktree::{SpecialToken, TokTrie}; pub mod bytes; pub mod host; pub mod recognizer; +pub mod rng; pub mod rx; pub mod rxvm; +pub mod svob; pub mod toktree; -pub mod rng; pub type TokenId = bytes::TokenId; @@ -41,6 +43,7 @@ pub struct AiciVmHelper { pub tokens: Vec, pub prompt_length: usize, pub logit_biases: Vec, + pub allowed_tokens: SimpleVob, pub trie: Rc>, } @@ -51,6 +54,7 @@ impl AiciVmHelper { tokens: Vec::new(), prompt_length: 0, logit_biases: Vec::new(), + allowed_tokens: SimpleVob::new(), trie: Rc::new(Box::new(TokTrie::from_host())), } } @@ -58,6 +62,7 @@ impl AiciVmHelper { // we keep one more logit at the end as a placeholder to avoid branching in // the inner loop of append_bias self.logit_biases.resize((size + 1) as usize, 0.0); + self.allowed_tokens.resize(self.logit_biases.len()); self.logit_biases.as_mut_ptr() } pub fn aici_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { @@ -67,16 +72,21 @@ impl AiciVmHelper { } pub fn all_disallowed(&mut self) { - self.logit_biases.iter_mut().for_each(|x| *x = -100.0); + self.allowed_tokens.set_all(false); } pub fn allow_one(&mut self, tok: TokenId) { - self.logit_biases[tok as usize] = 0.0; + self.allowed_tokens.allow_token(tok); } pub fn allow_eos(&mut self) { self.allow_one(self.trie.special_token(SpecialToken::EndOfSentence)); } + + pub fn compute_biases(&mut self) { + self.logit_biases.iter_mut().for_each(|x| *x = -100.0); + self.allowed_tokens.apply_to(&mut self.logit_biases); + } } pub trait AiciVm { diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 9b9400f8..6ca7b763 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -23,7 +23,8 @@ impl AiciRecognizer { fn compute(&mut self) { // wprintln!("compute"); self.trie - .compute_bias(&mut self.rec, &mut self.helper.logit_biases); + .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); + self.helper.compute_biases(); } } diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs new file mode 100644 index 00000000..319e3392 --- /dev/null +++ b/aici_abi/src/svob.rs @@ -0,0 +1,59 @@ +use crate::TokenId; + +#[derive(Clone)] +pub struct SimpleVob { + data: Vec, +} + +const BITS: usize = 32; + +impl SimpleVob { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + pub fn len(&self) -> usize { + self.data.len() * BITS + } + + #[inline(always)] + pub fn allow_token(&mut self, tok: TokenId) { + let idx = tok as usize; + let byte_idx = idx / BITS; + let bit_idx = idx % BITS; + self.data[byte_idx] |= 1 << bit_idx; + } + + pub fn resize(&mut self, size: usize) { + let new_size = size / BITS + 1; + assert!(new_size >= self.data.len()); + self.data.resize(new_size, 0); + } + + #[inline(always)] + pub fn is_allowed(&self, tok: TokenId) -> bool { + let idx = tok as usize; + let byte_idx = idx / 32; + let bit_idx = idx % 32; + (self.data[byte_idx] & (1 << bit_idx)) != 0 + } + + pub fn set_all(&mut self, val: bool) { + let val = if val { !0 } else { 0 }; + self.data.iter_mut().for_each(|x| *x = val); + } + + pub fn apply_to(&self, logits: &mut [f32]) { + for (idx, v) in self.data.iter().enumerate() { + if *v == 0 { + continue; + } + let idx = idx * BITS; + for bit_idx in 0..BITS { + if v & (1 << bit_idx) != 0 { + logits[idx + bit_idx] = 0.0; + } + } + } + } +} diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index f4fa6cdd..b704f091 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -6,6 +6,7 @@ use crate::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, }, host::trie_bytes, + svob::SimpleVob, }; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -362,11 +363,11 @@ impl TokTrie { Some(n) } - pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut [f32]) { - logits.iter_mut().for_each(|x| *x = -100.0); + pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut SimpleVob) { + logits.set_all(false); for tok in vec![SpecialToken::EndOfSentence] { if r.special_allowed(tok) { - logits[self.special_token(tok) as usize] = 0.0; + logits.allow_token(self.special_token(tok)) } } self.add_bias(r, logits) @@ -396,7 +397,7 @@ impl TokTrie { ok } - pub fn add_bias(&self, r: &mut impl Recognizer, mut logits: impl AllowToken) { + pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob) { let n = self.root(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); @@ -406,7 +407,7 @@ impl TokTrie { let n = &self.nodes[p]; let b = n.byte(); if r.try_push_byte(b) { - logits.allow_token(n.token_id().unwrap_or(defl_tok)); + toks.allow_token(n.token_id().unwrap_or(defl_tok)); r.pop_bytes(if n.subtree_size() == 1 { n.num_parents() } else { @@ -423,31 +424,6 @@ impl TokTrie { } } -pub trait AllowToken { - fn allow_token(&mut self, tok: TokenId); -} - -impl AllowToken for &mut [f32] { - #[inline(always)] - fn allow_token(&mut self, tok: TokenId) { - self[tok as usize] = 0.0; - } -} - -impl AllowToken for &mut Vec { - #[inline(always)] - fn allow_token(&mut self, tok: TokenId) { - self[tok as usize] = 0.0; - } -} - -impl AllowToken for &mut [u8] { - #[inline(always)] - fn allow_token(&mut self, tok: TokenId) { - self[tok as usize] = 1; - } -} - pub struct NodeChildren<'a> { trie: &'a TokTrie, current_offset: usize, From 94930e3ead9343c307c15cacd1dc247d769ef1db Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 27 Oct 2023 17:57:02 +0000 Subject: [PATCH 077/301] remove more unused code --- aici_abi/src/lib.rs | 2 - aici_abi/src/rx.rs | 156 ------------------------------------------- aici_abi/src/rxvm.rs | 45 ------------- 3 files changed, 203 deletions(-) delete mode 100644 aici_abi/src/rx.rs delete mode 100644 aici_abi/src/rxvm.rs diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 7033d8f3..741b54ea 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -7,8 +7,6 @@ pub mod bytes; pub mod host; pub mod recognizer; pub mod rng; -pub mod rx; -pub mod rxvm; pub mod svob; pub mod toktree; diff --git a/aici_abi/src/rx.rs b/aici_abi/src/rx.rs deleted file mode 100644 index ccfd904e..00000000 --- a/aici_abi/src/rx.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::{mem::size_of, slice::from_raw_parts}; - -use crate::bytes::{clone_as_bytes, clone_vec_as_bytes, TokRxInfo}; - -pub type TokenId = crate::bytes::TokenId; -pub type Transition = (StateOffset, TokenSetOffset); - -#[derive(Clone, Copy, PartialEq, Eq)] -pub struct TokenSetOffset { - pub off: u32, -} - -pub struct StateDesc { - default_transition: StateOffset, - transitions: &'static [Transition], -} - -#[derive(Clone, Copy, PartialEq, Eq)] -pub struct StateOffset { - pub off: u32, -} - -impl StateOffset { - pub const DEAD: StateOffset = StateOffset { off: 1 }; - pub const START: StateOffset = StateOffset { off: 3 }; -} - -#[repr(C)] -struct TokRxHeader { - magic: u32, - hd_size: u32, - state_bytes: u32, - token_bytes: u32, - info: TokRxInfo, - align: [u32; 0], -} - -impl TokRxHeader { - pub const MAGIC: u32 = 0x6623f10b; - pub const SIZE: u32 = size_of::() as u32; -} - -#[derive(Clone)] -pub struct TokRx { - pub info: &'static TokRxInfo, - pub token_data: &'static [TokenId], - pub state_data: &'static [u32], -} - -impl TokRx { - pub fn deserialize(bytes: &'static [u8]) -> TokRx { - unsafe { - assert!(bytes.len() > TokRxHeader::SIZE as usize); - let hd = (bytes.as_ptr() as *const TokRxHeader).as_ref().unwrap(); - assert!(hd.magic == TokRxHeader::MAGIC); - assert!(hd.hd_size == TokRxHeader::SIZE); - let state_data = from_raw_parts( - bytes.as_ptr().add(TokRxHeader::SIZE as usize) as *const u32, - hd.state_bytes as usize / size_of::(), - ); - let token_data = from_raw_parts( - bytes - .as_ptr() - .add((TokRxHeader::SIZE + hd.state_bytes) as usize) - as *const TokenId, - hd.token_bytes as usize / size_of::(), - ); - TokRx { - info: &hd.info, - state_data, - token_data, - } - } - } - - pub fn serialize( - info: &TokRxInfo, - token_data: &Vec, - state_data: &Vec, - ) -> Vec { - let mut token_bytes = clone_vec_as_bytes(&token_data); - let mut state_bytes = clone_vec_as_bytes(&state_data); - let hd = TokRxHeader { - magic: TokRxHeader::MAGIC, - hd_size: TokRxHeader::SIZE, - info: info.clone(), - state_bytes: state_bytes.len() as u32, - token_bytes: token_bytes.len() as u32, - align: [], - }; - let mut bytes = clone_as_bytes(&hd); - bytes.append(&mut state_bytes); - bytes.append(&mut token_bytes); - bytes - } - - fn token_set(&self, set: TokenSetOffset) -> &'static [TokenId] { - let idx = set.off as usize; - let sz = self.token_data[idx] as usize; - unsafe { from_raw_parts(self.token_data.as_ptr().add(idx + 1), sz) } - } - - fn state_desc(&self, state: StateOffset) -> StateDesc { - let idx = state.off as usize; - let default_transition = StateOffset { - off: self.state_data[idx], - }; - let sz = self.state_data[idx + 1] as usize; - StateDesc { - default_transition, - transitions: unsafe { - from_raw_parts( - self.state_data.as_ptr().add(idx + 2) as *const Transition, - sz, - ) - }, - } - } - - fn state_bias(state: StateOffset) -> f32 { - if state == StateOffset::DEAD { - -100.0 - } else { - 0.0 - } - } - - pub fn compute_logit_bias(&self, state_offset: StateOffset, bias: &mut [f32]) { - let state = self.state_desc(state_offset); - - let init_val = Self::state_bias(state.default_transition); - for idx in 0..bias.len() { - bias[idx] = init_val; - } - - for (st, ts) in state.transitions { - let val = Self::state_bias(*st); - let toks = self.token_set(*ts); - for tok in toks { - bias[*tok as usize] = val; - } - } - } - - pub fn advance(&self, state_offset: StateOffset, token: TokenId) -> StateOffset { - let state = self.state_desc(state_offset); - - for (st, ts) in state.transitions { - if self.token_set(*ts).contains(&token) { - return *st; - } - } - - state.default_transition - } -} diff --git a/aici_abi/src/rxvm.rs b/aici_abi/src/rxvm.rs deleted file mode 100644 index c7ac6096..00000000 --- a/aici_abi/src/rxvm.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::rx::{StateOffset, TokRx}; -use crate::{wprintln, AiciVm, AiciVmHelper}; - -pub struct RxAici { - pub helper: AiciVmHelper, - pub compiled: TokRx, - pub state: StateOffset, -} - -impl RxAici { - pub fn from_token_compiled(compiled: TokRx) -> Self { - RxAici { - helper: AiciVmHelper::new(), - compiled, - state: StateOffset::START, - } - } -} - -impl AiciVm for RxAici { - fn aici_process_prompt(&mut self) { - wprintln!("prompt, {} tokens", self.helper.prompt_length); - // the regex doesn't care about the prompt - self.state = StateOffset::START; - self.compiled - .compute_logit_bias(self.state, &mut self.helper.logit_biases); - } - - fn aici_append_token(&mut self, token: u32) { - // wprintln!("xapp {:?} {} {}", self as *const _, token, self.state.off); - self.state = self.compiled.advance(self.state, token); - - // save the token, just in case - let toks = &mut self.helper.tokens; - toks.push(token); - - // compute biases - self.compiled - .compute_logit_bias(self.state, &mut self.helper.logit_biases); - } - - fn get_helper(&mut self) -> &mut AiciVmHelper { - &mut self.helper - } -} From 24230b8a95f4d115b4e5b0f3be737104c73bb652 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 27 Oct 2023 18:32:16 +0000 Subject: [PATCH 078/301] perf work; add disasm script --- aici_abi/src/toktree.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index b704f091..396fb3e4 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -397,6 +397,7 @@ impl TokTrie { ok } + #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob) { let n = self.root(); let defl_tok = self.vocab_size() as u32; From b17e195c79604f9d51ac2026e5f65b25134802af Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 1 Nov 2023 23:21:56 +0000 Subject: [PATCH 079/301] initial work on ff tokens --- aici_abi/src/host.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index c409f782..c9a04251 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -16,6 +16,12 @@ extern "C" { // Tokenize given UTF8 string. `dst_size` is in elements, not bytes. Returns number of generated tokens. fn aici_host_tokenize(src: *const u8, src_size: u32, dst: *mut u32, dst_size: u32) -> u32; + + // Append fast-forward (FF) token. + // First FF token has to be returned by setting logit bias appropriately. + // Next tokens are added using this interface. + // All FF tokens are then generated in one go. + fn aici_host_ff_token(token: u32); } #[cfg(not(target_arch = "wasm32"))] @@ -130,3 +136,10 @@ pub fn tokenize(s: &str) -> Vec { // trim size res.clone() } + +pub fn ff_token(token: TokenId) { + unsafe { + aici_host_ff_token(token); + } +} + From f7783c660ca69c2b329946ac3d5302ba36390b7f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 2 Nov 2023 23:26:45 +0000 Subject: [PATCH 080/301] generalizing APIs --- aici_abi/src/bytes.rs | 13 +++++- aici_abi/src/host.rs | 86 +++++++++++++++++++++----------------- aici_abi/src/lib.rs | 42 ++----------------- aici_abi/src/recognizer.rs | 35 ++++++++-------- 4 files changed, 80 insertions(+), 96 deletions(-) diff --git a/aici_abi/src/bytes.rs b/aici_abi/src/bytes.rs index 11b45dc2..86a53c12 100644 --- a/aici_abi/src/bytes.rs +++ b/aici_abi/src/bytes.rs @@ -9,7 +9,6 @@ pub struct TokRxInfo { pub tok_eos: TokenId, } - pub fn clone_vec_as_bytes(input: &[T]) -> Vec { unsafe { let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); @@ -51,3 +50,15 @@ pub fn vec_from_bytes(bytes: &[u8]) -> Vec { } result } + +pub fn limit_str(s: &str, max_len: usize) -> String { + limit_bytes(s.as_bytes(), max_len) +} + +pub fn limit_bytes(s: &[u8], max_len: usize) -> String { + if s.len() > max_len { + format!("{}...", String::from_utf8_lossy(&s[0..max_len])) + } else { + String::from_utf8_lossy(s).to_string() + } +} diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index c9a04251..7ec83c8b 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,21 +1,34 @@ use std::io; -use crate::{bytes::TokenId, wprintln}; +use crate::{ + bytes::{vec_from_bytes, TokenId}, + wprintln, +}; + +#[repr(transparent)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct BlobId(u32); #[allow(dead_code)] extern "C" { // Log a string. fn aici_host_print(ptr: *const u8, len: u32); - // Read binary representation of TokTrie. - // Always returns the size of the trie, will write up to `size` bytes to `dst`. - fn aici_host_read_token_trie(ptr: *mut u8, len: u32) -> u32; + // Read binary blob. + // Always returns the size of the blob, will write up to `size` bytes to `dst`. + fn aici_host_read_blob(blob: BlobId, dst: *mut u8, size: u32) -> u32; + + // Return the ID of TokTrie binary representation. + fn aici_host_token_trie() -> BlobId; - // Similar, for argument passed by the user (typically JSON). - fn aici_host_read_arg(ptr: *mut u8, len: u32) -> u32; + // Return the ID of argument passed by the user. + fn aici_host_module_arg() -> BlobId; - // Tokenize given UTF8 string. `dst_size` is in elements, not bytes. Returns number of generated tokens. - fn aici_host_tokenize(src: *const u8, src_size: u32, dst: *mut u32, dst_size: u32) -> u32; + // Return the ID of argument passed by the user. + fn aici_host_tokens() -> BlobId; + + // Tokenize given UTF8 string. The result is only valid until next call to this function. + fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; // Append fast-forward (FF) token. // First FF token has to be returned by setting logit bias appropriately. @@ -24,6 +37,19 @@ extern "C" { fn aici_host_ff_token(token: u32); } +// TODO: add +fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { + let mut buffer = vec![0u8; prefetch_size]; + let prefetch_size = prefetch_size as u32; + let size = unsafe { aici_host_read_blob(blob, buffer.as_mut_ptr(), prefetch_size) }; + buffer.resize(size as usize, 0); + if size > prefetch_size { + // didn't read everything; retry + unsafe { aici_host_read_blob(blob, buffer.as_mut_ptr(), size) }; + } + buffer +} + #[cfg(not(target_arch = "wasm32"))] pub type Printer = std::io::Stdout; @@ -95,46 +121,31 @@ pub extern "C" fn aici_init() { pub fn arg_bytes() -> Vec { #[cfg(target_arch = "wasm32")] - unsafe { - let size = aici_host_read_arg(0 as _, 0); - let mut buffer = vec![0u8; size as usize]; - aici_host_read_arg(buffer.as_mut_ptr(), size); - return buffer; - } + return read_blob(unsafe { aici_host_module_arg() }, 1024); #[cfg(not(target_arch = "wasm32"))] - std::fs::read("arg.json").unwrap() + return std::fs::read("arg.json").unwrap(); } pub fn trie_bytes() -> Vec { #[cfg(target_arch = "wasm32")] - unsafe { - let size = aici_host_read_token_trie(0 as _, 0); - let mut buffer = vec![0u8; size as usize]; - aici_host_read_token_trie(buffer.as_mut_ptr(), size); - buffer - } + return read_blob(unsafe { aici_host_token_trie() }, 0); #[cfg(not(target_arch = "wasm32"))] - std::fs::read("tokenizer.bin").unwrap() + return std::fs::read("tokenizer.bin").unwrap(); +} + +pub fn tokens_arg() -> Vec { + let r = read_blob(unsafe { aici_host_tokens() }, 256); + vec_from_bytes(&r) } pub fn tokenize(s: &str) -> Vec { - let slen = s.len() as u32; - let cap = slen / 3 + 10; - let mut res = Vec::with_capacity(cap as usize); - let len = unsafe { aici_host_tokenize(s.as_ptr(), slen, res.as_mut_ptr(), cap) }; - if len > res.len() as u32 { - // unlikely... - res = Vec::with_capacity(len as usize); - unsafe { aici_host_tokenize(s.as_ptr(), slen, res.as_mut_ptr(), len) }; - } - unsafe { - res.set_len(len as usize); - } - wprintln!("tokenize: '{}' -> {:?}", s, res); - // trim size - res.clone() + let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; + let r = read_blob(id, 4 * (s.len() / 3 + 10)); + let res = vec_from_bytes(&r); + wprintln!("tokenize: {:?} -> {:?}", s, res); + res } pub fn ff_token(token: TokenId) { @@ -142,4 +153,3 @@ pub fn ff_token(token: TokenId) { aici_host_ff_token(token); } } - diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 741b54ea..30adee3e 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -38,8 +38,6 @@ macro_rules! expose { #[derive(Clone)] pub struct AiciVmHelper { - pub tokens: Vec, - pub prompt_length: usize, pub logit_biases: Vec, pub allowed_tokens: SimpleVob, pub trie: Rc>, @@ -49,8 +47,6 @@ pub struct AiciVmHelper { impl AiciVmHelper { pub fn new() -> Self { AiciVmHelper { - tokens: Vec::new(), - prompt_length: 0, logit_biases: Vec::new(), allowed_tokens: SimpleVob::new(), trie: Rc::new(Box::new(TokTrie::from_host())), @@ -63,11 +59,6 @@ impl AiciVmHelper { self.allowed_tokens.resize(self.logit_biases.len()); self.logit_biases.as_mut_ptr() } - pub fn aici_get_prompt_buffer(&mut self, size: u32) -> *mut u32 { - self.prompt_length = size as usize; - self.tokens.resize(self.prompt_length, 0); - self.tokens.as_mut_ptr() - } pub fn all_disallowed(&mut self) { self.allowed_tokens.set_all(false); @@ -88,11 +79,9 @@ impl AiciVmHelper { } pub trait AiciVm { - /// The prompt is in self.helper.tokens. - /// On return, self.helper.logit_biases are supposed to be updated. - fn aici_process_prompt(&mut self); + /// The prompt, single generated token, or all ff tokens, arg in host::tokens_arg(). /// On return, self.helper.logit_biases are supposed to be updated. - fn aici_append_token(&mut self, token: u32); + fn aici_process(&mut self); // Used in testing. fn get_helper(&mut self) -> &mut AiciVmHelper; } @@ -100,10 +89,8 @@ pub trait AiciVm { #[macro_export] macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::aici_process_prompt() -> ()); - $crate::expose!($struct_name::aici_append_token(token: u32) -> ()); + $crate::expose!($struct_name::aici_process() -> ()); $crate::expose!($struct_name::helper::aici_get_logit_bias_buffer(size: u32) -> *mut f32); - $crate::expose!($struct_name::helper::aici_get_prompt_buffer(size: u32) -> *mut u32); #[no_mangle] pub extern "C" fn aici_create() -> *mut $struct_name { @@ -150,26 +137,3 @@ macro_rules! wprint { }}; } -pub fn aici_harness(aici: &mut impl AiciVm, vocab_size: usize, prompt: &[TokenId]) { - let logits = unsafe { - std::slice::from_raw_parts_mut( - aici.get_helper() - .aici_get_logit_bias_buffer(vocab_size as u32), - vocab_size, - ) - }; - let prompt_buf = unsafe { - std::slice::from_raw_parts_mut( - aici.get_helper() - .aici_get_prompt_buffer(prompt.len() as u32), - prompt.len(), - ) - }; - prompt_buf.copy_from_slice(&prompt); - aici.aici_process_prompt(); - let p0 = logits.iter().filter(|x| **x > -50.0).count(); - wprintln!("res0: {}", p0); - aici.aici_append_token(13); - let p1 = logits.iter().filter(|x| **x > -50.0).count(); - wprintln!("res1: {}", p1); -} diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 6ca7b763..84ff3969 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,14 +1,16 @@ use std::{fmt::Debug, rc::Rc}; use crate::{ + host::tokens_arg, toktree::{Recognizer, SpecialToken, TokTrie}, - wprintln, AiciVm, AiciVmHelper, + AiciVm, AiciVmHelper, }; pub struct AiciRecognizer { pub helper: AiciVmHelper, pub rec: R, pub trie: Rc>, + pub is_prompt: bool, } impl AiciRecognizer { @@ -17,6 +19,7 @@ impl AiciRecognizer { helper: AiciVmHelper::new(), rec, trie, + is_prompt: true, } } @@ -29,24 +32,20 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn aici_process_prompt(&mut self) { - wprintln!("prompt, {} tokens", self.helper.prompt_length); - // the regex doesn't care about the prompt - self.compute(); - } - - fn aici_append_token(&mut self, token: u32) { - let bytes = self.trie.token(token); - // wprintln!("xapp {} {:?}", token, bytes); - for b in bytes { - self.rec.push_byte(*b) + fn aici_process(&mut self) { + if self.is_prompt { + // the regex doesn't care about the prompt + self.is_prompt = false; + } else { + for token in tokens_arg() { + let bytes = self.trie.token(token); + // wprintln!("xapp {} {:?}", token, bytes); + for b in bytes { + self.rec.push_byte(*b) + } + self.rec.collapse(); + } } - self.rec.collapse(); - - // save the token, just in case - let toks = &mut self.helper.tokens; - toks.push(token); - self.compute(); } From 2e464527737cbefb3aa5b8939b8ff9f1bcebac1b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 7 Nov 2023 19:52:46 +0000 Subject: [PATCH 081/301] towards JSON ops --- aici_abi/Cargo.lock | 58 ++++++++++++++++++++++++++++++++++++++++++++ aici_abi/Cargo.toml | 1 + aici_abi/src/host.rs | 15 ++++++++++++ aici_abi/src/lib.rs | 28 ++++++++++----------- aici_abi/src/svob.rs | 4 +++ 5 files changed, 91 insertions(+), 15 deletions(-) diff --git a/aici_abi/Cargo.lock b/aici_abi/Cargo.lock index 3e8bff87..bd21f7fd 100644 --- a/aici_abi/Cargo.lock +++ b/aici_abi/Cargo.lock @@ -5,3 +5,61 @@ version = 3 [[package]] name = "aici_abi" version = "0.1.0" +dependencies = [ + "serde", +] + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.192" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.192" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index edbc774f..b1a02809 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" name = "aici_abi" [dependencies] +serde = { version = "1.0.192", features = ["derive"] } diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 7ec83c8b..cc596f60 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -2,6 +2,7 @@ use std::io; use crate::{ bytes::{vec_from_bytes, TokenId}, + svob::SimpleVob, wprintln, }; @@ -24,12 +25,19 @@ extern "C" { // Return the ID of argument passed by the user. fn aici_host_module_arg() -> BlobId; + // Return the ID of argument passed to the process() function. + // It's a JSON serialization of ProcessArg. + fn aici_host_process_arg() -> BlobId; + // Return the ID of argument passed by the user. fn aici_host_tokens() -> BlobId; // Tokenize given UTF8 string. The result is only valid until next call to this function. fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; + // Set logit bias based on bitmask in src. + fn aici_host_return_logits(src: *const u32); + // Append fast-forward (FF) token. // First FF token has to be returned by setting logit bias appropriately. // Next tokens are added using this interface. @@ -153,3 +161,10 @@ pub fn ff_token(token: TokenId) { aici_host_ff_token(token); } } + +pub fn return_logits(vob: &SimpleVob) { + assert!(vob.len() > 0); + unsafe { + aici_host_return_logits(vob.as_ptr()); + } +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 30adee3e..a640234c 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,5 +1,6 @@ use std::rc::Rc; +use serde::{Deserialize, Serialize}; use svob::SimpleVob; use toktree::{SpecialToken, TokTrie}; @@ -12,6 +13,12 @@ pub mod toktree; pub type TokenId = bytes::TokenId; +#[derive(Serialize, Deserialize, Debug)] +pub enum ProcessArg { + Prompt {}, + Gen { tokens: Vec }, +} + /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); /// Generates "C" function: @@ -38,7 +45,6 @@ macro_rules! expose { #[derive(Clone)] pub struct AiciVmHelper { - pub logit_biases: Vec, pub allowed_tokens: SimpleVob, pub trie: Rc>, } @@ -46,19 +52,14 @@ pub struct AiciVmHelper { // aici_* are exposed to C in both AiciVm and AiciVmHelper impl AiciVmHelper { pub fn new() -> Self { + let trie = TokTrie::from_host(); + let mut allowed_tokens = SimpleVob::new(); + allowed_tokens.resize(trie.vocab_size() + 1); AiciVmHelper { - logit_biases: Vec::new(), - allowed_tokens: SimpleVob::new(), - trie: Rc::new(Box::new(TokTrie::from_host())), + allowed_tokens, + trie: Rc::new(Box::new(trie)), } } - pub fn aici_get_logit_bias_buffer(&mut self, size: u32) -> *mut f32 { - // we keep one more logit at the end as a placeholder to avoid branching in - // the inner loop of append_bias - self.logit_biases.resize((size + 1) as usize, 0.0); - self.allowed_tokens.resize(self.logit_biases.len()); - self.logit_biases.as_mut_ptr() - } pub fn all_disallowed(&mut self) { self.allowed_tokens.set_all(false); @@ -73,8 +74,7 @@ impl AiciVmHelper { } pub fn compute_biases(&mut self) { - self.logit_biases.iter_mut().for_each(|x| *x = -100.0); - self.allowed_tokens.apply_to(&mut self.logit_biases); + host::return_logits(&self.allowed_tokens); } } @@ -90,7 +90,6 @@ pub trait AiciVm { macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { $crate::expose!($struct_name::aici_process() -> ()); - $crate::expose!($struct_name::helper::aici_get_logit_bias_buffer(size: u32) -> *mut f32); #[no_mangle] pub extern "C" fn aici_create() -> *mut $struct_name { @@ -136,4 +135,3 @@ macro_rules! wprint { $crate::host::_print(&format!($($arg)*)); }}; } - diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs index 319e3392..f3dd062f 100644 --- a/aici_abi/src/svob.rs +++ b/aici_abi/src/svob.rs @@ -16,6 +16,10 @@ impl SimpleVob { self.data.len() * BITS } + pub unsafe fn as_ptr(&self) -> *const u32 { + self.data.as_ptr() + } + #[inline(always)] pub fn allow_token(&mut self, tok: TokenId) { let idx = tok as usize; From 167b5c831f0428bc2d3ff8a6c6caba8508e12664 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 7 Nov 2023 22:05:11 +0000 Subject: [PATCH 082/301] pass shm around --- aici_abi/Cargo.lock | 24 ++++++++++++++++++++++++ aici_abi/Cargo.toml | 1 + aici_abi/src/host.rs | 13 +++++-------- aici_abi/src/lib.rs | 13 ++++++++++--- aici_abi/src/recognizer.rs | 38 ++++++++++++++++++-------------------- 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/aici_abi/Cargo.lock b/aici_abi/Cargo.lock index bd21f7fd..acd53088 100644 --- a/aici_abi/Cargo.lock +++ b/aici_abi/Cargo.lock @@ -7,8 +7,15 @@ name = "aici_abi" version = "0.1.0" dependencies = [ "serde", + "serde_json", ] +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + [[package]] name = "proc-macro2" version = "1.0.69" @@ -27,6 +34,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + [[package]] name = "serde" version = "1.0.192" @@ -47,6 +60,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "syn" version = "2.0.39" diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index b1a02809..1b1787a8 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -8,3 +8,4 @@ name = "aici_abi" [dependencies] serde = { version = "1.0.192", features = ["derive"] } +serde_json = "1.0.108" diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index cc596f60..b79044e3 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -29,9 +29,6 @@ extern "C" { // It's a JSON serialization of ProcessArg. fn aici_host_process_arg() -> BlobId; - // Return the ID of argument passed by the user. - fn aici_host_tokens() -> BlobId; - // Tokenize given UTF8 string. The result is only valid until next call to this function. fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; @@ -143,11 +140,6 @@ pub fn trie_bytes() -> Vec { return std::fs::read("tokenizer.bin").unwrap(); } -pub fn tokens_arg() -> Vec { - let r = read_blob(unsafe { aici_host_tokens() }, 256); - vec_from_bytes(&r) -} - pub fn tokenize(s: &str) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); @@ -168,3 +160,8 @@ pub fn return_logits(vob: &SimpleVob) { aici_host_return_logits(vob.as_ptr()); } } + +pub fn process_arg_bytes() -> Vec { + return read_blob(unsafe { aici_host_process_arg() }, 1024); +} + diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index a640234c..bf7ed659 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -15,7 +15,8 @@ pub type TokenId = bytes::TokenId; #[derive(Serialize, Deserialize, Debug)] pub enum ProcessArg { - Prompt {}, + InitialPrompt { tokens: Vec }, + StepPrompt {}, Gen { tokens: Vec }, } @@ -79,9 +80,15 @@ impl AiciVmHelper { } pub trait AiciVm { - /// The prompt, single generated token, or all ff tokens, arg in host::tokens_arg(). + fn process(&mut self, arg: ProcessArg); + + /// The prompt, single generated token, or all ff tokens, arg in host::process_arg(). /// On return, self.helper.logit_biases are supposed to be updated. - fn aici_process(&mut self); + fn aici_process(&mut self) { + let arg: ProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + self.process(arg); + } + // Used in testing. fn get_helper(&mut self) -> &mut AiciVmHelper; } diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 84ff3969..98dc3bd2 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,16 +1,14 @@ use std::{fmt::Debug, rc::Rc}; use crate::{ - host::tokens_arg, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, AiciVmHelper, + AiciVm, AiciVmHelper, ProcessArg, }; pub struct AiciRecognizer { pub helper: AiciVmHelper, pub rec: R, pub trie: Rc>, - pub is_prompt: bool, } impl AiciRecognizer { @@ -19,7 +17,6 @@ impl AiciRecognizer { helper: AiciVmHelper::new(), rec, trie, - is_prompt: true, } } @@ -32,25 +29,26 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn aici_process(&mut self) { - if self.is_prompt { - // the regex doesn't care about the prompt - self.is_prompt = false; - } else { - for token in tokens_arg() { - let bytes = self.trie.token(token); - // wprintln!("xapp {} {:?}", token, bytes); - for b in bytes { - self.rec.push_byte(*b) + fn get_helper(&mut self) -> &mut AiciVmHelper { + &mut self.helper + } + + fn process(&mut self, arg: ProcessArg) { + match arg { + ProcessArg::InitialPrompt { .. } => {} + ProcessArg::StepPrompt {} => self.compute(), + ProcessArg::Gen { tokens } => { + for token in tokens { + let bytes = self.trie.token(token); + // wprintln!("xapp {} {:?}", token, bytes); + for b in bytes { + self.rec.push_byte(*b) + } + self.rec.collapse(); } - self.rec.collapse(); + self.compute(); } } - self.compute(); - } - - fn get_helper(&mut self) -> &mut AiciVmHelper { - &mut self.helper } } From a98c2f2bd20f110760aa54405ca512161803da69 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Nov 2023 11:01:36 -0800 Subject: [PATCH 083/301] working on interfaces --- aici_abi/src/aici_iface.h | 54 -------------- aici_abi/src/host.rs | 41 +++++++--- aici_abi/src/lib.rs | 148 ++++++++++++++++++++++++++----------- aici_abi/src/recognizer.rs | 20 ++--- 4 files changed, 141 insertions(+), 122 deletions(-) delete mode 100644 aici_abi/src/aici_iface.h diff --git a/aici_abi/src/aici_iface.h b/aici_abi/src/aici_iface.h deleted file mode 100644 index c0d73624..00000000 --- a/aici_abi/src/aici_iface.h +++ /dev/null @@ -1,54 +0,0 @@ -// -// This interface needs to be implemented by the WASM binary -// - -// Tokens are assumed to be at most 32 bit. -// Typical models range 30k (LLAMA) to 100k (GPT4) tokens. -typedef uint32_t token_t; - -// Called first, after instantiating WASM module. -void aici_init(void); - -// Called once per module, to get an AICI for a specific query -Aici *aici_create(void); - -// These two are called after aici_create() on the fresh AICI. -// They should return the buffers that the WASM code has to allocated and keep around -// until relevant aici_free(). - -// Return buffer where the prompt will be written. `size` is number of tokens in the prompt. -token_t *aici_get_prompt_buffer(Aici *aici, uint32_t size); - -// Return the buffer where the WASM code will write logit biases after -// aici_process_prompt() and aici_append_token(). -// Size of number of biases (which equals size of the vocabulary). -float *aici_get_logit_bias_buffer(Aici *aici, uint32_t size); - -// This called once, when the AICI should process the prompt in its buffer. -// It should set the values in logit bias buffer. -void aici_process_prompt(Aici *aici); -// The logical type (if WASM would allow such things) of this function is: -// float[vocab_size] aici_process_prompt(Aici *aici, token_t[] prompt); - -// This is called after a token is sampled. -// It should set the values in logit bias buffer. -void aici_append_token(Aici *aici, token_t tok); -// The logical type (if WASM would allow such things) of this function is: -// float[vocab_size] aici_append_token(Aici *aici, token_t tok); - -// -// This interface is available to the WASM binary -// - -// Log a string. -void aici_host_print(const uint8_t *ptr, uint32_t size); - -// Read binary representation of TokTrie. -// Always returns the size of the trie, will write up to `size` bytes to `dst`. -uint32_t aici_host_read_token_trie(uint8_t *dst, uint32_t size); - -// Similar, for argument passed by the user (typically JSON). -uint32_t aici_host_read_arg(uint8_t *dst, uint32_t size); - -// Tokenize given UTF8 string. `dst_size` is in elements, not bytes. Returns number of generated tokens. -uint32_t aici_host_tokenize(const uint8_t *src, uint32_t src_size, uint32_t *dst, uint32_t dst_size); \ No newline at end of file diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index b79044e3..9253763a 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -3,7 +3,7 @@ use std::io; use crate::{ bytes::{vec_from_bytes, TokenId}, svob::SimpleVob, - wprintln, + wprintln, SeqId, }; #[repr(transparent)] @@ -33,13 +33,17 @@ extern "C" { fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; // Set logit bias based on bitmask in src. - fn aici_host_return_logits(src: *const u32); + fn aici_host_return_logit_bias(src: *const u32); // Append fast-forward (FF) token. // First FF token has to be returned by setting logit bias appropriately. // Next tokens are added using this interface. // All FF tokens are then generated in one go. fn aici_host_ff_token(token: u32); + + fn aici_host_self_seq_id() -> u32; + + fn aici_host_return_process_result(res: *const u8, res_size: u32); } // TODO: add @@ -140,24 +144,16 @@ pub fn trie_bytes() -> Vec { return std::fs::read("tokenizer.bin").unwrap(); } -pub fn tokenize(s: &str) -> Vec { - let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; - let r = read_blob(id, 4 * (s.len() / 3 + 10)); - let res = vec_from_bytes(&r); - wprintln!("tokenize: {:?} -> {:?}", s, res); - res -} - pub fn ff_token(token: TokenId) { unsafe { aici_host_ff_token(token); } } -pub fn return_logits(vob: &SimpleVob) { +pub fn return_logit_bias(vob: &SimpleVob) { assert!(vob.len() > 0); unsafe { - aici_host_return_logits(vob.as_ptr()); + aici_host_return_logit_bias(vob.as_ptr()); } } @@ -165,3 +161,24 @@ pub fn process_arg_bytes() -> Vec { return read_blob(unsafe { aici_host_process_arg() }, 1024); } +pub fn return_process_result(res: &[u8]) { + unsafe { + aici_host_return_process_result(res.as_ptr(), res.len() as u32); + } +} + +// Public APIs + +/// Tokenize given UTF8 string. +pub fn tokenize(s: &str) -> Vec { + let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; + let r = read_blob(id, 4 * (s.len() / 3 + 10)); + let res = vec_from_bytes(&r); + wprintln!("tokenize: {:?} -> {:?}", s, res); + res +} + +/// Return the ID of the current process. +pub fn self_seq_id() -> SeqId { + unsafe { SeqId(aici_host_self_seq_id()) } +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index bf7ed659..1fc3ab98 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,8 +1,8 @@ +use serde::{Deserialize, Serialize}; use std::rc::Rc; -use serde::{Deserialize, Serialize}; -use svob::SimpleVob; -use toktree::{SpecialToken, TokTrie}; +use crate::svob::SimpleVob; +use crate::toktree::{SpecialToken, TokTrie}; pub mod bytes; pub mod host; @@ -13,35 +13,86 @@ pub mod toktree; pub type TokenId = bytes::TokenId; +#[derive(Serialize, Deserialize, Debug)] +pub struct InitPromptArg { + pub prompt: Vec, +} + +#[repr(transparent)] +#[derive(Serialize, Deserialize, Debug)] +pub struct SeqId(u32); + #[derive(Serialize, Deserialize, Debug)] pub enum ProcessArg { - InitialPrompt { tokens: Vec }, - StepPrompt {}, - Gen { tokens: Vec }, + /// Generally, issued after each token generated by the model. + /// `tokens` is typically just this one token, except for the first call, when + /// `tokens` is empty, and the cases when fast-forward tokens are used. + Append { tokens: Vec }, + + /// Issued after ProcessResult::Fork. + /// Use host::self_seq_id() to get the ID of the current sequence. + Fork { group: Vec }, } -/// Expose method as extern "C", usage: -/// expose!(Foo::set_count(n: i32) -> i32); -/// Generates "C" function: -/// set_count(Foo *, i32) -> i32 -#[macro_export] -macro_rules! expose { - ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$method_name($($arg),*) - } - } - }; - ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$field.$method_name($($arg),*) - } - } - }; +#[derive(Serialize, Deserialize, Debug)] +pub enum ProcessResult { + /// Stop the current sequence. + /// Similar to strong bias to EOS. + Stop, + + /// Sample next token in the current sequence, using bias set with `return_logit_bias()` + SampleWithBias, + + /// First pop `backtrack` tokens, + /// then force next tokens to be generated to be `ff_tokens`. + /// `backtrack` can be 0, and `ff_tokens` can be empty but not both. + Splice { + backtrack: u32, + ff_tokens: Vec, + }, + + /// Fork the current sequence into `num_children` sequences (including current one). + /// `resume_fork(0)` will be called on this VM, while children will be resumed + /// with `resume_fork(1)` ... `resume_fork(num_children - 1)` + /// (thus, `Fork {1}` will not create any new sequences). + Fork { num_children: u32 }, + + /// Wait until all listed variables are available for reading, + /// and all listed sequences have finished executing. + WaitAll { + variables: Vec, + finished: Vec, + }, +} + +pub trait AiciVm { + /// Called with the initial prompt. Has long time limit. + /// By default ignore prompt. + fn init_prompt(&mut self, _arg: InitPromptArg) {} + + /// This is the main entry point for the module. + /// Following calls are issued: + /// * `Append { tokens: [] }` - to generate bias for the first token of the output + /// And then any combination of: + /// * `Append { tokens: [t] }` - when a token `t` is sampled + /// * `Append { tokens: [t...] }` - after fast-forward + /// Either way, a bias should be eventually generated. + fn process(&mut self, arg: ProcessArg) -> ProcessResult; + + fn get_helper(&mut self) -> &mut AiciVmHelper; + + // Internals + fn aici_process(&mut self) { + let arg: ProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let res = self.process(arg); + let res_bytes = serde_json::to_vec(&res).unwrap(); + host::return_process_result(&res_bytes); + } + + fn aici_init_prompt(&mut self) { + let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + self.init_prompt(arg); + } } #[derive(Clone)] @@ -50,7 +101,6 @@ pub struct AiciVmHelper { pub trie: Rc>, } -// aici_* are exposed to C in both AiciVm and AiciVmHelper impl AiciVmHelper { pub fn new() -> Self { let trie = TokTrie::from_host(); @@ -74,29 +124,41 @@ impl AiciVmHelper { self.allow_one(self.trie.special_token(SpecialToken::EndOfSentence)); } - pub fn compute_biases(&mut self) { - host::return_logits(&self.allowed_tokens); + pub fn return_logit_bias(&mut self) -> ProcessResult { + host::return_logit_bias(&self.allowed_tokens); + ProcessResult::SampleWithBias } } -pub trait AiciVm { - fn process(&mut self, arg: ProcessArg); - - /// The prompt, single generated token, or all ff tokens, arg in host::process_arg(). - /// On return, self.helper.logit_biases are supposed to be updated. - fn aici_process(&mut self) { - let arg: ProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - self.process(arg); - } - - // Used in testing. - fn get_helper(&mut self) -> &mut AiciVmHelper; +/// Expose method as extern "C", usage: +/// expose!(Foo::set_count(n: i32) -> i32); +/// Generates "C" function: +/// set_count(Foo *, i32) -> i32 +#[macro_export] +macro_rules! expose { + ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$method_name($($arg),*) + } + } + }; + ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$field.$method_name($($arg),*) + } + } + }; } #[macro_export] macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { $crate::expose!($struct_name::aici_process() -> ()); + $crate::expose!($struct_name::aici_init_prompt() -> ()); #[no_mangle] pub extern "C" fn aici_create() -> *mut $struct_name { diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 98dc3bd2..a7afe5fb 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, rc::Rc}; use crate::{ toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, AiciVmHelper, ProcessArg, + AiciVm, AiciVmHelper, ProcessArg, ProcessResult, }; pub struct AiciRecognizer { @@ -19,13 +19,6 @@ impl AiciRecognizer { trie, } } - - fn compute(&mut self) { - // wprintln!("compute"); - self.trie - .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); - self.helper.compute_biases(); - } } impl AiciVm for AiciRecognizer { @@ -33,11 +26,9 @@ impl AiciVm for AiciRecognizer { &mut self.helper } - fn process(&mut self, arg: ProcessArg) { + fn process(&mut self, arg: ProcessArg) -> ProcessResult { match arg { - ProcessArg::InitialPrompt { .. } => {} - ProcessArg::StepPrompt {} => self.compute(), - ProcessArg::Gen { tokens } => { + ProcessArg::Append { tokens } => { for token in tokens { let bytes = self.trie.token(token); // wprintln!("xapp {} {:?}", token, bytes); @@ -46,8 +37,11 @@ impl AiciVm for AiciRecognizer { } self.rec.collapse(); } - self.compute(); + self.trie + .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); + self.helper.return_logit_bias() } + ProcessArg::Fork { .. } => panic!("fork not requested!"), } } } From 92c43c6d6093a135d2b79b5fdaa1eb4de76ede86 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Nov 2023 14:17:26 -0800 Subject: [PATCH 084/301] work on variables --- aici_abi/src/lib.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 1fc3ab98..50b0955a 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -13,6 +13,41 @@ pub mod toktree; pub type TokenId = bytes::TokenId; +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageOp { + Set, + Append, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageCmd { + /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. + ReadVar { name: String }, + + /// Write variable. + /// If `when_version_is == None`, always writes the variable and returns StorageResp::WriteVar. + /// Otherwise, if the variable has the specified version, it writes the variable + /// and returns StorageResp::WriteVar. + /// Otherwise (version conflict), returns either StorageResp::ReadVar or StorageResp::VariableMissing + /// just like ReadVar would. + WriteVar { + name: String, + value: Vec, + op: StorageOp, + when_version_is: Option, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageResp { + /// Upon handling the request the variable had the specified value and version number. + ReadVar { version: u64, value: Vec }, + /// Upon handling the request the variable was unset. + VariableMissing {}, + /// The variable has been written, and the new version is returned. + WriteVar { version: u64 }, +} + #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { pub prompt: Vec, From cc18e099bb842046af2e8d1489ec6d6ec4233568 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Nov 2023 14:29:43 -0800 Subject: [PATCH 085/301] variable api --- aici_abi/src/host.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++ aici_abi/src/lib.rs | 35 ----------------- 2 files changed, 93 insertions(+), 35 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 9253763a..0e7ceaf9 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,3 +1,4 @@ +use serde::{Deserialize, Serialize}; use std::io; use crate::{ @@ -44,6 +45,8 @@ extern "C" { fn aici_host_self_seq_id() -> u32; fn aici_host_return_process_result(res: *const u8, res_size: u32); + + fn aici_host_storage_cmd(cmd: *const u8, cmd_size: u32) -> BlobId; } // TODO: add @@ -167,8 +170,98 @@ pub fn return_process_result(res: &[u8]) { } } +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageOp { + Set, + Append, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageCmd { + /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. + ReadVar { name: String }, + + /// Write variable. + /// If `when_version_is == None`, always writes the variable and returns StorageResp::WriteVar. + /// Otherwise, if the variable has the specified version, it writes the variable + /// and returns StorageResp::WriteVar. + /// Otherwise (version conflict), returns either StorageResp::ReadVar or StorageResp::VariableMissing + /// just like ReadVar would. + WriteVar { + name: String, + value: Vec, + op: StorageOp, + when_version_is: Option, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum StorageResp { + /// Upon handling the request the variable had the specified value and version number. + ReadVar { version: u64, value: Vec }, + /// Upon handling the request the variable was unset. + VariableMissing {}, + /// The variable has been written, and the new version is returned. + WriteVar { version: u64 }, +} + +pub fn storage_cmd(cmd: StorageCmd) -> StorageResp { + let cmd_bytes = serde_json::to_vec(&cmd).unwrap(); + let res_id = unsafe { aici_host_storage_cmd(cmd_bytes.as_ptr(), cmd_bytes.len() as u32) }; + let resp_bytes = read_blob(res_id, 1024); + serde_json::from_slice(&resp_bytes).unwrap() +} + // Public APIs +pub struct VariableStorage { + // no fields yet +} + +impl VariableStorage { + /// Create a new instance of VariableStorage. It currently has no fields. + pub fn new() -> Self { + VariableStorage {} + } + + /// Read variable. Returns None if the variable is unset. + pub fn get(&self, name: &str) -> Option> { + self.get_with_version(name).map(|x| x.1) + } + + /// Write specified value to variable. + pub fn set(&self, name: &str, value: Vec) { + let _ver = self.write_var(name, value, StorageOp::Set); + } + + /// Append specified value to variable. + pub fn append(&self, name: &str, value: Vec) { + let _ver = self.write_var(name, value, StorageOp::Append); + } + + fn write_var(&self, name: &str, value: Vec, op: StorageOp) -> u64 { + match storage_cmd(StorageCmd::WriteVar { + name: name.to_string(), + value, + op, + when_version_is: None, + }) { + StorageResp::WriteVar { version } => version, + _ => panic!("unexpected response to writevar"), + } + } + + fn get_with_version(&self, name: &str) -> Option<(u64, Vec)> { + match storage_cmd(StorageCmd::ReadVar { + name: name.to_string(), + }) { + StorageResp::ReadVar { version, value } => Some((version, value)), + StorageResp::VariableMissing {} => None, + StorageResp::WriteVar { .. } => panic!("unexpected response to readvar"), + } + } +} + /// Tokenize given UTF8 string. pub fn tokenize(s: &str) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 50b0955a..1fc3ab98 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -13,41 +13,6 @@ pub mod toktree; pub type TokenId = bytes::TokenId; -#[derive(Serialize, Deserialize, Debug)] -pub enum StorageOp { - Set, - Append, -} - -#[derive(Serialize, Deserialize, Debug)] -pub enum StorageCmd { - /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. - ReadVar { name: String }, - - /// Write variable. - /// If `when_version_is == None`, always writes the variable and returns StorageResp::WriteVar. - /// Otherwise, if the variable has the specified version, it writes the variable - /// and returns StorageResp::WriteVar. - /// Otherwise (version conflict), returns either StorageResp::ReadVar or StorageResp::VariableMissing - /// just like ReadVar would. - WriteVar { - name: String, - value: Vec, - op: StorageOp, - when_version_is: Option, - }, -} - -#[derive(Serialize, Deserialize, Debug)] -pub enum StorageResp { - /// Upon handling the request the variable had the specified value and version number. - ReadVar { version: u64, value: Vec }, - /// Upon handling the request the variable was unset. - VariableMissing {}, - /// The variable has been written, and the new version is returned. - WriteVar { version: u64 }, -} - #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { pub prompt: Vec, From 4ac0d4b399355066a841e47adc9bda14ea53c7e7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Nov 2023 14:55:58 -0800 Subject: [PATCH 086/301] group/storage msgs --- aici_abi/src/host.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 0e7ceaf9..362773d3 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -33,7 +33,7 @@ extern "C" { // Tokenize given UTF8 string. The result is only valid until next call to this function. fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; - // Set logit bias based on bitmask in src. + // Set logit bias based on bit-mask in src. fn aici_host_return_logit_bias(src: *const u32); // Append fast-forward (FF) token. @@ -247,7 +247,7 @@ impl VariableStorage { when_version_is: None, }) { StorageResp::WriteVar { version } => version, - _ => panic!("unexpected response to writevar"), + _ => panic!("unexpected response to write var"), } } @@ -257,7 +257,7 @@ impl VariableStorage { }) { StorageResp::ReadVar { version, value } => Some((version, value)), StorageResp::VariableMissing {} => None, - StorageResp::WriteVar { .. } => panic!("unexpected response to readvar"), + StorageResp::WriteVar { .. } => panic!("unexpected response to read var"), } } } From 06b134dae8539dbeee0806a1b2b09c49a0a3810c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Nov 2023 15:52:57 -0800 Subject: [PATCH 087/301] working on blob api --- aici_abi/src/host.rs | 2 +- aici_abi/src/lib.rs | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 362773d3..7ec4120e 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -215,7 +215,7 @@ pub fn storage_cmd(cmd: StorageCmd) -> StorageResp { // Public APIs pub struct VariableStorage { - // no fields yet + // no fields (yet?) } impl VariableStorage { diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 1fc3ab98..b3dd73a6 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -165,6 +165,11 @@ macro_rules! aici_expose_all { let b = Box::new($new); Box::into_raw(b) } + + #[no_mangle] + pub extern "C" fn aici_panic() { + panic!("aici_panic()") + } } } From 36d6a92d5c2b4f341ac824c9be617fa811f74ed0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 9 Nov 2023 11:22:34 -0800 Subject: [PATCH 088/301] handle Splice (no backtrack) --- aici_abi/src/host.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 7ec4120e..48c7a33c 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -36,12 +36,6 @@ extern "C" { // Set logit bias based on bit-mask in src. fn aici_host_return_logit_bias(src: *const u32); - // Append fast-forward (FF) token. - // First FF token has to be returned by setting logit bias appropriately. - // Next tokens are added using this interface. - // All FF tokens are then generated in one go. - fn aici_host_ff_token(token: u32); - fn aici_host_self_seq_id() -> u32; fn aici_host_return_process_result(res: *const u8, res_size: u32); @@ -147,12 +141,6 @@ pub fn trie_bytes() -> Vec { return std::fs::read("tokenizer.bin").unwrap(); } -pub fn ff_token(token: TokenId) { - unsafe { - aici_host_ff_token(token); - } -} - pub fn return_logit_bias(vob: &SimpleVob) { assert!(vob.len() > 0); unsafe { From 1114a665c2963dc45961b5ab182c2d20596a9559 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 11 Nov 2023 01:05:59 +0000 Subject: [PATCH 089/301] attn mask support --- aici_abi/src/lib.rs | 54 ++++++++++++++++++++++++++++---------- aici_abi/src/recognizer.rs | 23 +++++++--------- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index b3dd73a6..7e294adb 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -23,15 +23,31 @@ pub struct InitPromptArg { pub struct SeqId(u32); #[derive(Serialize, Deserialize, Debug)] -pub enum ProcessArg { +pub struct PreProcessArg { /// Generally, issued after each token generated by the model. /// `tokens` is typically just this one token, except for the first call, when /// `tokens` is empty, and the cases when fast-forward tokens are used. - Append { tokens: Vec }, + pub tokens: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PreProcessResult { + /// If no attention masks are returned - stop the sequence. + /// If one is returned - just continue with this mask. + /// If more than one attention mask is returned - fork the generation. + /// Attention mask of length 0 is equivalent [1.0, ..., 1.0]. + /// Otherwise, length of the mask should be the same as the number of prompt + generated tokens. + pub attention_masks: Vec>, +} - /// Issued after ProcessResult::Fork. +#[derive(Serialize, Deserialize, Debug)] +pub struct ProcessArg { + /// This is the same tokens as in PreProcessArg. + pub tokens: Vec, + /// fork_group.len() == attention_masks.len(). /// Use host::self_seq_id() to get the ID of the current sequence. - Fork { group: Vec }, + /// TODO: not impl yet + pub fork_group: Vec, } #[derive(Serialize, Deserialize, Debug)] @@ -51,12 +67,6 @@ pub enum ProcessResult { ff_tokens: Vec, }, - /// Fork the current sequence into `num_children` sequences (including current one). - /// `resume_fork(0)` will be called on this VM, while children will be resumed - /// with `resume_fork(1)` ... `resume_fork(num_children - 1)` - /// (thus, `Fork {1}` will not create any new sequences). - Fork { num_children: u32 }, - /// Wait until all listed variables are available for reading, /// and all listed sequences have finished executing. WaitAll { @@ -70,6 +80,13 @@ pub trait AiciVm { /// By default ignore prompt. fn init_prompt(&mut self, _arg: InitPromptArg) {} + /// Called after tokens are appended, before process(). + fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { + PreProcessResult { + attention_masks: vec![vec![]], + } + } + /// This is the main entry point for the module. /// Following calls are issued: /// * `Append { tokens: [] }` - to generate bias for the first token of the output @@ -82,6 +99,18 @@ pub trait AiciVm { fn get_helper(&mut self) -> &mut AiciVmHelper; // Internals + fn aici_init_prompt(&mut self) { + let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + self.init_prompt(arg); + } + + fn aici_pre_process(&mut self) { + let arg: PreProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let res = self.pre_process(arg); + let res_bytes = serde_json::to_vec(&res).unwrap(); + host::return_process_result(&res_bytes); + } + fn aici_process(&mut self) { let arg: ProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); let res = self.process(arg); @@ -89,10 +118,6 @@ pub trait AiciVm { host::return_process_result(&res_bytes); } - fn aici_init_prompt(&mut self) { - let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - self.init_prompt(arg); - } } #[derive(Clone)] @@ -158,6 +183,7 @@ macro_rules! expose { macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { $crate::expose!($struct_name::aici_process() -> ()); + $crate::expose!($struct_name::aici_pre_process() -> ()); $crate::expose!($struct_name::aici_init_prompt() -> ()); #[no_mangle] diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index a7afe5fb..c46087e1 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -27,22 +27,17 @@ impl AiciVm for AiciRecognizer { } fn process(&mut self, arg: ProcessArg) -> ProcessResult { - match arg { - ProcessArg::Append { tokens } => { - for token in tokens { - let bytes = self.trie.token(token); - // wprintln!("xapp {} {:?}", token, bytes); - for b in bytes { - self.rec.push_byte(*b) - } - self.rec.collapse(); - } - self.trie - .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); - self.helper.return_logit_bias() + for token in arg.tokens { + let bytes = self.trie.token(token); + // wprintln!("process {} {:?}", token, bytes); + for b in bytes { + self.rec.push_byte(*b) } - ProcessArg::Fork { .. } => panic!("fork not requested!"), + self.rec.collapse(); } + self.trie + .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); + self.helper.return_logit_bias() } } From 5cc6fe534495335336f7306db64617e8582e809b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 14 Nov 2023 00:41:10 +0000 Subject: [PATCH 090/301] more work on forks --- aici_abi/src/lib.rs | 9 ++++----- aici_abi/src/recognizer.rs | 8 +++++++- aici_abi/src/toktree.rs | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 7e294adb..0e3424af 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -19,8 +19,8 @@ pub struct InitPromptArg { } #[repr(transparent)] -#[derive(Serialize, Deserialize, Debug)] -pub struct SeqId(u32); +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +pub struct SeqId(pub u32); #[derive(Serialize, Deserialize, Debug)] pub struct PreProcessArg { @@ -28,6 +28,8 @@ pub struct PreProcessArg { /// `tokens` is typically just this one token, except for the first call, when /// `tokens` is empty, and the cases when fast-forward tokens are used. pub tokens: Vec, + + pub max_context_size: usize, } #[derive(Serialize, Deserialize, Debug)] @@ -42,11 +44,8 @@ pub struct PreProcessResult { #[derive(Serialize, Deserialize, Debug)] pub struct ProcessArg { - /// This is the same tokens as in PreProcessArg. - pub tokens: Vec, /// fork_group.len() == attention_masks.len(). /// Use host::self_seq_id() to get the ID of the current sequence. - /// TODO: not impl yet pub fork_group: Vec, } diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index c46087e1..a554cd79 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -26,7 +26,7 @@ impl AiciVm for AiciRecognizer { &mut self.helper } - fn process(&mut self, arg: ProcessArg) -> ProcessResult { + fn pre_process(&mut self, arg: crate::PreProcessArg) -> crate::PreProcessResult { for token in arg.tokens { let bytes = self.trie.token(token); // wprintln!("process {} {:?}", token, bytes); @@ -35,6 +35,12 @@ impl AiciVm for AiciRecognizer { } self.rec.collapse(); } + crate::PreProcessResult { + attention_masks: vec![vec![]], + } + } + + fn process(&mut self, _arg: ProcessArg) -> ProcessResult { self.trie .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); self.helper.return_logit_bias() diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 396fb3e4..514e77c1 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -164,7 +164,7 @@ impl TokTrie { pub fn special_token(&self, tok: SpecialToken) -> TokenId { match tok { SpecialToken::EndOfSentence => self.info.tok_eos, - _ => todo!(), + _ => panic!("non-EOS special_token() called"), // TODO? } } From 734c1319adcd191baf4b276e2cda9d26df591d4f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 14 Nov 2023 00:46:29 +0000 Subject: [PATCH 091/301] cleanup --- aici_abi/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 0e3424af..efe44503 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -28,8 +28,6 @@ pub struct PreProcessArg { /// `tokens` is typically just this one token, except for the first call, when /// `tokens` is empty, and the cases when fast-forward tokens are used. pub tokens: Vec, - - pub max_context_size: usize, } #[derive(Serialize, Deserialize, Debug)] From 9202de1d0dc49db7ab9072e3f18190274b1b7168 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 14 Nov 2023 23:00:05 +0000 Subject: [PATCH 092/301] working on variables --- aici_abi/src/lib.rs | 2 -- aici_abi/src/recognizer.rs | 4 ---- 2 files changed, 6 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index efe44503..685c59bf 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -93,8 +93,6 @@ pub trait AiciVm { /// Either way, a bias should be eventually generated. fn process(&mut self, arg: ProcessArg) -> ProcessResult; - fn get_helper(&mut self) -> &mut AiciVmHelper; - // Internals fn aici_init_prompt(&mut self) { let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index a554cd79..ac7c871c 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -22,10 +22,6 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn get_helper(&mut self) -> &mut AiciVmHelper { - &mut self.helper - } - fn pre_process(&mut self, arg: crate::PreProcessArg) -> crate::PreProcessResult { for token in arg.tokens { let bytes = self.trie.token(token); From 962f460ddeffbf53f98eadc35d4dc5a9de04834d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 14 Nov 2023 23:26:43 +0000 Subject: [PATCH 093/301] get rid of AiciVmHelper --- aici_abi/src/lib.rs | 39 -------------------------------------- aici_abi/src/recognizer.rs | 20 +++++++++---------- aici_abi/src/toktree.rs | 6 ++++++ 3 files changed, 16 insertions(+), 49 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 685c59bf..bffddaa3 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,8 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::rc::Rc; - -use crate::svob::SimpleVob; -use crate::toktree::{SpecialToken, TokTrie}; pub mod bytes; pub mod host; @@ -115,41 +111,6 @@ pub trait AiciVm { } -#[derive(Clone)] -pub struct AiciVmHelper { - pub allowed_tokens: SimpleVob, - pub trie: Rc>, -} - -impl AiciVmHelper { - pub fn new() -> Self { - let trie = TokTrie::from_host(); - let mut allowed_tokens = SimpleVob::new(); - allowed_tokens.resize(trie.vocab_size() + 1); - AiciVmHelper { - allowed_tokens, - trie: Rc::new(Box::new(trie)), - } - } - - pub fn all_disallowed(&mut self) { - self.allowed_tokens.set_all(false); - } - - pub fn allow_one(&mut self, tok: TokenId) { - self.allowed_tokens.allow_token(tok); - } - - pub fn allow_eos(&mut self) { - self.allow_one(self.trie.special_token(SpecialToken::EndOfSentence)); - } - - pub fn return_logit_bias(&mut self) -> ProcessResult { - host::return_logit_bias(&self.allowed_tokens); - ProcessResult::SampleWithBias - } -} - /// Expose method as extern "C", usage: /// expose!(Foo::set_count(n: i32) -> i32); /// Generates "C" function: diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index ac7c871c..50e3e9ab 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,22 +1,21 @@ -use std::{fmt::Debug, rc::Rc}; +use std::fmt::Debug; use crate::{ + host, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, AiciVmHelper, ProcessArg, ProcessResult, + AiciVm, ProcessArg, ProcessResult, }; pub struct AiciRecognizer { - pub helper: AiciVmHelper, + pub trie: TokTrie, pub rec: R, - pub trie: Rc>, } impl AiciRecognizer { - pub fn from_recognizer(trie: Rc>, rec: R) -> Self { + pub fn from_recognizer(rec: R) -> Self { AiciRecognizer { - helper: AiciVmHelper::new(), + trie: TokTrie::from_host(), rec, - trie, } } } @@ -37,9 +36,10 @@ impl AiciVm for AiciRecognizer { } fn process(&mut self, _arg: ProcessArg) -> ProcessResult { - self.trie - .compute_bias(&mut self.rec, &mut self.helper.allowed_tokens); - self.helper.return_logit_bias() + let mut set = self.trie.alloc_token_set(); + self.trie.compute_bias(&mut self.rec, &mut set); + host::return_logit_bias(&set); + ProcessResult::SampleWithBias } } diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 514e77c1..b3202695 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -172,6 +172,12 @@ impl TokTrie { self.info.vocab_size as usize } + pub fn alloc_token_set(&self) -> SimpleVob { + let mut r = SimpleVob::new(); + r.resize(self.vocab_size() + 1); + r + } + pub fn alloc_logits(&self) -> Vec { vec![0.0; self.vocab_size() + 1] } From 37f3c94e4d768a516a52fa014829ef7e6fbd3107 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 16 Nov 2023 01:03:02 +0000 Subject: [PATCH 094/301] add wait() command to ast --- aici_abi/src/lib.rs | 28 ++++++++++++++++++++++++---- aici_abi/src/recognizer.rs | 4 +--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index bffddaa3..860f2ab2 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -34,6 +34,29 @@ pub struct PreProcessResult { /// Attention mask of length 0 is equivalent [1.0, ..., 1.0]. /// Otherwise, length of the mask should be the same as the number of prompt + generated tokens. pub attention_masks: Vec>, + + pub suspend: bool, +} + +impl PreProcessResult { + pub fn new(attention_masks: Vec>) -> Self { + PreProcessResult { + attention_masks, + suspend: false, + } + } + pub fn continue_() -> Self { + PreProcessResult::new(vec![vec![]]) + } + pub fn suspend() -> Self { + PreProcessResult { + attention_masks: vec![vec![]], + suspend: true, + } + } + pub fn stop() -> Self { + PreProcessResult::new(vec![]) + } } #[derive(Serialize, Deserialize, Debug)] @@ -75,9 +98,7 @@ pub trait AiciVm { /// Called after tokens are appended, before process(). fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { - PreProcessResult { - attention_masks: vec![vec![]], - } + PreProcessResult::continue_() } /// This is the main entry point for the module. @@ -108,7 +129,6 @@ pub trait AiciVm { let res_bytes = serde_json::to_vec(&res).unwrap(); host::return_process_result(&res_bytes); } - } /// Expose method as extern "C", usage: diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 50e3e9ab..ca9a4189 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -30,9 +30,7 @@ impl AiciVm for AiciRecognizer { } self.rec.collapse(); } - crate::PreProcessResult { - attention_masks: vec![vec![]], - } + crate::PreProcessResult::continue_() } fn process(&mut self, _arg: ProcessArg) -> ProcessResult { From 8eefcec97c6fd2743161bddd35ace5cf0b9e1913 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 17:58:17 +0000 Subject: [PATCH 095/301] new 3-step callback structure --- aici_abi/src/lib.rs | 78 ++++++++++++++++++++++++-------------- aici_abi/src/recognizer.rs | 20 +++++----- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 860f2ab2..995d09f4 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -19,12 +19,7 @@ pub struct InitPromptArg { pub struct SeqId(pub u32); #[derive(Serialize, Deserialize, Debug)] -pub struct PreProcessArg { - /// Generally, issued after each token generated by the model. - /// `tokens` is typically just this one token, except for the first call, when - /// `tokens` is empty, and the cases when fast-forward tokens are used. - pub tokens: Vec, -} +pub struct PreProcessArg {} #[derive(Serialize, Deserialize, Debug)] pub struct PreProcessResult { @@ -38,27 +33,6 @@ pub struct PreProcessResult { pub suspend: bool, } -impl PreProcessResult { - pub fn new(attention_masks: Vec>) -> Self { - PreProcessResult { - attention_masks, - suspend: false, - } - } - pub fn continue_() -> Self { - PreProcessResult::new(vec![vec![]]) - } - pub fn suspend() -> Self { - PreProcessResult { - attention_masks: vec![vec![]], - suspend: true, - } - } - pub fn stop() -> Self { - PreProcessResult::new(vec![]) - } -} - #[derive(Serialize, Deserialize, Debug)] pub struct ProcessArg { /// fork_group.len() == attention_masks.len(). @@ -91,12 +65,45 @@ pub enum ProcessResult { }, } +#[derive(Serialize, Deserialize, Debug)] +pub struct PostProcessArg { + /// Generally, issued after each token generated by the model. + /// `tokens` is typically just this one token, except for the + /// cases when fast-forward tokens are used. + pub tokens: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PostProcessResult {} + +impl PreProcessResult { + pub fn new(attention_masks: Vec>) -> Self { + PreProcessResult { + attention_masks, + suspend: false, + } + } + pub fn continue_() -> Self { + PreProcessResult::new(vec![vec![]]) + } + pub fn suspend() -> Self { + PreProcessResult { + attention_masks: vec![vec![]], + suspend: true, + } + } + pub fn stop() -> Self { + PreProcessResult::new(vec![]) + } +} + pub trait AiciVm { /// Called with the initial prompt. Has long time limit. /// By default ignore prompt. fn init_prompt(&mut self, _arg: InitPromptArg) {} - /// Called after tokens are appended, before process(). + /// Called before process(), can return attention masks. Has short time limit. + /// Should be stateless. fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { PreProcessResult::continue_() } @@ -110,6 +117,11 @@ pub trait AiciVm { /// Either way, a bias should be eventually generated. fn process(&mut self, arg: ProcessArg) -> ProcessResult; + /// Called after tokens are appended, before process(). + fn post_process(&mut self, _arg: PostProcessArg) -> PostProcessResult { + PostProcessResult {} + } + // Internals fn aici_init_prompt(&mut self) { let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); @@ -129,6 +141,13 @@ pub trait AiciVm { let res_bytes = serde_json::to_vec(&res).unwrap(); host::return_process_result(&res_bytes); } + + fn aici_post_process(&mut self) { + let arg: PostProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let res = self.post_process(arg); + let res_bytes = serde_json::to_vec(&res).unwrap(); + host::return_process_result(&res_bytes); + } } /// Expose method as extern "C", usage: @@ -158,8 +177,9 @@ macro_rules! expose { #[macro_export] macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::aici_process() -> ()); $crate::expose!($struct_name::aici_pre_process() -> ()); + $crate::expose!($struct_name::aici_process() -> ()); + $crate::expose!($struct_name::aici_post_process() -> ()); $crate::expose!($struct_name::aici_init_prompt() -> ()); #[no_mangle] diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index ca9a4189..80989092 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use crate::{ host, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, ProcessArg, ProcessResult, + AiciVm, PostProcessArg, PostProcessResult, ProcessArg, ProcessResult, }; pub struct AiciRecognizer { @@ -21,7 +21,14 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn pre_process(&mut self, arg: crate::PreProcessArg) -> crate::PreProcessResult { + fn process(&mut self, _arg: ProcessArg) -> ProcessResult { + let mut set = self.trie.alloc_token_set(); + self.trie.compute_bias(&mut self.rec, &mut set); + host::return_logit_bias(&set); + ProcessResult::SampleWithBias + } + + fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { for token in arg.tokens { let bytes = self.trie.token(token); // wprintln!("process {} {:?}", token, bytes); @@ -30,14 +37,7 @@ impl AiciVm for AiciRecognizer { } self.rec.collapse(); } - crate::PreProcessResult::continue_() - } - - fn process(&mut self, _arg: ProcessArg) -> ProcessResult { - let mut set = self.trie.alloc_token_set(); - self.trie.compute_bias(&mut self.rec, &mut set); - host::return_logit_bias(&set); - ProcessResult::SampleWithBias + PostProcessResult {} } } From d1fa9f61f0dd803ecf7e6d7fca9b55f071850de4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 18:10:41 +0000 Subject: [PATCH 096/301] rename: process -> mid_process --- aici_abi/src/host.rs | 2 +- aici_abi/src/lib.rs | 14 +++++++------- aici_abi/src/recognizer.rs | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 48c7a33c..d40c5b6c 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -27,7 +27,7 @@ extern "C" { fn aici_host_module_arg() -> BlobId; // Return the ID of argument passed to the process() function. - // It's a JSON serialization of ProcessArg. + // It's a JSON serialization of Pre/Mid/PostProcessArg. fn aici_host_process_arg() -> BlobId; // Tokenize given UTF8 string. The result is only valid until next call to this function. diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 995d09f4..8f5bcab2 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -34,14 +34,14 @@ pub struct PreProcessResult { } #[derive(Serialize, Deserialize, Debug)] -pub struct ProcessArg { +pub struct MidProcessArg { /// fork_group.len() == attention_masks.len(). /// Use host::self_seq_id() to get the ID of the current sequence. pub fork_group: Vec, } #[derive(Serialize, Deserialize, Debug)] -pub enum ProcessResult { +pub enum MidProcessResult { /// Stop the current sequence. /// Similar to strong bias to EOS. Stop, @@ -115,7 +115,7 @@ pub trait AiciVm { /// * `Append { tokens: [t] }` - when a token `t` is sampled /// * `Append { tokens: [t...] }` - after fast-forward /// Either way, a bias should be eventually generated. - fn process(&mut self, arg: ProcessArg) -> ProcessResult; + fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; /// Called after tokens are appended, before process(). fn post_process(&mut self, _arg: PostProcessArg) -> PostProcessResult { @@ -135,9 +135,9 @@ pub trait AiciVm { host::return_process_result(&res_bytes); } - fn aici_process(&mut self) { - let arg: ProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - let res = self.process(arg); + fn aici_mid_process(&mut self) { + let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let res = self.mid_process(arg); let res_bytes = serde_json::to_vec(&res).unwrap(); host::return_process_result(&res_bytes); } @@ -178,7 +178,7 @@ macro_rules! expose { macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { $crate::expose!($struct_name::aici_pre_process() -> ()); - $crate::expose!($struct_name::aici_process() -> ()); + $crate::expose!($struct_name::aici_mid_process() -> ()); $crate::expose!($struct_name::aici_post_process() -> ()); $crate::expose!($struct_name::aici_init_prompt() -> ()); diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 80989092..5fafe111 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use crate::{ host, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, PostProcessArg, PostProcessResult, ProcessArg, ProcessResult, + AiciVm, PostProcessArg, PostProcessResult, MidProcessArg, MidProcessResult, }; pub struct AiciRecognizer { @@ -21,11 +21,11 @@ impl AiciRecognizer { } impl AiciVm for AiciRecognizer { - fn process(&mut self, _arg: ProcessArg) -> ProcessResult { + fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { let mut set = self.trie.alloc_token_set(); self.trie.compute_bias(&mut self.rec, &mut set); host::return_logit_bias(&set); - ProcessResult::SampleWithBias + MidProcessResult::SampleWithBias } fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { From 1c3e36c9876d056456412147fa0a6fd1749b4ddb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 18:23:53 +0000 Subject: [PATCH 097/301] remove unused option --- aici_abi/src/lib.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 8f5bcab2..8d9741e5 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -56,13 +56,6 @@ pub enum MidProcessResult { backtrack: u32, ff_tokens: Vec, }, - - /// Wait until all listed variables are available for reading, - /// and all listed sequences have finished executing. - WaitAll { - variables: Vec, - finished: Vec, - }, } #[derive(Serialize, Deserialize, Debug)] From 3a5c45a9969856aa70ae11d979300ffa97678bf1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 19:47:56 +0000 Subject: [PATCH 098/301] cleaner return value from mid_process() --- aici_abi/src/lib.rs | 29 ++++++++++++++++++++++------- aici_abi/src/recognizer.rs | 8 ++++---- aici_abi/src/svob.rs | 16 ++++++++++++++++ 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 8d9741e5..b262fbf1 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -1,7 +1,8 @@ use serde::{Deserialize, Serialize}; +use svob::SimpleVob; pub mod bytes; -pub mod host; +mod host; pub mod recognizer; pub mod rng; pub mod svob; @@ -9,6 +10,11 @@ pub mod toktree; pub type TokenId = bytes::TokenId; +pub use host::{ + _print, arg_bytes, self_seq_id, stdout, tokenize, StorageCmd, StorageOp, StorageResp, + VariableStorage, +}; + #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { pub prompt: Vec, @@ -46,8 +52,11 @@ pub enum MidProcessResult { /// Similar to strong bias to EOS. Stop, - /// Sample next token in the current sequence, using bias set with `return_logit_bias()` - SampleWithBias, + /// Sample next token in the current sequence + SampleWithBias { + #[serde(skip)] + allowed_tokens: SimpleVob, + }, /// First pop `backtrack` tokens, /// then force next tokens to be generated to be `ff_tokens`. @@ -131,6 +140,12 @@ pub trait AiciVm { fn aici_mid_process(&mut self) { let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); let res = self.mid_process(arg); + match &res { + MidProcessResult::SampleWithBias { allowed_tokens } => { + host::return_logit_bias(allowed_tokens); + } + _ => {} + } let res_bytes = serde_json::to_vec(&res).unwrap(); host::return_process_result(&res_bytes); } @@ -210,17 +225,17 @@ macro_rules! include_bytes_aligned { #[macro_export] macro_rules! wprintln { () => { - $crate::host::_print("\n") + $crate::_print("\n") }; ($($arg:tt)*) => {{ - $crate::host::_print(&format!($($arg)*)); - $crate::host::_print("\n"); + $crate::_print(&format!($($arg)*)); + $crate::_print("\n"); }}; } #[macro_export] macro_rules! wprint { ($($arg:tt)*) => {{ - $crate::host::_print(&format!($($arg)*)); + $crate::_print(&format!($($arg)*)); }}; } diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 5fafe111..e35f1602 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,9 +1,8 @@ use std::fmt::Debug; use crate::{ - host, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, PostProcessArg, PostProcessResult, MidProcessArg, MidProcessResult, + AiciVm, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, }; pub struct AiciRecognizer { @@ -24,8 +23,9 @@ impl AiciVm for AiciRecognizer { fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { let mut set = self.trie.alloc_token_set(); self.trie.compute_bias(&mut self.rec, &mut set); - host::return_logit_bias(&set); - MidProcessResult::SampleWithBias + MidProcessResult::SampleWithBias { + allowed_tokens: set, + } } fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs index f3dd062f..9e97e7ce 100644 --- a/aici_abi/src/svob.rs +++ b/aici_abi/src/svob.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use crate::TokenId; #[derive(Clone)] @@ -5,6 +7,20 @@ pub struct SimpleVob { data: Vec, } +impl Debug for SimpleVob { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SimpleVob") + .field("len", &self.len()) + .finish() + } +} + +impl Default for SimpleVob { + fn default() -> Self { + Self::new() + } +} + const BITS: usize = 32; impl SimpleVob { From 8959dc6e90667c1574a150a768ea7cad37bc9af5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 19:59:59 +0000 Subject: [PATCH 099/301] update comments --- aici_abi/src/lib.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index b262fbf1..ba968605 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -100,26 +100,20 @@ impl PreProcessResult { } pub trait AiciVm { - /// Called with the initial prompt. Has long time limit. + /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. fn init_prompt(&mut self, _arg: InitPromptArg) {} - /// Called before process(), can return attention masks. Has short time limit. + /// Called before mid_process(), can return attention masks. ~1ms time limit. /// Should be stateless. fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { PreProcessResult::continue_() } - /// This is the main entry point for the module. - /// Following calls are issued: - /// * `Append { tokens: [] }` - to generate bias for the first token of the output - /// And then any combination of: - /// * `Append { tokens: [t] }` - when a token `t` is sampled - /// * `Append { tokens: [t...] }` - after fast-forward - /// Either way, a bias should be eventually generated. + /// This is the main entry point for the module. ~20ms time limit. fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; - /// Called after tokens are appended, before process(). + /// Called after tokens are appended, after mid_process(). ~1ms time limit. fn post_process(&mut self, _arg: PostProcessArg) -> PostProcessResult { PostProcessResult {} } From 4d936c9cba98c32e5eb14ca7461c8098748d5f93 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 Nov 2023 23:25:21 +0000 Subject: [PATCH 100/301] backtracing in ast runner --- aici_abi/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index ba968605..df771476 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -60,6 +60,7 @@ pub enum MidProcessResult { /// First pop `backtrack` tokens, /// then force next tokens to be generated to be `ff_tokens`. + /// `backtrack` count includes the token about to be generated from this step. /// `backtrack` can be 0, and `ff_tokens` can be empty but not both. Splice { backtrack: u32, From 0889158bcf998712778dae4e738073d28743febe Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 22 Nov 2023 00:04:57 +0000 Subject: [PATCH 101/301] hooking up backtracking --- aici_abi/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index df771476..0377770e 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -74,6 +74,9 @@ pub struct PostProcessArg { /// `tokens` is typically just this one token, except for the /// cases when fast-forward tokens are used. pub tokens: Vec, + + /// Typically 0. + pub backtrack: u32, } #[derive(Serialize, Deserialize, Debug)] From d6d5026239467850b68131ae2015926a4bf4b5d9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 29 Nov 2023 00:09:35 +0000 Subject: [PATCH 102/301] expressions in ast --- aici_abi/src/host.rs | 9 +++++++++ aici_abi/src/lib.rs | 2 +- aici_abi/src/svob.rs | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index d40c5b6c..f939d617 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -250,6 +250,15 @@ impl VariableStorage { } } +/// Tokenize given byte string. +pub fn tokenize_bytes(s: &[u8]) -> Vec { + let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; + let r = read_blob(id, 4 * (s.len() / 3 + 10)); + let res = vec_from_bytes(&r); + wprintln!("tokenize_bytes: {:?} -> {:?}", String::from_utf8_lossy(s), res); + res +} + /// Tokenize given UTF8 string. pub fn tokenize(s: &str) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 0377770e..56b8f88c 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -11,7 +11,7 @@ pub mod toktree; pub type TokenId = bytes::TokenId; pub use host::{ - _print, arg_bytes, self_seq_id, stdout, tokenize, StorageCmd, StorageOp, StorageResp, + _print, arg_bytes, self_seq_id, stdout, tokenize, tokenize_bytes, StorageCmd, StorageOp, StorageResp, VariableStorage, }; diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs index 9e97e7ce..b3809f7c 100644 --- a/aici_abi/src/svob.rs +++ b/aici_abi/src/svob.rs @@ -44,6 +44,14 @@ impl SimpleVob { self.data[byte_idx] |= 1 << bit_idx; } + #[inline(always)] + pub fn disallow_token(&mut self, tok: TokenId) { + let idx = tok as usize; + let byte_idx = idx / BITS; + let bit_idx = idx % BITS; + self.data[byte_idx] &= !(1 << bit_idx); + } + pub fn resize(&mut self, size: usize) { let new_size = size / BITS + 1; assert!(new_size >= self.data.len()); From 97e2ba7b51a0837c5b35e8d52523d4cdf3336c11 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 29 Nov 2023 18:21:07 +0000 Subject: [PATCH 103/301] return storage operations to the user --- aici_abi/src/host.rs | 49 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index f939d617..35bc2ead 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -164,6 +164,42 @@ pub enum StorageOp { Append, } +#[allow(dead_code)] +pub mod bin_string { + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Vec, s: S) -> Result { + let binstr = String::from_iter(v.iter().map(|b| *b as char)); + String::serialize(&binstr, s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let binstr = String::deserialize(d)?; + Ok(binstr.chars().map(|c| c as u8).collect()) + } +} + +pub mod hex_string { + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Vec, s: S) -> Result { + let hexstr = String::from_iter(v.iter().map(|b| format!("{:02x}", b))); + String::serialize(&hexstr, s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let hexstr = String::deserialize(d)?; + let mut res = Vec::new(); + for i in 0..(hexstr.len() / 2) { + let b = u8::from_str_radix(&hexstr[2 * i..2 * i + 2], 16).map_err(serde::de::Error::custom)?; + res.push(b); + } + Ok(res) + } +} + #[derive(Serialize, Deserialize, Debug)] pub enum StorageCmd { /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. @@ -177,6 +213,7 @@ pub enum StorageCmd { /// just like ReadVar would. WriteVar { name: String, + #[serde(with = "hex_string")] value: Vec, op: StorageOp, when_version_is: Option, @@ -186,7 +223,11 @@ pub enum StorageCmd { #[derive(Serialize, Deserialize, Debug)] pub enum StorageResp { /// Upon handling the request the variable had the specified value and version number. - ReadVar { version: u64, value: Vec }, + ReadVar { + version: u64, + #[serde(with = "hex_string")] + value: Vec, + }, /// Upon handling the request the variable was unset. VariableMissing {}, /// The variable has been written, and the new version is returned. @@ -255,7 +296,11 @@ pub fn tokenize_bytes(s: &[u8]) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); let res = vec_from_bytes(&r); - wprintln!("tokenize_bytes: {:?} -> {:?}", String::from_utf8_lossy(s), res); + wprintln!( + "tokenize_bytes: {:?} -> {:?}", + String::from_utf8_lossy(s), + res + ); res } From cd6761c134e0c8981097327c546c5a7acb28e510 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 1 Dec 2023 23:17:14 +0000 Subject: [PATCH 104/301] basic sample working with py --- aici_abi/src/lib.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 56b8f88c..7daa60d0 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -11,8 +11,8 @@ pub mod toktree; pub type TokenId = bytes::TokenId; pub use host::{ - _print, arg_bytes, self_seq_id, stdout, tokenize, tokenize_bytes, StorageCmd, StorageOp, StorageResp, - VariableStorage, + _print, arg_bytes, return_logit_bias, self_seq_id, stdout, tokenize, tokenize_bytes, + StorageCmd, StorageOp, StorageResp, VariableStorage, }; #[derive(Serialize, Deserialize, Debug)] @@ -140,7 +140,9 @@ pub trait AiciVm { let res = self.mid_process(arg); match &res { MidProcessResult::SampleWithBias { allowed_tokens } => { - host::return_logit_bias(allowed_tokens); + if allowed_tokens.len() > 0 { + host::return_logit_bias(allowed_tokens); + } } _ => {} } From 8716610557d2aa4a9e138166b58c7f736152a377 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Dec 2023 00:11:50 +0000 Subject: [PATCH 105/301] add rx --- aici_abi/src/recognizer.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index e35f1602..b1d1f366 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -75,6 +75,14 @@ impl> StackRecognizer { } } +impl> Debug for StackRecognizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StackRecognizer") + .field("top", &self.stack[self.stack_ptr]) + .finish() + } +} + impl> Recognizer for StackRecognizer { #[inline(always)] fn push_byte(&mut self, byte: u8) { From ea599949ab7c624b366554f3e8f0d01d818363f5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Dec 2023 04:22:44 +0000 Subject: [PATCH 106/301] work on seq finishing --- aici_abi/src/host.rs | 11 ++++++++++- aici_abi/src/lib.rs | 24 +++++++++++++++++++++--- aici_abi/src/recognizer.rs | 6 +++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 35bc2ead..052e5662 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -41,6 +41,9 @@ extern "C" { fn aici_host_return_process_result(res: *const u8, res_size: u32); fn aici_host_storage_cmd(cmd: *const u8, cmd_size: u32) -> BlobId; + + // This can be also obtained from the TokTrie. + fn aici_host_eos_token() -> TokenId; } // TODO: add @@ -193,7 +196,8 @@ pub mod hex_string { let hexstr = String::deserialize(d)?; let mut res = Vec::new(); for i in 0..(hexstr.len() / 2) { - let b = u8::from_str_radix(&hexstr[2 * i..2 * i + 2], 16).map_err(serde::de::Error::custom)?; + let b = u8::from_str_radix(&hexstr[2 * i..2 * i + 2], 16) + .map_err(serde::de::Error::custom)?; res.push(b); } Ok(res) @@ -317,3 +321,8 @@ pub fn tokenize(s: &str) -> Vec { pub fn self_seq_id() -> SeqId { unsafe { SeqId(aici_host_self_seq_id()) } } + +/// Return the ID of the EOS token. +pub fn eos_token() -> TokenId { + unsafe { aici_host_eos_token() } +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 7daa60d0..28ed45e2 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -80,7 +80,25 @@ pub struct PostProcessArg { } #[derive(Serialize, Deserialize, Debug)] -pub struct PostProcessResult {} +pub struct PostProcessResult { + /// If true, stop the sequence. + pub stop: bool, +} + +impl PostProcessResult { + pub fn stop() -> Self { + PostProcessResult { stop: true } + } + + pub fn continue_() -> Self { + PostProcessResult { stop: false } + } + + pub fn from_arg(arg: &PostProcessArg) -> Self { + let stop = arg.tokens.contains(&host::eos_token()); + PostProcessResult { stop } + } +} impl PreProcessResult { pub fn new(attention_masks: Vec>) -> Self { @@ -118,8 +136,8 @@ pub trait AiciVm { fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; /// Called after tokens are appended, after mid_process(). ~1ms time limit. - fn post_process(&mut self, _arg: PostProcessArg) -> PostProcessResult { - PostProcessResult {} + fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { + PostProcessResult::from_arg(&arg) } // Internals diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index b1d1f366..5c26b3ea 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -29,15 +29,15 @@ impl AiciVm for AiciRecognizer { } fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { - for token in arg.tokens { - let bytes = self.trie.token(token); + for token in &arg.tokens { + let bytes = self.trie.token(*token); // wprintln!("process {} {:?}", token, bytes); for b in bytes { self.rec.push_byte(*b) } self.rec.collapse(); } - PostProcessResult {} + PostProcessResult::from_arg(&arg) } } From 2db5b20cc6fcc193c6c2ab34053e3c1e53589f55 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 9 Dec 2023 00:30:45 +0000 Subject: [PATCH 107/301] move JSON types to aici_abi::api --- aici_abi/src/api.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++ aici_abi/src/lib.rs | 1 + 2 files changed, 81 insertions(+) create mode 100644 aici_abi/src/api.rs diff --git a/aici_abi/src/api.rs b/aici_abi/src/api.rs new file mode 100644 index 00000000..9e3765f6 --- /dev/null +++ b/aici_abi/src/api.rs @@ -0,0 +1,80 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::TokenId; + +pub type ModuleInstId = usize; + +#[derive(Serialize, Deserialize)] +pub struct AiciPreProcessReq { + pub max_context_len: usize, // in tokens + pub freed: Vec, + pub ops: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct AiciProcessReq { + pub ops: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct AiciPostProcessReq { + pub ops: Vec, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct AiciPreOp { + pub id: ModuleInstId, + pub req_id: Option, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct AiciMidOp { + pub id: ModuleInstId, + pub clone_id: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct AiciPostOp { + pub id: ModuleInstId, + pub tokens: Vec, + #[serde(default)] + pub backtrack: u32, + pub clone_id: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct MkModuleReq { + pub binary: String, + #[serde(default)] + pub meta: Value, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct InstantiateReq { + pub req_id: String, + // [TokenId] or str + pub prompt: Value, + pub module_id: String, + #[serde(default)] + pub module_arg: Value, +} + +pub type Token = TokenId; + +#[derive(Serialize, Deserialize)] +pub struct SpecialTokenIds { + pub bos: Option, + pub eos: Option, + pub unk: Option, + pub sep: Option, + pub pad: Option, + pub cls: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct TokensReq { + pub tokens: Vec, + pub special: SpecialTokenIds, +} + diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 28ed45e2..7013282b 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -7,6 +7,7 @@ pub mod recognizer; pub mod rng; pub mod svob; pub mod toktree; +pub mod api; pub type TokenId = bytes::TokenId; From 34fb22ee4bc09ec44aa830847c5bd3b9d71d85d2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 9 Dec 2023 00:49:39 +0000 Subject: [PATCH 108/301] adding aicirt lib --- aici_abi/src/api.rs | 80 --------------------------------------------- aici_abi/src/lib.rs | 1 - 2 files changed, 81 deletions(-) delete mode 100644 aici_abi/src/api.rs diff --git a/aici_abi/src/api.rs b/aici_abi/src/api.rs deleted file mode 100644 index 9e3765f6..00000000 --- a/aici_abi/src/api.rs +++ /dev/null @@ -1,80 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -use crate::TokenId; - -pub type ModuleInstId = usize; - -#[derive(Serialize, Deserialize)] -pub struct AiciPreProcessReq { - pub max_context_len: usize, // in tokens - pub freed: Vec, - pub ops: Vec, -} - -#[derive(Serialize, Deserialize)] -pub struct AiciProcessReq { - pub ops: Vec, -} - -#[derive(Serialize, Deserialize)] -pub struct AiciPostProcessReq { - pub ops: Vec, -} - -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct AiciPreOp { - pub id: ModuleInstId, - pub req_id: Option, -} - -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct AiciMidOp { - pub id: ModuleInstId, - pub clone_id: Option, -} - -#[derive(Serialize, Deserialize)] -pub struct AiciPostOp { - pub id: ModuleInstId, - pub tokens: Vec, - #[serde(default)] - pub backtrack: u32, - pub clone_id: Option, -} - -#[derive(Serialize, Deserialize)] -pub struct MkModuleReq { - pub binary: String, - #[serde(default)] - pub meta: Value, -} - -#[derive(Serialize, Deserialize, Clone)] -pub struct InstantiateReq { - pub req_id: String, - // [TokenId] or str - pub prompt: Value, - pub module_id: String, - #[serde(default)] - pub module_arg: Value, -} - -pub type Token = TokenId; - -#[derive(Serialize, Deserialize)] -pub struct SpecialTokenIds { - pub bos: Option, - pub eos: Option, - pub unk: Option, - pub sep: Option, - pub pad: Option, - pub cls: Option, -} - -#[derive(Serialize, Deserialize)] -pub struct TokensReq { - pub tokens: Vec, - pub special: SpecialTokenIds, -} - diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 7013282b..28ed45e2 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -7,7 +7,6 @@ pub mod recognizer; pub mod rng; pub mod svob; pub mod toktree; -pub mod api; pub type TokenId = bytes::TokenId; From 5da3567a6a146f92b86dd8f550287db63b4307ca Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 9 Dec 2023 01:23:13 +0000 Subject: [PATCH 109/301] use TokTrie in server --- aici_abi/src/toktree.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index b3202695..079deb1d 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -193,6 +193,17 @@ impl TokTrie { &self.token_data[off..(off + len as usize)] } + pub fn decode(&self, tokens: &[TokenId]) -> Vec { + tokens + .iter() + .flat_map(|t| self.token(*t).to_vec()) + .collect() + } + + pub fn decode_str(&self, tokens: &[TokenId]) -> String { + String::from_utf8_lossy(&self.decode(tokens)).to_string() + } + pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec { let mut r = Vec::new(); if bytes.len() == 0 { From ffa02cf6692c7fab5f3686f87100f1c4cae5c1b2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 9 Dec 2023 19:32:53 +0000 Subject: [PATCH 110/301] typed aici cmdchannel ifaces --- aici_abi/src/host.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 052e5662..da67f46f 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -161,7 +161,7 @@ pub fn return_process_result(res: &[u8]) { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum StorageOp { Set, Append, @@ -204,7 +204,7 @@ pub mod hex_string { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub enum StorageCmd { /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. ReadVar { name: String }, From ed4654ee73e4acab3de5dc1cd8ec64b41bd652b1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 10 Dec 2023 01:09:50 +0000 Subject: [PATCH 111/301] fixes --- aici_abi/src/toktree.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 079deb1d..c642615d 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -182,6 +182,10 @@ impl TokTrie { vec![0.0; self.vocab_size() + 1] } + pub fn token_dbg(&self, idx: u32) -> String { + format!("{:?}[{}]", self.token_str(idx), idx) + } + pub fn token_str(&self, idx: u32) -> String { String::from_utf8_lossy(self.token(idx)).to_string() } From f9937ad5d7c887a76ee8a5979351905b1e61612f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Dec 2023 20:52:38 -0800 Subject: [PATCH 112/301] allow returning ff_tokens from init_prompt --- aici_abi/src/lib.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 28ed45e2..b6d0e4c0 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -20,6 +20,11 @@ pub struct InitPromptArg { pub prompt: Vec, } +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct InitPromptResult { + pub ff_tokens: Vec, +} + #[repr(transparent)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct SeqId(pub u32); @@ -124,7 +129,9 @@ impl PreProcessResult { pub trait AiciVm { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. - fn init_prompt(&mut self, _arg: InitPromptArg) {} + fn init_prompt(&mut self, _arg: InitPromptArg) -> InitPromptResult { + InitPromptResult { ff_tokens: vec![] } + } /// Called before mid_process(), can return attention masks. ~1ms time limit. /// Should be stateless. @@ -143,7 +150,9 @@ pub trait AiciVm { // Internals fn aici_init_prompt(&mut self) { let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - self.init_prompt(arg); + let res = self.init_prompt(arg); + let res_bytes = serde_json::to_vec(&res).unwrap(); + host::return_process_result(&res_bytes); } fn aici_pre_process(&mut self) { From 0dc833b553cd40f2dc8ff51ccb2d1a0229d6e4d8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Dec 2023 21:36:39 -0800 Subject: [PATCH 113/301] remove ff_tokens from InitPromptResult --- aici_abi/src/lib.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index b6d0e4c0..109b1aa6 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -21,9 +21,7 @@ pub struct InitPromptArg { } #[derive(Serialize, Deserialize, Debug, Default)] -pub struct InitPromptResult { - pub ff_tokens: Vec, -} +pub struct InitPromptResult {} #[repr(transparent)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] @@ -130,7 +128,7 @@ pub trait AiciVm { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. fn init_prompt(&mut self, _arg: InitPromptArg) -> InitPromptResult { - InitPromptResult { ff_tokens: vec![] } + InitPromptResult::default() } /// Called before mid_process(), can return attention masks. ~1ms time limit. From f5f787874325f6ee53ed5fd24bf074a2ae63d813 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Dec 2023 22:18:18 -0800 Subject: [PATCH 114/301] allow for pre.ff_tokens --- aici_abi/src/lib.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 109b1aa6..89f3ad6f 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -40,6 +40,10 @@ pub struct PreProcessResult { pub attention_masks: Vec>, pub suspend: bool, + + /// If non-empty, the tokens may be appended and post_process() be called immediately, + /// skipping mid_process(); pre_process() is then typically called again. + pub ff_tokens: Vec, } #[derive(Serialize, Deserialize, Debug)] @@ -108,6 +112,7 @@ impl PreProcessResult { PreProcessResult { attention_masks, suspend: false, + ff_tokens: vec![], } } pub fn continue_() -> Self { @@ -117,11 +122,19 @@ impl PreProcessResult { PreProcessResult { attention_masks: vec![vec![]], suspend: true, + ff_tokens: vec![], } } pub fn stop() -> Self { PreProcessResult::new(vec![]) } + pub fn ff_tokens(toks: Vec) -> Self { + PreProcessResult { + attention_masks: vec![vec![]], + suspend: false, + ff_tokens: toks, + } + } } pub trait AiciVm { From 92eacd644b104e24ff71cdc8fc4da20a95518ff5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 20 Dec 2023 10:00:25 +0000 Subject: [PATCH 115/301] merge imports --- aici_abi/src/host.rs | 11 ++++------- aici_abi/src/recognizer.rs | 3 +-- aici_abi/src/svob.rs | 3 +-- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index da67f46f..5d8b00b2 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,11 +1,10 @@ -use serde::{Deserialize, Serialize}; -use std::io; - use crate::{ bytes::{vec_from_bytes, TokenId}, svob::SimpleVob, wprintln, SeqId, }; +use serde::{Deserialize, Serialize}; +use std::io; #[repr(transparent)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -169,8 +168,7 @@ pub enum StorageOp { #[allow(dead_code)] pub mod bin_string { - use serde::{Deserialize, Serialize}; - use serde::{Deserializer, Serializer}; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub fn serialize(v: &Vec, s: S) -> Result { let binstr = String::from_iter(v.iter().map(|b| *b as char)); @@ -184,8 +182,7 @@ pub mod bin_string { } pub mod hex_string { - use serde::{Deserialize, Serialize}; - use serde::{Deserializer, Serializer}; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub fn serialize(v: &Vec, s: S) -> Result { let hexstr = String::from_iter(v.iter().map(|b| format!("{:02x}", b))); diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index 5c26b3ea..eefa9bfc 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,9 +1,8 @@ -use std::fmt::Debug; - use crate::{ toktree::{Recognizer, SpecialToken, TokTrie}, AiciVm, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, }; +use std::fmt::Debug; pub struct AiciRecognizer { pub trie: TokTrie, diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs index b3809f7c..6863b0ac 100644 --- a/aici_abi/src/svob.rs +++ b/aici_abi/src/svob.rs @@ -1,6 +1,5 @@ -use std::fmt::Debug; - use crate::TokenId; +use std::fmt::Debug; #[derive(Clone)] pub struct SimpleVob { From 912904611874f039f696f846c978601a8f683bb9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 2 Jan 2024 13:44:57 +0000 Subject: [PATCH 116/301] moving to workspace --- aici_abi/Cargo.lock | 89 --------------------------------------------- 1 file changed, 89 deletions(-) delete mode 100644 aici_abi/Cargo.lock diff --git a/aici_abi/Cargo.lock b/aici_abi/Cargo.lock deleted file mode 100644 index acd53088..00000000 --- a/aici_abi/Cargo.lock +++ /dev/null @@ -1,89 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "aici_abi" -version = "0.1.0" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "itoa" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" - -[[package]] -name = "proc-macro2" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "ryu" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" - -[[package]] -name = "serde" -version = "1.0.192" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.192" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.108" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "syn" -version = "2.0.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" From b461b59714a823b003cb1819c8d62c3722bc0e30 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 3 Jan 2024 11:29:35 +0000 Subject: [PATCH 117/301] clean up python exception output --- aici_abi/src/host.rs | 10 ++++++++++ aici_abi/src/lib.rs | 7 ++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 5d8b00b2..34f1be13 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -43,6 +43,10 @@ extern "C" { // This can be also obtained from the TokTrie. fn aici_host_eos_token() -> TokenId; + + // Stop the program - any error info is assumed to have been printed already. + // Backtraces will be limited. + fn aici_host_stop(); } // TODO: add @@ -323,3 +327,9 @@ pub fn self_seq_id() -> SeqId { pub fn eos_token() -> TokenId { unsafe { aici_host_eos_token() } } + +/// Stop the program - any error info is assumed to have been printed already. +pub fn aici_stop() -> ! { + unsafe { aici_host_stop() }; + panic!("didn't stop"); +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 89f3ad6f..0edac727 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -11,7 +11,7 @@ pub mod toktree; pub type TokenId = bytes::TokenId; pub use host::{ - _print, arg_bytes, return_logit_bias, self_seq_id, stdout, tokenize, tokenize_bytes, + _print, aici_stop, arg_bytes, return_logit_bias, self_seq_id, stdout, tokenize, tokenize_bytes, StorageCmd, StorageOp, StorageResp, VariableStorage, }; @@ -174,7 +174,8 @@ pub trait AiciVm { } fn aici_mid_process(&mut self) { - let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) + .expect("aici_mid_process: failed to deserialize MidProcessArg"); let res = self.mid_process(arg); match &res { MidProcessResult::SampleWithBias { allowed_tokens } => { @@ -184,7 +185,7 @@ pub trait AiciVm { } _ => {} } - let res_bytes = serde_json::to_vec(&res).unwrap(); + let res_bytes = serde_json::to_vec(&res).expect("aici_mid_process: failed to serialize"); host::return_process_result(&res_bytes); } From eae1a63ae006606bea581ac2bf566ae607b4a6ac Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 15:07:16 +0000 Subject: [PATCH 118/301] move constraints to aici_abi --- aici_abi/Cargo.toml | 13 + aici_abi/src/cfg.rs | 574 ++++++++++++++++++++++++++++++++++++++ aici_abi/src/lex.rs | 349 +++++++++++++++++++++++ aici_abi/src/lib.rs | 10 + aici_abi/src/rx.rs | 66 +++++ aici_abi/src/substring.rs | 274 ++++++++++++++++++ 6 files changed, 1286 insertions(+) create mode 100644 aici_abi/src/cfg.rs create mode 100644 aici_abi/src/lex.rs create mode 100644 aici_abi/src/rx.rs create mode 100644 aici_abi/src/substring.rs diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index 1b1787a8..21e10f93 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -9,3 +9,16 @@ name = "aici_abi" [dependencies] serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" +anyhow = "1.0.75" +regex-automata = { version = "0.4.3", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } +cfgrammar = { version = "0.13.3", optional = true } +lrlex = { version = "0.13.3", optional = true } +lrpar = { version = "0.13.3", optional = true } +lrtable = { version = "0.13.3", optional = true } +vob = { version = "3.0.3", optional = true } +rustc-hash = { version = "1.1.0", optional = true } + +[features] +default = ["cfg", "rx"] +cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] +rx = ["dep:regex-automata"] diff --git a/aici_abi/src/cfg.rs b/aici_abi/src/cfg.rs new file mode 100644 index 00000000..f05e32eb --- /dev/null +++ b/aici_abi/src/cfg.rs @@ -0,0 +1,574 @@ +use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; +use crate::{ + svob::SimpleVob, + toktree::{Recognizer, SpecialToken, TokTrie}, + wprint, wprintln, +}; +use anyhow::Result; +use cfgrammar::{ + yacc::{YaccGrammar, YaccKind}, + Span, Spanned, Symbol, TIdx, +}; +use lrtable::{from_yacc, Action, Minimiser, StIdx, StateTable}; +use rustc_hash::FxHashMap; +use std::{cell::RefCell, vec}; +use vob::{vob, Vob}; + +type StorageT = u32; +type PStack = Vec>; // Parse stack + +const LOG_PARSER: bool = false; + +#[derive(Debug, Clone, Copy)] +enum ParseResult { + Accept, + Error, + Continue, +} + +struct CfgStats { + yacc_actions: usize, + states_pushed: usize, +} + +pub struct CfgParser { + grm: YaccGrammar, + stable: StateTable, + lexer: Lexer, + byte_states: Vec, + pat_idx_to_tidx: Vec>, + vobset: VobSet, + stats: RefCell, + tidx_to_pat_idx: FxHashMap, usize>, + parse_stacks: Vec>>, + skip_patterns: Vob, + friendly_pattern_names: Vec, + viable_vobidx_by_state: Vec, +} + +fn is_rx(name: &str) -> bool { + name.len() > 2 && name.starts_with("/") && name.ends_with("/") +} + +fn quote_rx(name: &str) -> String { + name.chars() + .map(|ch| { + if ('0' <= ch && ch <= '9') + || ('a' <= ch && ch <= 'z') + || ('A' <= ch && ch <= 'Z') + || '<' == ch + || '>' == ch + { + ch.to_string() + } else { + format!("\\{}", ch) + } + }) + .collect::() +} + +impl CfgParser { + fn span_to_str(s: &Span, src: &str) -> String { + let mut line = 1; + let mut last_nl = 0; + for (idx, ch) in src.chars().enumerate() { + if idx == s.start() { + break; + } + if ch == '\n' { + line += 1; + last_nl = idx; + } + } + let column = s.start() - last_nl; + format!("({},{})", line, column) + } + + pub fn from_yacc(yacc: &str) -> Result { + let grmkind = YaccKind::Original(cfgrammar::yacc::YaccOriginalActionKind::NoAction); + let grm = match YaccGrammar::new(grmkind, yacc) { + Ok(grm) => grm, + Err(e) => { + let err_str = e + .iter() + .map(|e| { + let spans = e + .spans() + .iter() + .map(|s| Self::span_to_str(s, yacc)) + .collect::>() + .join(", "); + format!("{}: {}", spans, e) + }) + .collect::>() + .join("\n"); + anyhow::bail!("yacc grammar errors:\n{}", err_str); + } + }; + + // TIME: all these annotation are for native release x86 build for C grammar + // TIME: 27ms + let (sgraph, stable) = match from_yacc(&grm, Minimiser::Pager) { + Ok(r) => r, + Err(e) => { + // not sure this works: + // anyhow::bail!("state table error:\n{e} on {:?}", grm.action(e.pidx)); + anyhow::bail!("state table error:\n{e}"); + } + }; + + if false { + wprintln!("core\n{}\n\n", sgraph.pp(&grm, true)); + for pidx in grm.iter_pidxs() { + let prod = grm.prod(pidx); + wprintln!("{:?} -> {}", prod, prod.len()); + } + } + + let mut pat_idx_to_tidx = grm + .iter_tidxs() + .filter(|tidx| grm.token_name(*tidx).is_some()) + .collect::>(); + + pat_idx_to_tidx.sort_by_key(|tidx| { + let name = grm.token_name(*tidx).unwrap(); + let l = name.len() as isize; + if is_rx(name) { + -l + 100000 + } else { + -l + } + }); + + let patterns = pat_idx_to_tidx + .iter() + .map(|tok| { + let name = grm.token_name(*tok).unwrap(); + if is_rx(name) { + name[1..name.len() - 1].to_string() + } else { + quote_rx(name) + } + }) + .collect::>(); + + let mut tidx_to_pat_idx = FxHashMap::default(); + for (idx, _tok) in patterns.iter().enumerate() { + tidx_to_pat_idx.insert(pat_idx_to_tidx[idx], idx); + } + + let mut skip_patterns = vob![false; patterns.len()]; + let mut friendly_pattern_names = pat_idx_to_tidx + .iter() + .map(|tok| grm.token_name(*tok).unwrap().to_string()) + .collect::>(); + + for ridx in grm.iter_rules() { + let rname = grm.rule_name_str(ridx); + if rname.to_uppercase() != rname { + continue; + } + for pidx in grm.rule_to_prods(ridx) { + let toks = grm.prod(*pidx); + if let [Symbol::Token(tidx)] = toks { + let idx = *tidx_to_pat_idx.get(&tidx).unwrap(); + friendly_pattern_names[idx] = rname.to_string(); + if rname == "SKIP" { + skip_patterns.set(idx, true); + } + } + } + } + + wprintln!("patterns: {:?}", friendly_pattern_names); + + let mut vobset = VobSet::new(); + // all-zero has to be inserted first + let _all0 = vobset.get(&vob![false; patterns.len()]); + let all1 = vobset.get(&vob![true; patterns.len()]); + + // TIME: 27ms + let dfa = Lexer::from(patterns, &mut vobset); + + let parse_stacks = vec![vec![stable.start_state()]]; + + let byte_state = ByteState { + lexer_state: dfa.file_start_state(), + parse_stack_idx: PStackIdx(0), + viable: all1, + }; + + let viable_vobidx_by_state = sgraph + .iter_stidxs() + .enumerate() + .map(|(idx, stidx)| { + assert!(idx == stidx.as_storaget() as usize); + + // skip patterns (whitespace) are always viable + let mut r = skip_patterns.clone(); + for tidx in stable.state_actions(stidx) { + match stable.action(stidx, tidx) { + Action::Error => {} + _ => { + if let Some(pat_idx) = tidx_to_pat_idx.get(&tidx) { + r.set(*pat_idx, true); + } + } + } + } + + vobset.get(&r) + }) + .collect::>(); + + let mut cfg = CfgParser { + grm, + stable, + lexer: dfa, + byte_states: vec![byte_state], + pat_idx_to_tidx, + tidx_to_pat_idx, + viable_vobidx_by_state, + skip_patterns, + friendly_pattern_names, + parse_stacks, + vobset, + stats: RefCell::new(CfgStats { + yacc_actions: 0, + states_pushed: 0, + }), + }; + + cfg.vobset.pre_compute(); + + Ok(cfg) + } + + fn viable_vobidx(&self, stidx: StIdx) -> VobIdx { + self.viable_vobidx_by_state[stidx.as_storaget() as usize] + } + + #[allow(dead_code)] + fn friendly_token_name(&self, lexeme: TIdx) -> &str { + if let Some(pidx) = self.tidx_to_pat_idx.get(&lexeme) { + &self.friendly_pattern_names[*pidx] + } else if self.grm.eof_token_idx() == lexeme { + return ""; + } else { + return ""; + } + } + + fn parse_lexeme(&self, lexeme: TIdx, pstack: &mut PStack) -> ParseResult { + loop { + let stidx = *pstack.last().unwrap(); + + let act = self.stable.action(stidx, lexeme); + + if LOG_PARSER { + wprintln!( + "parse: {:?} {:?} -> {:?}", + pstack, + self.friendly_token_name(lexeme), + act + ); + } + + match act { + Action::Reduce(pidx) => { + let ridx = self.grm.prod_to_rule(pidx); + let pop_idx = pstack.len() - self.grm.prod(pidx).len(); + pstack.drain(pop_idx..); + let prior = *pstack.last().unwrap(); + pstack.push(self.stable.goto(prior, ridx).unwrap()); + } + Action::Shift(state_id) => { + pstack.push(state_id); + return ParseResult::Continue; + } + Action::Accept => { + // only happens when lexeme is EOF + return ParseResult::Accept; + } + Action::Error => { + return ParseResult::Error; + } + } + } + } + + #[allow(dead_code)] + fn print_viable(&self, lbl: &str, vob: &Vob) { + wprintln!("viable tokens {}:", lbl); + for (idx, b) in vob.iter().enumerate() { + if b { + wprintln!(" {}: {}", idx, self.friendly_pattern_names[idx]); + } + } + } + + // None means EOF + #[inline(always)] + fn try_push(&mut self, byte: Option) -> Option { + let top = self.byte_states.last().unwrap().clone(); + if LOG_PARSER { + wprint!("try_push: "); + if let Some(b) = byte { + wprint!("{:?}", b as char) + } else { + wprint!("") + } + } + let (info, res) = match self.lexer.advance(top.lexer_state, byte) { + // Error? + None => ("lex-err", None), + // Just new state, no token - the hot path + Some((ls, None)) => ( + "lex", + self.mk_byte_state(ls, top.parse_stack_idx, top.viable), + ), + // New state and token generated + Some((ls, Some(pat_idx))) => ("parse", self.run_parser(pat_idx, &top, ls)), + }; + if LOG_PARSER { + wprintln!( + " -> {} {}", + info, + if res.is_none() { "error" } else { "ok" } + ); + } + res + } + + fn pstack_for(&self, top: &ByteState) -> &PStack { + &self.parse_stacks[top.parse_stack_idx.0] + } + + fn push_pstack(&mut self, top: &ByteState, pstack: Vec>) -> PStackIdx { + let new_idx = PStackIdx(top.parse_stack_idx.0 + 1); + if self.parse_stacks.len() <= new_idx.0 { + self.parse_stacks.push(Vec::new()); + } + self.parse_stacks[new_idx.0] = pstack; + new_idx + } + + fn run_parser(&mut self, pat_idx: usize, top: &ByteState, ls: LexerState) -> Option { + { + let mut s = self.stats.borrow_mut(); + s.yacc_actions += 1; + } + if LOG_PARSER { + wprintln!(); + } + let pstack = self.pstack_for(top); + if self.skip_patterns[pat_idx] { + let stidx = *pstack.last().unwrap(); + let viable = self.viable_vobidx(stidx); + //self.print_viable("reset", &viable); + if LOG_PARSER { + wprintln!("parse: {:?} skip", pstack); + } + // reset viable states - they have been narrowed down to SKIP + self.mk_byte_state(ls, top.parse_stack_idx, viable) + } else { + let tidx = self.pat_idx_to_tidx[pat_idx]; + let mut pstack = pstack.clone(); + match self.parse_lexeme(tidx, &mut pstack) { + ParseResult::Accept => panic!("accept non EOF?"), + ParseResult::Continue => { + let stidx = *pstack.last().unwrap(); + let viable = self.viable_vobidx(stidx); + let new_idx = self.push_pstack(top, pstack); + self.mk_byte_state(ls, new_idx, viable) + } + ParseResult::Error => None, + } + } + } + + #[allow(dead_code)] + pub fn viable_now(&self) { + let v = self.byte_states.last().unwrap().viable; + self.print_viable("now", self.vobset.resolve(v)) + } + + pub fn get_stats(&self) -> String { + let mut s = self.stats.borrow_mut(); + let r = format!("yacc: {}/{}", s.yacc_actions, s.states_pushed); + s.yacc_actions = 0; + s.states_pushed = 0; + r + } + + fn mk_byte_state( + &self, + ls: LexerState, + pstack: PStackIdx, + viable: VobIdx, + ) -> Option { + { + let mut s = self.stats.borrow_mut(); + s.states_pushed += 1; + } + if self.vobset.and_is_zero(viable, ls.reachable) { + None + } else { + Some(ByteState { + lexer_state: ls.state, + parse_stack_idx: pstack, + viable, + }) + } + } +} + +#[derive(Clone, Copy)] +struct PStackIdx(usize); + +#[derive(Clone)] +struct ByteState { + lexer_state: StateID, + parse_stack_idx: PStackIdx, + viable: VobIdx, +} + +impl Recognizer for CfgParser { + fn pop_bytes(&mut self, num: usize) { + self.byte_states.truncate(self.byte_states.len() - num); + } + + fn collapse(&mut self) { + let final_state = self.byte_states.pop().unwrap(); + self.byte_states.clear(); + self.byte_states.push(final_state); + } + + fn special_allowed(&mut self, tok: SpecialToken) -> bool { + match tok { + SpecialToken::EndOfSentence => { + if let Some(st) = self.try_push(None) { + let tidx = self.grm.eof_token_idx(); + let mut pstack = self.pstack_for(&st).clone(); + match self.parse_lexeme(tidx, &mut pstack) { + ParseResult::Accept => true, + _ => false, + } + } else { + false + } + } + _ => false, + } + } + + fn trie_finished(&mut self) { + assert!(self.byte_states.len() == 1); + } + + #[inline(always)] + fn try_push_byte(&mut self, byte: u8) -> bool { + if let Some(st) = self.try_push(Some(byte)) { + self.byte_states.push(st); + true + } else { + false + } + } +} + +#[allow(dead_code)] +pub fn cfg_test() -> Result<()> { + let yacc_bytes = include_bytes!("../../grammars/c.y"); + let mut cfg = CfgParser::from_yacc(&String::from_utf8_lossy(yacc_bytes)).unwrap(); + let sample = include_bytes!("../../grammars/sample.c"); + + if true { + let trie = TokTrie::from_host(); + let toks = trie.greedy_tokenize(sample); + + #[cfg(not(target_arch = "wasm32"))] + let t0 = std::time::Instant::now(); + + let mut line = 1; + let mut vob = SimpleVob::new(); + vob.resize(trie.vocab_size() + 1); + + for tok in &toks[0..1000] { + let tok = *tok; + trie.compute_bias(&mut cfg, &mut vob); + if !vob.is_allowed(tok) { + wprintln!("reject, line={}, tok={:?}", line, trie.token_str(tok)); + panic!(); + } + for b in trie.token(tok) { + if *b == b'\n' { + line += 1; + } + } + if false { + wprintln!( + "tok: {:?} {}; {}", + trie.token_str(tok), + vob.is_allowed(tok), + cfg.get_stats() + ); + cfg.viable_now(); + } + trie.append_token(&mut cfg, tok); + } + + #[cfg(not(target_arch = "wasm32"))] + wprintln!("time: {:?} ", t0.elapsed()); + + wprintln!("stats: {}", cfg.get_stats()); + } + + if false { + let mut rng = crate::rng::Rng::new(0); + let mut ok = true; + let mut idx = 0; + while idx < sample.len() { + let b = sample[idx]; + // wprintln!("idx {} {:?}", idx, b as char); + let r = cfg.try_push_byte(b); + if !r { + ok = false; + wprintln!( + "reject at\n{:?}\n{:?}", + String::from_utf8_lossy(&sample[idx.saturating_sub(50)..idx]), + String::from_utf8_lossy(&sample[idx..std::cmp::min(idx + 30, sample.len())]) + ); + break; + } + idx += 1; + + if false { + let max_pop = cfg.byte_states.len() - 1; + if max_pop > 0 && rng.gen_up_to(4) == 0 { + let num = rng.gen_up_to(max_pop - 1) + 1; + // wprintln!("pop {} {}", num, cfg.byte_states.len()); + cfg.pop_bytes(num); + idx -= num; + } + + if rng.gen_up_to(10) == 0 { + // wprintln!("collapse"); + cfg.collapse(); + } + } + } + + if ok { + if cfg.special_allowed(SpecialToken::EndOfSentence) { + wprintln!("accept EOS"); + } else { + wprintln!("reject EOS"); + } + } else { + wprintln!("reject"); + } + } + + Ok(()) +} diff --git a/aici_abi/src/lex.rs b/aici_abi/src/lex.rs new file mode 100644 index 00000000..25d96098 --- /dev/null +++ b/aici_abi/src/lex.rs @@ -0,0 +1,349 @@ +use crate::wprintln; +use regex_automata::{ + dfa::{dense, Automaton}, + util::syntax, +}; +use rustc_hash::FxHashMap; +use std::{hash::Hash, vec}; +use vob::{vob, Vob}; + +pub type PatIdx = usize; +pub type StateID = regex_automata::util::primitives::StateID; + +const LOG_LEXER: bool = false; + +// enabling this is slightly faster, but it requires ~ |lexer_states|*|parser_states| bits +const PRECOMPUTE_AND: bool = false; + +#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] +pub struct LexerState { + pub state: StateID, + pub reachable: VobIdx, +} + +impl LexerState { + fn fake() -> Self { + LexerState { + state: StateID::default(), + reachable: VobIdx::all_zero(), + } + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] +pub struct VobIdx { + v: u32, +} + +impl VobIdx { + pub fn new(v: usize) -> Self { + VobIdx { v: v as u32 } + } + + pub fn all_zero() -> Self { + VobIdx { v: 0 } + } + + pub fn as_usize(&self) -> usize { + self.v as usize + } + + pub fn is_zero(&self) -> bool { + self.v == 0 + } +} + +pub struct VobSet { + vobs: Vec, + by_vob: FxHashMap, + non_empty: Vob, +} + +impl VobSet { + pub fn new() -> Self { + VobSet { + vobs: Vec::new(), + by_vob: FxHashMap::default(), + non_empty: Vob::new(), + } + } + + pub fn get(&mut self, vob: &Vob) -> VobIdx { + if let Some(idx) = self.by_vob.get(vob) { + return *idx; + } + let len = self.vobs.len(); + if len == 0 && !vob_is_zero(vob) { + panic!("first vob must be empty"); + } + let idx = VobIdx::new(len); + self.vobs.push(vob.clone()); + self.by_vob.insert(vob.clone(), idx); + idx + } + + pub fn resolve(&self, idx: VobIdx) -> &Vob { + &self.vobs[idx.as_usize()] + } + + pub fn and_is_zero(&self, a: VobIdx, b: VobIdx) -> bool { + if PRECOMPUTE_AND { + !self.non_empty[a.as_usize() * self.vobs.len() + b.as_usize()] + } else { + vob_and_is_zero(&self.vobs[a.as_usize()], &self.vobs[b.as_usize()]) + } + } + + pub fn pre_compute(&mut self) { + if PRECOMPUTE_AND { + let l = self.vobs.len(); + self.non_empty.resize(l * l, false); + for x in 0..self.vobs.len() { + for y in 0..=x { + if !vob_and_is_zero(&self.vobs[x], &self.vobs[y]) { + self.non_empty.set(x * l + y, true); + self.non_empty.set(y * l + x, true); + } + } + } + wprintln!( + "vob set: {} VOBs, {} nonempty", + self.vobs.len(), + self.non_empty.len() + ); + } + } +} + +pub struct Lexer { + dfa: dense::DFA>, + initial: LexerState, + vobidx_by_state_off: Vec, +} + +impl Lexer { + pub fn from(patterns: Vec, vobset: &mut VobSet) -> Self { + // TIME: 4ms + let dfa = dense::Builder::new() + .configure( + dense::Config::new() + .start_kind(regex_automata::dfa::StartKind::Anchored) + .match_kind(regex_automata::MatchKind::All), + ) + .syntax(syntax::Config::new().unicode(false).utf8(false)) + .build_many(&patterns) + .unwrap(); + + wprintln!( + "dfa: {} bytes, {} patterns", + dfa.memory_usage(), + patterns.len(), + ); + if false { + for p in &patterns { + wprintln!(" {}", p) + } + } + + let anch = regex_automata::Anchored::Yes; + + let mut incoming = FxHashMap::default(); + let initial = dfa.universal_start_state(anch).unwrap(); + let mut todo = vec![initial]; + incoming.insert(initial, Vec::new()); + + // TIME: 1.5ms + while todo.len() > 0 { + let s = todo.pop().unwrap(); + for b in 0..=255 { + let s2 = dfa.next_state(s, b); + if !incoming.contains_key(&s2) { + todo.push(s2); + incoming.insert(s2, Vec::new()); + } + incoming.get_mut(&s2).unwrap().push(s); + } + } + + let states = incoming.keys().map(|x| *x).collect::>(); + let mut reachable_patterns = FxHashMap::default(); + + for s in &states { + let mut v = vob![false; patterns.len()]; + let s2 = dfa.next_eoi_state(*s); + if dfa.is_match_state(s2) { + for idx in 0..dfa.match_len(s2) { + let idx = dfa.match_pattern(s2, idx).as_usize(); + v.set(idx, true); + if LOG_LEXER { + wprintln!(" match: {:?} {}", *s, patterns[idx]) + } + } + } + reachable_patterns.insert(*s, v); + } + + // TIME: 20ms + loop { + let mut num_set = 0; + + for s in &states { + let ours = reachable_patterns.get(s).unwrap().clone(); + for o in &incoming[s] { + let theirs = reachable_patterns.get(o).unwrap(); + let mut tmp = ours.clone(); + tmp |= theirs; + if tmp != *theirs { + num_set += 1; + reachable_patterns.insert(*o, tmp); + } + } + } + + if LOG_LEXER { + wprintln!("iter {} {}", num_set, states.len()); + } + if num_set == 0 { + break; + } + } + + let mut states_idx = states.iter().map(|x| x.as_usize()).collect::>(); + states_idx.sort(); + + let shift = dfa.stride2(); + let mut vobidx_by_state_off = + vec![VobIdx::all_zero(); 1 + (states_idx.iter().max().unwrap() >> shift)]; + for (k, v) in reachable_patterns.iter() { + vobidx_by_state_off[k.as_usize() >> shift] = vobset.get(v); + } + + wprintln!("initial: {:?}; {} states", initial, states.len()); + + let mut lex = Lexer { + dfa, + vobidx_by_state_off, + initial: LexerState::fake(), + }; + + lex.initial = lex.mk_state(initial); + + if LOG_LEXER { + for s in &states { + if lex.is_dead(*s) { + wprintln!("dead: {:?} {}", s, lex.dfa.is_dead_state(*s)); + } + } + + wprintln!("reachable: {:#?}", reachable_patterns); + } + + lex + } + + pub fn file_start_state(&self) -> StateID { + // pretend we've just seen a newline at the beginning of the file + // TODO: this should be configurable + self.dfa.next_state(self.initial.state, b'\n') + } + + fn mk_state(&self, state: StateID) -> LexerState { + LexerState { + state, + reachable: self.reachable_tokens(state), + } + } + + fn is_dead(&self, state: StateID) -> bool { + self.reachable_tokens(state).is_zero() + } + + fn reachable_tokens(&self, state: StateID) -> VobIdx { + self.vobidx_by_state_off[state.as_usize() >> self.dfa.stride2()] + } + + fn get_token(&self, prev: StateID) -> Option { + let state = self.dfa.next_eoi_state(prev); + if !self.dfa.is_match_state(state) { + return None; + } + + // we take the first token that matched + // (eg., "while" will match both keyword and identifier, but keyword is first) + let pat_idx = (0..self.dfa.match_len(state)) + .map(|idx| self.dfa.match_pattern(state, idx).as_usize()) + .min() + .unwrap(); + + if LOG_LEXER { + wprintln!("token: {}", pat_idx); + } + + Some(pat_idx) + } + + #[inline(always)] + pub fn advance(&self, prev: StateID, byte: Option) -> Option<(LexerState, Option)> { + let dfa = &self.dfa; + if let Some(byte) = byte { + let state = dfa.next_state(prev, byte); + if LOG_LEXER { + wprintln!( + "lex: {:?} -{:?}-> {:?} d={}", + prev, + byte as char, + state, + self.is_dead(state), + ); + } + let v = self.reachable_tokens(state); + if v.is_zero() { + // if final_state is a match state, find the token that matched + let tok = self.get_token(prev); + if tok.is_none() { + None + } else { + let state = dfa.next_state(self.initial.state, byte); + if LOG_LEXER { + wprintln!("lex0: {:?} -{:?}-> {:?}", self.initial, byte as char, state); + } + Some((self.mk_state(state), tok)) + } + } else { + Some(( + LexerState { + state, + reachable: v, + }, + None, + )) + } + } else { + let tok = self.get_token(prev); + if tok.is_none() { + None + } else { + Some((self.initial, tok)) + } + } + } +} + +fn vob_and_is_zero(a: &Vob, b: &Vob) -> bool { + debug_assert!(a.len() == b.len()); + for (a, b) in a.iter_storage().zip(b.iter_storage()) { + if a & b != 0 { + return false; + } + } + return true; +} + +fn vob_is_zero(v: &Vob) -> bool { + for b in v.iter_storage() { + if b != 0 { + return false; + } + } + true +} diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 0edac727..f298d106 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -8,6 +8,16 @@ pub mod rng; pub mod svob; pub mod toktree; +#[cfg(feature = "cfg")] +pub mod cfg; +#[cfg(feature = "cfg")] +mod lex; + +#[cfg(feature = "rx")] +pub mod rx; + +pub mod substring; + pub type TokenId = bytes::TokenId; pub use host::{ diff --git a/aici_abi/src/rx.rs b/aici_abi/src/rx.rs new file mode 100644 index 00000000..e9a517b7 --- /dev/null +++ b/aici_abi/src/rx.rs @@ -0,0 +1,66 @@ +use crate::{ + recognizer::{FunctionalRecognizer, StackRecognizer}, + toktree::SpecialToken, + wprintln, +}; +use regex_automata::{ + dfa::{dense, Automaton}, + util::{primitives::StateID, syntax}, +}; + +pub type RecRxState = StateID; + +#[derive(Clone)] +pub struct RecRx { + dfa: dense::DFA>, +} + +pub type RxStackRecognizer = StackRecognizer; + +impl RecRx { + pub fn from_rx(rx: &str) -> Self { + let rx = if rx.ends_with("$") { + rx.to_string() + } else { + rx.to_string() + "$" + }; + let dfa = dense::Builder::new() + .configure(dense::Config::new().start_kind(regex_automata::dfa::StartKind::Anchored)) + .syntax(syntax::Config::new().unicode(false).utf8(false)) + .build(&rx) + .unwrap(); + wprintln!("dfa: {} bytes", dfa.memory_usage()); + Self { dfa } + } + + pub fn to_stack_recognizer(self) -> RxStackRecognizer { + StackRecognizer::from(self) + } +} + +impl FunctionalRecognizer for RecRx { + fn initial(&self) -> RecRxState { + self.dfa + .universal_start_state(regex_automata::Anchored::Yes) + .expect("dfa has no universal start state; make sure it doesn't match empty string") + } + + #[inline(always)] + fn append(&self, state: RecRxState, byte: u8) -> RecRxState { + self.dfa.next_state(state, byte) + } + + #[inline(always)] + fn byte_allowed(&self, state: RecRxState, byte: u8) -> bool { + !self.dfa.is_dead_state(self.dfa.next_state(state, byte)) + } + + #[inline(always)] + fn special_allowed(&self, state: RecRxState, tok: SpecialToken) -> bool { + let state = self.dfa.next_eoi_state(state); + match tok { + SpecialToken::EndOfSentence => self.dfa.is_match_state(state), + _ => false, + } + } +} diff --git a/aici_abi/src/substring.rs b/aici_abi/src/substring.rs new file mode 100644 index 00000000..d5d262f3 --- /dev/null +++ b/aici_abi/src/substring.rs @@ -0,0 +1,274 @@ +use std::fmt::Display; + +use crate::{ + bytes::limit_str, + recognizer::{FunctionalRecognizer, StackRecognizer}, + toktree::SpecialToken, +}; +use serde_json::json; + +enum Node { + Inner { children: Vec<(u8, usize)> }, + Leaf { source_offset: usize }, +} + +pub struct SubStrMatcher { + end_str: String, + source: String, + nodes: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SubStrState { + Dead, + Node(usize), + SourceOffset(usize), + EndStrOffset(usize), +} + +pub type SubStrStackRecognizer = StackRecognizer; + +fn add_node(nodes: &mut Vec, n: Node) -> usize { + let idx = nodes.len(); + nodes.push(n); + idx +} + +impl Display for SubStrMatcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.pp(f, 0, 0) + } +} + +impl SubStrMatcher { + #[allow(dead_code)] + fn to_json(&self, node_idx: usize) -> serde_json::Value { + match &self.nodes[node_idx] { + Node::Inner { children } => { + let mut children_json = serde_json::Map::new(); + for (c, idx) in children.iter() { + children_json.insert(format!("{}", *c as char), self.to_json(*idx)); + } + serde_json::Value::Object(children_json) + } + Node::Leaf { source_offset } => { + json!(limit_str(&self.source[*source_offset..], 20)) + } + } + } + + fn pp( + &self, + f: &mut std::fmt::Formatter<'_>, + indent: usize, + node_idx: usize, + ) -> std::fmt::Result { + let node = &self.nodes[node_idx]; + match node { + Node::Inner { children } => { + for (c, idx) in children.iter() { + writeln!(f, "{:indent$}{:?} -> {}", "", *c as char, idx)?; + self.pp(f, indent + 1, *idx)?; + } + } + Node::Leaf { source_offset } => { + writeln!( + f, + "{:indent$}{}: {:?}", + "", + *source_offset, + limit_str(&self.source[*source_offset..], 20), + )?; + } + } + Ok(()) + } + + pub fn new(source: &str, end_str: &str) -> Self { + let mut tmp = Self { + source: source.to_string() + " ", + end_str: end_str.to_string(), + nodes: vec![Node::Inner { children: vec![] }], + }; + tmp.add(0); + for i in 0..tmp.source.len() { + if tmp.source.as_bytes()[i] == b' ' { + tmp.add(i + 1); + } + } + // println!("{}", tmp); + // println!("JSON: {}", serde_json::to_string(&tmp.to_json(0)).unwrap()); + tmp + } + + fn find(&self, s: &str) -> (usize, usize) { + let mut node_idx = 0; + for (i, b) in s.bytes().enumerate() { + let node = &self.nodes[node_idx]; + match node { + Node::Inner { children } => { + let mut found = false; + for (c, idx) in children.iter() { + if *c == b { + node_idx = *idx; + found = true; + break; + } + } + if !found { + return (node_idx, i); + } + } + Node::Leaf { .. } => return (node_idx, i), + } + } + (node_idx, s.len()) + } + + fn add(&mut self, source_offset1: usize) { + let s1 = &self.source[source_offset1..]; + let (mut node_idx, offset) = self.find(s1); + if offset >= s1.len() { + return; + } + let source_offset1 = source_offset1 + offset; + let s1 = &self.source[source_offset1..]; + + let num_nodes = self.nodes.len(); + match &mut self.nodes[node_idx] { + Node::Inner { children } => { + children.push((s1.as_bytes()[0], num_nodes)); + let n = add_node( + &mut self.nodes, + Node::Leaf { + source_offset: source_offset1 + 1, + }, + ); + assert!(n == num_nodes); + } + Node::Leaf { source_offset } => { + let source_offset2 = *source_offset; + let s2 = &self.source[source_offset2..]; + if s2.starts_with(s1) { + return; + } + if s1.starts_with(s2) { + self.nodes[node_idx] = Node::Leaf { + source_offset: source_offset1, + }; + return; + } + + for i in 0..s1.len() { + let b1 = s1.as_bytes()[i]; + let b2 = s2.as_bytes()[i]; + if b1 != b2 { + let n1 = add_node( + &mut self.nodes, + Node::Leaf { + source_offset: source_offset1 + i + 1, + }, + ); + let n2 = add_node( + &mut self.nodes, + Node::Leaf { + source_offset: source_offset2 + i + 1, + }, + ); + self.nodes[node_idx] = Node::Inner { + children: vec![(b1, n1), (b2, n2)], + }; + return; + } else { + let n1 = add_node(&mut self.nodes, Node::Inner { children: vec![] }); + self.nodes[node_idx] = Node::Inner { + children: vec![(b1, n1)], + }; + node_idx = n1; + } + } + } + } + } + + pub fn to_stack_recognizer(self) -> SubStrStackRecognizer { + StackRecognizer::from(self) + } + + fn append_to_src_off(&self, off: usize, byte: u8) -> SubStrState { + if off < self.source.len() && self.source.as_bytes()[off] == byte { + SubStrState::SourceOffset(off + 1) + } else { + SubStrState::Dead + } + } + + fn append_inner(&self, state: SubStrState, byte: u8) -> SubStrState { + match state { + SubStrState::Dead => SubStrState::Dead, + SubStrState::EndStrOffset(off) => { + if off < self.end_str.len() && self.end_str.as_bytes()[off] == byte { + SubStrState::EndStrOffset(off + 1) + } else { + SubStrState::Dead + } + } + SubStrState::Node(state) => { + let node = &self.nodes[state]; + match node { + Node::Inner { children } => { + for (c, idx) in children.iter() { + if *c == byte { + return SubStrState::Node(*idx); + } + } + SubStrState::Dead + } + Node::Leaf { source_offset } => self.append_to_src_off(*source_offset, byte), + } + } + SubStrState::SourceOffset(off) => self.append_to_src_off(off, byte), + } + } +} + +impl FunctionalRecognizer for SubStrMatcher { + fn initial(&self) -> SubStrState { + SubStrState::Node(0) + } + + #[inline(always)] + fn append(&self, state: SubStrState, byte: u8) -> SubStrState { + let state = match state { + SubStrState::Node(_) | SubStrState::SourceOffset(_) + if self.end_str.as_bytes().first() == Some(&byte) + && self.append_inner(state, b' ') != SubStrState::Dead => + { + SubStrState::EndStrOffset(0) + } + _ => state, + }; + + self.append_inner(state, byte) + } + + #[inline(always)] + fn byte_allowed(&self, state: SubStrState, byte: u8) -> bool { + self.append(state, byte) != SubStrState::Dead + } + + #[inline(always)] + fn special_allowed(&self, state: SubStrState, tok: SpecialToken) -> bool { + match tok { + SpecialToken::EndOfSentence => { + let l = self.end_str.len(); + if l == 0 { + self.append_inner(state, b' ') != SubStrState::Dead + } else { + state == SubStrState::EndStrOffset(l) + } + } + _ => false, + } + } +} From 3191c29043fedbf9120a3f7a40d375f9ce31593a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 15:17:56 +0000 Subject: [PATCH 119/301] build everything for wasm32-wasi not wasm32-unknown-unknown --- aici_abi/.cargo/config.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/.cargo/config.toml b/aici_abi/.cargo/config.toml index f4e8c002..6b77899c 100644 --- a/aici_abi/.cargo/config.toml +++ b/aici_abi/.cargo/config.toml @@ -1,2 +1,2 @@ [build] -target = "wasm32-unknown-unknown" +target = "wasm32-wasi" From 4191f11c7e65a761f0dfe8d2f5367830c5c63677 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 15:23:03 +0000 Subject: [PATCH 120/301] remove specialized wprint etc; rely on WASI --- aici_abi/src/cfg.rs | 47 +++++++++++++++---------------- aici_abi/src/host.rs | 57 +++----------------------------------- aici_abi/src/lex.rs | 23 ++++++++------- aici_abi/src/lib.rs | 22 ++------------- aici_abi/src/recognizer.rs | 4 +-- aici_abi/src/rx.rs | 3 +- aici_abi/src/toktree.rs | 2 +- 7 files changed, 44 insertions(+), 114 deletions(-) diff --git a/aici_abi/src/cfg.rs b/aici_abi/src/cfg.rs index f05e32eb..fb7d667b 100644 --- a/aici_abi/src/cfg.rs +++ b/aici_abi/src/cfg.rs @@ -2,7 +2,6 @@ use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; use crate::{ svob::SimpleVob, toktree::{Recognizer, SpecialToken, TokTrie}, - wprint, wprintln, }; use anyhow::Result; use cfgrammar::{ @@ -118,10 +117,10 @@ impl CfgParser { }; if false { - wprintln!("core\n{}\n\n", sgraph.pp(&grm, true)); + println!("core\n{}\n\n", sgraph.pp(&grm, true)); for pidx in grm.iter_pidxs() { let prod = grm.prod(pidx); - wprintln!("{:?} -> {}", prod, prod.len()); + println!("{:?} -> {}", prod, prod.len()); } } @@ -180,7 +179,7 @@ impl CfgParser { } } - wprintln!("patterns: {:?}", friendly_pattern_names); + println!("patterns: {:?}", friendly_pattern_names); let mut vobset = VobSet::new(); // all-zero has to be inserted first @@ -266,7 +265,7 @@ impl CfgParser { let act = self.stable.action(stidx, lexeme); if LOG_PARSER { - wprintln!( + println!( "parse: {:?} {:?} -> {:?}", pstack, self.friendly_token_name(lexeme), @@ -299,10 +298,10 @@ impl CfgParser { #[allow(dead_code)] fn print_viable(&self, lbl: &str, vob: &Vob) { - wprintln!("viable tokens {}:", lbl); + println!("viable tokens {}:", lbl); for (idx, b) in vob.iter().enumerate() { if b { - wprintln!(" {}: {}", idx, self.friendly_pattern_names[idx]); + println!(" {}: {}", idx, self.friendly_pattern_names[idx]); } } } @@ -312,11 +311,11 @@ impl CfgParser { fn try_push(&mut self, byte: Option) -> Option { let top = self.byte_states.last().unwrap().clone(); if LOG_PARSER { - wprint!("try_push: "); + print!("try_push: "); if let Some(b) = byte { - wprint!("{:?}", b as char) + print!("{:?}", b as char) } else { - wprint!("") + print!("") } } let (info, res) = match self.lexer.advance(top.lexer_state, byte) { @@ -331,7 +330,7 @@ impl CfgParser { Some((ls, Some(pat_idx))) => ("parse", self.run_parser(pat_idx, &top, ls)), }; if LOG_PARSER { - wprintln!( + println!( " -> {} {}", info, if res.is_none() { "error" } else { "ok" } @@ -359,7 +358,7 @@ impl CfgParser { s.yacc_actions += 1; } if LOG_PARSER { - wprintln!(); + println!(); } let pstack = self.pstack_for(top); if self.skip_patterns[pat_idx] { @@ -367,7 +366,7 @@ impl CfgParser { let viable = self.viable_vobidx(stidx); //self.print_viable("reset", &viable); if LOG_PARSER { - wprintln!("parse: {:?} skip", pstack); + println!("parse: {:?} skip", pstack); } // reset viable states - they have been narrowed down to SKIP self.mk_byte_state(ls, top.parse_stack_idx, viable) @@ -498,7 +497,7 @@ pub fn cfg_test() -> Result<()> { let tok = *tok; trie.compute_bias(&mut cfg, &mut vob); if !vob.is_allowed(tok) { - wprintln!("reject, line={}, tok={:?}", line, trie.token_str(tok)); + println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); panic!(); } for b in trie.token(tok) { @@ -507,7 +506,7 @@ pub fn cfg_test() -> Result<()> { } } if false { - wprintln!( + println!( "tok: {:?} {}; {}", trie.token_str(tok), vob.is_allowed(tok), @@ -519,9 +518,9 @@ pub fn cfg_test() -> Result<()> { } #[cfg(not(target_arch = "wasm32"))] - wprintln!("time: {:?} ", t0.elapsed()); + println!("time: {:?} ", t0.elapsed()); - wprintln!("stats: {}", cfg.get_stats()); + println!("stats: {}", cfg.get_stats()); } if false { @@ -530,11 +529,11 @@ pub fn cfg_test() -> Result<()> { let mut idx = 0; while idx < sample.len() { let b = sample[idx]; - // wprintln!("idx {} {:?}", idx, b as char); + // println!("idx {} {:?}", idx, b as char); let r = cfg.try_push_byte(b); if !r { ok = false; - wprintln!( + println!( "reject at\n{:?}\n{:?}", String::from_utf8_lossy(&sample[idx.saturating_sub(50)..idx]), String::from_utf8_lossy(&sample[idx..std::cmp::min(idx + 30, sample.len())]) @@ -547,13 +546,13 @@ pub fn cfg_test() -> Result<()> { let max_pop = cfg.byte_states.len() - 1; if max_pop > 0 && rng.gen_up_to(4) == 0 { let num = rng.gen_up_to(max_pop - 1) + 1; - // wprintln!("pop {} {}", num, cfg.byte_states.len()); + // println!("pop {} {}", num, cfg.byte_states.len()); cfg.pop_bytes(num); idx -= num; } if rng.gen_up_to(10) == 0 { - // wprintln!("collapse"); + // println!("collapse"); cfg.collapse(); } } @@ -561,12 +560,12 @@ pub fn cfg_test() -> Result<()> { if ok { if cfg.special_allowed(SpecialToken::EndOfSentence) { - wprintln!("accept EOS"); + println!("accept EOS"); } else { - wprintln!("reject EOS"); + println!("reject EOS"); } } else { - wprintln!("reject"); + println!("reject"); } } diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 34f1be13..5a41df5f 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -1,10 +1,9 @@ use crate::{ bytes::{vec_from_bytes, TokenId}, svob::SimpleVob, - wprintln, SeqId, + SeqId, }; use serde::{Deserialize, Serialize}; -use std::io; #[repr(transparent)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -12,9 +11,6 @@ struct BlobId(u32); #[allow(dead_code)] extern "C" { - // Log a string. - fn aici_host_print(ptr: *const u8, len: u32); - // Read binary blob. // Always returns the size of the blob, will write up to `size` bytes to `dst`. fn aici_host_read_blob(blob: BlobId, dst: *mut u8, size: u32) -> u32; @@ -62,24 +58,6 @@ fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { buffer } -#[cfg(not(target_arch = "wasm32"))] -pub type Printer = std::io::Stdout; - -#[cfg(target_arch = "wasm32")] -pub struct Printer {} - -#[cfg(target_arch = "wasm32")] -impl io::Write for Printer { - fn write(&mut self, buf: &[u8]) -> io::Result { - unsafe { aici_host_print(buf.as_ptr(), buf.len() as u32) }; - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - pub fn init_panic() { #[cfg(target_arch = "wasm32")] std::panic::set_hook(Box::new(|info| { @@ -95,37 +73,10 @@ pub fn init_panic() { }, }; - let err_info = format!("Panicked at '{}', {}:{}:{}\n", msg, file, line, col); - _print(&err_info); + println!("Panicked at '{}', {}:{}:{}", msg, file, line, col); })) } -pub fn stdout() -> Printer { - #[cfg(target_arch = "wasm32")] - { - Printer {} - } - - #[cfg(not(target_arch = "wasm32"))] - { - io::stdout() - } -} - -pub fn _print(msg: &str) { - #[cfg(target_arch = "wasm32")] - { - let vec: Vec = msg.into(); - unsafe { aici_host_print(vec.as_ptr(), vec.len() as u32) }; - } - - #[cfg(not(target_arch = "wasm32"))] - { - use std::io::Write; - std::io::stdout().write_all(msg.as_bytes()).unwrap(); - } -} - #[no_mangle] pub extern "C" fn aici_init() { init_panic(); @@ -301,7 +252,7 @@ pub fn tokenize_bytes(s: &[u8]) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); let res = vec_from_bytes(&r); - wprintln!( + println!( "tokenize_bytes: {:?} -> {:?}", String::from_utf8_lossy(s), res @@ -314,7 +265,7 @@ pub fn tokenize(s: &str) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); let res = vec_from_bytes(&r); - wprintln!("tokenize: {:?} -> {:?}", s, res); + println!("tokenize: {:?} -> {:?}", s, res); res } diff --git a/aici_abi/src/lex.rs b/aici_abi/src/lex.rs index 25d96098..4ad1e672 100644 --- a/aici_abi/src/lex.rs +++ b/aici_abi/src/lex.rs @@ -1,4 +1,3 @@ -use crate::wprintln; use regex_automata::{ dfa::{dense, Automaton}, util::syntax, @@ -106,7 +105,7 @@ impl VobSet { } } } - wprintln!( + println!( "vob set: {} VOBs, {} nonempty", self.vobs.len(), self.non_empty.len() @@ -134,14 +133,14 @@ impl Lexer { .build_many(&patterns) .unwrap(); - wprintln!( + println!( "dfa: {} bytes, {} patterns", dfa.memory_usage(), patterns.len(), ); if false { for p in &patterns { - wprintln!(" {}", p) + println!(" {}", p) } } @@ -176,7 +175,7 @@ impl Lexer { let idx = dfa.match_pattern(s2, idx).as_usize(); v.set(idx, true); if LOG_LEXER { - wprintln!(" match: {:?} {}", *s, patterns[idx]) + println!(" match: {:?} {}", *s, patterns[idx]) } } } @@ -201,7 +200,7 @@ impl Lexer { } if LOG_LEXER { - wprintln!("iter {} {}", num_set, states.len()); + println!("iter {} {}", num_set, states.len()); } if num_set == 0 { break; @@ -218,7 +217,7 @@ impl Lexer { vobidx_by_state_off[k.as_usize() >> shift] = vobset.get(v); } - wprintln!("initial: {:?}; {} states", initial, states.len()); + println!("initial: {:?}; {} states", initial, states.len()); let mut lex = Lexer { dfa, @@ -231,11 +230,11 @@ impl Lexer { if LOG_LEXER { for s in &states { if lex.is_dead(*s) { - wprintln!("dead: {:?} {}", s, lex.dfa.is_dead_state(*s)); + println!("dead: {:?} {}", s, lex.dfa.is_dead_state(*s)); } } - wprintln!("reachable: {:#?}", reachable_patterns); + println!("reachable: {:#?}", reachable_patterns); } lex @@ -276,7 +275,7 @@ impl Lexer { .unwrap(); if LOG_LEXER { - wprintln!("token: {}", pat_idx); + println!("token: {}", pat_idx); } Some(pat_idx) @@ -288,7 +287,7 @@ impl Lexer { if let Some(byte) = byte { let state = dfa.next_state(prev, byte); if LOG_LEXER { - wprintln!( + println!( "lex: {:?} -{:?}-> {:?} d={}", prev, byte as char, @@ -305,7 +304,7 @@ impl Lexer { } else { let state = dfa.next_state(self.initial.state, byte); if LOG_LEXER { - wprintln!("lex0: {:?} -{:?}-> {:?}", self.initial, byte as char, state); + println!("lex0: {:?} -{:?}-> {:?}", self.initial, byte as char, state); } Some((self.mk_state(state), tok)) } diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index f298d106..73b6fd50 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -21,8 +21,8 @@ pub mod substring; pub type TokenId = bytes::TokenId; pub use host::{ - _print, aici_stop, arg_bytes, return_logit_bias, self_seq_id, stdout, tokenize, tokenize_bytes, - StorageCmd, StorageOp, StorageResp, VariableStorage, + aici_stop, arg_bytes, return_logit_bias, self_seq_id, tokenize, tokenize_bytes, StorageCmd, + StorageOp, StorageResp, VariableStorage, }; #[derive(Serialize, Deserialize, Debug)] @@ -270,21 +270,3 @@ macro_rules! include_bytes_aligned { &ALIGNED.bytes }}; } - -#[macro_export] -macro_rules! wprintln { - () => { - $crate::_print("\n") - }; - ($($arg:tt)*) => {{ - $crate::_print(&format!($($arg)*)); - $crate::_print("\n"); - }}; -} - -#[macro_export] -macro_rules! wprint { - ($($arg:tt)*) => {{ - $crate::_print(&format!($($arg)*)); - }}; -} diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index eefa9bfc..d97992b9 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -30,7 +30,7 @@ impl AiciVm for AiciRecognizer { fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { for token in &arg.tokens { let bytes = self.trie.token(*token); - // wprintln!("process {} {:?}", token, bytes); + // println!("process {} {:?}", token, bytes); for b in bytes { self.rec.push_byte(*b) } @@ -102,7 +102,7 @@ impl> Recognizer for StackRecognizer } fn trie_finished(&mut self) { - // wprintln!("{:?}", &self.stack[0..=self.stack_ptr]); + // println!("{:?}", &self.stack[0..=self.stack_ptr]); assert!(self.stack_ptr == 0); } diff --git a/aici_abi/src/rx.rs b/aici_abi/src/rx.rs index e9a517b7..04ebfbcf 100644 --- a/aici_abi/src/rx.rs +++ b/aici_abi/src/rx.rs @@ -1,7 +1,6 @@ use crate::{ recognizer::{FunctionalRecognizer, StackRecognizer}, toktree::SpecialToken, - wprintln, }; use regex_automata::{ dfa::{dense, Automaton}, @@ -29,7 +28,7 @@ impl RecRx { .syntax(syntax::Config::new().unicode(false).utf8(false)) .build(&rx) .unwrap(); - wprintln!("dfa: {} bytes", dfa.memory_usage()); + println!("dfa: {} bytes", dfa.memory_usage()); Self { dfa } } diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index c642615d..b27f0c13 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -241,7 +241,7 @@ impl TokTrie { pub fn token_id(&self, bytes: &[u8]) -> Option { let (tok, len) = self.prefix_token_id(bytes); - // wprintln!("tok_id {:?} {:?} {:?} ", bytes, tok, len); + // println!("tok_id {:?} {:?} {:?} ", bytes, tok, len); if len == bytes.len() { Some(tok) } else { From bf99d6044524093080a3db97d561bf7e895e9271 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 15:48:06 +0000 Subject: [PATCH 121/301] don't use wasm-strip --- aici_abi/.cargo/config.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aici_abi/.cargo/config.toml b/aici_abi/.cargo/config.toml index 6b77899c..e0b0d22a 100644 --- a/aici_abi/.cargo/config.toml +++ b/aici_abi/.cargo/config.toml @@ -1,2 +1,8 @@ [build] target = "wasm32-wasi" + +[profile.dev] +strip = "debuginfo" + +[profile.release] +strip = "debuginfo" From f62d7ae1cf78cba85cfa18cd516969366f0b12ec Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 16:58:14 +0000 Subject: [PATCH 122/301] add simple wasm/rust sample --- aici_abi/Cargo.toml | 4 ++ aici_abi/src/recognizer.rs | 9 +-- aici_abi/src/toktree.rs | 6 ++ aici_abi/src/uppercase.rs | 109 +++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 aici_abi/src/uppercase.rs diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index 21e10f93..aafe97ec 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -22,3 +22,7 @@ rustc-hash = { version = "1.1.0", optional = true } default = ["cfg", "rx"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] + +[[bin]] +name = "uppercase" +path = "src/uppercase.rs" \ No newline at end of file diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index d97992b9..c4ce4041 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -28,14 +28,7 @@ impl AiciVm for AiciRecognizer { } fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { - for token in &arg.tokens { - let bytes = self.trie.token(*token); - // println!("process {} {:?}", token, bytes); - for b in bytes { - self.rec.push_byte(*b) - } - self.rec.collapse(); - } + self.trie.append_tokens(&mut self.rec, &arg.tokens); PostProcessResult::from_arg(&arg) } } diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index b27f0c13..18c19869 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -394,6 +394,12 @@ impl TokTrie { self.add_bias(r, logits) } + pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) { + for t in ts { + self.append_token(r, *t) + } + } + pub fn append_token(&self, r: &mut impl Recognizer, t: TokenId) { let bytes = self.token(t); for &byte in bytes { diff --git a/aici_abi/src/uppercase.rs b/aici_abi/src/uppercase.rs new file mode 100644 index 00000000..644d9ec9 --- /dev/null +++ b/aici_abi/src/uppercase.rs @@ -0,0 +1,109 @@ +use aici_abi::{ + recognizer::{FunctionalRecognizer, StackRecognizer}, + tokenize, + toktree::{SpecialToken, TokTrie}, + AiciVm, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, PostProcessArg, + PostProcessResult, PreProcessArg, PreProcessResult, +}; + +// This constraints enforces an upper case letter every second byte +// The state is the position in the output stream +struct EvenUpper {} +impl FunctionalRecognizer for EvenUpper { + fn initial(&self) -> usize { + 0 + } + + fn append(&self, state: usize, _byte: u8) -> usize { + state + 1 + } + + fn byte_allowed(&self, state: usize, byte: u8) -> bool { + if state % 4 == 0 { + byte.is_ascii_uppercase() + } else { + true + } + } + + fn special_allowed(&self, _state: usize, tok: SpecialToken) -> bool { + match tok { + SpecialToken::EndOfSentence => false, + _ => false, + } + } +} + +pub struct Runner { + toktrie: TokTrie, + tokens: Vec, + rec: StackRecognizer, +} + +impl Runner { + pub fn new(aici_arg: Vec) -> Self { + println!("user passed in {} bytes", aici_arg.len()); + Runner { + toktrie: TokTrie::from_host(), + tokens: Vec::new(), + rec: StackRecognizer::from(EvenUpper {}), + } + } +} + +impl AiciVm for Runner { + fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { + // with VMs, the prompt is often empty, but let's print it + println!( + "init_prompt: {:?} {:?}", + arg.prompt, + self.toktrie.decode_str(&arg.prompt) + ); + // result is currently empty + InitPromptResult::default() + } + + fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { + if self.tokens.is_empty() { + // if no tokens yet, send our prompt + let toks = tokenize("Here's a tweet:\n"); + PreProcessResult::ff_tokens(toks) + } else { + // otherwise just continue - the other option is to suspend + PreProcessResult::continue_() + } + } + + fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { + if self.tokens.len() > 50 { + // stop after 50 tokens + return MidProcessResult::Stop; + } + + // otherwise, compute bias according to our recognizer + let mut set = self.toktrie.alloc_token_set(); + self.toktrie.compute_bias(&mut self.rec, &mut set); + MidProcessResult::SampleWithBias { + allowed_tokens: set, + } + } + + fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { + // save our tokens + self.tokens.extend_from_slice(&arg.tokens); + // and update the state of our recognizer + self.toktrie.append_tokens(&mut self.rec, &arg.tokens); + // ::from_arg() will translate generation of EOS token into Stop instruction + PostProcessResult::from_arg(&arg) + } +} + +fn runner_from_env() -> Runner { + Runner::new(aici_abi::arg_bytes()) +} + +fn main() { + // test code here? +} + +aici_abi::aici_expose_all!(Runner, runner_from_env()); From df531638a9fa8a57837f4209f632e23eb27381a0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jan 2024 17:25:24 +0000 Subject: [PATCH 123/301] add yes/no example --- aici_abi/Cargo.toml | 7 +++- aici_abi/src/uppercase.rs | 10 +++--- aici_abi/src/yesno.rs | 70 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 6 deletions(-) create mode 100644 aici_abi/src/yesno.rs diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index aafe97ec..4b296c24 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -25,4 +25,9 @@ rx = ["dep:regex-automata"] [[bin]] name = "uppercase" -path = "src/uppercase.rs" \ No newline at end of file +path = "src/uppercase.rs" + + +[[bin]] +name = "yesno" +path = "src/yesno.rs" \ No newline at end of file diff --git a/aici_abi/src/uppercase.rs b/aici_abi/src/uppercase.rs index 644d9ec9..87cc0f05 100644 --- a/aici_abi/src/uppercase.rs +++ b/aici_abi/src/uppercase.rs @@ -6,10 +6,10 @@ use aici_abi::{ PostProcessResult, PreProcessArg, PreProcessResult, }; -// This constraints enforces an upper case letter every second byte +// This constraints enforces an upper case letter every 4th byte // The state is the position in the output stream -struct EvenUpper {} -impl FunctionalRecognizer for EvenUpper { +struct QuadUpper {} +impl FunctionalRecognizer for QuadUpper { fn initial(&self) -> usize { 0 } @@ -37,7 +37,7 @@ impl FunctionalRecognizer for EvenUpper { pub struct Runner { toktrie: TokTrie, tokens: Vec, - rec: StackRecognizer, + rec: StackRecognizer, } impl Runner { @@ -46,7 +46,7 @@ impl Runner { Runner { toktrie: TokTrie::from_host(), tokens: Vec::new(), - rec: StackRecognizer::from(EvenUpper {}), + rec: StackRecognizer::from(QuadUpper {}), } } } diff --git a/aici_abi/src/yesno.rs b/aici_abi/src/yesno.rs new file mode 100644 index 00000000..e593e684 --- /dev/null +++ b/aici_abi/src/yesno.rs @@ -0,0 +1,70 @@ +use aici_abi::{ + tokenize, toktree::TokTrie, AiciVm, InitPromptArg, InitPromptResult, MidProcessArg, + MidProcessResult, PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, TokenId, +}; + +pub struct Runner { + toktrie: TokTrie, + tokens: Vec, + yes: TokenId, + no: TokenId, +} + +impl Runner { + pub fn new() -> Self { + let yes = tokenize("Yes")[0]; + let no = tokenize("No")[0]; + // ignore user-passed arg + Runner { + toktrie: TokTrie::from_host(), + tokens: Vec::new(), + yes, + no, + } + } +} + +impl AiciVm for Runner { + fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { + if arg.prompt.len() < 2 { + // we'll be forcing answer; require a question + panic!("prompt too short") + } + InitPromptResult::default() + } + + fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { + if self.tokens.is_empty() { + // Make sure the prompt ends with newline + let toks = tokenize("\n"); + PreProcessResult::ff_tokens(toks) + } else { + PreProcessResult::continue_() + } + } + + fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { + let mut set = self.toktrie.alloc_token_set(); + set.allow_token(self.yes); + set.allow_token(self.no); + MidProcessResult::SampleWithBias { + allowed_tokens: set, + } + } + + fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { + // save our tokens + self.tokens.extend_from_slice(&arg.tokens); + if self.tokens.len() >= 2 { + PostProcessResult::stop() + } else { + PostProcessResult::from_arg(&arg) + } + } +} + +fn main() { + // test code here? +} + +aici_abi::aici_expose_all!(Runner, Runner::new()); From e88d4aa98188e955e4073fffb79dda96ee69dcc9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jan 2024 10:20:45 +0000 Subject: [PATCH 124/301] Better wasm error printing --- aici_abi/src/host.rs | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 5a41df5f..07bb8cf3 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -58,22 +58,10 @@ fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { buffer } -pub fn init_panic() { +fn init_panic() { #[cfg(target_arch = "wasm32")] std::panic::set_hook(Box::new(|info| { - let file = info.location().unwrap().file(); - let line = info.location().unwrap().line(); - let col = info.location().unwrap().column(); - - let msg = match info.payload().downcast_ref::<&'static str>() { - Some(s) => *s, - None => match info.payload().downcast_ref::() { - Some(s) => &s[..], - None => "Box", - }, - }; - - println!("Panicked at '{}', {}:{}:{}", msg, file, line, col); + println!("{}", info); })) } From 517f38c6f0c910bcc1f621cfda2c7d6ec31e4da2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jan 2024 10:22:19 +0000 Subject: [PATCH 125/301] add comment --- aici_abi/src/host.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 07bb8cf3..27415cf6 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -61,6 +61,7 @@ fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { fn init_panic() { #[cfg(target_arch = "wasm32")] std::panic::set_hook(Box::new(|info| { + // skip 'run with `RUST_BACKTRACE=1`' message (not relevant for remote running) println!("{}", info); })) } From 692a094c19daeca28c6deb5c4258b55e8211924e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 6 Jan 2024 17:27:55 +0100 Subject: [PATCH 126/301] spellchecking --- aici_abi/src/bytes.rs | 2 +- aici_abi/src/cfg.rs | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/aici_abi/src/bytes.rs b/aici_abi/src/bytes.rs index 86a53c12..1c471e6c 100644 --- a/aici_abi/src/bytes.rs +++ b/aici_abi/src/bytes.rs @@ -37,7 +37,7 @@ pub fn box_from_bytes(bytes: &[u8]) -> Box { pub fn vec_from_bytes(bytes: &[u8]) -> Vec { if bytes.len() % size_of::() != 0 { panic!( - "vecT: got {} bytes, needed mult of {}", + "vecT: got {} bytes, needed multiple of {}", bytes.len(), size_of::() ); diff --git a/aici_abi/src/cfg.rs b/aici_abi/src/cfg.rs index fb7d667b..695220bf 100644 --- a/aici_abi/src/cfg.rs +++ b/aici_abi/src/cfg.rs @@ -110,8 +110,10 @@ impl CfgParser { let (sgraph, stable) = match from_yacc(&grm, Minimiser::Pager) { Ok(r) => r, Err(e) => { - // not sure this works: - // anyhow::bail!("state table error:\n{e} on {:?}", grm.action(e.pidx)); + if false { + // not sure this works: + anyhow::bail!("state table error:\n{e} on {:?}", grm.action(e.pidx)); + } anyhow::bail!("state table error:\n{e}"); } }; From 2da88ceba85f358eeaeb3716287cd68797a337b8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Jan 2024 18:20:04 +0000 Subject: [PATCH 127/301] move text around readmes --- aici_abi/README.md | 204 +++++++++++++++++++++++++++++++++++++ aici_abi/implementation.md | 153 ++++++++++++++++++++++++++++ 2 files changed, 357 insertions(+) create mode 100644 aici_abi/README.md create mode 100644 aici_abi/implementation.md diff --git a/aici_abi/README.md b/aici_abi/README.md new file mode 100644 index 00000000..78a4e7ad --- /dev/null +++ b/aici_abi/README.md @@ -0,0 +1,204 @@ +# aici_abi + +This crate specifies the application binary inferface (ABI) for the AICI Controllers. +It also provides higher-level interfaces for implementing controllers. + +## Low-level interface + +Conceptually, the lowest level interface to AICI constraint is this: + +```rust +type TokenId = u32; +type SeqId = u32; + +trait AiciVm { + /// Called with the initial prompt. ~1000ms time limit. + fn init_prompt(prompt: Vec); + + /// Called before mid_process(), can fork or suspend. ~1ms. + fn pre_process() -> enum { + Stop, + Continue, // Same as Fork { num_forks: 1 } + Suspend, // skip this generation round + Fork { num_forks: u32 }, + } + + /// This is the main entry point for the module. ~20ms. + fn mid_process(fork_group: Vec) -> enum { + Stop, + SampleWithBias { bias: Vec }, + Splice { backtrack: u32, ff_tokens: Vec } + }; + + /// Called after tokens are appended. ~1ms. + fn post_process(tokens: Vec) -> enum { Stop, Continue }; +} +``` + +Tokens depend on the tokenizer used (eg., for Llama there 32000 tokens, and for GPT-4 there is ~100k). + +The actual binary interface is a bit more complicated, due +to limitations in passing values to and from Wasm. +A Wasm module instance is created for each token sequence. +Also, when the sequence forks (as in beam search), the module instance is cloned. +See the [AiciVm Rust trait](aici_abi/src/lib.rs) for details. + +A number of functions are exposed to the Wasm module. + +First, there are functions for accessing the current tokenizer: + +```rust +/// Given a byte sequence, return a sequence of token Ids. +fn tokenize_bytes(s: Vec) -> Vec; + +/// Represents trie of all tokens in the current tokenizer. +impl TokTrie { + /// Get Id for EOS token etc. + fn special_token(tok: SpecialToken) -> TokenId; + /// Number of tokens. + fn vocab_size() -> usize; + /// Convert token Id to bytes (often UTF-8 string). + fn token(token: TokenId) -> Vec; + /// Given a Recognizer, compute the set of allowed tokens. + fn compute_bias(rec: impl Recognizer) -> Vec; +} +``` + +Different forks in a sequence can communicate via shared variables: + +```rust +/// This can be looked up in fork_group. +fn self_seq_id() -> SeqId; + +trait VariableStorage { + fn get(name: str) -> Option>; + fn set(name: str, value: Vec); + fn append(name: str, value: Vec); +} +``` + +Additionally, the `stdout` and `stderr` file descriptors are captured by the runtime +and returned to user when streaming results. + +This interface may need to be extended in the future. + +## Byte stack interface + +The constraints are typically expressed on strings or bytes, not tokens. +To compute the set of tokens that match a string constraint, one needs go through all the possible tokens +and apply the constraint. +An efficient way to do this is walk a prefix tree (trie) of all tokens. +The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints +implementing the [following interface](aici_abi/src/toktree.rs): + +```rust +pub trait Recognizer { + /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. + fn push_byte(&mut self, byte: u8); + /// for _ in 0..num { stack.pop() } + fn pop_bytes(&mut self, num: usize); + /// X = stack.top(); stack.empty(); stack.push(X) + fn collapse(&mut self); + /// check if stack.top() transitions via byte to a viable state + fn byte_allowed(&mut self, byte: u8) -> bool; + /// check if stack.top() transitions via tok to a viable state + fn special_allowed(&mut self, tok: SpecialToken) -> bool; + /// Called when iteration over the trie is finished + /// Stack has exactly one element then. + fn trie_finished(&mut self); + /// This combines `push_byte` and `byte_allowed` into one function for performance. + fn try_push_byte(&mut self, byte: u8) -> bool; +} +``` + +The `AiciRecognizer` struct converts `Recognizer` to `AiciVm`. + +## Functional byte interface + +The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. + +```rust +pub trait FunctionalRecognizer { + /// Initial state + fn initial(&self) -> S; + /// Extend the recognizer with given byte. + fn append(&self, state: S, byte: u8) -> S; + /// Check if given byte is allowed in given state. + fn byte_allowed(&self, state: S, byte: u8) -> bool; + /// Check if given special token is allowed in given state. + fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; +} +``` + +These three layers add up to about 40k of compiled code (Wasm). + +## Regular expressions + +The `FunctionalRecognizer` interface is implemented for regular expressions. +The `S` type is the state of the DFA (Deterministic Finite Automaton) that recognizes the regular expression, +then `append()` and `byte_allowed()` are the standard DFA operations, +while `special_allowed()` is only implemented for end-of-sequence token +(which is allowed when the current state is accepting). + +## LR(1) grammars + +The `Recognizer` interface is implemented for LR(1) grammars and DFA-based lexers. + +The grammar uses inline syntax for the lexer: + +- `"keyword"` or `'keyword'` for keywords; any string works, eg. `"+="`, `"while"`, ... +- `"/.../"` or `'/.../'` for regular expressions; you cannot have both `'` and `"` in the regex + Special `SKIP` rule is used to indicate tokens that need to be skipped by the LR(1) parser (eg., whitespace and comments) + +The lexer has a DFA which recognizes all regexps and keywords +(a big disjunction, but with additional machinery to disambiguate between different branches). +It goes byte by byte, until the DFA gets to a dead state (from which no match is possible). +Then it goes back one byte and checks for match. +It prefers keywords over regexps. +If no match is found, an error is reported, which requires careful design of the lexical part of the grammar +(eg., see how the `white-space` rule below is prefix of the `pre-processor` rule). + +For example, this is fragment of [grammar for C](./grammars/c.y): + +```yacc +%start translation_unit +%% + +SKIP + : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment + | "///.*/" // line comment + | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor + | "/\n?[ \t\v\f]*/" // white-space + ; + +IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; + +CONSTANT + : "/0[xX][0-9a-fA-F]+[uUlL]*?/" + | "/0[0-9]+[uUlL]*?/" + ; + +STRING_LITERAL: '/"(\\.|[^\\"])*"/' ; + +primary_expression + : IDENTIFIER + | CONSTANT + | STRING_LITERAL + | "(" expression ")" + ; + +// ... + +enum_specifier + : "enum" "{" enumerator_list "}" + | "enum" IDENTIFIER "{" enumerator_list "}" + | "enum" IDENTIFIER + ; + +// ... + +translation_unit + : external_declaration + | translation_unit external_declaration + ; +``` diff --git a/aici_abi/implementation.md b/aici_abi/implementation.md new file mode 100644 index 00000000..bd766709 --- /dev/null +++ b/aici_abi/implementation.md @@ -0,0 +1,153 @@ +# Implementation notes + +## Token trie + +The round nodes represent tokens, the square nodes do not have a corresponding token. + +The number (`num_parents`) specifies how many parents do you need to pop to get to the parent of the node which comes after our children in DFS order. + +We also keep the `token_id` and a `subtree_size` (which includes the node itself) in each node. +A bogus `token_id` is used for nodes that do not have a corresponding token. + +```mermaid +graph TD + root[ε, 0] -- a --> a((a, 1)) + root -- b --> b((b, 1)) + root -- c --> c((c, 1)) + a -- x --> ax((ax, 1)) + a -- y --> ay[ay, 1] + a -- z --> az((az, 2)) + az -- a --> azq((aza, 3)) + ay -- a --> ayq((aya, 1)) + ay -- b --> ayw((ayb, 2)) +``` + +Traversal algorithm - computing the set of tokens allowed by a stack-based recognizer. +The set is stored in `logits` array - entries with `0.0` are allowed. + +```rust +let mut logits = vec![-100.0; VOCAB_SIZE + 1]; +``` + +A simple version of traversal algorithm: + +```rust +fn traverse(n) { + // mark token as allowed; nodes without token use `token_id == VOCAB_SIZE` + logits[n.token_id] = 0.0; + for c in n.children { + // for every child that starts with an allowed byte + if byte_allowed(c.byte) { + push_byte(c.byte); + // traverse it + traverse(c); + pop_bytes(1); + } + } +} +``` + +Now, assume the tree is laid out in memory in DFS order: + +```rust +fn traverse(mut p) { + let endp = p + nodes[p].subtree_size; + p += 1; // move to first child + while p < endp { + let n = nodes[p]; + if byte_allowed(n.byte) { + push_byte(n.byte); + logits[n.token_id] = 0.0; + // p is moved by n.subtree_size + p = traverse(p); + pop_bytes(1); + } else { + p += n.subtree_size; + } + } +} +``` + +Now, we get rid of the recursion: + +```rust +let mut p = 0; +while p < nodes.len() { + let n = nodes[p]; + if byte_allowed(n.byte) { + push_byte(n.byte); + logits[n.token_id] = 0.0; + // if the node is a leaf, we need to pop all the parents + pop_bytes(if n.subtree_size == 1 { n.num_parents } else { 0 }); + // move to first child, or sibling if no children + p += 1; + } else { + // skip the children, and go to the sibling node + p += n.subtree_size; + // regardless if the node is a leaf, we need to pop all the parents + pop_bytes(n.num_parents - 1); + } +} +``` + +Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. +The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. + +## LR(1) parsing + +The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. +DFA has a single number as the state, while the state of the LR(1) is a stack of numbers. +The LR(1) action is determined based on the next token from the lexer and the top of the stack. + +The `Recognizer` interface also has a concept of stack, however every entry on that +stack contains a DFA state and an LR(1) stack. + +Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state, +while the LR(1) stack is copied unchanged (the memory is shared). + + +### Early error detection + +Consider the following invalid C program: + +```c +int 123456; +``` + +The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword. +The parser would reject `123456`, however only after all six characters of it have been read. +This is too late for the LLM. + +To detect such errors early, we compute a set of reachable tokens for each DFA state. +For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`). +The initial DFA state has a full set of tokens, while a state after `'i'` +has only `int`, `if`, and `ID`, +and a state after `'1'` includes only `INTLIT`. +In the picture below, each state is labelled by its reachable set, +and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity. + +```mermaid +graph LR + 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}")) + 0 -- "[a-z] - [i]" --> id(("{ID*}")) + 0 -- "[0-9]" --> const(("{INTLIT*}")) + const -- "[0-9]" --> const + const -- "[a-z]" --> bot["{}"] + i -- "[a-z0-9] - [nf]" --> id + id -- "[a-z0-9]" --> id + i -- "[n]" --> in(("{int,ID*}")) + in -- "[t]" --> int(("{int*,ID}")) + in -- "[a-z0-9] - [t]" --> id + int -- "[a-z0-9]" --> id + i -- "[f]" --> if(("{if*,ID}")) + if -- "[a-z0-9]" --> id +``` + +For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do +not immediately lead to an error. + +While parsing input, if the intersection of viable and reachable tokens is empty, we report an error. + +In the example above, the viable tokens after `int` do not include `INTLIT`, +and thus the parser fails immediately at `1`. + From dfb0afac484802071bd7d60d8028102929266125 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Jan 2024 18:25:48 +0000 Subject: [PATCH 128/301] move grammars folder --- aici_abi/grammars/c.y | 442 +++++++++++++ aici_abi/grammars/sample.c | 1245 ++++++++++++++++++++++++++++++++++++ aici_abi/src/cfg.rs | 4 +- 3 files changed, 1689 insertions(+), 2 deletions(-) create mode 100644 aici_abi/grammars/c.y create mode 100644 aici_abi/grammars/sample.c diff --git a/aici_abi/grammars/c.y b/aici_abi/grammars/c.y new file mode 100644 index 00000000..7397a971 --- /dev/null +++ b/aici_abi/grammars/c.y @@ -0,0 +1,442 @@ +// based on http://www.lysator.liu.se/c/ANSI-C-grammar-y.html + +%start translation_unit +%% + +SKIP + : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment + | "///.*/" // line comment + | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor + | "/\n?[ \t\v\f]*/" // white-space + ; + +IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; + +TYPE_NAME: "/[a-zA-Z_][0-9a-zA-Z_]*_t/" ; + +CONSTANT + : "/0[xX][0-9a-fA-F]+[uUlL]*?/" + | "/0[0-9]+[uUlL]*?/" + | "/[0-9]+[uUlL]*?/" + | "/[a-zA-Z_]?'(\\.|[^\\'])+'/" + | "/[0-9]+[Ee][+-]?[0-9]+[flFL]?/" + | "/[0-9]*\\.[0-9]+([Ee][+-]?[0-9]+)?[flFL]?/" + | "/[0-9]+\\.[0-9]*([Ee][+-]?[0-9]+)?[flFL]?/" + ; + +STRING_LITERAL: '/[a-zA-Z_]?"(\\.|[^\\"])*"/' ; + +primary_expression + : IDENTIFIER + | CONSTANT + | STRING_LITERAL + | "(" expression ")" + ; + +postfix_expression + : primary_expression + | postfix_expression "[" expression "]" + | postfix_expression "(" ")" + | postfix_expression "(" argument_expression_list ")" + | postfix_expression "." IDENTIFIER + | postfix_expression "->" IDENTIFIER + | postfix_expression "++" + | postfix_expression "--" + ; + +argument_expression_list + : assignment_expression + | argument_expression_list "," assignment_expression + ; + +unary_expression + : postfix_expression + | "++" unary_expression + | "--" unary_expression + | unary_operator cast_expression + | "sizeof" unary_expression + | "sizeof" "(" type_name ")" + ; + +unary_operator + : "&" + | "*" + | "+" + | "-" + | "~" + | "!" + ; + +cast_expression + : unary_expression + | "(" type_name ")" cast_expression + ; + +multiplicative_expression + : cast_expression + | multiplicative_expression "*" cast_expression + | multiplicative_expression "/" cast_expression + | multiplicative_expression "%" cast_expression + ; + +additive_expression + : multiplicative_expression + | additive_expression "+" multiplicative_expression + | additive_expression "-" multiplicative_expression + ; + +shift_expression + : additive_expression + | shift_expression "<<" additive_expression + | shift_expression ">>" additive_expression + ; + +relational_expression + : shift_expression + | relational_expression "<" shift_expression + | relational_expression ">" shift_expression + | relational_expression "<=" shift_expression + | relational_expression ">=" shift_expression + ; + +equality_expression + : relational_expression + | equality_expression "==" relational_expression + | equality_expression "!=" relational_expression + ; + +and_expression + : equality_expression + | and_expression "&" equality_expression + ; + +exclusive_or_expression + : and_expression + | exclusive_or_expression "^" and_expression + ; + +inclusive_or_expression + : exclusive_or_expression + | inclusive_or_expression "|" exclusive_or_expression + ; + +logical_and_expression + : inclusive_or_expression + | logical_and_expression "&&" inclusive_or_expression + ; + +logical_or_expression + : logical_and_expression + | logical_or_expression "||" logical_and_expression + ; + +conditional_expression + : logical_or_expression + | logical_or_expression "?" expression ":" conditional_expression + ; + +assignment_expression + : conditional_expression + | unary_expression assignment_operator assignment_expression + ; + +assignment_operator + : "=" + | "*=" + | "/=" + | "%=" + | "+=" + | "-=" + | "<<=" + | ">>=" + | "&=" + | "^=" + | "|=" + ; + +expression + : assignment_expression + | expression "," assignment_expression + ; + +constant_expression + : conditional_expression + ; + +declaration + : declaration_specifiers ";" + | declaration_specifiers init_declarator_list ";" + ; + +declaration_specifiers + : storage_class_specifier + | storage_class_specifier declaration_specifiers + | type_specifier + | type_specifier declaration_specifiers + | type_qualifier + | type_qualifier declaration_specifiers + ; + +init_declarator_list + : init_declarator + | init_declarator_list "," init_declarator + ; + +init_declarator + : declarator + | declarator "=" initializer + ; + +storage_class_specifier + : "typedef" + | "extern" + | "static" + | "auto" + | "register" + | "inline" + ; + +type_specifier + : "void" + | "char" + | "short" + | "int" + | "long" + | "float" + | "double" + | "signed" + | "unsigned" + | "bool" + | struct_or_union_specifier + | enum_specifier + | TYPE_NAME + ; + +struct_or_union_specifier + : struct_or_union IDENTIFIER "{" struct_declaration_list "}" + | struct_or_union "{" struct_declaration_list "}" + | struct_or_union IDENTIFIER + ; + +struct_or_union + : "struct" + | "union" + ; + +struct_declaration_list + : struct_declaration + | struct_declaration_list struct_declaration + ; + +struct_declaration + : specifier_qualifier_list struct_declarator_list ";" + ; + +specifier_qualifier_list + : type_specifier specifier_qualifier_list + | type_specifier + | type_qualifier specifier_qualifier_list + | type_qualifier + ; + +struct_declarator_list + : struct_declarator + | struct_declarator_list "," struct_declarator + ; + +struct_declarator + : declarator + | ":" constant_expression + | declarator ":" constant_expression + ; + +enum_specifier + : "enum" "{" enumerator_list "}" + | "enum" IDENTIFIER "{" enumerator_list "}" + | "enum" IDENTIFIER + ; + +enumerator_list + : enumerator + | enumerator_list "," enumerator + ; + +enumerator + : IDENTIFIER + | IDENTIFIER "=" constant_expression + ; + +type_qualifier + : "const" + | "volatile" + ; + +declarator + : pointer direct_declarator + | direct_declarator + ; + +direct_declarator + : IDENTIFIER + | "(" declarator ")" + | direct_declarator "[" constant_expression "]" + | direct_declarator "[" "]" + | direct_declarator "(" parameter_type_list ")" + | direct_declarator "(" identifier_list ")" + | direct_declarator "(" ")" + ; + +pointer + : "*" + | "*" type_qualifier_list + | "*" pointer + | "*" type_qualifier_list pointer + ; + +type_qualifier_list + : type_qualifier + | type_qualifier_list type_qualifier + ; + + +parameter_type_list + : parameter_list + | parameter_list "," "..." + ; + +parameter_list + : parameter_declaration + | parameter_list "," parameter_declaration + ; + +parameter_declaration + : declaration_specifiers declarator + | declaration_specifiers abstract_declarator + | declaration_specifiers + ; + +identifier_list + : IDENTIFIER + | identifier_list "," IDENTIFIER + ; + +type_name + : specifier_qualifier_list + | specifier_qualifier_list abstract_declarator + ; + +abstract_declarator + : pointer + | direct_abstract_declarator + | pointer direct_abstract_declarator + ; + +direct_abstract_declarator + : "(" abstract_declarator ")" + | "[" "]" + | "[" constant_expression "]" + | direct_abstract_declarator "[" "]" + | direct_abstract_declarator "[" constant_expression "]" + | "(" ")" + | "(" parameter_type_list ")" + | direct_abstract_declarator "(" ")" + | direct_abstract_declarator "(" parameter_type_list ")" + ; + +initializer + : assignment_expression + | "." IDENTIFIER "=" assignment_expression + | "[" assignment_expression "]" "=" assignment_expression + | "{" initializer_list "}" + | "{" initializer_list "," "}" + ; + +initializer_list + : initializer + | initializer_list "," initializer + ; + +statement + : labeled_statement + | compound_statement + | expression_statement + | selection_statement + | iteration_statement + | jump_statement + ; + +labeled_statement + : IDENTIFIER ":" statement + | "case" constant_expression ":" statement + | "default" ":" statement + ; + +compound_statement + : "{" "}" + | "{" statement_list "}" + ; + +declaration_list + : declaration + | declaration_list declaration + ; + +statement_or_declaration + : statement + | declaration + ; + +statement_list + : statement_or_declaration + | statement_list statement_or_declaration + ; + +expression_statement + : ";" + | expression ";" + ; + +for_decl + : expression_statement + | declaration + ; + +selection_statement + : "if" "(" expression ")" statement + | "if" "(" expression ")" statement "else" statement + | "switch" "(" expression ")" statement + ; + +iteration_statement + : "while" "(" expression ")" statement + | "do" statement "while" "(" expression ")" ";" + | "for" "(" for_decl expression_statement ")" statement + | "for" "(" for_decl expression_statement expression ")" statement + ; + +jump_statement + : "goto" IDENTIFIER ";" + | "continue" ";" + | "break" ";" + | "return" ";" + | "return" expression ";" + ; + +translation_unit + : external_declaration + | translation_unit external_declaration + ; + +external_declaration + : function_definition + | declaration + ; + +function_definition + : declaration_specifiers declarator declaration_list compound_statement + | declaration_specifiers declarator compound_statement + | declarator declaration_list compound_statement + | declarator compound_statement + ; + +%% diff --git a/aici_abi/grammars/sample.c b/aici_abi/grammars/sample.c new file mode 100644 index 00000000..97753eb0 --- /dev/null +++ b/aici_abi/grammars/sample.c @@ -0,0 +1,1245 @@ +#include "devs_internal.h" +#include "devs_objects.h" + +// #define LOG_TAG "obj" +#include "devs_logging.h" + +void devs_map_clear(devs_ctx_t *ctx, devs_map_t *map) { + if (map->data) { + devs_free(ctx, map->data); + map->data = NULL; + map->capacity = 0; + map->length = 0; + } +} + +static inline uint16_t *short_keys(devs_short_map_t *map) { + return (uint16_t *)(map->short_data + map->capacity); +} + +static value_t *lookup_short(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key) { + unsigned len = map->length; + uint16_t *keys = short_keys(map); + for (unsigned i = 0; i < len; i++) { + if (keys[i] == key) { + return &map->short_data[i]; + } + } + return NULL; +} + +static value_t *lookup(devs_ctx_t *ctx, devs_map_t *map, value_t key) { + if (!devs_is_string(ctx, key)) + return NULL; + + value_t *data = map->data; + uint32_t kh = devs_handle_value(key); + unsigned len2 = map->length * 2; + + // do a quick reference-only check + for (unsigned i = 0; i < len2; i += 2) { + // check the low bits first, since they are more likely to be different + if (devs_handle_value(data[i]) == kh && data[i].u64 == key.u64) { + return &data[i + 1]; + } + } + + // slow path - compare strings + unsigned ksz, csz; + const char *cp, *kp = devs_string_get_utf8(ctx, key, &ksz); + for (unsigned i = 0; i < len2; i += 2) { + cp = devs_string_get_utf8(ctx, data[i], &csz); + if (csz == ksz && memcmp(kp, cp, ksz) == 0) + return &data[i + 1]; + } + + // nothing found... + return NULL; +} + +static value_t proto_value(devs_ctx_t *ctx, const devs_builtin_proto_entry_t *p) { + unsigned idx = p->builtin_idx; + if (idx <= DEVS_BUILTIN_OBJECT___MAX) + return devs_builtin_object_value(ctx, idx); + JD_ASSERT(idx >= DEVS_FIRST_BUILTIN_FUNCTION); + return devs_value_from_handle(DEVS_HANDLE_TYPE_STATIC_FUNCTION, idx); +} + +unsigned devs_maplike_iter(devs_ctx_t *ctx, devs_maplike_t *src, void *userdata, + devs_map_iter_cb_t cb) { + if (devs_is_service_spec(ctx, src)) { + // Object.keys() etc or debugger inspection on compiled spec + // return empty for now, do not crash + return 0; + } else if (devs_is_builtin_proto(src)) { + const devs_builtin_proto_t *proto = (const devs_builtin_proto_t *)src; + const devs_builtin_proto_entry_t *p = proto->entries; + while (p->builtin_string_id) { + if (cb) + cb(ctx, userdata, devs_builtin_string(p->builtin_string_id), proto_value(ctx, p)); + p++; + } + return p - proto->entries; + } else { + JD_ASSERT(devs_is_map(src)); + devs_map_t *srcmap = (devs_map_t *)src; + unsigned len = srcmap->length; + + if (cb != NULL) { + unsigned len2 = srcmap->length * 2; + value_t *data = srcmap->data; + for (unsigned i = 0; i < len2; i += 2) { + cb(ctx, userdata, data[i], data[i + 1]); + } + } + + if (devs_gc_tag(srcmap) == DEVS_GC_TAG_HALF_STATIC_MAP) + len += devs_maplike_iter(ctx, srcmap->proto, userdata, cb); + + return len; + } +} + +void devs_map_copy_into(devs_ctx_t *ctx, devs_map_t *dst, devs_maplike_t *src) { + devs_maplike_iter(ctx, src, dst, (devs_map_iter_cb_t)devs_map_set); +} + +struct kv_ctx { + unsigned dp; + bool keys; + devs_array_t *arr; +}; + +static void kv_add(devs_ctx_t *ctx, void *userdata, value_t k, value_t v) { + struct kv_ctx *acc = userdata; + acc->arr->data[acc->dp++] = acc->keys ? k : v; +} + +bool devs_maplike_is_map(devs_ctx_t *ctx, devs_maplike_t *src) { + if (src == NULL || devs_is_builtin_proto(src) || devs_is_service_spec(ctx, src)) + return false; + JD_ASSERT(devs_is_map(src)); + return true; +} + +void devs_maplike_keys_or_values(devs_ctx_t *ctx, devs_maplike_t *src, devs_array_t *arr, + bool keys) { + struct kv_ctx acc = { + .dp = arr->length, + .arr = arr, + .keys = keys, + }; + + unsigned len = devs_maplike_iter(ctx, src, NULL, NULL); + + if (devs_array_insert(ctx, arr, acc.dp, len) != 0) + return; + + devs_maplike_iter(ctx, src, &acc, kv_add); +} + +static int grow_len(int capacity) { + int newlen = capacity * 10 / 8; + if (newlen < 4) + newlen = 4; + return newlen; +} + +void devs_map_set(devs_ctx_t *ctx, devs_map_t *map, value_t key, value_t v) { + value_t *tmp = lookup(ctx, map, key); + if (tmp != NULL) { + *tmp = v; + return; + } + + if (!devs_is_string(ctx, key)) { + devs_throw_expecting_error(ctx, DEVS_BUILTIN_STRING_STRING, key); + return; + } + + JD_ASSERT(map->capacity >= map->length); + + if (map->capacity == map->length) { + int newlen = grow_len(map->capacity); + tmp = devs_try_alloc(ctx, newlen * (2 * sizeof(value_t))); + if (!tmp) + return; + map->capacity = newlen; + if (map->length) { + memcpy(tmp, map->data, map->length * sizeof(value_t) * 2); + } + map->data = tmp; + jd_gc_unpin(ctx->gc, tmp); + } + + map->data[map->length * 2] = key; + map->data[map->length * 2 + 1] = v; + map->length++; +} + +void devs_short_map_set(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key, value_t v) { + value_t *tmp = lookup_short(ctx, map, key); + if (tmp != NULL) { + *tmp = v; + return; + } + + JD_ASSERT(map->capacity >= map->length); + + if (map->capacity == map->length) { + int newlen = grow_len(map->capacity); + tmp = devs_try_alloc(ctx, newlen * (sizeof(value_t) + sizeof(uint16_t))); + if (!tmp) + return; + uint16_t *srckeys = short_keys(map); + map->capacity = newlen; + if (map->length) { + memcpy(tmp, map->short_data, map->length * sizeof(value_t)); + memcpy(tmp + newlen, srckeys, map->length * sizeof(uint16_t)); + } + map->short_data = tmp; + jd_gc_unpin(ctx->gc, tmp); + } + + map->short_data[map->length] = v; + short_keys(map)[map->length] = key; + map->length++; +} + +int devs_map_delete(devs_ctx_t *ctx, devs_map_t *map, value_t key) { + value_t *tmp = lookup(ctx, map, key); + if (tmp == NULL) { + return -1; + } + + tmp--; + unsigned off = tmp - map->data; + unsigned trailing = map->length - off / 2 - 1; + map->length--; + if (trailing) + memmove(tmp, tmp + 2, trailing * 2 * sizeof(value_t)); + return 0; +} + +bool devs_is_service_spec(devs_ctx_t *ctx, const void *ptr) { + return (uintptr_t)((const uint8_t *)ptr - + (const uint8_t *)devs_img_get_service_spec(ctx->img, 0)) < + (sizeof(devs_service_spec_t) * ctx->img.header->num_service_specs); +} + +value_t devs_map_get(devs_ctx_t *ctx, devs_map_t *map, value_t key) { + value_t *tmp = lookup(ctx, map, key); + if (tmp == NULL) + return devs_undefined; + return *tmp; +} + +value_t devs_short_map_get(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key) { + value_t *tmp = lookup_short(ctx, map, key); + if (tmp == NULL) + return devs_undefined; + return *tmp; +} + +static const devs_builtin_proto_t *get_static_built_in_proto(devs_ctx_t *ctx, unsigned idx) { + JD_ASSERT(idx <= DEVS_BUILTIN_OBJECT___MAX); + if (devs_builtin_protos[idx].entries == NULL) + return NULL; // not there? + return &devs_builtin_protos[idx]; +} + +static const uint8_t builtin_proto_idx[] = { + [DEVS_BUILTIN_OBJECT_MATH] = 1, + [DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE] = 2, + [DEVS_BUILTIN_OBJECT_ARRAY_PROTOTYPE] = 3, + [DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE] = 4, + [DEVS_BUILTIN_OBJECT_DSREGISTER_PROTOTYPE] = 5, + [DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE] = 6, + [DEVS_BUILTIN_OBJECT_DSEVENT_PROTOTYPE] = 7, + [DEVS_BUILTIN_OBJECT_DEVICESCRIPT] = 8, + [DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE] = 9, + [DEVS_BUILTIN_OBJECT_BUFFER] = 10, + [DEVS_BUILTIN_OBJECT_GPIO_PROTOTYPE] = 11, + [DEVS_BUILTIN_OBJECT_GPIO] = 12, +}; +#define MAX_PROTO 12 + +devs_maplike_t *devs_get_builtin_object(devs_ctx_t *ctx, unsigned idx) { + if (idx < sizeof(builtin_proto_idx)) { + unsigned midx = builtin_proto_idx[idx]; + if (midx > 0) { + midx--; + if (ctx->_builtin_protos == NULL) { + ctx->_builtin_protos = devs_try_alloc(ctx, sizeof(void *) * MAX_PROTO); + ctx->_num_builtin_protos = MAX_PROTO; + if (ctx->_builtin_protos == NULL) + return NULL; // whoops + } + JD_ASSERT(midx < MAX_PROTO); + devs_map_t *m = ctx->_builtin_protos[midx]; + if (m == NULL) { + m = devs_any_try_alloc(ctx, DEVS_GC_TAG_HALF_STATIC_MAP, sizeof(devs_map_t)); + if (m != NULL) { + ctx->_builtin_protos[midx] = m; + m->proto = (devs_maplike_t *)get_static_built_in_proto(ctx, idx); + } + } + return (devs_maplike_t *)m; + } + } + + return (devs_maplike_t *)get_static_built_in_proto(ctx, idx); +} + +bool devs_static_streq(devs_ctx_t *ctx, unsigned stridx, const char *other, unsigned other_len) { + unsigned size; + const char *r = devs_img_get_utf8(ctx->img, stridx, &size); + if (other_len != size) + return false; + return memcmp(r, other, size) == 0; +} + +#define MAX_OFF_BITS (DEVS_PACK_SHIFT - DEVS_ROLE_BITS) + +value_t devs_value_from_service_spec_idx(devs_ctx_t *ctx, unsigned idx) { + return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, + DEVS_ROLE_INVALID | (idx << DEVS_ROLE_BITS)); +} + +value_t devs_value_from_service_spec(devs_ctx_t *ctx, const devs_service_spec_t *spec) { + unsigned idx = spec - devs_img_get_service_spec(ctx->img, 0); + JD_ASSERT(idx < ctx->img.header->num_service_specs); + return devs_value_from_service_spec_idx(ctx, idx); +} + +value_t devs_value_from_packet_spec(devs_ctx_t *ctx, const devs_packet_spec_t *pkt) { + if (pkt == NULL) + return devs_undefined; + const uint32_t *baseoff = (const void *)devs_img_get_service_spec(ctx->img, 0); + uintptr_t off = (const uint32_t *)pkt - baseoff; + JD_ASSERT(off < (1 << MAX_OFF_BITS)); + return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, + DEVS_ROLE_INVALID | (off << DEVS_ROLE_BITS)); +} + +int devs_value_to_service_spec_idx(devs_ctx_t *ctx, value_t v) { + if (devs_handle_type(v) != DEVS_HANDLE_TYPE_ROLE_MEMBER) + return -1; + unsigned off = devs_handle_value(v) >> DEVS_ROLE_BITS; + if (off < ctx->img.header->num_service_specs) + return off; + return -1; +} + +const devs_service_spec_t *devs_value_to_service_spec(devs_ctx_t *ctx, value_t v) { + int off = devs_value_to_service_spec_idx(ctx, v); + if (off < 0) + return NULL; + return devs_img_get_service_spec(ctx->img, off); +} + +const devs_packet_spec_t *devs_decode_role_packet(devs_ctx_t *ctx, value_t v, unsigned *roleidx) { + if (roleidx) + *roleidx = DEVS_ROLE_INVALID; + if (devs_handle_type(v) != DEVS_HANDLE_TYPE_ROLE_MEMBER) + return NULL; + if (devs_value_to_service_spec(ctx, v)) + return NULL; + uint32_t h = devs_handle_value(v); + if (roleidx) + *roleidx = h & DEVS_ROLE_MASK; + return devs_img_get_packet_spec(ctx->img, h >> DEVS_ROLE_BITS); +} + +int devs_spec_idx(devs_ctx_t *ctx, const devs_service_spec_t *spec) { + if (spec == NULL) + return -1; + unsigned idx = spec - devs_img_get_service_spec(ctx->img, 0); + JD_ASSERT(idx < ctx->img.header->num_service_specs); + return idx; +} + +const devs_service_spec_t *devs_role_spec_for_class(devs_ctx_t *ctx, uint32_t cls) { + for (unsigned i = 0; i < ctx->img.header->num_service_specs; ++i) { + const devs_service_spec_t *spec = devs_img_get_service_spec(ctx->img, i); + if (spec->service_class == cls) + return spec; + } + return NULL; +} + +int devs_packet_spec_parent(devs_ctx_t *ctx, const devs_packet_spec_t *pspec) { + int off = (uint8_t *)pspec - ctx->img.data - ctx->img.header->service_specs.start; + for (unsigned i = 0; i < ctx->img.header->num_service_specs; ++i) { + const devs_service_spec_t *spec = devs_img_get_service_spec(ctx->img, i); + int idx = off - 4 * spec->packets_offset; + if (0 <= idx && idx < (int)(spec->num_packets * sizeof(devs_packet_spec_t))) + return i; + } + JD_PANIC(); + return -1; +} + +const devs_service_spec_t *devs_role_spec(devs_ctx_t *ctx, unsigned roleidx) { + if (roleidx >= DEVS_ROLE_FIRST_SPEC) { + unsigned specidx = roleidx - DEVS_ROLE_FIRST_SPEC; + if (specidx >= ctx->img.header->num_service_specs) + return NULL; + return devs_img_get_service_spec(ctx->img, specidx); + } + + devs_role_t *r = devs_role(ctx, roleidx); + + if (!r) + return NULL; + + return devs_role_spec_for_class(ctx, r->jdrole->service_class); +} + +devs_role_t *devs_role_or_fail(devs_ctx_t *ctx, unsigned roleidx) { + devs_role_t *r = devs_role(ctx, roleidx); + if (r == NULL) + devs_invalid_program(ctx, 60130); + return r; +} + +jd_device_service_t *devs_role_service(devs_ctx_t *ctx, unsigned roleidx) { + devs_role_t *r = devs_role(ctx, roleidx); + if (r == NULL) + return NULL; + return r->jdrole->service; +} + +const char *devs_role_name(devs_ctx_t *ctx, unsigned idx) { + devs_role_t *r = devs_role(ctx, idx); + if (r == NULL) + return "???"; + return r->jdrole->name; +} + +const devs_service_spec_t *devs_get_base_spec(devs_ctx_t *ctx, const devs_service_spec_t *spec) { + if (spec->service_class == JD_SERVICE_CLASS_BASE) + return NULL; + int idx = spec->flags & DEVS_SERVICESPEC_FLAG_DERIVE_MASK; + JD_ASSERT(idx <= DEVS_SERVICESPEC_FLAG_DERIVE_LAST); + return devs_img_get_service_spec(ctx->img, idx); +} + +value_t devs_spec_lookup(devs_ctx_t *ctx, const devs_service_spec_t *spec, value_t key) { + while (spec) { + JD_ASSERT(devs_is_service_spec(ctx, spec)); + const devs_packet_spec_t *pkts = devs_img_get_packet_spec(ctx->img, spec->packets_offset); + unsigned num_packets = spec->num_packets; + + if (devs_handle_type(key) == DEVS_HANDLE_TYPE_IMG_BUFFERISH) { + unsigned kidx = devs_handle_value(key); + for (unsigned i = 0; i < num_packets; ++i) { + if (pkts[i].name_idx == kidx) + return devs_value_from_packet_spec(ctx, &pkts[i]); + } + } + + unsigned ksz; + const char *kptr = devs_string_get_utf8(ctx, key, &ksz); + if (ksz == 0) + return devs_undefined; + + for (unsigned i = 0; i < num_packets; ++i) { + if (devs_static_streq(ctx, pkts[i].name_idx, kptr, ksz)) + return devs_value_from_packet_spec(ctx, &pkts[i]); + } + + spec = devs_get_base_spec(ctx, spec); + } + + return devs_undefined; +} + +static value_t devs_proto_lookup(devs_ctx_t *ctx, const devs_builtin_proto_t *proto, value_t key) { + JD_ASSERT(devs_is_proto(proto)); + + while (proto) { + const devs_builtin_proto_entry_t *p = proto->entries; + + if (devs_handle_type(key) == DEVS_HANDLE_TYPE_IMG_BUFFERISH && + (devs_handle_value(key) >> DEVS_STRIDX__SHIFT) == DEVS_STRIDX_BUILTIN) { + unsigned kidx = devs_handle_value(key) & ((1 << DEVS_STRIDX__SHIFT) - 1); + while (p->builtin_string_id) { + if (p->builtin_string_id == kidx) + return proto_value(ctx, p); + p++; + } + } else { + unsigned ksz; + const char *kptr = devs_string_get_utf8(ctx, key, &ksz); + if (ksz != strlen(kptr)) + return devs_undefined; + while (p->builtin_string_id) { + if (strcmp(devs_builtin_string_by_idx(p->builtin_string_id), kptr) == 0) + return proto_value(ctx, p); + p++; + } + } + + proto = proto->parent; + } + + return devs_undefined; +} + +static value_t devs_function_bind_alloc(devs_ctx_t *ctx, value_t obj, value_t fn) { + devs_bound_function_t *res = + devs_any_try_alloc(ctx, DEVS_GC_TAG_BOUND_FUNCTION, sizeof(devs_bound_function_t)); + if (res == NULL) + return devs_undefined; + + res->this_val = obj; + res->func = fn; + return devs_value_from_gc_obj(ctx, res); +} + +static const devs_builtin_function_t *devs_get_property_desc(devs_ctx_t *ctx, value_t fn) { + int htp = devs_handle_type(fn); + + if (htp != DEVS_HANDLE_TYPE_STATIC_FUNCTION) + return NULL; + + unsigned fidx = devs_handle_value(fn); + + int bltin = fidx - DEVS_FIRST_BUILTIN_FUNCTION; + if (bltin >= 0) { + JD_ASSERT(bltin < devs_num_builtin_functions); + const devs_builtin_function_t *h = &devs_builtin_functions[bltin]; + if (h->flags & DEVS_BUILTIN_FLAG_IS_PROPERTY) { + JD_ASSERT(h->num_args == 0); + return h; + } + } + + return NULL; +} + +// if `fn` is a static function, return `(obj, fn)` tuple +// if `fn` is a role member and `obj` is role, return (a different) `(obj, fn)` tuple +// otherwise return `obj` +// it may allocate an object for the tuple, but typically it doesn't +value_t devs_function_bind(devs_ctx_t *ctx, value_t obj, value_t fn) { + int htp = devs_handle_type(fn); + + if (htp == DEVS_HANDLE_TYPE_ROLE_MEMBER && devs_handle_type(obj) == DEVS_HANDLE_TYPE_ROLE && + !devs_value_to_service_spec(ctx, fn)) { + uint32_t role = devs_handle_value(obj); + JD_ASSERT((role & DEVS_ROLE_MASK) == role); + role |= devs_handle_value(fn) & ~DEVS_ROLE_MASK; + return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, role); + } + + if (htp == DEVS_HANDLE_TYPE_CLOSURE) + return devs_function_bind_alloc(ctx, obj, fn); + + if (htp != DEVS_HANDLE_TYPE_STATIC_FUNCTION) + return fn; + + const devs_builtin_function_t *h = devs_get_property_desc(ctx, fn); + if (h) + return h->handler.prop(ctx, obj); + + unsigned fidx = devs_handle_value(fn); + int otp = devs_handle_type(obj); + + if (fidx <= 0xffff) + switch (otp) { + case DEVS_HANDLE_TYPE_SPECIAL: + case DEVS_HANDLE_TYPE_FIBER: + case DEVS_HANDLE_TYPE_ROLE: + case DEVS_HANDLE_TYPE_ROLE_MEMBER: + case DEVS_HANDLE_TYPE_STATIC_FUNCTION: + case DEVS_HANDLE_TYPE_IMG_BUFFERISH: { + uint32_t hv = devs_handle_value(obj); + JD_ASSERT((((uint32_t)otp << DEVS_PACK_SHIFT) >> DEVS_PACK_SHIFT) == (uint32_t)otp); + JD_ASSERT((hv >> DEVS_PACK_SHIFT) == 0); + JD_ASSERT(devs_handle_high_value(obj) == 0); + return devs_value_from_handle(DEVS_HANDLE_TYPE_BOUND_FUNCTION_STATIC | (fidx << 4), + (otp << DEVS_PACK_SHIFT) | hv); + } + + case DEVS_HANDLE_TYPE_GC_OBJECT: + JD_ASSERT(devs_handle_high_value(obj) == 0); + return devs_value_from_handle(DEVS_HANDLE_TYPE_BOUND_FUNCTION | (fidx << 4), + devs_handle_value(obj)); + } + + return devs_function_bind_alloc(ctx, obj, fn); +} + +value_t devs_make_closure(devs_ctx_t *ctx, devs_activation_t *closure, unsigned fnidx) { + JD_ASSERT(fnidx <= 0xffff); + return devs_value_from_pointer(ctx, DEVS_HANDLE_TYPE_CLOSURE | (fnidx << 4), closure); +} + +static int devs_get_fnidx_core(devs_ctx_t *ctx, value_t src, value_t *this_val, + devs_activation_t **closure, int depth) { + *closure = NULL; + *this_val = devs_undefined; + + if (depth > 2) + return -1; + + uint32_t hv = devs_handle_value(src); + switch (devs_handle_type(src)) { + case DEVS_HANDLE_TYPE_STATIC_FUNCTION: + *this_val = devs_undefined; + return hv; + case DEVS_HANDLE_TYPE_BOUND_FUNCTION_STATIC: + *this_val = + devs_value_from_handle(hv >> DEVS_PACK_SHIFT, hv & ((1 << DEVS_PACK_SHIFT) - 1)); + return devs_handle_high_value(src); + case DEVS_HANDLE_TYPE_BOUND_FUNCTION: + *this_val = devs_value_from_handle(DEVS_HANDLE_TYPE_GC_OBJECT, hv); + return devs_handle_high_value(src); + case DEVS_HANDLE_TYPE_CLOSURE: + *closure = devs_handle_ptr_value(ctx, src); + return devs_handle_high_value(src); + case DEVS_HANDLE_TYPE_GC_OBJECT: { + devs_bound_function_t *bnd = devs_handle_ptr_value(ctx, src); + if (devs_gc_tag(bnd) == DEVS_GC_TAG_BOUND_FUNCTION) { + int r = devs_get_fnidx_core(ctx, bnd->func, this_val, closure, depth + 1); + *this_val = bnd->this_val; + return r; + } else { + return -1; + } + } + default: { + if (devs_is_nullish(src)) + return -1; + value_t f = devs_object_get_built_in_field(ctx, src, DEVS_BUILTIN_STRING___FUNC__); + if (devs_is_undefined(f)) + return -1; + else { + int r = devs_get_fnidx_core(ctx, f, this_val, closure, depth + 1); + *this_val = src; + return r; + } + } + } +} + +int devs_get_fnidx(devs_ctx_t *ctx, value_t src, value_t *this_val, devs_activation_t **closure) { + return devs_get_fnidx_core(ctx, src, this_val, closure, 0); +} + +#define ATTACH_RW 0x01 +#define ATTACH_ENUM 0x02 +#define ATTACH_DIRECT 0x04 + +static void throw_field_error_str(devs_ctx_t *ctx, unsigned attach_flags, const char *objdesc) { + const char *op = attach_flags & ATTACH_RW ? "setting" : "getting"; + char *objd = jd_strdup(objdesc); + + if (devs_is_undefined(ctx->diag_field)) + devs_throw_type_error(ctx, "%s fields of %s", op, objd); + else + devs_throw_type_error(ctx, "%s field '%s' of %s", op, devs_show_value(ctx, ctx->diag_field), + objd); + + jd_free(objd); +} + +static void throw_field_error(devs_ctx_t *ctx, unsigned attach_flags, value_t v) { + throw_field_error_str(ctx, attach_flags, devs_show_value(ctx, v)); +} + +static devs_maplike_t *devs_get_static_proto(devs_ctx_t *ctx, int tp, unsigned attach_flags) { + if ((attach_flags & (ATTACH_DIRECT | ATTACH_ENUM)) == ATTACH_ENUM) + return NULL; + + devs_maplike_t *r = devs_get_builtin_object(ctx, tp); + + // accessing prototype on static object - can't attach properties + if (attach_flags & ATTACH_RW) { + if (attach_flags & ATTACH_DIRECT) { + if (devs_is_builtin_proto(r)) { + throw_field_error_str(ctx, attach_flags, "a builtin frozen object"); + return NULL; + } else { + JD_ASSERT(devs_is_map(r)); + return r; + } + } else { + // note that in ES writing to string/... properties is no-op + // we make it an error + throw_field_error_str(ctx, attach_flags, "a primitive"); + return NULL; + } + } else { + return r; + } +} + +devs_map_t *devs_get_spec_proto(devs_ctx_t *ctx, uint32_t spec_idx) { + value_t r = devs_short_map_get(ctx, ctx->spec_protos, spec_idx); + if (!devs_is_undefined(r)) + return devs_value_to_gc_obj(ctx, r); + + devs_map_t *m = devs_any_try_alloc(ctx, DEVS_GC_TAG_HALF_STATIC_MAP, sizeof(devs_map_t)); + if (m == NULL) + return NULL; + value_t v = devs_value_from_gc_obj(ctx, m); + devs_value_pin(ctx, v); + m->proto = (const void *)devs_img_get_service_spec(ctx->img, spec_idx); + devs_short_map_set(ctx, ctx->spec_protos, spec_idx, v); + devs_value_unpin(ctx, v); + return m; +} + +devs_map_t *devs_get_role_proto(devs_ctx_t *ctx, unsigned roleidx) { + devs_role_t *r = devs_role(ctx, roleidx); + if (!r) + return NULL; + + const devs_service_spec_t *spec = devs_role_spec_for_class(ctx, r->jdrole->service_class); + int idx = devs_spec_idx(ctx, spec); + if (idx < 0) + return NULL; // ??? + + return devs_get_spec_proto(ctx, idx); +} + +static devs_maplike_t *devs_object_get_attached(devs_ctx_t *ctx, value_t v, unsigned attach_flags) { + static const uint8_t proto_by_object_type[] = { + [DEVS_OBJECT_TYPE_NUMBER] = DEVS_BUILTIN_OBJECT_NUMBER_PROTOTYPE, + [DEVS_OBJECT_TYPE_FIBER] = DEVS_BUILTIN_OBJECT_DSFIBER_PROTOTYPE, + [DEVS_OBJECT_TYPE_ROLE] = DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE, + [DEVS_OBJECT_TYPE_FUNCTION] = DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE, + [DEVS_OBJECT_TYPE_STRING] = DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE, + [DEVS_OBJECT_TYPE_BUFFER] = DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE, + [DEVS_OBJECT_TYPE_IMAGE] = DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE, + [DEVS_OBJECT_TYPE_BOOL] = DEVS_BUILTIN_OBJECT_BOOLEAN_PROTOTYPE, + [DEVS_OBJECT_TYPE_EXOTIC] = DEVS_BUILTIN_OBJECT_OBJECT_PROTOTYPE, + }; + + if (devs_is_null_or_undefined(v)) { + throw_field_error(ctx, attach_flags, v); + return NULL; + } + + int htp = devs_handle_type(v); + + if (htp == DEVS_HANDLE_TYPE_ROLE_MEMBER) { + unsigned roleidx; + int pt; + const devs_packet_spec_t *spec = devs_decode_role_packet(ctx, v, &roleidx); + if (roleidx == DEVS_ROLE_INVALID) + pt = devs_value_to_service_spec(ctx, v) ? DEVS_BUILTIN_OBJECT_DSSERVICESPEC_PROTOTYPE + : DEVS_BUILTIN_OBJECT_DSPACKETSPEC_PROTOTYPE; + else + switch (spec->code & DEVS_PACKETSPEC_CODE_MASK) { + case DEVS_PACKETSPEC_CODE_REGISTER: + pt = DEVS_BUILTIN_OBJECT_DSREGISTER_PROTOTYPE; + break; + case DEVS_PACKETSPEC_CODE_EVENT: + pt = DEVS_BUILTIN_OBJECT_DSEVENT_PROTOTYPE; + break; + case DEVS_PACKETSPEC_CODE_COMMAND: + pt = DEVS_BUILTIN_OBJECT_DSCOMMAND_PROTOTYPE; + break; + case DEVS_PACKETSPEC_CODE_REPORT: + pt = DEVS_BUILTIN_OBJECT_DSREPORT_PROTOTYPE; + break; + default: + JD_PANIC(); + } + return devs_get_static_proto(ctx, pt, attach_flags); + } + + if (htp == DEVS_HANDLE_TYPE_ROLE) { + unsigned roleidx = devs_handle_value(v); + devs_role_t *rl = devs_role(ctx, roleidx); + if (!rl) + return NULL; + const void *r = rl->attached; + if (r || (attach_flags & ATTACH_ENUM)) + return r; + r = devs_get_role_proto(ctx, roleidx); + if (!r) + return NULL; + if (attach_flags & ATTACH_RW) { + devs_map_t *m = devs_map_try_alloc(ctx, r); + rl->attached = m; + r = m; + } + return r; + } + + if (htp != DEVS_HANDLE_TYPE_GC_OBJECT) { + int pt = 0; + int tp = devs_value_typeof(ctx, v); + if (tp == DEVS_OBJECT_TYPE_MAP && devs_is_special(v)) { + uint32_t hv = devs_handle_value(v); + if (devs_handle_is_builtin(hv)) + return devs_get_static_proto(ctx, hv - DEVS_SPECIAL_BUILTIN_OBJ_FIRST, + attach_flags | ATTACH_DIRECT); + } + if (tp == DEVS_OBJECT_TYPE_FUNCTION) { + value_t this_val; + devs_activation_t *closure; + int fidx = devs_get_fnidx(ctx, v, &this_val, &closure); + if (fidx >= 0) { + value_t r = devs_short_map_get(ctx, ctx->fn_values, fidx); + if (devs_is_undefined(r) && attach_flags) { + r = devs_value_from_gc_obj( + ctx, + devs_map_try_alloc(ctx, devs_get_builtin_object( + ctx, DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE))); + if (!devs_is_undefined(r)) { + devs_value_pin(ctx, r); + devs_short_map_set(ctx, ctx->fn_values, fidx, r); + devs_value_unpin(ctx, r); + } + } + if (!devs_is_undefined(r)) + return devs_value_to_gc_obj(ctx, r); + } + } + if (tp < (int)sizeof(proto_by_object_type)) { + pt = proto_by_object_type[tp]; + } + JD_ASSERT(pt != 0); + return devs_get_static_proto(ctx, pt, attach_flags); + } + + devs_gc_object_t *obj = devs_handle_ptr_value(ctx, v); + + devs_map_t **attached; + int builtin; + + switch (devs_gc_tag(obj)) { + case DEVS_GC_TAG_BUFFER: + attached = &((devs_buffer_t *)obj)->attached; + builtin = DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE; + break; + case DEVS_GC_TAG_IMAGE: + attached = &((devs_gimage_t *)obj)->attached; + builtin = DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE; + break; + case DEVS_GC_TAG_ARRAY: + attached = &((devs_array_t *)obj)->attached; + builtin = DEVS_BUILTIN_OBJECT_ARRAY_PROTOTYPE; + break; + case DEVS_GC_TAG_PACKET: + attached = &((devs_packet_t *)obj)->attached; + builtin = DEVS_BUILTIN_OBJECT_DSPACKET_PROTOTYPE; + break; + case DEVS_GC_TAG_HALF_STATIC_MAP: + case DEVS_GC_TAG_MAP: + return (devs_maplike_t *)obj; + case DEVS_GC_TAG_STRING_JMP: + case DEVS_GC_TAG_STRING: + return devs_get_static_proto(ctx, DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE, attach_flags); + case DEVS_GC_TAG_BOUND_FUNCTION: + return devs_get_static_proto(ctx, DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE, attach_flags); + case DEVS_GC_TAG_BUILTIN_PROTO: + case DEVS_GC_TAG_SHORT_MAP: + default: + JD_PANIC(); + break; + } + + devs_map_t *map = *attached; + + if (!map && (attach_flags & ATTACH_RW)) { + map = *attached = devs_map_try_alloc(ctx, devs_get_builtin_object(ctx, builtin)); + if (map == NULL) + return NULL; + } + + if (map || (attach_flags & ATTACH_ENUM)) + return (devs_maplike_t *)map; + else + return devs_get_builtin_object(ctx, builtin); +} + +devs_map_t *devs_object_get_attached_rw(devs_ctx_t *ctx, value_t v) { + const void *r = devs_object_get_attached(ctx, v, ATTACH_RW); + JD_ASSERT(r == NULL || devs_is_map(r)); + ctx->diag_field = devs_undefined; + return (void *)r; +} + +devs_maplike_t *devs_object_get_attached_ro(devs_ctx_t *ctx, value_t v) { + devs_maplike_t *r = devs_object_get_attached(ctx, v, 0); + ctx->diag_field = devs_undefined; + return r; +} + +devs_maplike_t *devs_object_get_attached_enum(devs_ctx_t *ctx, value_t v) { + devs_maplike_t *r = devs_object_get_attached(ctx, v, ATTACH_ENUM); + ctx->diag_field = devs_undefined; + return r; +} + +devs_maplike_t *devs_maplike_get_proto(devs_ctx_t *ctx, devs_maplike_t *obj) { + const void *res; + + if (devs_is_builtin_proto(obj)) { + res = ((const devs_builtin_proto_t *)obj)->parent; + } else if (devs_is_service_spec(ctx, obj)) { + res = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE); + } else { + JD_ASSERT(devs_is_map(obj)); + devs_map_t *map = (devs_map_t *)obj; + return map->proto; + } + + if (res == NULL) + res = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_OBJECT_PROTOTYPE); + if (obj == res) // Object.prototype.__proto__ == NULL + return NULL; + return res; +} + +devs_maplike_t *devs_get_prototype_field(devs_ctx_t *ctx, value_t cls) { + value_t cls_proto_val = devs_object_get_built_in_field(ctx, cls, DEVS_BUILTIN_STRING_PROTOTYPE); + if (devs_is_undefined(cls_proto_val)) { + if (!ctx->in_throw) + devs_throw_type_error(ctx, "no .prototype"); + return NULL; + } else { + devs_maplike_t *cls_proto = devs_object_get_attached_enum(ctx, cls_proto_val); + if (cls_proto == NULL) + devs_throw_type_error(ctx, "invalid .prototype"); + return cls_proto; + } +} + +bool devs_instance_of(devs_ctx_t *ctx, value_t obj, devs_maplike_t *cls_proto) { + if (cls_proto == NULL || devs_is_nullish(obj)) + return false; + + devs_maplike_t *proto = devs_object_get_attached_ro(ctx, obj); + devs_maplike_t *en = devs_object_get_attached_enum(ctx, obj); + if (proto && proto == en) + proto = devs_maplike_get_proto(ctx, proto); + if (proto == NULL) + return false; + + while (proto) { + if (cls_proto == proto) + return true; + proto = devs_maplike_get_proto(ctx, proto); + } + + return false; +} + +value_t devs_maplike_get_no_bind(devs_ctx_t *ctx, devs_maplike_t *proto, value_t key) { + value_t ptmp, *tmp = NULL; + + while (proto) { + devs_map_t *map; + if (devs_is_builtin_proto(proto)) { + ptmp = devs_proto_lookup(ctx, (const devs_builtin_proto_t *)proto, key); + tmp = &ptmp; + break; + } else if (devs_is_service_spec(ctx, proto)) { + ptmp = devs_spec_lookup(ctx, (const devs_service_spec_t *)proto, key); + if (!devs_is_undefined(ptmp)) { + tmp = &ptmp; + break; + } else { + proto = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE); + continue; + } + } else { + JD_ASSERT(devs_is_map(proto)); + map = (devs_map_t *)proto; + tmp = lookup(ctx, map, key); + if (tmp) + break; + } + + proto = map->proto; + } + + if (tmp == NULL) + return devs_undefined; + return *tmp; +} + +value_t devs_object_get(devs_ctx_t *ctx, value_t obj, value_t key) { + ctx->diag_field = key; + value_t tmp = devs_maplike_get_no_bind(ctx, devs_object_get_attached_ro(ctx, obj), key); + return devs_function_bind(ctx, obj, tmp); +} + +value_t devs_object_get_built_in_field(devs_ctx_t *ctx, value_t obj, unsigned idx) { + value_t key = devs_builtin_string(idx); + ctx->diag_field = key; + value_t fn = devs_maplike_get_no_bind(ctx, devs_object_get_attached_ro(ctx, obj), key); + const devs_builtin_function_t *h = devs_get_property_desc(ctx, fn); + if (h) + return h->handler.prop(ctx, obj); + return fn; +} + +value_t devs_seq_get(devs_ctx_t *ctx, value_t seq, unsigned idx) { + if (idx > DEVS_MAX_ALLOC) + return devs_undefined; + + unsigned len; + const uint8_t *p = devs_bufferish_data(ctx, seq, &len); + if (p && idx < len) { + if (devs_is_string(ctx, seq)) { + int off = devs_string_index(ctx, seq, idx); + if (off < 0) + return devs_undefined; + p += off; + unsigned len = devs_utf8_code_point_length((const char *)p); + return devs_value_from_gc_obj(ctx, + devs_string_try_alloc_init(ctx, (const char *)p, len)); + } + return devs_value_from_int(p[idx]); + } + + devs_array_t *arr = devs_value_to_gc_obj(ctx, seq); + if (devs_gc_tag(arr) == DEVS_GC_TAG_ARRAY) { + if (idx < arr->length) + return arr->data[idx]; + } + + return devs_undefined; +} + +bool devs_looks_indexable(devs_ctx_t *ctx, value_t seq) { + return devs_is_array(ctx, seq) || devs_is_buffer(ctx, seq) || devs_is_string(ctx, seq); +} + +value_t devs_any_get(devs_ctx_t *ctx, value_t obj, value_t key) { + if (devs_is_number(key) && devs_looks_indexable(ctx, obj)) { + unsigned idx = devs_value_to_int(ctx, key); + return devs_seq_get(ctx, obj, idx); + } else if (devs_is_string(ctx, key)) { + return devs_object_get(ctx, obj, key); + } else { + key = devs_value_to_string(ctx, key); + devs_value_pin(ctx, key); + value_t res = devs_object_get(ctx, obj, key); + devs_value_unpin(ctx, key); + return res; + } +} + +void devs_any_set(devs_ctx_t *ctx, value_t obj, value_t key, value_t v) { + if (devs_is_number(key) && devs_looks_indexable(ctx, obj)) { + unsigned idx = devs_value_to_int(ctx, key); + devs_seq_set(ctx, obj, idx, v); + } else { + ctx->diag_field = key; + devs_map_t *map = devs_object_get_attached_rw(ctx, obj); + if (!map) + return; + if (devs_is_string(ctx, key)) + devs_map_set(ctx, map, key, v); + else { + key = devs_value_to_string(ctx, key); + devs_value_pin(ctx, key); + devs_map_set(ctx, map, key, v); + devs_value_unpin(ctx, key); + } + } +} + +static int array_ensure_len(devs_ctx_t *ctx, devs_array_t *arr, unsigned newlen) { + if (arr->capacity < newlen) { + newlen = grow_len(newlen); + value_t *newarr = devs_try_alloc(ctx, newlen * sizeof(value_t)); + if (newarr == NULL) + return -1; + if (arr->data) + memcpy(newarr, arr->data, sizeof(value_t) * arr->length); + arr->data = newarr; + arr->capacity = newlen; + jd_gc_unpin(ctx->gc, newarr); + } + return 0; +} + +void devs_array_set(devs_ctx_t *ctx, devs_array_t *arr, unsigned idx, value_t v) { + if (idx > DEVS_MAX_ALLOC / sizeof(value_t)) + devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); + else { + if (array_ensure_len(ctx, arr, idx + 1) != 0) + return; + arr->data[idx] = v; + if (idx >= arr->length) + arr->length = idx + 1; + } +} + +void devs_array_pin_push(devs_ctx_t *ctx, devs_array_t *arr, value_t v) { + devs_value_pin(ctx, v); + devs_array_set(ctx, arr, arr->length, v); + devs_value_unpin(ctx, v); +} + +void devs_seq_set(devs_ctx_t *ctx, value_t seq, unsigned idx, value_t v) { + if (idx > DEVS_MAX_ALLOC) { + devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); + } else if (devs_buffer_is_writable(ctx, seq)) { + unsigned len; + uint8_t *p = devs_buffer_data(ctx, seq, &len); + if (idx < len) { + p[idx] = devs_value_to_int(ctx, v) & 0xff; + } else { + devs_throw_range_error(ctx, "buffer write at %u, len=%u", idx, len); + } + } else { + devs_array_t *arr = devs_value_to_gc_obj(ctx, seq); + if (devs_gc_tag(arr) == DEVS_GC_TAG_ARRAY) { + devs_array_set(ctx, arr, idx, v); + } else { + devs_throw_expecting_error(ctx, DEVS_BUILTIN_STRING_ARRAY, seq); + } + } +} + +int devs_array_insert(devs_ctx_t *ctx, devs_array_t *arr, unsigned idx, int count) { + if (count > (int)(DEVS_MAX_ALLOC / sizeof(value_t))) { + devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); + return -4; + } + + int newlen = arr->length + count; + if (newlen < 0) { + count = -arr->length; + newlen = 0; + } + + if (count == 0) + return 0; + + if (newlen > (int)(DEVS_MAX_ALLOC / sizeof(value_t))) { + devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); + return -6; + } + + if (idx > arr->length) + idx = arr->length; + + if (array_ensure_len(ctx, arr, newlen)) + return -5; + + unsigned trailing = arr->length - idx; + + if (count < 0) { + count = -count; + memmove(arr->data + idx, arr->data + idx + count, sizeof(value_t) * (trailing - count)); + } else { + memmove(arr->data + idx + count, arr->data + idx, sizeof(value_t) * trailing); + memset(arr->data + idx, 0, count * sizeof(value_t)); + } + arr->length = newlen; + + return 0; +} + +int32_t devs_arg_int_defl(devs_ctx_t *ctx, unsigned idx, int32_t defl) { + value_t arg = devs_arg(ctx, idx); + if (devs_is_null_or_undefined(arg)) + return defl; + return devs_value_to_int(ctx, arg); +} + +int32_t devs_arg_int(devs_ctx_t *ctx, unsigned idx) { + return devs_value_to_int(ctx, devs_arg(ctx, idx)); +} + +bool devs_arg_bool(devs_ctx_t *ctx, unsigned idx) { + return devs_value_to_bool(ctx, devs_arg(ctx, idx)); +} + +double devs_arg_double(devs_ctx_t *ctx, unsigned idx) { + return devs_value_to_double(ctx, devs_arg(ctx, idx)); +} + +const char *devs_arg_utf8_with_conv(devs_ctx_t *ctx, unsigned idx, unsigned *sz) { + // store it on the stack, so it doesn't get GCed + ctx->the_stack[idx + 1] = devs_value_to_string(ctx, devs_arg(ctx, idx)); + return devs_string_get_utf8(ctx, devs_arg(ctx, idx), sz); +} + +void devs_ret_double(devs_ctx_t *ctx, double v) { + devs_ret(ctx, devs_value_from_double(v)); +} + +void devs_ret_int(devs_ctx_t *ctx, int v) { + devs_ret(ctx, devs_value_from_int(v)); +} + +void devs_ret_bool(devs_ctx_t *ctx, bool v) { + devs_ret(ctx, devs_value_from_bool(v)); +} + +void devs_ret_gc_ptr(devs_ctx_t *ctx, void *v) { + devs_ret(ctx, devs_value_from_gc_obj(ctx, v)); +} + +devs_map_t *devs_arg_self_map(devs_ctx_t *ctx) { + value_t s = devs_arg_self(ctx); + void *p = devs_value_to_gc_obj(ctx, s); + if (devs_is_map(p)) + return p; + devs_throw_type_error(ctx, "object expected"); + return NULL; +} + +void devs_setup_resume(devs_fiber_t *f, devs_resume_cb_t cb, void *userdata) { + if (devs_did_yield(f->ctx)) { + f->resume_cb = cb; + f->resume_data = userdata; + } else { + cb(f->ctx, userdata); + } +} + +bool devs_can_attach(devs_ctx_t *ctx, value_t v) { + switch (devs_value_typeof(ctx, v)) { + case DEVS_OBJECT_TYPE_MAP: + case DEVS_OBJECT_TYPE_ROLE: + case DEVS_OBJECT_TYPE_ARRAY: + case DEVS_OBJECT_TYPE_BUFFER: + case DEVS_OBJECT_TYPE_IMAGE: + return true; + default: + return false; + } +} + +value_t devs_builtin_object_value(devs_ctx_t *ctx, unsigned idx) { + if (idx > DEVS_BUILTIN_OBJECT___MAX) + return devs_undefined; + + devs_maplike_t *p = devs_get_builtin_object(ctx, idx); + if (devs_is_builtin_proto(p)) + return devs_value_from_handle(DEVS_HANDLE_TYPE_SPECIAL, + DEVS_SPECIAL_BUILTIN_OBJ_FIRST + idx); + else + return devs_value_from_gc_obj(ctx, (void *)p); +} + +value_t devs_maplike_to_value(devs_ctx_t *ctx, devs_maplike_t *obj) { + if (devs_is_builtin_proto(obj)) { + return devs_builtin_object_value(ctx, + (const devs_builtin_proto_t *)obj - devs_builtin_protos); + } else if (devs_is_service_spec(ctx, obj)) { + // this shouldn't happen + return devs_undefined; + } else { + JD_ASSERT(devs_is_map(obj)); + devs_map_t *map = (devs_map_t *)obj; + if (devs_gc_tag(map) == DEVS_GC_TAG_HALF_STATIC_MAP && devs_is_builtin_proto(map->proto)) + return devs_maplike_to_value(ctx, map->proto); + return devs_value_from_gc_obj(ctx, map); + } +} \ No newline at end of file diff --git a/aici_abi/src/cfg.rs b/aici_abi/src/cfg.rs index 695220bf..5a1cc4ef 100644 --- a/aici_abi/src/cfg.rs +++ b/aici_abi/src/cfg.rs @@ -480,9 +480,9 @@ impl Recognizer for CfgParser { #[allow(dead_code)] pub fn cfg_test() -> Result<()> { - let yacc_bytes = include_bytes!("../../grammars/c.y"); + let yacc_bytes = include_bytes!("../grammars/c.y"); let mut cfg = CfgParser::from_yacc(&String::from_utf8_lossy(yacc_bytes)).unwrap(); - let sample = include_bytes!("../../grammars/sample.c"); + let sample = include_bytes!("../grammars/sample.c"); if true { let trie = TokTrie::from_host(); From 4724bd35a126221004e81b67caa5260b4bdba5ab Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Jan 2024 18:27:20 +0000 Subject: [PATCH 129/301] AiciVm -> AiciCtrl --- aici_abi/README.md | 6 +++--- aici_abi/src/lib.rs | 2 +- aici_abi/src/recognizer.rs | 4 ++-- aici_abi/src/uppercase.rs | 4 ++-- aici_abi/src/yesno.rs | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/aici_abi/README.md b/aici_abi/README.md index 78a4e7ad..bda556b8 100644 --- a/aici_abi/README.md +++ b/aici_abi/README.md @@ -11,7 +11,7 @@ Conceptually, the lowest level interface to AICI constraint is this: type TokenId = u32; type SeqId = u32; -trait AiciVm { +trait AiciCtrl { /// Called with the initial prompt. ~1000ms time limit. fn init_prompt(prompt: Vec); @@ -41,7 +41,7 @@ The actual binary interface is a bit more complicated, due to limitations in passing values to and from Wasm. A Wasm module instance is created for each token sequence. Also, when the sequence forks (as in beam search), the module instance is cloned. -See the [AiciVm Rust trait](aici_abi/src/lib.rs) for details. +See the [AiciCtrl Rust trait](aici_abi/src/lib.rs) for details. A number of functions are exposed to the Wasm module. @@ -111,7 +111,7 @@ pub trait Recognizer { } ``` -The `AiciRecognizer` struct converts `Recognizer` to `AiciVm`. +The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. ## Functional byte interface diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 73b6fd50..97c0307b 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -147,7 +147,7 @@ impl PreProcessResult { } } -pub trait AiciVm { +pub trait AiciCtrl { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. fn init_prompt(&mut self, _arg: InitPromptArg) -> InitPromptResult { diff --git a/aici_abi/src/recognizer.rs b/aici_abi/src/recognizer.rs index c4ce4041..ec6d8767 100644 --- a/aici_abi/src/recognizer.rs +++ b/aici_abi/src/recognizer.rs @@ -1,6 +1,6 @@ use crate::{ toktree::{Recognizer, SpecialToken, TokTrie}, - AiciVm, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, + AiciCtrl, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, }; use std::fmt::Debug; @@ -18,7 +18,7 @@ impl AiciRecognizer { } } -impl AiciVm for AiciRecognizer { +impl AiciCtrl for AiciRecognizer { fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { let mut set = self.trie.alloc_token_set(); self.trie.compute_bias(&mut self.rec, &mut set); diff --git a/aici_abi/src/uppercase.rs b/aici_abi/src/uppercase.rs index 87cc0f05..f37b7f8b 100644 --- a/aici_abi/src/uppercase.rs +++ b/aici_abi/src/uppercase.rs @@ -2,7 +2,7 @@ use aici_abi::{ recognizer::{FunctionalRecognizer, StackRecognizer}, tokenize, toktree::{SpecialToken, TokTrie}, - AiciVm, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, PostProcessArg, + AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, }; @@ -51,7 +51,7 @@ impl Runner { } } -impl AiciVm for Runner { +impl AiciCtrl for Runner { fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { // with VMs, the prompt is often empty, but let's print it println!( diff --git a/aici_abi/src/yesno.rs b/aici_abi/src/yesno.rs index e593e684..36fede93 100644 --- a/aici_abi/src/yesno.rs +++ b/aici_abi/src/yesno.rs @@ -1,5 +1,5 @@ use aici_abi::{ - tokenize, toktree::TokTrie, AiciVm, InitPromptArg, InitPromptResult, MidProcessArg, + tokenize, toktree::TokTrie, AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, TokenId, }; @@ -24,7 +24,7 @@ impl Runner { } } -impl AiciVm for Runner { +impl AiciCtrl for Runner { fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { if arg.prompt.len() < 2 { // we'll be forcing answer; require a question From 666d0c06be31ddaeff7b847e20ca9bb7b5ffdfef Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Jan 2024 18:41:25 +0000 Subject: [PATCH 130/301] vm -> controller in docs etc --- aici_abi/src/uppercase.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/src/uppercase.rs b/aici_abi/src/uppercase.rs index f37b7f8b..c97962dc 100644 --- a/aici_abi/src/uppercase.rs +++ b/aici_abi/src/uppercase.rs @@ -53,7 +53,7 @@ impl Runner { impl AiciCtrl for Runner { fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { - // with VMs, the prompt is often empty, but let's print it + // when using AICI Controllers, the prompt is often empty, but let's print it println!( "init_prompt: {:?} {:?}", arg.prompt, From f1bcf49b7aabf7f5e7201f96440d92d5630ba71a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Jan 2024 19:24:51 +0000 Subject: [PATCH 131/301] separate sample project folder; readme movements --- aici_abi/Cargo.toml | 5 -- aici_abi/src/uppercase.rs | 109 -------------------------------------- 2 files changed, 114 deletions(-) delete mode 100644 aici_abi/src/uppercase.rs diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index 4b296c24..5e7e3ffe 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -23,11 +23,6 @@ default = ["cfg", "rx"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] -[[bin]] -name = "uppercase" -path = "src/uppercase.rs" - - [[bin]] name = "yesno" path = "src/yesno.rs" \ No newline at end of file diff --git a/aici_abi/src/uppercase.rs b/aici_abi/src/uppercase.rs deleted file mode 100644 index c97962dc..00000000 --- a/aici_abi/src/uppercase.rs +++ /dev/null @@ -1,109 +0,0 @@ -use aici_abi::{ - recognizer::{FunctionalRecognizer, StackRecognizer}, - tokenize, - toktree::{SpecialToken, TokTrie}, - AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, PostProcessArg, - PostProcessResult, PreProcessArg, PreProcessResult, -}; - -// This constraints enforces an upper case letter every 4th byte -// The state is the position in the output stream -struct QuadUpper {} -impl FunctionalRecognizer for QuadUpper { - fn initial(&self) -> usize { - 0 - } - - fn append(&self, state: usize, _byte: u8) -> usize { - state + 1 - } - - fn byte_allowed(&self, state: usize, byte: u8) -> bool { - if state % 4 == 0 { - byte.is_ascii_uppercase() - } else { - true - } - } - - fn special_allowed(&self, _state: usize, tok: SpecialToken) -> bool { - match tok { - SpecialToken::EndOfSentence => false, - _ => false, - } - } -} - -pub struct Runner { - toktrie: TokTrie, - tokens: Vec, - rec: StackRecognizer, -} - -impl Runner { - pub fn new(aici_arg: Vec) -> Self { - println!("user passed in {} bytes", aici_arg.len()); - Runner { - toktrie: TokTrie::from_host(), - tokens: Vec::new(), - rec: StackRecognizer::from(QuadUpper {}), - } - } -} - -impl AiciCtrl for Runner { - fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { - // when using AICI Controllers, the prompt is often empty, but let's print it - println!( - "init_prompt: {:?} {:?}", - arg.prompt, - self.toktrie.decode_str(&arg.prompt) - ); - // result is currently empty - InitPromptResult::default() - } - - fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { - if self.tokens.is_empty() { - // if no tokens yet, send our prompt - let toks = tokenize("Here's a tweet:\n"); - PreProcessResult::ff_tokens(toks) - } else { - // otherwise just continue - the other option is to suspend - PreProcessResult::continue_() - } - } - - fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { - if self.tokens.len() > 50 { - // stop after 50 tokens - return MidProcessResult::Stop; - } - - // otherwise, compute bias according to our recognizer - let mut set = self.toktrie.alloc_token_set(); - self.toktrie.compute_bias(&mut self.rec, &mut set); - MidProcessResult::SampleWithBias { - allowed_tokens: set, - } - } - - fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { - // save our tokens - self.tokens.extend_from_slice(&arg.tokens); - // and update the state of our recognizer - self.toktrie.append_tokens(&mut self.rec, &arg.tokens); - // ::from_arg() will translate generation of EOS token into Stop instruction - PostProcessResult::from_arg(&arg) - } -} - -fn runner_from_env() -> Runner { - Runner::new(aici_abi::arg_bytes()) -} - -fn main() { - // test code here? -} - -aici_abi::aici_expose_all!(Runner, runner_from_env()); From c94e7fc56955539596b0e86f3b0dea2d92a1623e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 17 Jan 2024 16:31:24 +0000 Subject: [PATCH 132/301] Add new methods to SimpleVob struct --- aici_abi/src/svob.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/aici_abi/src/svob.rs b/aici_abi/src/svob.rs index 6863b0ac..8a5a5cd4 100644 --- a/aici_abi/src/svob.rs +++ b/aici_abi/src/svob.rs @@ -1,5 +1,5 @@ use crate::TokenId; -use std::fmt::Debug; +use std::{fmt::Debug, ops::Index}; #[derive(Clone)] pub struct SimpleVob { @@ -27,6 +27,12 @@ impl SimpleVob { Self { data: Vec::new() } } + pub fn alloc(size: usize) -> Self { + let mut r = Self::new(); + r.resize(size); + r + } + pub fn len(&self) -> usize { self.data.len() * BITS } @@ -51,6 +57,14 @@ impl SimpleVob { self.data[byte_idx] &= !(1 << bit_idx); } + pub fn set(&mut self, tok: TokenId, val: bool) { + if val { + self.allow_token(tok); + } else { + self.disallow_token(tok); + } + } + pub fn resize(&mut self, size: usize) { let new_size = size / BITS + 1; assert!(new_size >= self.data.len()); @@ -84,3 +98,15 @@ impl SimpleVob { } } } + +impl Index for SimpleVob { + type Output = bool; + + fn index(&self, index: usize) -> &Self::Output { + if self.is_allowed(index as TokenId) { + &true + } else { + &false + } + } +} From 27d5fdaa6de37a1d681fae05e6aee9178821893a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 23 Jan 2024 23:27:39 +0000 Subject: [PATCH 133/301] capture llama.cpp logs --- aici_abi/src/toktree.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/aici_abi/src/toktree.rs b/aici_abi/src/toktree.rs index 18c19869..d77263ed 100644 --- a/aici_abi/src/toktree.rs +++ b/aici_abi/src/toktree.rs @@ -122,7 +122,6 @@ impl TokTrie { let mut trie = TrieHash::new(0xff); let mut token_offsets = Vec::new(); let mut token_data = Vec::new(); - println!("info: {:?} wl={}", info, words.len()); assert!(info.vocab_size == words.len() as u32); for (idx, word) in words.iter().enumerate() { if word.len() > 0 { From 12d3b09b2165508b31a19a873ccbd30e706433a0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 27 Jan 2024 01:44:21 +0000 Subject: [PATCH 134/301] don't use prompt in samples --- aici_abi/src/host.rs | 5 +++++ aici_abi/src/lib.rs | 4 ++-- aici_abi/src/yesno.rs | 17 +++++------------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 27415cf6..7ed1de9c 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -79,6 +79,11 @@ pub fn arg_bytes() -> Vec { return std::fs::read("arg.json").unwrap(); } +pub fn arg_string() -> String { + String::from_utf8_lossy(&arg_bytes()).to_string() +} + + pub fn trie_bytes() -> Vec { #[cfg(target_arch = "wasm32")] return read_blob(unsafe { aici_host_token_trie() }, 0); diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 97c0307b..80c5588b 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -21,8 +21,8 @@ pub mod substring; pub type TokenId = bytes::TokenId; pub use host::{ - aici_stop, arg_bytes, return_logit_bias, self_seq_id, tokenize, tokenize_bytes, StorageCmd, - StorageOp, StorageResp, VariableStorage, + aici_stop, arg_bytes, arg_string, return_logit_bias, self_seq_id, tokenize, tokenize_bytes, + StorageCmd, StorageOp, StorageResp, VariableStorage, }; #[derive(Serialize, Deserialize, Debug)] diff --git a/aici_abi/src/yesno.rs b/aici_abi/src/yesno.rs index 36fede93..bbf903f4 100644 --- a/aici_abi/src/yesno.rs +++ b/aici_abi/src/yesno.rs @@ -1,11 +1,12 @@ use aici_abi::{ - tokenize, toktree::TokTrie, AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, - MidProcessResult, PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, TokenId, + arg_string, tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, + PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, TokenId, }; pub struct Runner { toktrie: TokTrie, tokens: Vec, + question: String, yes: TokenId, no: TokenId, } @@ -18,6 +19,7 @@ impl Runner { Runner { toktrie: TokTrie::from_host(), tokens: Vec::new(), + question: arg_string() + "\n", yes, no, } @@ -25,18 +27,9 @@ impl Runner { } impl AiciCtrl for Runner { - fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { - if arg.prompt.len() < 2 { - // we'll be forcing answer; require a question - panic!("prompt too short") - } - InitPromptResult::default() - } - fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { if self.tokens.is_empty() { - // Make sure the prompt ends with newline - let toks = tokenize("\n"); + let toks = tokenize(&self.question); PreProcessResult::ff_tokens(toks) } else { PreProcessResult::continue_() From 959e89270c856d32bee36b25d7aa2ffa63e86150 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 27 Jan 2024 02:03:30 +0000 Subject: [PATCH 135/301] fix examples --- aici_abi/src/lib.rs | 2 ++ aici_abi/src/yesno.rs | 9 ++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 80c5588b..18183c2f 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -27,6 +27,7 @@ pub use host::{ #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { + /// Typically just the start token if any. pub prompt: Vec, } @@ -150,6 +151,7 @@ impl PreProcessResult { pub trait AiciCtrl { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. + /// This is typically just the start token if any (REST API forces empty prompt). fn init_prompt(&mut self, _arg: InitPromptArg) -> InitPromptResult { InitPromptResult::default() } diff --git a/aici_abi/src/yesno.rs b/aici_abi/src/yesno.rs index bbf903f4..493f9ebf 100644 --- a/aici_abi/src/yesno.rs +++ b/aici_abi/src/yesno.rs @@ -6,7 +6,7 @@ use aici_abi::{ pub struct Runner { toktrie: TokTrie, tokens: Vec, - question: String, + question: Vec, yes: TokenId, no: TokenId, } @@ -19,7 +19,7 @@ impl Runner { Runner { toktrie: TokTrie::from_host(), tokens: Vec::new(), - question: arg_string() + "\n", + question: tokenize(&(arg_string() + "\n")), yes, no, } @@ -29,8 +29,7 @@ impl Runner { impl AiciCtrl for Runner { fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { if self.tokens.is_empty() { - let toks = tokenize(&self.question); - PreProcessResult::ff_tokens(toks) + PreProcessResult::ff_tokens(self.question.clone()) } else { PreProcessResult::continue_() } @@ -48,7 +47,7 @@ impl AiciCtrl for Runner { fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { // save our tokens self.tokens.extend_from_slice(&arg.tokens); - if self.tokens.len() >= 2 { + if self.tokens.len() >= self.question.len() + 1 { PostProcessResult::stop() } else { PostProcessResult::from_arg(&arg) From 04d1bf9a058ff616fe73580cacdb18cb802a6a46 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 29 Jan 2024 23:30:49 +0000 Subject: [PATCH 136/301] new rest API; fixes #39 --- aici_abi/src/host.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 7ed1de9c..6d78f065 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -246,11 +246,11 @@ pub fn tokenize_bytes(s: &[u8]) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); let res = vec_from_bytes(&r); - println!( - "tokenize_bytes: {:?} -> {:?}", - String::from_utf8_lossy(s), - res - ); + // println!( + // "tokenize_bytes: {:?} -> {:?}", + // String::from_utf8_lossy(s), + // res + // ); res } From 98ab80a17898bf80a2df939f28ecb27d2cd3fc3e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 29 Jan 2024 23:54:10 +0000 Subject: [PATCH 137/301] further docs updates --- aici_abi/src/host.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/src/host.rs b/aici_abi/src/host.rs index 6d78f065..d25665c5 100644 --- a/aici_abi/src/host.rs +++ b/aici_abi/src/host.rs @@ -259,7 +259,7 @@ pub fn tokenize(s: &str) -> Vec { let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; let r = read_blob(id, 4 * (s.len() / 3 + 10)); let res = vec_from_bytes(&r); - println!("tokenize: {:?} -> {:?}", s, res); + // println!("tokenize: {:?} -> {:?}", s, res); res } From e4cff487f7202ef23019fa9936b7c943351a749a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 30 Jan 2024 02:15:25 +0000 Subject: [PATCH 138/301] merge post/pre_process aicirt calls --- aici_abi/src/lib.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/aici_abi/src/lib.rs b/aici_abi/src/lib.rs index 18183c2f..69625545 100644 --- a/aici_abi/src/lib.rs +++ b/aici_abi/src/lib.rs @@ -43,12 +43,10 @@ pub struct PreProcessArg {} #[derive(Serialize, Deserialize, Debug)] pub struct PreProcessResult { - /// If no attention masks are returned - stop the sequence. - /// If one is returned - just continue with this mask. - /// If more than one attention mask is returned - fork the generation. - /// Attention mask of length 0 is equivalent [1.0, ..., 1.0]. - /// Otherwise, length of the mask should be the same as the number of prompt + generated tokens. - pub attention_masks: Vec>, + /// If 0 - stop the sequence. + /// If 1 - just continue. + /// If more than 1 - fork the generation. + pub num_forks: usize, pub suspend: bool, @@ -59,7 +57,7 @@ pub struct PreProcessResult { #[derive(Serialize, Deserialize, Debug)] pub struct MidProcessArg { - /// fork_group.len() == attention_masks.len(). + /// fork_group.len() == num_forks. /// Use host::self_seq_id() to get the ID of the current sequence. pub fork_group: Vec, } @@ -119,29 +117,29 @@ impl PostProcessResult { } impl PreProcessResult { - pub fn new(attention_masks: Vec>) -> Self { + pub fn new(num_forks: usize) -> Self { PreProcessResult { - attention_masks, + num_forks, suspend: false, ff_tokens: vec![], } } pub fn continue_() -> Self { - PreProcessResult::new(vec![vec![]]) + PreProcessResult::new(1) } pub fn suspend() -> Self { PreProcessResult { - attention_masks: vec![vec![]], + num_forks: 1, suspend: true, ff_tokens: vec![], } } pub fn stop() -> Self { - PreProcessResult::new(vec![]) + PreProcessResult::new(0) } pub fn ff_tokens(toks: Vec) -> Self { PreProcessResult { - attention_masks: vec![vec![]], + num_forks: 1, suspend: false, ff_tokens: toks, } From ce4cef6f85044d8017765ced0dd115567f6e20b0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 1 Feb 2024 00:44:39 +0000 Subject: [PATCH 139/301] specify min rust version --- aici_abi/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/aici_abi/Cargo.toml b/aici_abi/Cargo.toml index 5e7e3ffe..b39afd81 100644 --- a/aici_abi/Cargo.toml +++ b/aici_abi/Cargo.toml @@ -2,6 +2,7 @@ name = "aici_abi" version = "0.1.0" edition = "2021" +rust-version = "1.75.0" [lib] name = "aici_abi" From 1a29808c66e79ef0f830a1d21ab44f3635f51d10 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 1 Feb 2024 01:36:44 +0000 Subject: [PATCH 140/301] fix links --- aici_abi/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aici_abi/README.md b/aici_abi/README.md index bda556b8..e394b1cd 100644 --- a/aici_abi/README.md +++ b/aici_abi/README.md @@ -41,7 +41,7 @@ The actual binary interface is a bit more complicated, due to limitations in passing values to and from Wasm. A Wasm module instance is created for each token sequence. Also, when the sequence forks (as in beam search), the module instance is cloned. -See the [AiciCtrl Rust trait](aici_abi/src/lib.rs) for details. +See the [AiciCtrl Rust trait](src/lib.rs) for details. A number of functions are exposed to the Wasm module. @@ -89,7 +89,7 @@ To compute the set of tokens that match a string constraint, one needs go throug and apply the constraint. An efficient way to do this is walk a prefix tree (trie) of all tokens. The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints -implementing the [following interface](aici_abi/src/toktree.rs): +implementing the [following interface](src/toktree.rs): ```rust pub trait Recognizer { From 1f7b5527cf63142d879f4dd004daae459280d4d2 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 1 Feb 2024 15:28:11 -0800 Subject: [PATCH 141/301] Minor typos fix --- aici_abi/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aici_abi/README.md b/aici_abi/README.md index e394b1cd..a15dd336 100644 --- a/aici_abi/README.md +++ b/aici_abi/README.md @@ -1,6 +1,6 @@ # aici_abi -This crate specifies the application binary inferface (ABI) for the AICI Controllers. +This crate specifies the application binary interface (ABI) for the AICI Controllers. It also provides higher-level interfaces for implementing controllers. ## Low-level interface From 3928e0f321990d618ffde74c3a4f5d416da8c3a0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 9 Feb 2024 01:20:20 +0000 Subject: [PATCH 142/301] start moving files around; fixes #56 --- {aici_abi => controllers/aici_abi}/.cargo/config.toml | 0 {aici_abi => controllers/aici_abi}/Cargo.toml | 0 {aici_abi => controllers/aici_abi}/README.md | 0 {aici_abi => controllers/aici_abi}/grammars/c.y | 0 {aici_abi => controllers/aici_abi}/grammars/sample.c | 0 {aici_abi => controllers/aici_abi}/implementation.md | 0 {aici_abi => controllers/aici_abi}/src/bytes.rs | 0 {aici_abi => controllers/aici_abi}/src/cfg.rs | 0 {aici_abi => controllers/aici_abi}/src/host.rs | 0 {aici_abi => controllers/aici_abi}/src/lex.rs | 0 {aici_abi => controllers/aici_abi}/src/lib.rs | 0 {aici_abi => controllers/aici_abi}/src/recognizer.rs | 0 {aici_abi => controllers/aici_abi}/src/rng.rs | 0 {aici_abi => controllers/aici_abi}/src/rx.rs | 0 {aici_abi => controllers/aici_abi}/src/substring.rs | 0 {aici_abi => controllers/aici_abi}/src/svob.rs | 0 {aici_abi => controllers/aici_abi}/src/toktree.rs | 0 {aici_abi => controllers/aici_abi}/src/yesno.rs | 0 18 files changed, 0 insertions(+), 0 deletions(-) rename {aici_abi => controllers/aici_abi}/.cargo/config.toml (100%) rename {aici_abi => controllers/aici_abi}/Cargo.toml (100%) rename {aici_abi => controllers/aici_abi}/README.md (100%) rename {aici_abi => controllers/aici_abi}/grammars/c.y (100%) rename {aici_abi => controllers/aici_abi}/grammars/sample.c (100%) rename {aici_abi => controllers/aici_abi}/implementation.md (100%) rename {aici_abi => controllers/aici_abi}/src/bytes.rs (100%) rename {aici_abi => controllers/aici_abi}/src/cfg.rs (100%) rename {aici_abi => controllers/aici_abi}/src/host.rs (100%) rename {aici_abi => controllers/aici_abi}/src/lex.rs (100%) rename {aici_abi => controllers/aici_abi}/src/lib.rs (100%) rename {aici_abi => controllers/aici_abi}/src/recognizer.rs (100%) rename {aici_abi => controllers/aici_abi}/src/rng.rs (100%) rename {aici_abi => controllers/aici_abi}/src/rx.rs (100%) rename {aici_abi => controllers/aici_abi}/src/substring.rs (100%) rename {aici_abi => controllers/aici_abi}/src/svob.rs (100%) rename {aici_abi => controllers/aici_abi}/src/toktree.rs (100%) rename {aici_abi => controllers/aici_abi}/src/yesno.rs (100%) diff --git a/aici_abi/.cargo/config.toml b/controllers/aici_abi/.cargo/config.toml similarity index 100% rename from aici_abi/.cargo/config.toml rename to controllers/aici_abi/.cargo/config.toml diff --git a/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml similarity index 100% rename from aici_abi/Cargo.toml rename to controllers/aici_abi/Cargo.toml diff --git a/aici_abi/README.md b/controllers/aici_abi/README.md similarity index 100% rename from aici_abi/README.md rename to controllers/aici_abi/README.md diff --git a/aici_abi/grammars/c.y b/controllers/aici_abi/grammars/c.y similarity index 100% rename from aici_abi/grammars/c.y rename to controllers/aici_abi/grammars/c.y diff --git a/aici_abi/grammars/sample.c b/controllers/aici_abi/grammars/sample.c similarity index 100% rename from aici_abi/grammars/sample.c rename to controllers/aici_abi/grammars/sample.c diff --git a/aici_abi/implementation.md b/controllers/aici_abi/implementation.md similarity index 100% rename from aici_abi/implementation.md rename to controllers/aici_abi/implementation.md diff --git a/aici_abi/src/bytes.rs b/controllers/aici_abi/src/bytes.rs similarity index 100% rename from aici_abi/src/bytes.rs rename to controllers/aici_abi/src/bytes.rs diff --git a/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs similarity index 100% rename from aici_abi/src/cfg.rs rename to controllers/aici_abi/src/cfg.rs diff --git a/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs similarity index 100% rename from aici_abi/src/host.rs rename to controllers/aici_abi/src/host.rs diff --git a/aici_abi/src/lex.rs b/controllers/aici_abi/src/lex.rs similarity index 100% rename from aici_abi/src/lex.rs rename to controllers/aici_abi/src/lex.rs diff --git a/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs similarity index 100% rename from aici_abi/src/lib.rs rename to controllers/aici_abi/src/lib.rs diff --git a/aici_abi/src/recognizer.rs b/controllers/aici_abi/src/recognizer.rs similarity index 100% rename from aici_abi/src/recognizer.rs rename to controllers/aici_abi/src/recognizer.rs diff --git a/aici_abi/src/rng.rs b/controllers/aici_abi/src/rng.rs similarity index 100% rename from aici_abi/src/rng.rs rename to controllers/aici_abi/src/rng.rs diff --git a/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs similarity index 100% rename from aici_abi/src/rx.rs rename to controllers/aici_abi/src/rx.rs diff --git a/aici_abi/src/substring.rs b/controllers/aici_abi/src/substring.rs similarity index 100% rename from aici_abi/src/substring.rs rename to controllers/aici_abi/src/substring.rs diff --git a/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs similarity index 100% rename from aici_abi/src/svob.rs rename to controllers/aici_abi/src/svob.rs diff --git a/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs similarity index 100% rename from aici_abi/src/toktree.rs rename to controllers/aici_abi/src/toktree.rs diff --git a/aici_abi/src/yesno.rs b/controllers/aici_abi/src/yesno.rs similarity index 100% rename from aici_abi/src/yesno.rs rename to controllers/aici_abi/src/yesno.rs From 59b01aaa05966b2ac68f395f4d09de62ce632e12 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 15 Feb 2024 01:43:13 +0000 Subject: [PATCH 143/301] run first pre() callback in instantiate() --- controllers/aici_abi/src/lib.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 69625545..ae6a418f 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -41,7 +41,7 @@ pub struct SeqId(pub u32); #[derive(Serialize, Deserialize, Debug)] pub struct PreProcessArg {} -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct PreProcessResult { /// If 0 - stop the sequence. /// If 1 - just continue. @@ -55,6 +55,16 @@ pub struct PreProcessResult { pub ff_tokens: Vec, } +impl Default for PreProcessResult { + fn default() -> Self { + PreProcessResult { + num_forks: 1, + suspend: false, + ff_tokens: vec![], + } + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct MidProcessArg { /// fork_group.len() == num_forks. From 7b4e33a3963ec61e124154815d4fe8b294a95203 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 22 Feb 2024 00:37:17 +0000 Subject: [PATCH 144/301] add TokenSet.num_set/repr in pyctrl --- controllers/aici_abi/src/svob.rs | 4 ++++ controllers/aici_abi/src/toktree.rs | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 8a5a5cd4..6513080d 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -37,6 +37,10 @@ impl SimpleVob { self.data.len() * BITS } + pub fn num_set(&self) -> usize { + self.data.iter().map(|x| x.count_ones() as usize).sum() + } + pub unsafe fn as_ptr(&self) -> *const u32 { self.data.as_ptr() } diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index d77263ed..c64055a2 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -177,6 +177,29 @@ impl TokTrie { r } + pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { + let num_set = ts.num_set(); + let max_tok = std::cmp::min(100, num_set); + let mut token_names = Vec::new(); + for idx in 0..self.vocab_size() { + if ts.is_allowed(idx as TokenId) { + token_names.push(self.token_dbg(idx as TokenId)); + if token_names.len() >= max_tok { + break; + } + } + } + if token_names.len() < num_set { + token_names.push("...".to_string()); + } + format!( + "TokenSet: {}/{}; {}", + num_set, + self.vocab_size(), + token_names.join(", ") + ) + } + pub fn alloc_logits(&self) -> Vec { vec![0.0; self.vocab_size() + 1] } From 1012be809f09246c5ef5107816ac93ded500ccc1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 22 Feb 2024 01:21:01 +0000 Subject: [PATCH 145/301] CFG: don't add fake \n at the start of file --- controllers/aici_abi/src/lex.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/lex.rs b/controllers/aici_abi/src/lex.rs index 4ad1e672..b8679b66 100644 --- a/controllers/aici_abi/src/lex.rs +++ b/controllers/aici_abi/src/lex.rs @@ -241,9 +241,10 @@ impl Lexer { } pub fn file_start_state(&self) -> StateID { + self.initial.state // pretend we've just seen a newline at the beginning of the file // TODO: this should be configurable - self.dfa.next_state(self.initial.state, b'\n') + // self.dfa.next_state(self.initial.state, b'\n') } fn mk_state(&self, state: StateID) -> LexerState { From 2a09a6ffa6e23230e73ae556cfd9ddd19dabc08b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 22 Feb 2024 01:21:29 +0000 Subject: [PATCH 146/301] cfg: fix initial viable states --- controllers/aici_abi/src/cfg.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index 5a1cc4ef..61d2c9df 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -159,7 +159,7 @@ impl CfgParser { } let mut skip_patterns = vob![false; patterns.len()]; - let mut friendly_pattern_names = pat_idx_to_tidx + let friendly_pattern_names = pat_idx_to_tidx .iter() .map(|tok| grm.token_name(*tok).unwrap().to_string()) .collect::>(); @@ -173,7 +173,8 @@ impl CfgParser { let toks = grm.prod(*pidx); if let [Symbol::Token(tidx)] = toks { let idx = *tidx_to_pat_idx.get(&tidx).unwrap(); - friendly_pattern_names[idx] = rname.to_string(); + // this doesn't seem very useful + // friendly_pattern_names[idx] = rname.to_string(); if rname == "SKIP" { skip_patterns.set(idx, true); } @@ -191,7 +192,8 @@ impl CfgParser { // TIME: 27ms let dfa = Lexer::from(patterns, &mut vobset); - let parse_stacks = vec![vec![stable.start_state()]]; + let cfg_start = stable.start_state(); + let parse_stacks = vec![vec![cfg_start]]; let byte_state = ByteState { lexer_state: dfa.file_start_state(), @@ -242,6 +244,15 @@ impl CfgParser { cfg.vobset.pre_compute(); + // compute viable set of initial tokens + cfg.byte_states[0].viable = cfg.viable_vobidx(cfg_start); + if LOG_PARSER { + println!( + "initial viable: {:?}", + cfg.vobset.resolve(cfg.byte_states[0].viable) + ); + } + Ok(cfg) } @@ -313,7 +324,7 @@ impl CfgParser { fn try_push(&mut self, byte: Option) -> Option { let top = self.byte_states.last().unwrap().clone(); if LOG_PARSER { - print!("try_push: "); + print!("try_push[{}]: ", self.byte_states.len()); if let Some(b) = byte { print!("{:?}", b as char) } else { @@ -415,6 +426,11 @@ impl CfgParser { if self.vobset.and_is_zero(viable, ls.reachable) { None } else { + // print!( + // " {:?} {:?} ", + // self.vobset.resolve(viable), + // self.vobset.resolve(ls.reachable) + // ); Some(ByteState { lexer_state: ls.state, parse_stack_idx: pstack, From b3cdc2f9d7c7b2985db56ccb21fddf30c8afb231 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 22 Feb 2024 01:21:41 +0000 Subject: [PATCH 147/301] usability fixes --- controllers/aici_abi/src/toktree.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index c64055a2..72150d12 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -205,7 +205,14 @@ impl TokTrie { } pub fn token_dbg(&self, idx: u32) -> String { - format!("{:?}[{}]", self.token_str(idx), idx) + if idx == self.info.tok_eos { + "EOS".to_string() + } else if idx as usize >= self.vocab_size() { + format!("OOB[{}]", idx) + } else { + // format!("{:?}[{}]", self.token_str(idx), idx) + format!("{:?}", self.token_str(idx)) + } } pub fn token_str(&self, idx: u32) -> String { @@ -471,6 +478,8 @@ impl TokTrie { } } r.trie_finished(); + // revert the fake token + toks.disallow_token(defl_tok); } } From 9ca59827ed757eacb02cbb847f36590b0fa2a379 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 1 Mar 2024 00:24:37 +0000 Subject: [PATCH 148/301] playing with earley parsers --- controllers/aici_abi/Cargo.toml | 3 +- controllers/aici_abi/src/cfg.rs | 91 ++++--- controllers/aici_abi/src/earley.rs | 336 ++++++++++++++++++++++++ controllers/aici_abi/src/earley_yacc.rs | 90 +++++++ controllers/aici_abi/src/lib.rs | 5 + 5 files changed, 481 insertions(+), 44 deletions(-) create mode 100644 controllers/aici_abi/src/earley.rs create mode 100644 controllers/aici_abi/src/earley_yacc.rs diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index b39afd81..3e2a3acd 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -20,9 +20,10 @@ vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } [features] -default = ["cfg", "rx"] +default = ["cfg", "rx", "earley"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] +earley = ["rx", "dep:vob", "dep:rustc-hash"] [[bin]] name = "yesno" diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index 61d2c9df..675a0beb 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -66,45 +66,57 @@ fn quote_rx(name: &str) -> String { .collect::() } -impl CfgParser { - fn span_to_str(s: &Span, src: &str) -> String { - let mut line = 1; - let mut last_nl = 0; - for (idx, ch) in src.chars().enumerate() { - if idx == s.start() { - break; - } - if ch == '\n' { - line += 1; - last_nl = idx; - } +pub(crate) fn parse_rx_token(name: &str) -> String { + if is_rx(name) { + name[1..name.len() - 1].to_string() + } else { + quote_rx(name) + } +} + +fn span_to_str(s: &Span, src: &str) -> String { + let mut line = 1; + let mut last_nl = 0; + for (idx, ch) in src.chars().enumerate() { + if idx == s.start() { + break; + } + if ch == '\n' { + line += 1; + last_nl = idx; } - let column = s.start() - last_nl; - format!("({},{})", line, column) } + let column = s.start() - last_nl; + format!("({},{})", line, column) +} - pub fn from_yacc(yacc: &str) -> Result { - let grmkind = YaccKind::Original(cfgrammar::yacc::YaccOriginalActionKind::NoAction); - let grm = match YaccGrammar::new(grmkind, yacc) { - Ok(grm) => grm, - Err(e) => { - let err_str = e - .iter() - .map(|e| { - let spans = e - .spans() - .iter() - .map(|s| Self::span_to_str(s, yacc)) - .collect::>() - .join(", "); - format!("{}: {}", spans, e) - }) - .collect::>() - .join("\n"); - anyhow::bail!("yacc grammar errors:\n{}", err_str); - } - }; +pub(crate) fn parse_yacc(yacc: &str) -> Result { + let grmkind = YaccKind::Original(cfgrammar::yacc::YaccOriginalActionKind::NoAction); + let grm = match YaccGrammar::new(grmkind, yacc) { + Ok(grm) => grm, + Err(e) => { + let err_str = e + .iter() + .map(|e| { + let spans = e + .spans() + .iter() + .map(|s| span_to_str(s, yacc)) + .collect::>() + .join(", "); + format!("{}: {}", spans, e) + }) + .collect::>() + .join("\n"); + anyhow::bail!("yacc grammar errors:\n{}", err_str); + } + }; + Ok(grm) +} +impl CfgParser { + pub fn from_yacc(yacc: &str) -> Result { + let grm = parse_yacc(yacc)?; // TIME: all these annotation are for native release x86 build for C grammar // TIME: 27ms let (sgraph, stable) = match from_yacc(&grm, Minimiser::Pager) { @@ -143,14 +155,7 @@ impl CfgParser { let patterns = pat_idx_to_tidx .iter() - .map(|tok| { - let name = grm.token_name(*tok).unwrap(); - if is_rx(name) { - name[1..name.len() - 1].to_string() - } else { - quote_rx(name) - } - }) + .map(|tok| parse_rx_token(grm.token_name(*tok).unwrap())) .collect::>(); let mut tidx_to_pat_idx = FxHashMap::default(); diff --git a/controllers/aici_abi/src/earley.rs b/controllers/aici_abi/src/earley.rs new file mode 100644 index 00000000..bf094cf6 --- /dev/null +++ b/controllers/aici_abi/src/earley.rs @@ -0,0 +1,336 @@ +use std::{fmt::Debug, rc::Rc, vec}; + +use rustc_hash::FxHashMap; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SymIdx(u32); + +// format: +// symbol_index : rule_index +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RuleIdx { + data: u32, +} + +const SYM_IDX_BITS: u32 = 12; +const RULE_IDX_BITS: u32 = 10; +const DOT_POS_BITS: u32 = 7; +const TOK_POS_BITS: u32 = 64 - (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS); + +fn mask32(bits: u32) -> u32 { + (1 << bits) - 1 +} + +impl RuleIdx { + fn sym_idx(&self) -> SymIdx { + SymIdx(self.data >> RULE_IDX_BITS) + } + + fn sym_rule_idx(&self) -> usize { + (self.data & mask32(RULE_IDX_BITS)) as usize + } +} + +impl SymIdx { + fn rule_at(&self, rule: usize) -> RuleIdx { + assert!(rule < mask32(RULE_IDX_BITS) as usize); + RuleIdx { + data: (self.0 << RULE_IDX_BITS) | rule as u32, + } + } +} + +struct Symbol { + idx: SymIdx, + name: String, + rx: Option, + rules: Vec, + nullable: bool, +} + +struct Rule { + idx: RuleIdx, + rhs: Vec, +} + +impl Rule { + fn lhs(&self) -> SymIdx { + self.idx.sym_idx() + } +} + +pub struct Grammar { + symbols: Vec, + symbol_by_name: FxHashMap, +} + +// format: +// token_position : dot_position : symbol_index : rule_index +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct Item(u64); + +pub struct Row { + token: SymIdx, + // TODO index this by .after_dot() ? + items: Vec, +} + +impl Row { + +} + +impl Item { + fn new(rule: RuleIdx, dot: usize, start: usize) -> Self { + assert!(start < mask32(TOK_POS_BITS) as usize); + assert!(dot < mask32(DOT_POS_BITS) as usize); + let data = (start as u64) << (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS) + | (dot as u64) << (SYM_IDX_BITS + RULE_IDX_BITS) + | (rule.data as u64); + Item(data) + } + + fn rule_idx(&self) -> RuleIdx { + RuleIdx { + data: self.0 as u32 & mask32(SYM_IDX_BITS + RULE_IDX_BITS), + } + } + + fn dot_pos(&self) -> usize { + (self.0 >> (SYM_IDX_BITS + RULE_IDX_BITS)) as usize & mask32(DOT_POS_BITS) as usize + } + + fn start_pos(&self) -> usize { + (self.0 >> (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS)) as usize + & mask32(TOK_POS_BITS) as usize + } + + fn advance_dot(&self) -> Self { + Item::new(self.rule_idx(), self.dot_pos() + 1, self.start_pos()) + } +} + +pub struct Parser { + grammar: Rc, + rows: Vec, +} + +impl Parser { + pub fn new(grammar: Rc) -> Self { + let start = grammar.start(); + let init_rules = grammar + .sym_data(start) + .rules + .iter() + .map(|r| Item::new(r.idx, 0, 0)) + .collect(); + let mut r = Parser { + grammar, + rows: vec![], + }; + // 'start' token is bogus + r.push_row(init_rules, start); + r + } + + fn after_dot(&self, item: Item) -> Option { + let rule = self.grammar.rule_data(item.rule_idx()); + if item.dot_pos() < rule.rhs.len() { + Some(rule.rhs[item.dot_pos()]) + } else { + None + } + } + + fn scan(&mut self, token: SymIdx) { + let next_row = self.items_with_after_dot(token, self.rows.len() - 1); + self.push_row(next_row, token); + } + + fn items_with_after_dot(&self, sym: SymIdx, row_idx: usize) -> Vec { + let mut r = vec![]; + for item in &self.rows[row_idx].items { + if self.after_dot(*item) == Some(sym) { + r.push(*item); + } + } + r + } + + fn push_row(&mut self, mut curr_row: Vec, token: SymIdx) { + let curr_idx = self.rows.len(); + let mut agenda = curr_row.clone(); + let mut predicated_syms = vec![]; + + while !agenda.is_empty() { + let item = agenda.pop().unwrap(); + let lhs = item.rule_idx().sym_idx(); + let mut to_add = vec![]; + match self.after_dot(item) { + Some(after_dot) => { + let sym_data = self.grammar.sym_data(after_dot); + if sym_data.nullable { + let new_item = item.advance_dot(); + if !to_add.contains(&new_item) { + to_add.push(new_item); + } + } + if !predicated_syms.contains(&after_dot) { + predicated_syms.push(after_dot); + for rule in &sym_data.rules { + let new_item = Item::new(rule.idx, 0, curr_idx); + if !to_add.contains(&new_item) { + to_add.push(new_item); + } + } + } + } + // complete + None => { + if item.start_pos() < curr_idx { + // if item.start_pos() == curr_idx, then we handled it above in the nullable check + for parent in self.items_with_after_dot(lhs, item.start_pos()) { + let new_item = parent.advance_dot(); + if !to_add.contains(&new_item) { + to_add.push(new_item); + } + } + } + } + } + + for new_item in to_add { + if !curr_row.contains(&new_item) { + curr_row.push(new_item); + agenda.push(new_item); + } + } + } + + self.rows.push(Row { + token, + items: curr_row, + }); + } +} + +impl Grammar { + pub fn new() -> Self { + let mut r = Grammar { + symbols: vec![], + symbol_by_name: FxHashMap::default(), + }; + let _ = r.symbol("_start"); + r + } + + pub fn start(&self) -> SymIdx { + self.symbols[0].idx + } + + fn sym_data(&self, sym: SymIdx) -> &Symbol { + &self.symbols[sym.0 as usize] + } + + fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { + &mut self.symbols[sym.0 as usize] + } + + fn rule_data(&self, rule: RuleIdx) -> &Rule { + let sym = self.sym_data(rule.sym_idx()); + &sym.rules[rule.sym_rule_idx()] + } + + fn propagate_nullable(&mut self) { + loop { + let mut to_null = vec![]; + for sym in self.symbols.iter() { + for rule in sym.rules.iter() { + if rule.rhs.iter().all(|s| self.sym_data(*s).nullable) { + if !sym.nullable { + to_null.push(sym.idx); + } + } + } + } + if to_null.is_empty() { + break; + } + for sym in to_null { + self.sym_data_mut(sym).nullable = true; + } + } + } + + pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { + assert!(rhs.len() < mask32(DOT_POS_BITS) as usize); + + let is_nullable = rhs.iter().all(|s| self.sym_data(*s).nullable); + + if rhs.len() > 0 { + let sym = self.sym_data_mut(lhs); + sym.rules.push(Rule { + idx: lhs.rule_at(sym.rules.len()), + rhs, + }); + } + + if is_nullable { + self.sym_data_mut(lhs).nullable = true; + self.propagate_nullable(); + } + } + + pub fn make_terminal(&mut self, sym: SymIdx, rx: &str) { + self.symbols[sym.0 as usize].rx = Some(rx.to_string()); + } + + pub fn sym_name(&self, sym: SymIdx) -> &str { + &self.symbols[sym.0 as usize].name + } + + fn rule_to_string(&self, rule: &Rule) -> String { + let lhs = self.sym_name(rule.lhs()); + let rhs = rule + .rhs + .iter() + .map(|s| self.sym_name(*s)) + .collect::>() + .join(" "); + format!("{} ::= {}", lhs, rhs) + } + + pub fn symbol(&mut self, name: &str) -> SymIdx { + match self.symbol_by_name.get(name) { + Some(idx) => *idx, + None => { + let idx = SymIdx(self.symbols.len() as u32); + self.symbols.push(Symbol { + name: name.to_string(), + rx: None, + idx, + rules: vec![], + nullable: false, + }); + self.symbol_by_name.insert(name.to_string(), idx); + idx + } + } + } +} + +impl Debug for Grammar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for sym in &self.symbols { + match sym.rx { + Some(ref rx) => writeln!(f, "{} /= {:?}", sym.name, rx)?, + None => {} + } + } + for sym in &self.symbols { + for rule in &sym.rules { + writeln!(f, "{}", self.rule_to_string(rule))?; + } + } + Ok(()) + } +} diff --git a/controllers/aici_abi/src/earley_yacc.rs b/controllers/aici_abi/src/earley_yacc.rs new file mode 100644 index 00000000..f95e80d2 --- /dev/null +++ b/controllers/aici_abi/src/earley_yacc.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use cfgrammar::Symbol; + +use crate::{ + cfg::{parse_rx_token, parse_yacc}, + earley::Grammar, + toktree::TokTrie, +}; + +pub fn earley_grm_from_yacc(yacc: &str) -> Result { + let grm = parse_yacc(yacc)?; + + let mut res = Grammar::new(); + + for pidx in grm.iter_pidxs() { + let ridx = grm.prod_to_rule(pidx); + + let lhs = res.symbol(grm.rule_name_str(ridx)); + let rhs = grm + .prod(pidx) + .iter() + .map(|sym| match sym { + Symbol::Token(tidx) => { + let name = grm.token_name(*tidx).unwrap(); + let t = res.symbol(name); + res.make_terminal(t, &parse_rx_token(name)); + t + } + Symbol::Rule(ridx) => res.symbol(grm.rule_name_str(*ridx)), + }) + .collect(); + + res.add_rule(lhs, rhs); + } + + let start_sym = grm.rule_name_str(grm.start_rule_idx()); + println!("start_sym: {:?}", start_sym); + let ss = res.symbol(start_sym); + res.add_rule(res.start(), vec![ss]); + + Ok(res) +} + +#[allow(dead_code)] +pub fn earley_test(trie: TokTrie) { + let yacc_bytes = include_bytes!("../grammars/c.y"); + let cfg = earley_grm_from_yacc(&String::from_utf8_lossy(yacc_bytes)).unwrap(); + + println!("cfg: {:?}", cfg); + + let sample = include_bytes!("../grammars/sample.c"); + let toks = trie.greedy_tokenize(sample); + + println!("toks: {:?}", toks.len()); + + // #[cfg(not(target_arch = "wasm32"))] + // let t0 = std::time::Instant::now(); + + // let mut line = 1; + // let mut vob = trie.alloc_token_set(); + + // for tok in &toks[0..1000] { + // let tok = *tok; + // trie.compute_bias(&mut cfg, &mut vob); + // if !vob.is_allowed(tok) { + // println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); + // panic!(); + // } + // for b in trie.token(tok) { + // if *b == b'\n' { + // line += 1; + // } + // } + // if false { + // println!( + // "tok: {:?} {}; {}", + // trie.token_str(tok), + // vob.is_allowed(tok), + // cfg.get_stats() + // ); + // cfg.viable_now(); + // } + // trie.append_token(&mut cfg, tok); + // } + + // #[cfg(not(target_arch = "wasm32"))] + // println!("time: {:?} ", t0.elapsed()); + + // println!("stats: {}", cfg.get_stats()); +} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index ae6a418f..7a98308a 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -16,6 +16,11 @@ mod lex; #[cfg(feature = "rx")] pub mod rx; +#[cfg(feature = "earley")] +pub mod earley; +#[cfg(all(feature = "earley", feature = "cfg"))] +pub mod earley_yacc; + pub mod substring; pub type TokenId = bytes::TokenId; From e5688f4844a9fa5f77e00cf1c28149780364de85 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 1 Mar 2024 17:53:10 +0000 Subject: [PATCH 149/301] working on guidance deserialization --- controllers/aici_abi/Cargo.toml | 5 +- controllers/aici_abi/src/earley.rs | 55 ++- controllers/aici_abi/src/earley_yacc.rs | 65 +++- controllers/aici_abi/src/guidance.rs | 455 ++++++++++++++++++++++++ controllers/aici_abi/src/lib.rs | 1 + 5 files changed, 565 insertions(+), 16 deletions(-) create mode 100644 controllers/aici_abi/src/guidance.rs diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 3e2a3acd..ac8cf6eb 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -18,13 +18,14 @@ lrpar = { version = "0.13.3", optional = true } lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } +quick-protobuf = { version = "0.8.1", optional = true } [features] default = ["cfg", "rx", "earley"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] -earley = ["rx", "dep:vob", "dep:rustc-hash"] +earley = ["rx", "dep:vob", "dep:rustc-hash", "dep:quick-protobuf"] [[bin]] name = "yesno" -path = "src/yesno.rs" \ No newline at end of file +path = "src/yesno.rs" diff --git a/controllers/aici_abi/src/earley.rs b/controllers/aici_abi/src/earley.rs index bf094cf6..1f4c22a4 100644 --- a/controllers/aici_abi/src/earley.rs +++ b/controllers/aici_abi/src/earley.rs @@ -40,10 +40,41 @@ impl SymIdx { } } +pub struct ByteSet { + mask: [u32; 8], +} + +impl ByteSet { + pub fn new() -> Self { + ByteSet { mask: [0; 8] } + } + + pub fn add(&mut self, byte: u8) { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] |= 1 << bit; + } + + pub fn contains(&self, byte: u8) -> bool { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] & (1 << bit) != 0 + } + + pub fn from_range(start: u8, end: u8) -> Self { + let mut r = ByteSet::new(); + // TODO optimize + for b in start..=end { + r.add(b); + } + r + } +} + struct Symbol { idx: SymIdx, name: String, - rx: Option, + bytes: Option, rules: Vec, nullable: bool, } @@ -75,9 +106,7 @@ pub struct Row { items: Vec, } -impl Row { - -} +impl Row {} impl Item { fn new(rule: RuleIdx, dot: usize, start: usize) -> Self { @@ -280,8 +309,8 @@ impl Grammar { } } - pub fn make_terminal(&mut self, sym: SymIdx, rx: &str) { - self.symbols[sym.0 as usize].rx = Some(rx.to_string()); + pub fn make_terminal(&mut self, sym: SymIdx, bytes: ByteSet) { + self.symbols[sym.0 as usize].bytes = Some(bytes); } pub fn sym_name(&self, sym: SymIdx) -> &str { @@ -306,7 +335,7 @@ impl Grammar { let idx = SymIdx(self.symbols.len() as u32); self.symbols.push(Symbol { name: name.to_string(), - rx: None, + bytes: None, idx, rules: vec![], nullable: false, @@ -320,12 +349,12 @@ impl Grammar { impl Debug for Grammar { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for sym in &self.symbols { - match sym.rx { - Some(ref rx) => writeln!(f, "{} /= {:?}", sym.name, rx)?, - None => {} - } - } + // for sym in &self.symbols { + // match sym.bytes { + // Some(ref rx) => writeln!(f, "{} /= {:?}", sym.name, rx)?, + // None => {} + // } + // } for sym in &self.symbols { for rule in &sym.rules { writeln!(f, "{}", self.rule_to_string(rule))?; diff --git a/controllers/aici_abi/src/earley_yacc.rs b/controllers/aici_abi/src/earley_yacc.rs index f95e80d2..e3bb5793 100644 --- a/controllers/aici_abi/src/earley_yacc.rs +++ b/controllers/aici_abi/src/earley_yacc.rs @@ -1,9 +1,12 @@ use anyhow::Result; use cfgrammar::Symbol; +use quick_protobuf::MessageRead; +use rustc_hash::FxHashSet; use crate::{ cfg::{parse_rx_token, parse_yacc}, - earley::Grammar, + earley::{ByteSet, Grammar}, + guidance, toktree::TokTrie, }; @@ -41,6 +44,66 @@ pub fn earley_grm_from_yacc(yacc: &str) -> Result { Ok(res) } +pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { + let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); + let gg = guidance::Grammar::from_reader(&mut reader, bytes).unwrap(); + let mut grm = Grammar::new(); + + let symbols = gg + .nodes + .iter() + .map(|n| match &n.function_type { + guidance::mod_GrammarFunction::OneOffunction_type::join(n) => grm.symbol(&n.name), + guidance::mod_GrammarFunction::OneOffunction_type::select(n) => grm.symbol(&n.name), + guidance::mod_GrammarFunction::OneOffunction_type::byte(n) => { + assert!(n.byte.len() == 1); + let sym = grm.symbol(&format!("b'{}", n.byte[0])); + grm.make_terminal(sym, ByteSet::from_range(n.byte[0], n.byte[0])); + sym + } + guidance::mod_GrammarFunction::OneOffunction_type::byte_range(n) => { + assert!(n.byte_range.len() == 2); + let sym = grm.symbol(&format!("b'{}-{}", n.byte_range[0], n.byte_range[1])); + grm.make_terminal(sym, ByteSet::from_range(n.byte_range[0], n.byte_range[1])); + sym + } + guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { + grm.symbol(&n.name) + } + guidance::mod_GrammarFunction::OneOffunction_type::None => { + panic!("None function type in guidance::Grammar") + } + }) + .collect::>(); + + let set = FxHashSet::from_iter(symbols.iter()); + assert!(set.len() == symbols.len(), "duplicate symbols"); + + for (n, sym) in gg.nodes.iter().zip(symbols.iter()) { + let lhs = *sym; + match &n.function_type { + guidance::mod_GrammarFunction::OneOffunction_type::join(n) => { + let rhs = n.values.iter().map(|idx| symbols[*idx as usize]).collect(); + grm.add_rule(lhs, rhs); + } + guidance::mod_GrammarFunction::OneOffunction_type::select(n) => { + for v in &n.values { + grm.add_rule(lhs, vec![symbols[*v as usize]]); + } + } + guidance::mod_GrammarFunction::OneOffunction_type::byte(_) => {} + guidance::mod_GrammarFunction::OneOffunction_type::byte_range(_) => {} + guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { + // eos_token, bos_token etc + panic!("model_variable not implemented yet ({:?})", n.name); + } + guidance::mod_GrammarFunction::OneOffunction_type::None => panic!("???"), + } + } + + Ok(grm) +} + #[allow(dead_code)] pub fn earley_test(trie: TokTrie) { let yacc_bytes = include_bytes!("../grammars/c.y"); diff --git a/controllers/aici_abi/src/guidance.rs b/controllers/aici_abi/src/guidance.rs new file mode 100644 index 00000000..24273c82 --- /dev/null +++ b/controllers/aici_abi/src/guidance.rs @@ -0,0 +1,455 @@ +// Automatically generated rust module for '_serialization.proto' file +// pb-rs _serialization.proto + +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +#![allow(unknown_lints)] +#![allow(clippy::all)] +#![cfg_attr(rustfmt, rustfmt_skip)] + + +use std::borrow::Cow; +use std::collections::HashMap; +type KVMap = HashMap; +use quick_protobuf::{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result}; +use quick_protobuf::sizeofs::*; +use super::*; + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Grammar<'a> { + pub nodes: Vec>, +} + +impl<'a> MessageRead<'a> for Grammar<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.nodes.push(r.read_message::(bytes)?), + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for Grammar<'a> { + fn get_size(&self) -> usize { + 0 + + self.nodes.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + for s in &self.nodes { w.write_with_tag(10, |w| w.write_message(s))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct EngineCallResponse<'a> { + pub new_bytes: Cow<'a, [u8]>, + pub is_generated: bool, + pub new_bytes_prob: f32, + pub capture_groups: KVMap, Cow<'a, str>>, + pub capture_group_log_probs: KVMap, f32>, + pub new_token_count: i32, +} + +impl<'a> MessageRead<'a> for EngineCallResponse<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.new_bytes = r.read_bytes(bytes).map(Cow::Borrowed)?, + Ok(16) => msg.is_generated = r.read_bool(bytes)?, + Ok(29) => msg.new_bytes_prob = r.read_float(bytes)?, + Ok(34) => { + let (key, value) = r.read_map(bytes, |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?), |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?))?; + msg.capture_groups.insert(key, value); + } + Ok(42) => { + let (key, value) = r.read_map(bytes, |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?), |r, bytes| Ok(r.read_float(bytes)?))?; + msg.capture_group_log_probs.insert(key, value); + } + Ok(48) => msg.new_token_count = r.read_int32(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for EngineCallResponse<'a> { + fn get_size(&self) -> usize { + 0 + + if self.new_bytes == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.new_bytes).len()) } + + if self.is_generated == false { 0 } else { 1 + sizeof_varint(*(&self.is_generated) as u64) } + + if self.new_bytes_prob == 0f32 { 0 } else { 1 + 4 } + + self.capture_groups.iter().map(|(k, v)| 1 + sizeof_len(2 + sizeof_len((k).len()) + sizeof_len((v).len()))).sum::() + + self.capture_group_log_probs.iter().map(|(k, v)| 1 + sizeof_len(2 + sizeof_len((k).len()) + 4)).sum::() + + if self.new_token_count == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.new_token_count) as u64) } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.new_bytes != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.new_bytes))?; } + if self.is_generated != false { w.write_with_tag(16, |w| w.write_bool(*&self.is_generated))?; } + if self.new_bytes_prob != 0f32 { w.write_with_tag(29, |w| w.write_float(*&self.new_bytes_prob))?; } + for (k, v) in self.capture_groups.iter() { w.write_with_tag(34, |w| w.write_map(2 + sizeof_len((k).len()) + sizeof_len((v).len()), 10, |w| w.write_string(&**k), 18, |w| w.write_string(&**v)))?; } + for (k, v) in self.capture_group_log_probs.iter() { w.write_with_tag(42, |w| w.write_map(2 + sizeof_len((k).len()) + 4, 10, |w| w.write_string(&**k), 21, |w| w.write_float(*v)))?; } + if self.new_token_count != 0i32 { w.write_with_tag(48, |w| w.write_int32(*&self.new_token_count))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Byte<'a> { + pub byte: Cow<'a, [u8]>, + pub hidden: bool, + pub commit_point: bool, + pub nullable: bool, + pub capture_name: Cow<'a, str>, + pub temperature: f32, +} + +impl<'a> MessageRead<'a> for Byte<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.byte = r.read_bytes(bytes).map(Cow::Borrowed)?, + Ok(16) => msg.hidden = r.read_bool(bytes)?, + Ok(24) => msg.commit_point = r.read_bool(bytes)?, + Ok(32) => msg.nullable = r.read_bool(bytes)?, + Ok(42) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(53) => msg.temperature = r.read_float(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for Byte<'a> { + fn get_size(&self) -> usize { + 0 + + if self.byte == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.byte).len()) } + + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } + + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } + + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } + + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } + + if self.temperature == 0f32 { 0 } else { 1 + 4 } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.byte != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.byte))?; } + if self.hidden != false { w.write_with_tag(16, |w| w.write_bool(*&self.hidden))?; } + if self.commit_point != false { w.write_with_tag(24, |w| w.write_bool(*&self.commit_point))?; } + if self.nullable != false { w.write_with_tag(32, |w| w.write_bool(*&self.nullable))?; } + if self.capture_name != "" { w.write_with_tag(42, |w| w.write_string(&**&self.capture_name))?; } + if self.temperature != 0f32 { w.write_with_tag(53, |w| w.write_float(*&self.temperature))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct ByteRange<'a> { + pub byte_range: Cow<'a, [u8]>, + pub hidden: bool, + pub commit_point: bool, + pub capture_name: Cow<'a, str>, + pub temperature: f32, +} + +impl<'a> MessageRead<'a> for ByteRange<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.byte_range = r.read_bytes(bytes).map(Cow::Borrowed)?, + Ok(24) => msg.hidden = r.read_bool(bytes)?, + Ok(32) => msg.commit_point = r.read_bool(bytes)?, + Ok(42) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(53) => msg.temperature = r.read_float(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for ByteRange<'a> { + fn get_size(&self) -> usize { + 0 + + if self.byte_range == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.byte_range).len()) } + + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } + + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } + + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } + + if self.temperature == 0f32 { 0 } else { 1 + 4 } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.byte_range != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.byte_range))?; } + if self.hidden != false { w.write_with_tag(24, |w| w.write_bool(*&self.hidden))?; } + if self.commit_point != false { w.write_with_tag(32, |w| w.write_bool(*&self.commit_point))?; } + if self.capture_name != "" { w.write_with_tag(42, |w| w.write_string(&**&self.capture_name))?; } + if self.temperature != 0f32 { w.write_with_tag(53, |w| w.write_float(*&self.temperature))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Null { } + +impl<'a> MessageRead<'a> for Null { + fn from_reader(r: &mut BytesReader, _: &[u8]) -> Result { + r.read_to_end(); + Ok(Self::default()) + } +} + +impl MessageWrite for Null { } + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct ModelVariable<'a> { + pub name: Cow<'a, str>, + pub hidden: bool, + pub commit_point: bool, + pub capture_name: Cow<'a, str>, + pub nullable: bool, +} + +impl<'a> MessageRead<'a> for ModelVariable<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(16) => msg.hidden = r.read_bool(bytes)?, + Ok(24) => msg.commit_point = r.read_bool(bytes)?, + Ok(34) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(40) => msg.nullable = r.read_bool(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for ModelVariable<'a> { + fn get_size(&self) -> usize { + 0 + + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } + + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } + + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } + + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } + + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.name != "" { w.write_with_tag(10, |w| w.write_string(&**&self.name))?; } + if self.hidden != false { w.write_with_tag(16, |w| w.write_bool(*&self.hidden))?; } + if self.commit_point != false { w.write_with_tag(24, |w| w.write_bool(*&self.commit_point))?; } + if self.capture_name != "" { w.write_with_tag(34, |w| w.write_string(&**&self.capture_name))?; } + if self.nullable != false { w.write_with_tag(40, |w| w.write_bool(*&self.nullable))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Join<'a> { + pub nullable: bool, + pub values: Vec, + pub name: Cow<'a, str>, + pub hidden: bool, + pub commit_point: bool, + pub capture_name: Cow<'a, str>, + pub max_tokens: i32, +} + +impl<'a> MessageRead<'a> for Join<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(8) => msg.nullable = r.read_bool(bytes)?, + Ok(18) => msg.values = r.read_packed(bytes, |r, bytes| Ok(r.read_int32(bytes)?))?, + Ok(26) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(32) => msg.hidden = r.read_bool(bytes)?, + Ok(40) => msg.commit_point = r.read_bool(bytes)?, + Ok(50) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(56) => msg.max_tokens = r.read_int32(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for Join<'a> { + fn get_size(&self) -> usize { + 0 + + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } + + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*(s) as u64)).sum::()) } + + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } + + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } + + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } + + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } + + if self.max_tokens == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.max_tokens) as u64) } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.nullable != false { w.write_with_tag(8, |w| w.write_bool(*&self.nullable))?; } + w.write_packed_with_tag(18, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*(m) as u64))?; + if self.name != "" { w.write_with_tag(26, |w| w.write_string(&**&self.name))?; } + if self.hidden != false { w.write_with_tag(32, |w| w.write_bool(*&self.hidden))?; } + if self.commit_point != false { w.write_with_tag(40, |w| w.write_bool(*&self.commit_point))?; } + if self.capture_name != "" { w.write_with_tag(50, |w| w.write_string(&**&self.capture_name))?; } + if self.max_tokens != 0i32 { w.write_with_tag(56, |w| w.write_int32(*&self.max_tokens))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Select<'a> { + pub nullable: bool, + pub values: Vec, + pub name: Cow<'a, str>, + pub hidden: bool, + pub commit_point: bool, + pub capture_name: Cow<'a, str>, + pub max_tokens: i32, + pub recursive: bool, +} + +impl<'a> MessageRead<'a> for Select<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(8) => msg.nullable = r.read_bool(bytes)?, + Ok(18) => msg.values = r.read_packed(bytes, |r, bytes| Ok(r.read_int32(bytes)?))?, + Ok(26) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(32) => msg.hidden = r.read_bool(bytes)?, + Ok(40) => msg.commit_point = r.read_bool(bytes)?, + Ok(50) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, + Ok(56) => msg.max_tokens = r.read_int32(bytes)?, + Ok(64) => msg.recursive = r.read_bool(bytes)?, + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for Select<'a> { + fn get_size(&self) -> usize { + 0 + + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } + + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*(s) as u64)).sum::()) } + + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } + + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } + + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } + + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } + + if self.max_tokens == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.max_tokens) as u64) } + + if self.recursive == false { 0 } else { 1 + sizeof_varint(*(&self.recursive) as u64) } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if self.nullable != false { w.write_with_tag(8, |w| w.write_bool(*&self.nullable))?; } + w.write_packed_with_tag(18, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*(m) as u64))?; + if self.name != "" { w.write_with_tag(26, |w| w.write_string(&**&self.name))?; } + if self.hidden != false { w.write_with_tag(32, |w| w.write_bool(*&self.hidden))?; } + if self.commit_point != false { w.write_with_tag(40, |w| w.write_bool(*&self.commit_point))?; } + if self.capture_name != "" { w.write_with_tag(50, |w| w.write_string(&**&self.capture_name))?; } + if self.max_tokens != 0i32 { w.write_with_tag(56, |w| w.write_int32(*&self.max_tokens))?; } + if self.recursive != false { w.write_with_tag(64, |w| w.write_bool(*&self.recursive))?; } + Ok(()) + } +} + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct GrammarFunction<'a> { + pub function_type: guidance::mod_GrammarFunction::OneOffunction_type<'a>, +} + +impl<'a> MessageRead<'a> for GrammarFunction<'a> { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::join(r.read_message::(bytes)?), + Ok(18) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::select(r.read_message::(bytes)?), + Ok(26) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::byte(r.read_message::(bytes)?), + Ok(34) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::byte_range(r.read_message::(bytes)?), + Ok(42) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::model_variable(r.read_message::(bytes)?), + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl<'a> MessageWrite for GrammarFunction<'a> { + fn get_size(&self) -> usize { + 0 + + match self.function_type { + guidance::mod_GrammarFunction::OneOffunction_type::join(ref m) => 1 + sizeof_len((m).get_size()), + guidance::mod_GrammarFunction::OneOffunction_type::select(ref m) => 1 + sizeof_len((m).get_size()), + guidance::mod_GrammarFunction::OneOffunction_type::byte(ref m) => 1 + sizeof_len((m).get_size()), + guidance::mod_GrammarFunction::OneOffunction_type::byte_range(ref m) => 1 + sizeof_len((m).get_size()), + guidance::mod_GrammarFunction::OneOffunction_type::model_variable(ref m) => 1 + sizeof_len((m).get_size()), + guidance::mod_GrammarFunction::OneOffunction_type::None => 0, + } } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + match self.function_type { guidance::mod_GrammarFunction::OneOffunction_type::join(ref m) => { w.write_with_tag(10, |w| w.write_message(m))? }, + guidance::mod_GrammarFunction::OneOffunction_type::select(ref m) => { w.write_with_tag(18, |w| w.write_message(m))? }, + guidance::mod_GrammarFunction::OneOffunction_type::byte(ref m) => { w.write_with_tag(26, |w| w.write_message(m))? }, + guidance::mod_GrammarFunction::OneOffunction_type::byte_range(ref m) => { w.write_with_tag(34, |w| w.write_message(m))? }, + guidance::mod_GrammarFunction::OneOffunction_type::model_variable(ref m) => { w.write_with_tag(42, |w| w.write_message(m))? }, + guidance::mod_GrammarFunction::OneOffunction_type::None => {}, + } Ok(()) + } +} + +pub mod mod_GrammarFunction { + +use super::*; + +#[derive(Debug, PartialEq, Clone)] +pub enum OneOffunction_type<'a> { + join(guidance::Join<'a>), + select(guidance::Select<'a>), + byte(guidance::Byte<'a>), + byte_range(guidance::ByteRange<'a>), + model_variable(guidance::ModelVariable<'a>), + None, +} + +impl<'a> Default for OneOffunction_type<'a> { + fn default() -> Self { + OneOffunction_type::None + } +} + +} + diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 7a98308a..7d489ac0 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -20,6 +20,7 @@ pub mod rx; pub mod earley; #[cfg(all(feature = "earley", feature = "cfg"))] pub mod earley_yacc; +mod guidance; pub mod substring; From 31e00989fb8bdce85a8a6f142af0992379031455 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 1 Mar 2024 23:18:10 +0000 Subject: [PATCH 150/301] guidance grammar deserialization --- controllers/aici_abi/grammars/json0.guidance | Bin 0 -> 1326 bytes controllers/aici_abi/src/earley.rs | 365 --------- .../from_guidance.rs} | 89 +-- .../aici_abi/src/{ => earley}/guidance.rs | 1 + controllers/aici_abi/src/earley/mod.rs | 5 + controllers/aici_abi/src/earley/parser.rs | 704 ++++++++++++++++++ controllers/aici_abi/src/lib.rs | 3 - 7 files changed, 748 insertions(+), 419 deletions(-) create mode 100644 controllers/aici_abi/grammars/json0.guidance delete mode 100644 controllers/aici_abi/src/earley.rs rename controllers/aici_abi/src/{earley_yacc.rs => earley/from_guidance.rs} (66%) rename controllers/aici_abi/src/{ => earley}/guidance.rs (99%) create mode 100644 controllers/aici_abi/src/earley/mod.rs create mode 100644 controllers/aici_abi/src/earley/parser.rs diff --git a/controllers/aici_abi/grammars/json0.guidance b/controllers/aici_abi/grammars/json0.guidance new file mode 100644 index 0000000000000000000000000000000000000000..bcad296f58677710f08530ebefa086ef3f84dbb6 GIT binary patch literal 1326 zcmZvc+j0^?5Qfvs3S>e!Ob}y>kzmXc2@s=bLZUg~F`jP~%VHHGDxip`SXSkQdzpU(4SAX5}_wM#TDx*@`CFfe19e=QP`0IB~DUK`pR-c$S{HX#QQB!$? z+k#VST1VVtF4I@2FG_t*WwjfP#V0xDdxLhr<3Y_pYJ&6I5ceVOK-`d61QSN^Kn~l~ z4PBj1WoC|_IC<*ynfbHlyruKY)s+huFRv%lI-8s2{9K{9uvnsUrKVOlHm@?7>3{rj zw@D>cLOY2hlP{-D6(^O`J`gQ|X{Hwdcb*{BOfLbhtRu`abp!@)SwWU#>H>R?Y?f&s zn1_s+UI8m0%QN+WEhC#_dJU|KtibdJ*gUc#(_3I$$QGFTz-q`AnQ()iA}cWsq3;xDOWQu8w8mmknp>_jnFEM=*Hls4ysYf~Ld#%<0 zSGk1^SC~Fy! u32 { - (1 << bits) - 1 -} - -impl RuleIdx { - fn sym_idx(&self) -> SymIdx { - SymIdx(self.data >> RULE_IDX_BITS) - } - - fn sym_rule_idx(&self) -> usize { - (self.data & mask32(RULE_IDX_BITS)) as usize - } -} - -impl SymIdx { - fn rule_at(&self, rule: usize) -> RuleIdx { - assert!(rule < mask32(RULE_IDX_BITS) as usize); - RuleIdx { - data: (self.0 << RULE_IDX_BITS) | rule as u32, - } - } -} - -pub struct ByteSet { - mask: [u32; 8], -} - -impl ByteSet { - pub fn new() -> Self { - ByteSet { mask: [0; 8] } - } - - pub fn add(&mut self, byte: u8) { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] |= 1 << bit; - } - - pub fn contains(&self, byte: u8) -> bool { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] & (1 << bit) != 0 - } - - pub fn from_range(start: u8, end: u8) -> Self { - let mut r = ByteSet::new(); - // TODO optimize - for b in start..=end { - r.add(b); - } - r - } -} - -struct Symbol { - idx: SymIdx, - name: String, - bytes: Option, - rules: Vec, - nullable: bool, -} - -struct Rule { - idx: RuleIdx, - rhs: Vec, -} - -impl Rule { - fn lhs(&self) -> SymIdx { - self.idx.sym_idx() - } -} - -pub struct Grammar { - symbols: Vec, - symbol_by_name: FxHashMap, -} - -// format: -// token_position : dot_position : symbol_index : rule_index -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct Item(u64); - -pub struct Row { - token: SymIdx, - // TODO index this by .after_dot() ? - items: Vec, -} - -impl Row {} - -impl Item { - fn new(rule: RuleIdx, dot: usize, start: usize) -> Self { - assert!(start < mask32(TOK_POS_BITS) as usize); - assert!(dot < mask32(DOT_POS_BITS) as usize); - let data = (start as u64) << (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS) - | (dot as u64) << (SYM_IDX_BITS + RULE_IDX_BITS) - | (rule.data as u64); - Item(data) - } - - fn rule_idx(&self) -> RuleIdx { - RuleIdx { - data: self.0 as u32 & mask32(SYM_IDX_BITS + RULE_IDX_BITS), - } - } - - fn dot_pos(&self) -> usize { - (self.0 >> (SYM_IDX_BITS + RULE_IDX_BITS)) as usize & mask32(DOT_POS_BITS) as usize - } - - fn start_pos(&self) -> usize { - (self.0 >> (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS)) as usize - & mask32(TOK_POS_BITS) as usize - } - - fn advance_dot(&self) -> Self { - Item::new(self.rule_idx(), self.dot_pos() + 1, self.start_pos()) - } -} - -pub struct Parser { - grammar: Rc, - rows: Vec, -} - -impl Parser { - pub fn new(grammar: Rc) -> Self { - let start = grammar.start(); - let init_rules = grammar - .sym_data(start) - .rules - .iter() - .map(|r| Item::new(r.idx, 0, 0)) - .collect(); - let mut r = Parser { - grammar, - rows: vec![], - }; - // 'start' token is bogus - r.push_row(init_rules, start); - r - } - - fn after_dot(&self, item: Item) -> Option { - let rule = self.grammar.rule_data(item.rule_idx()); - if item.dot_pos() < rule.rhs.len() { - Some(rule.rhs[item.dot_pos()]) - } else { - None - } - } - - fn scan(&mut self, token: SymIdx) { - let next_row = self.items_with_after_dot(token, self.rows.len() - 1); - self.push_row(next_row, token); - } - - fn items_with_after_dot(&self, sym: SymIdx, row_idx: usize) -> Vec { - let mut r = vec![]; - for item in &self.rows[row_idx].items { - if self.after_dot(*item) == Some(sym) { - r.push(*item); - } - } - r - } - - fn push_row(&mut self, mut curr_row: Vec, token: SymIdx) { - let curr_idx = self.rows.len(); - let mut agenda = curr_row.clone(); - let mut predicated_syms = vec![]; - - while !agenda.is_empty() { - let item = agenda.pop().unwrap(); - let lhs = item.rule_idx().sym_idx(); - let mut to_add = vec![]; - match self.after_dot(item) { - Some(after_dot) => { - let sym_data = self.grammar.sym_data(after_dot); - if sym_data.nullable { - let new_item = item.advance_dot(); - if !to_add.contains(&new_item) { - to_add.push(new_item); - } - } - if !predicated_syms.contains(&after_dot) { - predicated_syms.push(after_dot); - for rule in &sym_data.rules { - let new_item = Item::new(rule.idx, 0, curr_idx); - if !to_add.contains(&new_item) { - to_add.push(new_item); - } - } - } - } - // complete - None => { - if item.start_pos() < curr_idx { - // if item.start_pos() == curr_idx, then we handled it above in the nullable check - for parent in self.items_with_after_dot(lhs, item.start_pos()) { - let new_item = parent.advance_dot(); - if !to_add.contains(&new_item) { - to_add.push(new_item); - } - } - } - } - } - - for new_item in to_add { - if !curr_row.contains(&new_item) { - curr_row.push(new_item); - agenda.push(new_item); - } - } - } - - self.rows.push(Row { - token, - items: curr_row, - }); - } -} - -impl Grammar { - pub fn new() -> Self { - let mut r = Grammar { - symbols: vec![], - symbol_by_name: FxHashMap::default(), - }; - let _ = r.symbol("_start"); - r - } - - pub fn start(&self) -> SymIdx { - self.symbols[0].idx - } - - fn sym_data(&self, sym: SymIdx) -> &Symbol { - &self.symbols[sym.0 as usize] - } - - fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { - &mut self.symbols[sym.0 as usize] - } - - fn rule_data(&self, rule: RuleIdx) -> &Rule { - let sym = self.sym_data(rule.sym_idx()); - &sym.rules[rule.sym_rule_idx()] - } - - fn propagate_nullable(&mut self) { - loop { - let mut to_null = vec![]; - for sym in self.symbols.iter() { - for rule in sym.rules.iter() { - if rule.rhs.iter().all(|s| self.sym_data(*s).nullable) { - if !sym.nullable { - to_null.push(sym.idx); - } - } - } - } - if to_null.is_empty() { - break; - } - for sym in to_null { - self.sym_data_mut(sym).nullable = true; - } - } - } - - pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { - assert!(rhs.len() < mask32(DOT_POS_BITS) as usize); - - let is_nullable = rhs.iter().all(|s| self.sym_data(*s).nullable); - - if rhs.len() > 0 { - let sym = self.sym_data_mut(lhs); - sym.rules.push(Rule { - idx: lhs.rule_at(sym.rules.len()), - rhs, - }); - } - - if is_nullable { - self.sym_data_mut(lhs).nullable = true; - self.propagate_nullable(); - } - } - - pub fn make_terminal(&mut self, sym: SymIdx, bytes: ByteSet) { - self.symbols[sym.0 as usize].bytes = Some(bytes); - } - - pub fn sym_name(&self, sym: SymIdx) -> &str { - &self.symbols[sym.0 as usize].name - } - - fn rule_to_string(&self, rule: &Rule) -> String { - let lhs = self.sym_name(rule.lhs()); - let rhs = rule - .rhs - .iter() - .map(|s| self.sym_name(*s)) - .collect::>() - .join(" "); - format!("{} ::= {}", lhs, rhs) - } - - pub fn symbol(&mut self, name: &str) -> SymIdx { - match self.symbol_by_name.get(name) { - Some(idx) => *idx, - None => { - let idx = SymIdx(self.symbols.len() as u32); - self.symbols.push(Symbol { - name: name.to_string(), - bytes: None, - idx, - rules: vec![], - nullable: false, - }); - self.symbol_by_name.insert(name.to_string(), idx); - idx - } - } - } -} - -impl Debug for Grammar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // for sym in &self.symbols { - // match sym.bytes { - // Some(ref rx) => writeln!(f, "{} /= {:?}", sym.name, rx)?, - // None => {} - // } - // } - for sym in &self.symbols { - for rule in &sym.rules { - writeln!(f, "{}", self.rule_to_string(rule))?; - } - } - Ok(()) - } -} diff --git a/controllers/aici_abi/src/earley_yacc.rs b/controllers/aici_abi/src/earley/from_guidance.rs similarity index 66% rename from controllers/aici_abi/src/earley_yacc.rs rename to controllers/aici_abi/src/earley/from_guidance.rs index e3bb5793..6da219d8 100644 --- a/controllers/aici_abi/src/earley_yacc.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -1,48 +1,16 @@ use anyhow::Result; -use cfgrammar::Symbol; use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; use crate::{ - cfg::{parse_rx_token, parse_yacc}, - earley::{ByteSet, Grammar}, - guidance, + earley::{ + guidance, + parser::{ByteSet, Parser}, + }, toktree::TokTrie, }; -pub fn earley_grm_from_yacc(yacc: &str) -> Result { - let grm = parse_yacc(yacc)?; - - let mut res = Grammar::new(); - - for pidx in grm.iter_pidxs() { - let ridx = grm.prod_to_rule(pidx); - - let lhs = res.symbol(grm.rule_name_str(ridx)); - let rhs = grm - .prod(pidx) - .iter() - .map(|sym| match sym { - Symbol::Token(tidx) => { - let name = grm.token_name(*tidx).unwrap(); - let t = res.symbol(name); - res.make_terminal(t, &parse_rx_token(name)); - t - } - Symbol::Rule(ridx) => res.symbol(grm.rule_name_str(*ridx)), - }) - .collect(); - - res.add_rule(lhs, rhs); - } - - let start_sym = grm.rule_name_str(grm.start_rule_idx()); - println!("start_sym: {:?}", start_sym); - let ss = res.symbol(start_sym); - res.add_rule(res.start(), vec![ss]); - - Ok(res) -} +use super::parser::Grammar; pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); @@ -53,22 +21,20 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { .nodes .iter() .map(|n| match &n.function_type { - guidance::mod_GrammarFunction::OneOffunction_type::join(n) => grm.symbol(&n.name), - guidance::mod_GrammarFunction::OneOffunction_type::select(n) => grm.symbol(&n.name), + guidance::mod_GrammarFunction::OneOffunction_type::join(n) => grm.fresh_symbol(&n.name), + guidance::mod_GrammarFunction::OneOffunction_type::select(n) => { + grm.fresh_symbol(&n.name) + } guidance::mod_GrammarFunction::OneOffunction_type::byte(n) => { assert!(n.byte.len() == 1); - let sym = grm.symbol(&format!("b'{}", n.byte[0])); - grm.make_terminal(sym, ByteSet::from_range(n.byte[0], n.byte[0])); - sym + grm.terminal(ByteSet::from_range(n.byte[0], n.byte[0])) } guidance::mod_GrammarFunction::OneOffunction_type::byte_range(n) => { assert!(n.byte_range.len() == 2); - let sym = grm.symbol(&format!("b'{}-{}", n.byte_range[0], n.byte_range[1])); - grm.make_terminal(sym, ByteSet::from_range(n.byte_range[0], n.byte_range[1])); - sym + grm.terminal(ByteSet::from_range(n.byte_range[0], n.byte_range[1])) } guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { - grm.symbol(&n.name) + grm.fresh_symbol(&n.name) } guidance::mod_GrammarFunction::OneOffunction_type::None => { panic!("None function type in guidance::Grammar") @@ -83,10 +49,17 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let lhs = *sym; match &n.function_type { guidance::mod_GrammarFunction::OneOffunction_type::join(n) => { + if n.nullable { + //println!("nullable join: {:?}", n.name); + } let rhs = n.values.iter().map(|idx| symbols[*idx as usize]).collect(); grm.add_rule(lhs, rhs); } guidance::mod_GrammarFunction::OneOffunction_type::select(n) => { + if n.nullable { + // println!("nullable sel: {:?} {:?}", n.name, n.values); + grm.add_rule(lhs, vec![]); + } for v in &n.values { grm.add_rule(lhs, vec![symbols[*v as usize]]); } @@ -101,21 +74,35 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { } } + grm.add_rule(grm.start(), vec![symbols[0]]); + Ok(grm) } #[allow(dead_code)] pub fn earley_test(trie: TokTrie) { - let yacc_bytes = include_bytes!("../grammars/c.y"); - let cfg = earley_grm_from_yacc(&String::from_utf8_lossy(yacc_bytes)).unwrap(); - + let g_bytes = include_bytes!("../../grammars/json0.guidance"); + let cfg = earley_grm_from_guidance(g_bytes).unwrap(); + println!("cfg0: {:?}", cfg); + let cfg = cfg.optimize(); println!("cfg: {:?}", cfg); - let sample = include_bytes!("../grammars/sample.c"); - let toks = trie.greedy_tokenize(sample); + let input = r#"{"name":"Joe","info":{"foo":10,"bar":"20"}}"#.as_bytes(); + let toks = trie.greedy_tokenize(input); println!("toks: {:?}", toks.len()); + let mut parser = Parser::new(cfg); + for b in input { + let row = parser.scan(*b); + if row.is_empty() { + println!("reject"); + break; + } + println!("row: {}", parser.row_to_string(&row)); + parser.push_row(row); + } + // #[cfg(not(target_arch = "wasm32"))] // let t0 = std::time::Instant::now(); diff --git a/controllers/aici_abi/src/guidance.rs b/controllers/aici_abi/src/earley/guidance.rs similarity index 99% rename from controllers/aici_abi/src/guidance.rs rename to controllers/aici_abi/src/earley/guidance.rs index 24273c82..c6493592 100644 --- a/controllers/aici_abi/src/guidance.rs +++ b/controllers/aici_abi/src/earley/guidance.rs @@ -5,6 +5,7 @@ #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(unused_imports)] +#![allow(unused_variables)] #![allow(unknown_lints)] #![allow(clippy::all)] #![cfg_attr(rustfmt, rustfmt_skip)] diff --git a/controllers/aici_abi/src/earley/mod.rs b/controllers/aici_abi/src/earley/mod.rs new file mode 100644 index 00000000..3e228090 --- /dev/null +++ b/controllers/aici_abi/src/earley/mod.rs @@ -0,0 +1,5 @@ +pub mod parser; +mod guidance; +mod from_guidance; + +pub use from_guidance::earley_test; \ No newline at end of file diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs new file mode 100644 index 00000000..9311edcc --- /dev/null +++ b/controllers/aici_abi/src/earley/parser.rs @@ -0,0 +1,704 @@ +use std::{ + fmt::{Debug, Display}, + rc::Rc, + vec, +}; + +use rustc_hash::FxHashMap; + +const DEBUG: bool = false; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SymIdx(u32); + +// format: +// symbol_index : rule_index +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RuleIdx { + data: u32, +} + +const SYM_IDX_BITS: u32 = 12; +const RULE_IDX_BITS: u32 = 10; +const DOT_POS_BITS: u32 = 7; +const TOK_POS_BITS: u32 = 64 - (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS); + +fn mask32(bits: u32) -> u32 { + (1 << bits) - 1 +} + +fn mask64(bits: u32) -> u64 { + (1u64 << (bits as u64)) - 1 +} + +impl RuleIdx { + fn sym_idx(&self) -> SymIdx { + SymIdx(self.data >> RULE_IDX_BITS) + } + + fn sym_rule_idx(&self) -> usize { + (self.data & mask32(RULE_IDX_BITS)) as usize + } +} + +impl Symbol { + fn is_terminal(&self) -> bool { + self.bytes.is_some() + } +} + +impl SymIdx { + fn rule_at(&self, rule: usize) -> RuleIdx { + assert!(rule < mask32(RULE_IDX_BITS) as usize); + RuleIdx { + data: (self.0 << RULE_IDX_BITS) | rule as u32, + } + } +} + +const BYTESET_LEN: usize = 8; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct ByteSet { + mask: [u32; BYTESET_LEN], +} + +fn byte_to_string(b: u8) -> String { + if b >= 0x7f { + format!("x{:02x}", b) + } else { + let b = b as char; + match b { + '_' | 'a'..='z' | 'A'..='Z' | '0'..='9' => format!("{}", b), + _ => format!("{:?}", b as char), + } + } +} + +impl Display for ByteSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut start = None; + let mut first = true; + for i in 0u32..=256 { + if i <= 0xff && self.contains(i as u8) { + if start.is_none() { + start = Some(i); + } + } else { + if let Some(start) = start { + if !first { + write!(f, ";")?; + } + first = false; + write!(f, "{}", byte_to_string(start as u8))?; + if i - start > 1 { + write!(f, "-{}", byte_to_string((i - 1) as u8))?; + } + } + start = None; + } + } + Ok(()) + } +} + +impl ByteSet { + pub fn new() -> Self { + ByteSet { + mask: [0; BYTESET_LEN], + } + } + + pub fn from_sum<'a>(elts: impl Iterator) -> Self { + let mut r = ByteSet::new(); + for e in elts { + r.add_set(&e); + } + r + } + + pub fn add_set(&mut self, other: &ByteSet) { + for i in 0..BYTESET_LEN { + self.mask[i] |= other.mask[i]; + } + } + + pub fn add(&mut self, byte: u8) { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] |= 1 << bit; + } + + pub fn contains(&self, byte: u8) -> bool { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] & (1 << bit) != 0 + } + + pub fn from_range(start: u8, end: u8) -> Self { + let mut r = ByteSet::new(); + // TODO optimize + for b in start..=end { + r.add(b); + } + r + } +} + +struct Symbol { + idx: SymIdx, + name: String, + bytes: Option, + rules: Vec, + nullable: bool, +} + +struct Rule { + idx: RuleIdx, + rhs: Vec, +} + +impl Rule { + fn lhs(&self) -> SymIdx { + self.idx.sym_idx() + } +} + +pub struct Grammar { + symbols: Vec, + symbol_by_name: FxHashMap, + terminals: FxHashMap, +} + +#[derive(Clone)] +pub struct OptimizedGrammar { + inner: Rc, +} + +impl Debug for OptimizedGrammar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +// format: +// token_position : dot_position : symbol_index : rule_index +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct Item(u64); + +pub struct Row { + token: u8, + position: usize, + // TODO index this by .after_dot() ? + items: Vec, +} + +impl Row { + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } +} + +impl Item { + fn new(rule: RuleIdx, dot: usize, start: usize) -> Self { + assert!(start < mask64(TOK_POS_BITS) as usize); + assert!(dot < mask32(DOT_POS_BITS) as usize); + let data = (start as u64) << (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS) + | (dot as u64) << (SYM_IDX_BITS + RULE_IDX_BITS) + | (rule.data as u64); + Item(data) + } + + fn rule_idx(&self) -> RuleIdx { + RuleIdx { + data: self.0 as u32 & mask32(SYM_IDX_BITS + RULE_IDX_BITS), + } + } + + fn dot_pos(&self) -> usize { + (self.0 >> (SYM_IDX_BITS + RULE_IDX_BITS)) as usize & mask32(DOT_POS_BITS) as usize + } + + fn start_pos(&self) -> usize { + ((self.0 >> (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS)) & mask64(TOK_POS_BITS)) as usize + } + + fn advance_dot(&self) -> Self { + Item::new(self.rule_idx(), self.dot_pos() + 1, self.start_pos()) + } +} + +pub struct Parser { + grammar: Rc, + rows: Vec, +} + +impl Parser { + pub fn new(grammar: OptimizedGrammar) -> Self { + let grammar = grammar.inner; + let init_rules = grammar + .sym_data(grammar.start()) + .rules + .iter() + .map(|r| Item::new(r.idx, 0, 0)) + .collect(); + let mut r = Parser { + grammar, + rows: vec![], + }; + // '0' token is bogus + let row = r.make_row(init_rules, 0); + println!("init: {}", r.row_to_string(&row)); + r.push_row(row); + r + } + + fn item_to_string(&self, item: &Item) -> String { + let rule = self.grammar.rule_data(item.rule_idx()); + self.grammar.rule_to_string(rule, item.dot_pos()) + } + + pub fn row_to_string(&self, row: &Row) -> String { + let mut r = vec![format!("token: {}", byte_to_string(row.token))]; + for item in &row.items { + r.push(self.item_to_string(item)); + } + r.join("\n") + "\n" + } + + pub fn scan(&mut self, b: u8) -> Row { + let mut r = vec![]; + let row_idx = self.rows.len() - 1; + for item in &self.rows[row_idx].items { + if let Some(s) = self.grammar.after_dot(*item) { + if let Some(bytes) = self.grammar.sym_data(s).bytes.clone() { + if bytes.contains(b) { + r.push(item.advance_dot()); + } + } + } + } + self.make_row(r, b) + } + + pub fn pop_rows(&mut self, n: usize) { + self.rows.drain(self.rows.len() - n..); + } + + pub fn push_row(&mut self, row: Row) { + assert!(row.position == self.rows.len()); + self.rows.push(row); + } + + fn items_with_after_dot(&self, sym: SymIdx, row_idx: usize) -> Vec { + let mut r = vec![]; + for item in &self.rows[row_idx].items { + if self.grammar.after_dot(*item) == Some(sym) { + r.push(*item); + } + } + r + } + + fn make_row(&self, mut curr_row: Vec, token: u8) -> Row { + let curr_idx = self.rows.len(); + let mut agenda = curr_row.clone(); + let mut predicated_syms = vec![]; + + if DEBUG { + let row0 = Row { + token, + position: curr_idx, + items: curr_row.clone(), + }; + println!("row0: {}", self.row_to_string(&row0)); + } + + while !agenda.is_empty() { + let item = agenda.pop().unwrap(); + if DEBUG { + println!("from agenda: {}", self.item_to_string(&item)); + } + let lhs = item.rule_idx().sym_idx(); + let mut to_add = vec![]; + match self.grammar.after_dot(item) { + Some(after_dot) => { + let sym_data = self.grammar.sym_data(after_dot); + if sym_data.nullable { + let new_item = item.advance_dot(); + if !to_add.contains(&new_item) { + to_add.push(new_item); + if DEBUG { + println!(" adding (nullable): {}", self.item_to_string(&new_item)); + } + } + } + if !predicated_syms.contains(&after_dot) { + predicated_syms.push(after_dot); + for rule in &sym_data.rules { + let new_item = Item::new(rule.idx, 0, curr_idx); + if !to_add.contains(&new_item) { + to_add.push(new_item); + if DEBUG { + println!(" adding: {}", self.item_to_string(&new_item)); + } + } + } + } + } + // complete + None => { + if item.start_pos() < curr_idx { + // if item.start_pos() == curr_idx, then we handled it above in the nullable check + for parent in self.items_with_after_dot(lhs, item.start_pos()) { + let new_item = parent.advance_dot(); + if !to_add.contains(&new_item) { + to_add.push(new_item); + if DEBUG { + println!( + " adding complete: {}", + self.item_to_string(&new_item) + ); + } + } + } + } + } + } + + for new_item in to_add { + if !curr_row.contains(&new_item) { + curr_row.push(new_item); + agenda.push(new_item); + } + } + } + + Row { + token, + position: curr_idx, + items: curr_row, + } + } +} + +impl Grammar { + pub fn new() -> Self { + let mut r = Grammar { + symbols: vec![], + symbol_by_name: FxHashMap::default(), + terminals: FxHashMap::default(), + }; + let _ = r.symbol("_start"); + r + } + + pub fn start(&self) -> SymIdx { + self.symbols[0].idx + } + + fn sym_data(&self, sym: SymIdx) -> &Symbol { + &self.symbols[sym.0 as usize] + } + + fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { + &mut self.symbols[sym.0 as usize] + } + + fn rule_data(&self, rule: RuleIdx) -> &Rule { + let sym = self.sym_data(rule.sym_idx()); + if rule.sym_rule_idx() >= sym.rules.len() { + println!("invalid rule index; {}", sym.name); + } + &sym.rules[rule.sym_rule_idx()] + } + + fn propagate_nullable(&mut self) { + for sym in self.symbols.iter_mut() { + if sym.rules.iter().any(|r| r.rhs.is_empty()) { + sym.nullable = true; + sym.rules.retain(|r| !r.rhs.is_empty()); + // re-number them + for (i, r) in sym.rules.iter_mut().enumerate() { + r.idx = sym.idx.rule_at(i); + } + } + } + loop { + let mut to_null = vec![]; + for sym in self.symbols.iter() { + for rule in sym.rules.iter() { + if rule.rhs.iter().all(|s| self.sym_data(*s).nullable) { + if !sym.nullable { + to_null.push(sym.idx); + } + } + } + } + if to_null.is_empty() { + break; + } + for sym in to_null { + self.sym_data_mut(sym).nullable = true; + } + } + } + + pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { + assert!(rhs.len() < mask32(DOT_POS_BITS) as usize); + assert!(!self.sym_data(lhs).is_terminal()); + let sym = self.sym_data_mut(lhs); + sym.rules.push(Rule { + idx: lhs.rule_at(sym.rules.len()), + rhs, + }); + } + + pub fn terminal(&mut self, bytes: ByteSet) -> SymIdx { + match self.terminals.get(&bytes) { + Some(sym) => *sym, + None => { + let mut name = format!("T:{}", bytes); + if name.len() > 40 { + name = format!("T@{}", self.terminals.len()); + } + let sym = self.fresh_symbol(&name); + self.sym_data_mut(sym).bytes = Some(bytes.clone()); + self.terminals.insert(bytes, sym); + sym + } + } + } + + pub fn sym_name(&self, sym: SymIdx) -> &str { + &self.symbols[sym.0 as usize].name + } + + fn rule_to_string(&self, rule: &Rule, dot: usize) -> String { + let lhs = self.sym_name(rule.lhs()); + let mut rhs = rule + .rhs + .iter() + .enumerate() + .map(|(i, s)| { + format!( + "{}{}", + if i == dot { "(*) " } else { "" }, + self.sym_name(*s) + ) + }) + .collect::>() + .join(" "); + if dot == rule.rhs.len() { + rhs.push_str(" (*)"); + } + format!("{} ::= {}", lhs, rhs) + } + + fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { + let sym_data = other.sym_data(sym); + if sym_data.is_terminal() { + self.terminal(sym_data.bytes.clone().unwrap()) + } else { + self.symbol(&sym_data.name) + } + } + + fn collapse_terminals(&self) -> Self { + let mut outp = Grammar::new(); + for sym in &self.symbols { + if sym.rules.is_empty() { + continue; + } + let mut rules_by_shape = FxHashMap::default(); + for rule in &sym.rules { + let shape = rule + .rhs + .iter() + .map(|s| { + if self.sym_data(*s).is_terminal() { + None + } else { + Some(*s) + } + }) + .collect::>(); + rules_by_shape + .entry(shape) + .or_insert_with(Vec::new) + .push(rule); + } + let lhs = outp.copy_from(self, sym.idx); + for rules in rules_by_shape.values() { + let rhs = rules[0] + .rhs + .iter() + .enumerate() + .map(|(i, s)| { + let sym = self.sym_data(*s); + if sym.is_terminal() { + let terminals = rules + .iter() + .map(|r| self.sym_data(r.rhs[i]).bytes.clone().unwrap()); + outp.terminal(ByteSet::from_sum(terminals)) + } else { + outp.copy_from(self, *s) + } + }) + .collect(); + outp.add_rule(lhs, rhs); + } + } + outp + } + + fn expand_shortcuts(&self) -> Self { + let mut use_count = vec![0; self.symbols.len()]; + for sym in &self.symbols { + for r in sym.rules.iter() { + for s in &r.rhs { + use_count[s.0 as usize] += 1; + } + } + } + + let mut repl = FxHashMap::default(); + for sym in &self.symbols { + if sym.idx == self.start() { + continue; + } + if sym.rules.len() == 1 + && (use_count[sym.idx.0 as usize] == 1 || sym.rules[0].rhs.len() == 1) + { + // eliminate sym.idx + repl.insert(sym.idx, sym.rules[0].rhs.clone()); + } + } + + // fix-point expand the mapping + loop { + let to_change = repl + .iter() + .filter_map(|(lhs, rhs)| { + let rhs2 = rhs + .iter() + .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) + .collect::>(); + assert!(rhs2.iter().all(|s| *s != *lhs), "cyclic?"); + if *rhs != rhs2 { + Some((*lhs, rhs2)) + } else { + None + } + }) + .collect::>(); + if to_change.is_empty() { + break; + } + for (lhs, rhs) in to_change { + repl.insert(lhs, rhs); + } + } + + let mut outp = Grammar::new(); + for sym in &self.symbols { + if repl.contains_key(&sym.idx) { + continue; + } + let lhs = outp.copy_from(self, sym.idx); + for rule in &sym.rules { + let rhs = rule + .rhs + .iter() + .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) + .map(|s| outp.copy_from(self, s)) + .collect(); + outp.add_rule(lhs, rhs); + } + } + outp + } + + pub fn optimize(&self) -> OptimizedGrammar { + let mut outp = self + .expand_shortcuts() + .collapse_terminals() + .expand_shortcuts(); + outp.propagate_nullable(); + OptimizedGrammar { + inner: Rc::new(outp), + } + } + + pub fn fresh_symbol(&mut self, name0: &str) -> SymIdx { + let mut name = name0.to_string(); + let mut idx = 2; + while self.symbol_by_name.contains_key(&name) { + name = format!("{}#{}", name0, idx); + idx += 1; + } + + let idx = SymIdx(self.symbols.len() as u32); + self.symbols.push(Symbol { + name: name.clone(), + bytes: None, + idx, + rules: vec![], + nullable: false, + }); + self.symbol_by_name.insert(name, idx); + idx + } + + pub fn symbol(&mut self, name: &str) -> SymIdx { + match self.symbol_by_name.get(name) { + Some(idx) => *idx, + None => self.fresh_symbol(name), + } + } + + fn after_dot(&self, item: Item) -> Option { + let rule = self.rule_data(item.rule_idx()); + if item.dot_pos() < rule.rhs.len() { + Some(rule.rhs[item.dot_pos()]) + } else { + None + } + } +} + +impl Debug for Grammar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for sym in &self.symbols { + match sym.bytes { + Some(ref bytes) if sym.name.starts_with("T@") => { + writeln!(f, "{} := {}", sym.name, bytes)? + } + _ => {} + } + } + let mut num_term = 0; + let mut num_rules = 0; + let mut num_non_term = 0; + for sym in &self.symbols { + if sym.is_terminal() { + num_term += 1; + } else { + num_non_term += 1; + num_rules += sym.rules.len(); + } + if sym.nullable { + writeln!(f, "{} ::= ϵ", sym.name)?; + } + for rule in &sym.rules { + writeln!(f, "{}", self.rule_to_string(rule, usize::MAX))?; + } + } + writeln!( + f, + "stats: {} terminals; {} non-terminals with {} rules\n", + num_term, num_non_term, num_rules + )?; + Ok(()) + } +} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 7d489ac0..9307df1a 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -18,9 +18,6 @@ pub mod rx; #[cfg(feature = "earley")] pub mod earley; -#[cfg(all(feature = "earley", feature = "cfg"))] -pub mod earley_yacc; -mod guidance; pub mod substring; From aa20ab47d15bac0f51e6c27f978cae99bbe43913 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Mar 2024 01:02:15 +0000 Subject: [PATCH 151/301] move byteset out --- controllers/aici_abi/src/earley/byteset.rs | 90 +++++++++++++++++++ .../aici_abi/src/earley/from_guidance.rs | 11 +-- controllers/aici_abi/src/earley/mod.rs | 9 +- controllers/aici_abi/src/earley/parser.rs | 90 +------------------ 4 files changed, 100 insertions(+), 100 deletions(-) create mode 100644 controllers/aici_abi/src/earley/byteset.rs diff --git a/controllers/aici_abi/src/earley/byteset.rs b/controllers/aici_abi/src/earley/byteset.rs new file mode 100644 index 00000000..414df548 --- /dev/null +++ b/controllers/aici_abi/src/earley/byteset.rs @@ -0,0 +1,90 @@ +use std::fmt::Display; + +const BYTESET_LEN: usize = 8; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct ByteSet { + mask: [u32; BYTESET_LEN], +} + +pub fn byte_to_string(b: u8) -> String { + if b >= 0x7f { + format!("x{:02x}", b) + } else { + let b = b as char; + match b { + '_' | 'a'..='z' | 'A'..='Z' | '0'..='9' => format!("{}", b), + _ => format!("{:?}", b as char), + } + } +} + +impl Display for ByteSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut start = None; + let mut first = true; + for i in 0u32..=256 { + if i <= 0xff && self.contains(i as u8) { + if start.is_none() { + start = Some(i); + } + } else { + if let Some(start) = start { + if !first { + write!(f, ";")?; + } + first = false; + write!(f, "{}", byte_to_string(start as u8))?; + if i - start > 1 { + write!(f, "-{}", byte_to_string((i - 1) as u8))?; + } + } + start = None; + } + } + Ok(()) + } +} + +impl ByteSet { + pub fn new() -> Self { + ByteSet { + mask: [0; BYTESET_LEN], + } + } + + pub fn from_sum<'a>(elts: impl Iterator) -> Self { + let mut r = ByteSet::new(); + for e in elts { + r.add_set(&e); + } + r + } + + pub fn add_set(&mut self, other: &ByteSet) { + for i in 0..BYTESET_LEN { + self.mask[i] |= other.mask[i]; + } + } + + pub fn add(&mut self, byte: u8) { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] |= 1 << bit; + } + + pub fn contains(&self, byte: u8) -> bool { + let idx = byte as usize / 32; + let bit = byte as usize % 32; + self.mask[idx] & (1 << bit) != 0 + } + + pub fn from_range(start: u8, end: u8) -> Self { + let mut r = ByteSet::new(); + // TODO optimize + for b in start..=end { + r.add(b); + } + r + } +} diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 6da219d8..d8fdbc36 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -2,15 +2,8 @@ use anyhow::Result; use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; -use crate::{ - earley::{ - guidance, - parser::{ByteSet, Parser}, - }, - toktree::TokTrie, -}; - -use super::parser::Grammar; +use super::{guidance, ByteSet, Parser, Grammar}; +use crate::toktree::TokTrie; pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); diff --git a/controllers/aici_abi/src/earley/mod.rs b/controllers/aici_abi/src/earley/mod.rs index 3e228090..224a9bd5 100644 --- a/controllers/aici_abi/src/earley/mod.rs +++ b/controllers/aici_abi/src/earley/mod.rs @@ -1,5 +1,8 @@ -pub mod parser; -mod guidance; +mod byteset; mod from_guidance; +mod guidance; +mod parser; -pub use from_guidance::earley_test; \ No newline at end of file +pub use byteset::ByteSet; +pub use from_guidance::earley_test; +pub use parser::{Grammar, Parser}; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 9311edcc..7d1e8c4d 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -6,6 +6,8 @@ use std::{ use rustc_hash::FxHashMap; +use super::{byteset::byte_to_string, ByteSet}; + const DEBUG: bool = false; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -56,94 +58,6 @@ impl SymIdx { } } -const BYTESET_LEN: usize = 8; - -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct ByteSet { - mask: [u32; BYTESET_LEN], -} - -fn byte_to_string(b: u8) -> String { - if b >= 0x7f { - format!("x{:02x}", b) - } else { - let b = b as char; - match b { - '_' | 'a'..='z' | 'A'..='Z' | '0'..='9' => format!("{}", b), - _ => format!("{:?}", b as char), - } - } -} - -impl Display for ByteSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut start = None; - let mut first = true; - for i in 0u32..=256 { - if i <= 0xff && self.contains(i as u8) { - if start.is_none() { - start = Some(i); - } - } else { - if let Some(start) = start { - if !first { - write!(f, ";")?; - } - first = false; - write!(f, "{}", byte_to_string(start as u8))?; - if i - start > 1 { - write!(f, "-{}", byte_to_string((i - 1) as u8))?; - } - } - start = None; - } - } - Ok(()) - } -} - -impl ByteSet { - pub fn new() -> Self { - ByteSet { - mask: [0; BYTESET_LEN], - } - } - - pub fn from_sum<'a>(elts: impl Iterator) -> Self { - let mut r = ByteSet::new(); - for e in elts { - r.add_set(&e); - } - r - } - - pub fn add_set(&mut self, other: &ByteSet) { - for i in 0..BYTESET_LEN { - self.mask[i] |= other.mask[i]; - } - } - - pub fn add(&mut self, byte: u8) { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] |= 1 << bit; - } - - pub fn contains(&self, byte: u8) -> bool { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] & (1 << bit) != 0 - } - - pub fn from_range(start: u8, end: u8) -> Self { - let mut r = ByteSet::new(); - // TODO optimize - for b in start..=end { - r.add(b); - } - r - } -} struct Symbol { idx: SymIdx, From 341411e200d405a8a17c7284e708666f9cec3740 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Mar 2024 01:33:00 +0000 Subject: [PATCH 152/301] implementing recognizer --- .../aici_abi/src/earley/from_guidance.rs | 107 ++++++++++++------ controllers/aici_abi/src/earley/parser.rs | 54 ++++----- 2 files changed, 97 insertions(+), 64 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index d8fdbc36..5b120c50 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -2,8 +2,8 @@ use anyhow::Result; use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; -use super::{guidance, ByteSet, Parser, Grammar}; -use crate::toktree::TokTrie; +use super::{guidance, ByteSet, Grammar, Parser}; +use crate::toktree::{Recognizer, SpecialToken, TokTrie}; pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); @@ -72,6 +72,38 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { Ok(grm) } +impl Recognizer for Parser { + fn pop_bytes(&mut self, num: usize) { + self.pop_rows(num); + } + + fn collapse(&mut self) { + // does nothing - we need to keep the entire state + } + + fn special_allowed(&mut self, tok: SpecialToken) -> bool { + if tok == SpecialToken::EndOfSentence { + self.curr_row().is_accepting() + } else { + false + } + } + + fn trie_finished(&mut self) { + // do nothing? + } + + fn try_push_byte(&mut self, byte: u8) -> bool { + let row = self.scan(byte); + if row.is_empty() { + false + } else { + self.push_row(row); + true + } + } +} + #[allow(dead_code)] pub fn earley_test(trie: TokTrie) { let g_bytes = include_bytes!("../../grammars/json0.guidance"); @@ -85,7 +117,7 @@ pub fn earley_test(trie: TokTrie) { let toks = trie.greedy_tokenize(input); println!("toks: {:?}", toks.len()); - let mut parser = Parser::new(cfg); + let mut parser = Parser::new(cfg.clone()); for b in input { let row = parser.scan(*b); if row.is_empty() { @@ -96,38 +128,39 @@ pub fn earley_test(trie: TokTrie) { parser.push_row(row); } - // #[cfg(not(target_arch = "wasm32"))] - // let t0 = std::time::Instant::now(); - - // let mut line = 1; - // let mut vob = trie.alloc_token_set(); - - // for tok in &toks[0..1000] { - // let tok = *tok; - // trie.compute_bias(&mut cfg, &mut vob); - // if !vob.is_allowed(tok) { - // println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); - // panic!(); - // } - // for b in trie.token(tok) { - // if *b == b'\n' { - // line += 1; - // } - // } - // if false { - // println!( - // "tok: {:?} {}; {}", - // trie.token_str(tok), - // vob.is_allowed(tok), - // cfg.get_stats() - // ); - // cfg.viable_now(); - // } - // trie.append_token(&mut cfg, tok); - // } - - // #[cfg(not(target_arch = "wasm32"))] - // println!("time: {:?} ", t0.elapsed()); - - // println!("stats: {}", cfg.get_stats()); + #[cfg(not(target_arch = "wasm32"))] + let t0 = std::time::Instant::now(); + + let mut line = 1; + let mut vob = trie.alloc_token_set(); + + parser = Parser::new(cfg); + println!("start!"); + let mut times = vec![]; + + for tok in &toks { + let tok = *tok; + let tt = std::time::Instant::now(); + trie.compute_bias(&mut parser, &mut vob); + if !vob.is_allowed(tok) { + println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); + panic!(); + } + for b in trie.token(tok) { + if *b == b'\n' { + line += 1; + } + } + println!("TOKENS: {}", trie.token_set_dbg(&vob)); + trie.append_token(&mut parser, tok); + times.push(tt.elapsed().as_micros() as u32); + } + + #[cfg(not(target_arch = "wasm32"))] + println!( + "time: {:?} ({:?}/tok)", + t0.elapsed(), + t0.elapsed() / toks.len() as u32 + ); + println!("times: {:?}", times); } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 7d1e8c4d..fff6527d 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,8 +1,4 @@ -use std::{ - fmt::{Debug, Display}, - rc::Rc, - vec, -}; +use std::{fmt::Debug, rc::Rc, vec}; use rustc_hash::FxHashMap; @@ -58,7 +54,6 @@ impl SymIdx { } } - struct Symbol { idx: SymIdx, name: String, @@ -105,12 +100,17 @@ pub struct Row { position: usize, // TODO index this by .after_dot() ? items: Vec, + accepting: bool, } impl Row { pub fn is_empty(&self) -> bool { self.items.is_empty() } + + pub fn is_accepting(&self) -> bool { + self.accepting + } } impl Item { @@ -199,6 +199,10 @@ impl Parser { self.rows.drain(self.rows.len() - n..); } + pub fn curr_row(&self) -> &Row { + &self.rows[self.rows.len() - 1] + } + pub fn push_row(&mut self, row: Row) { assert!(row.position == self.rows.len()); self.rows.push(row); @@ -218,12 +222,14 @@ impl Parser { let curr_idx = self.rows.len(); let mut agenda = curr_row.clone(); let mut predicated_syms = vec![]; + let mut accepting = false; if DEBUG { let row0 = Row { token, position: curr_idx, items: curr_row.clone(), + accepting, }; println!("row0: {}", self.row_to_string(&row0)); } @@ -234,29 +240,30 @@ impl Parser { println!("from agenda: {}", self.item_to_string(&item)); } let lhs = item.rule_idx().sym_idx(); + if lhs == self.grammar.start() && self.grammar.after_dot(item).is_none() { + accepting = true; + } let mut to_add = vec![]; + let mut add = |new_item: Item, tag: &str| { + if !to_add.contains(&new_item) { + to_add.push(new_item); + if DEBUG { + println!(" adding {}: {}", tag, self.item_to_string(&new_item)); + } + } + }; match self.grammar.after_dot(item) { Some(after_dot) => { let sym_data = self.grammar.sym_data(after_dot); if sym_data.nullable { let new_item = item.advance_dot(); - if !to_add.contains(&new_item) { - to_add.push(new_item); - if DEBUG { - println!(" adding (nullable): {}", self.item_to_string(&new_item)); - } - } + add(new_item, "null"); } if !predicated_syms.contains(&after_dot) { predicated_syms.push(after_dot); for rule in &sym_data.rules { let new_item = Item::new(rule.idx, 0, curr_idx); - if !to_add.contains(&new_item) { - to_add.push(new_item); - if DEBUG { - println!(" adding: {}", self.item_to_string(&new_item)); - } - } + add(new_item, "predict"); } } } @@ -266,15 +273,7 @@ impl Parser { // if item.start_pos() == curr_idx, then we handled it above in the nullable check for parent in self.items_with_after_dot(lhs, item.start_pos()) { let new_item = parent.advance_dot(); - if !to_add.contains(&new_item) { - to_add.push(new_item); - if DEBUG { - println!( - " adding complete: {}", - self.item_to_string(&new_item) - ); - } - } + add(new_item, "complete"); } } } @@ -292,6 +291,7 @@ impl Parser { token, position: curr_idx, items: curr_row, + accepting, } } } From 45b2180beadcf81053f9edc73564ae2519e9345d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Mar 2024 02:50:48 +0000 Subject: [PATCH 153/301] re-order memory in grammar --- .../aici_abi/src/earley/from_guidance.rs | 10 +- controllers/aici_abi/src/earley/grammar.rs | 465 +++++++++++++++ controllers/aici_abi/src/earley/mod.rs | 4 +- controllers/aici_abi/src/earley/parser.rs | 537 ++---------------- 4 files changed, 536 insertions(+), 480 deletions(-) create mode 100644 controllers/aici_abi/src/earley/grammar.rs diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 5b120c50..5c93a8a9 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -108,7 +108,7 @@ impl Recognizer for Parser { pub fn earley_test(trie: TokTrie) { let g_bytes = include_bytes!("../../grammars/json0.guidance"); let cfg = earley_grm_from_guidance(g_bytes).unwrap(); - println!("cfg0: {:?}", cfg); + // println!("cfg0: {:?}", cfg); let cfg = cfg.optimize(); println!("cfg: {:?}", cfg); @@ -117,14 +117,14 @@ pub fn earley_test(trie: TokTrie) { let toks = trie.greedy_tokenize(input); println!("toks: {:?}", toks.len()); - let mut parser = Parser::new(cfg.clone()); + let mut parser = Parser::new(cfg.compile()); for b in input { let row = parser.scan(*b); if row.is_empty() { println!("reject"); break; } - println!("row: {}", parser.row_to_string(&row)); + // println!("row: {}", parser.row_to_string(&row)); parser.push_row(row); } @@ -134,7 +134,7 @@ pub fn earley_test(trie: TokTrie) { let mut line = 1; let mut vob = trie.alloc_token_set(); - parser = Parser::new(cfg); + parser = Parser::new(cfg.compile()); println!("start!"); let mut times = vec![]; @@ -151,7 +151,7 @@ pub fn earley_test(trie: TokTrie) { line += 1; } } - println!("TOKENS: {}", trie.token_set_dbg(&vob)); + // println!("TOK: {} ===> {}", trie.token_dbg(tok), trie.token_set_dbg(&vob)); trie.append_token(&mut parser, tok); times.push(tt.elapsed().as_micros() as u32); } diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs new file mode 100644 index 00000000..83b8acba --- /dev/null +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -0,0 +1,465 @@ +use std::fmt::Debug; + +use super::ByteSet; +use rustc_hash::FxHashMap; +use vob::Vob; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SymIdx(u32); + +impl Symbol { + fn is_terminal(&self) -> bool { + self.bytes.is_some() + } +} + +struct Symbol { + idx: SymIdx, + name: String, + bytes: Option, + rules: Vec, +} + +struct Rule { + lhs: SymIdx, + rhs: Vec, +} + +impl Rule { + fn lhs(&self) -> SymIdx { + self.lhs + } +} + +pub struct Grammar { + symbols: Vec, + symbol_by_name: FxHashMap, + terminals: FxHashMap, +} + +impl Grammar { + pub fn new() -> Self { + let mut r = Grammar { + symbols: vec![], + symbol_by_name: FxHashMap::default(), + terminals: FxHashMap::default(), + }; + let _ = r.symbol("_start"); + r + } + + pub fn start(&self) -> SymIdx { + self.symbols[0].idx + } + + fn sym_data(&self, sym: SymIdx) -> &Symbol { + &self.symbols[sym.0 as usize] + } + + fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { + &mut self.symbols[sym.0 as usize] + } + + pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { + assert!(!self.sym_data(lhs).is_terminal()); + let sym = self.sym_data_mut(lhs); + sym.rules.push(Rule { lhs, rhs }); + } + + pub fn terminal(&mut self, bytes: ByteSet) -> SymIdx { + match self.terminals.get(&bytes) { + Some(sym) => *sym, + None => { + let mut name = format!("T:{}", bytes); + if name.len() > 40 { + name = format!("T@{}", self.terminals.len()); + } + let sym = self.fresh_symbol(&name); + self.sym_data_mut(sym).bytes = Some(bytes.clone()); + self.terminals.insert(bytes, sym); + sym + } + } + } + + pub fn sym_name(&self, sym: SymIdx) -> &str { + &self.symbols[sym.0 as usize].name + } + + fn rule_to_string(&self, rule: &Rule, dot: usize) -> String { + let lhs = self.sym_name(rule.lhs()); + let mut rhs = rule + .rhs + .iter() + .enumerate() + .map(|(i, s)| { + format!( + "{}{}", + if i == dot { "(*) " } else { "" }, + self.sym_name(*s) + ) + }) + .collect::>() + .join(" "); + if rule.rhs.is_empty() { + rhs.push_str("ϵ"); + } + if dot == rule.rhs.len() { + rhs.push_str(" (*)"); + } + format!("{} ::= {}", lhs, rhs) + } + + fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { + let sym_data = other.sym_data(sym); + if sym_data.is_terminal() { + self.terminal(sym_data.bytes.clone().unwrap()) + } else { + self.symbol(&sym_data.name) + } + } + + fn collapse_terminals(&self) -> Self { + let mut outp = Grammar::new(); + for sym in &self.symbols { + if sym.rules.is_empty() { + continue; + } + let mut rules_by_shape = FxHashMap::default(); + for rule in &sym.rules { + let shape = rule + .rhs + .iter() + .map(|s| { + if self.sym_data(*s).is_terminal() { + None + } else { + Some(*s) + } + }) + .collect::>(); + rules_by_shape + .entry(shape) + .or_insert_with(Vec::new) + .push(rule); + } + let lhs = outp.copy_from(self, sym.idx); + for rules in rules_by_shape.values() { + let rhs = rules[0] + .rhs + .iter() + .enumerate() + .map(|(i, s)| { + let sym = self.sym_data(*s); + if sym.is_terminal() { + let terminals = rules + .iter() + .map(|r| self.sym_data(r.rhs[i]).bytes.clone().unwrap()); + outp.terminal(ByteSet::from_sum(terminals)) + } else { + outp.copy_from(self, *s) + } + }) + .collect(); + outp.add_rule(lhs, rhs); + } + } + outp + } + + fn expand_shortcuts(&self) -> Self { + let mut use_count = vec![0; self.symbols.len()]; + for sym in &self.symbols { + for r in sym.rules.iter() { + for s in &r.rhs { + use_count[s.0 as usize] += 1; + } + } + } + + let mut repl = FxHashMap::default(); + for sym in &self.symbols { + if sym.idx == self.start() { + continue; + } + if sym.rules.len() == 1 + && (use_count[sym.idx.0 as usize] == 1 || sym.rules[0].rhs.len() == 1) + { + // eliminate sym.idx + repl.insert(sym.idx, sym.rules[0].rhs.clone()); + } + } + + // fix-point expand the mapping + loop { + let to_change = repl + .iter() + .filter_map(|(lhs, rhs)| { + let rhs2 = rhs + .iter() + .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) + .collect::>(); + assert!(rhs2.iter().all(|s| *s != *lhs), "cyclic?"); + if *rhs != rhs2 { + Some((*lhs, rhs2)) + } else { + None + } + }) + .collect::>(); + if to_change.is_empty() { + break; + } + for (lhs, rhs) in to_change { + repl.insert(lhs, rhs); + } + } + + let mut outp = Grammar::new(); + for sym in &self.symbols { + if repl.contains_key(&sym.idx) { + continue; + } + let lhs = outp.copy_from(self, sym.idx); + for rule in &sym.rules { + let rhs = rule + .rhs + .iter() + .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) + .map(|s| outp.copy_from(self, s)) + .collect(); + outp.add_rule(lhs, rhs); + } + } + outp + } + + pub fn optimize(&self) -> Self { + self.expand_shortcuts() + .collapse_terminals() + .expand_shortcuts() + } + + pub fn compile(&self) -> OptGrammar { + OptGrammar::from_grammar(self) + } + + pub fn fresh_symbol(&mut self, name0: &str) -> SymIdx { + let mut name = name0.to_string(); + let mut idx = 2; + while self.symbol_by_name.contains_key(&name) { + name = format!("{}#{}", name0, idx); + idx += 1; + } + + let idx = SymIdx(self.symbols.len() as u32); + self.symbols.push(Symbol { + name: name.clone(), + bytes: None, + idx, + rules: vec![], + }); + self.symbol_by_name.insert(name, idx); + idx + } + + pub fn symbol(&mut self, name: &str) -> SymIdx { + match self.symbol_by_name.get(name) { + Some(idx) => *idx, + None => self.fresh_symbol(name), + } + } +} + +impl Debug for Grammar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for sym in &self.symbols { + match sym.bytes { + Some(ref bytes) if sym.name.starts_with("T@") => { + writeln!(f, "{} := {}", sym.name, bytes)? + } + _ => {} + } + } + let mut num_term = 0; + let mut num_rules = 0; + let mut num_non_term = 0; + for sym in &self.symbols { + if sym.is_terminal() { + num_term += 1; + } else { + num_non_term += 1; + num_rules += sym.rules.len(); + } + for rule in &sym.rules { + writeln!(f, "{}", self.rule_to_string(rule, usize::MAX))?; + } + } + writeln!( + f, + "stats: {} terminals; {} non-terminals with {} rules\n", + num_term, num_non_term, num_rules + )?; + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct OptSymIdx(u16); + +impl OptSymIdx { + pub const NULL: OptSymIdx = OptSymIdx(0); + + pub fn as_index(&self) -> usize { + self.0 as usize + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RuleIdx(u32); + +impl RuleIdx { + pub fn advance(&self) -> RuleIdx { + RuleIdx(self.0 + 1) + } + + pub fn as_index(&self) -> usize { + self.0 as usize + } +} + +pub struct OptSymbol { + pub idx: OptSymIdx, + pub name: String, + pub is_terminal: bool, + pub is_nullable: bool, + pub rules: Vec, +} + +pub struct OptGrammar { + start_symbol: OptSymIdx, + terminals: Vec, + symbols: Vec, + rules: Vec, + terminals_by_byte: Vec, +} + +impl OptGrammar { + pub fn sym_data(&self, sym: OptSymIdx) -> &OptSymbol { + &self.symbols[sym.0 as usize] + } + + fn sym_data_mut(&mut self, sym: OptSymIdx) -> &mut OptSymbol { + &mut self.symbols[sym.0 as usize] + } + + pub fn terminals_by_byte(&self, b: u8) -> &Vob { + &self.terminals_by_byte[b as usize] + } + + pub fn sym_idx_at(&self, idx: RuleIdx) -> OptSymIdx { + self.rules[idx.0 as usize] + } + + pub fn start(&self) -> OptSymIdx { + self.start_symbol + } + + pub fn is_accepting(&self, sym: OptSymIdx, rule: RuleIdx) -> bool { + sym == self.start() && self.sym_idx_at(rule) == OptSymIdx::NULL + } + + pub fn rules_of(&self, sym: OptSymIdx) -> &[RuleIdx] { + &self.sym_data(sym).rules + } + + fn from_grammar(grammar: &Grammar) -> Self { + let mut outp = OptGrammar { + start_symbol: OptSymIdx::NULL, + terminals: vec![ByteSet::new()], + symbols: vec![], + rules: vec![], + terminals_by_byte: vec![], + }; + let mut sym_map = FxHashMap::default(); + for (_, sidx) in &grammar.terminals { + let sym = grammar.sym_data(*sidx); + outp.terminals.push(sym.bytes.clone().unwrap()); + let idx = outp.symbols.len() as u16; + outp.symbols.push(OptSymbol { + idx: OptSymIdx(idx), + name: sym.name.clone(), + is_terminal: true, + is_nullable: false, + rules: vec![], + }); + sym_map.insert(sym.idx, OptSymIdx(idx)); + } + for sym in &grammar.symbols { + if sym.is_terminal() { + continue; + } + let idx = outp.symbols.len() as u16; + outp.symbols.push(OptSymbol { + idx: OptSymIdx(idx), + name: sym.name.clone(), + is_terminal: false, + is_nullable: sym.rules.iter().any(|r| r.rhs.is_empty()), + rules: vec![], + }); + sym_map.insert(sym.idx, OptSymIdx(idx)); + } + outp.start_symbol = sym_map[&grammar.start()]; + for sym in &grammar.symbols { + if sym.is_terminal() { + continue; + } + let idx = sym_map[&sym.idx]; + for rule in &sym.rules { + let curr = RuleIdx(outp.rules.len().try_into().unwrap()); + outp.sym_data_mut(idx).rules.push(curr); + // outp.rules.push(idx); + for r in &rule.rhs { + outp.rules.push(sym_map[r]); + } + outp.rules.push(OptSymIdx::NULL); + } + } + + loop { + let mut to_null = vec![]; + for sym in &outp.symbols { + if sym.is_nullable { + continue; + } + 'rules: for rule in sym.rules.iter() { + let mut idx = rule.as_index(); + while outp.rules[idx] != OptSymIdx::NULL { + if outp.sym_data(outp.rules[idx]).is_nullable { + to_null.push(sym.idx); + break 'rules; + } + idx += 1; + } + } + } + if to_null.is_empty() { + break; + } + for sym in to_null { + outp.sym_data_mut(sym).is_nullable = true; + } + } + + for b in 0..=255 { + let mut v = Vob::from_elem(false, outp.terminals.len()); + for (i, bytes) in outp.terminals.iter().enumerate() { + if bytes.contains(b as u8) { + v.set(i, true); + } + } + outp.terminals_by_byte.push(v); + } + outp + } +} diff --git a/controllers/aici_abi/src/earley/mod.rs b/controllers/aici_abi/src/earley/mod.rs index 224a9bd5..65c7de23 100644 --- a/controllers/aici_abi/src/earley/mod.rs +++ b/controllers/aici_abi/src/earley/mod.rs @@ -1,8 +1,10 @@ mod byteset; mod from_guidance; +mod grammar; mod guidance; mod parser; pub use byteset::ByteSet; pub use from_guidance::earley_test; -pub use parser::{Grammar, Parser}; +pub use parser::Parser; +pub use grammar::Grammar; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index fff6527d..2b969cfb 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,100 +1,19 @@ -use std::{fmt::Debug, rc::Rc, vec}; +use std::{fmt::Debug, vec}; -use rustc_hash::FxHashMap; - -use super::{byteset::byte_to_string, ByteSet}; +use super::{ + byteset::byte_to_string, + grammar::{OptGrammar, OptSymIdx, RuleIdx}, +}; const DEBUG: bool = false; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct SymIdx(u32); - -// format: -// symbol_index : rule_index -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct RuleIdx { - data: u32, -} - -const SYM_IDX_BITS: u32 = 12; -const RULE_IDX_BITS: u32 = 10; -const DOT_POS_BITS: u32 = 7; -const TOK_POS_BITS: u32 = 64 - (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS); - -fn mask32(bits: u32) -> u32 { - (1 << bits) - 1 -} - -fn mask64(bits: u32) -> u64 { - (1u64 << (bits as u64)) - 1 -} - -impl RuleIdx { - fn sym_idx(&self) -> SymIdx { - SymIdx(self.data >> RULE_IDX_BITS) - } - - fn sym_rule_idx(&self) -> usize { - (self.data & mask32(RULE_IDX_BITS)) as usize - } -} - -impl Symbol { - fn is_terminal(&self) -> bool { - self.bytes.is_some() - } -} - -impl SymIdx { - fn rule_at(&self, rule: usize) -> RuleIdx { - assert!(rule < mask32(RULE_IDX_BITS) as usize); - RuleIdx { - data: (self.0 << RULE_IDX_BITS) | rule as u32, - } - } -} - -struct Symbol { - idx: SymIdx, - name: String, - bytes: Option, - rules: Vec, - nullable: bool, -} - -struct Rule { - idx: RuleIdx, - rhs: Vec, -} - -impl Rule { - fn lhs(&self) -> SymIdx { - self.idx.sym_idx() - } -} - -pub struct Grammar { - symbols: Vec, - symbol_by_name: FxHashMap, - terminals: FxHashMap, +struct Item { + rule_idx: RuleIdx, + start: u32, + sym_idx: OptSymIdx, } -#[derive(Clone)] -pub struct OptimizedGrammar { - inner: Rc, -} - -impl Debug for OptimizedGrammar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.inner.fmt(f) - } -} - -// format: -// token_position : dot_position : symbol_index : rule_index -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct Item(u64); - pub struct Row { token: u8, position: usize, @@ -114,47 +33,43 @@ impl Row { } impl Item { - fn new(rule: RuleIdx, dot: usize, start: usize) -> Self { - assert!(start < mask64(TOK_POS_BITS) as usize); - assert!(dot < mask32(DOT_POS_BITS) as usize); - let data = (start as u64) << (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS) - | (dot as u64) << (SYM_IDX_BITS + RULE_IDX_BITS) - | (rule.data as u64); - Item(data) + fn new(sym: OptSymIdx, rule: RuleIdx, start: usize) -> Self { + Item { + sym_idx: sym, + rule_idx: rule, + start: start.try_into().unwrap(), + } } fn rule_idx(&self) -> RuleIdx { - RuleIdx { - data: self.0 as u32 & mask32(SYM_IDX_BITS + RULE_IDX_BITS), - } + self.rule_idx } - fn dot_pos(&self) -> usize { - (self.0 >> (SYM_IDX_BITS + RULE_IDX_BITS)) as usize & mask32(DOT_POS_BITS) as usize + fn sym_idx(&self) -> OptSymIdx { + self.sym_idx } fn start_pos(&self) -> usize { - ((self.0 >> (DOT_POS_BITS + SYM_IDX_BITS + RULE_IDX_BITS)) & mask64(TOK_POS_BITS)) as usize + self.start as usize } fn advance_dot(&self) -> Self { - Item::new(self.rule_idx(), self.dot_pos() + 1, self.start_pos()) + Item::new(self.sym_idx, self.rule_idx.advance(), self.start_pos()) } } pub struct Parser { - grammar: Rc, + grammar: OptGrammar, rows: Vec, } impl Parser { - pub fn new(grammar: OptimizedGrammar) -> Self { - let grammar = grammar.inner; + pub fn new(grammar: OptGrammar) -> Self { + let start = grammar.start(); let init_rules = grammar - .sym_data(grammar.start()) - .rules + .rules_of(start) .iter() - .map(|r| Item::new(r.idx, 0, 0)) + .map(|r| Item::new(start, *r, 0)) .collect(); let mut r = Parser { grammar, @@ -168,8 +83,12 @@ impl Parser { } fn item_to_string(&self, item: &Item) -> String { - let rule = self.grammar.rule_data(item.rule_idx()); - self.grammar.rule_to_string(rule, item.dot_pos()) + // let rule = self.grammar.rule_data(item.rule_idx()); + // self.grammar.rule_to_string(rule, item.dot_pos()) + format!( + "item: rule: {:?}, dot: {:?}, start: {}", + item.rule_idx, item.sym_idx, item.start + ) } pub fn row_to_string(&self, row: &Row) -> String { @@ -180,16 +99,15 @@ impl Parser { r.join("\n") + "\n" } - pub fn scan(&mut self, b: u8) -> Row { + pub fn scan(&self, b: u8) -> Row { + let allowed = self.grammar.terminals_by_byte(b); let mut r = vec![]; let row_idx = self.rows.len() - 1; for item in &self.rows[row_idx].items { - if let Some(s) = self.grammar.after_dot(*item) { - if let Some(bytes) = self.grammar.sym_data(s).bytes.clone() { - if bytes.contains(b) { - r.push(item.advance_dot()); - } - } + let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index(); + assert!(idx != 0); + if idx < allowed.len() && allowed[idx] { + r.push(item.advance_dot()); } } self.make_row(r, b) @@ -208,16 +126,6 @@ impl Parser { self.rows.push(row); } - fn items_with_after_dot(&self, sym: SymIdx, row_idx: usize) -> Vec { - let mut r = vec![]; - for item in &self.rows[row_idx].items { - if self.grammar.after_dot(*item) == Some(sym) { - r.push(*item); - } - } - r - } - fn make_row(&self, mut curr_row: Vec, token: u8) -> Row { let curr_idx = self.rows.len(); let mut agenda = curr_row.clone(); @@ -239,10 +147,6 @@ impl Parser { if DEBUG { println!("from agenda: {}", self.item_to_string(&item)); } - let lhs = item.rule_idx().sym_idx(); - if lhs == self.grammar.start() && self.grammar.after_dot(item).is_none() { - accepting = true; - } let mut to_add = vec![]; let mut add = |new_item: Item, tag: &str| { if !to_add.contains(&new_item) { @@ -252,29 +156,35 @@ impl Parser { } } }; - match self.grammar.after_dot(item) { - Some(after_dot) => { - let sym_data = self.grammar.sym_data(after_dot); - if sym_data.nullable { - let new_item = item.advance_dot(); - add(new_item, "null"); - } - if !predicated_syms.contains(&after_dot) { - predicated_syms.push(after_dot); - for rule in &sym_data.rules { - let new_item = Item::new(rule.idx, 0, curr_idx); - add(new_item, "predict"); + + let lhs = item.sym_idx(); + let rule = item.rule_idx(); + let after_dot = self.grammar.sym_idx_at(rule); + + if after_dot == OptSymIdx::NULL { + // complete + if lhs == self.grammar.start() { + accepting = true; + } + + if item.start_pos() < curr_idx { + // if item.start_pos() == curr_idx, then we handled it above in the nullable check + for item in self.rows[item.start_pos()].items.iter() { + if self.grammar.sym_idx_at(item.rule_idx()) == lhs { + add(item.advance_dot(), "complete"); } } } - // complete - None => { - if item.start_pos() < curr_idx { - // if item.start_pos() == curr_idx, then we handled it above in the nullable check - for parent in self.items_with_after_dot(lhs, item.start_pos()) { - let new_item = parent.advance_dot(); - add(new_item, "complete"); - } + } else { + let sym_data = self.grammar.sym_data(after_dot); + if sym_data.is_nullable { + add(item.advance_dot(), "null"); + } + if !predicated_syms.contains(&after_dot) { + predicated_syms.push(after_dot); + for rule in &sym_data.rules { + let new_item = Item::new(after_dot, *rule, curr_idx); + add(new_item, "predict"); } } } @@ -295,324 +205,3 @@ impl Parser { } } } - -impl Grammar { - pub fn new() -> Self { - let mut r = Grammar { - symbols: vec![], - symbol_by_name: FxHashMap::default(), - terminals: FxHashMap::default(), - }; - let _ = r.symbol("_start"); - r - } - - pub fn start(&self) -> SymIdx { - self.symbols[0].idx - } - - fn sym_data(&self, sym: SymIdx) -> &Symbol { - &self.symbols[sym.0 as usize] - } - - fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { - &mut self.symbols[sym.0 as usize] - } - - fn rule_data(&self, rule: RuleIdx) -> &Rule { - let sym = self.sym_data(rule.sym_idx()); - if rule.sym_rule_idx() >= sym.rules.len() { - println!("invalid rule index; {}", sym.name); - } - &sym.rules[rule.sym_rule_idx()] - } - - fn propagate_nullable(&mut self) { - for sym in self.symbols.iter_mut() { - if sym.rules.iter().any(|r| r.rhs.is_empty()) { - sym.nullable = true; - sym.rules.retain(|r| !r.rhs.is_empty()); - // re-number them - for (i, r) in sym.rules.iter_mut().enumerate() { - r.idx = sym.idx.rule_at(i); - } - } - } - loop { - let mut to_null = vec![]; - for sym in self.symbols.iter() { - for rule in sym.rules.iter() { - if rule.rhs.iter().all(|s| self.sym_data(*s).nullable) { - if !sym.nullable { - to_null.push(sym.idx); - } - } - } - } - if to_null.is_empty() { - break; - } - for sym in to_null { - self.sym_data_mut(sym).nullable = true; - } - } - } - - pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { - assert!(rhs.len() < mask32(DOT_POS_BITS) as usize); - assert!(!self.sym_data(lhs).is_terminal()); - let sym = self.sym_data_mut(lhs); - sym.rules.push(Rule { - idx: lhs.rule_at(sym.rules.len()), - rhs, - }); - } - - pub fn terminal(&mut self, bytes: ByteSet) -> SymIdx { - match self.terminals.get(&bytes) { - Some(sym) => *sym, - None => { - let mut name = format!("T:{}", bytes); - if name.len() > 40 { - name = format!("T@{}", self.terminals.len()); - } - let sym = self.fresh_symbol(&name); - self.sym_data_mut(sym).bytes = Some(bytes.clone()); - self.terminals.insert(bytes, sym); - sym - } - } - } - - pub fn sym_name(&self, sym: SymIdx) -> &str { - &self.symbols[sym.0 as usize].name - } - - fn rule_to_string(&self, rule: &Rule, dot: usize) -> String { - let lhs = self.sym_name(rule.lhs()); - let mut rhs = rule - .rhs - .iter() - .enumerate() - .map(|(i, s)| { - format!( - "{}{}", - if i == dot { "(*) " } else { "" }, - self.sym_name(*s) - ) - }) - .collect::>() - .join(" "); - if dot == rule.rhs.len() { - rhs.push_str(" (*)"); - } - format!("{} ::= {}", lhs, rhs) - } - - fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { - let sym_data = other.sym_data(sym); - if sym_data.is_terminal() { - self.terminal(sym_data.bytes.clone().unwrap()) - } else { - self.symbol(&sym_data.name) - } - } - - fn collapse_terminals(&self) -> Self { - let mut outp = Grammar::new(); - for sym in &self.symbols { - if sym.rules.is_empty() { - continue; - } - let mut rules_by_shape = FxHashMap::default(); - for rule in &sym.rules { - let shape = rule - .rhs - .iter() - .map(|s| { - if self.sym_data(*s).is_terminal() { - None - } else { - Some(*s) - } - }) - .collect::>(); - rules_by_shape - .entry(shape) - .or_insert_with(Vec::new) - .push(rule); - } - let lhs = outp.copy_from(self, sym.idx); - for rules in rules_by_shape.values() { - let rhs = rules[0] - .rhs - .iter() - .enumerate() - .map(|(i, s)| { - let sym = self.sym_data(*s); - if sym.is_terminal() { - let terminals = rules - .iter() - .map(|r| self.sym_data(r.rhs[i]).bytes.clone().unwrap()); - outp.terminal(ByteSet::from_sum(terminals)) - } else { - outp.copy_from(self, *s) - } - }) - .collect(); - outp.add_rule(lhs, rhs); - } - } - outp - } - - fn expand_shortcuts(&self) -> Self { - let mut use_count = vec![0; self.symbols.len()]; - for sym in &self.symbols { - for r in sym.rules.iter() { - for s in &r.rhs { - use_count[s.0 as usize] += 1; - } - } - } - - let mut repl = FxHashMap::default(); - for sym in &self.symbols { - if sym.idx == self.start() { - continue; - } - if sym.rules.len() == 1 - && (use_count[sym.idx.0 as usize] == 1 || sym.rules[0].rhs.len() == 1) - { - // eliminate sym.idx - repl.insert(sym.idx, sym.rules[0].rhs.clone()); - } - } - - // fix-point expand the mapping - loop { - let to_change = repl - .iter() - .filter_map(|(lhs, rhs)| { - let rhs2 = rhs - .iter() - .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) - .collect::>(); - assert!(rhs2.iter().all(|s| *s != *lhs), "cyclic?"); - if *rhs != rhs2 { - Some((*lhs, rhs2)) - } else { - None - } - }) - .collect::>(); - if to_change.is_empty() { - break; - } - for (lhs, rhs) in to_change { - repl.insert(lhs, rhs); - } - } - - let mut outp = Grammar::new(); - for sym in &self.symbols { - if repl.contains_key(&sym.idx) { - continue; - } - let lhs = outp.copy_from(self, sym.idx); - for rule in &sym.rules { - let rhs = rule - .rhs - .iter() - .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) - .map(|s| outp.copy_from(self, s)) - .collect(); - outp.add_rule(lhs, rhs); - } - } - outp - } - - pub fn optimize(&self) -> OptimizedGrammar { - let mut outp = self - .expand_shortcuts() - .collapse_terminals() - .expand_shortcuts(); - outp.propagate_nullable(); - OptimizedGrammar { - inner: Rc::new(outp), - } - } - - pub fn fresh_symbol(&mut self, name0: &str) -> SymIdx { - let mut name = name0.to_string(); - let mut idx = 2; - while self.symbol_by_name.contains_key(&name) { - name = format!("{}#{}", name0, idx); - idx += 1; - } - - let idx = SymIdx(self.symbols.len() as u32); - self.symbols.push(Symbol { - name: name.clone(), - bytes: None, - idx, - rules: vec![], - nullable: false, - }); - self.symbol_by_name.insert(name, idx); - idx - } - - pub fn symbol(&mut self, name: &str) -> SymIdx { - match self.symbol_by_name.get(name) { - Some(idx) => *idx, - None => self.fresh_symbol(name), - } - } - - fn after_dot(&self, item: Item) -> Option { - let rule = self.rule_data(item.rule_idx()); - if item.dot_pos() < rule.rhs.len() { - Some(rule.rhs[item.dot_pos()]) - } else { - None - } - } -} - -impl Debug for Grammar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for sym in &self.symbols { - match sym.bytes { - Some(ref bytes) if sym.name.starts_with("T@") => { - writeln!(f, "{} := {}", sym.name, bytes)? - } - _ => {} - } - } - let mut num_term = 0; - let mut num_rules = 0; - let mut num_non_term = 0; - for sym in &self.symbols { - if sym.is_terminal() { - num_term += 1; - } else { - num_non_term += 1; - num_rules += sym.rules.len(); - } - if sym.nullable { - writeln!(f, "{} ::= ϵ", sym.name)?; - } - for rule in &sym.rules { - writeln!(f, "{}", self.rule_to_string(rule, usize::MAX))?; - } - } - writeln!( - f, - "stats: {} terminals; {} non-terminals with {} rules\n", - num_term, num_non_term, num_rules - )?; - Ok(()) - } -} From 04acaf08e450529f999b68837d3215de22762b61 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 2 Mar 2024 02:58:17 +0000 Subject: [PATCH 154/301] fixes --- controllers/aici_abi/src/earley/grammar.rs | 8 +++++++- controllers/aici_abi/src/earley/parser.rs | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 83b8acba..16409972 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -377,7 +377,13 @@ impl OptGrammar { let mut outp = OptGrammar { start_symbol: OptSymIdx::NULL, terminals: vec![ByteSet::new()], - symbols: vec![], + symbols: vec![OptSymbol { + idx: OptSymIdx::NULL, + name: "NULL".to_string(), + is_terminal: true, + is_nullable: false, + rules: vec![], + }], rules: vec![], terminals_by_byte: vec![], }; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 2b969cfb..52b4b54f 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -105,7 +105,7 @@ impl Parser { let row_idx = self.rows.len() - 1; for item in &self.rows[row_idx].items { let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index(); - assert!(idx != 0); + // idx == 0 => completed if idx < allowed.len() && allowed[idx] { r.push(item.advance_dot()); } From 170a159f88b22e9eb1511466a99d7a0c01c107a4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 00:01:17 +0000 Subject: [PATCH 155/301] speed up --- .../aici_abi/src/earley/from_guidance.rs | 23 +- controllers/aici_abi/src/earley/grammar.rs | 4 + controllers/aici_abi/src/earley/parser.rs | 211 +++++++++++------- 3 files changed, 143 insertions(+), 95 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 5c93a8a9..73657d68 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -3,7 +3,10 @@ use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; use super::{guidance, ByteSet, Grammar, Parser}; -use crate::toktree::{Recognizer, SpecialToken, TokTrie}; +use crate::{ + earley::parser::ParseResult, + toktree::{Recognizer, SpecialToken, TokTrie}, +}; pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); @@ -83,7 +86,7 @@ impl Recognizer for Parser { fn special_allowed(&mut self, tok: SpecialToken) -> bool { if tok == SpecialToken::EndOfSentence { - self.curr_row().is_accepting() + self.is_accepting() } else { false } @@ -94,11 +97,10 @@ impl Recognizer for Parser { } fn try_push_byte(&mut self, byte: u8) -> bool { - let row = self.scan(byte); - if row.is_empty() { + let res = self.scan(byte); + if res == ParseResult::Reject { false } else { - self.push_row(row); true } } @@ -118,14 +120,16 @@ pub fn earley_test(trie: TokTrie) { println!("toks: {:?}", toks.len()); let mut parser = Parser::new(cfg.compile()); + let mut last_res = ParseResult::Reject; for b in input { - let row = parser.scan(*b); - if row.is_empty() { + last_res = parser.scan(*b); + if last_res == ParseResult::Reject { println!("reject"); break; } - // println!("row: {}", parser.row_to_string(&row)); - parser.push_row(row); + } + if last_res != ParseResult::Accept { + println!("final non-accept"); } #[cfg(not(target_arch = "wasm32"))] @@ -142,6 +146,7 @@ pub fn earley_test(trie: TokTrie) { let tok = *tok; let tt = std::time::Instant::now(); trie.compute_bias(&mut parser, &mut vob); + // parser.print_stats(); if !vob.is_allowed(tok) { println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); panic!(); diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 16409972..8fbb7f7c 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -357,6 +357,10 @@ impl OptGrammar { &self.terminals_by_byte[b as usize] } + pub fn terminal_allowed(&self, b: u8, sym: OptSymIdx) -> bool { + self.terminals_by_byte[b as usize].get(sym.0 as usize) == Some(true) + } + pub fn sym_idx_at(&self, idx: RuleIdx) -> OptSymIdx { self.rules[idx.0 as usize] } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 52b4b54f..735fb485 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,9 +1,6 @@ use std::{fmt::Debug, vec}; -use super::{ - byteset::byte_to_string, - grammar::{OptGrammar, OptSymIdx, RuleIdx}, -}; +use super::grammar::{OptGrammar, OptSymIdx, RuleIdx}; const DEBUG: bool = false; @@ -14,22 +11,25 @@ struct Item { sym_idx: OptSymIdx, } -pub struct Row { - token: u8, - position: usize, - // TODO index this by .after_dot() ? - items: Vec, - accepting: bool, +#[derive(Debug, Default)] +pub struct Stats { + pub rows: usize, + pub empty_rows: usize, + pub nontrivial_scans: usize, + pub scan_items: usize, + pub all_items: usize, } -impl Row { - pub fn is_empty(&self) -> bool { - self.items.is_empty() - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParseResult { + Accept, + Reject, + Continue, +} - pub fn is_accepting(&self) -> bool { - self.accepting - } +struct Row { + first_item: usize, + last_item: usize, } impl Item { @@ -58,30 +58,68 @@ impl Item { } } +#[derive(Default)] +struct Scratch { + row_start: usize, + row_end: usize, + items: Vec, +} + pub struct Parser { grammar: OptGrammar, + scratch: Scratch, rows: Vec, + stats: Stats, + is_accepting: bool, +} + +impl Scratch { + fn row_len(&self) -> usize { + self.row_end - self.row_start + } + + fn ensure_items(&mut self, n: usize) { + if self.items.len() < n { + let missing = n - self.items.len(); + self.items.reserve(missing); + unsafe { self.items.set_len(n) } + } + } + + fn just_add(&mut self, item: Item) { + self.ensure_items(self.row_end + 1); + self.items[self.row_end] = item; + self.row_end += 1; + } + + fn add_unique(&mut self, item: Item, _info: &str) { + if !self.items[self.row_start..self.row_end].contains(&item) { + self.just_add(item); + } + } } impl Parser { pub fn new(grammar: OptGrammar) -> Self { let start = grammar.start(); - let init_rules = grammar - .rules_of(start) - .iter() - .map(|r| Item::new(start, *r, 0)) - .collect(); let mut r = Parser { grammar, rows: vec![], + scratch: Scratch::default(), + stats: Stats::default(), + is_accepting: false, }; - // '0' token is bogus - let row = r.make_row(init_rules, 0); - println!("init: {}", r.row_to_string(&row)); - r.push_row(row); + for rule in r.grammar.rules_of(start).to_vec() { + r.scratch.add_unique(Item::new(start, rule, 0), "init"); + } + let _ = r.push_row(); r } + pub fn is_accepting(&self) -> bool { + self.is_accepting + } + fn item_to_string(&self, item: &Item) -> String { // let rule = self.grammar.rule_data(item.rule_idx()); // self.grammar.rule_to_string(rule, item.dot_pos()) @@ -91,71 +129,64 @@ impl Parser { ) } - pub fn row_to_string(&self, row: &Row) -> String { - let mut r = vec![format!("token: {}", byte_to_string(row.token))]; - for item in &row.items { - r.push(self.item_to_string(item)); - } - r.join("\n") + "\n" - } + // fn row_to_string(&self, row: &Row) -> String { + // // let mut r = vec![format!("token: {}", byte_to_string(row.token))]; + // // for item in &row.items { + // // r.push(self.item_to_string(item)); + // // } + // // r.join("\n") + "\n" + // "todo".to_string() + // } - pub fn scan(&self, b: u8) -> Row { - let allowed = self.grammar.terminals_by_byte(b); - let mut r = vec![]; + pub fn scan(&mut self, b: u8) -> ParseResult { let row_idx = self.rows.len() - 1; - for item in &self.rows[row_idx].items { + let last = self.rows[row_idx].last_item; + let mut i = self.rows[row_idx].first_item; + let n = last - i; + self.scratch.ensure_items(last + n + 100); + + let allowed = self.grammar.terminals_by_byte(b); + + // for next row: + self.scratch.row_start = last; + self.scratch.row_end = last; + + while i < last { + let item = self.scratch.items[i]; let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index(); // idx == 0 => completed if idx < allowed.len() && allowed[idx] { - r.push(item.advance_dot()); + self.scratch.just_add(item.advance_dot()); } + i += 1; } - self.make_row(r, b) + self.push_row() } pub fn pop_rows(&mut self, n: usize) { self.rows.drain(self.rows.len() - n..); } - pub fn curr_row(&self) -> &Row { - &self.rows[self.rows.len() - 1] - } - - pub fn push_row(&mut self, row: Row) { - assert!(row.position == self.rows.len()); - self.rows.push(row); + pub fn print_stats(&mut self) { + println!("stats: {:?}", self.stats); + self.stats = Stats::default(); } - fn make_row(&self, mut curr_row: Vec, token: u8) -> Row { + fn push_row(&mut self) -> ParseResult { let curr_idx = self.rows.len(); - let mut agenda = curr_row.clone(); + let mut agenda_ptr = self.scratch.row_start; + let mut predicated_syms = vec![]; - let mut accepting = false; - - if DEBUG { - let row0 = Row { - token, - position: curr_idx, - items: curr_row.clone(), - accepting, - }; - println!("row0: {}", self.row_to_string(&row0)); - } - while !agenda.is_empty() { - let item = agenda.pop().unwrap(); + self.stats.rows += 1; + self.is_accepting = false; + + while agenda_ptr < self.scratch.row_end { + let item = self.scratch.items[agenda_ptr]; + agenda_ptr += 1; if DEBUG { println!("from agenda: {}", self.item_to_string(&item)); } - let mut to_add = vec![]; - let mut add = |new_item: Item, tag: &str| { - if !to_add.contains(&new_item) { - to_add.push(new_item); - if DEBUG { - println!(" adding {}: {}", tag, self.item_to_string(&new_item)); - } - } - }; let lhs = item.sym_idx(); let rule = item.rule_idx(); @@ -164,44 +195,52 @@ impl Parser { if after_dot == OptSymIdx::NULL { // complete if lhs == self.grammar.start() { - accepting = true; + self.is_accepting = true; } if item.start_pos() < curr_idx { // if item.start_pos() == curr_idx, then we handled it above in the nullable check - for item in self.rows[item.start_pos()].items.iter() { + let srow = &self.rows[item.start_pos()]; + for i in srow.first_item..srow.last_item { + let item = self.scratch.items[i]; if self.grammar.sym_idx_at(item.rule_idx()) == lhs { - add(item.advance_dot(), "complete"); + self.scratch.add_unique(item.advance_dot(), "complete"); } } } } else { let sym_data = self.grammar.sym_data(after_dot); if sym_data.is_nullable { - add(item.advance_dot(), "null"); + self.scratch.add_unique(item.advance_dot(), "null"); } + // TODO this is slow if !predicated_syms.contains(&after_dot) { predicated_syms.push(after_dot); for rule in &sym_data.rules { let new_item = Item::new(after_dot, *rule, curr_idx); - add(new_item, "predict"); + self.scratch.add_unique(new_item, "predict"); } } } + } - for new_item in to_add { - if !curr_row.contains(&new_item) { - curr_row.push(new_item); - agenda.push(new_item); - } - } + let row_len = self.scratch.row_len(); + self.stats.all_items += row_len; + + if row_len == 0 { + assert!(!self.is_accepting); + return ParseResult::Reject; } - Row { - token, - position: curr_idx, - items: curr_row, - accepting, + self.rows.push(Row { + first_item: self.scratch.row_start, + last_item: self.scratch.row_end, + }); + + if self.is_accepting { + ParseResult::Accept + } else { + ParseResult::Continue } } } From 18cfcf8c87808290745ba44d3ce1427fdc7a1ddc Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 00:30:24 +0000 Subject: [PATCH 156/301] more speed up --- .../aici_abi/src/earley/from_guidance.rs | 77 ++++++++++++------- controllers/aici_abi/src/earley/parser.rs | 26 ++++--- 2 files changed, 64 insertions(+), 39 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 73657d68..5ac2c0e0 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use anyhow::Result; use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; @@ -132,40 +134,59 @@ pub fn earley_test(trie: TokTrie) { println!("final non-accept"); } - #[cfg(not(target_arch = "wasm32"))] - let t0 = std::time::Instant::now(); + const NUM_REP: usize = 200; + let mut durations = vec![]; + println!("start!"); - let mut line = 1; - let mut vob = trie.alloc_token_set(); + for _ in 0..NUM_REP { + #[cfg(not(target_arch = "wasm32"))] + let t0 = std::time::Instant::now(); - parser = Parser::new(cfg.compile()); - println!("start!"); - let mut times = vec![]; - - for tok in &toks { - let tok = *tok; - let tt = std::time::Instant::now(); - trie.compute_bias(&mut parser, &mut vob); - // parser.print_stats(); - if !vob.is_allowed(tok) { - println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); - panic!(); - } - for b in trie.token(tok) { - if *b == b'\n' { - line += 1; + let mut line = 1; + let mut vob = trie.alloc_token_set(); + + const COLLECT_TIMES: bool = false; + + parser = Parser::new(cfg.compile()); + let mut times = vec![]; + + for tok in &toks { + let tok = *tok; + let tt = std::time::Instant::now(); + trie.compute_bias(&mut parser, &mut vob); + // parser.print_stats(); + if !vob.is_allowed(tok) { + println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); + panic!(); + } + for b in trie.token(tok) { + if *b == b'\n' { + line += 1; + } + } + // println!("TOK: {} ===> {}", trie.token_dbg(tok), trie.token_set_dbg(&vob)); + trie.append_token(&mut parser, tok); + if COLLECT_TIMES { + times.push(tt.elapsed().as_micros() as u32); } } - // println!("TOK: {} ===> {}", trie.token_dbg(tok), trie.token_set_dbg(&vob)); - trie.append_token(&mut parser, tok); - times.push(tt.elapsed().as_micros() as u32); + + durations.push(t0.elapsed()); + + if COLLECT_TIMES { + println!("times: {:?}", times); + } + } + + let mut total = Duration::ZERO; + for d in &durations { + total += *d; } - #[cfg(not(target_arch = "wasm32"))] println!( - "time: {:?} ({:?}/tok)", - t0.elapsed(), - t0.elapsed() / toks.len() as u32 + "time: {:?} - {:?} - {:?}", + durations.iter().min().unwrap(), + total / durations.len() as u32, + durations.iter().max().unwrap(), ); - println!("times: {:?}", times); } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 735fb485..e6fb6980 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, vec}; +use std::{fmt::Debug, ops::Range, vec}; use super::grammar::{OptGrammar, OptSymIdx, RuleIdx}; @@ -32,6 +32,12 @@ struct Row { last_item: usize, } +impl Row { + fn item_indices(&self) -> Range { + self.first_item..self.last_item + } +} + impl Item { fn new(sym: OptSymIdx, rule: RuleIdx, start: usize) -> Self { Item { @@ -63,6 +69,7 @@ struct Scratch { row_start: usize, row_end: usize, items: Vec, + predicated_syms: Vec, } pub struct Parser { @@ -164,7 +171,8 @@ impl Parser { } pub fn pop_rows(&mut self, n: usize) { - self.rows.drain(self.rows.len() - n..); + unsafe { self.rows.set_len(self.rows.len() - n) } + // self.rows.drain(self.rows.len() - n..); } pub fn print_stats(&mut self) { @@ -176,7 +184,7 @@ impl Parser { let curr_idx = self.rows.len(); let mut agenda_ptr = self.scratch.row_start; - let mut predicated_syms = vec![]; + self.scratch.predicated_syms.clear(); self.stats.rows += 1; self.is_accepting = false; @@ -194,14 +202,11 @@ impl Parser { if after_dot == OptSymIdx::NULL { // complete - if lhs == self.grammar.start() { - self.is_accepting = true; - } + self.is_accepting = self.is_accepting || lhs == self.grammar.start(); if item.start_pos() < curr_idx { // if item.start_pos() == curr_idx, then we handled it above in the nullable check - let srow = &self.rows[item.start_pos()]; - for i in srow.first_item..srow.last_item { + for i in self.rows[item.start_pos()].item_indices() { let item = self.scratch.items[i]; if self.grammar.sym_idx_at(item.rule_idx()) == lhs { self.scratch.add_unique(item.advance_dot(), "complete"); @@ -213,9 +218,8 @@ impl Parser { if sym_data.is_nullable { self.scratch.add_unique(item.advance_dot(), "null"); } - // TODO this is slow - if !predicated_syms.contains(&after_dot) { - predicated_syms.push(after_dot); + if !self.scratch.predicated_syms.contains(&after_dot) { + self.scratch.predicated_syms.push(after_dot); for rule in &sym_data.rules { let new_item = Item::new(after_dot, *rule, curr_idx); self.scratch.add_unique(new_item, "predict"); From f3227fbe4f8d3773e2596e42a148c1b68afa9139 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 00:45:19 +0000 Subject: [PATCH 157/301] logging --- controllers/aici_abi/src/earley/from_guidance.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 5ac2c0e0..fc032412 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -134,7 +134,8 @@ pub fn earley_test(trie: TokTrie) { println!("final non-accept"); } - const NUM_REP: usize = 200; + const COLLECT_TIMES: bool = false; + const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 200 }; let mut durations = vec![]; println!("start!"); @@ -145,8 +146,6 @@ pub fn earley_test(trie: TokTrie) { let mut line = 1; let mut vob = trie.alloc_token_set(); - const COLLECT_TIMES: bool = false; - parser = Parser::new(cfg.compile()); let mut times = vec![]; From 599911a2932b1330a50a7eb1fa556b1da18602f9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 00:57:24 +0000 Subject: [PATCH 158/301] hashing items --- .../aici_abi/src/earley/from_guidance.rs | 6 +- controllers/aici_abi/src/earley/grammar.rs | 18 +++++ controllers/aici_abi/src/earley/parser.rs | 72 ++++++++++++++++--- 3 files changed, 86 insertions(+), 10 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index fc032412..7caaacad 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -163,7 +163,11 @@ pub fn earley_test(trie: TokTrie) { line += 1; } } - // println!("TOK: {} ===> {}", trie.token_dbg(tok), trie.token_set_dbg(&vob)); + // println!( + // "TOK: {} ===> {}", + // trie.token_dbg(tok), + // trie.token_set_dbg(&vob) + // ); trie.append_token(&mut parser, tok); if COLLECT_TIMES { times.push(tt.elapsed().as_micros() as u32); diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 8fbb7f7c..05ff1500 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -315,6 +315,24 @@ impl OptSymIdx { } } +pub trait SimpleHash { + fn simple_hash(&self) -> u32; + + fn mask64(&self) -> u64 { + 1 << (self.simple_hash() & 63) + } + + fn mask32(&self) -> u32 { + 1 << (self.simple_hash() & 31) + } +} + +impl SimpleHash for OptSymIdx { + fn simple_hash(&self) -> u32 { + (self.0 as u32).wrapping_mul(79667123) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct RuleIdx(u32); diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index e6fb6980..5cfceb45 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,6 +1,6 @@ -use std::{fmt::Debug, ops::Range, vec}; +use std::{fmt::Debug, hash::Hash, ops::Range, vec}; -use super::grammar::{OptGrammar, OptSymIdx, RuleIdx}; +use super::grammar::{OptGrammar, OptSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; @@ -64,12 +64,59 @@ impl Item { } } +impl SimpleHash for Item { + fn simple_hash(&self) -> u32 { + (self.rule_idx.as_index() as u32) + .wrapping_mul(16315967) + .wrapping_add((self.start as u32).wrapping_mul(33398653)) + } +} + +struct SimpleSet { + hash: u64, + items: Vec, +} + +impl Default for SimpleSet { + fn default() -> Self { + SimpleSet { + hash: 0, + items: vec![], + } + } +} + +impl SimpleSet { + fn clear(&mut self) { + self.hash = 0; + self.items.clear(); + } + + fn insert(&mut self, item: T) { + let mask = item.mask64(); + if (self.hash & mask) != 0 && self.items.contains(&item) { + return; + } + self.hash |= mask; + self.items.push(item); + } + + fn contains(&self, item: T) -> bool { + if (item.mask64() & self.hash) == 0 { + false + } else { + self.items.contains(&item) + } + } +} + #[derive(Default)] struct Scratch { row_start: usize, row_end: usize, + row_hash: u64, items: Vec, - predicated_syms: Vec, + predicated_syms: SimpleSet, } pub struct Parser { @@ -81,6 +128,12 @@ pub struct Parser { } impl Scratch { + fn new_row(&mut self, pos: usize) { + self.row_start = pos; + self.row_end = pos; + self.row_hash = 0; + } + fn row_len(&self) -> usize { self.row_end - self.row_start } @@ -97,10 +150,13 @@ impl Scratch { self.ensure_items(self.row_end + 1); self.items[self.row_end] = item; self.row_end += 1; + self.row_hash |= item.mask64(); } fn add_unique(&mut self, item: Item, _info: &str) { - if !self.items[self.row_start..self.row_end].contains(&item) { + if self.row_hash & item.mask64() == 0 + || !self.items[self.row_start..self.row_end].contains(&item) + { self.just_add(item); } } @@ -154,9 +210,7 @@ impl Parser { let allowed = self.grammar.terminals_by_byte(b); - // for next row: - self.scratch.row_start = last; - self.scratch.row_end = last; + self.scratch.new_row(last); while i < last { let item = self.scratch.items[i]; @@ -218,8 +272,8 @@ impl Parser { if sym_data.is_nullable { self.scratch.add_unique(item.advance_dot(), "null"); } - if !self.scratch.predicated_syms.contains(&after_dot) { - self.scratch.predicated_syms.push(after_dot); + if !self.scratch.predicated_syms.contains(after_dot) { + self.scratch.predicated_syms.insert(after_dot); for rule in &sym_data.rules { let new_item = Item::new(after_dot, *rule, curr_idx); self.scratch.add_unique(new_item, "predict"); From f5f9baa05d0b6b3ebc07bd2ff993371060bbded9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 02:13:28 +0000 Subject: [PATCH 159/301] more speed up --- controllers/aici_abi/disasm.sh | 14 ++++++++++++++ controllers/aici_abi/src/earley/parser.rs | 12 ++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) create mode 100755 controllers/aici_abi/disasm.sh diff --git a/controllers/aici_abi/disasm.sh b/controllers/aici_abi/disasm.sh new file mode 100755 index 00000000..1dad5260 --- /dev/null +++ b/controllers/aici_abi/disasm.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +RUSTFLAGS="--emit asm" cargo build --release --target x86_64-unknown-linux-gnu +F=`echo ../../target/x86_64-unknown-linux-gnu/release/deps/aici_abi-*.s` +# if $F has more than one file +if [ `echo $F | wc -w` -gt 1 ]; then + echo "More than one file found: $F; removing; try again" + rm -f $F + exit 1 +fi + +mkdir -p tmp + +rustfilt < $F > tmp/aici_abi.s diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 5cfceb45..77f71090 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -114,7 +114,6 @@ impl SimpleSet { struct Scratch { row_start: usize, row_end: usize, - row_hash: u64, items: Vec, predicated_syms: SimpleSet, } @@ -131,13 +130,13 @@ impl Scratch { fn new_row(&mut self, pos: usize) { self.row_start = pos; self.row_end = pos; - self.row_hash = 0; } fn row_len(&self) -> usize { self.row_end - self.row_start } + #[inline(always)] fn ensure_items(&mut self, n: usize) { if self.items.len() < n { let missing = n - self.items.len(); @@ -146,17 +145,16 @@ impl Scratch { } } + #[inline(always)] fn just_add(&mut self, item: Item) { self.ensure_items(self.row_end + 1); self.items[self.row_end] = item; self.row_end += 1; - self.row_hash |= item.mask64(); } + #[inline(always)] fn add_unique(&mut self, item: Item, _info: &str) { - if self.row_hash & item.mask64() == 0 - || !self.items[self.row_start..self.row_end].contains(&item) - { + if !self.items[self.row_start..self.row_end].contains(&item) { self.just_add(item); } } @@ -201,6 +199,7 @@ impl Parser { // "todo".to_string() // } + #[inline(always)] pub fn scan(&mut self, b: u8) -> ParseResult { let row_idx = self.rows.len() - 1; let last = self.rows[row_idx].last_item; @@ -234,6 +233,7 @@ impl Parser { self.stats = Stats::default(); } + #[inline(always)] fn push_row(&mut self) -> ParseResult { let curr_idx = self.rows.len(); let mut agenda_ptr = self.scratch.row_start; From da12c7733ad2634f23513eda5159ccbca40c54a5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 4 Mar 2024 23:35:20 +0000 Subject: [PATCH 160/301] disasm scripts --- controllers/aici_abi/annotate_asm.js | 103 +++++++++++++++++++++++++++ controllers/aici_abi/disasm.sh | 11 +-- 2 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 controllers/aici_abi/annotate_asm.js diff --git a/controllers/aici_abi/annotate_asm.js b/controllers/aici_abi/annotate_asm.js new file mode 100644 index 00000000..26e1a260 --- /dev/null +++ b/controllers/aici_abi/annotate_asm.js @@ -0,0 +1,103 @@ +const child_process = require("child_process") +const fs = require("fs") + +const sysroot = child_process.execSync("rustc --print sysroot").toString().trim() + +function main(sname, filter) { + if (!filter) { + console.error("please pass filter arg") + return + } + + const sections = {} + const files = [] + let idx = 0 + for (const sect of fs.readFileSync(sname, "utf8").split("\n\n")) { + idx++ + let sectId = "sect" + idx + let m = /^\t\.type\t(.*),@/m.exec(sect) + if (m) { + sectId = m[1] + } + + let outp = "" + for (const line of sect.split("\n")) { + if (line.startsWith(".Ltmp") || line.startsWith("\t.cfi_")) + continue + if (line.startsWith("\t.file\t")) { + m = /(\d+)\s+"([^"]+)"\s+"([^"]+)"/.exec(line) + if (!m) { + // console.error("Bad file line", line) + } else { + const folder = m[2].replace(/^\/rustc\/[^/]+/, sysroot + "/lib/rustlib/src/rust") + files[+m[1]] = folder + "/" + m[3] + } + continue + } + outp += line + "\n" + } + + sections[sectId] = outp + } + + const keys = Object.keys(sections).filter(k => k.includes(filter)) + if (keys.length > 1) { + const max = 50 + console.error("Multiple sections found for filter", filter, keys.slice(0, max).join("\n")) + if (keys.length > max) { + console.error("...") + } + return + } + if (keys.length === 0) { + console.error("No sections found for filter", filter) + return + } + + const filecontent = [] + + function fileLines(id) { + if (filecontent[id]) { + return filecontent[id] + } + const lines = fs.readFileSync(files[id], "utf8").split("\n") + filecontent[id] = lines + return lines + } + + let outp = "" + const labels = {} + for (let line of sections[keys[0]].split("\n")) { + if (line.startsWith("\t.loc\t")) { + const m = /\t.loc\t(\d+)\s+(\d+)\s+(\d+)/.exec(line) + const lineno = +m[2] + const lines = fileLines(+m[1]) + const filename = files[+m[1]] + let basename = filename.split("/").pop() + if (filename.startsWith(sysroot)) + basename = "[lib]" + basename + // outp += "// file://" + files[+m[1]] + "\n" + if (lines[lineno - 1] !== undefined) { + const tag = basename + ":" + lineno + outp += "// " + tag.padEnd(40, " ") + lines[lineno - 1] + "\n" + } + } else { + const m = /^(\.L\w+):/.exec(line) + if (m) { + labels[m[1]] = true + } + const words = line.split(/\s+/) + if (words.some(w => labels[w])) { + line += " // ===============================================> BACK" + } + outp += line + "\n" + } + } + + console.log("Section", keys[0], ":") + console.log(outp) +} + + +const args = process.argv.slice(2) +main(...args) diff --git a/controllers/aici_abi/disasm.sh b/controllers/aici_abi/disasm.sh index 1dad5260..41f9fb17 100755 --- a/controllers/aici_abi/disasm.sh +++ b/controllers/aici_abi/disasm.sh @@ -1,7 +1,9 @@ #!/bin/sh -RUSTFLAGS="--emit asm" cargo build --release --target x86_64-unknown-linux-gnu -F=`echo ../../target/x86_64-unknown-linux-gnu/release/deps/aici_abi-*.s` +TRG=`rustup show | head -1 | sed -e 's/.*: //'` +CRATE=`grep "^name =" Cargo.toml | head -1 | sed -e 's/.*= "//; s/"//'` +RUSTFLAGS="--emit asm" cargo build --release --target $TRG +F=`echo ../../target/$TRG/release/deps/$CRATE-*.s` # if $F has more than one file if [ `echo $F | wc -w` -gt 1 ]; then echo "More than one file found: $F; removing; try again" @@ -10,5 +12,6 @@ if [ `echo $F | wc -w` -gt 1 ]; then fi mkdir -p tmp - -rustfilt < $F > tmp/aici_abi.s +cp $F tmp/full.s +node annotate_asm.js tmp/full.s "$@" | rustfilt > tmp/func.s +ls -l tmp/func.s From 4992586d44404bc7bf6a4525503818b005139b1a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 01:14:02 +0000 Subject: [PATCH 161/301] use SimpleVob not Vob --- controllers/aici_abi/src/earley/from_guidance.rs | 5 +++-- controllers/aici_abi/src/earley/grammar.rs | 15 ++++++--------- controllers/aici_abi/src/earley/parser.rs | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 7caaacad..e331ff54 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -140,8 +140,6 @@ pub fn earley_test(trie: TokTrie) { println!("start!"); for _ in 0..NUM_REP { - #[cfg(not(target_arch = "wasm32"))] - let t0 = std::time::Instant::now(); let mut line = 1; let mut vob = trie.alloc_token_set(); @@ -149,6 +147,9 @@ pub fn earley_test(trie: TokTrie) { parser = Parser::new(cfg.compile()); let mut times = vec![]; + #[cfg(not(target_arch = "wasm32"))] + let t0 = std::time::Instant::now(); + for tok in &toks { let tok = *tok; let tt = std::time::Instant::now(); diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 05ff1500..b86f766d 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -1,8 +1,9 @@ use std::fmt::Debug; +use crate::svob::SimpleVob; + use super::ByteSet; use rustc_hash::FxHashMap; -use vob::Vob; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct SymIdx(u32); @@ -359,7 +360,7 @@ pub struct OptGrammar { terminals: Vec, symbols: Vec, rules: Vec, - terminals_by_byte: Vec, + terminals_by_byte: Vec, } impl OptGrammar { @@ -371,14 +372,10 @@ impl OptGrammar { &mut self.symbols[sym.0 as usize] } - pub fn terminals_by_byte(&self, b: u8) -> &Vob { + pub fn terminals_by_byte(&self, b: u8) -> &SimpleVob { &self.terminals_by_byte[b as usize] } - pub fn terminal_allowed(&self, b: u8, sym: OptSymIdx) -> bool { - self.terminals_by_byte[b as usize].get(sym.0 as usize) == Some(true) - } - pub fn sym_idx_at(&self, idx: RuleIdx) -> OptSymIdx { self.rules[idx.0 as usize] } @@ -480,10 +477,10 @@ impl OptGrammar { } for b in 0..=255 { - let mut v = Vob::from_elem(false, outp.terminals.len()); + let mut v = SimpleVob::alloc(outp.terminals.len()); for (i, bytes) in outp.terminals.iter().enumerate() { if bytes.contains(b as u8) { - v.set(i, true); + v.allow_token(i as u32); } } outp.terminals_by_byte.push(v); diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 77f71090..e25fb15a 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -250,11 +250,11 @@ impl Parser { println!("from agenda: {}", self.item_to_string(&item)); } - let lhs = item.sym_idx(); let rule = item.rule_idx(); let after_dot = self.grammar.sym_idx_at(rule); if after_dot == OptSymIdx::NULL { + let lhs = item.sym_idx(); // complete self.is_accepting = self.is_accepting || lhs == self.grammar.start(); From 3e55ff643a4e207cb2f72d7988481c4a7e2c89b7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 01:21:00 +0000 Subject: [PATCH 162/301] remove duplicate sym_idx from Item --- controllers/aici_abi/src/earley/grammar.rs | 15 +++++++++++++++ controllers/aici_abi/src/earley/parser.rs | 21 ++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index b86f766d..1e5d1c9c 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -360,10 +360,17 @@ pub struct OptGrammar { terminals: Vec, symbols: Vec, rules: Vec, + rule_idx_to_sym_idx: Vec, terminals_by_byte: Vec, } +const RULE_SHIFT: usize = 2; + impl OptGrammar { + pub fn sym_idx_of(&self, rule: RuleIdx) -> OptSymIdx { + self.rule_idx_to_sym_idx[rule.as_index() >> RULE_SHIFT] + } + pub fn sym_data(&self, sym: OptSymIdx) -> &OptSymbol { &self.symbols[sym.0 as usize] } @@ -404,6 +411,7 @@ impl OptGrammar { rules: vec![], }], rules: vec![], + rule_idx_to_sym_idx: vec![], terminals_by_byte: vec![], }; let mut sym_map = FxHashMap::default(); @@ -449,6 +457,13 @@ impl OptGrammar { } outp.rules.push(OptSymIdx::NULL); } + while outp.rules.len() % (1 << RULE_SHIFT) != 0 { + outp.rules.push(OptSymIdx::NULL); + } + let rlen = outp.rules.len() >> RULE_SHIFT; + while outp.rule_idx_to_sym_idx.len() < rlen { + outp.rule_idx_to_sym_idx.push(idx); + } } loop { diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index e25fb15a..40ee8616 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -8,7 +8,6 @@ const DEBUG: bool = false; struct Item { rule_idx: RuleIdx, start: u32, - sym_idx: OptSymIdx, } #[derive(Debug, Default)] @@ -39,9 +38,8 @@ impl Row { } impl Item { - fn new(sym: OptSymIdx, rule: RuleIdx, start: usize) -> Self { + fn new(rule: RuleIdx, start: usize) -> Self { Item { - sym_idx: sym, rule_idx: rule, start: start.try_into().unwrap(), } @@ -51,16 +49,12 @@ impl Item { self.rule_idx } - fn sym_idx(&self) -> OptSymIdx { - self.sym_idx - } - fn start_pos(&self) -> usize { self.start as usize } fn advance_dot(&self) -> Self { - Item::new(self.sym_idx, self.rule_idx.advance(), self.start_pos()) + Item::new(self.rule_idx.advance(), self.start_pos()) } } @@ -171,7 +165,7 @@ impl Parser { is_accepting: false, }; for rule in r.grammar.rules_of(start).to_vec() { - r.scratch.add_unique(Item::new(start, rule, 0), "init"); + r.scratch.add_unique(Item::new(rule, 0), "init"); } let _ = r.push_row(); r @@ -184,10 +178,7 @@ impl Parser { fn item_to_string(&self, item: &Item) -> String { // let rule = self.grammar.rule_data(item.rule_idx()); // self.grammar.rule_to_string(rule, item.dot_pos()) - format!( - "item: rule: {:?}, dot: {:?}, start: {}", - item.rule_idx, item.sym_idx, item.start - ) + format!("item: rule: {:?}, start: {}", item.rule_idx, item.start) } // fn row_to_string(&self, row: &Row) -> String { @@ -254,7 +245,7 @@ impl Parser { let after_dot = self.grammar.sym_idx_at(rule); if after_dot == OptSymIdx::NULL { - let lhs = item.sym_idx(); + let lhs = self.grammar.sym_idx_of(item.rule_idx()); // complete self.is_accepting = self.is_accepting || lhs == self.grammar.start(); @@ -275,7 +266,7 @@ impl Parser { if !self.scratch.predicated_syms.contains(after_dot) { self.scratch.predicated_syms.insert(after_dot); for rule in &sym_data.rules { - let new_item = Item::new(after_dot, *rule, curr_idx); + let new_item = Item::new(*rule, curr_idx); self.scratch.add_unique(new_item, "predict"); } } From 41dbee17aa477eb0dc0808dac981f670a5e024a5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 17:41:03 +0000 Subject: [PATCH 163/301] working on perf measurements --- controllers/aici_abi/Cargo.toml | 1 + controllers/aici_abi/src/bench.rs | 60 +++++++++++++++++++ .../aici_abi/src/earley/from_guidance.rs | 38 ++++++++++-- controllers/aici_abi/src/earley/parser.rs | 1 + controllers/aici_abi/src/lib.rs | 3 + 5 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 controllers/aici_abi/src/bench.rs diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index ac8cf6eb..c0091116 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -19,6 +19,7 @@ lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } quick-protobuf = { version = "0.8.1", optional = true } +perfcnt = "0.8.0" [features] default = ["cfg", "rx", "earley"] diff --git a/controllers/aici_abi/src/bench.rs b/controllers/aici_abi/src/bench.rs new file mode 100644 index 00000000..599d811d --- /dev/null +++ b/controllers/aici_abi/src/bench.rs @@ -0,0 +1,60 @@ +use perfcnt::{ + linux::{HardwareEventType, PerfCounterBuilderLinux}, + AbstractPerfCounter, +}; + +pub struct BenchmarkState { + pub times: Vec, + cnt: perfcnt::PerfCounter, +} + +impl BenchmarkState { + pub fn new() -> Self { + let cnt = PerfCounterBuilderLinux::from_hardware_event(HardwareEventType::Instructions) + .finish() + .expect("Failed to create counter"); + + BenchmarkState { times: vec![], cnt } + } + + pub fn measure(&mut self, f: impl FnOnce()) { + let t0 = std::time::Instant::now(); + self.cnt.reset().unwrap(); + self.cnt.start().unwrap(); + f(); + self.cnt.stop().unwrap(); + let res = self.cnt.read().unwrap(); + let _res = t0.elapsed().as_nanos() as u64; + self.times.push(res); + } + + pub fn is_done(&self) -> bool { + self.times.len() >= 20 + } + + pub fn print(&self) { + let avg = self.times.iter().sum::() / self.times.len() as u64; + let min = *self.times.iter().min().unwrap(); + let max = *self.times.iter().max().unwrap(); + let (t10, median, b10) = { + let mut sorted = self.times.clone(); + sorted.sort(); + ( + sorted[sorted.len() / 10], + sorted[sorted.len() / 2], + sorted[sorted.len() * 9 / 10], + ) + }; + let to_m = |x| x as f64 / 1_000_000.0; + // println!("times: {:?}", self.times); + println!( + "Cycles: min:{:.3}-{:.3} med:{:.3} avg:{:.3} max:{:.3}-{:.3}", + to_m(min), + to_m(t10), + to_m(median), + to_m(avg), + to_m(b10), + to_m(max) + ); + } +} diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index e331ff54..e6014525 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -135,12 +135,16 @@ pub fn earley_test(trie: TokTrie) { } const COLLECT_TIMES: bool = false; - const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 200 }; + const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 1000 }; let mut durations = vec![]; println!("start!"); + let mut min_us = 100000; + let mut dur2 = vec![]; - for _ in 0..NUM_REP { + let mut btest = crate::bench::BenchmarkState::new(); + let max_tidx = 4; + for r in 0..NUM_REP { let mut line = 1; let mut vob = trie.alloc_token_set(); @@ -150,10 +154,21 @@ pub fn earley_test(trie: TokTrie) { #[cfg(not(target_arch = "wasm32"))] let t0 = std::time::Instant::now(); - for tok in &toks { + for (tidx, tok) in toks.iter().enumerate() { let tok = *tok; let tt = std::time::Instant::now(); - trie.compute_bias(&mut parser, &mut vob); + if tidx == max_tidx { + btest.measure(|| { + trie.compute_bias(&mut parser, &mut vob); + }); + } else { + trie.compute_bias(&mut parser, &mut vob); + } + + if r > 0 && tidx > max_tidx { + break; + } + // parser.print_stats(); if !vob.is_allowed(tok) { println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); @@ -170,6 +185,13 @@ pub fn earley_test(trie: TokTrie) { // trie.token_set_dbg(&vob) // ); trie.append_token(&mut parser, tok); + let tm = tt.elapsed().as_micros() as u32; + if tm > 1000 { + dur2.push(tm); + } + if tm > 1000 && tm < min_us { + min_us = tm; + } if COLLECT_TIMES { times.push(tt.elapsed().as_micros() as u32); } @@ -182,15 +204,21 @@ pub fn earley_test(trie: TokTrie) { } } + btest.print(); + let mut total = Duration::ZERO; for d in &durations { total += *d; } + dur2.sort(); + println!( - "time: {:?} - {:?} - {:?}", + "time: {:?} - {:?} - {:?} {} {}", durations.iter().min().unwrap(), total / durations.len() as u32, durations.iter().max().unwrap(), + min_us, + dur2[dur2.len() / 2] ); } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 40ee8616..8000685c 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -4,6 +4,7 @@ use super::grammar::{OptGrammar, OptSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; +//#[repr(align(8))] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct Item { rule_idx: RuleIdx, diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 9307df1a..789699fb 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -19,6 +19,9 @@ pub mod rx; #[cfg(feature = "earley")] pub mod earley; +#[cfg(target_os = "linux")] +pub mod bench; + pub mod substring; pub type TokenId = bytes::TokenId; From 37d48088b48580e94c15d8aa00442854384db8c8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 17:41:10 +0000 Subject: [PATCH 164/301] Revert "working on perf measurements" This reverts commit 41dbee17aa477eb0dc0808dac981f670a5e024a5. --- controllers/aici_abi/Cargo.toml | 1 - controllers/aici_abi/src/bench.rs | 60 ------------------- .../aici_abi/src/earley/from_guidance.rs | 38 ++---------- controllers/aici_abi/src/earley/parser.rs | 1 - controllers/aici_abi/src/lib.rs | 3 - 5 files changed, 5 insertions(+), 98 deletions(-) delete mode 100644 controllers/aici_abi/src/bench.rs diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index c0091116..ac8cf6eb 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -19,7 +19,6 @@ lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } quick-protobuf = { version = "0.8.1", optional = true } -perfcnt = "0.8.0" [features] default = ["cfg", "rx", "earley"] diff --git a/controllers/aici_abi/src/bench.rs b/controllers/aici_abi/src/bench.rs deleted file mode 100644 index 599d811d..00000000 --- a/controllers/aici_abi/src/bench.rs +++ /dev/null @@ -1,60 +0,0 @@ -use perfcnt::{ - linux::{HardwareEventType, PerfCounterBuilderLinux}, - AbstractPerfCounter, -}; - -pub struct BenchmarkState { - pub times: Vec, - cnt: perfcnt::PerfCounter, -} - -impl BenchmarkState { - pub fn new() -> Self { - let cnt = PerfCounterBuilderLinux::from_hardware_event(HardwareEventType::Instructions) - .finish() - .expect("Failed to create counter"); - - BenchmarkState { times: vec![], cnt } - } - - pub fn measure(&mut self, f: impl FnOnce()) { - let t0 = std::time::Instant::now(); - self.cnt.reset().unwrap(); - self.cnt.start().unwrap(); - f(); - self.cnt.stop().unwrap(); - let res = self.cnt.read().unwrap(); - let _res = t0.elapsed().as_nanos() as u64; - self.times.push(res); - } - - pub fn is_done(&self) -> bool { - self.times.len() >= 20 - } - - pub fn print(&self) { - let avg = self.times.iter().sum::() / self.times.len() as u64; - let min = *self.times.iter().min().unwrap(); - let max = *self.times.iter().max().unwrap(); - let (t10, median, b10) = { - let mut sorted = self.times.clone(); - sorted.sort(); - ( - sorted[sorted.len() / 10], - sorted[sorted.len() / 2], - sorted[sorted.len() * 9 / 10], - ) - }; - let to_m = |x| x as f64 / 1_000_000.0; - // println!("times: {:?}", self.times); - println!( - "Cycles: min:{:.3}-{:.3} med:{:.3} avg:{:.3} max:{:.3}-{:.3}", - to_m(min), - to_m(t10), - to_m(median), - to_m(avg), - to_m(b10), - to_m(max) - ); - } -} diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index e6014525..e331ff54 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -135,16 +135,12 @@ pub fn earley_test(trie: TokTrie) { } const COLLECT_TIMES: bool = false; - const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 1000 }; + const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 200 }; let mut durations = vec![]; println!("start!"); - let mut min_us = 100000; - let mut dur2 = vec![]; - let mut btest = crate::bench::BenchmarkState::new(); - let max_tidx = 4; + for _ in 0..NUM_REP { - for r in 0..NUM_REP { let mut line = 1; let mut vob = trie.alloc_token_set(); @@ -154,21 +150,10 @@ pub fn earley_test(trie: TokTrie) { #[cfg(not(target_arch = "wasm32"))] let t0 = std::time::Instant::now(); - for (tidx, tok) in toks.iter().enumerate() { + for tok in &toks { let tok = *tok; let tt = std::time::Instant::now(); - if tidx == max_tidx { - btest.measure(|| { - trie.compute_bias(&mut parser, &mut vob); - }); - } else { - trie.compute_bias(&mut parser, &mut vob); - } - - if r > 0 && tidx > max_tidx { - break; - } - + trie.compute_bias(&mut parser, &mut vob); // parser.print_stats(); if !vob.is_allowed(tok) { println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); @@ -185,13 +170,6 @@ pub fn earley_test(trie: TokTrie) { // trie.token_set_dbg(&vob) // ); trie.append_token(&mut parser, tok); - let tm = tt.elapsed().as_micros() as u32; - if tm > 1000 { - dur2.push(tm); - } - if tm > 1000 && tm < min_us { - min_us = tm; - } if COLLECT_TIMES { times.push(tt.elapsed().as_micros() as u32); } @@ -204,21 +182,15 @@ pub fn earley_test(trie: TokTrie) { } } - btest.print(); - let mut total = Duration::ZERO; for d in &durations { total += *d; } - dur2.sort(); - println!( - "time: {:?} - {:?} - {:?} {} {}", + "time: {:?} - {:?} - {:?}", durations.iter().min().unwrap(), total / durations.len() as u32, durations.iter().max().unwrap(), - min_us, - dur2[dur2.len() / 2] ); } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 8000685c..40ee8616 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -4,7 +4,6 @@ use super::grammar::{OptGrammar, OptSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; -//#[repr(align(8))] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct Item { rule_idx: RuleIdx, diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 789699fb..9307df1a 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -19,9 +19,6 @@ pub mod rx; #[cfg(feature = "earley")] pub mod earley; -#[cfg(target_os = "linux")] -pub mod bench; - pub mod substring; pub type TokenId = bytes::TokenId; From fe84229d360b825086bf3a08916a283c4819c110 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 19:37:08 +0000 Subject: [PATCH 165/301] more stats --- .../aici_abi/src/earley/from_guidance.rs | 59 ++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index e331ff54..49d2ede1 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use anyhow::Result; use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; @@ -135,12 +133,14 @@ pub fn earley_test(trie: TokTrie) { } const COLLECT_TIMES: bool = false; - const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 200 }; + const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 500 }; let mut durations = vec![]; + let mut durations_us = vec![]; println!("start!"); - for _ in 0..NUM_REP { + let num_tok = 4; + for _ in 0..NUM_REP { let mut line = 1; let mut vob = trie.alloc_token_set(); @@ -150,10 +150,13 @@ pub fn earley_test(trie: TokTrie) { #[cfg(not(target_arch = "wasm32"))] let t0 = std::time::Instant::now(); - for tok in &toks { + for (idx, tok) in toks.iter().take(num_tok).enumerate() { let tok = *tok; let tt = std::time::Instant::now(); trie.compute_bias(&mut parser, &mut vob); + if idx == num_tok - 1 { + durations_us.push(tt.elapsed().as_micros() as u64); + } // parser.print_stats(); if !vob.is_allowed(tok) { println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); @@ -175,22 +178,50 @@ pub fn earley_test(trie: TokTrie) { } } - durations.push(t0.elapsed()); + durations.push(t0.elapsed().as_micros() as u64); if COLLECT_TIMES { println!("times: {:?}", times); } } - let mut total = Duration::ZERO; - for d in &durations { - total += *d; - } + durations.sort(); + durations_us.sort(); println!( - "time: {:?} - {:?} - {:?}", - durations.iter().min().unwrap(), - total / durations.len() as u32, - durations.iter().max().unwrap(), + "time: {},{}", + vec_stats(&durations), + vec_stats(&durations_us), ); + + println!( + "time_us: {:?},{:?},{:?}", + durations_us.iter().min().unwrap(), + durations_us[durations_us.len() / 2], + durations_us.iter().max().unwrap(), + ); +} + +fn vec_stats(times: &[u64]) -> String { + let mut times = times.to_vec(); + times.sort(); + let sum0 = times.iter().sum::(); + let drop = times.len() / 10; + let len2 = times.len() - 2 * drop; + let sum1 = times[drop..times.len() - drop].iter().sum::(); + // t0,t10,t50,t90,t100,avg,avg90 + let stats = vec![ + times[0], + times[drop], + times[times.len() / 2], + times[times.len() - drop], + times[times.len() - 1], + sum0 / times.len() as u64, + sum1 / len2 as u64, + ]; + stats + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") } From d57ff14b3b26889e64c277f4a807804a39ca92f4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 19:53:16 +0000 Subject: [PATCH 166/301] minor speedup --- .../aici_abi/src/earley/from_guidance.rs | 40 ++----------------- controllers/aici_abi/src/earley/grammar.rs | 4 ++ controllers/aici_abi/src/earley/parser.rs | 24 ++++++----- 3 files changed, 22 insertions(+), 46 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 49d2ede1..94c6f1c2 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -188,40 +188,8 @@ pub fn earley_test(trie: TokTrie) { durations.sort(); durations_us.sort(); - println!( - "time: {},{}", - vec_stats(&durations), - vec_stats(&durations_us), - ); - - println!( - "time_us: {:?},{:?},{:?}", - durations_us.iter().min().unwrap(), - durations_us[durations_us.len() / 2], - durations_us.iter().max().unwrap(), - ); -} - -fn vec_stats(times: &[u64]) -> String { - let mut times = times.to_vec(); - times.sort(); - let sum0 = times.iter().sum::(); - let drop = times.len() / 10; - let len2 = times.len() - 2 * drop; - let sum1 = times[drop..times.len() - drop].iter().sum::(); - // t0,t10,t50,t90,t100,avg,avg90 - let stats = vec![ - times[0], - times[drop], - times[times.len() / 2], - times[times.len() - drop], - times[times.len() - 1], - sum0 / times.len() as u64, - sum1 / len2 as u64, - ]; - stats - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(",") + let min_us = *durations_us.iter().min().unwrap(); + // println!("min_time_us: {:?}", min_us); + // for ~5ms 0.1ms is the precision we expect + println!("min_time_ms: {:.1}", min_us as f64 / 1000.0); } diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 1e5d1c9c..7b968710 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -338,6 +338,10 @@ impl SimpleHash for OptSymIdx { pub struct RuleIdx(u32); impl RuleIdx { + pub fn from_index(idx: u32) -> Self { + RuleIdx(idx) + } + pub fn advance(&self) -> RuleIdx { RuleIdx(self.0 + 1) } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 40ee8616..431d977d 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -6,8 +6,7 @@ const DEBUG: bool = false; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct Item { - rule_idx: RuleIdx, - start: u32, + data: u64, } #[derive(Debug, Default)] @@ -40,29 +39,30 @@ impl Row { impl Item { fn new(rule: RuleIdx, start: usize) -> Self { Item { - rule_idx: rule, - start: start.try_into().unwrap(), + data: rule.as_index() as u64 | ((start as u64) << 32), } } fn rule_idx(&self) -> RuleIdx { - self.rule_idx + RuleIdx::from_index(self.data as u32) } fn start_pos(&self) -> usize { - self.start as usize + (self.data >> 32) as usize } fn advance_dot(&self) -> Self { - Item::new(self.rule_idx.advance(), self.start_pos()) + Item { + data: self.data + 1, + } } } impl SimpleHash for Item { fn simple_hash(&self) -> u32 { - (self.rule_idx.as_index() as u32) + (self.rule_idx().as_index() as u32) .wrapping_mul(16315967) - .wrapping_add((self.start as u32).wrapping_mul(33398653)) + .wrapping_add((self.start_pos() as u32).wrapping_mul(33398653)) } } @@ -178,7 +178,11 @@ impl Parser { fn item_to_string(&self, item: &Item) -> String { // let rule = self.grammar.rule_data(item.rule_idx()); // self.grammar.rule_to_string(rule, item.dot_pos()) - format!("item: rule: {:?}, start: {}", item.rule_idx, item.start) + format!( + "item: rule: {:?}, start: {}", + item.rule_idx(), + item.start_pos() + ) } // fn row_to_string(&self, row: &Row) -> String { From 4048e4f4c93a92760c2f20aead53e5c55ac99127 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 21:43:04 +0000 Subject: [PATCH 167/301] clone optgrammar --- controllers/aici_abi/src/earley/from_guidance.rs | 6 ++++-- controllers/aici_abi/src/earley/grammar.rs | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/from_guidance.rs index 94c6f1c2..0a10e7a8 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/from_guidance.rs @@ -119,7 +119,9 @@ pub fn earley_test(trie: TokTrie) { let toks = trie.greedy_tokenize(input); println!("toks: {:?}", toks.len()); - let mut parser = Parser::new(cfg.compile()); + let grm = cfg.compile(); + + let mut parser = Parser::new(grm.clone()); let mut last_res = ParseResult::Reject; for b in input { last_res = parser.scan(*b); @@ -144,7 +146,7 @@ pub fn earley_test(trie: TokTrie) { let mut line = 1; let mut vob = trie.alloc_token_set(); - parser = Parser::new(cfg.compile()); + parser = Parser::new(grm.clone()); let mut times = vec![]; #[cfg(not(target_arch = "wasm32"))] diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 7b968710..0da7cf5f 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -351,6 +351,7 @@ impl RuleIdx { } } +#[derive(Clone)] pub struct OptSymbol { pub idx: OptSymIdx, pub name: String, @@ -359,6 +360,7 @@ pub struct OptSymbol { pub rules: Vec, } +#[derive(Clone)] pub struct OptGrammar { start_symbol: OptSymIdx, terminals: Vec, From 62c2ab8296553ffd9eb6e021621dd1bac31542d9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 5 Mar 2024 21:56:31 +0000 Subject: [PATCH 168/301] clean up file structure --- controllers/aici_abi/Cargo.toml | 6 ++- .../src/earley/{from_guidance.rs => bench.rs} | 41 ++----------------- controllers/aici_abi/src/earley/mod.rs | 8 ++-- controllers/aici_abi/src/earley/parser.rs | 33 +++++++++++++++ 4 files changed, 45 insertions(+), 43 deletions(-) rename controllers/aici_abi/src/earley/{from_guidance.rs => bench.rs} (86%) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index ac8cf6eb..0acd2a3c 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -18,13 +18,15 @@ lrpar = { version = "0.13.3", optional = true } lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } -quick-protobuf = { version = "0.8.1", optional = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quick-protobuf = { version = "0.8.1" } [features] default = ["cfg", "rx", "earley"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] -earley = ["rx", "dep:vob", "dep:rustc-hash", "dep:quick-protobuf"] +earley = ["rx", "dep:rustc-hash"] [[bin]] name = "yesno" diff --git a/controllers/aici_abi/src/earley/from_guidance.rs b/controllers/aici_abi/src/earley/bench.rs similarity index 86% rename from controllers/aici_abi/src/earley/from_guidance.rs rename to controllers/aici_abi/src/earley/bench.rs index 0a10e7a8..b2937a28 100644 --- a/controllers/aici_abi/src/earley/from_guidance.rs +++ b/controllers/aici_abi/src/earley/bench.rs @@ -3,10 +3,7 @@ use quick_protobuf::MessageRead; use rustc_hash::FxHashSet; use super::{guidance, ByteSet, Grammar, Parser}; -use crate::{ - earley::parser::ParseResult, - toktree::{Recognizer, SpecialToken, TokTrie}, -}; +use crate::earley::parser::ParseResult; pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); @@ -75,39 +72,7 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { Ok(grm) } -impl Recognizer for Parser { - fn pop_bytes(&mut self, num: usize) { - self.pop_rows(num); - } - - fn collapse(&mut self) { - // does nothing - we need to keep the entire state - } - - fn special_allowed(&mut self, tok: SpecialToken) -> bool { - if tok == SpecialToken::EndOfSentence { - self.is_accepting() - } else { - false - } - } - - fn trie_finished(&mut self) { - // do nothing? - } - - fn try_push_byte(&mut self, byte: u8) -> bool { - let res = self.scan(byte); - if res == ParseResult::Reject { - false - } else { - true - } - } -} - -#[allow(dead_code)] -pub fn earley_test(trie: TokTrie) { +pub fn earley_test(trie: crate::toktree::TokTrie) { let g_bytes = include_bytes!("../../grammars/json0.guidance"); let cfg = earley_grm_from_guidance(g_bytes).unwrap(); // println!("cfg0: {:?}", cfg); @@ -117,7 +82,7 @@ pub fn earley_test(trie: TokTrie) { let input = r#"{"name":"Joe","info":{"foo":10,"bar":"20"}}"#.as_bytes(); let toks = trie.greedy_tokenize(input); - println!("toks: {:?}", toks.len()); + println!("tokens: {:?}", toks.len()); let grm = cfg.compile(); diff --git a/controllers/aici_abi/src/earley/mod.rs b/controllers/aici_abi/src/earley/mod.rs index 65c7de23..105c4d32 100644 --- a/controllers/aici_abi/src/earley/mod.rs +++ b/controllers/aici_abi/src/earley/mod.rs @@ -1,10 +1,12 @@ mod byteset; -mod from_guidance; mod grammar; -mod guidance; mod parser; pub use byteset::ByteSet; -pub use from_guidance::earley_test; pub use parser::Parser; pub use grammar::Grammar; + +#[cfg(not(target_arch = "wasm32"))] +mod guidance; +#[cfg(not(target_arch = "wasm32"))] +pub mod bench; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 431d977d..267bbc92 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -1,5 +1,7 @@ use std::{fmt::Debug, hash::Hash, ops::Range, vec}; +use crate::toktree::{Recognizer, SpecialToken}; + use super::grammar::{OptGrammar, OptSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; @@ -297,3 +299,34 @@ impl Parser { } } } + +impl Recognizer for Parser { + fn pop_bytes(&mut self, num: usize) { + self.pop_rows(num); + } + + fn collapse(&mut self) { + // does nothing - we need to keep the entire state + } + + fn special_allowed(&mut self, tok: SpecialToken) -> bool { + if tok == SpecialToken::EndOfSentence { + self.is_accepting() + } else { + false + } + } + + fn trie_finished(&mut self) { + // do nothing? + } + + fn try_push_byte(&mut self, byte: u8) -> bool { + let res = self.scan(byte); + if res == ParseResult::Reject { + false + } else { + true + } + } +} From 0f16a82801ede89df4f2f6a92ab4b98916050980 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Mar 2024 16:40:11 +0000 Subject: [PATCH 169/301] byteset by ref --- controllers/aici_abi/src/earley/bench.rs | 4 ++-- controllers/aici_abi/src/earley/byteset.rs | 19 ++++++++++++++++++- controllers/aici_abi/src/earley/grammar.rs | 10 +++++----- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/controllers/aici_abi/src/earley/bench.rs b/controllers/aici_abi/src/earley/bench.rs index b2937a28..b741869f 100644 --- a/controllers/aici_abi/src/earley/bench.rs +++ b/controllers/aici_abi/src/earley/bench.rs @@ -20,11 +20,11 @@ pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { } guidance::mod_GrammarFunction::OneOffunction_type::byte(n) => { assert!(n.byte.len() == 1); - grm.terminal(ByteSet::from_range(n.byte[0], n.byte[0])) + grm.terminal(&ByteSet::from_range(n.byte[0], n.byte[0])) } guidance::mod_GrammarFunction::OneOffunction_type::byte_range(n) => { assert!(n.byte_range.len() == 2); - grm.terminal(ByteSet::from_range(n.byte_range[0], n.byte_range[1])) + grm.terminal(&ByteSet::from_range(n.byte_range[0], n.byte_range[1])) } guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { grm.fresh_symbol(&n.name) diff --git a/controllers/aici_abi/src/earley/byteset.rs b/controllers/aici_abi/src/earley/byteset.rs index 414df548..dc578e1a 100644 --- a/controllers/aici_abi/src/earley/byteset.rs +++ b/controllers/aici_abi/src/earley/byteset.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::fmt::{Debug, Display}; const BYTESET_LEN: usize = 8; @@ -7,6 +7,23 @@ pub struct ByteSet { mask: [u32; BYTESET_LEN], } +impl Debug for ByteSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + let mut first = true; + for i in 0u32..=256 { + if i <= 0xff && self.contains(i as u8) { + if !first { + write!(f, ", ")?; + } + first = false; + write!(f, "{}", i)?; + } + } + write!(f, "]") + } +} + pub fn byte_to_string(b: u8) -> String { if b >= 0x7f { format!("x{:02x}", b) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 0da7cf5f..305c1f38 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -67,8 +67,8 @@ impl Grammar { sym.rules.push(Rule { lhs, rhs }); } - pub fn terminal(&mut self, bytes: ByteSet) -> SymIdx { - match self.terminals.get(&bytes) { + pub fn terminal(&mut self, bytes: &ByteSet) -> SymIdx { + match self.terminals.get(bytes) { Some(sym) => *sym, None => { let mut name = format!("T:{}", bytes); @@ -77,7 +77,7 @@ impl Grammar { } let sym = self.fresh_symbol(&name); self.sym_data_mut(sym).bytes = Some(bytes.clone()); - self.terminals.insert(bytes, sym); + self.terminals.insert(bytes.clone(), sym); sym } } @@ -114,7 +114,7 @@ impl Grammar { fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { let sym_data = other.sym_data(sym); if sym_data.is_terminal() { - self.terminal(sym_data.bytes.clone().unwrap()) + self.terminal(sym_data.bytes.as_ref().unwrap()) } else { self.symbol(&sym_data.name) } @@ -156,7 +156,7 @@ impl Grammar { let terminals = rules .iter() .map(|r| self.sym_data(r.rhs[i]).bytes.clone().unwrap()); - outp.terminal(ByteSet::from_sum(terminals)) + outp.terminal(&ByteSet::from_sum(terminals)) } else { outp.copy_from(self, *s) } From 832107c6346bebdb518753204a4f9b571ab14fa2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Mar 2024 16:52:35 +0000 Subject: [PATCH 170/301] rename things --- controllers/aici_abi/src/earley/grammar.rs | 74 +++++++++++----------- controllers/aici_abi/src/earley/parser.rs | 10 +-- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 305c1f38..0d73b524 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -241,8 +241,8 @@ impl Grammar { .expand_shortcuts() } - pub fn compile(&self) -> OptGrammar { - OptGrammar::from_grammar(self) + pub fn compile(&self) -> CGrammar { + CGrammar::from_grammar(self) } pub fn fresh_symbol(&mut self, name0: &str) -> SymIdx { @@ -306,10 +306,10 @@ impl Debug for Grammar { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct OptSymIdx(u16); +pub struct CSymIdx(u16); -impl OptSymIdx { - pub const NULL: OptSymIdx = OptSymIdx(0); +impl CSymIdx { + pub const NULL: CSymIdx = CSymIdx(0); pub fn as_index(&self) -> usize { self.0 as usize @@ -328,7 +328,7 @@ pub trait SimpleHash { } } -impl SimpleHash for OptSymIdx { +impl SimpleHash for CSymIdx { fn simple_hash(&self) -> u32 { (self.0 as u32).wrapping_mul(79667123) } @@ -338,6 +338,8 @@ impl SimpleHash for OptSymIdx { pub struct RuleIdx(u32); impl RuleIdx { + pub const NULL: RuleIdx = RuleIdx(0); + pub fn from_index(idx: u32) -> Self { RuleIdx(idx) } @@ -352,8 +354,8 @@ impl RuleIdx { } #[derive(Clone)] -pub struct OptSymbol { - pub idx: OptSymIdx, +pub struct CSymbol { + pub idx: CSymIdx, pub name: String, pub is_terminal: bool, pub is_nullable: bool, @@ -361,27 +363,27 @@ pub struct OptSymbol { } #[derive(Clone)] -pub struct OptGrammar { - start_symbol: OptSymIdx, +pub struct CGrammar { + start_symbol: CSymIdx, terminals: Vec, - symbols: Vec, - rules: Vec, - rule_idx_to_sym_idx: Vec, + symbols: Vec, + rules: Vec, + rule_idx_to_sym_idx: Vec, terminals_by_byte: Vec, } const RULE_SHIFT: usize = 2; -impl OptGrammar { - pub fn sym_idx_of(&self, rule: RuleIdx) -> OptSymIdx { +impl CGrammar { + pub fn sym_idx_of(&self, rule: RuleIdx) -> CSymIdx { self.rule_idx_to_sym_idx[rule.as_index() >> RULE_SHIFT] } - pub fn sym_data(&self, sym: OptSymIdx) -> &OptSymbol { + pub fn sym_data(&self, sym: CSymIdx) -> &CSymbol { &self.symbols[sym.0 as usize] } - fn sym_data_mut(&mut self, sym: OptSymIdx) -> &mut OptSymbol { + fn sym_data_mut(&mut self, sym: CSymIdx) -> &mut CSymbol { &mut self.symbols[sym.0 as usize] } @@ -389,34 +391,34 @@ impl OptGrammar { &self.terminals_by_byte[b as usize] } - pub fn sym_idx_at(&self, idx: RuleIdx) -> OptSymIdx { + pub fn sym_idx_at(&self, idx: RuleIdx) -> CSymIdx { self.rules[idx.0 as usize] } - pub fn start(&self) -> OptSymIdx { + pub fn start(&self) -> CSymIdx { self.start_symbol } - pub fn is_accepting(&self, sym: OptSymIdx, rule: RuleIdx) -> bool { - sym == self.start() && self.sym_idx_at(rule) == OptSymIdx::NULL + pub fn is_accepting(&self, sym: CSymIdx, rule: RuleIdx) -> bool { + sym == self.start() && self.sym_idx_at(rule) == CSymIdx::NULL } - pub fn rules_of(&self, sym: OptSymIdx) -> &[RuleIdx] { + pub fn rules_of(&self, sym: CSymIdx) -> &[RuleIdx] { &self.sym_data(sym).rules } fn from_grammar(grammar: &Grammar) -> Self { - let mut outp = OptGrammar { - start_symbol: OptSymIdx::NULL, + let mut outp = CGrammar { + start_symbol: CSymIdx::NULL, // replaced terminals: vec![ByteSet::new()], - symbols: vec![OptSymbol { - idx: OptSymIdx::NULL, + symbols: vec![CSymbol { + idx: CSymIdx::NULL, name: "NULL".to_string(), is_terminal: true, is_nullable: false, rules: vec![], }], - rules: vec![], + rules: vec![CSymIdx::NULL], // make sure RuleIdx::NULL is invalid rule_idx_to_sym_idx: vec![], terminals_by_byte: vec![], }; @@ -425,28 +427,28 @@ impl OptGrammar { let sym = grammar.sym_data(*sidx); outp.terminals.push(sym.bytes.clone().unwrap()); let idx = outp.symbols.len() as u16; - outp.symbols.push(OptSymbol { - idx: OptSymIdx(idx), + outp.symbols.push(CSymbol { + idx: CSymIdx(idx), name: sym.name.clone(), is_terminal: true, is_nullable: false, rules: vec![], }); - sym_map.insert(sym.idx, OptSymIdx(idx)); + sym_map.insert(sym.idx, CSymIdx(idx)); } for sym in &grammar.symbols { if sym.is_terminal() { continue; } let idx = outp.symbols.len() as u16; - outp.symbols.push(OptSymbol { - idx: OptSymIdx(idx), + outp.symbols.push(CSymbol { + idx: CSymIdx(idx), name: sym.name.clone(), is_terminal: false, is_nullable: sym.rules.iter().any(|r| r.rhs.is_empty()), rules: vec![], }); - sym_map.insert(sym.idx, OptSymIdx(idx)); + sym_map.insert(sym.idx, CSymIdx(idx)); } outp.start_symbol = sym_map[&grammar.start()]; for sym in &grammar.symbols { @@ -461,10 +463,10 @@ impl OptGrammar { for r in &rule.rhs { outp.rules.push(sym_map[r]); } - outp.rules.push(OptSymIdx::NULL); + outp.rules.push(CSymIdx::NULL); } while outp.rules.len() % (1 << RULE_SHIFT) != 0 { - outp.rules.push(OptSymIdx::NULL); + outp.rules.push(CSymIdx::NULL); } let rlen = outp.rules.len() >> RULE_SHIFT; while outp.rule_idx_to_sym_idx.len() < rlen { @@ -480,7 +482,7 @@ impl OptGrammar { } 'rules: for rule in sym.rules.iter() { let mut idx = rule.as_index(); - while outp.rules[idx] != OptSymIdx::NULL { + while outp.rules[idx] != CSymIdx::NULL { if outp.sym_data(outp.rules[idx]).is_nullable { to_null.push(sym.idx); break 'rules; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 267bbc92..37563f92 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, hash::Hash, ops::Range, vec}; use crate::toktree::{Recognizer, SpecialToken}; -use super::grammar::{OptGrammar, OptSymIdx, RuleIdx, SimpleHash}; +use super::grammar::{CGrammar, CSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; @@ -111,11 +111,11 @@ struct Scratch { row_start: usize, row_end: usize, items: Vec, - predicated_syms: SimpleSet, + predicated_syms: SimpleSet, } pub struct Parser { - grammar: OptGrammar, + grammar: CGrammar, scratch: Scratch, rows: Vec, stats: Stats, @@ -157,7 +157,7 @@ impl Scratch { } impl Parser { - pub fn new(grammar: OptGrammar) -> Self { + pub fn new(grammar: CGrammar) -> Self { let start = grammar.start(); let mut r = Parser { grammar, @@ -250,7 +250,7 @@ impl Parser { let rule = item.rule_idx(); let after_dot = self.grammar.sym_idx_at(rule); - if after_dot == OptSymIdx::NULL { + if after_dot == CSymIdx::NULL { let lhs = self.grammar.sym_idx_of(item.rule_idx()); // complete self.is_accepting = self.is_accepting || lhs == self.grammar.start(); From f2336b1472989ecc4137054ba1892e41dd624366 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Mar 2024 18:29:14 +0000 Subject: [PATCH 171/301] bugfix for parser --- controllers/aici_abi/src/earley/grammar.rs | 58 +++++++++++++++++++--- controllers/aici_abi/src/earley/parser.rs | 36 +++++++++----- controllers/aici_abi/src/toktree.rs | 1 + 3 files changed, 74 insertions(+), 21 deletions(-) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 0d73b524..304e4007 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -379,6 +379,20 @@ impl CGrammar { self.rule_idx_to_sym_idx[rule.as_index() >> RULE_SHIFT] } + pub fn rule_rhs(&self, rule: RuleIdx) -> (&[CSymIdx], usize) { + let idx = rule.as_index(); + let mut start = idx - 1; + while self.rules[start] != CSymIdx::NULL { + start -= 1; + } + start += 1; + let mut stop = idx; + while self.rules[stop] != CSymIdx::NULL { + stop += 1; + } + (&self.rules[start..stop], idx - start) + } + pub fn sym_data(&self, sym: CSymIdx) -> &CSymbol { &self.symbols[sym.0 as usize] } @@ -480,14 +494,14 @@ impl CGrammar { if sym.is_nullable { continue; } - 'rules: for rule in sym.rules.iter() { - let mut idx = rule.as_index(); - while outp.rules[idx] != CSymIdx::NULL { - if outp.sym_data(outp.rules[idx]).is_nullable { - to_null.push(sym.idx); - break 'rules; - } - idx += 1; + for rule in sym.rules.iter() { + if outp + .rule_rhs(*rule) + .0 + .iter() + .all(|elt| outp.sym_data(*elt).is_nullable) + { + to_null.push(sym.idx); } } } @@ -510,4 +524,32 @@ impl CGrammar { } outp } + + pub fn sym_name(&self, sym: CSymIdx) -> &str { + &self.symbols[sym.0 as usize].name + } + + pub fn rule_to_string(&self, rule: RuleIdx) -> String { + let lhs = self.sym_name(self.sym_idx_of(rule)); + let (rhs, dot) = self.rule_rhs(rule); + let mut rhs_str = rhs + .iter() + .enumerate() + .map(|(i, s)| { + format!( + "{}{}", + if i == dot { "(*) " } else { "" }, + self.sym_name(*s) + ) + }) + .collect::>() + .join(" "); + if rhs.is_empty() { + rhs_str.push_str("ϵ"); + } + if dot == rhs.len() { + rhs_str.push_str(" (*)"); + } + format!("{} ::= {}", lhs, rhs_str) + } } diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 37563f92..07b946e5 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -120,6 +120,7 @@ pub struct Parser { rows: Vec, stats: Stats, is_accepting: bool, + last_collapse: usize, } impl Scratch { @@ -165,6 +166,7 @@ impl Parser { scratch: Scratch::default(), stats: Stats::default(), is_accepting: false, + last_collapse: 0, }; for rule in r.grammar.rules_of(start).to_vec() { r.scratch.add_unique(Item::new(rule, 0), "init"); @@ -178,23 +180,24 @@ impl Parser { } fn item_to_string(&self, item: &Item) -> String { - // let rule = self.grammar.rule_data(item.rule_idx()); - // self.grammar.rule_to_string(rule, item.dot_pos()) format!( - "item: rule: {:?}, start: {}", - item.rule_idx(), + "{} @{}", + self.grammar.rule_to_string(item.rule_idx()), item.start_pos() ) } - // fn row_to_string(&self, row: &Row) -> String { - // // let mut r = vec![format!("token: {}", byte_to_string(row.token))]; - // // for item in &row.items { - // // r.push(self.item_to_string(item)); - // // } - // // r.join("\n") + "\n" - // "todo".to_string() - // } + pub fn print_row(&self, row_idx: usize) { + let row = &self.rows[row_idx]; + println!("row {}", row_idx); + for i in row.item_indices() { + println!("{}", self.item_to_string(&self.scratch.items[i])); + } + } + + pub fn num_rows(&self) -> usize { + self.rows.len() + } #[inline(always)] pub fn scan(&mut self, b: u8) -> ParseResult { @@ -306,7 +309,14 @@ impl Recognizer for Parser { } fn collapse(&mut self) { - // does nothing - we need to keep the entire state + // this actually means "commit" - can no longer backtrack past this point + + if false { + for idx in self.last_collapse..self.num_rows() { + self.print_row(idx); + } + } + self.last_collapse = self.num_rows(); } fn special_allowed(&mut self, tok: SpecialToken) -> bool { diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 72150d12..0c6a0add 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -430,6 +430,7 @@ impl TokTrie { } pub fn append_token(&self, r: &mut impl Recognizer, t: TokenId) { + // println!("append_token: {}", self.token_dbg(t)); let bytes = self.token(t); for &byte in bytes { r.push_byte(byte) From 4f8497f47945c69ddb2e9a8a203c39eaa1bff13e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 7 Mar 2024 01:27:54 +0000 Subject: [PATCH 172/301] earley: minor perf work --- controllers/aici_abi/src/earley/parser.rs | 28 ++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs index 07b946e5..740f97e0 100644 --- a/controllers/aici_abi/src/earley/parser.rs +++ b/controllers/aici_abi/src/earley/parser.rs @@ -6,6 +6,9 @@ use super::grammar::{CGrammar, CSymIdx, RuleIdx, SimpleHash}; const DEBUG: bool = false; +// this may speed up more complex grammar but slows down simple ones (by 10%) +const PREDICTED_SYM_FILTER: bool = false; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct Item { data: u64, @@ -88,6 +91,7 @@ impl SimpleSet { self.items.clear(); } + #[inline(always)] fn insert(&mut self, item: T) { let mask = item.mask64(); if (self.hash & mask) != 0 && self.items.contains(&item) { @@ -97,6 +101,7 @@ impl SimpleSet { self.items.push(item); } + #[inline(always)] fn contains(&self, item: T) -> bool { if (item.mask64() & self.hash) == 0 { false @@ -104,6 +109,20 @@ impl SimpleSet { self.items.contains(&item) } } + + #[inline(always)] + fn should_insert(&mut self, item: T) -> bool { + if !PREDICTED_SYM_FILTER { + true + } else { + if self.contains(item) { + false + } else { + self.insert(item); + true + } + } + } } #[derive(Default)] @@ -145,7 +164,11 @@ impl Scratch { #[inline(always)] fn just_add(&mut self, item: Item) { self.ensure_items(self.row_end + 1); - self.items[self.row_end] = item; + // SAFETY: we just ensured that there is enough space + unsafe { + self.items.as_mut_ptr().add(self.row_end).write(item); + } + // self.items[self.row_end] = item; self.row_end += 1; } @@ -272,8 +295,7 @@ impl Parser { if sym_data.is_nullable { self.scratch.add_unique(item.advance_dot(), "null"); } - if !self.scratch.predicated_syms.contains(after_dot) { - self.scratch.predicated_syms.insert(after_dot); + if self.scratch.predicated_syms.should_insert(after_dot) { for rule in &sym_data.rules { let new_item = Item::new(*rule, curr_idx); self.scratch.add_unique(new_item, "predict"); From 9cc3cff3e6ad556a1dcebc50c9948474ed60b648 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Mar 2024 18:33:13 +0000 Subject: [PATCH 173/301] better grm formatting --- controllers/aici_abi/src/earley/byteset.rs | 27 +++++ controllers/aici_abi/src/earley/grammar.rs | 118 ++++++++++++++------- 2 files changed, 105 insertions(+), 40 deletions(-) diff --git a/controllers/aici_abi/src/earley/byteset.rs b/controllers/aici_abi/src/earley/byteset.rs index dc578e1a..94e7e49c 100644 --- a/controllers/aici_abi/src/earley/byteset.rs +++ b/controllers/aici_abi/src/earley/byteset.rs @@ -104,4 +104,31 @@ impl ByteSet { } r } + + pub fn num_bytes(&self) -> usize { + let mut r = 0; + for i in 0..BYTESET_LEN { + r += self.mask[i].count_ones() as usize; + } + r + } + + pub fn first_byte(&self) -> Option { + for i in 0..BYTESET_LEN { + let m = self.mask[i]; + if m != 0 { + let bit = m.trailing_zeros() as usize; + return Some((i * 32 + bit) as u8); + } + } + None + } + + pub fn single_byte(&self) -> Option { + if self.num_bytes() != 1 { + None + } else { + self.first_byte() + } + } } diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 304e4007..8a3543ea 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -32,6 +32,22 @@ impl Rule { } } +enum SymName { + Name(String), + Byte(u8), +} + +impl SymName { + fn from(name: &str, bytes: Option<&ByteSet>) -> Self { + if let Some(bytes) = bytes { + if let Some(b) = bytes.single_byte() { + return SymName::Byte(b); + } + } + SymName::Name(name.to_string()) + } +} + pub struct Grammar { symbols: Vec, symbol_by_name: FxHashMap, @@ -87,28 +103,18 @@ impl Grammar { &self.symbols[sym.0 as usize].name } - fn rule_to_string(&self, rule: &Rule, dot: usize) -> String { - let lhs = self.sym_name(rule.lhs()); - let mut rhs = rule - .rhs - .iter() - .enumerate() - .map(|(i, s)| { - format!( - "{}{}", - if i == dot { "(*) " } else { "" }, - self.sym_name(*s) - ) - }) - .collect::>() - .join(" "); - if rule.rhs.is_empty() { - rhs.push_str("ϵ"); - } - if dot == rule.rhs.len() { - rhs.push_str(" (*)"); - } - format!("{} ::= {}", lhs, rhs) + fn rule_to_string(&self, rule: &Rule, dot: Option) -> String { + rule_to_string( + self.sym_name(rule.lhs()), + rule.rhs + .iter() + .map(|s| { + let d = self.sym_data(*s); + SymName::from(&d.name, d.bytes.as_ref()) + }) + .collect(), + dot, + ) } fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { @@ -293,7 +299,7 @@ impl Debug for Grammar { num_rules += sym.rules.len(); } for rule in &sym.rules { - writeln!(f, "{}", self.rule_to_string(rule, usize::MAX))?; + writeln!(f, "{}", self.rule_to_string(rule, None))?; } } writeln!( @@ -532,24 +538,56 @@ impl CGrammar { pub fn rule_to_string(&self, rule: RuleIdx) -> String { let lhs = self.sym_name(self.sym_idx_of(rule)); let (rhs, dot) = self.rule_rhs(rule); - let mut rhs_str = rhs - .iter() - .enumerate() - .map(|(i, s)| { - format!( - "{}{}", - if i == dot { "(*) " } else { "" }, - self.sym_name(*s) - ) - }) - .collect::>() - .join(" "); - if rhs.is_empty() { - rhs_str.push_str("ϵ"); + rule_to_string( + lhs, + rhs.iter() + .map(|s| { + let d = self.sym_data(*s); + SymName::from( + &d.name, + if d.is_terminal { + Some(&self.terminals[d.idx.0 as usize]) + } else { + None + }, + ) + }) + .collect(), + Some(dot), + ) + } +} + +fn rule_to_string(lhs: &str, mut rhs: Vec, dot: Option) -> String { + if rhs.is_empty() { + rhs.push(SymName::Name("ϵ".to_string())); + if dot == Some(0) { + rhs.push(SymName::Name("•".to_string())); } - if dot == rhs.len() { - rhs_str.push_str(" (*)"); + } else if let Some(dot) = dot { + rhs.insert(dot, SymName::Name("•".to_string())); + } + let mut outp = Vec::new(); + let mut i = 0; + while i < rhs.len() { + match &rhs[i] { + SymName::Name(s) => { + outp.push(s.clone()); + i += 1; + } + SymName::Byte(_) => { + let mut text = Vec::new(); + while i < rhs.len() { + if let SymName::Byte(b) = rhs[i] { + text.push(b); + i += 1; + } else { + break; + } + } + outp.push(format!("{:?}", String::from_utf8_lossy(&text))); + } } - format!("{} ::= {}", lhs, rhs_str) } + format!("{} ::= {}", lhs, outp.join(" ")) } From 0b5863440410977e39539327a415765fdeb492b1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Mar 2024 21:18:12 +0000 Subject: [PATCH 174/301] fix grammar opt --- controllers/aici_abi/src/earley/grammar.rs | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs index 8a3543ea..2143cd56 100644 --- a/controllers/aici_abi/src/earley/grammar.rs +++ b/controllers/aici_abi/src/earley/grammar.rs @@ -126,6 +126,21 @@ impl Grammar { } } + fn rule_shape(&self, r: &Rule) -> Vec> { + let mut shape = Vec::new(); + let mut had_term = false; + for s in &r.rhs { + let sym = self.sym_data(*s); + if !had_term && sym.is_terminal() { + had_term = true; + shape.push(None); + } else { + shape.push(Some(*s)); + } + } + shape + } + fn collapse_terminals(&self) -> Self { let mut outp = Grammar::new(); for sym in &self.symbols { @@ -134,19 +149,8 @@ impl Grammar { } let mut rules_by_shape = FxHashMap::default(); for rule in &sym.rules { - let shape = rule - .rhs - .iter() - .map(|s| { - if self.sym_data(*s).is_terminal() { - None - } else { - Some(*s) - } - }) - .collect::>(); rules_by_shape - .entry(shape) + .entry(self.rule_shape(rule)) .or_insert_with(Vec::new) .push(rule); } @@ -280,6 +284,7 @@ impl Grammar { impl Debug for Grammar { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Grammar:")?; for sym in &self.symbols { match sym.bytes { Some(ref bytes) if sym.name.starts_with("T@") => { @@ -589,5 +594,5 @@ fn rule_to_string(lhs: &str, mut rhs: Vec, dot: Option) -> Strin } } } - format!("{} ::= {}", lhs, outp.join(" ")) + format!("{:15} ⇦ {}", lhs, outp.join(" ")) } From 399b28c05e98d75ac1dce9305c19f9722768703d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Mar 2024 21:47:07 +0000 Subject: [PATCH 175/301] move earley parser to gctrl --- controllers/aici_abi/Cargo.toml | 3 +- controllers/aici_abi/src/earley/bench.rs | 162 ------ controllers/aici_abi/src/earley/byteset.rs | 134 ----- controllers/aici_abi/src/earley/grammar.rs | 598 -------------------- controllers/aici_abi/src/earley/guidance.rs | 456 --------------- controllers/aici_abi/src/earley/mod.rs | 12 - controllers/aici_abi/src/earley/parser.rs | 364 ------------ controllers/aici_abi/src/lib.rs | 3 - 8 files changed, 1 insertion(+), 1731 deletions(-) delete mode 100644 controllers/aici_abi/src/earley/bench.rs delete mode 100644 controllers/aici_abi/src/earley/byteset.rs delete mode 100644 controllers/aici_abi/src/earley/grammar.rs delete mode 100644 controllers/aici_abi/src/earley/guidance.rs delete mode 100644 controllers/aici_abi/src/earley/mod.rs delete mode 100644 controllers/aici_abi/src/earley/parser.rs diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 0acd2a3c..79ce484f 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -23,10 +23,9 @@ rustc-hash = { version = "1.1.0", optional = true } quick-protobuf = { version = "0.8.1" } [features] -default = ["cfg", "rx", "earley"] +default = ["cfg", "rx"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] -earley = ["rx", "dep:rustc-hash"] [[bin]] name = "yesno" diff --git a/controllers/aici_abi/src/earley/bench.rs b/controllers/aici_abi/src/earley/bench.rs deleted file mode 100644 index b741869f..00000000 --- a/controllers/aici_abi/src/earley/bench.rs +++ /dev/null @@ -1,162 +0,0 @@ -use anyhow::Result; -use quick_protobuf::MessageRead; -use rustc_hash::FxHashSet; - -use super::{guidance, ByteSet, Grammar, Parser}; -use crate::earley::parser::ParseResult; - -pub fn earley_grm_from_guidance(bytes: &[u8]) -> Result { - let mut reader = quick_protobuf::BytesReader::from_bytes(bytes); - let gg = guidance::Grammar::from_reader(&mut reader, bytes).unwrap(); - let mut grm = Grammar::new(); - - let symbols = gg - .nodes - .iter() - .map(|n| match &n.function_type { - guidance::mod_GrammarFunction::OneOffunction_type::join(n) => grm.fresh_symbol(&n.name), - guidance::mod_GrammarFunction::OneOffunction_type::select(n) => { - grm.fresh_symbol(&n.name) - } - guidance::mod_GrammarFunction::OneOffunction_type::byte(n) => { - assert!(n.byte.len() == 1); - grm.terminal(&ByteSet::from_range(n.byte[0], n.byte[0])) - } - guidance::mod_GrammarFunction::OneOffunction_type::byte_range(n) => { - assert!(n.byte_range.len() == 2); - grm.terminal(&ByteSet::from_range(n.byte_range[0], n.byte_range[1])) - } - guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { - grm.fresh_symbol(&n.name) - } - guidance::mod_GrammarFunction::OneOffunction_type::None => { - panic!("None function type in guidance::Grammar") - } - }) - .collect::>(); - - let set = FxHashSet::from_iter(symbols.iter()); - assert!(set.len() == symbols.len(), "duplicate symbols"); - - for (n, sym) in gg.nodes.iter().zip(symbols.iter()) { - let lhs = *sym; - match &n.function_type { - guidance::mod_GrammarFunction::OneOffunction_type::join(n) => { - if n.nullable { - //println!("nullable join: {:?}", n.name); - } - let rhs = n.values.iter().map(|idx| symbols[*idx as usize]).collect(); - grm.add_rule(lhs, rhs); - } - guidance::mod_GrammarFunction::OneOffunction_type::select(n) => { - if n.nullable { - // println!("nullable sel: {:?} {:?}", n.name, n.values); - grm.add_rule(lhs, vec![]); - } - for v in &n.values { - grm.add_rule(lhs, vec![symbols[*v as usize]]); - } - } - guidance::mod_GrammarFunction::OneOffunction_type::byte(_) => {} - guidance::mod_GrammarFunction::OneOffunction_type::byte_range(_) => {} - guidance::mod_GrammarFunction::OneOffunction_type::model_variable(n) => { - // eos_token, bos_token etc - panic!("model_variable not implemented yet ({:?})", n.name); - } - guidance::mod_GrammarFunction::OneOffunction_type::None => panic!("???"), - } - } - - grm.add_rule(grm.start(), vec![symbols[0]]); - - Ok(grm) -} - -pub fn earley_test(trie: crate::toktree::TokTrie) { - let g_bytes = include_bytes!("../../grammars/json0.guidance"); - let cfg = earley_grm_from_guidance(g_bytes).unwrap(); - // println!("cfg0: {:?}", cfg); - let cfg = cfg.optimize(); - println!("cfg: {:?}", cfg); - - let input = r#"{"name":"Joe","info":{"foo":10,"bar":"20"}}"#.as_bytes(); - - let toks = trie.greedy_tokenize(input); - println!("tokens: {:?}", toks.len()); - - let grm = cfg.compile(); - - let mut parser = Parser::new(grm.clone()); - let mut last_res = ParseResult::Reject; - for b in input { - last_res = parser.scan(*b); - if last_res == ParseResult::Reject { - println!("reject"); - break; - } - } - if last_res != ParseResult::Accept { - println!("final non-accept"); - } - - const COLLECT_TIMES: bool = false; - const NUM_REP: usize = if COLLECT_TIMES { 5 } else { 500 }; - let mut durations = vec![]; - let mut durations_us = vec![]; - println!("start!"); - - let num_tok = 4; - - for _ in 0..NUM_REP { - let mut line = 1; - let mut vob = trie.alloc_token_set(); - - parser = Parser::new(grm.clone()); - let mut times = vec![]; - - #[cfg(not(target_arch = "wasm32"))] - let t0 = std::time::Instant::now(); - - for (idx, tok) in toks.iter().take(num_tok).enumerate() { - let tok = *tok; - let tt = std::time::Instant::now(); - trie.compute_bias(&mut parser, &mut vob); - if idx == num_tok - 1 { - durations_us.push(tt.elapsed().as_micros() as u64); - } - // parser.print_stats(); - if !vob.is_allowed(tok) { - println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); - panic!(); - } - for b in trie.token(tok) { - if *b == b'\n' { - line += 1; - } - } - // println!( - // "TOK: {} ===> {}", - // trie.token_dbg(tok), - // trie.token_set_dbg(&vob) - // ); - trie.append_token(&mut parser, tok); - if COLLECT_TIMES { - times.push(tt.elapsed().as_micros() as u32); - } - } - - durations.push(t0.elapsed().as_micros() as u64); - - if COLLECT_TIMES { - println!("times: {:?}", times); - } - } - - durations.sort(); - durations_us.sort(); - - let min_us = *durations_us.iter().min().unwrap(); - // println!("min_time_us: {:?}", min_us); - // for ~5ms 0.1ms is the precision we expect - println!("min_time_ms: {:.1}", min_us as f64 / 1000.0); -} diff --git a/controllers/aici_abi/src/earley/byteset.rs b/controllers/aici_abi/src/earley/byteset.rs deleted file mode 100644 index 94e7e49c..00000000 --- a/controllers/aici_abi/src/earley/byteset.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::fmt::{Debug, Display}; - -const BYTESET_LEN: usize = 8; - -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct ByteSet { - mask: [u32; BYTESET_LEN], -} - -impl Debug for ByteSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - let mut first = true; - for i in 0u32..=256 { - if i <= 0xff && self.contains(i as u8) { - if !first { - write!(f, ", ")?; - } - first = false; - write!(f, "{}", i)?; - } - } - write!(f, "]") - } -} - -pub fn byte_to_string(b: u8) -> String { - if b >= 0x7f { - format!("x{:02x}", b) - } else { - let b = b as char; - match b { - '_' | 'a'..='z' | 'A'..='Z' | '0'..='9' => format!("{}", b), - _ => format!("{:?}", b as char), - } - } -} - -impl Display for ByteSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut start = None; - let mut first = true; - for i in 0u32..=256 { - if i <= 0xff && self.contains(i as u8) { - if start.is_none() { - start = Some(i); - } - } else { - if let Some(start) = start { - if !first { - write!(f, ";")?; - } - first = false; - write!(f, "{}", byte_to_string(start as u8))?; - if i - start > 1 { - write!(f, "-{}", byte_to_string((i - 1) as u8))?; - } - } - start = None; - } - } - Ok(()) - } -} - -impl ByteSet { - pub fn new() -> Self { - ByteSet { - mask: [0; BYTESET_LEN], - } - } - - pub fn from_sum<'a>(elts: impl Iterator) -> Self { - let mut r = ByteSet::new(); - for e in elts { - r.add_set(&e); - } - r - } - - pub fn add_set(&mut self, other: &ByteSet) { - for i in 0..BYTESET_LEN { - self.mask[i] |= other.mask[i]; - } - } - - pub fn add(&mut self, byte: u8) { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] |= 1 << bit; - } - - pub fn contains(&self, byte: u8) -> bool { - let idx = byte as usize / 32; - let bit = byte as usize % 32; - self.mask[idx] & (1 << bit) != 0 - } - - pub fn from_range(start: u8, end: u8) -> Self { - let mut r = ByteSet::new(); - // TODO optimize - for b in start..=end { - r.add(b); - } - r - } - - pub fn num_bytes(&self) -> usize { - let mut r = 0; - for i in 0..BYTESET_LEN { - r += self.mask[i].count_ones() as usize; - } - r - } - - pub fn first_byte(&self) -> Option { - for i in 0..BYTESET_LEN { - let m = self.mask[i]; - if m != 0 { - let bit = m.trailing_zeros() as usize; - return Some((i * 32 + bit) as u8); - } - } - None - } - - pub fn single_byte(&self) -> Option { - if self.num_bytes() != 1 { - None - } else { - self.first_byte() - } - } -} diff --git a/controllers/aici_abi/src/earley/grammar.rs b/controllers/aici_abi/src/earley/grammar.rs deleted file mode 100644 index 2143cd56..00000000 --- a/controllers/aici_abi/src/earley/grammar.rs +++ /dev/null @@ -1,598 +0,0 @@ -use std::fmt::Debug; - -use crate::svob::SimpleVob; - -use super::ByteSet; -use rustc_hash::FxHashMap; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct SymIdx(u32); - -impl Symbol { - fn is_terminal(&self) -> bool { - self.bytes.is_some() - } -} - -struct Symbol { - idx: SymIdx, - name: String, - bytes: Option, - rules: Vec, -} - -struct Rule { - lhs: SymIdx, - rhs: Vec, -} - -impl Rule { - fn lhs(&self) -> SymIdx { - self.lhs - } -} - -enum SymName { - Name(String), - Byte(u8), -} - -impl SymName { - fn from(name: &str, bytes: Option<&ByteSet>) -> Self { - if let Some(bytes) = bytes { - if let Some(b) = bytes.single_byte() { - return SymName::Byte(b); - } - } - SymName::Name(name.to_string()) - } -} - -pub struct Grammar { - symbols: Vec, - symbol_by_name: FxHashMap, - terminals: FxHashMap, -} - -impl Grammar { - pub fn new() -> Self { - let mut r = Grammar { - symbols: vec![], - symbol_by_name: FxHashMap::default(), - terminals: FxHashMap::default(), - }; - let _ = r.symbol("_start"); - r - } - - pub fn start(&self) -> SymIdx { - self.symbols[0].idx - } - - fn sym_data(&self, sym: SymIdx) -> &Symbol { - &self.symbols[sym.0 as usize] - } - - fn sym_data_mut(&mut self, sym: SymIdx) -> &mut Symbol { - &mut self.symbols[sym.0 as usize] - } - - pub fn add_rule(&mut self, lhs: SymIdx, rhs: Vec) { - assert!(!self.sym_data(lhs).is_terminal()); - let sym = self.sym_data_mut(lhs); - sym.rules.push(Rule { lhs, rhs }); - } - - pub fn terminal(&mut self, bytes: &ByteSet) -> SymIdx { - match self.terminals.get(bytes) { - Some(sym) => *sym, - None => { - let mut name = format!("T:{}", bytes); - if name.len() > 40 { - name = format!("T@{}", self.terminals.len()); - } - let sym = self.fresh_symbol(&name); - self.sym_data_mut(sym).bytes = Some(bytes.clone()); - self.terminals.insert(bytes.clone(), sym); - sym - } - } - } - - pub fn sym_name(&self, sym: SymIdx) -> &str { - &self.symbols[sym.0 as usize].name - } - - fn rule_to_string(&self, rule: &Rule, dot: Option) -> String { - rule_to_string( - self.sym_name(rule.lhs()), - rule.rhs - .iter() - .map(|s| { - let d = self.sym_data(*s); - SymName::from(&d.name, d.bytes.as_ref()) - }) - .collect(), - dot, - ) - } - - fn copy_from(&mut self, other: &Grammar, sym: SymIdx) -> SymIdx { - let sym_data = other.sym_data(sym); - if sym_data.is_terminal() { - self.terminal(sym_data.bytes.as_ref().unwrap()) - } else { - self.symbol(&sym_data.name) - } - } - - fn rule_shape(&self, r: &Rule) -> Vec> { - let mut shape = Vec::new(); - let mut had_term = false; - for s in &r.rhs { - let sym = self.sym_data(*s); - if !had_term && sym.is_terminal() { - had_term = true; - shape.push(None); - } else { - shape.push(Some(*s)); - } - } - shape - } - - fn collapse_terminals(&self) -> Self { - let mut outp = Grammar::new(); - for sym in &self.symbols { - if sym.rules.is_empty() { - continue; - } - let mut rules_by_shape = FxHashMap::default(); - for rule in &sym.rules { - rules_by_shape - .entry(self.rule_shape(rule)) - .or_insert_with(Vec::new) - .push(rule); - } - let lhs = outp.copy_from(self, sym.idx); - for rules in rules_by_shape.values() { - let rhs = rules[0] - .rhs - .iter() - .enumerate() - .map(|(i, s)| { - let sym = self.sym_data(*s); - if sym.is_terminal() { - let terminals = rules - .iter() - .map(|r| self.sym_data(r.rhs[i]).bytes.clone().unwrap()); - outp.terminal(&ByteSet::from_sum(terminals)) - } else { - outp.copy_from(self, *s) - } - }) - .collect(); - outp.add_rule(lhs, rhs); - } - } - outp - } - - fn expand_shortcuts(&self) -> Self { - let mut use_count = vec![0; self.symbols.len()]; - for sym in &self.symbols { - for r in sym.rules.iter() { - for s in &r.rhs { - use_count[s.0 as usize] += 1; - } - } - } - - let mut repl = FxHashMap::default(); - for sym in &self.symbols { - if sym.idx == self.start() { - continue; - } - if sym.rules.len() == 1 - && (use_count[sym.idx.0 as usize] == 1 || sym.rules[0].rhs.len() == 1) - { - // eliminate sym.idx - repl.insert(sym.idx, sym.rules[0].rhs.clone()); - } - } - - // fix-point expand the mapping - loop { - let to_change = repl - .iter() - .filter_map(|(lhs, rhs)| { - let rhs2 = rhs - .iter() - .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) - .collect::>(); - assert!(rhs2.iter().all(|s| *s != *lhs), "cyclic?"); - if *rhs != rhs2 { - Some((*lhs, rhs2)) - } else { - None - } - }) - .collect::>(); - if to_change.is_empty() { - break; - } - for (lhs, rhs) in to_change { - repl.insert(lhs, rhs); - } - } - - let mut outp = Grammar::new(); - for sym in &self.symbols { - if repl.contains_key(&sym.idx) { - continue; - } - let lhs = outp.copy_from(self, sym.idx); - for rule in &sym.rules { - let rhs = rule - .rhs - .iter() - .flat_map(|s| repl.get(s).cloned().unwrap_or_else(|| vec![*s])) - .map(|s| outp.copy_from(self, s)) - .collect(); - outp.add_rule(lhs, rhs); - } - } - outp - } - - pub fn optimize(&self) -> Self { - self.expand_shortcuts() - .collapse_terminals() - .expand_shortcuts() - } - - pub fn compile(&self) -> CGrammar { - CGrammar::from_grammar(self) - } - - pub fn fresh_symbol(&mut self, name0: &str) -> SymIdx { - let mut name = name0.to_string(); - let mut idx = 2; - while self.symbol_by_name.contains_key(&name) { - name = format!("{}#{}", name0, idx); - idx += 1; - } - - let idx = SymIdx(self.symbols.len() as u32); - self.symbols.push(Symbol { - name: name.clone(), - bytes: None, - idx, - rules: vec![], - }); - self.symbol_by_name.insert(name, idx); - idx - } - - pub fn symbol(&mut self, name: &str) -> SymIdx { - match self.symbol_by_name.get(name) { - Some(idx) => *idx, - None => self.fresh_symbol(name), - } - } -} - -impl Debug for Grammar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "Grammar:")?; - for sym in &self.symbols { - match sym.bytes { - Some(ref bytes) if sym.name.starts_with("T@") => { - writeln!(f, "{} := {}", sym.name, bytes)? - } - _ => {} - } - } - let mut num_term = 0; - let mut num_rules = 0; - let mut num_non_term = 0; - for sym in &self.symbols { - if sym.is_terminal() { - num_term += 1; - } else { - num_non_term += 1; - num_rules += sym.rules.len(); - } - for rule in &sym.rules { - writeln!(f, "{}", self.rule_to_string(rule, None))?; - } - } - writeln!( - f, - "stats: {} terminals; {} non-terminals with {} rules\n", - num_term, num_non_term, num_rules - )?; - Ok(()) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct CSymIdx(u16); - -impl CSymIdx { - pub const NULL: CSymIdx = CSymIdx(0); - - pub fn as_index(&self) -> usize { - self.0 as usize - } -} - -pub trait SimpleHash { - fn simple_hash(&self) -> u32; - - fn mask64(&self) -> u64 { - 1 << (self.simple_hash() & 63) - } - - fn mask32(&self) -> u32 { - 1 << (self.simple_hash() & 31) - } -} - -impl SimpleHash for CSymIdx { - fn simple_hash(&self) -> u32 { - (self.0 as u32).wrapping_mul(79667123) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct RuleIdx(u32); - -impl RuleIdx { - pub const NULL: RuleIdx = RuleIdx(0); - - pub fn from_index(idx: u32) -> Self { - RuleIdx(idx) - } - - pub fn advance(&self) -> RuleIdx { - RuleIdx(self.0 + 1) - } - - pub fn as_index(&self) -> usize { - self.0 as usize - } -} - -#[derive(Clone)] -pub struct CSymbol { - pub idx: CSymIdx, - pub name: String, - pub is_terminal: bool, - pub is_nullable: bool, - pub rules: Vec, -} - -#[derive(Clone)] -pub struct CGrammar { - start_symbol: CSymIdx, - terminals: Vec, - symbols: Vec, - rules: Vec, - rule_idx_to_sym_idx: Vec, - terminals_by_byte: Vec, -} - -const RULE_SHIFT: usize = 2; - -impl CGrammar { - pub fn sym_idx_of(&self, rule: RuleIdx) -> CSymIdx { - self.rule_idx_to_sym_idx[rule.as_index() >> RULE_SHIFT] - } - - pub fn rule_rhs(&self, rule: RuleIdx) -> (&[CSymIdx], usize) { - let idx = rule.as_index(); - let mut start = idx - 1; - while self.rules[start] != CSymIdx::NULL { - start -= 1; - } - start += 1; - let mut stop = idx; - while self.rules[stop] != CSymIdx::NULL { - stop += 1; - } - (&self.rules[start..stop], idx - start) - } - - pub fn sym_data(&self, sym: CSymIdx) -> &CSymbol { - &self.symbols[sym.0 as usize] - } - - fn sym_data_mut(&mut self, sym: CSymIdx) -> &mut CSymbol { - &mut self.symbols[sym.0 as usize] - } - - pub fn terminals_by_byte(&self, b: u8) -> &SimpleVob { - &self.terminals_by_byte[b as usize] - } - - pub fn sym_idx_at(&self, idx: RuleIdx) -> CSymIdx { - self.rules[idx.0 as usize] - } - - pub fn start(&self) -> CSymIdx { - self.start_symbol - } - - pub fn is_accepting(&self, sym: CSymIdx, rule: RuleIdx) -> bool { - sym == self.start() && self.sym_idx_at(rule) == CSymIdx::NULL - } - - pub fn rules_of(&self, sym: CSymIdx) -> &[RuleIdx] { - &self.sym_data(sym).rules - } - - fn from_grammar(grammar: &Grammar) -> Self { - let mut outp = CGrammar { - start_symbol: CSymIdx::NULL, // replaced - terminals: vec![ByteSet::new()], - symbols: vec![CSymbol { - idx: CSymIdx::NULL, - name: "NULL".to_string(), - is_terminal: true, - is_nullable: false, - rules: vec![], - }], - rules: vec![CSymIdx::NULL], // make sure RuleIdx::NULL is invalid - rule_idx_to_sym_idx: vec![], - terminals_by_byte: vec![], - }; - let mut sym_map = FxHashMap::default(); - for (_, sidx) in &grammar.terminals { - let sym = grammar.sym_data(*sidx); - outp.terminals.push(sym.bytes.clone().unwrap()); - let idx = outp.symbols.len() as u16; - outp.symbols.push(CSymbol { - idx: CSymIdx(idx), - name: sym.name.clone(), - is_terminal: true, - is_nullable: false, - rules: vec![], - }); - sym_map.insert(sym.idx, CSymIdx(idx)); - } - for sym in &grammar.symbols { - if sym.is_terminal() { - continue; - } - let idx = outp.symbols.len() as u16; - outp.symbols.push(CSymbol { - idx: CSymIdx(idx), - name: sym.name.clone(), - is_terminal: false, - is_nullable: sym.rules.iter().any(|r| r.rhs.is_empty()), - rules: vec![], - }); - sym_map.insert(sym.idx, CSymIdx(idx)); - } - outp.start_symbol = sym_map[&grammar.start()]; - for sym in &grammar.symbols { - if sym.is_terminal() { - continue; - } - let idx = sym_map[&sym.idx]; - for rule in &sym.rules { - let curr = RuleIdx(outp.rules.len().try_into().unwrap()); - outp.sym_data_mut(idx).rules.push(curr); - // outp.rules.push(idx); - for r in &rule.rhs { - outp.rules.push(sym_map[r]); - } - outp.rules.push(CSymIdx::NULL); - } - while outp.rules.len() % (1 << RULE_SHIFT) != 0 { - outp.rules.push(CSymIdx::NULL); - } - let rlen = outp.rules.len() >> RULE_SHIFT; - while outp.rule_idx_to_sym_idx.len() < rlen { - outp.rule_idx_to_sym_idx.push(idx); - } - } - - loop { - let mut to_null = vec![]; - for sym in &outp.symbols { - if sym.is_nullable { - continue; - } - for rule in sym.rules.iter() { - if outp - .rule_rhs(*rule) - .0 - .iter() - .all(|elt| outp.sym_data(*elt).is_nullable) - { - to_null.push(sym.idx); - } - } - } - if to_null.is_empty() { - break; - } - for sym in to_null { - outp.sym_data_mut(sym).is_nullable = true; - } - } - - for b in 0..=255 { - let mut v = SimpleVob::alloc(outp.terminals.len()); - for (i, bytes) in outp.terminals.iter().enumerate() { - if bytes.contains(b as u8) { - v.allow_token(i as u32); - } - } - outp.terminals_by_byte.push(v); - } - outp - } - - pub fn sym_name(&self, sym: CSymIdx) -> &str { - &self.symbols[sym.0 as usize].name - } - - pub fn rule_to_string(&self, rule: RuleIdx) -> String { - let lhs = self.sym_name(self.sym_idx_of(rule)); - let (rhs, dot) = self.rule_rhs(rule); - rule_to_string( - lhs, - rhs.iter() - .map(|s| { - let d = self.sym_data(*s); - SymName::from( - &d.name, - if d.is_terminal { - Some(&self.terminals[d.idx.0 as usize]) - } else { - None - }, - ) - }) - .collect(), - Some(dot), - ) - } -} - -fn rule_to_string(lhs: &str, mut rhs: Vec, dot: Option) -> String { - if rhs.is_empty() { - rhs.push(SymName::Name("ϵ".to_string())); - if dot == Some(0) { - rhs.push(SymName::Name("•".to_string())); - } - } else if let Some(dot) = dot { - rhs.insert(dot, SymName::Name("•".to_string())); - } - let mut outp = Vec::new(); - let mut i = 0; - while i < rhs.len() { - match &rhs[i] { - SymName::Name(s) => { - outp.push(s.clone()); - i += 1; - } - SymName::Byte(_) => { - let mut text = Vec::new(); - while i < rhs.len() { - if let SymName::Byte(b) = rhs[i] { - text.push(b); - i += 1; - } else { - break; - } - } - outp.push(format!("{:?}", String::from_utf8_lossy(&text))); - } - } - } - format!("{:15} ⇦ {}", lhs, outp.join(" ")) -} diff --git a/controllers/aici_abi/src/earley/guidance.rs b/controllers/aici_abi/src/earley/guidance.rs deleted file mode 100644 index c6493592..00000000 --- a/controllers/aici_abi/src/earley/guidance.rs +++ /dev/null @@ -1,456 +0,0 @@ -// Automatically generated rust module for '_serialization.proto' file -// pb-rs _serialization.proto - -#![allow(non_snake_case)] -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(unused_imports)] -#![allow(unused_variables)] -#![allow(unknown_lints)] -#![allow(clippy::all)] -#![cfg_attr(rustfmt, rustfmt_skip)] - - -use std::borrow::Cow; -use std::collections::HashMap; -type KVMap = HashMap; -use quick_protobuf::{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result}; -use quick_protobuf::sizeofs::*; -use super::*; - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Grammar<'a> { - pub nodes: Vec>, -} - -impl<'a> MessageRead<'a> for Grammar<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.nodes.push(r.read_message::(bytes)?), - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for Grammar<'a> { - fn get_size(&self) -> usize { - 0 - + self.nodes.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - for s in &self.nodes { w.write_with_tag(10, |w| w.write_message(s))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct EngineCallResponse<'a> { - pub new_bytes: Cow<'a, [u8]>, - pub is_generated: bool, - pub new_bytes_prob: f32, - pub capture_groups: KVMap, Cow<'a, str>>, - pub capture_group_log_probs: KVMap, f32>, - pub new_token_count: i32, -} - -impl<'a> MessageRead<'a> for EngineCallResponse<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.new_bytes = r.read_bytes(bytes).map(Cow::Borrowed)?, - Ok(16) => msg.is_generated = r.read_bool(bytes)?, - Ok(29) => msg.new_bytes_prob = r.read_float(bytes)?, - Ok(34) => { - let (key, value) = r.read_map(bytes, |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?), |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?))?; - msg.capture_groups.insert(key, value); - } - Ok(42) => { - let (key, value) = r.read_map(bytes, |r, bytes| Ok(r.read_string(bytes).map(Cow::Borrowed)?), |r, bytes| Ok(r.read_float(bytes)?))?; - msg.capture_group_log_probs.insert(key, value); - } - Ok(48) => msg.new_token_count = r.read_int32(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for EngineCallResponse<'a> { - fn get_size(&self) -> usize { - 0 - + if self.new_bytes == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.new_bytes).len()) } - + if self.is_generated == false { 0 } else { 1 + sizeof_varint(*(&self.is_generated) as u64) } - + if self.new_bytes_prob == 0f32 { 0 } else { 1 + 4 } - + self.capture_groups.iter().map(|(k, v)| 1 + sizeof_len(2 + sizeof_len((k).len()) + sizeof_len((v).len()))).sum::() - + self.capture_group_log_probs.iter().map(|(k, v)| 1 + sizeof_len(2 + sizeof_len((k).len()) + 4)).sum::() - + if self.new_token_count == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.new_token_count) as u64) } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.new_bytes != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.new_bytes))?; } - if self.is_generated != false { w.write_with_tag(16, |w| w.write_bool(*&self.is_generated))?; } - if self.new_bytes_prob != 0f32 { w.write_with_tag(29, |w| w.write_float(*&self.new_bytes_prob))?; } - for (k, v) in self.capture_groups.iter() { w.write_with_tag(34, |w| w.write_map(2 + sizeof_len((k).len()) + sizeof_len((v).len()), 10, |w| w.write_string(&**k), 18, |w| w.write_string(&**v)))?; } - for (k, v) in self.capture_group_log_probs.iter() { w.write_with_tag(42, |w| w.write_map(2 + sizeof_len((k).len()) + 4, 10, |w| w.write_string(&**k), 21, |w| w.write_float(*v)))?; } - if self.new_token_count != 0i32 { w.write_with_tag(48, |w| w.write_int32(*&self.new_token_count))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Byte<'a> { - pub byte: Cow<'a, [u8]>, - pub hidden: bool, - pub commit_point: bool, - pub nullable: bool, - pub capture_name: Cow<'a, str>, - pub temperature: f32, -} - -impl<'a> MessageRead<'a> for Byte<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.byte = r.read_bytes(bytes).map(Cow::Borrowed)?, - Ok(16) => msg.hidden = r.read_bool(bytes)?, - Ok(24) => msg.commit_point = r.read_bool(bytes)?, - Ok(32) => msg.nullable = r.read_bool(bytes)?, - Ok(42) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(53) => msg.temperature = r.read_float(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for Byte<'a> { - fn get_size(&self) -> usize { - 0 - + if self.byte == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.byte).len()) } - + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } - + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } - + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } - + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } - + if self.temperature == 0f32 { 0 } else { 1 + 4 } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.byte != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.byte))?; } - if self.hidden != false { w.write_with_tag(16, |w| w.write_bool(*&self.hidden))?; } - if self.commit_point != false { w.write_with_tag(24, |w| w.write_bool(*&self.commit_point))?; } - if self.nullable != false { w.write_with_tag(32, |w| w.write_bool(*&self.nullable))?; } - if self.capture_name != "" { w.write_with_tag(42, |w| w.write_string(&**&self.capture_name))?; } - if self.temperature != 0f32 { w.write_with_tag(53, |w| w.write_float(*&self.temperature))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct ByteRange<'a> { - pub byte_range: Cow<'a, [u8]>, - pub hidden: bool, - pub commit_point: bool, - pub capture_name: Cow<'a, str>, - pub temperature: f32, -} - -impl<'a> MessageRead<'a> for ByteRange<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.byte_range = r.read_bytes(bytes).map(Cow::Borrowed)?, - Ok(24) => msg.hidden = r.read_bool(bytes)?, - Ok(32) => msg.commit_point = r.read_bool(bytes)?, - Ok(42) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(53) => msg.temperature = r.read_float(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for ByteRange<'a> { - fn get_size(&self) -> usize { - 0 - + if self.byte_range == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.byte_range).len()) } - + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } - + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } - + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } - + if self.temperature == 0f32 { 0 } else { 1 + 4 } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.byte_range != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.byte_range))?; } - if self.hidden != false { w.write_with_tag(24, |w| w.write_bool(*&self.hidden))?; } - if self.commit_point != false { w.write_with_tag(32, |w| w.write_bool(*&self.commit_point))?; } - if self.capture_name != "" { w.write_with_tag(42, |w| w.write_string(&**&self.capture_name))?; } - if self.temperature != 0f32 { w.write_with_tag(53, |w| w.write_float(*&self.temperature))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Null { } - -impl<'a> MessageRead<'a> for Null { - fn from_reader(r: &mut BytesReader, _: &[u8]) -> Result { - r.read_to_end(); - Ok(Self::default()) - } -} - -impl MessageWrite for Null { } - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct ModelVariable<'a> { - pub name: Cow<'a, str>, - pub hidden: bool, - pub commit_point: bool, - pub capture_name: Cow<'a, str>, - pub nullable: bool, -} - -impl<'a> MessageRead<'a> for ModelVariable<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(16) => msg.hidden = r.read_bool(bytes)?, - Ok(24) => msg.commit_point = r.read_bool(bytes)?, - Ok(34) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(40) => msg.nullable = r.read_bool(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for ModelVariable<'a> { - fn get_size(&self) -> usize { - 0 - + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } - + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } - + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } - + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } - + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.name != "" { w.write_with_tag(10, |w| w.write_string(&**&self.name))?; } - if self.hidden != false { w.write_with_tag(16, |w| w.write_bool(*&self.hidden))?; } - if self.commit_point != false { w.write_with_tag(24, |w| w.write_bool(*&self.commit_point))?; } - if self.capture_name != "" { w.write_with_tag(34, |w| w.write_string(&**&self.capture_name))?; } - if self.nullable != false { w.write_with_tag(40, |w| w.write_bool(*&self.nullable))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Join<'a> { - pub nullable: bool, - pub values: Vec, - pub name: Cow<'a, str>, - pub hidden: bool, - pub commit_point: bool, - pub capture_name: Cow<'a, str>, - pub max_tokens: i32, -} - -impl<'a> MessageRead<'a> for Join<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(8) => msg.nullable = r.read_bool(bytes)?, - Ok(18) => msg.values = r.read_packed(bytes, |r, bytes| Ok(r.read_int32(bytes)?))?, - Ok(26) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(32) => msg.hidden = r.read_bool(bytes)?, - Ok(40) => msg.commit_point = r.read_bool(bytes)?, - Ok(50) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(56) => msg.max_tokens = r.read_int32(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for Join<'a> { - fn get_size(&self) -> usize { - 0 - + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } - + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*(s) as u64)).sum::()) } - + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } - + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } - + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } - + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } - + if self.max_tokens == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.max_tokens) as u64) } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.nullable != false { w.write_with_tag(8, |w| w.write_bool(*&self.nullable))?; } - w.write_packed_with_tag(18, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*(m) as u64))?; - if self.name != "" { w.write_with_tag(26, |w| w.write_string(&**&self.name))?; } - if self.hidden != false { w.write_with_tag(32, |w| w.write_bool(*&self.hidden))?; } - if self.commit_point != false { w.write_with_tag(40, |w| w.write_bool(*&self.commit_point))?; } - if self.capture_name != "" { w.write_with_tag(50, |w| w.write_string(&**&self.capture_name))?; } - if self.max_tokens != 0i32 { w.write_with_tag(56, |w| w.write_int32(*&self.max_tokens))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Select<'a> { - pub nullable: bool, - pub values: Vec, - pub name: Cow<'a, str>, - pub hidden: bool, - pub commit_point: bool, - pub capture_name: Cow<'a, str>, - pub max_tokens: i32, - pub recursive: bool, -} - -impl<'a> MessageRead<'a> for Select<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(8) => msg.nullable = r.read_bool(bytes)?, - Ok(18) => msg.values = r.read_packed(bytes, |r, bytes| Ok(r.read_int32(bytes)?))?, - Ok(26) => msg.name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(32) => msg.hidden = r.read_bool(bytes)?, - Ok(40) => msg.commit_point = r.read_bool(bytes)?, - Ok(50) => msg.capture_name = r.read_string(bytes).map(Cow::Borrowed)?, - Ok(56) => msg.max_tokens = r.read_int32(bytes)?, - Ok(64) => msg.recursive = r.read_bool(bytes)?, - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for Select<'a> { - fn get_size(&self) -> usize { - 0 - + if self.nullable == false { 0 } else { 1 + sizeof_varint(*(&self.nullable) as u64) } - + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*(s) as u64)).sum::()) } - + if self.name == "" { 0 } else { 1 + sizeof_len((&self.name).len()) } - + if self.hidden == false { 0 } else { 1 + sizeof_varint(*(&self.hidden) as u64) } - + if self.commit_point == false { 0 } else { 1 + sizeof_varint(*(&self.commit_point) as u64) } - + if self.capture_name == "" { 0 } else { 1 + sizeof_len((&self.capture_name).len()) } - + if self.max_tokens == 0i32 { 0 } else { 1 + sizeof_varint(*(&self.max_tokens) as u64) } - + if self.recursive == false { 0 } else { 1 + sizeof_varint(*(&self.recursive) as u64) } - } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - if self.nullable != false { w.write_with_tag(8, |w| w.write_bool(*&self.nullable))?; } - w.write_packed_with_tag(18, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*(m) as u64))?; - if self.name != "" { w.write_with_tag(26, |w| w.write_string(&**&self.name))?; } - if self.hidden != false { w.write_with_tag(32, |w| w.write_bool(*&self.hidden))?; } - if self.commit_point != false { w.write_with_tag(40, |w| w.write_bool(*&self.commit_point))?; } - if self.capture_name != "" { w.write_with_tag(50, |w| w.write_string(&**&self.capture_name))?; } - if self.max_tokens != 0i32 { w.write_with_tag(56, |w| w.write_int32(*&self.max_tokens))?; } - if self.recursive != false { w.write_with_tag(64, |w| w.write_bool(*&self.recursive))?; } - Ok(()) - } -} - -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Default, PartialEq, Clone)] -pub struct GrammarFunction<'a> { - pub function_type: guidance::mod_GrammarFunction::OneOffunction_type<'a>, -} - -impl<'a> MessageRead<'a> for GrammarFunction<'a> { - fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { - let mut msg = Self::default(); - while !r.is_eof() { - match r.next_tag(bytes) { - Ok(10) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::join(r.read_message::(bytes)?), - Ok(18) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::select(r.read_message::(bytes)?), - Ok(26) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::byte(r.read_message::(bytes)?), - Ok(34) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::byte_range(r.read_message::(bytes)?), - Ok(42) => msg.function_type = guidance::mod_GrammarFunction::OneOffunction_type::model_variable(r.read_message::(bytes)?), - Ok(t) => { r.read_unknown(bytes, t)?; } - Err(e) => return Err(e), - } - } - Ok(msg) - } -} - -impl<'a> MessageWrite for GrammarFunction<'a> { - fn get_size(&self) -> usize { - 0 - + match self.function_type { - guidance::mod_GrammarFunction::OneOffunction_type::join(ref m) => 1 + sizeof_len((m).get_size()), - guidance::mod_GrammarFunction::OneOffunction_type::select(ref m) => 1 + sizeof_len((m).get_size()), - guidance::mod_GrammarFunction::OneOffunction_type::byte(ref m) => 1 + sizeof_len((m).get_size()), - guidance::mod_GrammarFunction::OneOffunction_type::byte_range(ref m) => 1 + sizeof_len((m).get_size()), - guidance::mod_GrammarFunction::OneOffunction_type::model_variable(ref m) => 1 + sizeof_len((m).get_size()), - guidance::mod_GrammarFunction::OneOffunction_type::None => 0, - } } - - fn write_message(&self, w: &mut Writer) -> Result<()> { - match self.function_type { guidance::mod_GrammarFunction::OneOffunction_type::join(ref m) => { w.write_with_tag(10, |w| w.write_message(m))? }, - guidance::mod_GrammarFunction::OneOffunction_type::select(ref m) => { w.write_with_tag(18, |w| w.write_message(m))? }, - guidance::mod_GrammarFunction::OneOffunction_type::byte(ref m) => { w.write_with_tag(26, |w| w.write_message(m))? }, - guidance::mod_GrammarFunction::OneOffunction_type::byte_range(ref m) => { w.write_with_tag(34, |w| w.write_message(m))? }, - guidance::mod_GrammarFunction::OneOffunction_type::model_variable(ref m) => { w.write_with_tag(42, |w| w.write_message(m))? }, - guidance::mod_GrammarFunction::OneOffunction_type::None => {}, - } Ok(()) - } -} - -pub mod mod_GrammarFunction { - -use super::*; - -#[derive(Debug, PartialEq, Clone)] -pub enum OneOffunction_type<'a> { - join(guidance::Join<'a>), - select(guidance::Select<'a>), - byte(guidance::Byte<'a>), - byte_range(guidance::ByteRange<'a>), - model_variable(guidance::ModelVariable<'a>), - None, -} - -impl<'a> Default for OneOffunction_type<'a> { - fn default() -> Self { - OneOffunction_type::None - } -} - -} - diff --git a/controllers/aici_abi/src/earley/mod.rs b/controllers/aici_abi/src/earley/mod.rs deleted file mode 100644 index 105c4d32..00000000 --- a/controllers/aici_abi/src/earley/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -mod byteset; -mod grammar; -mod parser; - -pub use byteset::ByteSet; -pub use parser::Parser; -pub use grammar::Grammar; - -#[cfg(not(target_arch = "wasm32"))] -mod guidance; -#[cfg(not(target_arch = "wasm32"))] -pub mod bench; diff --git a/controllers/aici_abi/src/earley/parser.rs b/controllers/aici_abi/src/earley/parser.rs deleted file mode 100644 index 740f97e0..00000000 --- a/controllers/aici_abi/src/earley/parser.rs +++ /dev/null @@ -1,364 +0,0 @@ -use std::{fmt::Debug, hash::Hash, ops::Range, vec}; - -use crate::toktree::{Recognizer, SpecialToken}; - -use super::grammar::{CGrammar, CSymIdx, RuleIdx, SimpleHash}; - -const DEBUG: bool = false; - -// this may speed up more complex grammar but slows down simple ones (by 10%) -const PREDICTED_SYM_FILTER: bool = false; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct Item { - data: u64, -} - -#[derive(Debug, Default)] -pub struct Stats { - pub rows: usize, - pub empty_rows: usize, - pub nontrivial_scans: usize, - pub scan_items: usize, - pub all_items: usize, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ParseResult { - Accept, - Reject, - Continue, -} - -struct Row { - first_item: usize, - last_item: usize, -} - -impl Row { - fn item_indices(&self) -> Range { - self.first_item..self.last_item - } -} - -impl Item { - fn new(rule: RuleIdx, start: usize) -> Self { - Item { - data: rule.as_index() as u64 | ((start as u64) << 32), - } - } - - fn rule_idx(&self) -> RuleIdx { - RuleIdx::from_index(self.data as u32) - } - - fn start_pos(&self) -> usize { - (self.data >> 32) as usize - } - - fn advance_dot(&self) -> Self { - Item { - data: self.data + 1, - } - } -} - -impl SimpleHash for Item { - fn simple_hash(&self) -> u32 { - (self.rule_idx().as_index() as u32) - .wrapping_mul(16315967) - .wrapping_add((self.start_pos() as u32).wrapping_mul(33398653)) - } -} - -struct SimpleSet { - hash: u64, - items: Vec, -} - -impl Default for SimpleSet { - fn default() -> Self { - SimpleSet { - hash: 0, - items: vec![], - } - } -} - -impl SimpleSet { - fn clear(&mut self) { - self.hash = 0; - self.items.clear(); - } - - #[inline(always)] - fn insert(&mut self, item: T) { - let mask = item.mask64(); - if (self.hash & mask) != 0 && self.items.contains(&item) { - return; - } - self.hash |= mask; - self.items.push(item); - } - - #[inline(always)] - fn contains(&self, item: T) -> bool { - if (item.mask64() & self.hash) == 0 { - false - } else { - self.items.contains(&item) - } - } - - #[inline(always)] - fn should_insert(&mut self, item: T) -> bool { - if !PREDICTED_SYM_FILTER { - true - } else { - if self.contains(item) { - false - } else { - self.insert(item); - true - } - } - } -} - -#[derive(Default)] -struct Scratch { - row_start: usize, - row_end: usize, - items: Vec, - predicated_syms: SimpleSet, -} - -pub struct Parser { - grammar: CGrammar, - scratch: Scratch, - rows: Vec, - stats: Stats, - is_accepting: bool, - last_collapse: usize, -} - -impl Scratch { - fn new_row(&mut self, pos: usize) { - self.row_start = pos; - self.row_end = pos; - } - - fn row_len(&self) -> usize { - self.row_end - self.row_start - } - - #[inline(always)] - fn ensure_items(&mut self, n: usize) { - if self.items.len() < n { - let missing = n - self.items.len(); - self.items.reserve(missing); - unsafe { self.items.set_len(n) } - } - } - - #[inline(always)] - fn just_add(&mut self, item: Item) { - self.ensure_items(self.row_end + 1); - // SAFETY: we just ensured that there is enough space - unsafe { - self.items.as_mut_ptr().add(self.row_end).write(item); - } - // self.items[self.row_end] = item; - self.row_end += 1; - } - - #[inline(always)] - fn add_unique(&mut self, item: Item, _info: &str) { - if !self.items[self.row_start..self.row_end].contains(&item) { - self.just_add(item); - } - } -} - -impl Parser { - pub fn new(grammar: CGrammar) -> Self { - let start = grammar.start(); - let mut r = Parser { - grammar, - rows: vec![], - scratch: Scratch::default(), - stats: Stats::default(), - is_accepting: false, - last_collapse: 0, - }; - for rule in r.grammar.rules_of(start).to_vec() { - r.scratch.add_unique(Item::new(rule, 0), "init"); - } - let _ = r.push_row(); - r - } - - pub fn is_accepting(&self) -> bool { - self.is_accepting - } - - fn item_to_string(&self, item: &Item) -> String { - format!( - "{} @{}", - self.grammar.rule_to_string(item.rule_idx()), - item.start_pos() - ) - } - - pub fn print_row(&self, row_idx: usize) { - let row = &self.rows[row_idx]; - println!("row {}", row_idx); - for i in row.item_indices() { - println!("{}", self.item_to_string(&self.scratch.items[i])); - } - } - - pub fn num_rows(&self) -> usize { - self.rows.len() - } - - #[inline(always)] - pub fn scan(&mut self, b: u8) -> ParseResult { - let row_idx = self.rows.len() - 1; - let last = self.rows[row_idx].last_item; - let mut i = self.rows[row_idx].first_item; - let n = last - i; - self.scratch.ensure_items(last + n + 100); - - let allowed = self.grammar.terminals_by_byte(b); - - self.scratch.new_row(last); - - while i < last { - let item = self.scratch.items[i]; - let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index(); - // idx == 0 => completed - if idx < allowed.len() && allowed[idx] { - self.scratch.just_add(item.advance_dot()); - } - i += 1; - } - self.push_row() - } - - pub fn pop_rows(&mut self, n: usize) { - unsafe { self.rows.set_len(self.rows.len() - n) } - // self.rows.drain(self.rows.len() - n..); - } - - pub fn print_stats(&mut self) { - println!("stats: {:?}", self.stats); - self.stats = Stats::default(); - } - - #[inline(always)] - fn push_row(&mut self) -> ParseResult { - let curr_idx = self.rows.len(); - let mut agenda_ptr = self.scratch.row_start; - - self.scratch.predicated_syms.clear(); - - self.stats.rows += 1; - self.is_accepting = false; - - while agenda_ptr < self.scratch.row_end { - let item = self.scratch.items[agenda_ptr]; - agenda_ptr += 1; - if DEBUG { - println!("from agenda: {}", self.item_to_string(&item)); - } - - let rule = item.rule_idx(); - let after_dot = self.grammar.sym_idx_at(rule); - - if after_dot == CSymIdx::NULL { - let lhs = self.grammar.sym_idx_of(item.rule_idx()); - // complete - self.is_accepting = self.is_accepting || lhs == self.grammar.start(); - - if item.start_pos() < curr_idx { - // if item.start_pos() == curr_idx, then we handled it above in the nullable check - for i in self.rows[item.start_pos()].item_indices() { - let item = self.scratch.items[i]; - if self.grammar.sym_idx_at(item.rule_idx()) == lhs { - self.scratch.add_unique(item.advance_dot(), "complete"); - } - } - } - } else { - let sym_data = self.grammar.sym_data(after_dot); - if sym_data.is_nullable { - self.scratch.add_unique(item.advance_dot(), "null"); - } - if self.scratch.predicated_syms.should_insert(after_dot) { - for rule in &sym_data.rules { - let new_item = Item::new(*rule, curr_idx); - self.scratch.add_unique(new_item, "predict"); - } - } - } - } - - let row_len = self.scratch.row_len(); - self.stats.all_items += row_len; - - if row_len == 0 { - assert!(!self.is_accepting); - return ParseResult::Reject; - } - - self.rows.push(Row { - first_item: self.scratch.row_start, - last_item: self.scratch.row_end, - }); - - if self.is_accepting { - ParseResult::Accept - } else { - ParseResult::Continue - } - } -} - -impl Recognizer for Parser { - fn pop_bytes(&mut self, num: usize) { - self.pop_rows(num); - } - - fn collapse(&mut self) { - // this actually means "commit" - can no longer backtrack past this point - - if false { - for idx in self.last_collapse..self.num_rows() { - self.print_row(idx); - } - } - self.last_collapse = self.num_rows(); - } - - fn special_allowed(&mut self, tok: SpecialToken) -> bool { - if tok == SpecialToken::EndOfSentence { - self.is_accepting() - } else { - false - } - } - - fn trie_finished(&mut self) { - // do nothing? - } - - fn try_push_byte(&mut self, byte: u8) -> bool { - let res = self.scan(byte); - if res == ParseResult::Reject { - false - } else { - true - } - } -} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 9307df1a..ae6a418f 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -16,9 +16,6 @@ mod lex; #[cfg(feature = "rx")] pub mod rx; -#[cfg(feature = "earley")] -pub mod earley; - pub mod substring; pub type TokenId = bytes::TokenId; From de384d49f49c63ebfd5b979a756505c433460a79 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Mar 2024 21:51:26 +0000 Subject: [PATCH 176/301] link modules in gctrl --- controllers/aici_abi/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 79ce484f..638748d8 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -19,9 +19,6 @@ lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } -[target.'cfg(not(target_arch = "wasm32"))'.dependencies] -quick-protobuf = { version = "0.8.1" } - [features] default = ["cfg", "rx"] cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] From 6cc60f68a23a69dd26af77ecbd05a47d8fcb3ea2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 14 Mar 2024 22:07:52 +0000 Subject: [PATCH 177/301] organize files --- controllers/aici_abi/annotate_asm.js | 103 --------------------------- controllers/aici_abi/disasm.sh | 17 ----- 2 files changed, 120 deletions(-) delete mode 100644 controllers/aici_abi/annotate_asm.js delete mode 100755 controllers/aici_abi/disasm.sh diff --git a/controllers/aici_abi/annotate_asm.js b/controllers/aici_abi/annotate_asm.js deleted file mode 100644 index 26e1a260..00000000 --- a/controllers/aici_abi/annotate_asm.js +++ /dev/null @@ -1,103 +0,0 @@ -const child_process = require("child_process") -const fs = require("fs") - -const sysroot = child_process.execSync("rustc --print sysroot").toString().trim() - -function main(sname, filter) { - if (!filter) { - console.error("please pass filter arg") - return - } - - const sections = {} - const files = [] - let idx = 0 - for (const sect of fs.readFileSync(sname, "utf8").split("\n\n")) { - idx++ - let sectId = "sect" + idx - let m = /^\t\.type\t(.*),@/m.exec(sect) - if (m) { - sectId = m[1] - } - - let outp = "" - for (const line of sect.split("\n")) { - if (line.startsWith(".Ltmp") || line.startsWith("\t.cfi_")) - continue - if (line.startsWith("\t.file\t")) { - m = /(\d+)\s+"([^"]+)"\s+"([^"]+)"/.exec(line) - if (!m) { - // console.error("Bad file line", line) - } else { - const folder = m[2].replace(/^\/rustc\/[^/]+/, sysroot + "/lib/rustlib/src/rust") - files[+m[1]] = folder + "/" + m[3] - } - continue - } - outp += line + "\n" - } - - sections[sectId] = outp - } - - const keys = Object.keys(sections).filter(k => k.includes(filter)) - if (keys.length > 1) { - const max = 50 - console.error("Multiple sections found for filter", filter, keys.slice(0, max).join("\n")) - if (keys.length > max) { - console.error("...") - } - return - } - if (keys.length === 0) { - console.error("No sections found for filter", filter) - return - } - - const filecontent = [] - - function fileLines(id) { - if (filecontent[id]) { - return filecontent[id] - } - const lines = fs.readFileSync(files[id], "utf8").split("\n") - filecontent[id] = lines - return lines - } - - let outp = "" - const labels = {} - for (let line of sections[keys[0]].split("\n")) { - if (line.startsWith("\t.loc\t")) { - const m = /\t.loc\t(\d+)\s+(\d+)\s+(\d+)/.exec(line) - const lineno = +m[2] - const lines = fileLines(+m[1]) - const filename = files[+m[1]] - let basename = filename.split("/").pop() - if (filename.startsWith(sysroot)) - basename = "[lib]" + basename - // outp += "// file://" + files[+m[1]] + "\n" - if (lines[lineno - 1] !== undefined) { - const tag = basename + ":" + lineno - outp += "// " + tag.padEnd(40, " ") + lines[lineno - 1] + "\n" - } - } else { - const m = /^(\.L\w+):/.exec(line) - if (m) { - labels[m[1]] = true - } - const words = line.split(/\s+/) - if (words.some(w => labels[w])) { - line += " // ===============================================> BACK" - } - outp += line + "\n" - } - } - - console.log("Section", keys[0], ":") - console.log(outp) -} - - -const args = process.argv.slice(2) -main(...args) diff --git a/controllers/aici_abi/disasm.sh b/controllers/aici_abi/disasm.sh deleted file mode 100755 index 41f9fb17..00000000 --- a/controllers/aici_abi/disasm.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh - -TRG=`rustup show | head -1 | sed -e 's/.*: //'` -CRATE=`grep "^name =" Cargo.toml | head -1 | sed -e 's/.*= "//; s/"//'` -RUSTFLAGS="--emit asm" cargo build --release --target $TRG -F=`echo ../../target/$TRG/release/deps/$CRATE-*.s` -# if $F has more than one file -if [ `echo $F | wc -w` -gt 1 ]; then - echo "More than one file found: $F; removing; try again" - rm -f $F - exit 1 -fi - -mkdir -p tmp -cp $F tmp/full.s -node annotate_asm.js tmp/full.s "$@" | rustfilt > tmp/func.s -ls -l tmp/func.s From d836bff65dc6a9cbedeec822e4be692c9265361c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 15 Mar 2024 18:02:40 +0000 Subject: [PATCH 178/301] token-forcing --- controllers/aici_abi/src/toktree.rs | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 0c6a0add..965e0b2d 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -52,6 +52,7 @@ pub struct TokTrie { token_offsets: Vec, token_data: Vec, nodes: Vec, + max_token_len: usize, } #[repr(C)] @@ -134,12 +135,17 @@ impl TokTrie { } let mut nodes = Vec::new(); trie.serialize(&mut nodes, 0); - let r = TokTrie { + let mut r = TokTrie { info: info.clone(), token_offsets, token_data, nodes, + max_token_len: 0, }; + r.max_token_len = (0..info.vocab_size) + .map(|idx| r.token(idx).len()) + .max() + .unwrap(); r.validate(); r } @@ -268,6 +274,13 @@ impl TokTrie { r } + pub fn has_extensions(&self, bytes: &[u8]) -> bool { + match self.child_at_bytes(self.root(), bytes) { + None => false, + Some(n) => n.subtree_size() > 1, + } + } + pub fn token_id(&self, bytes: &[u8]) -> Option { let (tok, len) = self.prefix_token_id(bytes); // println!("tok_id {:?} {:?} {:?} ", bytes, tok, len); @@ -306,16 +319,25 @@ impl TokTrie { let token_offsets = vec_from_bytes(&bytes[trie_end..offsets_end]); let token_data = vec_from_bytes(&bytes[offsets_end..]); - let r = TokTrie { + let mut r = TokTrie { info: hd.info, token_offsets, token_data, nodes, + max_token_len: 0, }; r.validate(); + r.max_token_len = (0..r.info.vocab_size) + .map(|idx| r.token(idx).len()) + .max() + .unwrap(); r } + pub fn max_token_len(&self) -> usize { + self.max_token_len + } + fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) { if let Some(tok) = n.token_id() { assert!(tok < self.info.vocab_size); From bb451eb40831121e1ce7fcf56817399238a418d1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 15 Mar 2024 20:37:46 +0000 Subject: [PATCH 179/301] token healing --- controllers/aici_abi/src/toktree.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 965e0b2d..7d8b6e5f 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -210,6 +210,23 @@ impl TokTrie { vec![0.0; self.vocab_size() + 1] } + pub fn tokens_dbg(&self, toks: &[u32]) -> String { + format!( + "\"{}\"", + toks.iter() + .map(|t| { + let s = self.token_dbg(*t); + if s.starts_with("\"") { + s[1..s.len() - 1].to_string() + } else { + format!("<{}>", s) + } + }) + .collect::>() + .join("‿"), + ) + } + pub fn token_dbg(&self, idx: u32) -> String { if idx == self.info.tok_eos { "EOS".to_string() From bbe951c3a19684c49dc2038ba69a4b8c5bad3277 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 15 Mar 2024 22:30:12 +0000 Subject: [PATCH 180/301] better token debugging in TokTrie --- controllers/aici_abi/src/toktree.rs | 33 ++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 7d8b6e5f..15fbc1af 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -211,20 +211,29 @@ impl TokTrie { } pub fn tokens_dbg(&self, toks: &[u32]) -> String { - format!( - "\"{}\"", - toks.iter() - .map(|t| { - let s = self.token_dbg(*t); - if s.starts_with("\"") { - s[1..s.len() - 1].to_string() + let minimal = false; + let sep = "‧"; + let joined = toks + .iter() + .map(|t| { + let s = self.token_dbg(*t); + if s.starts_with("\"") { + let inner = s[1..s.len() - 1].to_string(); + let b = s.as_bytes(); + // for " [\w]..." and " " the sep in front is implicit + if minimal && b[1] == b' ' && ((b[2] as char).is_alphanumeric() || b.len() == 3) + { + inner } else { - format!("<{}>", s) + format!("{}{}", sep, inner) } - }) - .collect::>() - .join("‿"), - ) + } else { + format!("≺{}≻", s) + } + }) + .collect::>() + .join(""); + format!("\"{}\"", joined.trim_start_matches(sep)) } pub fn token_dbg(&self, idx: u32) -> String { From 814342f0883929e0c2cbf4cdff7151811d0e9728 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 18 Mar 2024 17:56:02 +0000 Subject: [PATCH 181/301] support for model variables --- controllers/aici_abi/src/toktree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 15fbc1af..5e48878c 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -9,7 +9,7 @@ use crate::{ svob::SimpleVob, }; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum SpecialToken { Unknown, Padding, From 1c4d60bab41925acbad4f714cf315f254e2b52d2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 18 Mar 2024 21:30:56 +0000 Subject: [PATCH 182/301] account for duplicate tokens; see #78 --- controllers/aici_abi/src/svob.rs | 10 +++ controllers/aici_abi/src/toktree.rs | 94 +++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 25 deletions(-) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 6513080d..c5fd1997 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -41,6 +41,16 @@ impl SimpleVob { self.data.iter().map(|x| x.count_ones() as usize).sum() } + pub fn negated(&self, size: usize) -> Self { + let mut r = Self::new(); + r.data = self.data.iter().map(|x| !x).collect(); + for i in size..r.len() { + // disallow tokens that are out of range + r.disallow_token(i as TokenId); + } + r + } + pub unsafe fn as_ptr(&self) -> *const u32 { self.data.as_ptr() } diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 5e48878c..4c87b206 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -1,6 +1,8 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 +use rustc_hash::FxHashMap; + use crate::{ bytes::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, @@ -53,6 +55,7 @@ pub struct TokTrie { token_data: Vec, nodes: Vec, max_token_len: usize, + token_duplicates: FxHashMap>, } #[repr(C)] @@ -126,7 +129,7 @@ impl TokTrie { assert!(info.vocab_size == words.len() as u32); for (idx, word) in words.iter().enumerate() { if word.len() > 0 { - trie.insert(word, idx as u32) + trie.insert(word, idx as u32); } assert!(word.len() < 0xff); let desc = (word.len() as u32) | ((token_data.len() as u32) << 8); @@ -141,15 +144,27 @@ impl TokTrie { token_data, nodes, max_token_len: 0, + token_duplicates: FxHashMap::default(), }; - r.max_token_len = (0..info.vocab_size) - .map(|idx| r.token(idx).len()) - .max() - .unwrap(); - r.validate(); + r.finalize_ctor(); r } + fn finalize_ctor(&mut self) { + for tok_id in 0..self.info.vocab_size { + let bytes = self.token(tok_id); + let tok_ids = self.greedy_tokenize(bytes); + self.max_token_len = std::cmp::max(self.max_token_len, bytes.len()); + if tok_ids.len() == 1 && tok_ids[0] != tok_id { + self.token_duplicates + .entry(tok_ids[0]) + .or_insert_with(Vec::new) + .push(tok_id); + } + } + self.validate(); + } + fn node_offset(&self, n: &TrieNode) -> usize { let off = unsafe { (n as *const TrieNode).offset_from(self.root() as *const TrieNode) }; assert!(off >= 0); @@ -184,11 +199,14 @@ impl TokTrie { } pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { - let num_set = ts.num_set(); + let ts_neg = ts.negated(self.vocab_size()); + let use_neg = ts_neg.num_set() * 20 < ts.num_set(); + let ts1 = if use_neg { &ts_neg } else { &ts }; + let num_set = ts1.num_set(); let max_tok = std::cmp::min(100, num_set); let mut token_names = Vec::new(); for idx in 0..self.vocab_size() { - if ts.is_allowed(idx as TokenId) { + if ts1.is_allowed(idx as TokenId) { token_names.push(self.token_dbg(idx as TokenId)); if token_names.len() >= max_tok { break; @@ -199,9 +217,10 @@ impl TokTrie { token_names.push("...".to_string()); } format!( - "TokenSet: {}/{}; {}", - num_set, + "TokenSet: {}/{}; {}{}", + ts.num_set(), self.vocab_size(), + if use_neg { "ALL EXCEPT " } else { "" }, token_names.join(", ") ) } @@ -243,7 +262,21 @@ impl TokTrie { format!("OOB[{}]", idx) } else { // format!("{:?}[{}]", self.token_str(idx), idx) - format!("{:?}", self.token_str(idx)) + let s = self.token_str(idx); + if s.len() == 0 { + format!("EMPTY[{}]", idx) + } else if !s.contains('\u{fffd}') { + format!("{:?}", s) + } else { + let bytes = self.token(idx); + format!( + "HEX[{}]", + bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(), + ) + } } } @@ -351,12 +384,9 @@ impl TokTrie { token_data, nodes, max_token_len: 0, + token_duplicates: FxHashMap::default(), }; - r.validate(); - r.max_token_len = (0..r.info.vocab_size) - .map(|idx| r.token(idx).len()) - .max() - .unwrap(); + r.finalize_ctor(); r } @@ -422,13 +452,14 @@ impl TokTrie { assert!(bytes == self.token(tid)); let root = self.root(); if bytes.len() > 0 { - assert!( - self.child_at_bytes(root, &bytes) - .unwrap() - .token_id() - .unwrap() - == tid - ); + let tid2 = self + .child_at_bytes(root, &bytes) + .unwrap() + .token_id() + .unwrap(); + if tid != tid2 { + assert!(self.token_duplicates[&tid2].contains(&tid)); + } } } } @@ -468,7 +499,18 @@ impl TokTrie { logits.allow_token(self.special_token(tok)) } } - self.add_bias(r, logits) + self.add_bias(r, logits); + self.apply_duplicates(logits); + } + + pub fn apply_duplicates(&self, logits: &mut SimpleVob) { + for (tok, dups) in &self.token_duplicates { + if logits.is_allowed(*tok) { + for &dup in dups { + logits.allow_token(dup); + } + } + } } pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) { @@ -575,7 +617,9 @@ impl TrieHash { } fn insert(&mut self, word: &[u8], token_id: u32) { if word.len() == 0 { - assert!(self.token_id == NO_TOKEN); + // Some tokenizers have duplicate tokens... + // we just override + // assert!(self.token_id == NO_TOKEN); self.token_id = token_id; } else { if self.children.len() == 0x100 { From 52b4027e2f677b41ee1f5ad0061a053ce55df48c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 20 Mar 2024 00:10:46 +0000 Subject: [PATCH 183/301] working on hidden items --- controllers/aici_abi/src/toktree.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 4c87b206..bed4730e 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -45,6 +45,8 @@ pub trait Recognizer { /// Called when iteration over the trie is finished /// Stack has exactly one element then. fn trie_finished(&mut self); + /// Called when iteration over the trie is started + fn trie_started(&mut self) {} /// This combines `push_byte` and `byte_allowed` into one function for performance. fn try_push_byte(&mut self, byte: u8) -> bool; } @@ -546,6 +548,7 @@ impl TokTrie { #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob) { + r.trie_started(); let n = self.root(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); From 4c7ecba9860d966df469569cdcf8427ffa1c817a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 21 Mar 2024 09:22:48 -0700 Subject: [PATCH 184/301] towards hidden nodes --- controllers/aici_abi/src/toktree.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index bed4730e..88bca4e9 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -495,13 +495,17 @@ impl TokTrie { } pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut SimpleVob) { + self.compute_bias_ext(r, logits, &[]); + } + + pub fn compute_bias_ext(&self, r: &mut impl Recognizer, logits: &mut SimpleVob, start: &[u8]) { logits.set_all(false); for tok in vec![SpecialToken::EndOfSentence] { if r.special_allowed(tok) { logits.allow_token(self.special_token(tok)) } } - self.add_bias(r, logits); + self.add_bias(r, logits, start); self.apply_duplicates(logits); } @@ -547,9 +551,9 @@ impl TokTrie { } #[inline(never)] - pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob) { + pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { r.trie_started(); - let n = self.root(); + let n = self.child_at_bytes(self.root(), start).unwrap(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); let mut p = off + 1; From 171f8443db6f0ed544cb28905f57e0f6fe298cf2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 22 Mar 2024 00:34:23 +0000 Subject: [PATCH 185/301] earley fixes --- controllers/aici_abi/src/toktree.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 88bca4e9..b7c50f7a 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -505,6 +505,15 @@ impl TokTrie { logits.allow_token(self.special_token(tok)) } } + // all prefixes of 'start' are also allowed + if start.len() > 0 { + for len in 1..start.len() - 1 { + let bytes = &start[0..len]; + if let Some(tok) = self.token_id(bytes) { + logits.allow_token(tok); + } + } + } self.add_bias(r, logits, start); self.apply_duplicates(logits); } From b93ae760cbb73709883edd73b6391674ad051397 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 22 Mar 2024 22:15:53 +0000 Subject: [PATCH 186/301] add utility hex functions --- controllers/aici_abi/src/bytes.rs | 23 +++++++++++++++++++++++ controllers/aici_abi/src/host.rs | 13 ++++--------- controllers/aici_abi/src/toktree.rs | 11 +++-------- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/controllers/aici_abi/src/bytes.rs b/controllers/aici_abi/src/bytes.rs index 1c471e6c..66ef0de0 100644 --- a/controllers/aici_abi/src/bytes.rs +++ b/controllers/aici_abi/src/bytes.rs @@ -1,5 +1,7 @@ use std::{mem::size_of, slice::from_raw_parts}; +use anyhow::{anyhow, Result}; + pub(crate) type TokenId = u32; #[repr(C)] @@ -62,3 +64,24 @@ pub fn limit_bytes(s: &[u8], max_len: usize) -> String { String::from_utf8_lossy(s).to_string() } } + +pub fn to_hex_string(bytes: &[u8]) -> String { + bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::>() + .join("") +} + +pub fn from_hex_string(s: &str) -> Result> { + let mut result = Vec::with_capacity(s.len() / 2); + let mut iter = s.chars(); + while let Some(c1) = iter.next() { + let c2 = iter + .next() + .ok_or_else(|| anyhow!("expecting even number of chars"))?; + let byte = u8::from_str_radix(&format!("{}{}", c1, c2), 16)?; + result.push(byte); + } + Ok(result) +} diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index d25665c5..eca93f06 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -83,7 +83,6 @@ pub fn arg_string() -> String { String::from_utf8_lossy(&arg_bytes()).to_string() } - pub fn trie_bytes() -> Vec { #[cfg(target_arch = "wasm32")] return read_blob(unsafe { aici_host_token_trie() }, 0); @@ -133,20 +132,16 @@ pub mod bin_string { pub mod hex_string { use serde::{Deserialize, Deserializer, Serialize, Serializer}; + use crate::bytes::{from_hex_string, to_hex_string}; + pub fn serialize(v: &Vec, s: S) -> Result { - let hexstr = String::from_iter(v.iter().map(|b| format!("{:02x}", b))); + let hexstr = to_hex_string(v); String::serialize(&hexstr, s) } pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { let hexstr = String::deserialize(d)?; - let mut res = Vec::new(); - for i in 0..(hexstr.len() / 2) { - let b = u8::from_str_radix(&hexstr[2 * i..2 * i + 2], 16) - .map_err(serde::de::Error::custom)?; - res.push(b); - } - Ok(res) + from_hex_string(&hexstr).map_err(serde::de::Error::custom) } } diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index b7c50f7a..09a0b0cc 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -5,7 +5,8 @@ use rustc_hash::FxHashMap; use crate::{ bytes::{ - box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, + box_from_bytes, clone_as_bytes, clone_vec_as_bytes, to_hex_string, vec_from_bytes, + TokRxInfo, TokenId, }, host::trie_bytes, svob::SimpleVob, @@ -271,13 +272,7 @@ impl TokTrie { format!("{:?}", s) } else { let bytes = self.token(idx); - format!( - "HEX[{}]", - bytes - .iter() - .map(|b| format!("{:02x}", b)) - .collect::(), - ) + format!("HEX[{}]", to_hex_string(bytes)) } } } From be8d0ef87ee22b1b918971a59d3c43c794da6a94 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 23 Mar 2024 23:41:08 +0000 Subject: [PATCH 187/301] hidden items work --- controllers/aici_abi/src/toktree.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 09a0b0cc..9bde3089 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -202,11 +202,13 @@ impl TokTrie { } pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { + let max_examples = 50; + let ts_neg = ts.negated(self.vocab_size()); let use_neg = ts_neg.num_set() * 20 < ts.num_set(); let ts1 = if use_neg { &ts_neg } else { &ts }; let num_set = ts1.num_set(); - let max_tok = std::cmp::min(100, num_set); + let max_tok = std::cmp::min(max_examples, num_set); let mut token_names = Vec::new(); for idx in 0..self.vocab_size() { if ts1.is_allowed(idx as TokenId) { From ab557380d299cf2c56774fe5c803ed93581ba50b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Mar 2024 22:06:31 +0000 Subject: [PATCH 188/301] removing pre/post - WIP --- controllers/aici_abi/src/lib.rs | 242 ++++++++++++------------- controllers/aici_abi/src/recognizer.rs | 17 +- controllers/aici_abi/src/yesno.rs | 39 +--- 3 files changed, 130 insertions(+), 168 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index ae6a418f..a062b90f 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,3 +1,4 @@ +use regex_automata::nfa::thompson::backtrack; use serde::{Deserialize, Serialize}; use svob::SimpleVob; @@ -21,13 +22,12 @@ pub mod substring; pub type TokenId = bytes::TokenId; pub use host::{ - aici_stop, arg_bytes, arg_string, return_logit_bias, self_seq_id, tokenize, tokenize_bytes, - StorageCmd, StorageOp, StorageResp, VariableStorage, + aici_stop, arg_bytes, arg_string, self_seq_id, tokenize, tokenize_bytes, StorageCmd, StorageOp, + StorageResp, VariableStorage, }; #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { - /// Typically just the start token if any. pub prompt: Vec, } @@ -39,123 +39,127 @@ pub struct InitPromptResult {} pub struct SeqId(pub u32); #[derive(Serialize, Deserialize, Debug)] -pub struct PreProcessArg {} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct PreProcessResult { - /// If 0 - stop the sequence. - /// If 1 - just continue. - /// If more than 1 - fork the generation. - pub num_forks: usize, - - pub suspend: bool, - - /// If non-empty, the tokens may be appended and post_process() be called immediately, - /// skipping mid_process(); pre_process() is then typically called again. - pub ff_tokens: Vec, +pub struct MidProcessArg { + /// Sampling result for the previous iteration. + /// For simple sampled token 't', backtrack==0 and tokens==[t]. + /// For first request, backtrack==0 and tokens==[] (prompt is passed separetely, before). + /// Can be more complex when splices are used. + pub backtrack: u32, + pub tokens: Vec, + /// + pub fork_group: Vec, } -impl Default for PreProcessResult { - fn default() -> Self { - PreProcessResult { - num_forks: 1, - suspend: false, - ff_tokens: vec![], - } +impl MidProcessArg { + pub fn has_eos(&self) -> bool { + let eos = host::eos_token(); + self.tokens.iter().any(|t| *t == eos) } -} -#[derive(Serialize, Deserialize, Debug)] -pub struct MidProcessArg { - /// fork_group.len() == num_forks. - /// Use host::self_seq_id() to get the ID of the current sequence. - pub fork_group: Vec, + pub fn save_tokens(&self, acc_tokens: &mut Vec) { + acc_tokens.truncate(acc_tokens.len() - self.backtrack as usize); + acc_tokens.extend_from_slice(&self.tokens); + } } -#[derive(Serialize, Deserialize, Debug)] -pub enum MidProcessResult { - /// Stop the current sequence. - /// Similar to strong bias to EOS. - Stop, - - /// Sample next token in the current sequence - SampleWithBias { - #[serde(skip)] - allowed_tokens: SimpleVob, - }, - - /// First pop `backtrack` tokens, - /// then force next tokens to be generated to be `ff_tokens`. - /// `backtrack` count includes the token about to be generated from this step. - /// `backtrack` can be 0, and `ff_tokens` can be empty but not both. - Splice { - backtrack: u32, - ff_tokens: Vec, - }, +/* +For example, if we're generating JSON, according to the following schema: +{ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + } } -#[derive(Serialize, Deserialize, Debug)] -pub struct PostProcessArg { - /// Generally, issued after each token generated by the model. - /// `tokens` is typically just this one token, except for the - /// cases when fast-forward tokens are used. - pub tokens: Vec, - - /// Typically 0. +Let's say we have generated: {"name": "something +We would use a single splice: + when_sampled: ['"', '",', '", '], + backtrack: 1, + ff_tokens: tokenize('", "age": ') +Which means: when any token starting with '"' is sampled, we remove it (backtrack: 1) +and then append the next full fragment of JSON '", "age": ' + +If the tokenizers has tokens like 'a"', 'b"' etc, then we would need many splices +(there may be limits how many we want to pass over the IPC boundry). +*/ + +/// Describes what to do after sampling. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Splice { + /// If one of the tokens in when_sampled is sampled, this sequence is appended. + /// When empty, this sequence is appended unconditionally, regardless of sampling. + pub when_sampled: Vec, + /// Backtrack this much before appending this sequence (this includes sampled token if any). pub backtrack: u32, + /// Append these tokens after backtracking. + pub ff_tokens: Vec, } #[derive(Serialize, Deserialize, Debug)] -pub struct PostProcessResult { - /// If true, stop the sequence. - pub stop: bool, +pub struct Branch { + /// If None, no sampling is performed. + /// If Some(set), only tokens from the set are allowed. + pub sample_mask: Option, + /// Describes what to do after sampling. + /// If no sampling, there should be exactly one splice, with empty `when_sampled`. + pub splices: Vec, } -impl PostProcessResult { - pub fn stop() -> Self { - PostProcessResult { stop: true } - } - - pub fn continue_() -> Self { - PostProcessResult { stop: false } +impl Branch { + pub fn map_mask(&self, f: F) -> Branch + where + F: FnOnce(&S) -> T, + { + Branch { + sample_mask: self.sample_mask.as_ref().map(f), + splices: self.splices.clone(), + } } +} - pub fn from_arg(arg: &PostProcessArg) -> Self { - let stop = arg.tokens.contains(&host::eos_token()); - PostProcessResult { stop } - } +#[derive(Debug)] +pub struct MidProcessResult { + /// Fork the request into multiple branches. + /// Typically, exactly one branch is returned. + /// If multiple branches are returned, they are executed in parallel. + /// If no branches are returned, the request is terminated. + pub branches: Vec>, } -impl PreProcessResult { - pub fn new(num_forks: usize) -> Self { - PreProcessResult { - num_forks, - suspend: false, - ff_tokens: vec![], - } - } - pub fn continue_() -> Self { - PreProcessResult::new(1) +impl MidProcessResult { + pub fn stop() -> Self { + MidProcessResult { branches: vec![] } } - pub fn suspend() -> Self { - PreProcessResult { - num_forks: 1, - suspend: true, - ff_tokens: vec![], + + pub fn sample(set: SimpleVob) -> Self { + MidProcessResult { + branches: vec![Branch { + sample_mask: Some(set), + splices: vec![], + }], } } - pub fn stop() -> Self { - PreProcessResult::new(0) - } - pub fn ff_tokens(toks: Vec) -> Self { - PreProcessResult { - num_forks: 1, - suspend: false, - ff_tokens: toks, + + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { + MidProcessResult { + branches: vec![Branch { + sample_mask: None, + splices: vec![Splice { + when_sampled: vec![], + backtrack, + ff_tokens, + }], + }], } } } +#[derive(Serialize, Deserialize)] +pub struct ProcessResultOffset { + pub branches: Vec>, +} + pub trait AiciCtrl { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. @@ -164,20 +168,9 @@ pub trait AiciCtrl { InitPromptResult::default() } - /// Called before mid_process(), can return attention masks. ~1ms time limit. - /// Should be stateless. - fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { - PreProcessResult::continue_() - } - /// This is the main entry point for the module. ~20ms time limit. fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; - /// Called after tokens are appended, after mid_process(). ~1ms time limit. - fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { - PostProcessResult::from_arg(&arg) - } - // Internals fn aici_init_prompt(&mut self) { let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); @@ -186,35 +179,28 @@ pub trait AiciCtrl { host::return_process_result(&res_bytes); } - fn aici_pre_process(&mut self) { - let arg: PreProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - let res = self.pre_process(arg); - let res_bytes = serde_json::to_vec(&res).unwrap(); - host::return_process_result(&res_bytes); - } - fn aici_mid_process(&mut self) { let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) .expect("aici_mid_process: failed to deserialize MidProcessArg"); let res = self.mid_process(arg); - match &res { - MidProcessResult::SampleWithBias { allowed_tokens } => { - if allowed_tokens.len() > 0 { - host::return_logit_bias(allowed_tokens); - } - } - _ => {} + if res.branches.len() > 1 { + panic!("aici_mid_process: multiple branches not yet supported"); } + let res = ProcessResultOffset { + branches: res + .branches + .into_iter() + .map(|b| { + b.map_mask(|vob| { + host::return_logit_bias(&vob); + 0 + }) + }) + .collect(), + }; let res_bytes = serde_json::to_vec(&res).expect("aici_mid_process: failed to serialize"); host::return_process_result(&res_bytes); } - - fn aici_post_process(&mut self) { - let arg: PostProcessArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - let res = self.post_process(arg); - let res_bytes = serde_json::to_vec(&res).unwrap(); - host::return_process_result(&res_bytes); - } } /// Expose method as extern "C", usage: @@ -244,9 +230,7 @@ macro_rules! expose { #[macro_export] macro_rules! aici_expose_all { ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::aici_pre_process() -> ()); $crate::expose!($struct_name::aici_mid_process() -> ()); - $crate::expose!($struct_name::aici_post_process() -> ()); $crate::expose!($struct_name::aici_init_prompt() -> ()); #[no_mangle] diff --git a/controllers/aici_abi/src/recognizer.rs b/controllers/aici_abi/src/recognizer.rs index ec6d8767..7045da6e 100644 --- a/controllers/aici_abi/src/recognizer.rs +++ b/controllers/aici_abi/src/recognizer.rs @@ -1,6 +1,6 @@ use crate::{ toktree::{Recognizer, SpecialToken, TokTrie}, - AiciCtrl, MidProcessArg, MidProcessResult, PostProcessArg, PostProcessResult, + AiciCtrl, MidProcessArg, MidProcessResult, }; use std::fmt::Debug; @@ -19,17 +19,14 @@ impl AiciRecognizer { } impl AiciCtrl for AiciRecognizer { - fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { - let mut set = self.trie.alloc_token_set(); - self.trie.compute_bias(&mut self.rec, &mut set); - MidProcessResult::SampleWithBias { - allowed_tokens: set, + fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { + if arg.has_eos() { + return MidProcessResult::stop(); } - } - - fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { self.trie.append_tokens(&mut self.rec, &arg.tokens); - PostProcessResult::from_arg(&arg) + let mut set = self.trie.alloc_token_set(); + self.trie.compute_bias(&mut self.rec, &mut set); + MidProcessResult::sample(set) } } diff --git a/controllers/aici_abi/src/yesno.rs b/controllers/aici_abi/src/yesno.rs index 493f9ebf..dc16d2d2 100644 --- a/controllers/aici_abi/src/yesno.rs +++ b/controllers/aici_abi/src/yesno.rs @@ -1,12 +1,8 @@ -use aici_abi::{ - arg_string, tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, - PostProcessArg, PostProcessResult, PreProcessArg, PreProcessResult, TokenId, -}; +use aici_abi::{tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; pub struct Runner { toktrie: TokTrie, tokens: Vec, - question: Vec, yes: TokenId, no: TokenId, } @@ -19,7 +15,6 @@ impl Runner { Runner { toktrie: TokTrie::from_host(), tokens: Vec::new(), - question: tokenize(&(arg_string() + "\n")), yes, no, } @@ -27,30 +22,16 @@ impl Runner { } impl AiciCtrl for Runner { - fn pre_process(&mut self, _arg: PreProcessArg) -> PreProcessResult { - if self.tokens.is_empty() { - PreProcessResult::ff_tokens(self.question.clone()) + fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { + arg.save_tokens(&mut self.tokens); + if self.tokens.len() >= 1 { + // we only want the first token + MidProcessResult::stop() } else { - PreProcessResult::continue_() - } - } - - fn mid_process(&mut self, _arg: MidProcessArg) -> MidProcessResult { - let mut set = self.toktrie.alloc_token_set(); - set.allow_token(self.yes); - set.allow_token(self.no); - MidProcessResult::SampleWithBias { - allowed_tokens: set, - } - } - - fn post_process(&mut self, arg: PostProcessArg) -> PostProcessResult { - // save our tokens - self.tokens.extend_from_slice(&arg.tokens); - if self.tokens.len() >= self.question.len() + 1 { - PostProcessResult::stop() - } else { - PostProcessResult::from_arg(&arg) + let mut set = self.toktrie.alloc_token_set(); + set.allow_token(self.yes); + set.allow_token(self.no); + MidProcessResult::sample(set) } } } From 57720bdf716af5c6054089418bd9dee779b196fb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Mar 2024 16:39:44 -0700 Subject: [PATCH 189/301] pre/post code complete --- controllers/aici_abi/src/lib.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index a062b90f..e7728397 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,4 +1,3 @@ -use regex_automata::nfa::thompson::backtrack; use serde::{Deserialize, Serialize}; use svob::SimpleVob; @@ -106,6 +105,15 @@ pub struct Branch { pub splices: Vec, } +impl Clone for Branch { + fn clone(&self) -> Self { + Branch { + sample_mask: self.sample_mask.clone(), + splices: self.splices.clone(), + } + } +} + impl Branch { pub fn map_mask(&self, f: F) -> Branch where From 7bb11e6b85bfc5f41ff3abeda27c6ae52a28cf4f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 29 Mar 2024 21:11:46 +0000 Subject: [PATCH 190/301] add limited forking --- controllers/aici_abi/src/lib.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index e7728397..56fe9a31 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -191,15 +191,17 @@ pub trait AiciCtrl { let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) .expect("aici_mid_process: failed to deserialize MidProcessArg"); let res = self.mid_process(arg); - if res.branches.len() > 1 { - panic!("aici_mid_process: multiple branches not yet supported"); - } + let mut used_logits = false; let res = ProcessResultOffset { branches: res .branches .into_iter() .map(|b| { b.map_mask(|vob| { + if used_logits { + panic!("aici_mid_process: multiple branches with sampling not yet supported"); + } + used_logits = true; host::return_logit_bias(&vob); 0 }) From 82a59ecea1e648c0c91ea61e4f8d50b6d1dfc2cc Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 3 Apr 2024 19:40:30 +0000 Subject: [PATCH 191/301] make the gctrl work again --- controllers/aici_abi/src/toktree.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 9bde3089..76904d5f 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -191,6 +191,10 @@ impl TokTrie { } } + pub fn eos_token(&self) -> TokenId { + self.info.tok_eos + } + pub fn vocab_size(&self) -> usize { self.info.vocab_size as usize } @@ -504,7 +508,7 @@ impl TokTrie { } // all prefixes of 'start' are also allowed if start.len() > 0 { - for len in 1..start.len() - 1 { + for len in 1..=start.len() { let bytes = &start[0..len]; if let Some(tok) = self.token_id(bytes) { logits.allow_token(tok); From 3b0dc0c52ae572e536c94b31fde6267354d53f71 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Apr 2024 18:34:15 +0000 Subject: [PATCH 192/301] add aici_native crate --- controllers/aici_abi/src/host.rs | 136 +++++++--- controllers/aici_abi/src/lib.rs | 3 + controllers/aici_native/Cargo.toml | 20 ++ controllers/aici_native/README.md | 3 + controllers/aici_native/src/bintokens.rs | 326 +++++++++++++++++++++++ controllers/aici_native/src/lib.rs | 8 + controllers/aici_native/src/log.rs | 100 +++++++ controllers/aici_native/src/variables.rs | 56 ++++ 8 files changed, 620 insertions(+), 32 deletions(-) create mode 100644 controllers/aici_native/Cargo.toml create mode 100644 controllers/aici_native/README.md create mode 100644 controllers/aici_native/src/bintokens.rs create mode 100644 controllers/aici_native/src/lib.rs create mode 100644 controllers/aici_native/src/log.rs create mode 100644 controllers/aici_native/src/variables.rs diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index eca93f06..7e2c6730 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -69,14 +69,104 @@ fn init_panic() { #[no_mangle] pub extern "C" fn aici_init() { init_panic(); + set_host(Box::new(WasmHost {})); +} + +/** + * This is normally implemented straightforwardly by wasm callbacks. + * It can be overridden with set_host() when compiling to native. + */ +pub trait HostInterface { + fn arg_bytes(&self) -> Vec; + fn trie_bytes(&self) -> Vec; + fn return_logit_bias(&self, vob: &SimpleVob); + fn process_arg_bytes(&self) -> Vec; + fn return_process_result(&self, res: &[u8]); + fn storage_cmd(&self, cmd: StorageCmd) -> StorageResp; + fn tokenize_bytes(&self, s: &[u8]) -> Vec; + fn self_seq_id(&self) -> SeqId; + fn eos_token(&self) -> TokenId; + fn stop(&self) -> !; +} + +static mut HOST: Option> = None; + +struct WasmHost {} +impl HostInterface for WasmHost { + fn arg_bytes(&self) -> Vec { + read_blob(unsafe { aici_host_module_arg() }, 1024) + } + + fn trie_bytes(&self) -> Vec { + read_blob(unsafe { aici_host_token_trie() }, 0) + } + + fn return_logit_bias(&self, vob: &SimpleVob) { + assert!(vob.len() > 0); + unsafe { + aici_host_return_logit_bias(vob.as_ptr()); + } + } + + fn process_arg_bytes(&self) -> Vec { + read_blob(unsafe { aici_host_process_arg() }, 1024) + } + + fn return_process_result(&self, res: &[u8]) { + unsafe { + aici_host_return_process_result(res.as_ptr(), res.len() as u32); + } + } + + fn storage_cmd(&self, cmd: StorageCmd) -> StorageResp { + let cmd_bytes = serde_json::to_vec(&cmd).unwrap(); + let res_id = unsafe { aici_host_storage_cmd(cmd_bytes.as_ptr(), cmd_bytes.len() as u32) }; + let resp_bytes = read_blob(res_id, 1024); + serde_json::from_slice(&resp_bytes).unwrap() + } + + fn stop(&self) -> ! { + unsafe { aici_host_stop() }; + panic!("didn't stop") + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; + let r = read_blob(id, 4 * (s.len() / 3 + 10)); + let res = vec_from_bytes(&r); + // println!( + // "tokenize_bytes: {:?} -> {:?}", + // String::from_utf8_lossy(s), + // res + // ); + res + } + + fn self_seq_id(&self) -> SeqId { + unsafe { SeqId(aici_host_self_seq_id()) } + } + + fn eos_token(&self) -> TokenId { + unsafe { aici_host_eos_token() } + } +} + +fn get_host() -> &'static Box { + unsafe { HOST.as_ref().unwrap() } +} + +pub fn set_host(host: Box) { + unsafe { + assert!(HOST.is_none()); + HOST = Some(host); + } } pub fn arg_bytes() -> Vec { - #[cfg(target_arch = "wasm32")] - return read_blob(unsafe { aici_host_module_arg() }, 1024); + get_host().arg_bytes() - #[cfg(not(target_arch = "wasm32"))] - return std::fs::read("arg.json").unwrap(); + // #[cfg(not(target_arch = "wasm32"))] + // return std::fs::read("arg.json").unwrap(); } pub fn arg_string() -> String { @@ -84,22 +174,17 @@ pub fn arg_string() -> String { } pub fn trie_bytes() -> Vec { - #[cfg(target_arch = "wasm32")] - return read_blob(unsafe { aici_host_token_trie() }, 0); - - #[cfg(not(target_arch = "wasm32"))] - return std::fs::read("tokenizer.bin").unwrap(); + get_host().trie_bytes() + // #[cfg(not(target_arch = "wasm32"))] + // return std::fs::read("tokenizer.bin").unwrap(); } pub fn return_logit_bias(vob: &SimpleVob) { - assert!(vob.len() > 0); - unsafe { - aici_host_return_logit_bias(vob.as_ptr()); - } + get_host().return_logit_bias(vob); } pub fn process_arg_bytes() -> Vec { - return read_blob(unsafe { aici_host_process_arg() }, 1024); + get_host().process_arg_bytes() } pub fn return_process_result(res: &[u8]) { @@ -238,38 +323,25 @@ impl VariableStorage { /// Tokenize given byte string. pub fn tokenize_bytes(s: &[u8]) -> Vec { - let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; - let r = read_blob(id, 4 * (s.len() / 3 + 10)); - let res = vec_from_bytes(&r); - // println!( - // "tokenize_bytes: {:?} -> {:?}", - // String::from_utf8_lossy(s), - // res - // ); - res + get_host().tokenize_bytes(s) } /// Tokenize given UTF8 string. pub fn tokenize(s: &str) -> Vec { - let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; - let r = read_blob(id, 4 * (s.len() / 3 + 10)); - let res = vec_from_bytes(&r); - // println!("tokenize: {:?} -> {:?}", s, res); - res + get_host().tokenize_bytes(s.as_bytes()) } /// Return the ID of the current process. pub fn self_seq_id() -> SeqId { - unsafe { SeqId(aici_host_self_seq_id()) } + get_host().self_seq_id() } /// Return the ID of the EOS token. pub fn eos_token() -> TokenId { - unsafe { aici_host_eos_token() } + get_host().eos_token() } /// Stop the program - any error info is assumed to have been printed already. pub fn aici_stop() -> ! { - unsafe { aici_host_stop() }; - panic!("didn't stop"); + get_host().stop(); } diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 56fe9a31..abc8d269 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -25,6 +25,9 @@ pub use host::{ StorageResp, VariableStorage, }; +#[cfg(not(target_arch = "wasm32"))] +pub use host::{set_host, HostInterface}; + #[derive(Serialize, Deserialize, Debug)] pub struct InitPromptArg { pub prompt: Vec, diff --git a/controllers/aici_native/Cargo.toml b/controllers/aici_native/Cargo.toml new file mode 100644 index 00000000..128d8f7e --- /dev/null +++ b/controllers/aici_native/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "aici_native" +version = "0.1.0" +edition = "2021" + +[lib] +name = "aici_native" + +[dependencies] +aici_abi = { path = "../aici_abi" } +vob = { version = "3.0.3", optional = true } +serde = { version = "1.0.192", features = ["derive"] } +serde_json = "1.0.108" +anyhow = "1.0.75" +rustc-hash = "1.1.0" +base64 = "0.22.0" +fxhash = "0.2.1" +tokenizers = { version = "0.15.0", features = ["http"] } +log = "0.4.21" +flexi_logger = "0.28.0" diff --git a/controllers/aici_native/README.md b/controllers/aici_native/README.md new file mode 100644 index 00000000..205a0ae0 --- /dev/null +++ b/controllers/aici_native/README.md @@ -0,0 +1,3 @@ +# AICI native + +Utilities for building native (non-Wasm) AICI Controllers. diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs new file mode 100644 index 00000000..910059e2 --- /dev/null +++ b/controllers/aici_native/src/bintokens.rs @@ -0,0 +1,326 @@ +use aici_abi::bytes::TokRxInfo; +use anyhow::{anyhow, bail, Result}; +use fxhash::FxHashMap; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; +use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; + +#[derive(Serialize, Deserialize)] +pub struct ByteTokenizer { + pub hf_model: String, + pub hf_tokenizer: Tokenizer, + pub eos_token: u32, + pub vocab_size: u32, + token_bytes: Vec>, + pub special: BTreeMap, +} + +pub struct TokenizerInfo { + pub name: &'static str, + pub description: &'static str, + pub hf_model: &'static str, + pub model_ids: &'static str, +} + +pub fn tokenizers() -> Vec { + vec![ + TokenizerInfo { + name: "gpt4", + description: "cl100k_base, used by GPT-4 and GPT-3.5", + hf_model: "Xenova/gpt-4", + model_ids: "gpt-4", + }, + TokenizerInfo { + name: "llama16", + description: "same as llama, with 16 added tokens (used by 13B codellama)", + hf_model: "codellama/CodeLlama-13b-Instruct-hf", + model_ids: "codellama-13b", + }, + TokenizerInfo { + name: "llama70", + description: "used by codellama-70b; with token", + hf_model: "codellama/CodeLlama-70b-Instruct-hf", + model_ids: "codellama-70b", + }, + TokenizerInfo { + name: "llama", + description: "used by Llama, CodeLlama, etc.", + hf_model: "codellama/CodeLlama-34b-Instruct-hf", + model_ids: "", + }, + TokenizerInfo { + name: "orca", + description: "llama", + hf_model: "microsoft/Orca-2-13b@refs/pr/23", + model_ids: "for microsoft/Orca models; similar to llama, with 3 tokens added for chat", + }, + TokenizerInfo { + name: "falcon", + description: "used by Falcon 7b, 40b, etc.", + hf_model: "tiiuae/falcon-7b", + model_ids: "", + }, + TokenizerInfo { + name: "mistral", + description: "used by Mistral and Mixtral", + hf_model: "mistralai/Mistral-7B-Instruct-v0.2", + model_ids: "mixtral", + }, + TokenizerInfo { + name: "mpt", + description: "MPT", + hf_model: "mosaicml/mpt-7b", + model_ids: "", + }, + TokenizerInfo { + name: "phi", + description: "Phi 1.5 and Phi 2", + hf_model: "microsoft/phi-1_5", + model_ids: "", + }, + TokenizerInfo { + name: "gpt2", + description: "GPT-2", + hf_model: "gpt2", + model_ids: "gpt-2", + }, + ] +} + +// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi + +fn is_self_mapped(c: char) -> bool { + match c { + '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}' => true, + _ => false, + } +} + +fn build_char_map() -> FxHashMap { + let mut res = FxHashMap::default(); + let mut k = 0x100u32; + for byte in 0..=255u8 { + let c = byte as char; + if is_self_mapped(c) { + res.insert(c, byte); + } else { + res.insert(char::from_u32(k).unwrap(), byte); + k += 1; + } + } + res +} + +pub fn list_tokenizers() -> String { + format!( + "Available tokenizers for -t or --tokenizer:\n{}\n{}", + tokenizers() + .iter() + .map(|t| format!(" -t {:16} {}", t.name, t.description)) + .collect::>() + .join("\n"), + "You can also use a HuggingFace model name, in format 'user/modelname'." + ) +} + +pub fn guess_tokenizer(model_name: &str) -> Option { + let m = model_name.to_lowercase(); + tokenizers() + .iter() + .find(|t| { + m.contains(&t.name) + || t.model_ids + .split(',') + .map(|x| x.trim()) + .filter(|x| x.len() > 0) + .any(|x| m.contains(x)) + }) + .map(|t| t.name.to_string()) +} + +fn strip_suffix(sep: &str, s: &mut String) -> Option { + let mut parts = s.splitn(2, sep); + let core = parts.next().unwrap().to_string(); + let suff = parts.next().map(|s| s.to_string()); + *s = core; + suff +} + +pub fn test_tokenizers() { + for t in tokenizers() { + let t = find_tokenizer(t.name).unwrap(); + println!("tokenizer: {} {}", t.hf_model, t.vocab_size); + } +} + +pub fn find_tokenizer(mut name: &str) -> Result { + if !name.contains("/") { + for t in tokenizers() { + if t.name == name { + name = t.hf_model; + break; + } + } + } + + log::info!("loading tokenizer: {}", name); + + let mut name2 = name.to_string(); + let mut args = FromPretrainedParameters::default(); + + match strip_suffix("@", &mut name2) { + Some(s) => args.revision = s, + None => {} + } + + match Tokenizer::from_pretrained(name2, Some(args)) { + Err(e) => { + let msg = format!("can't load tokenizer {}: {}", name, e); + println!("{}\n{}", msg, list_tokenizers()); + return Err(anyhow!("{}", msg)); + } + Ok(t) => { + let bt = ByteTokenizer::from_tokenizer(t)?; + Ok(bt) + } + } +} + +impl ByteTokenizer { + pub fn from_tokenizer(mut hft: Tokenizer) -> Result { + let mut is_byte_level = false; + let mut is_byte_fallback = false; + let mut space_ch = ' '; + + // remove the "Prepend space" + if let Some(n) = hft.get_normalizer() { + let n = match n { + NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new( + x.get_normalizers() + .iter() + .filter_map(|n| match n { + NormalizerWrapper::Prepend(_) => None, + _ => Some(n.clone()), + }) + .collect(), + )), + _ => n.clone(), + }; + hft.with_normalizer(n); + } + + if let Some(d) = hft.get_decoder() { + // DecoderWrapper::Sequence() doesn't let one access the decoders + // so we resort to json munching + let v = serde_json::to_value(d).unwrap(); + if v["type"].as_str() == Some("ByteLevel") { + is_byte_level = true; + } else if v["type"].as_str() == Some("Sequence") { + if let Some(decoders) = v["decoders"].as_array() { + for decoder in decoders { + if decoder["type"].as_str() == Some("ByteFallback") { + is_byte_fallback = true; + } else if decoder["type"].as_str() == Some("Replace") + && decoder["content"].as_str() == Some(" ") + { + if let Some(s) = decoder["pattern"]["String"].as_str() { + let s: Vec = s.chars().collect(); + if s.len() == 1 { + space_ch = s[0]; + } + } + } + } + } + } + } + + if !is_byte_fallback && !is_byte_level { + bail!("can't determine decoder type: {:?}", hft.get_decoder()); + } + + let vocab_size = hft.get_vocab_size(true) as u32; + let added = hft.get_added_tokens_decoder(); + + let mut res = ByteTokenizer { + hf_model: "foobar".to_string(), + eos_token: 0, + vocab_size, + special: BTreeMap::new(), + token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), + hf_tokenizer: hft, + }; + + for (id, info) in added.iter() { + if info.special { + match info.content.as_str() { + "" | "<|endoftext|>" => res.eos_token = *id, + _ => {} + } + res.special.insert(info.content.clone(), *id); + } else { + res.token_bytes[*id as usize] = info.content.clone().into_bytes(); + } + } + + let char_map = build_char_map(); + + for tok_id in 0..vocab_size { + if added.contains_key(&tok_id) { + continue; + } + if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) { + if is_byte_fallback { + if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") + { + // parse hex number from tok_name + let hex_str = &tok_name[3..5]; + let byte = u8::from_str_radix(hex_str, 16).unwrap(); + res.token_bytes[tok_id as usize] = vec![byte]; + } else { + assert!(!tok_name.starts_with("<0x")); + let tok_name = tok_name.replace(space_ch, " "); + res.token_bytes[tok_id as usize] = tok_name.as_bytes().to_vec(); + } + } else if is_byte_level { + let bytes: Result> = tok_name + .chars() + .map(|c| { + char_map + .get(&c) + .map(|c| *c) + .ok_or_else(|| anyhow!("missing char: {}", c)) + }) + .collect(); + let bytes = match bytes { + Ok(b) => b, + Err(e) => { + println!("error: {} for {:?}", e, tok_name); + continue; + } + }; + + res.token_bytes[tok_id as usize] = bytes; + } else { + panic!(); + } + } else { + log::warn!("missing token: {}", tok_id); + } + } + + Ok(res) + } +} + +impl ByteTokenizer { + pub fn tokrx_info(&self) -> TokRxInfo { + TokRxInfo { + vocab_size: self.vocab_size, + tok_eos: self.eos_token, + } + } + pub fn token_bytes(&self) -> Vec> { + self.token_bytes.clone() + } +} diff --git a/controllers/aici_native/src/lib.rs b/controllers/aici_native/src/lib.rs new file mode 100644 index 00000000..8fd3fa2c --- /dev/null +++ b/controllers/aici_native/src/lib.rs @@ -0,0 +1,8 @@ +pub mod bintokens; +mod log; +pub mod variables; + +pub use log::*; + +pub use fxhash::FxHashMap as HashMap; +pub use fxhash::FxHashSet as HashSet; diff --git a/controllers/aici_native/src/log.rs b/controllers/aici_native/src/log.rs new file mode 100644 index 00000000..c8c6f8ad --- /dev/null +++ b/controllers/aici_native/src/log.rs @@ -0,0 +1,100 @@ +use std::fmt::Write; + +use anyhow::Result; +use flexi_logger::style; +use flexi_logger::{DeferredNow, Logger, WriteMode}; +use log::Record; + +pub enum LogMode { + Normal, + Test, + Daemon, +} + +struct LimitedWrite { + limit: usize, + dst: Vec, +} + +impl Write for LimitedWrite { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + if self.dst.len() > self.limit { + return Err(std::fmt::Error); + } + if self.dst.len() + s.len() < self.limit { + self.dst.extend_from_slice(s.as_bytes()); + Ok(()) + } else { + let remaining = self.limit - self.dst.len(); + self.dst.extend_from_slice(&s.as_bytes()[..remaining]); + self.dst.extend_from_slice(b" (...)"); + Err(std::fmt::Error) + } + } +} + +fn args_to_str(limit: usize, args: &std::fmt::Arguments) -> String { + // let capacity = args.estimated_capacity(); + let mut output = LimitedWrite { + limit, + dst: Vec::with_capacity(128), + }; + if output.write_fmt(*args).is_err() { + assert!(output.dst.len() > limit); + } + match String::from_utf8(output.dst) { + Ok(s) => s, + Err(err) => String::from_utf8_lossy(err.as_bytes()).to_string(), + } +} + +fn truncated_format( + w: &mut dyn std::io::Write, + _now: &mut DeferredNow, + record: &Record, +) -> Result<(), std::io::Error> { + let level = record.level(); + write!( + w, + "{} [{}] {}", + style(level).paint(level.to_string()), + record.module_path().unwrap_or(""), + style(level).paint(args_to_str(1000, record.args())) + ) +} + +fn daemon_format( + w: &mut dyn std::io::Write, + now: &mut DeferredNow, + record: &Record, +) -> Result<(), std::io::Error> { + write!( + w, + "{} {} [{}] {}", + now.format("%Y-%m-%d %H:%M:%S%.3f"), + record.level(), + record.module_path().unwrap_or(""), + args_to_str(5000, record.args()) + ) +} + +pub fn init_log(mode: LogMode) -> Result<()> { + let logger = match mode { + LogMode::Normal => Logger::try_with_env_or_str("info")? + .format(truncated_format) + .log_to_stdout(), + LogMode::Test => { + Logger::try_with_env_or_str("debug")?.write_mode(WriteMode::SupportCapture) + } + LogMode::Daemon => Logger::try_with_env_or_str("info")? + .format(daemon_format) + .log_to_stdout(), + }; + + logger.start()?; + Ok(()) +} + +pub fn setup_log() { + init_log(LogMode::Normal).expect("Failed to initialize log") +} diff --git a/controllers/aici_native/src/variables.rs b/controllers/aici_native/src/variables.rs new file mode 100644 index 00000000..825d75f8 --- /dev/null +++ b/controllers/aici_native/src/variables.rs @@ -0,0 +1,56 @@ +use aici_abi::{StorageCmd, StorageOp, StorageResp}; +use fxhash::FxHashMap; + +#[derive(Default)] +pub struct Variables { + pub variables: FxHashMap)>, +} + +impl Variables { + pub fn process_cmd(&mut self, cmd: StorageCmd) -> StorageResp { + match cmd { + StorageCmd::ReadVar { name } => match self.variables.get(&name).map(|x| x.clone()) { + None => StorageResp::VariableMissing {}, + Some((version, value)) => StorageResp::ReadVar { value, version }, + }, + StorageCmd::WriteVar { + name, + value, + when_version_is, + op, + } => { + let curr = self.variables.get(&name).map(|x| x.clone()); + match curr { + Some((prev_version, prev_val)) => match when_version_is { + Some(v) if v != prev_version => StorageResp::ReadVar { + version: prev_version, + value: prev_val, + }, + _ => { + let value = match op { + StorageOp::Append => { + let mut v = prev_val.clone(); + v.extend(value); + v + } + StorageOp::Set => value, + }; + let version = prev_version + 1; + self.variables.insert(name, (version, value)); + StorageResp::WriteVar { version } + } + }, + + None => match when_version_is { + None => { + self.variables.insert(name, (1, value)); + StorageResp::WriteVar { version: 1 } + } + Some(_) => StorageResp::VariableMissing {}, + }, + } + } + } + } +} + From c754b6b06d5a6532cc2dd437c437a4e9a470697f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Apr 2024 18:54:25 +0000 Subject: [PATCH 193/301] remove unused deps use rustc-hash not fxhash everywhere --- controllers/aici_abi/Cargo.toml | 4 +--- controllers/aici_native/Cargo.toml | 3 --- controllers/aici_native/src/bintokens.rs | 2 +- controllers/aici_native/src/lib.rs | 4 ++-- controllers/aici_native/src/variables.rs | 2 +- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 638748d8..cd30bc49 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -13,15 +13,13 @@ serde_json = "1.0.108" anyhow = "1.0.75" regex-automata = { version = "0.4.3", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } cfgrammar = { version = "0.13.3", optional = true } -lrlex = { version = "0.13.3", optional = true } -lrpar = { version = "0.13.3", optional = true } lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } [features] default = ["cfg", "rx"] -cfg = ["dep:cfgrammar", "dep:lrlex", "dep:lrpar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] +cfg = ["dep:cfgrammar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] rx = ["dep:regex-automata"] [[bin]] diff --git a/controllers/aici_native/Cargo.toml b/controllers/aici_native/Cargo.toml index 128d8f7e..9dec1862 100644 --- a/controllers/aici_native/Cargo.toml +++ b/controllers/aici_native/Cargo.toml @@ -8,13 +8,10 @@ name = "aici_native" [dependencies] aici_abi = { path = "../aici_abi" } -vob = { version = "3.0.3", optional = true } serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" rustc-hash = "1.1.0" -base64 = "0.22.0" -fxhash = "0.2.1" tokenizers = { version = "0.15.0", features = ["http"] } log = "0.4.21" flexi_logger = "0.28.0" diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 910059e2..31a4308e 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -1,6 +1,6 @@ use aici_abi::bytes::TokRxInfo; use anyhow::{anyhow, bail, Result}; -use fxhash::FxHashMap; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; diff --git a/controllers/aici_native/src/lib.rs b/controllers/aici_native/src/lib.rs index 8fd3fa2c..6bdf9d43 100644 --- a/controllers/aici_native/src/lib.rs +++ b/controllers/aici_native/src/lib.rs @@ -4,5 +4,5 @@ pub mod variables; pub use log::*; -pub use fxhash::FxHashMap as HashMap; -pub use fxhash::FxHashSet as HashSet; +pub use rustc_hash::FxHashMap as HashMap; +pub use rustc_hash::FxHashSet as HashSet; diff --git a/controllers/aici_native/src/variables.rs b/controllers/aici_native/src/variables.rs index 825d75f8..4f0dc07b 100644 --- a/controllers/aici_native/src/variables.rs +++ b/controllers/aici_native/src/variables.rs @@ -1,5 +1,5 @@ use aici_abi::{StorageCmd, StorageOp, StorageResp}; -use fxhash::FxHashMap; +use rustc_hash::FxHashMap; #[derive(Default)] pub struct Variables { From 2d2ee035134a2f6e661cbdb887429bc9ab9b0530 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Apr 2024 23:56:28 +0000 Subject: [PATCH 194/301] add TokenizerEnv --- controllers/aici_abi/src/host.rs | 43 +++++++++++++++++++++++- controllers/aici_abi/src/lib.rs | 2 +- controllers/aici_abi/src/toktree.rs | 2 ++ controllers/aici_native/src/bintokens.rs | 42 ++++++++++++++++++++++- 4 files changed, 86 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 7e2c6730..cc8f69de 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -1,6 +1,7 @@ use crate::{ bytes::{vec_from_bytes, TokenId}, svob::SimpleVob, + toktree::TokTrie, SeqId, }; use serde::{Deserialize, Serialize}; @@ -58,20 +59,60 @@ fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { buffer } +#[cfg(target_arch = "wasm32")] fn init_panic() { - #[cfg(target_arch = "wasm32")] std::panic::set_hook(Box::new(|info| { // skip 'run with `RUST_BACKTRACE=1`' message (not relevant for remote running) println!("{}", info); })) } +#[cfg(target_arch = "wasm32")] #[no_mangle] pub extern "C" fn aici_init() { init_panic(); set_host(Box::new(WasmHost {})); } +pub trait TokenizerEnv: Send { + fn stop(&self) -> !; + fn tok_trie(&self) -> &TokTrie; + fn tokenize_bytes(&self, s: &[u8]) -> Vec; + + fn tokenize(&self, s: &str) -> Vec { + self.tokenize_bytes(s.as_bytes()) + } + fn eos_token(&self) -> TokenId { + self.tok_trie().eos_token() + } +} + +pub struct WasmTokenizerEnv { + toktrie: TokTrie, +} + +impl Default for WasmTokenizerEnv { + fn default() -> Self { + WasmTokenizerEnv { + toktrie: TokTrie::from_bytes(&trie_bytes()), + } + } +} + +impl TokenizerEnv for WasmTokenizerEnv { + fn stop(&self) -> ! { + aici_stop() + } + + fn tok_trie(&self) -> &TokTrie { + &self.toktrie + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + tokenize_bytes(s) + } +} + /** * This is normally implemented straightforwardly by wasm callbacks. * It can be overridden with set_host() when compiling to native. diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index abc8d269..6be64b5d 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -22,7 +22,7 @@ pub type TokenId = bytes::TokenId; pub use host::{ aici_stop, arg_bytes, arg_string, self_seq_id, tokenize, tokenize_bytes, StorageCmd, StorageOp, - StorageResp, VariableStorage, + StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, }; #[cfg(not(target_arch = "wasm32"))] diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 76904d5f..bb274f18 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -52,6 +52,7 @@ pub trait Recognizer { fn try_push_byte(&mut self, byte: u8) -> bool; } +#[derive(Clone)] pub struct TokTrie { info: TokRxInfo, token_offsets: Vec, @@ -76,6 +77,7 @@ impl TokTrieHeader { const MAGIC: u32 = 0x558b6fd3; } +#[derive(Clone)] #[repr(C)] pub struct TrieNode { // byte:token diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 31a4308e..691a0551 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -1,4 +1,4 @@ -use aici_abi::bytes::TokRxInfo; +use aici_abi::{bytes::TokRxInfo, toktree::TokTrie, TokenId, TokenizerEnv}; use anyhow::{anyhow, bail, Result}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -324,3 +324,43 @@ impl ByteTokenizer { self.token_bytes.clone() } } + +pub struct ByteTokenizerEnv { + pub tokenizer: ByteTokenizer, + tok_trie: TokTrie, +} + +impl ByteTokenizerEnv { + pub fn load(tokenizer_name: &str) -> Result { + let tokenizer = find_tokenizer(tokenizer_name)?; + Ok(Self::new(tokenizer)) + } + pub fn new(tokenizer: ByteTokenizer) -> ByteTokenizerEnv { + let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); + ByteTokenizerEnv { + tokenizer, + tok_trie, + } + } +} + +impl TokenizerEnv for ByteTokenizerEnv { + fn stop(&self) -> ! { + panic!("stop called") + } + + fn tok_trie(&self) -> &TokTrie { + &self.tok_trie + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + let tokens = self + .tokenizer + .hf_tokenizer + .encode(String::from_utf8_lossy(s), false); + match tokens { + Err(e) => panic!("tokenize error: {e}"), + Ok(tokens) => Vec::from(tokens.get_ids()), + } + } +} From 462b68e81a1d44b169779c0283d39e9d0a5e8eb3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 6 Apr 2024 01:07:22 +0000 Subject: [PATCH 195/301] minor --- controllers/aici_abi/src/lib.rs | 7 ++++++- controllers/aici_native/src/bintokens.rs | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 6be64b5d..46efe891 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -59,7 +59,12 @@ impl MidProcessArg { } pub fn save_tokens(&self, acc_tokens: &mut Vec) { - acc_tokens.truncate(acc_tokens.len() - self.backtrack as usize); + let bt = self.backtrack as usize; + assert!( + bt <= acc_tokens.len(), + "attempting to backtrack past beginning" + ); + acc_tokens.truncate(acc_tokens.len() - bt); acc_tokens.extend_from_slice(&self.tokens); } } diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 691a0551..660d17af 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -327,7 +327,7 @@ impl ByteTokenizer { pub struct ByteTokenizerEnv { pub tokenizer: ByteTokenizer, - tok_trie: TokTrie, + pub tok_trie: TokTrie, } impl ByteTokenizerEnv { From baf35d90e28d674a95c996636c684432f04de958 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 8 Apr 2024 20:07:15 +0000 Subject: [PATCH 196/301] improvements to byte forcing --- controllers/aici_abi/src/toktree.rs | 55 ++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index bb274f18..91c971e1 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -562,6 +562,48 @@ impl TokTrie { ok } + /// Check if add_bias() would have returned any tokens. + #[inline(never)] + pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool { + let n = self.child_at_bytes(self.root(), start); + if n.is_none() { + return false; + } + let n = n.unwrap(); + r.trie_started(); + let off = self.node_offset(n); + let mut p = off + 1; + let endp = off + n.subtree_size(); + let mut ok = false; + let mut next_pop = 0; + while p < endp { + r.pop_bytes(next_pop); + let n = &self.nodes[p]; + let b = n.byte(); + if r.try_push_byte(b) { + if n.token_id().is_some() { + ok = true; + break; + } + next_pop = if n.subtree_size() == 1 { + n.num_parents() + } else { + 0 + }; + p += 1; + } else { + p += n.subtree_size(); + next_pop = n.num_parents() - 1; + } + } + if start.len() == 0 { + // if start was non-empty, trie_finished() is supposed to clean this up + r.pop_bytes(next_pop); + } + r.trie_finished(); + ok + } + #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { r.trie_started(); @@ -570,23 +612,28 @@ impl TokTrie { let off = self.node_offset(n); let mut p = off + 1; let endp = off + n.subtree_size(); + let mut next_pop = 0; while p < endp { + r.pop_bytes(next_pop); let n = &self.nodes[p]; let b = n.byte(); if r.try_push_byte(b) { toks.allow_token(n.token_id().unwrap_or(defl_tok)); - r.pop_bytes(if n.subtree_size() == 1 { + next_pop = if n.subtree_size() == 1 { n.num_parents() } else { 0 - }); - + }; p += 1; } else { p += n.subtree_size(); - r.pop_bytes(n.num_parents() - 1); + next_pop = n.num_parents() - 1; } } + if start.len() == 0 { + // if start was non-empty, trie_finished() is supposed to clean this up + r.pop_bytes(next_pop); + } r.trie_finished(); // revert the fake token toks.disallow_token(defl_tok); From 1e080ef73fe80db0416b281df759851b5e27d121 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 8 Apr 2024 22:09:54 +0000 Subject: [PATCH 197/301] fix comment --- controllers/aici_abi/src/toktree.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 91c971e1..5b8923fd 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -44,7 +44,8 @@ pub trait Recognizer { /// check if stack.top() transitions via tok to a viable state fn special_allowed(&mut self, tok: SpecialToken) -> bool; /// Called when iteration over the trie is finished - /// Stack has exactly one element then. + /// Stack has exactly one element then, except when iteration started from non-root node. + /// In that case, the stack may have more than one element, and trie_finished() needs to pop the excessive elements. fn trie_finished(&mut self); /// Called when iteration over the trie is started fn trie_started(&mut self) {} From 7a6785de7c3546480fb2334027ef8cc1d8fe89a3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 11 Apr 2024 21:36:08 +0000 Subject: [PATCH 198/301] updating python code to deal with no pre/post --- controllers/aici_abi/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 46efe891..a50879ab 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -173,6 +173,7 @@ impl MidProcessResult { #[derive(Serialize, Deserialize)] pub struct ProcessResultOffset { + /// Branches use byte offsets into the bias tensor. pub branches: Vec>, } From fb3fd188ad5d5bc7c7815dc1e40cd558e7a35469 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 12 Apr 2024 18:16:13 +0000 Subject: [PATCH 199/301] update vllm --- controllers/aici_abi/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index a50879ab..cda39e4a 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -169,6 +169,10 @@ impl MidProcessResult { }], } } + + pub fn noop() -> Self { + Self::splice(0, vec![]) + } } #[derive(Serialize, Deserialize)] From b483e5752d0ad7cc874f3ac73ae91672a6dc643b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 12 Apr 2024 21:31:46 +0000 Subject: [PATCH 200/301] add get_config() host API (fork support detection so far) --- controllers/aici_abi/src/host.rs | 14 ++++++++++++++ controllers/aici_abi/src/lib.rs | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index cc8f69de..07ac97f1 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -41,6 +41,9 @@ extern "C" { // This can be also obtained from the TokTrie. fn aici_host_eos_token() -> TokenId; + // Get value of configuration parameters, like "forks". + fn aici_host_get_config(src: *const u8, src_size: u32) -> i32; + // Stop the program - any error info is assumed to have been printed already. // Backtraces will be limited. fn aici_host_stop(); @@ -127,6 +130,7 @@ pub trait HostInterface { fn tokenize_bytes(&self, s: &[u8]) -> Vec; fn self_seq_id(&self) -> SeqId; fn eos_token(&self) -> TokenId; + fn get_config(&self, name: &str) -> i32; fn stop(&self) -> !; } @@ -190,6 +194,12 @@ impl HostInterface for WasmHost { fn eos_token(&self) -> TokenId { unsafe { aici_host_eos_token() } } + + fn get_config(&self, name: &str) -> i32 { + let name_bytes = name.as_bytes(); + let res = unsafe { aici_host_get_config(name_bytes.as_ptr(), name_bytes.len() as u32) }; + res + } } fn get_host() -> &'static Box { @@ -234,6 +244,10 @@ pub fn return_process_result(res: &[u8]) { } } +pub fn get_config(name: &str) -> i32 { + get_host().get_config(name) +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub enum StorageOp { Set, diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index cda39e4a..9995152b 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -21,8 +21,8 @@ pub mod substring; pub type TokenId = bytes::TokenId; pub use host::{ - aici_stop, arg_bytes, arg_string, self_seq_id, tokenize, tokenize_bytes, StorageCmd, StorageOp, - StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, + aici_stop, arg_bytes, arg_string, get_config, self_seq_id, tokenize, tokenize_bytes, + StorageCmd, StorageOp, StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, }; #[cfg(not(target_arch = "wasm32"))] From 755ed45f77597248e0f5176e0871494ffc1edfbf Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 12 Apr 2024 23:22:49 +0000 Subject: [PATCH 201/301] fix typo --- controllers/aici_abi/src/host.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 07ac97f1..353e252e 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -41,7 +41,7 @@ extern "C" { // This can be also obtained from the TokTrie. fn aici_host_eos_token() -> TokenId; - // Get value of configuration parameters, like "forks". + // Get value of configuration parameters, like "fork". fn aici_host_get_config(src: *const u8, src_size: u32) -> i32; // Stop the program - any error info is assumed to have been printed already. From ed83f30dd45170fec39bba3a54597d439d48ffb1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 12 Apr 2024 23:23:48 +0000 Subject: [PATCH 202/301] fix declctrl --- controllers/aici_abi/src/lib.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 9995152b..87864381 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -132,6 +132,21 @@ impl Branch { splices: self.splices.clone(), } } + + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { + Branch { + sample_mask: None, + splices: vec![Splice { + when_sampled: vec![], + backtrack, + ff_tokens, + }], + } + } + + pub fn noop() -> Self { + Self::splice(0, vec![]) + } } #[derive(Debug)] @@ -159,14 +174,7 @@ impl MidProcessResult { pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { MidProcessResult { - branches: vec![Branch { - sample_mask: None, - splices: vec![Splice { - when_sampled: vec![], - backtrack, - ff_tokens, - }], - }], + branches: vec![Branch::splice(backtrack, ff_tokens)], } } From 371706835ed63a92ad48da091afebe7119d18407 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 16 Apr 2024 02:04:01 +0000 Subject: [PATCH 203/301] allow local tokenizers --- controllers/aici_native/src/bintokens.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 660d17af..34090bdb 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -113,13 +113,14 @@ fn build_char_map() -> FxHashMap { pub fn list_tokenizers() -> String { format!( - "Available tokenizers for -t or --tokenizer:\n{}\n{}", + "Available tokenizers for -t or --tokenizer:\n{}\n{}\n{}", tokenizers() .iter() .map(|t| format!(" -t {:16} {}", t.name, t.description)) .collect::>() .join("\n"), - "You can also use a HuggingFace model name, in format 'user/modelname'." + "You can also use a HuggingFace model name, in format 'user/modelname',", + "or a local file in format './path/to/tokenizer.json'." ) } @@ -165,15 +166,20 @@ pub fn find_tokenizer(mut name: &str) -> Result { log::info!("loading tokenizer: {}", name); - let mut name2 = name.to_string(); - let mut args = FromPretrainedParameters::default(); + let loaded = if name.starts_with(".") { + Tokenizer::from_file(name) + } else { + let mut name2 = name.to_string(); + let mut args = FromPretrainedParameters::default(); - match strip_suffix("@", &mut name2) { - Some(s) => args.revision = s, - None => {} - } + match strip_suffix("@", &mut name2) { + Some(s) => args.revision = s, + None => {} + } + Tokenizer::from_pretrained(name2, Some(args)) + }; - match Tokenizer::from_pretrained(name2, Some(args)) { + match loaded { Err(e) => { let msg = format!("can't load tokenizer {}: {}", name, e); println!("{}\n{}", msg, list_tokenizers()); From 222ca5a68f3da69c6e50b5071568866f878cd834 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 16 Apr 2024 23:44:07 +0000 Subject: [PATCH 204/301] add aicirt --logits-size --- controllers/aici_native/src/bintokens.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 34090bdb..95e06434 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -317,9 +317,7 @@ impl ByteTokenizer { Ok(res) } -} -impl ByteTokenizer { pub fn tokrx_info(&self) -> TokRxInfo { TokRxInfo { vocab_size: self.vocab_size, @@ -329,6 +327,19 @@ impl ByteTokenizer { pub fn token_bytes(&self) -> Vec> { self.token_bytes.clone() } + + pub fn add_missing_tokens(&mut self, vocab_size: usize) { + assert!(self.vocab_size == self.token_bytes.len() as u32); + assert!(vocab_size >= self.token_bytes.len()); + assert!(vocab_size - self.token_bytes.len() <= 200); + while self.token_bytes.len() < vocab_size { + let idx = self.token_bytes.len(); + let name = format!(""); + self.token_bytes.push(name.as_bytes().to_vec()); + self.vocab_size += 1; + self.special.insert(name, idx as u32); + } + } } pub struct ByteTokenizerEnv { From 383bc4b6e7e4df3c1deb5f87da9badfe76b56c04 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 19 Apr 2024 00:02:29 +0000 Subject: [PATCH 205/301] add and use --bias-dtype --- controllers/aici_abi/src/host.rs | 12 ++++++------ controllers/aici_abi/src/lib.rs | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 353e252e..d71dbfb8 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -30,7 +30,7 @@ extern "C" { fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; // Set logit bias based on bit-mask in src. - fn aici_host_return_logit_bias(src: *const u32); + fn aici_host_return_logit_bias(src: *const u32) -> u32; fn aici_host_self_seq_id() -> u32; @@ -123,7 +123,7 @@ impl TokenizerEnv for WasmTokenizerEnv { pub trait HostInterface { fn arg_bytes(&self) -> Vec; fn trie_bytes(&self) -> Vec; - fn return_logit_bias(&self, vob: &SimpleVob); + fn return_logit_bias(&self, vob: &SimpleVob) -> u32; fn process_arg_bytes(&self) -> Vec; fn return_process_result(&self, res: &[u8]); fn storage_cmd(&self, cmd: StorageCmd) -> StorageResp; @@ -146,10 +146,10 @@ impl HostInterface for WasmHost { read_blob(unsafe { aici_host_token_trie() }, 0) } - fn return_logit_bias(&self, vob: &SimpleVob) { + fn return_logit_bias(&self, vob: &SimpleVob) -> u32 { assert!(vob.len() > 0); unsafe { - aici_host_return_logit_bias(vob.as_ptr()); + aici_host_return_logit_bias(vob.as_ptr()) } } @@ -230,8 +230,8 @@ pub fn trie_bytes() -> Vec { // return std::fs::read("tokenizer.bin").unwrap(); } -pub fn return_logit_bias(vob: &SimpleVob) { - get_host().return_logit_bias(vob); +pub fn return_logit_bias(vob: &SimpleVob) -> u32 { + get_host().return_logit_bias(vob) } pub fn process_arg_bytes() -> Vec { diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 87864381..e96ff8bb 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -223,8 +223,7 @@ pub trait AiciCtrl { panic!("aici_mid_process: multiple branches with sampling not yet supported"); } used_logits = true; - host::return_logit_bias(&vob); - 0 + host::return_logit_bias(&vob) as usize }) }) .collect(), From 5d41a54bd1556cad6755839f198e0289a97a9df5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 20 Apr 2024 00:01:50 +0000 Subject: [PATCH 206/301] allow for prompt-rewriting in aici_init_prompt --- controllers/aici_abi/src/lib.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index e96ff8bb..40df1c3b 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -33,8 +33,16 @@ pub struct InitPromptArg { pub prompt: Vec, } -#[derive(Serialize, Deserialize, Debug, Default)] -pub struct InitPromptResult {} +#[derive(Serialize, Deserialize, Debug)] +pub struct InitPromptResult { + pub prompt: Vec, +} + +impl InitPromptResult { + pub fn from_arg(arg: InitPromptArg) -> Self { + InitPromptResult { prompt: arg.prompt } + } +} #[repr(transparent)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] @@ -192,9 +200,8 @@ pub struct ProcessResultOffset { pub trait AiciCtrl { /// Called with the initial prompt. ~1000ms time limit. /// By default ignore prompt. - /// This is typically just the start token if any (REST API forces empty prompt). - fn init_prompt(&mut self, _arg: InitPromptArg) -> InitPromptResult { - InitPromptResult::default() + fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { + InitPromptResult::from_arg(arg) } /// This is the main entry point for the module. ~20ms time limit. From e471aa899ae258c5b8b256987ef5e20d189ba8d4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 23 Apr 2024 00:47:03 +0000 Subject: [PATCH 207/301] limit re-tokenization --- controllers/aici_abi/src/toktree.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 5b8923fd..6ec6371a 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -563,6 +563,25 @@ impl TokTrie { ok } + /// Return how many tokens and bytes need to chopped off tokens, + /// so that we do not limit all possible future tokenizations matching the recognizer. + pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) { + let mut suff = Vec::new(); + let mut chop_tokens = 0; + let mut chop_bytes = 0; + for (idx, t) in tokens.iter().rev().enumerate() { + suff.splice(0..0, self.token(*t).iter().cloned()); + if suff.len() > self.max_token_len() { + break; + } + if self.has_valid_extensions(r, &suff) { + chop_tokens = idx + 1; + chop_bytes = suff.len(); + } + } + (chop_tokens, chop_bytes) + } + /// Check if add_bias() would have returned any tokens. #[inline(never)] pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool { From bc94bd790721ea0f39ad9d744562e8a9068277f0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 23 Apr 2024 22:20:57 +0000 Subject: [PATCH 208/301] allow longer tokens; fixes #98 --- controllers/aici_abi/src/recognizer.rs | 2 +- controllers/aici_abi/src/toktree.rs | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/controllers/aici_abi/src/recognizer.rs b/controllers/aici_abi/src/recognizer.rs index 7045da6e..588059e8 100644 --- a/controllers/aici_abi/src/recognizer.rs +++ b/controllers/aici_abi/src/recognizer.rs @@ -50,7 +50,7 @@ pub struct StackRecognizer> { impl> StackRecognizer { pub fn from(rec: R) -> Self { - let stack = vec![rec.initial(); 130]; + let stack = vec![rec.initial(); 300]; StackRecognizer { rec, stack, diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 6ec6371a..c58f1246 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -122,6 +122,9 @@ impl TrieNode { } } +// max length of token is 1023 bytes +const LEN_BITS: u32 = 10; + impl TokTrie { pub fn from_host() -> Self { let buffer = trie_bytes(); @@ -137,8 +140,9 @@ impl TokTrie { if word.len() > 0 { trie.insert(word, idx as u32); } - assert!(word.len() < 0xff); - let desc = (word.len() as u32) | ((token_data.len() as u32) << 8); + assert!(word.len() < (1 << LEN_BITS)); + assert!(token_data.len() < (1 << (32 - LEN_BITS))); + let desc = (word.len() as u32) | ((token_data.len() as u32) << LEN_BITS); token_offsets.push(desc); token_data.extend_from_slice(word); } @@ -292,8 +296,8 @@ impl TokTrie { pub fn token(&self, idx: u32) -> &[u8] { let off = self.token_offsets[idx as usize]; - let len = off & 0xff; - let off = (off >> 8) as usize; + let len = off & ((1 << LEN_BITS) - 1); + let off = (off >> LEN_BITS) as usize; &self.token_data[off..(off + len as usize)] } From 2966550f0f3b899318e71db529f290b2cb93679c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 24 Apr 2024 00:06:19 +0000 Subject: [PATCH 209/301] report final text from guidance --- controllers/aici_abi/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 40df1c3b..5a6e66e3 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -189,6 +189,10 @@ impl MidProcessResult { pub fn noop() -> Self { Self::splice(0, vec![]) } + + pub fn is_stop(&self) -> bool { + self.branches.is_empty() + } } #[derive(Serialize, Deserialize)] From 72dd0a1bf1ee39bc3e0910801f0fe8a82f573e40 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 24 Apr 2024 19:59:51 +0000 Subject: [PATCH 210/301] append/byte_allowed -> try_append in FunctionalRecognizer --- controllers/aici_abi/src/recognizer.rs | 47 +++++++++++--------------- controllers/aici_abi/src/rx.rs | 14 ++++---- controllers/aici_abi/src/substring.rs | 21 +++++++----- 3 files changed, 38 insertions(+), 44 deletions(-) diff --git a/controllers/aici_abi/src/recognizer.rs b/controllers/aici_abi/src/recognizer.rs index 588059e8..57eabc09 100644 --- a/controllers/aici_abi/src/recognizer.rs +++ b/controllers/aici_abi/src/recognizer.rs @@ -33,10 +33,8 @@ impl AiciCtrl for AiciRecognizer { pub trait FunctionalRecognizer { /// Initial state fn initial(&self) -> S; - /// Extend the recognizer with given byte. - fn append(&self, state: S, byte: u8) -> S; - /// Check if given byte is allowed in given state. - fn byte_allowed(&self, state: S, byte: u8) -> bool; + /// Extend the recognizer with given byte if allowed. + fn try_append(&self, state: S, byte: u8) -> Option; /// Check if given special token is allowed in given state. fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; } @@ -62,6 +60,14 @@ impl> StackRecognizer { self.stack_ptr = 0; self.stack[0] = self.rec.initial(); } + + pub fn recognizer(&self) -> &R { + &self.rec + } + + pub fn recognizer_mut(&mut self) -> &mut R { + &mut self.rec + } } impl> Debug for StackRecognizer { @@ -73,24 +79,11 @@ impl> Debug for StackRecognizer> Recognizer for StackRecognizer { - #[inline(always)] - fn push_byte(&mut self, byte: u8) { - let state = self.stack[self.stack_ptr]; - let state = self.rec.append(state, byte); - self.stack_ptr += 1; - self.stack[self.stack_ptr] = state; - } - #[inline(always)] fn pop_bytes(&mut self, num: usize) { self.stack_ptr -= num; } - #[inline(always)] - fn byte_allowed(&mut self, byte: u8) -> bool { - self.rec.byte_allowed(self.stack[self.stack_ptr], byte) - } - fn trie_finished(&mut self) { // println!("{:?}", &self.stack[0..=self.stack_ptr]); assert!(self.stack_ptr == 0); @@ -107,11 +100,13 @@ impl> Recognizer for StackRecognizer #[inline(always)] fn try_push_byte(&mut self, byte: u8) -> bool { - if self.rec.byte_allowed(self.stack[self.stack_ptr], byte) { - self.push_byte(byte); - true - } else { - false + match self.rec.try_append(self.stack[self.stack_ptr], byte) { + Some(state) => { + self.stack_ptr += 1; + self.stack[self.stack_ptr] = state; + true + } + None => false, } } } @@ -124,12 +119,8 @@ impl FunctionalRecognizer<()> for AnythingGoes { () } - fn append(&self, state: (), _byte: u8) -> () { - state - } - - fn byte_allowed(&self, _state: (), _byte: u8) -> bool { - true + fn try_append(&self, state: (), _byte: u8) -> Option<()> { + Some(state) } fn special_allowed(&self, _state: (), _tok: SpecialToken) -> bool { diff --git a/controllers/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs index 04ebfbcf..a116e2df 100644 --- a/controllers/aici_abi/src/rx.rs +++ b/controllers/aici_abi/src/rx.rs @@ -45,13 +45,13 @@ impl FunctionalRecognizer for RecRx { } #[inline(always)] - fn append(&self, state: RecRxState, byte: u8) -> RecRxState { - self.dfa.next_state(state, byte) - } - - #[inline(always)] - fn byte_allowed(&self, state: RecRxState, byte: u8) -> bool { - !self.dfa.is_dead_state(self.dfa.next_state(state, byte)) + fn try_append(&self, state: RecRxState, byte: u8) -> Option { + let next = self.dfa.next_state(state, byte); + if self.dfa.is_dead_state(next) { + None + } else { + Some(next) + } } #[inline(always)] diff --git a/controllers/aici_abi/src/substring.rs b/controllers/aici_abi/src/substring.rs index d5d262f3..69308631 100644 --- a/controllers/aici_abi/src/substring.rs +++ b/controllers/aici_abi/src/substring.rs @@ -230,15 +230,9 @@ impl SubStrMatcher { SubStrState::SourceOffset(off) => self.append_to_src_off(off, byte), } } -} - -impl FunctionalRecognizer for SubStrMatcher { - fn initial(&self) -> SubStrState { - SubStrState::Node(0) - } #[inline(always)] - fn append(&self, state: SubStrState, byte: u8) -> SubStrState { + fn do_append(&self, state: SubStrState, byte: u8) -> SubStrState { let state = match state { SubStrState::Node(_) | SubStrState::SourceOffset(_) if self.end_str.as_bytes().first() == Some(&byte) @@ -251,10 +245,19 @@ impl FunctionalRecognizer for SubStrMatcher { self.append_inner(state, byte) } +} + +impl FunctionalRecognizer for SubStrMatcher { + fn initial(&self) -> SubStrState { + SubStrState::Node(0) + } #[inline(always)] - fn byte_allowed(&self, state: SubStrState, byte: u8) -> bool { - self.append(state, byte) != SubStrState::Dead + fn try_append(&self, state: SubStrState, byte: u8) -> Option { + match self.do_append(state, byte) { + SubStrState::Dead => None, + state => Some(state), + } } #[inline(always)] From 2e10bcc06288ebd1292b55ccdfeaf62287ee13d3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 24 Apr 2024 21:08:43 +0000 Subject: [PATCH 211/301] first draft of DynamicLexer --- controllers/aici_abi/src/dlex.rs | 266 +++++++++++++++++++++++++++++++ controllers/aici_abi/src/lib.rs | 2 + 2 files changed, 268 insertions(+) create mode 100644 controllers/aici_abi/src/dlex.rs diff --git a/controllers/aici_abi/src/dlex.rs b/controllers/aici_abi/src/dlex.rs new file mode 100644 index 00000000..02f04313 --- /dev/null +++ b/controllers/aici_abi/src/dlex.rs @@ -0,0 +1,266 @@ +use crate::{ + recognizer::{FunctionalRecognizer, StackRecognizer}, + svob::SimpleVob, + toktree::SpecialToken, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct NodeId(u32); + +impl NodeId { + const NULL: NodeId = NodeId(0); + const ROOT: NodeId = NodeId(1); +} + +#[derive(Debug, Default, Clone)] +pub struct NodeData { + pub is_terminal: bool, +} + +enum TrieNode { + Sparse { + data: NodeData, + children: Vec<(u8, NodeId)>, + }, + Dense { + data: NodeData, + children: Vec, + }, +} + +impl TrieNode { + fn new_dense(data: NodeData, children: &Vec<(u8, NodeId)>) -> Self { + let mut dense_children = vec![NodeId::NULL; 256]; + for (byte, node_id) in children { + dense_children[*byte as usize] = *node_id; + } + TrieNode::Dense { + data, + children: dense_children, + } + } + + fn new_leaf() -> Self { + TrieNode::Sparse { + data: NodeData::default(), + children: vec![], + } + } + + fn data(&self) -> &NodeData { + match self { + TrieNode::Sparse { data, .. } => data, + TrieNode::Dense { data, .. } => data, + } + } + + fn data_mut(&mut self) -> &mut NodeData { + match self { + TrieNode::Sparse { data, .. } => data, + TrieNode::Dense { data, .. } => data, + } + } +} + +pub struct Trie { + nodes: Vec, +} + +impl Trie { + const MAX_SPARSE: usize = 8; + + pub fn new() -> Self { + Trie { + nodes: vec![ + TrieNode::new_leaf(), + TrieNode::new_dense(NodeData::default(), &vec![]), + ], + } + } + + fn node(&self, node_id: NodeId) -> &TrieNode { + &self.nodes[node_id.0 as usize] + } + + fn node_mut(&mut self, node_id: NodeId) -> &mut TrieNode { + &mut self.nodes[node_id.0 as usize] + } + + pub fn node_data(&self, node_id: NodeId) -> &NodeData { + self.node(node_id).data() + } + + pub fn root(&self) -> NodeId { + NodeId::ROOT + } + + pub fn child_at(&self, start: NodeId, b: u8) -> Option { + match self.node(start) { + TrieNode::Sparse { children, .. } => { + children.iter().find_map( + |&(byte, node_id)| { + if byte == b { + Some(node_id) + } else { + None + } + }, + ) + } + TrieNode::Dense { children, .. } => { + let node_id = children[b as usize]; + if node_id == NodeId::NULL { + None + } else { + Some(node_id) + } + } + } + } + + pub fn lookup(&self, start: NodeId, word: &[u8]) -> Option { + let mut node_id = start; + for &byte in word { + match self.child_at(node_id, byte) { + Some(child_id) => { + node_id = child_id; + } + None => { + return None; + } + } + } + Some(node_id) + } + + pub fn add(&mut self, word: &[u8]) { + let mut node_id = NodeId::ROOT; + for &byte in word { + let new_node_id = NodeId(self.nodes.len() as u32); + let node = self.node_mut(node_id); + match node { + TrieNode::Sparse { data, children } => { + match children.iter().find(|&&(b, _)| b == byte) { + Some(&(_, child_id)) => { + node_id = child_id; + } + None => { + children.push((byte, new_node_id)); + if children.len() > Trie::MAX_SPARSE { + self.nodes[node_id.0 as usize] = + TrieNode::new_dense(data.clone(), children); + } + self.nodes.push(TrieNode::new_leaf()); + node_id = new_node_id; + } + } + } + TrieNode::Dense { children, .. } => { + node_id = children[byte as usize]; + if node_id == NodeId::NULL { + children[byte as usize] = new_node_id; + self.nodes.push(TrieNode::new_leaf()); + node_id = new_node_id; + } + } + } + } + + self.node_mut(node_id).data_mut().is_terminal = true; + } +} + +pub struct DynamicLexer { + trie: Trie, + id_start: SimpleVob, + id_body: SimpleVob, +} + +#[derive(Debug, Clone, Copy)] +pub struct DState { + node_id: NodeId, +} + +impl DState { + const ROOT: DState = DState { + node_id: NodeId::ROOT, + }; +} + +pub type DynamicLexerRec = StackRecognizer; + +impl DynamicLexer { + pub fn new(additional_id_chars: &Vec) -> Self { + let mut id_start = SimpleVob::alloc(0x100); + let mut id_body = SimpleVob::alloc(0x100); + for i in 0..=255u8 { + match i as char { + 'a'..='z' | 'A'..='Z' | '_' => { + id_start.allow_token(i as u32); + id_body.allow_token(i as u32); + } + '0'..='9' => { + id_body.allow_token(i as u32); + } + _ => {} + } + } + for &c in additional_id_chars { + id_start.allow_token(c as u32); + id_body.allow_token(c as u32); + } + DynamicLexer { + trie: Trie::new(), + id_start, + id_body, + } + } + + pub fn to_stack_recognizer(self) -> StackRecognizer { + StackRecognizer::from(self) + } + + pub fn add(&mut self, word: &[u8]) { + self.trie.add(word); + } +} + +impl FunctionalRecognizer for DynamicLexer { + fn initial(&self) -> DState { + DState::ROOT + } + + fn try_append(&self, state: DState, byte: u8) -> Option { + if state.node_id == NodeId::ROOT { + if self.id_start.is_allowed(byte as u32) { + match self.trie.child_at(state.node_id, byte) { + Some(node_id) => Some(DState { node_id }), + None => None, + } + } else { + Some(state) + } + } else { + if self.id_body.is_allowed(byte as u32) { + match self.trie.child_at(state.node_id, byte) { + Some(node_id) => Some(DState { node_id }), + None => None, + } + } else { + if self.trie.node_data(state.node_id).is_terminal { + Some(DState::ROOT) + } else { + None + } + } + } + } + + fn special_allowed(&self, state: DState, tok: SpecialToken) -> bool { + if tok == SpecialToken::EndOfSentence { + self.trie.node_data(state.node_id).is_terminal + } else { + false + } + } +} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 5a6e66e3..e4ccb2fd 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -16,6 +16,8 @@ mod lex; #[cfg(feature = "rx")] pub mod rx; +pub mod dlex; + pub mod substring; pub type TokenId = bytes::TokenId; From 2be07d8bb5750d38401d1f1cdf38b562d9bca750 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 29 Apr 2024 20:23:15 +0000 Subject: [PATCH 212/301] Use bytes not strings in substr matcher; fixes #100 --- controllers/aici_abi/src/substring.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/controllers/aici_abi/src/substring.rs b/controllers/aici_abi/src/substring.rs index 69308631..55f80eef 100644 --- a/controllers/aici_abi/src/substring.rs +++ b/controllers/aici_abi/src/substring.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use crate::{ - bytes::limit_str, + bytes::limit_bytes, recognizer::{FunctionalRecognizer, StackRecognizer}, toktree::SpecialToken, }; @@ -14,7 +14,7 @@ enum Node { pub struct SubStrMatcher { end_str: String, - source: String, + source: Vec, nodes: Vec, } @@ -52,7 +52,7 @@ impl SubStrMatcher { serde_json::Value::Object(children_json) } Node::Leaf { source_offset } => { - json!(limit_str(&self.source[*source_offset..], 20)) + json!(limit_bytes(&self.source[*source_offset..], 20)) } } } @@ -77,7 +77,7 @@ impl SubStrMatcher { "{:indent$}{}: {:?}", "", *source_offset, - limit_str(&self.source[*source_offset..], 20), + limit_bytes(&self.source[*source_offset..], 20), )?; } } @@ -86,13 +86,13 @@ impl SubStrMatcher { pub fn new(source: &str, end_str: &str) -> Self { let mut tmp = Self { - source: source.to_string() + " ", + source: (source.to_string() + " ").as_bytes().to_vec(), end_str: end_str.to_string(), nodes: vec![Node::Inner { children: vec![] }], }; tmp.add(0); for i in 0..tmp.source.len() { - if tmp.source.as_bytes()[i] == b' ' { + if tmp.source[i] == b' ' { tmp.add(i + 1); } } @@ -101,15 +101,15 @@ impl SubStrMatcher { tmp } - fn find(&self, s: &str) -> (usize, usize) { + fn find(&self, s: &[u8]) -> (usize, usize) { let mut node_idx = 0; - for (i, b) in s.bytes().enumerate() { + for (i, b) in s.iter().enumerate() { let node = &self.nodes[node_idx]; match node { Node::Inner { children } => { let mut found = false; for (c, idx) in children.iter() { - if *c == b { + if *c == *b { node_idx = *idx; found = true; break; @@ -137,7 +137,7 @@ impl SubStrMatcher { let num_nodes = self.nodes.len(); match &mut self.nodes[node_idx] { Node::Inner { children } => { - children.push((s1.as_bytes()[0], num_nodes)); + children.push((s1[0], num_nodes)); let n = add_node( &mut self.nodes, Node::Leaf { @@ -160,8 +160,8 @@ impl SubStrMatcher { } for i in 0..s1.len() { - let b1 = s1.as_bytes()[i]; - let b2 = s2.as_bytes()[i]; + let b1 = s1[i]; + let b2 = s2[i]; if b1 != b2 { let n1 = add_node( &mut self.nodes, @@ -196,7 +196,7 @@ impl SubStrMatcher { } fn append_to_src_off(&self, off: usize, byte: u8) -> SubStrState { - if off < self.source.len() && self.source.as_bytes()[off] == byte { + if off < self.source.len() && self.source[off] == byte { SubStrState::SourceOffset(off + 1) } else { SubStrState::Dead From 1981da0aac162e692d29a95a492941af7eafa138 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 4 May 2024 00:08:35 +0000 Subject: [PATCH 213/301] fix EOS checking with byte prefix --- controllers/aici_abi/src/toktree.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index c58f1246..638c277c 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -508,9 +508,12 @@ impl TokTrie { pub fn compute_bias_ext(&self, r: &mut impl Recognizer, logits: &mut SimpleVob, start: &[u8]) { logits.set_all(false); - for tok in vec![SpecialToken::EndOfSentence] { - if r.special_allowed(tok) { - logits.allow_token(self.special_token(tok)) + if start.is_empty() { + // EOS is only allowed if there is no forced byte prefix + for tok in vec![SpecialToken::EndOfSentence] { + if r.special_allowed(tok) { + logits.allow_token(self.special_token(tok)) + } } } // all prefixes of 'start' are also allowed From ba69d35d8e8e04119010d859afd8a3ef297d90a8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 May 2024 17:00:40 +0000 Subject: [PATCH 214/301] allow control of temperature --- controllers/aici_abi/src/lib.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index e4ccb2fd..659feac3 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -118,6 +118,8 @@ pub struct Branch { /// If None, no sampling is performed. /// If Some(set), only tokens from the set are allowed. pub sample_mask: Option, + /// Override temperature for sampling. It may or may not be sticky. + pub temperature: Option, /// Describes what to do after sampling. /// If no sampling, there should be exactly one splice, with empty `when_sampled`. pub splices: Vec, @@ -127,6 +129,7 @@ impl Clone for Branch { fn clone(&self) -> Self { Branch { sample_mask: self.sample_mask.clone(), + temperature: self.temperature, splices: self.splices.clone(), } } @@ -139,6 +142,7 @@ impl Branch { { Branch { sample_mask: self.sample_mask.as_ref().map(f), + temperature: self.temperature, splices: self.splices.clone(), } } @@ -146,6 +150,7 @@ impl Branch { pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { Branch { sample_mask: None, + temperature: None, splices: vec![Splice { when_sampled: vec![], backtrack, @@ -174,9 +179,14 @@ impl MidProcessResult { } pub fn sample(set: SimpleVob) -> Self { + Self::sample_with_temp(set, None) + } + + pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { MidProcessResult { branches: vec![Branch { sample_mask: Some(set), + temperature: temperature, splices: vec![], }], } From ca8c67f2051a3270ba4a3bd3762a3d1f2e23b87a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 10 May 2024 18:48:03 +0000 Subject: [PATCH 215/301] allow absoulte model and tokenizer paths --- controllers/aici_native/src/bintokens.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 95e06434..da632c26 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -166,7 +166,7 @@ pub fn find_tokenizer(mut name: &str) -> Result { log::info!("loading tokenizer: {}", name); - let loaded = if name.starts_with(".") { + let loaded = if name.starts_with(".") || name.starts_with("/") { Tokenizer::from_file(name) } else { let mut name2 = name.to_string(); From 32233bd4e25a1c3a0aad3129c37c882357ef0dd7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 15:41:59 +0000 Subject: [PATCH 216/301] add exact size and iterators to SimpleVob; regex fixes --- controllers/aici_abi/src/rx.rs | 64 +++++++++-- controllers/aici_abi/src/svob.rs | 159 ++++++++++++++++++++++++++-- controllers/aici_abi/src/toktree.rs | 28 ++--- 3 files changed, 224 insertions(+), 27 deletions(-) diff --git a/controllers/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs index a116e2df..4d0397d0 100644 --- a/controllers/aici_abi/src/rx.rs +++ b/controllers/aici_abi/src/rx.rs @@ -1,7 +1,10 @@ +use std::error::Error; + use crate::{ recognizer::{FunctionalRecognizer, StackRecognizer}, toktree::SpecialToken, }; +use anyhow::{bail, Result}; use regex_automata::{ dfa::{dense, Automaton}, util::{primitives::StateID, syntax}, @@ -12,24 +15,65 @@ pub type RecRxState = StateID; #[derive(Clone)] pub struct RecRx { dfa: dense::DFA>, + info: String, } pub type RxStackRecognizer = StackRecognizer; impl RecRx { - pub fn from_rx(rx: &str) -> Self { + pub fn from_rx(rx: &str) -> Result { let rx = if rx.ends_with("$") { rx.to_string() } else { rx.to_string() + "$" }; + let rx = if rx.starts_with("^") { + rx[1..].to_string() + } else { + rx + }; + let t0 = std::time::Instant::now(); + let size_mb = 3; // 3MB should be on the order of 50ms + let cfg = dense::Config::new() + .start_kind(regex_automata::dfa::StartKind::Anchored) + .dfa_size_limit(Some(size_mb << 20)) + .determinize_size_limit(Some(size_mb << 20)); let dfa = dense::Builder::new() - .configure(dense::Config::new().start_kind(regex_automata::dfa::StartKind::Anchored)) + .configure(cfg) .syntax(syntax::Config::new().unicode(false).utf8(false)) - .build(&rx) - .unwrap(); - println!("dfa: {} bytes", dfa.memory_usage()); - Self { dfa } + .build(&rx); + let dfa = match dfa { + Ok(dfa) => dfa, + Err(e) => { + if let Some(e) = e.source() { + if let Some(e) = e.source() { + bail!("error building dfa(2): {}", e) + } else { + bail!("error building dfa(1): {}", e) + } + } else { + bail!("error building dfa(0): {}", e) + } + } + }; + let time = t0.elapsed(); + let mb_per_s = dfa.memory_usage() as f64 / time.as_secs_f64() / 1024.0 / 1024.0; + let info = format!( + "dfa: {} bytes; time {:?}; {:.3} MB/s", + dfa.memory_usage(), + time, + mb_per_s + ); + + if let Err(e) = dfa.start_state(&anchored_start()) { + bail!("DFA has no start state; {}", e) + } + + Ok(Self { dfa, info }) + } + + pub fn info(&self) -> &str { + &self.info } pub fn to_stack_recognizer(self) -> RxStackRecognizer { @@ -37,11 +81,15 @@ impl RecRx { } } +fn anchored_start() -> regex_automata::util::start::Config { + regex_automata::util::start::Config::new().anchored(regex_automata::Anchored::Yes) +} + impl FunctionalRecognizer for RecRx { fn initial(&self) -> RecRxState { self.dfa - .universal_start_state(regex_automata::Anchored::Yes) - .expect("dfa has no universal start state; make sure it doesn't match empty string") + .start_state(&anchored_start()) + .expect("dfa has no start state") } #[inline(always)] diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index c5fd1997..a029d00d 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -4,6 +4,7 @@ use std::{fmt::Debug, ops::Index}; #[derive(Clone)] pub struct SimpleVob { data: Vec, + size: usize, } impl Debug for SimpleVob { @@ -20,11 +21,20 @@ impl Default for SimpleVob { } } +impl Into> for SimpleVob { + fn into(self) -> Vec { + self.data + } +} + const BITS: usize = 32; impl SimpleVob { pub fn new() -> Self { - Self { data: Vec::new() } + Self { + data: Vec::new(), + size: 0, + } } pub fn alloc(size: usize) -> Self { @@ -34,20 +44,25 @@ impl SimpleVob { } pub fn len(&self) -> usize { - self.data.len() * BITS + self.size } pub fn num_set(&self) -> usize { self.data.iter().map(|x| x.count_ones() as usize).sum() } - pub fn negated(&self, size: usize) -> Self { - let mut r = Self::new(); - r.data = self.data.iter().map(|x| !x).collect(); - for i in size..r.len() { + fn clear_excessive_bits(&mut self) { + for i in self.size..(self.data.len() * 32) { // disallow tokens that are out of range - r.disallow_token(i as TokenId); + self.disallow_token(i as TokenId); } + } + + pub fn negated(&self) -> Self { + let mut r = Self::new(); + r.data = self.data.iter().map(|x| !x).collect(); + r.size = self.size; + r.clear_excessive_bits(); r } @@ -55,6 +70,103 @@ impl SimpleVob { self.data.as_ptr() } + pub fn as_slice(&self) -> &[u32] { + &self.data + } + + #[inline(always)] + pub fn iter_set_entries(&self, mut f: impl FnMut(usize)) { + let src = self.as_slice(); + let numelts = self.size; + let max_len = numelts / 32; + for idx in 0..max_len { + let d = src[idx]; + // optimize for the two common cases + if d == 0 { + continue; + } else if d == u32::MAX { + for bit in 0..32 { + f(idx * 32 + bit); + } + } else { + for bit in 0..32 { + if d & (1 << bit) != 0 { + f(idx * 32 + bit); + } + } + } + } + // final few elts + for idx in (max_len * 32)..numelts { + if self.is_allowed(idx as TokenId) { + f(idx); + } + } + } + + #[inline(always)] + pub fn iter_unset_entries(&self, mut f: impl FnMut(usize)) { + let src = self.as_slice(); + let numelts = self.size; + let max_len = numelts / 32; + for idx in 0..max_len { + let d = src[idx]; + // optimize for the two common cases + if d == 0 { + for bit in 0..32 { + f(idx * 32 + bit); + } + } else if d == u32::MAX { + continue; + } else { + for bit in 0..32 { + if d & (1 << bit) == 0 { + f(idx * 32 + bit); + } + } + } + } + // final few elts + for idx in (max_len * 32)..numelts { + if !self.is_allowed(idx as TokenId) { + f(idx); + } + } + } + + #[inline(always)] + pub fn iter_entries(&self, mut f: impl FnMut(bool, usize)) { + let src = self.as_slice(); + let numelts = self.size; + let max_len = numelts / 32; + for idx in 0..max_len { + let d = src[idx]; + // optimize for the two common cases + if d == 0 { + for bit in 0..32 { + f(false, idx * 32 + bit); + } + } else if d == u32::MAX { + for bit in 0..32 { + f(true, idx * 32 + bit); + } + } else { + for bit in 0..32 { + f(d & (1 << bit) != 0, idx * 32 + bit); + } + } + } + // final few elts + for idx in (max_len * 32)..numelts { + f(self.is_allowed(idx as TokenId), idx); + } + } + + pub fn write_to(&self, buf: &mut [u8]) { + assert!(buf.len() == self.data.len() * 4); + bytemuck::cast_slice_mut(buf).copy_from_slice(&self.data); + } + #[inline(always)] pub fn allow_token(&mut self, tok: TokenId) { let idx = tok as usize; @@ -83,6 +195,7 @@ impl SimpleVob { let new_size = size / BITS + 1; assert!(new_size >= self.data.len()); self.data.resize(new_size, 0); + self.size = size; } #[inline(always)] @@ -96,6 +209,7 @@ impl SimpleVob { pub fn set_all(&mut self, val: bool) { let val = if val { !0 } else { 0 }; self.data.iter_mut().for_each(|x| *x = val); + self.clear_excessive_bits(); } pub fn apply_to(&self, logits: &mut [f32]) { @@ -111,6 +225,37 @@ impl SimpleVob { } } } + + pub fn iter(&self) -> SimpleVobIter { + SimpleVobIter { vob: self, idx: 0 } + } +} + +pub struct SimpleVobIter<'a> { + vob: &'a SimpleVob, + idx: usize, +} + +impl<'a> Iterator for SimpleVobIter<'a> { + type Item = u32; + + #[inline(always)] + fn next(&mut self) -> Option { + let mut bitoff = self.idx % BITS; + let mut dataoff = self.idx / BITS; + let data = &self.vob.data; + while dataoff < data.len() { + let d = data[dataoff] >> bitoff; + if d != 0 { + let idx = dataoff * BITS + d.trailing_zeros() as usize + bitoff; + self.idx = idx + 1; + return Some(idx as u32); + } + bitoff = 0; + dataoff += 1; + } + return None; + } } impl Index for SimpleVob { diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 638c277c..dd2d9644 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -1,6 +1,7 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 +use anyhow::Result; use rustc_hash::FxHashMap; use crate::{ @@ -8,7 +9,6 @@ use crate::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, to_hex_string, vec_from_bytes, TokRxInfo, TokenId, }, - host::trie_bytes, svob::SimpleVob, }; @@ -51,6 +51,10 @@ pub trait Recognizer { fn trie_started(&mut self) {} /// This combines `push_byte` and `byte_allowed` into one function for performance. fn try_push_byte(&mut self, byte: u8) -> bool; + /// Check if there are any errors to be reported to the user. + fn get_error(&self) -> Option { + None + } } #[derive(Clone)] @@ -126,11 +130,6 @@ impl TrieNode { const LEN_BITS: u32 = 10; impl TokTrie { - pub fn from_host() -> Self { - let buffer = trie_bytes(); - Self::from_bytes(&buffer) - } - pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { let mut trie = TrieHash::new(0xff); let mut token_offsets = Vec::new(); @@ -215,7 +214,7 @@ impl TokTrie { pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { let max_examples = 50; - let ts_neg = ts.negated(self.vocab_size()); + let ts_neg = ts.negated(); let use_neg = ts_neg.num_set() * 20 < ts.num_set(); let ts1 = if use_neg { &ts_neg } else { &ts }; let num_set = ts1.num_set(); @@ -539,19 +538,24 @@ impl TokTrie { } } - pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) { + pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) -> Result<()> { for t in ts { - self.append_token(r, *t) + self.append_token(r, *t)?; } + Ok(()) } - pub fn append_token(&self, r: &mut impl Recognizer, t: TokenId) { + pub fn append_token(&self, r: &mut impl Recognizer, t: TokenId) -> Result<()> { // println!("append_token: {}", self.token_dbg(t)); let bytes = self.token(t); for &byte in bytes { - r.push_byte(byte) + if !r.try_push_byte(byte) { + r.collapse(); + return Err(anyhow::anyhow!("byte {:?} not allowed", byte as char)); + } } - r.collapse() + r.collapse(); + Ok(()) } pub fn token_allowed(&self, r: &mut impl Recognizer, t: TokenId) -> bool { From 39107c680ee4210fcea5b2038cad7def3b586bcb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 16:00:34 +0000 Subject: [PATCH 217/301] allow passing external size limit --- controllers/aici_abi/src/rx.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/controllers/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs index 4d0397d0..627a14dc 100644 --- a/controllers/aici_abi/src/rx.rs +++ b/controllers/aici_abi/src/rx.rs @@ -21,7 +21,7 @@ pub struct RecRx { pub type RxStackRecognizer = StackRecognizer; impl RecRx { - pub fn from_rx(rx: &str) -> Result { + pub fn from_rx(rx: &str, size_limit: Option) -> Result { let rx = if rx.ends_with("$") { rx.to_string() } else { @@ -32,12 +32,13 @@ impl RecRx { } else { rx }; + // default to 16MB - it takes about 1s to build + let size_limit = size_limit.unwrap_or(16 << 20); let t0 = std::time::Instant::now(); - let size_mb = 3; // 3MB should be on the order of 50ms let cfg = dense::Config::new() .start_kind(regex_automata::dfa::StartKind::Anchored) - .dfa_size_limit(Some(size_mb << 20)) - .determinize_size_limit(Some(size_mb << 20)); + .dfa_size_limit(Some(size_limit)) + .determinize_size_limit(Some(size_limit)); let dfa = dense::Builder::new() .configure(cfg) .syntax(syntax::Config::new().unicode(false).utf8(false)) From e292d089140f31f4e06b0bac07e84c60a152a7d6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 16:01:16 +0000 Subject: [PATCH 218/301] adapt to API changes --- controllers/aici_abi/Cargo.toml | 1 + controllers/aici_abi/src/cfg.rs | 15 ++++++++------- controllers/aici_abi/src/host.rs | 10 ++++------ controllers/aici_abi/src/lib.rs | 6 +++--- controllers/aici_abi/src/recognizer.rs | 7 +++---- controllers/aici_abi/src/yesno.rs | 4 ++-- 6 files changed, 21 insertions(+), 22 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index cd30bc49..3878475e 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -16,6 +16,7 @@ cfgrammar = { version = "0.13.3", optional = true } lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } +bytemuck = "1.16.0" [features] default = ["cfg", "rx"] diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index 675a0beb..56393c78 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -1,7 +1,8 @@ +use crate::host::host_trie; use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; use crate::{ svob::SimpleVob, - toktree::{Recognizer, SpecialToken, TokTrie}, + toktree::{Recognizer, SpecialToken}, }; use anyhow::Result; use cfgrammar::{ @@ -170,8 +171,8 @@ impl CfgParser { .collect::>(); for ridx in grm.iter_rules() { - let rname = grm.rule_name_str(ridx); - if rname.to_uppercase() != rname { + let rule_name = grm.rule_name_str(ridx); + if rule_name.to_uppercase() != rule_name { continue; } for pidx in grm.rule_to_prods(ridx) { @@ -179,8 +180,8 @@ impl CfgParser { if let [Symbol::Token(tidx)] = toks { let idx = *tidx_to_pat_idx.get(&tidx).unwrap(); // this doesn't seem very useful - // friendly_pattern_names[idx] = rname.to_string(); - if rname == "SKIP" { + // friendly_pattern_names[idx] = rule_name.to_string(); + if rule_name == "SKIP" { skip_patterns.set(idx, true); } } @@ -506,7 +507,7 @@ pub fn cfg_test() -> Result<()> { let sample = include_bytes!("../grammars/sample.c"); if true { - let trie = TokTrie::from_host(); + let trie = host_trie(); let toks = trie.greedy_tokenize(sample); #[cfg(not(target_arch = "wasm32"))] @@ -537,7 +538,7 @@ pub fn cfg_test() -> Result<()> { ); cfg.viable_now(); } - trie.append_token(&mut cfg, tok); + trie.append_token(&mut cfg, tok).unwrap(); } #[cfg(not(target_arch = "wasm32"))] diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index d71dbfb8..2dd775dd 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -97,7 +97,7 @@ pub struct WasmTokenizerEnv { impl Default for WasmTokenizerEnv { fn default() -> Self { WasmTokenizerEnv { - toktrie: TokTrie::from_bytes(&trie_bytes()), + toktrie: host_trie(), } } } @@ -148,9 +148,7 @@ impl HostInterface for WasmHost { fn return_logit_bias(&self, vob: &SimpleVob) -> u32 { assert!(vob.len() > 0); - unsafe { - aici_host_return_logit_bias(vob.as_ptr()) - } + unsafe { aici_host_return_logit_bias(vob.as_ptr()) } } fn process_arg_bytes(&self) -> Vec { @@ -224,8 +222,8 @@ pub fn arg_string() -> String { String::from_utf8_lossy(&arg_bytes()).to_string() } -pub fn trie_bytes() -> Vec { - get_host().trie_bytes() +pub fn host_trie() -> TokTrie { + TokTrie::from_bytes(&get_host().trie_bytes()) // #[cfg(not(target_arch = "wasm32"))] // return std::fs::read("tokenizer.bin").unwrap(); } diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 659feac3..a85ba120 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -23,7 +23,7 @@ pub mod substring; pub type TokenId = bytes::TokenId; pub use host::{ - aici_stop, arg_bytes, arg_string, get_config, self_seq_id, tokenize, tokenize_bytes, + aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, StorageCmd, StorageOp, StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, }; @@ -54,7 +54,7 @@ pub struct SeqId(pub u32); pub struct MidProcessArg { /// Sampling result for the previous iteration. /// For simple sampled token 't', backtrack==0 and tokens==[t]. - /// For first request, backtrack==0 and tokens==[] (prompt is passed separetely, before). + /// For first request, backtrack==0 and tokens==[] (prompt is passed separately, before). /// Can be more complex when splices are used. pub backtrack: u32, pub tokens: Vec, @@ -98,7 +98,7 @@ Which means: when any token starting with '"' is sampled, we remove it (backtrac and then append the next full fragment of JSON '", "age": ' If the tokenizers has tokens like 'a"', 'b"' etc, then we would need many splices -(there may be limits how many we want to pass over the IPC boundry). +(there may be limits how many we want to pass over the IPC boundary). */ /// Describes what to do after sampling. diff --git a/controllers/aici_abi/src/recognizer.rs b/controllers/aici_abi/src/recognizer.rs index 57eabc09..50bd0cdb 100644 --- a/controllers/aici_abi/src/recognizer.rs +++ b/controllers/aici_abi/src/recognizer.rs @@ -1,6 +1,5 @@ use crate::{ - toktree::{Recognizer, SpecialToken, TokTrie}, - AiciCtrl, MidProcessArg, MidProcessResult, + host::host_trie, toktree::{Recognizer, SpecialToken, TokTrie}, AiciCtrl, MidProcessArg, MidProcessResult }; use std::fmt::Debug; @@ -12,7 +11,7 @@ pub struct AiciRecognizer { impl AiciRecognizer { pub fn from_recognizer(rec: R) -> Self { AiciRecognizer { - trie: TokTrie::from_host(), + trie: host_trie(), rec, } } @@ -23,7 +22,7 @@ impl AiciCtrl for AiciRecognizer { if arg.has_eos() { return MidProcessResult::stop(); } - self.trie.append_tokens(&mut self.rec, &arg.tokens); + self.trie.append_tokens(&mut self.rec, &arg.tokens).unwrap(); let mut set = self.trie.alloc_token_set(); self.trie.compute_bias(&mut self.rec, &mut set); MidProcessResult::sample(set) diff --git a/controllers/aici_abi/src/yesno.rs b/controllers/aici_abi/src/yesno.rs index dc16d2d2..1e021e0d 100644 --- a/controllers/aici_abi/src/yesno.rs +++ b/controllers/aici_abi/src/yesno.rs @@ -1,4 +1,4 @@ -use aici_abi::{tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; +use aici_abi::{host_trie, tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; pub struct Runner { toktrie: TokTrie, @@ -13,7 +13,7 @@ impl Runner { let no = tokenize("No")[0]; // ignore user-passed arg Runner { - toktrie: TokTrie::from_host(), + toktrie: host_trie(), tokens: Vec::new(), yes, no, From 2892d6ac80e8d7590a4862f0fc6b05c056c76851 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 16:31:09 +0000 Subject: [PATCH 219/301] use bytemuck instead of unsafe code --- controllers/aici_abi/Cargo.toml | 1 + controllers/aici_abi/src/bytes.rs | 43 ++++++++--------------------- controllers/aici_abi/src/toktree.rs | 26 ++++++++--------- 3 files changed, 25 insertions(+), 45 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 3878475e..9635fc70 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -17,6 +17,7 @@ lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } rustc-hash = { version = "1.1.0", optional = true } bytemuck = "1.16.0" +bytemuck_derive = "1.6.0" [features] default = ["cfg", "rx"] diff --git a/controllers/aici_abi/src/bytes.rs b/controllers/aici_abi/src/bytes.rs index 66ef0de0..7343a4e3 100644 --- a/controllers/aici_abi/src/bytes.rs +++ b/controllers/aici_abi/src/bytes.rs @@ -1,42 +1,27 @@ -use std::{mem::size_of, slice::from_raw_parts}; +use std::mem::size_of; use anyhow::{anyhow, Result}; +use bytemuck::{NoUninit, Pod}; +use bytemuck_derive::{Pod, Zeroable}; pub(crate) type TokenId = u32; +#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] -#[derive(Clone, PartialEq, Eq, Debug)] pub struct TokRxInfo { pub vocab_size: u32, pub tok_eos: TokenId, } -pub fn clone_vec_as_bytes(input: &[T]) -> Vec { - unsafe { - let byte_slice = from_raw_parts(input.as_ptr() as *const u8, input.len() * size_of::()); - byte_slice.to_vec() - } -} - -pub fn clone_as_bytes(input: &T) -> Vec { - unsafe { - let byte_slice = from_raw_parts(input as *const T as *const u8, size_of::()); - byte_slice.to_vec() - } -} +#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] +#[repr(C)] +pub struct U32Pair(pub u32, pub u32); -pub fn box_from_bytes(bytes: &[u8]) -> Box { - if bytes.len() != size_of::() { - panic!("T: got {} bytes, needed {}", bytes.len(), size_of::()); - } - let mut t: Box = Box::new(unsafe { std::mem::zeroed() }); - unsafe { - std::ptr::copy_nonoverlapping(bytes.as_ptr(), &mut *t as *mut T as *mut u8, size_of::()); - } - t +pub fn clone_vec_as_bytes(input: &[T]) -> Vec { + bytemuck::cast_slice(input).to_vec() } -pub fn vec_from_bytes(bytes: &[u8]) -> Vec { +pub fn vec_from_bytes(bytes: &[u8]) -> Vec { if bytes.len() % size_of::() != 0 { panic!( "vecT: got {} bytes, needed multiple of {}", @@ -44,13 +29,7 @@ pub fn vec_from_bytes(bytes: &[u8]) -> Vec { size_of::() ); } - let num_elements = bytes.len() / size_of::(); - let mut result = Vec::with_capacity(num_elements); - unsafe { - result.set_len(num_elements); - std::ptr::copy_nonoverlapping(bytes.as_ptr(), result.as_mut_ptr() as *mut u8, bytes.len()); - } - result + bytemuck::cast_slice(bytes).to_vec() } pub fn limit_str(s: &str, max_len: usize) -> String { diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index dd2d9644..18d9518b 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -2,13 +2,11 @@ // special case num_ch=0xff -> num_ch=0x100 use anyhow::Result; +use bytemuck_derive::{Pod, Zeroable}; use rustc_hash::FxHashMap; use crate::{ - bytes::{ - box_from_bytes, clone_as_bytes, clone_vec_as_bytes, to_hex_string, vec_from_bytes, - TokRxInfo, TokenId, - }, + bytes::{to_hex_string, vec_from_bytes, TokRxInfo, TokenId}, svob::SimpleVob, }; @@ -67,6 +65,7 @@ pub struct TokTrie { token_duplicates: FxHashMap>, } +#[derive(Clone, Copy, Zeroable, Pod)] #[repr(C)] pub struct TokTrieHeader { magic: u32, @@ -82,7 +81,7 @@ impl TokTrieHeader { const MAGIC: u32 = 0x558b6fd3; } -#[derive(Clone)] +#[derive(Clone, Copy, Zeroable, Pod)] #[repr(C)] pub struct TrieNode { // byte:token @@ -377,7 +376,8 @@ impl TokTrie { pub fn from_bytes(bytes: &[u8]) -> Self { let pref = std::mem::size_of::(); - let hd = *box_from_bytes::(&bytes[0..pref]); + let hd: &TokTrieHeader = bytemuck::from_bytes(&bytes[0..pref]); + assert!(hd.magic == TokTrieHeader::MAGIC); assert!(hd.hd_size as usize == pref); @@ -428,9 +428,9 @@ impl TokTrie { } pub fn serialize(&self) -> Vec { - let mut trie_data = clone_vec_as_bytes(&self.nodes); - let mut token_offsets = clone_vec_as_bytes(&self.token_offsets); - let mut token_data = clone_vec_as_bytes(&self.token_data); + let trie_data: &[u8] = bytemuck::cast_slice(&self.nodes); + let token_offsets: &[u8] = bytemuck::cast_slice(&self.token_offsets); + let token_data: &[u8] = bytemuck::cast_slice(&self.token_data); let hd = TokTrieHeader { magic: TokTrieHeader::MAGIC, @@ -442,10 +442,10 @@ impl TokTrie { align: [], }; - let mut bytes = clone_as_bytes(&hd); - bytes.append(&mut trie_data); - bytes.append(&mut token_offsets); - bytes.append(&mut token_data); + let mut bytes = bytemuck::bytes_of(&hd).to_vec(); + bytes.extend_from_slice(trie_data); + bytes.extend_from_slice(token_offsets); + bytes.extend_from_slice(token_data); bytes } From 3fca18d908161c811af299e559349b6af6d72769 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 18:21:32 +0000 Subject: [PATCH 220/301] rename confusing method --- controllers/aici_abi/src/cfg.rs | 6 +++--- controllers/aici_abi/src/lex.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index 56393c78..22f2a3c6 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -192,8 +192,8 @@ impl CfgParser { let mut vobset = VobSet::new(); // all-zero has to be inserted first - let _all0 = vobset.get(&vob![false; patterns.len()]); - let all1 = vobset.get(&vob![true; patterns.len()]); + let _all0 = vobset.insert_or_get(&vob![false; patterns.len()]); + let all1 = vobset.insert_or_get(&vob![true; patterns.len()]); // TIME: 27ms let dfa = Lexer::from(patterns, &mut vobset); @@ -226,7 +226,7 @@ impl CfgParser { } } - vobset.get(&r) + vobset.insert_or_get(&r) }) .collect::>(); diff --git a/controllers/aici_abi/src/lex.rs b/controllers/aici_abi/src/lex.rs index b8679b66..33f1c6bb 100644 --- a/controllers/aici_abi/src/lex.rs +++ b/controllers/aici_abi/src/lex.rs @@ -67,7 +67,7 @@ impl VobSet { } } - pub fn get(&mut self, vob: &Vob) -> VobIdx { + pub fn insert_or_get(&mut self, vob: &Vob) -> VobIdx { if let Some(idx) = self.by_vob.get(vob) { return *idx; } @@ -214,7 +214,7 @@ impl Lexer { let mut vobidx_by_state_off = vec![VobIdx::all_zero(); 1 + (states_idx.iter().max().unwrap() >> shift)]; for (k, v) in reachable_patterns.iter() { - vobidx_by_state_off[k.as_usize() >> shift] = vobset.get(v); + vobidx_by_state_off[k.as_usize() >> shift] = vobset.insert_or_get(v); } println!("initial: {:?}; {} states", initial, states.len()); From c9ffe9ad692ab7b29b88383c31ed810b63b1adbe Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 18:43:38 +0000 Subject: [PATCH 221/301] lex.rs ported to simplevob --- controllers/aici_abi/Cargo.toml | 2 +- controllers/aici_abi/src/svob.rs | 70 +++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index 9635fc70..d37e78d3 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -11,7 +11,7 @@ name = "aici_abi" serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" -regex-automata = { version = "0.4.3", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } +regex-automata = { version = "0.4.6", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } cfgrammar = { version = "0.13.3", optional = true } lrtable = { version = "0.13.3", optional = true } vob = { version = "3.0.3", optional = true } diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index a029d00d..09970c6f 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -1,5 +1,5 @@ use crate::TokenId; -use std::{fmt::Debug, ops::Index}; +use std::{fmt::Debug, hash::Hash, ops::Index}; #[derive(Clone)] pub struct SimpleVob { @@ -7,6 +7,21 @@ pub struct SimpleVob { size: usize, } +impl Hash for SimpleVob { + fn hash(&self, state: &mut H) { + self.size.hash(state); + self.data.hash(state); + } +} + +impl PartialEq for SimpleVob { + fn eq(&self, other: &Self) -> bool { + self.size == other.size && self.data == other.data + } +} + +impl Eq for SimpleVob {} + impl Debug for SimpleVob { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SimpleVob") @@ -169,25 +184,22 @@ impl SimpleVob { #[inline(always)] pub fn allow_token(&mut self, tok: TokenId) { - let idx = tok as usize; - let byte_idx = idx / BITS; - let bit_idx = idx % BITS; - self.data[byte_idx] |= 1 << bit_idx; + self.set(tok as usize, true) } #[inline(always)] pub fn disallow_token(&mut self, tok: TokenId) { - let idx = tok as usize; - let byte_idx = idx / BITS; - let bit_idx = idx % BITS; - self.data[byte_idx] &= !(1 << bit_idx); + self.set(tok as usize, false) } - pub fn set(&mut self, tok: TokenId, val: bool) { + #[inline(always)] + pub fn set(&mut self, idx: usize, val: bool) { + let byte_idx = idx / BITS; + let bit_idx = idx % BITS; if val { - self.allow_token(tok); + self.data[byte_idx] |= 1 << bit_idx; } else { - self.disallow_token(tok); + self.data[byte_idx] &= !(1 << bit_idx); } } @@ -199,13 +211,17 @@ impl SimpleVob { } #[inline(always)] - pub fn is_allowed(&self, tok: TokenId) -> bool { - let idx = tok as usize; + pub fn get(&self, idx: usize) -> bool { let byte_idx = idx / 32; let bit_idx = idx % 32; (self.data[byte_idx] & (1 << bit_idx)) != 0 } + #[inline(always)] + pub fn is_allowed(&self, tok: TokenId) -> bool { + self.get(tok as usize) + } + pub fn set_all(&mut self, val: bool) { let val = if val { !0 } else { 0 }; self.data.iter_mut().for_each(|x| *x = val); @@ -229,6 +245,32 @@ impl SimpleVob { pub fn iter(&self) -> SimpleVobIter { SimpleVobIter { vob: self, idx: 0 } } + + pub fn or(&mut self, other: &SimpleVob) { + assert_eq!(self.size, other.size); + for (idx, v) in self.data.iter_mut().zip(other.data.iter()) { + *idx |= *v; + } + } + + pub fn and(&mut self, other: &SimpleVob) { + assert_eq!(self.size, other.size); + for (idx, v) in self.data.iter_mut().zip(other.data.iter()) { + *idx &= *v; + } + } + + pub fn is_zero(&self) -> bool { + self.data.iter().all(|x| *x == 0) + } + + pub fn and_is_zero(&self, other: &SimpleVob) -> bool { + assert_eq!(self.size, other.size); + self.data + .iter() + .zip(other.data.iter()) + .all(|(a, b)| *a & *b == 0) + } } pub struct SimpleVobIter<'a> { From a7fe1cc33a1f575bfe9672e26fc5c687972de42e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 21 May 2024 20:22:18 +0000 Subject: [PATCH 222/301] more work on lexer --- controllers/aici_abi/src/svob.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 09970c6f..90e55c89 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -271,6 +271,15 @@ impl SimpleVob { .zip(other.data.iter()) .all(|(a, b)| *a & *b == 0) } + + pub fn first_bit_set(&self) -> Option { + for (idx, v) in self.data.iter().enumerate() { + if *v != 0 { + return Some(idx * BITS + v.trailing_zeros() as usize); + } + } + None + } } pub struct SimpleVobIter<'a> { From 3043b3587c78ea610010123765beab382d41cc97 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 22 May 2024 00:10:38 +0000 Subject: [PATCH 223/301] more work on lexer --- controllers/aici_abi/src/svob.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 90e55c89..e6d99e27 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -272,6 +272,17 @@ impl SimpleVob { .all(|(a, b)| *a & *b == 0) } + pub fn first_bit_set_here_and_in(&self, other: &SimpleVob) -> Option { + assert_eq!(self.size, other.size); + for (idx, (a, b)) in self.data.iter().zip(other.data.iter()).enumerate() { + let v = *a & *b; + if v != 0 { + return Some(idx * BITS + v.trailing_zeros() as usize); + } + } + None + } + pub fn first_bit_set(&self) -> Option { for (idx, v) in self.data.iter().enumerate() { if *v != 0 { From 24f4a3fdd35a09b2c1bd949a7c6aef979b7507ec Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 22 May 2024 18:58:43 +0000 Subject: [PATCH 224/301] add hashcons for future --- controllers/aici_abi/src/hashcons.rs | 83 ++++++++++++++++++++++++++++ controllers/aici_abi/src/lib.rs | 1 + 2 files changed, 84 insertions(+) create mode 100644 controllers/aici_abi/src/hashcons.rs diff --git a/controllers/aici_abi/src/hashcons.rs b/controllers/aici_abi/src/hashcons.rs new file mode 100644 index 00000000..c3fec38b --- /dev/null +++ b/controllers/aici_abi/src/hashcons.rs @@ -0,0 +1,83 @@ +use std::collections::HashMap; + +use bytemuck_derive::{Pod, Zeroable}; + +#[derive(Clone, Copy, Zeroable, Pod)] +#[repr(transparent)] +pub struct HashNode(u32); + +pub struct HashConstructor { + data: Vec, + hash: HashMap, u32>, +} + +pub struct HashNodeRef<'a> { + constructor: &'a HashConstructor, + node: HashNode, +} + +impl<'a> HashNodeRef<'a> { + pub fn head(&self) -> u8 { + self.constructor.head(self.node) + } + + pub fn children(&self) -> &[HashNode] { + self.constructor.children(self.node) + } + + pub fn iter(&'a self) -> impl Iterator> { + self.children().iter().map(move |&n| self.constructor.node_ref(n)) + } +} + +pub type HeadType = u8; + +impl HashConstructor { + fn mk_head(&self, head: HeadType, arity: usize) -> u32 { + assert!(arity < 1 << 20); + (head as u32) | ((arity as u32) << 8) + } + + pub fn node_ref(&self, node: HashNode) -> HashNodeRef { + HashNodeRef { + constructor: self, + node, + } + } + + pub fn mk(&mut self, head: HeadType, children: &[HashNode]) -> HashNode { + let mut data: Vec = Vec::with_capacity(1 + children.len()); + data.push(self.mk_head(head, children.len())); + for child in children { + data.push(child.0); + } + if let Some(r) = self.hash.get(&data) { + HashNode(*r) + } else { + let r = self.data.len() as u32; + self.data.extend_from_slice(&data); + self.hash.insert(data, r); + HashNode(r) + } + } + + pub fn get(&self, node: HashNode) -> (HeadType, &[HashNode]) { + let idx = node.0 as usize; + let head = self.data[idx]; + let arity = (head >> 8) as usize; + let head = head as u8; + let children = bytemuck::cast_slice(&self.data[idx + 1..idx + 1 + arity]); + (head, children) + } + + pub fn head(&self, node: HashNode) -> HeadType { + let idx = node.0 as usize; + (self.data[idx] & 0xff) as HeadType + } + + pub fn children(&self, node: HashNode) -> &[HashNode] { + let idx = node.0 as usize; + let arity = (self.data[idx] >> 8) as usize; + bytemuck::cast_slice(&self.data[idx + 1..idx + 1 + arity]) + } +} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index a85ba120..acdb8135 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -7,6 +7,7 @@ pub mod recognizer; pub mod rng; pub mod svob; pub mod toktree; +pub mod hashcons; #[cfg(feature = "cfg")] pub mod cfg; From 767aa5752511a215c94d71662e6776b82f4c6f97 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 22 May 2024 23:53:29 +0000 Subject: [PATCH 225/301] work on grammars --- controllers/aici_abi/src/hashcons.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/controllers/aici_abi/src/hashcons.rs b/controllers/aici_abi/src/hashcons.rs index c3fec38b..826c5e12 100644 --- a/controllers/aici_abi/src/hashcons.rs +++ b/controllers/aici_abi/src/hashcons.rs @@ -11,12 +11,12 @@ pub struct HashConstructor { hash: HashMap, u32>, } -pub struct HashNodeRef<'a> { +pub struct BoundNode<'a> { constructor: &'a HashConstructor, node: HashNode, } -impl<'a> HashNodeRef<'a> { +impl<'a> BoundNode<'a> { pub fn head(&self) -> u8 { self.constructor.head(self.node) } @@ -25,8 +25,8 @@ impl<'a> HashNodeRef<'a> { self.constructor.children(self.node) } - pub fn iter(&'a self) -> impl Iterator> { - self.children().iter().map(move |&n| self.constructor.node_ref(n)) + pub fn iter(&'a self) -> impl Iterator> { + self.children().iter().map(move |&n| self.constructor.bind(n)) } } @@ -38,8 +38,8 @@ impl HashConstructor { (head as u32) | ((arity as u32) << 8) } - pub fn node_ref(&self, node: HashNode) -> HashNodeRef { - HashNodeRef { + pub fn bind(&self, node: HashNode) -> BoundNode { + BoundNode { constructor: self, node, } From d74eb85bc802849e3a53015884598d52f50178e0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 24 May 2024 15:55:34 +0000 Subject: [PATCH 226/301] simplify --- controllers/aici_abi/src/hashcons.rs | 102 ++++++++------------------- 1 file changed, 31 insertions(+), 71 deletions(-) diff --git a/controllers/aici_abi/src/hashcons.rs b/controllers/aici_abi/src/hashcons.rs index 826c5e12..0d58eb12 100644 --- a/controllers/aici_abi/src/hashcons.rs +++ b/controllers/aici_abi/src/hashcons.rs @@ -1,83 +1,43 @@ -use std::collections::HashMap; +use std::{collections::HashMap, rc::Rc}; -use bytemuck_derive::{Pod, Zeroable}; - -#[derive(Clone, Copy, Zeroable, Pod)] -#[repr(transparent)] -pub struct HashNode(u32); - -pub struct HashConstructor { - data: Vec, - hash: HashMap, u32>, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct VecHolder { + data: Rc>, } -pub struct BoundNode<'a> { - constructor: &'a HashConstructor, - node: HashNode, -} - -impl<'a> BoundNode<'a> { - pub fn head(&self) -> u8 { - self.constructor.head(self.node) - } - - pub fn children(&self) -> &[HashNode] { - self.constructor.children(self.node) - } - - pub fn iter(&'a self) -> impl Iterator> { - self.children().iter().map(move |&n| self.constructor.bind(n)) - } +pub struct VecHashMap { + by_id: Vec, + by_data: HashMap, } -pub type HeadType = u8; - -impl HashConstructor { - fn mk_head(&self, head: HeadType, arity: usize) -> u32 { - assert!(arity < 1 << 20); - (head as u32) | ((arity as u32) << 8) - } - - pub fn bind(&self, node: HashNode) -> BoundNode { - BoundNode { - constructor: self, - node, +impl VecHashMap { + pub fn new() -> Self { + let mut r = VecHashMap { + by_id: Vec::new(), + by_data: HashMap::new(), + }; + r.insert(Vec::new()); + r + } + + pub fn insert(&mut self, data: Vec) -> u32 { + let holder = VecHolder { + data: Rc::new(data), + }; + if let Some(&id) = self.by_data.get(&holder) { + return id; } + let id = self.by_id.len() as u32; + self.by_id.push(holder.clone()); + self.by_data.insert(holder, id); + id } - pub fn mk(&mut self, head: HeadType, children: &[HashNode]) -> HashNode { - let mut data: Vec = Vec::with_capacity(1 + children.len()); - data.push(self.mk_head(head, children.len())); - for child in children { - data.push(child.0); - } - if let Some(r) = self.hash.get(&data) { - HashNode(*r) - } else { - let r = self.data.len() as u32; - self.data.extend_from_slice(&data); - self.hash.insert(data, r); - HashNode(r) - } - } - - pub fn get(&self, node: HashNode) -> (HeadType, &[HashNode]) { - let idx = node.0 as usize; - let head = self.data[idx]; - let arity = (head >> 8) as usize; - let head = head as u8; - let children = bytemuck::cast_slice(&self.data[idx + 1..idx + 1 + arity]); - (head, children) - } - - pub fn head(&self, node: HashNode) -> HeadType { - let idx = node.0 as usize; - (self.data[idx] & 0xff) as HeadType + pub fn get(&self, id: u32) -> Option<&[u32]> { + self.by_id.get(id as usize).map(|holder| &holder.data[..]) } - pub fn children(&self, node: HashNode) -> &[HashNode] { - let idx = node.0 as usize; - let arity = (self.data[idx] >> 8) as usize; - bytemuck::cast_slice(&self.data[idx + 1..idx + 1 + arity]) + pub fn len(&self) -> usize { + self.by_id.len() } } From d8ce4d9fec91c245a8385aa0cc316c0da3b0ae0f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 25 May 2024 21:31:28 +0000 Subject: [PATCH 227/301] starting derivative regexps --- controllers/aici_abi/src/hashcons.rs | 43 ---------------------------- controllers/aici_abi/src/lib.rs | 1 - 2 files changed, 44 deletions(-) delete mode 100644 controllers/aici_abi/src/hashcons.rs diff --git a/controllers/aici_abi/src/hashcons.rs b/controllers/aici_abi/src/hashcons.rs deleted file mode 100644 index 0d58eb12..00000000 --- a/controllers/aici_abi/src/hashcons.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::{collections::HashMap, rc::Rc}; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct VecHolder { - data: Rc>, -} - -pub struct VecHashMap { - by_id: Vec, - by_data: HashMap, -} - -impl VecHashMap { - pub fn new() -> Self { - let mut r = VecHashMap { - by_id: Vec::new(), - by_data: HashMap::new(), - }; - r.insert(Vec::new()); - r - } - - pub fn insert(&mut self, data: Vec) -> u32 { - let holder = VecHolder { - data: Rc::new(data), - }; - if let Some(&id) = self.by_data.get(&holder) { - return id; - } - let id = self.by_id.len() as u32; - self.by_id.push(holder.clone()); - self.by_data.insert(holder, id); - id - } - - pub fn get(&self, id: u32) -> Option<&[u32]> { - self.by_id.get(id as usize).map(|holder| &holder.data[..]) - } - - pub fn len(&self) -> usize { - self.by_id.len() - } -} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index acdb8135..a85ba120 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -7,7 +7,6 @@ pub mod recognizer; pub mod rng; pub mod svob; pub mod toktree; -pub mod hashcons; #[cfg(feature = "cfg")] pub mod cfg; From d45fe5d02c840fac03db732791c9bb2426edb5fd Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 28 May 2024 01:28:40 +0000 Subject: [PATCH 228/301] fix multi-matching; docs --- controllers/aici_abi/src/svob.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index e6d99e27..926f0de9 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -52,6 +52,14 @@ impl SimpleVob { } } + pub fn from_slice(bits: &[bool]) -> Self { + let mut r = Self::alloc(bits.len()); + for (idx, b) in bits.iter().enumerate() { + r.set(idx, *b); + } + r + } + pub fn alloc(size: usize) -> Self { let mut r = Self::new(); r.resize(size); @@ -73,6 +81,14 @@ impl SimpleVob { } } + pub fn to_bin_string(&self) -> String { + let mut s = String::new(); + for i in 0..self.size { + s.push(if self.is_allowed(i as TokenId) { '1' } else { '0' }); + } + s + } + pub fn negated(&self) -> Self { let mut r = Self::new(); r.data = self.data.iter().map(|x| !x).collect(); From 13ee330ae839dedfed94baa7ee16b191aa9c1342 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 5 Jun 2024 09:24:16 -0700 Subject: [PATCH 229/301] add tests and optimizations to derivre --- controllers/aici_abi/src/svob.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 926f0de9..ff6d9abb 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -1,6 +1,7 @@ -use crate::TokenId; use std::{fmt::Debug, hash::Hash, ops::Index}; +pub type TokenId = u32; + #[derive(Clone)] pub struct SimpleVob { data: Vec, @@ -66,6 +67,12 @@ impl SimpleVob { r } + pub fn all_true(size: usize) -> Self { + let mut r = Self::alloc(size); + r.set_all(true); + r + } + pub fn len(&self) -> usize { self.size } From 1cf38d1a6ef3534bc6622a4ab412f8cdff131c12 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 5 Jun 2024 17:30:30 +0000 Subject: [PATCH 230/301] refactor try_push_byte --- controllers/aici_abi/src/toktree.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 18d9518b..b7d19b62 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -20,12 +20,6 @@ pub enum SpecialToken { } pub trait Recognizer { - /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. - fn push_byte(&mut self, byte: u8) { - if !self.try_push_byte(byte) { - panic!("byte {:?} not allowed", byte as char) - } - } /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); /// X = stack.top(); stack.empty(); stack.push(X) @@ -562,6 +556,7 @@ impl TokTrie { let bytes = self.token(t); let mut num = 0; let mut ok = true; + r.trie_started(); for &byte in bytes { if r.try_push_byte(byte) { num += 1; @@ -571,6 +566,7 @@ impl TokTrie { } } r.pop_bytes(num); + r.trie_finished(); ok } From 06d7763cefddf9cb45fd64187da076af0f8aff5f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 6 Jun 2024 15:19:31 +0000 Subject: [PATCH 231/301] clean up lexer API --- controllers/aici_abi/src/toktree.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index b7d19b62..2ed71eda 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -476,6 +476,23 @@ impl TokTrie { None } + pub fn all_subtokens(&self, bytes: &[u8]) -> Vec { + let mut r = Vec::new(); + for i in 0..bytes.len() { + let mut n = self.root(); + for j in i..bytes.len() { + n = match self.child_at_byte(n, bytes[j]) { + Some(n) => n, + None => break, + }; + if let Some(tok) = n.token_id() { + r.push(tok); + } + } + } + r + } + pub fn node_children(&self, n: &TrieNode) -> NodeChildren { let off = self.node_offset(n); NodeChildren { From 1bd1e07e83859dad9f10256916edcb41bfa654e6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 7 Jun 2024 22:56:46 +0000 Subject: [PATCH 232/301] add a hack for tokenization of partial utf8 --- controllers/aici_abi/src/toktree.rs | 23 +++++++++++++++++++++++ controllers/aici_native/src/bintokens.rs | 16 ++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 2ed71eda..cce449d9 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -335,6 +335,29 @@ impl TokTrie { r } + pub fn tokenize_with_greedy_fallback( + &self, + s: &[u8], + str_tokenize: impl FnOnce(&str) -> Vec, + ) -> Vec { + let utf8_str = String::from_utf8_lossy(s); + // if the string ends with a replacement character, remove them + let to_tokenize = if utf8_str.ends_with('\u{FFFD}') { + utf8_str.trim_end_matches('\u{FFFD}') + } else { + &utf8_str + }; + let mut r = str_tokenize(to_tokenize); + // if we didn't tokenize everything (because of the replacement character) + // we tokenize the suffix using greedy tokenizer that is happy with bytes + let last_tokenized = to_tokenize.len(); + if last_tokenized < s.len() { + let mut added = self.greedy_tokenize(&s[last_tokenized..]); + r.append(&mut added); + } + r + } + pub fn has_extensions(&self, bytes: &[u8]) -> bool { match self.child_at_bytes(self.root(), bytes) { None => false, diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index da632c26..d1205938 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -371,13 +371,13 @@ impl TokenizerEnv for ByteTokenizerEnv { } fn tokenize_bytes(&self, s: &[u8]) -> Vec { - let tokens = self - .tokenizer - .hf_tokenizer - .encode(String::from_utf8_lossy(s), false); - match tokens { - Err(e) => panic!("tokenize error: {e}"), - Ok(tokens) => Vec::from(tokens.get_ids()), - } + self.tok_trie.tokenize_with_greedy_fallback(s, |s| { + self.tokenizer + .hf_tokenizer + .encode(s, false) + .expect("tokenizer error") + .get_ids() + .to_vec() + }) } } From ba4ff745bbee99499286828ec68acb7de1473214 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 25 Jun 2024 00:12:16 +0000 Subject: [PATCH 233/301] add one more EOS variation --- controllers/aici_native/src/bintokens.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index d1205938..60ceda95 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -260,7 +260,7 @@ impl ByteTokenizer { for (id, info) in added.iter() { if info.special { match info.content.as_str() { - "" | "<|endoftext|>" => res.eos_token = *id, + "" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id, _ => {} } res.special.insert(info.content.clone(), *id); From 89124a2f5cd4e0e6de891110dd8bac437e1d9946 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 26 Jun 2024 00:28:02 +0000 Subject: [PATCH 234/301] improve dbg output --- controllers/aici_abi/src/toktree.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index cce449d9..33d5f656 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -213,8 +213,12 @@ impl TokTrie { let num_set = ts1.num_set(); let max_tok = std::cmp::min(max_examples, num_set); let mut token_names = Vec::new(); + // make sure we include EOS first if it's allowed + if ts1.is_allowed(self.info.tok_eos) { + token_names.push("EOS".to_string()); + } for idx in 0..self.vocab_size() { - if ts1.is_allowed(idx as TokenId) { + if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) { token_names.push(self.token_dbg(idx as TokenId)); if token_names.len() >= max_tok { break; From b30ab8b7f22e37a91cfbe58a94af0f076396bd32 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 28 Jun 2024 00:45:01 +0000 Subject: [PATCH 235/301] add LLTokenizer.test_trace_tokens --- controllers/aici_abi/src/toktree.rs | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 33d5f656..dd846d8e 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -241,30 +241,35 @@ impl TokTrie { vec![0.0; self.vocab_size() + 1] } + pub fn test_trace_tokens(&self, toks: &[u32]) -> String { + toks.iter() + .map(|t| { + let s = self.token_dbg(*t); + if s.starts_with("\"") { + self.token_str(*t) + } else { + format!("≺{}≻", s) + } + }) + .collect::>() + .join("‧") + } + pub fn tokens_dbg(&self, toks: &[u32]) -> String { - let minimal = false; - let sep = "‧"; let joined = toks .iter() .map(|t| { let s = self.token_dbg(*t); if s.starts_with("\"") { - let inner = s[1..s.len() - 1].to_string(); - let b = s.as_bytes(); - // for " [\w]..." and " " the sep in front is implicit - if minimal && b[1] == b' ' && ((b[2] as char).is_alphanumeric() || b.len() == 3) - { - inner - } else { - format!("{}{}", sep, inner) - } + s[1..s.len() - 1].to_string() } else { format!("≺{}≻", s) } }) .collect::>() - .join(""); - format!("\"{}\"", joined.trim_start_matches(sep)) + .join("‧"); + + format!("\"{}\"", joined) } pub fn token_dbg(&self, idx: u32) -> String { From 1bb34fb01918f656606834a005b0c9d437d4f43b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 28 Jun 2024 18:17:00 +0000 Subject: [PATCH 236/301] better handling of nested grammars --- controllers/aici_abi/src/svob.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index ff6d9abb..9c648827 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -276,6 +276,15 @@ impl SimpleVob { } } + /// self |= other & !minus + pub fn or_minus(&mut self, other: &SimpleVob, minus: &SimpleVob) { + assert_eq!(self.size, other.size); + assert_eq!(self.size, minus.size); + for ((slf, oth), mn) in self.data.iter_mut().zip(other.data.iter()).zip(minus.data.iter()) { + *slf |= *oth & !*mn; + } + } + pub fn and(&mut self, other: &SimpleVob) { assert_eq!(self.size, other.size); for (idx, v) in self.data.iter_mut().zip(other.data.iter()) { From 613a27c8b9935ee8e89da3c0060a05db5799aae7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 2 Jul 2024 22:54:47 +0000 Subject: [PATCH 237/301] revamp eos handling --- controllers/aici_abi/src/toktree.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index dd846d8e..99271f01 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -558,15 +558,6 @@ impl TokTrie { } } } - // all prefixes of 'start' are also allowed - if start.len() > 0 { - for len in 1..=start.len() { - let bytes = &start[0..len]; - if let Some(tok) = self.token_id(bytes) { - logits.allow_token(tok); - } - } - } self.add_bias(r, logits, start); self.apply_duplicates(logits); } @@ -682,6 +673,16 @@ impl TokTrie { #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { + // all prefixes of 'start' are also allowed + if start.len() > 0 { + for len in 1..=start.len() { + let bytes = &start[0..len]; + if let Some(tok) = self.token_id(bytes) { + toks.allow_token(tok); + } + } + } + r.trie_started(); let n = self.child_at_bytes(self.root(), start).unwrap(); let defl_tok = self.vocab_size() as u32; From fecf72e0a7202edcfdc3d501e7caeb3e1916b0ef Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jul 2024 00:01:58 +0000 Subject: [PATCH 238/301] start on toktrie split --- controllers/toktrie/Cargo.toml | 15 + controllers/toktrie/README.md | 204 +++++++++++ controllers/toktrie/implementation.md | 153 +++++++++ .../{aici_abi => toktrie}/src/bytes.rs | 0 controllers/toktrie/src/lib.rs | 319 ++++++++++++++++++ .../{aici_abi => toktrie}/src/recognizer.rs | 3 +- controllers/{aici_abi => toktrie}/src/rng.rs | 0 controllers/{aici_abi => toktrie}/src/svob.rs | 0 .../{aici_abi => toktrie}/src/toktree.rs | 0 9 files changed, 693 insertions(+), 1 deletion(-) create mode 100644 controllers/toktrie/Cargo.toml create mode 100644 controllers/toktrie/README.md create mode 100644 controllers/toktrie/implementation.md rename controllers/{aici_abi => toktrie}/src/bytes.rs (100%) create mode 100644 controllers/toktrie/src/lib.rs rename controllers/{aici_abi => toktrie}/src/recognizer.rs (96%) rename controllers/{aici_abi => toktrie}/src/rng.rs (100%) rename controllers/{aici_abi => toktrie}/src/svob.rs (100%) rename controllers/{aici_abi => toktrie}/src/toktree.rs (100%) diff --git a/controllers/toktrie/Cargo.toml b/controllers/toktrie/Cargo.toml new file mode 100644 index 00000000..20a8dd63 --- /dev/null +++ b/controllers/toktrie/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "toktrie" +version = "0.1.0" +edition = "2021" + +[lib] +name = "toktrie" + +[dependencies] +serde = { version = "1.0.192", features = ["derive"] } +serde_json = "1.0.108" +anyhow = "1.0.75" +bytemuck = "1.16.0" +bytemuck_derive = "1.6.0" +rustc-hash = { version = "2.0.0" } diff --git a/controllers/toktrie/README.md b/controllers/toktrie/README.md new file mode 100644 index 00000000..a15dd336 --- /dev/null +++ b/controllers/toktrie/README.md @@ -0,0 +1,204 @@ +# aici_abi + +This crate specifies the application binary interface (ABI) for the AICI Controllers. +It also provides higher-level interfaces for implementing controllers. + +## Low-level interface + +Conceptually, the lowest level interface to AICI constraint is this: + +```rust +type TokenId = u32; +type SeqId = u32; + +trait AiciCtrl { + /// Called with the initial prompt. ~1000ms time limit. + fn init_prompt(prompt: Vec); + + /// Called before mid_process(), can fork or suspend. ~1ms. + fn pre_process() -> enum { + Stop, + Continue, // Same as Fork { num_forks: 1 } + Suspend, // skip this generation round + Fork { num_forks: u32 }, + } + + /// This is the main entry point for the module. ~20ms. + fn mid_process(fork_group: Vec) -> enum { + Stop, + SampleWithBias { bias: Vec }, + Splice { backtrack: u32, ff_tokens: Vec } + }; + + /// Called after tokens are appended. ~1ms. + fn post_process(tokens: Vec) -> enum { Stop, Continue }; +} +``` + +Tokens depend on the tokenizer used (eg., for Llama there 32000 tokens, and for GPT-4 there is ~100k). + +The actual binary interface is a bit more complicated, due +to limitations in passing values to and from Wasm. +A Wasm module instance is created for each token sequence. +Also, when the sequence forks (as in beam search), the module instance is cloned. +See the [AiciCtrl Rust trait](src/lib.rs) for details. + +A number of functions are exposed to the Wasm module. + +First, there are functions for accessing the current tokenizer: + +```rust +/// Given a byte sequence, return a sequence of token Ids. +fn tokenize_bytes(s: Vec) -> Vec; + +/// Represents trie of all tokens in the current tokenizer. +impl TokTrie { + /// Get Id for EOS token etc. + fn special_token(tok: SpecialToken) -> TokenId; + /// Number of tokens. + fn vocab_size() -> usize; + /// Convert token Id to bytes (often UTF-8 string). + fn token(token: TokenId) -> Vec; + /// Given a Recognizer, compute the set of allowed tokens. + fn compute_bias(rec: impl Recognizer) -> Vec; +} +``` + +Different forks in a sequence can communicate via shared variables: + +```rust +/// This can be looked up in fork_group. +fn self_seq_id() -> SeqId; + +trait VariableStorage { + fn get(name: str) -> Option>; + fn set(name: str, value: Vec); + fn append(name: str, value: Vec); +} +``` + +Additionally, the `stdout` and `stderr` file descriptors are captured by the runtime +and returned to user when streaming results. + +This interface may need to be extended in the future. + +## Byte stack interface + +The constraints are typically expressed on strings or bytes, not tokens. +To compute the set of tokens that match a string constraint, one needs go through all the possible tokens +and apply the constraint. +An efficient way to do this is walk a prefix tree (trie) of all tokens. +The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints +implementing the [following interface](src/toktree.rs): + +```rust +pub trait Recognizer { + /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. + fn push_byte(&mut self, byte: u8); + /// for _ in 0..num { stack.pop() } + fn pop_bytes(&mut self, num: usize); + /// X = stack.top(); stack.empty(); stack.push(X) + fn collapse(&mut self); + /// check if stack.top() transitions via byte to a viable state + fn byte_allowed(&mut self, byte: u8) -> bool; + /// check if stack.top() transitions via tok to a viable state + fn special_allowed(&mut self, tok: SpecialToken) -> bool; + /// Called when iteration over the trie is finished + /// Stack has exactly one element then. + fn trie_finished(&mut self); + /// This combines `push_byte` and `byte_allowed` into one function for performance. + fn try_push_byte(&mut self, byte: u8) -> bool; +} +``` + +The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. + +## Functional byte interface + +The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. + +```rust +pub trait FunctionalRecognizer { + /// Initial state + fn initial(&self) -> S; + /// Extend the recognizer with given byte. + fn append(&self, state: S, byte: u8) -> S; + /// Check if given byte is allowed in given state. + fn byte_allowed(&self, state: S, byte: u8) -> bool; + /// Check if given special token is allowed in given state. + fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; +} +``` + +These three layers add up to about 40k of compiled code (Wasm). + +## Regular expressions + +The `FunctionalRecognizer` interface is implemented for regular expressions. +The `S` type is the state of the DFA (Deterministic Finite Automaton) that recognizes the regular expression, +then `append()` and `byte_allowed()` are the standard DFA operations, +while `special_allowed()` is only implemented for end-of-sequence token +(which is allowed when the current state is accepting). + +## LR(1) grammars + +The `Recognizer` interface is implemented for LR(1) grammars and DFA-based lexers. + +The grammar uses inline syntax for the lexer: + +- `"keyword"` or `'keyword'` for keywords; any string works, eg. `"+="`, `"while"`, ... +- `"/.../"` or `'/.../'` for regular expressions; you cannot have both `'` and `"` in the regex + Special `SKIP` rule is used to indicate tokens that need to be skipped by the LR(1) parser (eg., whitespace and comments) + +The lexer has a DFA which recognizes all regexps and keywords +(a big disjunction, but with additional machinery to disambiguate between different branches). +It goes byte by byte, until the DFA gets to a dead state (from which no match is possible). +Then it goes back one byte and checks for match. +It prefers keywords over regexps. +If no match is found, an error is reported, which requires careful design of the lexical part of the grammar +(eg., see how the `white-space` rule below is prefix of the `pre-processor` rule). + +For example, this is fragment of [grammar for C](./grammars/c.y): + +```yacc +%start translation_unit +%% + +SKIP + : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment + | "///.*/" // line comment + | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor + | "/\n?[ \t\v\f]*/" // white-space + ; + +IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; + +CONSTANT + : "/0[xX][0-9a-fA-F]+[uUlL]*?/" + | "/0[0-9]+[uUlL]*?/" + ; + +STRING_LITERAL: '/"(\\.|[^\\"])*"/' ; + +primary_expression + : IDENTIFIER + | CONSTANT + | STRING_LITERAL + | "(" expression ")" + ; + +// ... + +enum_specifier + : "enum" "{" enumerator_list "}" + | "enum" IDENTIFIER "{" enumerator_list "}" + | "enum" IDENTIFIER + ; + +// ... + +translation_unit + : external_declaration + | translation_unit external_declaration + ; +``` diff --git a/controllers/toktrie/implementation.md b/controllers/toktrie/implementation.md new file mode 100644 index 00000000..bd766709 --- /dev/null +++ b/controllers/toktrie/implementation.md @@ -0,0 +1,153 @@ +# Implementation notes + +## Token trie + +The round nodes represent tokens, the square nodes do not have a corresponding token. + +The number (`num_parents`) specifies how many parents do you need to pop to get to the parent of the node which comes after our children in DFS order. + +We also keep the `token_id` and a `subtree_size` (which includes the node itself) in each node. +A bogus `token_id` is used for nodes that do not have a corresponding token. + +```mermaid +graph TD + root[ε, 0] -- a --> a((a, 1)) + root -- b --> b((b, 1)) + root -- c --> c((c, 1)) + a -- x --> ax((ax, 1)) + a -- y --> ay[ay, 1] + a -- z --> az((az, 2)) + az -- a --> azq((aza, 3)) + ay -- a --> ayq((aya, 1)) + ay -- b --> ayw((ayb, 2)) +``` + +Traversal algorithm - computing the set of tokens allowed by a stack-based recognizer. +The set is stored in `logits` array - entries with `0.0` are allowed. + +```rust +let mut logits = vec![-100.0; VOCAB_SIZE + 1]; +``` + +A simple version of traversal algorithm: + +```rust +fn traverse(n) { + // mark token as allowed; nodes without token use `token_id == VOCAB_SIZE` + logits[n.token_id] = 0.0; + for c in n.children { + // for every child that starts with an allowed byte + if byte_allowed(c.byte) { + push_byte(c.byte); + // traverse it + traverse(c); + pop_bytes(1); + } + } +} +``` + +Now, assume the tree is laid out in memory in DFS order: + +```rust +fn traverse(mut p) { + let endp = p + nodes[p].subtree_size; + p += 1; // move to first child + while p < endp { + let n = nodes[p]; + if byte_allowed(n.byte) { + push_byte(n.byte); + logits[n.token_id] = 0.0; + // p is moved by n.subtree_size + p = traverse(p); + pop_bytes(1); + } else { + p += n.subtree_size; + } + } +} +``` + +Now, we get rid of the recursion: + +```rust +let mut p = 0; +while p < nodes.len() { + let n = nodes[p]; + if byte_allowed(n.byte) { + push_byte(n.byte); + logits[n.token_id] = 0.0; + // if the node is a leaf, we need to pop all the parents + pop_bytes(if n.subtree_size == 1 { n.num_parents } else { 0 }); + // move to first child, or sibling if no children + p += 1; + } else { + // skip the children, and go to the sibling node + p += n.subtree_size; + // regardless if the node is a leaf, we need to pop all the parents + pop_bytes(n.num_parents - 1); + } +} +``` + +Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. +The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. + +## LR(1) parsing + +The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. +DFA has a single number as the state, while the state of the LR(1) is a stack of numbers. +The LR(1) action is determined based on the next token from the lexer and the top of the stack. + +The `Recognizer` interface also has a concept of stack, however every entry on that +stack contains a DFA state and an LR(1) stack. + +Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state, +while the LR(1) stack is copied unchanged (the memory is shared). + + +### Early error detection + +Consider the following invalid C program: + +```c +int 123456; +``` + +The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword. +The parser would reject `123456`, however only after all six characters of it have been read. +This is too late for the LLM. + +To detect such errors early, we compute a set of reachable tokens for each DFA state. +For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`). +The initial DFA state has a full set of tokens, while a state after `'i'` +has only `int`, `if`, and `ID`, +and a state after `'1'` includes only `INTLIT`. +In the picture below, each state is labelled by its reachable set, +and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity. + +```mermaid +graph LR + 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}")) + 0 -- "[a-z] - [i]" --> id(("{ID*}")) + 0 -- "[0-9]" --> const(("{INTLIT*}")) + const -- "[0-9]" --> const + const -- "[a-z]" --> bot["{}"] + i -- "[a-z0-9] - [nf]" --> id + id -- "[a-z0-9]" --> id + i -- "[n]" --> in(("{int,ID*}")) + in -- "[t]" --> int(("{int*,ID}")) + in -- "[a-z0-9] - [t]" --> id + int -- "[a-z0-9]" --> id + i -- "[f]" --> if(("{if*,ID}")) + if -- "[a-z0-9]" --> id +``` + +For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do +not immediately lead to an error. + +While parsing input, if the intersection of viable and reachable tokens is empty, we report an error. + +In the example above, the viable tokens after `int` do not include `INTLIT`, +and thus the parser fails immediately at `1`. + diff --git a/controllers/aici_abi/src/bytes.rs b/controllers/toktrie/src/bytes.rs similarity index 100% rename from controllers/aici_abi/src/bytes.rs rename to controllers/toktrie/src/bytes.rs diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs new file mode 100644 index 00000000..a85ba120 --- /dev/null +++ b/controllers/toktrie/src/lib.rs @@ -0,0 +1,319 @@ +use serde::{Deserialize, Serialize}; +use svob::SimpleVob; + +pub mod bytes; +mod host; +pub mod recognizer; +pub mod rng; +pub mod svob; +pub mod toktree; + +#[cfg(feature = "cfg")] +pub mod cfg; +#[cfg(feature = "cfg")] +mod lex; + +#[cfg(feature = "rx")] +pub mod rx; + +pub mod dlex; + +pub mod substring; + +pub type TokenId = bytes::TokenId; + +pub use host::{ + aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, + StorageCmd, StorageOp, StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, +}; + +#[cfg(not(target_arch = "wasm32"))] +pub use host::{set_host, HostInterface}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct InitPromptArg { + pub prompt: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct InitPromptResult { + pub prompt: Vec, +} + +impl InitPromptResult { + pub fn from_arg(arg: InitPromptArg) -> Self { + InitPromptResult { prompt: arg.prompt } + } +} + +#[repr(transparent)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +pub struct SeqId(pub u32); + +#[derive(Serialize, Deserialize, Debug)] +pub struct MidProcessArg { + /// Sampling result for the previous iteration. + /// For simple sampled token 't', backtrack==0 and tokens==[t]. + /// For first request, backtrack==0 and tokens==[] (prompt is passed separately, before). + /// Can be more complex when splices are used. + pub backtrack: u32, + pub tokens: Vec, + /// + pub fork_group: Vec, +} + +impl MidProcessArg { + pub fn has_eos(&self) -> bool { + let eos = host::eos_token(); + self.tokens.iter().any(|t| *t == eos) + } + + pub fn save_tokens(&self, acc_tokens: &mut Vec) { + let bt = self.backtrack as usize; + assert!( + bt <= acc_tokens.len(), + "attempting to backtrack past beginning" + ); + acc_tokens.truncate(acc_tokens.len() - bt); + acc_tokens.extend_from_slice(&self.tokens); + } +} + +/* +For example, if we're generating JSON, according to the following schema: +{ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + } +} + +Let's say we have generated: {"name": "something +We would use a single splice: + when_sampled: ['"', '",', '", '], + backtrack: 1, + ff_tokens: tokenize('", "age": ') +Which means: when any token starting with '"' is sampled, we remove it (backtrack: 1) +and then append the next full fragment of JSON '", "age": ' + +If the tokenizers has tokens like 'a"', 'b"' etc, then we would need many splices +(there may be limits how many we want to pass over the IPC boundary). +*/ + +/// Describes what to do after sampling. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Splice { + /// If one of the tokens in when_sampled is sampled, this sequence is appended. + /// When empty, this sequence is appended unconditionally, regardless of sampling. + pub when_sampled: Vec, + /// Backtrack this much before appending this sequence (this includes sampled token if any). + pub backtrack: u32, + /// Append these tokens after backtracking. + pub ff_tokens: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Branch { + /// If None, no sampling is performed. + /// If Some(set), only tokens from the set are allowed. + pub sample_mask: Option, + /// Override temperature for sampling. It may or may not be sticky. + pub temperature: Option, + /// Describes what to do after sampling. + /// If no sampling, there should be exactly one splice, with empty `when_sampled`. + pub splices: Vec, +} + +impl Clone for Branch { + fn clone(&self) -> Self { + Branch { + sample_mask: self.sample_mask.clone(), + temperature: self.temperature, + splices: self.splices.clone(), + } + } +} + +impl Branch { + pub fn map_mask(&self, f: F) -> Branch + where + F: FnOnce(&S) -> T, + { + Branch { + sample_mask: self.sample_mask.as_ref().map(f), + temperature: self.temperature, + splices: self.splices.clone(), + } + } + + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { + Branch { + sample_mask: None, + temperature: None, + splices: vec![Splice { + when_sampled: vec![], + backtrack, + ff_tokens, + }], + } + } + + pub fn noop() -> Self { + Self::splice(0, vec![]) + } +} + +#[derive(Debug)] +pub struct MidProcessResult { + /// Fork the request into multiple branches. + /// Typically, exactly one branch is returned. + /// If multiple branches are returned, they are executed in parallel. + /// If no branches are returned, the request is terminated. + pub branches: Vec>, +} + +impl MidProcessResult { + pub fn stop() -> Self { + MidProcessResult { branches: vec![] } + } + + pub fn sample(set: SimpleVob) -> Self { + Self::sample_with_temp(set, None) + } + + pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { + MidProcessResult { + branches: vec![Branch { + sample_mask: Some(set), + temperature: temperature, + splices: vec![], + }], + } + } + + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { + MidProcessResult { + branches: vec![Branch::splice(backtrack, ff_tokens)], + } + } + + pub fn noop() -> Self { + Self::splice(0, vec![]) + } + + pub fn is_stop(&self) -> bool { + self.branches.is_empty() + } +} + +#[derive(Serialize, Deserialize)] +pub struct ProcessResultOffset { + /// Branches use byte offsets into the bias tensor. + pub branches: Vec>, +} + +pub trait AiciCtrl { + /// Called with the initial prompt. ~1000ms time limit. + /// By default ignore prompt. + fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { + InitPromptResult::from_arg(arg) + } + + /// This is the main entry point for the module. ~20ms time limit. + fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; + + // Internals + fn aici_init_prompt(&mut self) { + let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); + let res = self.init_prompt(arg); + let res_bytes = serde_json::to_vec(&res).unwrap(); + host::return_process_result(&res_bytes); + } + + fn aici_mid_process(&mut self) { + let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) + .expect("aici_mid_process: failed to deserialize MidProcessArg"); + let res = self.mid_process(arg); + let mut used_logits = false; + let res = ProcessResultOffset { + branches: res + .branches + .into_iter() + .map(|b| { + b.map_mask(|vob| { + if used_logits { + panic!("aici_mid_process: multiple branches with sampling not yet supported"); + } + used_logits = true; + host::return_logit_bias(&vob) as usize + }) + }) + .collect(), + }; + let res_bytes = serde_json::to_vec(&res).expect("aici_mid_process: failed to serialize"); + host::return_process_result(&res_bytes); + } +} + +/// Expose method as extern "C", usage: +/// expose!(Foo::set_count(n: i32) -> i32); +/// Generates "C" function: +/// set_count(Foo *, i32) -> i32 +#[macro_export] +macro_rules! expose { + ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$method_name($($arg),*) + } + } + }; + ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { + #[no_mangle] + pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { + unsafe { + (&mut *self_).$field.$method_name($($arg),*) + } + } + }; +} + +#[macro_export] +macro_rules! aici_expose_all { + ($struct_name:ident, $new:expr) => { + $crate::expose!($struct_name::aici_mid_process() -> ()); + $crate::expose!($struct_name::aici_init_prompt() -> ()); + + #[no_mangle] + pub extern "C" fn aici_create() -> *mut $struct_name { + let b = Box::new($new); + Box::into_raw(b) + } + + #[no_mangle] + pub extern "C" fn aici_panic() { + panic!("aici_panic()") + } + } +} + +#[macro_export] +macro_rules! include_bytes_aligned { + ($align_ty:ty, $path:literal) => {{ + #[repr(C)] // guarantee 'bytes' comes after '_align' + pub struct AlignedAs { + pub _align: [Align; 0], + pub bytes: Bytes, + } + + // this assignment is made possible by CoerceUnsized + static ALIGNED: &AlignedAs<$align_ty, [u8]> = &AlignedAs { + _align: [], + bytes: *include_bytes!($path), + }; + + &ALIGNED.bytes + }}; +} diff --git a/controllers/aici_abi/src/recognizer.rs b/controllers/toktrie/src/recognizer.rs similarity index 96% rename from controllers/aici_abi/src/recognizer.rs rename to controllers/toktrie/src/recognizer.rs index 50bd0cdb..54e7bd12 100644 --- a/controllers/aici_abi/src/recognizer.rs +++ b/controllers/toktrie/src/recognizer.rs @@ -1,5 +1,6 @@ use crate::{ - host::host_trie, toktree::{Recognizer, SpecialToken, TokTrie}, AiciCtrl, MidProcessArg, MidProcessResult + toktree::{Recognizer, SpecialToken, TokTrie}, + AiciCtrl, MidProcessArg, MidProcessResult, }; use std::fmt::Debug; diff --git a/controllers/aici_abi/src/rng.rs b/controllers/toktrie/src/rng.rs similarity index 100% rename from controllers/aici_abi/src/rng.rs rename to controllers/toktrie/src/rng.rs diff --git a/controllers/aici_abi/src/svob.rs b/controllers/toktrie/src/svob.rs similarity index 100% rename from controllers/aici_abi/src/svob.rs rename to controllers/toktrie/src/svob.rs diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/toktrie/src/toktree.rs similarity index 100% rename from controllers/aici_abi/src/toktree.rs rename to controllers/toktrie/src/toktree.rs From f08f5f934dd8e9cebf53cb3c4d4892ec62e19ba9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jul 2024 00:14:10 +0000 Subject: [PATCH 239/301] make it build again --- controllers/aici_abi/Cargo.toml | 1 + controllers/aici_abi/src/host.rs | 7 +- controllers/aici_abi/src/lib.rs | 10 +- controllers/toktrie/src/lib.rs | 159 -------------------------- controllers/toktrie/src/recognizer.rs | 39 ++----- controllers/toktrie/src/toktree.rs | 2 +- 6 files changed, 16 insertions(+), 202 deletions(-) diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml index d37e78d3..9d8027a3 100644 --- a/controllers/aici_abi/Cargo.toml +++ b/controllers/aici_abi/Cargo.toml @@ -8,6 +8,7 @@ rust-version = "1.75.0" name = "aici_abi" [dependencies] +toktrie = { path = "../toktrie" } serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 2dd775dd..6e69c570 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -1,9 +1,4 @@ -use crate::{ - bytes::{vec_from_bytes, TokenId}, - svob::SimpleVob, - toktree::TokTrie, - SeqId, -}; +use crate::{bytes::vec_from_bytes, svob::SimpleVob, toktree::TokTrie, SeqId, TokenId}; use serde::{Deserialize, Serialize}; #[repr(transparent)] diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index a85ba120..cf51fa20 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,12 +1,10 @@ +pub use toktrie::{bytes, recognizer, rng, svob, toktree}; + use serde::{Deserialize, Serialize}; use svob::SimpleVob; -pub mod bytes; + mod host; -pub mod recognizer; -pub mod rng; -pub mod svob; -pub mod toktree; #[cfg(feature = "cfg")] pub mod cfg; @@ -20,7 +18,7 @@ pub mod dlex; pub mod substring; -pub type TokenId = bytes::TokenId; +pub type TokenId = toktrie::TokenId; pub use host::{ aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs index a85ba120..d30e806c 100644 --- a/controllers/toktrie/src/lib.rs +++ b/controllers/toktrie/src/lib.rs @@ -2,54 +2,13 @@ use serde::{Deserialize, Serialize}; use svob::SimpleVob; pub mod bytes; -mod host; pub mod recognizer; pub mod rng; pub mod svob; pub mod toktree; -#[cfg(feature = "cfg")] -pub mod cfg; -#[cfg(feature = "cfg")] -mod lex; - -#[cfg(feature = "rx")] -pub mod rx; - -pub mod dlex; - -pub mod substring; - pub type TokenId = bytes::TokenId; -pub use host::{ - aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, - StorageCmd, StorageOp, StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, -}; - -#[cfg(not(target_arch = "wasm32"))] -pub use host::{set_host, HostInterface}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitPromptArg { - pub prompt: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitPromptResult { - pub prompt: Vec, -} - -impl InitPromptResult { - pub fn from_arg(arg: InitPromptArg) -> Self { - InitPromptResult { prompt: arg.prompt } - } -} - -#[repr(transparent)] -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] -pub struct SeqId(pub u32); - #[derive(Serialize, Deserialize, Debug)] pub struct MidProcessArg { /// Sampling result for the previous iteration. @@ -58,16 +17,9 @@ pub struct MidProcessArg { /// Can be more complex when splices are used. pub backtrack: u32, pub tokens: Vec, - /// - pub fork_group: Vec, } impl MidProcessArg { - pub fn has_eos(&self) -> bool { - let eos = host::eos_token(); - self.tokens.iter().any(|t| *t == eos) - } - pub fn save_tokens(&self, acc_tokens: &mut Vec) { let bt = self.backtrack as usize; assert!( @@ -206,114 +158,3 @@ impl MidProcessResult { self.branches.is_empty() } } - -#[derive(Serialize, Deserialize)] -pub struct ProcessResultOffset { - /// Branches use byte offsets into the bias tensor. - pub branches: Vec>, -} - -pub trait AiciCtrl { - /// Called with the initial prompt. ~1000ms time limit. - /// By default ignore prompt. - fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { - InitPromptResult::from_arg(arg) - } - - /// This is the main entry point for the module. ~20ms time limit. - fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; - - // Internals - fn aici_init_prompt(&mut self) { - let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - let res = self.init_prompt(arg); - let res_bytes = serde_json::to_vec(&res).unwrap(); - host::return_process_result(&res_bytes); - } - - fn aici_mid_process(&mut self) { - let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) - .expect("aici_mid_process: failed to deserialize MidProcessArg"); - let res = self.mid_process(arg); - let mut used_logits = false; - let res = ProcessResultOffset { - branches: res - .branches - .into_iter() - .map(|b| { - b.map_mask(|vob| { - if used_logits { - panic!("aici_mid_process: multiple branches with sampling not yet supported"); - } - used_logits = true; - host::return_logit_bias(&vob) as usize - }) - }) - .collect(), - }; - let res_bytes = serde_json::to_vec(&res).expect("aici_mid_process: failed to serialize"); - host::return_process_result(&res_bytes); - } -} - -/// Expose method as extern "C", usage: -/// expose!(Foo::set_count(n: i32) -> i32); -/// Generates "C" function: -/// set_count(Foo *, i32) -> i32 -#[macro_export] -macro_rules! expose { - ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$method_name($($arg),*) - } - } - }; - ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$field.$method_name($($arg),*) - } - } - }; -} - -#[macro_export] -macro_rules! aici_expose_all { - ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::aici_mid_process() -> ()); - $crate::expose!($struct_name::aici_init_prompt() -> ()); - - #[no_mangle] - pub extern "C" fn aici_create() -> *mut $struct_name { - let b = Box::new($new); - Box::into_raw(b) - } - - #[no_mangle] - pub extern "C" fn aici_panic() { - panic!("aici_panic()") - } - } -} - -#[macro_export] -macro_rules! include_bytes_aligned { - ($align_ty:ty, $path:literal) => {{ - #[repr(C)] // guarantee 'bytes' comes after '_align' - pub struct AlignedAs { - pub _align: [Align; 0], - pub bytes: Bytes, - } - - // this assignment is made possible by CoerceUnsized - static ALIGNED: &AlignedAs<$align_ty, [u8]> = &AlignedAs { - _align: [], - bytes: *include_bytes!($path), - }; - - &ALIGNED.bytes - }}; -} diff --git a/controllers/toktrie/src/recognizer.rs b/controllers/toktrie/src/recognizer.rs index 54e7bd12..115728b9 100644 --- a/controllers/toktrie/src/recognizer.rs +++ b/controllers/toktrie/src/recognizer.rs @@ -1,35 +1,6 @@ -use crate::{ - toktree::{Recognizer, SpecialToken, TokTrie}, - AiciCtrl, MidProcessArg, MidProcessResult, -}; +use crate::toktree::{Recognizer, SpecialToken}; use std::fmt::Debug; -pub struct AiciRecognizer { - pub trie: TokTrie, - pub rec: R, -} - -impl AiciRecognizer { - pub fn from_recognizer(rec: R) -> Self { - AiciRecognizer { - trie: host_trie(), - rec, - } - } -} - -impl AiciCtrl for AiciRecognizer { - fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { - if arg.has_eos() { - return MidProcessResult::stop(); - } - self.trie.append_tokens(&mut self.rec, &arg.tokens).unwrap(); - let mut set = self.trie.alloc_token_set(); - self.trie.compute_bias(&mut self.rec, &mut set); - MidProcessResult::sample(set) - } -} - pub trait FunctionalRecognizer { /// Initial state fn initial(&self) -> S; @@ -37,6 +8,10 @@ pub trait FunctionalRecognizer { fn try_append(&self, state: S, byte: u8) -> Option; /// Check if given special token is allowed in given state. fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; + /// Get error message if recognizer is in error state. + fn get_error(&self, _state: S) -> Option { + None + } } #[derive(Clone)] @@ -98,6 +73,10 @@ impl> Recognizer for StackRecognizer self.rec.special_allowed(self.stack[self.stack_ptr], tok) } + fn get_error(&mut self) -> Option { + self.rec.get_error(self.stack[self.stack_ptr]) + } + #[inline(always)] fn try_push_byte(&mut self, byte: u8) -> bool { match self.rec.try_append(self.stack[self.stack_ptr], byte) { diff --git a/controllers/toktrie/src/toktree.rs b/controllers/toktrie/src/toktree.rs index 99271f01..ef2a3b4b 100644 --- a/controllers/toktrie/src/toktree.rs +++ b/controllers/toktrie/src/toktree.rs @@ -44,7 +44,7 @@ pub trait Recognizer { /// This combines `push_byte` and `byte_allowed` into one function for performance. fn try_push_byte(&mut self, byte: u8) -> bool; /// Check if there are any errors to be reported to the user. - fn get_error(&self) -> Option { + fn get_error(&mut self) -> Option { None } } From 22b9297a536ccdcd279171a1a724f4fa1b4e278e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 4 Jul 2024 00:34:54 +0000 Subject: [PATCH 240/301] share more types --- controllers/aici_abi/src/lib.rs | 108 +++--------------- controllers/toktrie/README.md | 154 +------------------------- controllers/toktrie/implementation.md | 59 ---------- controllers/toktrie/src/lib.rs | 62 ++++------- 4 files changed, 34 insertions(+), 349 deletions(-) diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index cf51fa20..41448b41 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -3,7 +3,6 @@ pub use toktrie::{bytes, recognizer, rng, svob, toktree}; use serde::{Deserialize, Serialize}; use svob::SimpleVob; - mod host; #[cfg(feature = "cfg")] @@ -77,90 +76,7 @@ impl MidProcessArg { } } -/* -For example, if we're generating JSON, according to the following schema: -{ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"} - } -} - -Let's say we have generated: {"name": "something -We would use a single splice: - when_sampled: ['"', '",', '", '], - backtrack: 1, - ff_tokens: tokenize('", "age": ') -Which means: when any token starting with '"' is sampled, we remove it (backtrack: 1) -and then append the next full fragment of JSON '", "age": ' - -If the tokenizers has tokens like 'a"', 'b"' etc, then we would need many splices -(there may be limits how many we want to pass over the IPC boundary). -*/ - -/// Describes what to do after sampling. -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Splice { - /// If one of the tokens in when_sampled is sampled, this sequence is appended. - /// When empty, this sequence is appended unconditionally, regardless of sampling. - pub when_sampled: Vec, - /// Backtrack this much before appending this sequence (this includes sampled token if any). - pub backtrack: u32, - /// Append these tokens after backtracking. - pub ff_tokens: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Branch { - /// If None, no sampling is performed. - /// If Some(set), only tokens from the set are allowed. - pub sample_mask: Option, - /// Override temperature for sampling. It may or may not be sticky. - pub temperature: Option, - /// Describes what to do after sampling. - /// If no sampling, there should be exactly one splice, with empty `when_sampled`. - pub splices: Vec, -} - -impl Clone for Branch { - fn clone(&self) -> Self { - Branch { - sample_mask: self.sample_mask.clone(), - temperature: self.temperature, - splices: self.splices.clone(), - } - } -} - -impl Branch { - pub fn map_mask(&self, f: F) -> Branch - where - F: FnOnce(&S) -> T, - { - Branch { - sample_mask: self.sample_mask.as_ref().map(f), - temperature: self.temperature, - splices: self.splices.clone(), - } - } - - pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { - Branch { - sample_mask: None, - temperature: None, - splices: vec![Splice { - when_sampled: vec![], - backtrack, - ff_tokens, - }], - } - } - - pub fn noop() -> Self { - Self::splice(0, vec![]) - } -} +pub use toktrie::{Branch, Splice}; #[derive(Debug)] pub struct MidProcessResult { @@ -172,6 +88,16 @@ pub struct MidProcessResult { } impl MidProcessResult { + pub fn from_branch(branch: Branch) -> Self { + if branch.is_stop() { + Self::stop() + } else { + MidProcessResult { + branches: vec![branch], + } + } + } + pub fn stop() -> Self { MidProcessResult { branches: vec![] } } @@ -181,19 +107,11 @@ impl MidProcessResult { } pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { - MidProcessResult { - branches: vec![Branch { - sample_mask: Some(set), - temperature: temperature, - splices: vec![], - }], - } + Self::from_branch(Branch::sample(set, temperature)) } pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { - MidProcessResult { - branches: vec![Branch::splice(backtrack, ff_tokens)], - } + Self::from_branch(Branch::splice(backtrack, ff_tokens)) } pub fn noop() -> Self { diff --git a/controllers/toktrie/README.md b/controllers/toktrie/README.md index a15dd336..035599b5 100644 --- a/controllers/toktrie/README.md +++ b/controllers/toktrie/README.md @@ -1,86 +1,6 @@ -# aici_abi +# toktrie - Token utility library -This crate specifies the application binary interface (ABI) for the AICI Controllers. -It also provides higher-level interfaces for implementing controllers. - -## Low-level interface - -Conceptually, the lowest level interface to AICI constraint is this: - -```rust -type TokenId = u32; -type SeqId = u32; - -trait AiciCtrl { - /// Called with the initial prompt. ~1000ms time limit. - fn init_prompt(prompt: Vec); - - /// Called before mid_process(), can fork or suspend. ~1ms. - fn pre_process() -> enum { - Stop, - Continue, // Same as Fork { num_forks: 1 } - Suspend, // skip this generation round - Fork { num_forks: u32 }, - } - - /// This is the main entry point for the module. ~20ms. - fn mid_process(fork_group: Vec) -> enum { - Stop, - SampleWithBias { bias: Vec }, - Splice { backtrack: u32, ff_tokens: Vec } - }; - - /// Called after tokens are appended. ~1ms. - fn post_process(tokens: Vec) -> enum { Stop, Continue }; -} -``` - -Tokens depend on the tokenizer used (eg., for Llama there 32000 tokens, and for GPT-4 there is ~100k). - -The actual binary interface is a bit more complicated, due -to limitations in passing values to and from Wasm. -A Wasm module instance is created for each token sequence. -Also, when the sequence forks (as in beam search), the module instance is cloned. -See the [AiciCtrl Rust trait](src/lib.rs) for details. - -A number of functions are exposed to the Wasm module. - -First, there are functions for accessing the current tokenizer: - -```rust -/// Given a byte sequence, return a sequence of token Ids. -fn tokenize_bytes(s: Vec) -> Vec; - -/// Represents trie of all tokens in the current tokenizer. -impl TokTrie { - /// Get Id for EOS token etc. - fn special_token(tok: SpecialToken) -> TokenId; - /// Number of tokens. - fn vocab_size() -> usize; - /// Convert token Id to bytes (often UTF-8 string). - fn token(token: TokenId) -> Vec; - /// Given a Recognizer, compute the set of allowed tokens. - fn compute_bias(rec: impl Recognizer) -> Vec; -} -``` - -Different forks in a sequence can communicate via shared variables: - -```rust -/// This can be looked up in fork_group. -fn self_seq_id() -> SeqId; - -trait VariableStorage { - fn get(name: str) -> Option>; - fn set(name: str, value: Vec); - fn append(name: str, value: Vec); -} -``` - -Additionally, the `stdout` and `stderr` file descriptors are captured by the runtime -and returned to user when streaming results. - -This interface may need to be extended in the future. +This crate provides a utility library for working with tokens and token tries. ## Byte stack interface @@ -132,73 +52,3 @@ pub trait FunctionalRecognizer { These three layers add up to about 40k of compiled code (Wasm). -## Regular expressions - -The `FunctionalRecognizer` interface is implemented for regular expressions. -The `S` type is the state of the DFA (Deterministic Finite Automaton) that recognizes the regular expression, -then `append()` and `byte_allowed()` are the standard DFA operations, -while `special_allowed()` is only implemented for end-of-sequence token -(which is allowed when the current state is accepting). - -## LR(1) grammars - -The `Recognizer` interface is implemented for LR(1) grammars and DFA-based lexers. - -The grammar uses inline syntax for the lexer: - -- `"keyword"` or `'keyword'` for keywords; any string works, eg. `"+="`, `"while"`, ... -- `"/.../"` or `'/.../'` for regular expressions; you cannot have both `'` and `"` in the regex - Special `SKIP` rule is used to indicate tokens that need to be skipped by the LR(1) parser (eg., whitespace and comments) - -The lexer has a DFA which recognizes all regexps and keywords -(a big disjunction, but with additional machinery to disambiguate between different branches). -It goes byte by byte, until the DFA gets to a dead state (from which no match is possible). -Then it goes back one byte and checks for match. -It prefers keywords over regexps. -If no match is found, an error is reported, which requires careful design of the lexical part of the grammar -(eg., see how the `white-space` rule below is prefix of the `pre-processor` rule). - -For example, this is fragment of [grammar for C](./grammars/c.y): - -```yacc -%start translation_unit -%% - -SKIP - : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment - | "///.*/" // line comment - | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor - | "/\n?[ \t\v\f]*/" // white-space - ; - -IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; - -CONSTANT - : "/0[xX][0-9a-fA-F]+[uUlL]*?/" - | "/0[0-9]+[uUlL]*?/" - ; - -STRING_LITERAL: '/"(\\.|[^\\"])*"/' ; - -primary_expression - : IDENTIFIER - | CONSTANT - | STRING_LITERAL - | "(" expression ")" - ; - -// ... - -enum_specifier - : "enum" "{" enumerator_list "}" - | "enum" IDENTIFIER "{" enumerator_list "}" - | "enum" IDENTIFIER - ; - -// ... - -translation_unit - : external_declaration - | translation_unit external_declaration - ; -``` diff --git a/controllers/toktrie/implementation.md b/controllers/toktrie/implementation.md index bd766709..29d5b28c 100644 --- a/controllers/toktrie/implementation.md +++ b/controllers/toktrie/implementation.md @@ -92,62 +92,3 @@ while p < nodes.len() { Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. - -## LR(1) parsing - -The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. -DFA has a single number as the state, while the state of the LR(1) is a stack of numbers. -The LR(1) action is determined based on the next token from the lexer and the top of the stack. - -The `Recognizer` interface also has a concept of stack, however every entry on that -stack contains a DFA state and an LR(1) stack. - -Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state, -while the LR(1) stack is copied unchanged (the memory is shared). - - -### Early error detection - -Consider the following invalid C program: - -```c -int 123456; -``` - -The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword. -The parser would reject `123456`, however only after all six characters of it have been read. -This is too late for the LLM. - -To detect such errors early, we compute a set of reachable tokens for each DFA state. -For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`). -The initial DFA state has a full set of tokens, while a state after `'i'` -has only `int`, `if`, and `ID`, -and a state after `'1'` includes only `INTLIT`. -In the picture below, each state is labelled by its reachable set, -and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity. - -```mermaid -graph LR - 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}")) - 0 -- "[a-z] - [i]" --> id(("{ID*}")) - 0 -- "[0-9]" --> const(("{INTLIT*}")) - const -- "[0-9]" --> const - const -- "[a-z]" --> bot["{}"] - i -- "[a-z0-9] - [nf]" --> id - id -- "[a-z0-9]" --> id - i -- "[n]" --> in(("{int,ID*}")) - in -- "[t]" --> int(("{int*,ID}")) - in -- "[a-z0-9] - [t]" --> id - int -- "[a-z0-9]" --> id - i -- "[f]" --> if(("{if*,ID}")) - if -- "[a-z0-9]" --> id -``` - -For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do -not immediately lead to an error. - -While parsing input, if the intersection of viable and reachable tokens is empty, we report an error. - -In the example above, the viable tokens after `int` do not include `INTLIT`, -and thus the parser fails immediately at `1`. - diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs index d30e806c..fcf426b3 100644 --- a/controllers/toktrie/src/lib.rs +++ b/controllers/toktrie/src/lib.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use svob::SimpleVob; pub mod bytes; pub mod recognizer; @@ -10,7 +9,7 @@ pub mod toktree; pub type TokenId = bytes::TokenId; #[derive(Serialize, Deserialize, Debug)] -pub struct MidProcessArg { +pub struct StepArg { /// Sampling result for the previous iteration. /// For simple sampled token 't', backtrack==0 and tokens==[t]. /// For first request, backtrack==0 and tokens==[] (prompt is passed separately, before). @@ -19,7 +18,7 @@ pub struct MidProcessArg { pub tokens: Vec, } -impl MidProcessArg { +impl StepArg { pub fn save_tokens(&self, acc_tokens: &mut Vec) { let bt = self.backtrack as usize; assert!( @@ -99,6 +98,18 @@ impl Branch { } } + pub fn stop() -> Self { + Branch { + sample_mask: None, + temperature: None, + splices: vec![], + } + } + + pub fn is_stop(&self) -> bool { + self.sample_mask.is_none() && self.splices.is_empty() + } + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { Branch { sample_mask: None, @@ -114,47 +125,12 @@ impl Branch { pub fn noop() -> Self { Self::splice(0, vec![]) } -} - -#[derive(Debug)] -pub struct MidProcessResult { - /// Fork the request into multiple branches. - /// Typically, exactly one branch is returned. - /// If multiple branches are returned, they are executed in parallel. - /// If no branches are returned, the request is terminated. - pub branches: Vec>, -} - -impl MidProcessResult { - pub fn stop() -> Self { - MidProcessResult { branches: vec![] } - } - pub fn sample(set: SimpleVob) -> Self { - Self::sample_with_temp(set, None) - } - - pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { - MidProcessResult { - branches: vec![Branch { - sample_mask: Some(set), - temperature: temperature, - splices: vec![], - }], - } - } - - pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { - MidProcessResult { - branches: vec![Branch::splice(backtrack, ff_tokens)], + pub fn sample(set: S, temperature: Option) -> Self { + Branch { + sample_mask: Some(set), + temperature, + splices: vec![], } } - - pub fn noop() -> Self { - Self::splice(0, vec![]) - } - - pub fn is_stop(&self) -> bool { - self.branches.is_empty() - } } From ab923e92dd97f296b206eba4c9a10722920093cf Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:21:17 +0000 Subject: [PATCH 241/301] move stuff around --- controllers/toktrie/src/bytes.rs | 9 --------- controllers/toktrie/src/lib.rs | 5 +++-- controllers/toktrie/src/toktree.rs | 18 ++++++++++-------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/controllers/toktrie/src/bytes.rs b/controllers/toktrie/src/bytes.rs index 7343a4e3..6aa39fdb 100644 --- a/controllers/toktrie/src/bytes.rs +++ b/controllers/toktrie/src/bytes.rs @@ -4,15 +4,6 @@ use anyhow::{anyhow, Result}; use bytemuck::{NoUninit, Pod}; use bytemuck_derive::{Pod, Zeroable}; -pub(crate) type TokenId = u32; - -#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] -#[repr(C)] -pub struct TokRxInfo { - pub vocab_size: u32, - pub tok_eos: TokenId, -} - #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] pub struct U32Pair(pub u32, pub u32); diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs index fcf426b3..a44dc7b7 100644 --- a/controllers/toktrie/src/lib.rs +++ b/controllers/toktrie/src/lib.rs @@ -4,9 +4,10 @@ pub mod bytes; pub mod recognizer; pub mod rng; pub mod svob; -pub mod toktree; +mod toktree; -pub type TokenId = bytes::TokenId; +pub use svob::{SimpleVob, SimpleVobIter}; +pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId}; #[derive(Serialize, Deserialize, Debug)] pub struct StepArg { diff --git a/controllers/toktrie/src/toktree.rs b/controllers/toktrie/src/toktree.rs index ef2a3b4b..4dc51349 100644 --- a/controllers/toktrie/src/toktree.rs +++ b/controllers/toktrie/src/toktree.rs @@ -6,10 +6,19 @@ use bytemuck_derive::{Pod, Zeroable}; use rustc_hash::FxHashMap; use crate::{ - bytes::{to_hex_string, vec_from_bytes, TokRxInfo, TokenId}, + bytes::{to_hex_string, vec_from_bytes}, svob::SimpleVob, }; +pub type TokenId = u32; + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] +#[repr(C)] +pub struct TokRxInfo { + pub vocab_size: u32, + pub tok_eos: TokenId, +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum SpecialToken { Unknown, @@ -737,13 +746,6 @@ impl<'a> Iterator for NodeChildren<'a> { } } -#[repr(C)] -pub struct TokenizerBin { - magic: u32, - tokens_bytes: u32, - tree_bytes: u32, -} - struct TrieHash { token_id: u32, byte: u8, From fd9423e921a58bb83713ee98de7af7c9a12b710d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:22:27 +0000 Subject: [PATCH 242/301] use toktrie name --- controllers/aici_abi/README.md | 2 +- controllers/aici_abi/src/cfg.rs | 2 +- controllers/aici_abi/src/dlex.rs | 2 +- controllers/aici_abi/src/host.rs | 2 +- controllers/aici_abi/src/lib.rs | 4 +++- controllers/aici_abi/src/rx.rs | 2 +- controllers/aici_abi/src/substring.rs | 2 +- controllers/aici_abi/src/yesno.rs | 2 +- controllers/aici_native/src/bintokens.rs | 2 +- 9 files changed, 11 insertions(+), 9 deletions(-) diff --git a/controllers/aici_abi/README.md b/controllers/aici_abi/README.md index a15dd336..0f81ce12 100644 --- a/controllers/aici_abi/README.md +++ b/controllers/aici_abi/README.md @@ -89,7 +89,7 @@ To compute the set of tokens that match a string constraint, one needs go throug and apply the constraint. An efficient way to do this is walk a prefix tree (trie) of all tokens. The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints -implementing the [following interface](src/toktree.rs): +implementing the [following interface](src/toktrie.rs): ```rust pub trait Recognizer { diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index 22f2a3c6..d28c14b3 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -2,7 +2,7 @@ use crate::host::host_trie; use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; use crate::{ svob::SimpleVob, - toktree::{Recognizer, SpecialToken}, + toktrie::{Recognizer, SpecialToken}, }; use anyhow::Result; use cfgrammar::{ diff --git a/controllers/aici_abi/src/dlex.rs b/controllers/aici_abi/src/dlex.rs index 02f04313..ce147f66 100644 --- a/controllers/aici_abi/src/dlex.rs +++ b/controllers/aici_abi/src/dlex.rs @@ -1,7 +1,7 @@ use crate::{ recognizer::{FunctionalRecognizer, StackRecognizer}, svob::SimpleVob, - toktree::SpecialToken, + toktrie::SpecialToken, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 6e69c570..bea6854c 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -1,4 +1,4 @@ -use crate::{bytes::vec_from_bytes, svob::SimpleVob, toktree::TokTrie, SeqId, TokenId}; +use crate::{bytes::vec_from_bytes, svob::SimpleVob, toktrie::TokTrie, SeqId, TokenId}; use serde::{Deserialize, Serialize}; #[repr(transparent)] diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 41448b41..26eca8ce 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,4 +1,6 @@ -pub use toktrie::{bytes, recognizer, rng, svob, toktree}; +pub use toktrie::{bytes, recognizer, rng, svob}; + +pub use toktrie; use serde::{Deserialize, Serialize}; use svob::SimpleVob; diff --git a/controllers/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs index 627a14dc..883fd05b 100644 --- a/controllers/aici_abi/src/rx.rs +++ b/controllers/aici_abi/src/rx.rs @@ -2,7 +2,7 @@ use std::error::Error; use crate::{ recognizer::{FunctionalRecognizer, StackRecognizer}, - toktree::SpecialToken, + toktrie::SpecialToken, }; use anyhow::{bail, Result}; use regex_automata::{ diff --git a/controllers/aici_abi/src/substring.rs b/controllers/aici_abi/src/substring.rs index 55f80eef..b8be55d5 100644 --- a/controllers/aici_abi/src/substring.rs +++ b/controllers/aici_abi/src/substring.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use crate::{ bytes::limit_bytes, recognizer::{FunctionalRecognizer, StackRecognizer}, - toktree::SpecialToken, + toktrie::SpecialToken, }; use serde_json::json; diff --git a/controllers/aici_abi/src/yesno.rs b/controllers/aici_abi/src/yesno.rs index 1e021e0d..78b574c3 100644 --- a/controllers/aici_abi/src/yesno.rs +++ b/controllers/aici_abi/src/yesno.rs @@ -1,4 +1,4 @@ -use aici_abi::{host_trie, tokenize, toktree::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; +use aici_abi::{host_trie, tokenize, toktrie::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; pub struct Runner { toktrie: TokTrie, diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index 60ceda95..b4f2e203 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -1,4 +1,4 @@ -use aici_abi::{bytes::TokRxInfo, toktree::TokTrie, TokenId, TokenizerEnv}; +use aici_abi::{toktrie::TokRxInfo, toktrie::TokTrie, TokenId, TokenizerEnv}; use anyhow::{anyhow, bail, Result}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; From ef749003d4ff4d704ae5462467862a076be4ec95 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:26:33 +0000 Subject: [PATCH 243/301] don't use 'svob' as name --- controllers/aici_abi/src/cfg.rs | 2 +- controllers/aici_abi/src/dlex.rs | 2 +- controllers/aici_abi/src/host.rs | 2 +- controllers/aici_abi/src/lib.rs | 5 ++--- controllers/toktrie/src/lib.rs | 2 +- controllers/toktrie/src/toktree.rs | 2 +- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs index d28c14b3..c0fb412e 100644 --- a/controllers/aici_abi/src/cfg.rs +++ b/controllers/aici_abi/src/cfg.rs @@ -1,8 +1,8 @@ use crate::host::host_trie; use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; use crate::{ - svob::SimpleVob, toktrie::{Recognizer, SpecialToken}, + SimpleVob, }; use anyhow::Result; use cfgrammar::{ diff --git a/controllers/aici_abi/src/dlex.rs b/controllers/aici_abi/src/dlex.rs index ce147f66..df275fb7 100644 --- a/controllers/aici_abi/src/dlex.rs +++ b/controllers/aici_abi/src/dlex.rs @@ -1,7 +1,7 @@ use crate::{ recognizer::{FunctionalRecognizer, StackRecognizer}, - svob::SimpleVob, toktrie::SpecialToken, + SimpleVob, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index bea6854c..19281468 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -1,4 +1,4 @@ -use crate::{bytes::vec_from_bytes, svob::SimpleVob, toktrie::TokTrie, SeqId, TokenId}; +use crate::{bytes::vec_from_bytes, toktrie::TokTrie, SeqId, SimpleVob, TokenId}; use serde::{Deserialize, Serialize}; #[repr(transparent)] diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 26eca8ce..85ab7cdf 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,9 +1,8 @@ -pub use toktrie::{bytes, recognizer, rng, svob}; - pub use toktrie; +pub use toktrie::SimpleVob; +pub use toktrie::{bytes, recognizer, rng}; use serde::{Deserialize, Serialize}; -use svob::SimpleVob; mod host; diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs index a44dc7b7..e8821c0c 100644 --- a/controllers/toktrie/src/lib.rs +++ b/controllers/toktrie/src/lib.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; pub mod bytes; pub mod recognizer; pub mod rng; -pub mod svob; +mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; diff --git a/controllers/toktrie/src/toktree.rs b/controllers/toktrie/src/toktree.rs index 4dc51349..44d25007 100644 --- a/controllers/toktrie/src/toktree.rs +++ b/controllers/toktrie/src/toktree.rs @@ -7,7 +7,7 @@ use rustc_hash::FxHashMap; use crate::{ bytes::{to_hex_string, vec_from_bytes}, - svob::SimpleVob, + SimpleVob, }; pub type TokenId = u32; From 413b6edaaa3f35032e657395256fb6722193f5f0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:31:53 +0000 Subject: [PATCH 244/301] remove text moved to toktrie crate --- controllers/aici_abi/README.md | 51 +------------- controllers/aici_abi/implementation.md | 93 -------------------------- 2 files changed, 3 insertions(+), 141 deletions(-) diff --git a/controllers/aici_abi/README.md b/controllers/aici_abi/README.md index 0f81ce12..b1df3c30 100644 --- a/controllers/aici_abi/README.md +++ b/controllers/aici_abi/README.md @@ -82,55 +82,10 @@ and returned to user when streaming results. This interface may need to be extended in the future. -## Byte stack interface +See the `toktrie` crate for general utilities for building constraints. +This crate implements a few constraints including regexes, LR(1) grammars, and +substrings. -The constraints are typically expressed on strings or bytes, not tokens. -To compute the set of tokens that match a string constraint, one needs go through all the possible tokens -and apply the constraint. -An efficient way to do this is walk a prefix tree (trie) of all tokens. -The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints -implementing the [following interface](src/toktrie.rs): - -```rust -pub trait Recognizer { - /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. - fn push_byte(&mut self, byte: u8); - /// for _ in 0..num { stack.pop() } - fn pop_bytes(&mut self, num: usize); - /// X = stack.top(); stack.empty(); stack.push(X) - fn collapse(&mut self); - /// check if stack.top() transitions via byte to a viable state - fn byte_allowed(&mut self, byte: u8) -> bool; - /// check if stack.top() transitions via tok to a viable state - fn special_allowed(&mut self, tok: SpecialToken) -> bool; - /// Called when iteration over the trie is finished - /// Stack has exactly one element then. - fn trie_finished(&mut self); - /// This combines `push_byte` and `byte_allowed` into one function for performance. - fn try_push_byte(&mut self, byte: u8) -> bool; -} -``` - -The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. - -## Functional byte interface - -The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. - -```rust -pub trait FunctionalRecognizer { - /// Initial state - fn initial(&self) -> S; - /// Extend the recognizer with given byte. - fn append(&self, state: S, byte: u8) -> S; - /// Check if given byte is allowed in given state. - fn byte_allowed(&self, state: S, byte: u8) -> bool; - /// Check if given special token is allowed in given state. - fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; -} -``` - -These three layers add up to about 40k of compiled code (Wasm). ## Regular expressions diff --git a/controllers/aici_abi/implementation.md b/controllers/aici_abi/implementation.md index bd766709..1fadb63c 100644 --- a/controllers/aici_abi/implementation.md +++ b/controllers/aici_abi/implementation.md @@ -1,98 +1,5 @@ # Implementation notes -## Token trie - -The round nodes represent tokens, the square nodes do not have a corresponding token. - -The number (`num_parents`) specifies how many parents do you need to pop to get to the parent of the node which comes after our children in DFS order. - -We also keep the `token_id` and a `subtree_size` (which includes the node itself) in each node. -A bogus `token_id` is used for nodes that do not have a corresponding token. - -```mermaid -graph TD - root[ε, 0] -- a --> a((a, 1)) - root -- b --> b((b, 1)) - root -- c --> c((c, 1)) - a -- x --> ax((ax, 1)) - a -- y --> ay[ay, 1] - a -- z --> az((az, 2)) - az -- a --> azq((aza, 3)) - ay -- a --> ayq((aya, 1)) - ay -- b --> ayw((ayb, 2)) -``` - -Traversal algorithm - computing the set of tokens allowed by a stack-based recognizer. -The set is stored in `logits` array - entries with `0.0` are allowed. - -```rust -let mut logits = vec![-100.0; VOCAB_SIZE + 1]; -``` - -A simple version of traversal algorithm: - -```rust -fn traverse(n) { - // mark token as allowed; nodes without token use `token_id == VOCAB_SIZE` - logits[n.token_id] = 0.0; - for c in n.children { - // for every child that starts with an allowed byte - if byte_allowed(c.byte) { - push_byte(c.byte); - // traverse it - traverse(c); - pop_bytes(1); - } - } -} -``` - -Now, assume the tree is laid out in memory in DFS order: - -```rust -fn traverse(mut p) { - let endp = p + nodes[p].subtree_size; - p += 1; // move to first child - while p < endp { - let n = nodes[p]; - if byte_allowed(n.byte) { - push_byte(n.byte); - logits[n.token_id] = 0.0; - // p is moved by n.subtree_size - p = traverse(p); - pop_bytes(1); - } else { - p += n.subtree_size; - } - } -} -``` - -Now, we get rid of the recursion: - -```rust -let mut p = 0; -while p < nodes.len() { - let n = nodes[p]; - if byte_allowed(n.byte) { - push_byte(n.byte); - logits[n.token_id] = 0.0; - // if the node is a leaf, we need to pop all the parents - pop_bytes(if n.subtree_size == 1 { n.num_parents } else { 0 }); - // move to first child, or sibling if no children - p += 1; - } else { - // skip the children, and go to the sibling node - p += n.subtree_size; - // regardless if the node is a leaf, we need to pop all the parents - pop_bytes(n.num_parents - 1); - } -} -``` - -Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. -The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. - ## LR(1) parsing The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. From 134811d86c5d2aa73d30461933f5db02fa0e611e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:48:29 +0000 Subject: [PATCH 245/301] clean up deps --- controllers/aici_abi/src/host.rs | 14 +------------- controllers/aici_abi/src/lib.rs | 4 ++-- controllers/toktrie/src/lib.rs | 4 +++- controllers/toktrie/src/toktree.rs | 13 +++++++++++++ 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs index 19281468..222666e8 100644 --- a/controllers/aici_abi/src/host.rs +++ b/controllers/aici_abi/src/host.rs @@ -1,5 +1,6 @@ use crate::{bytes::vec_from_bytes, toktrie::TokTrie, SeqId, SimpleVob, TokenId}; use serde::{Deserialize, Serialize}; +use toktrie::TokenizerEnv; #[repr(transparent)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -72,19 +73,6 @@ pub extern "C" fn aici_init() { set_host(Box::new(WasmHost {})); } -pub trait TokenizerEnv: Send { - fn stop(&self) -> !; - fn tok_trie(&self) -> &TokTrie; - fn tokenize_bytes(&self, s: &[u8]) -> Vec; - - fn tokenize(&self, s: &str) -> Vec { - self.tokenize_bytes(s.as_bytes()) - } - fn eos_token(&self) -> TokenId { - self.tok_trie().eos_token() - } -} - pub struct WasmTokenizerEnv { toktrie: TokTrie, } diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 85ab7cdf..72024825 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -1,6 +1,6 @@ pub use toktrie; -pub use toktrie::SimpleVob; pub use toktrie::{bytes, recognizer, rng}; +pub use toktrie::{SimpleVob, TokenizerEnv}; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ pub type TokenId = toktrie::TokenId; pub use host::{ aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, - StorageCmd, StorageOp, StorageResp, TokenizerEnv, VariableStorage, WasmTokenizerEnv, + StorageCmd, StorageOp, StorageResp, VariableStorage, WasmTokenizerEnv, }; #[cfg(not(target_arch = "wasm32"))] diff --git a/controllers/toktrie/src/lib.rs b/controllers/toktrie/src/lib.rs index e8821c0c..1e2bba29 100644 --- a/controllers/toktrie/src/lib.rs +++ b/controllers/toktrie/src/lib.rs @@ -7,7 +7,7 @@ mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; -pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId}; +pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; #[derive(Serialize, Deserialize, Debug)] pub struct StepArg { @@ -135,3 +135,5 @@ impl Branch { } } } + +pub type StepResult = Branch; \ No newline at end of file diff --git a/controllers/toktrie/src/toktree.rs b/controllers/toktrie/src/toktree.rs index 44d25007..495cd53d 100644 --- a/controllers/toktrie/src/toktree.rs +++ b/controllers/toktrie/src/toktree.rs @@ -58,6 +58,19 @@ pub trait Recognizer { } } +pub trait TokenizerEnv: Send { + fn stop(&self) -> !; + fn tok_trie(&self) -> &TokTrie; + fn tokenize_bytes(&self, s: &[u8]) -> Vec; + + fn tokenize(&self, s: &str) -> Vec { + self.tokenize_bytes(s.as_bytes()) + } + fn eos_token(&self) -> TokenId { + self.tok_trie().eos_token() + } +} + #[derive(Clone)] pub struct TokTrie { info: TokRxInfo, From 3237a0b1112ddb95194a294702b6fd7a6265f064 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 23:12:26 +0000 Subject: [PATCH 246/301] add toktrie_hf_tokenizers --- controllers/aici_native/Cargo.toml | 1 + controllers/aici_native/src/bintokens.rs | 232 +------------------ controllers/toktrie_hf_tokenizers/Cargo.toml | 13 ++ controllers/toktrie_hf_tokenizers/src/lib.rs | 226 ++++++++++++++++++ 4 files changed, 243 insertions(+), 229 deletions(-) create mode 100644 controllers/toktrie_hf_tokenizers/Cargo.toml create mode 100644 controllers/toktrie_hf_tokenizers/src/lib.rs diff --git a/controllers/aici_native/Cargo.toml b/controllers/aici_native/Cargo.toml index 9dec1862..e98ef4d3 100644 --- a/controllers/aici_native/Cargo.toml +++ b/controllers/aici_native/Cargo.toml @@ -8,6 +8,7 @@ name = "aici_native" [dependencies] aici_abi = { path = "../aici_abi" } +toktrie_hf_tokenizers = { path = "../toktrie_hf_tokenizers" } serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs index b4f2e203..e47094a8 100644 --- a/controllers/aici_native/src/bintokens.rs +++ b/controllers/aici_native/src/bintokens.rs @@ -1,19 +1,7 @@ -use aici_abi::{toktrie::TokRxInfo, toktrie::TokTrie, TokenId, TokenizerEnv}; -use anyhow::{anyhow, bail, Result}; -use rustc_hash::FxHashMap; -use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; -use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; +use anyhow::{anyhow, Result}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; -#[derive(Serialize, Deserialize)] -pub struct ByteTokenizer { - pub hf_model: String, - pub hf_tokenizer: Tokenizer, - pub eos_token: u32, - pub vocab_size: u32, - token_bytes: Vec>, - pub special: BTreeMap, -} +pub use toktrie_hf_tokenizers::{ByteTokenizer, ByteTokenizerEnv}; pub struct TokenizerInfo { pub name: &'static str, @@ -87,30 +75,6 @@ pub fn tokenizers() -> Vec { ] } -// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi - -fn is_self_mapped(c: char) -> bool { - match c { - '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}' => true, - _ => false, - } -} - -fn build_char_map() -> FxHashMap { - let mut res = FxHashMap::default(); - let mut k = 0x100u32; - for byte in 0..=255u8 { - let c = byte as char; - if is_self_mapped(c) { - res.insert(c, byte); - } else { - res.insert(char::from_u32(k).unwrap(), byte); - k += 1; - } - } - res -} - pub fn list_tokenizers() -> String { format!( "Available tokenizers for -t or --tokenizer:\n{}\n{}\n{}", @@ -191,193 +155,3 @@ pub fn find_tokenizer(mut name: &str) -> Result { } } } - -impl ByteTokenizer { - pub fn from_tokenizer(mut hft: Tokenizer) -> Result { - let mut is_byte_level = false; - let mut is_byte_fallback = false; - let mut space_ch = ' '; - - // remove the "Prepend space" - if let Some(n) = hft.get_normalizer() { - let n = match n { - NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new( - x.get_normalizers() - .iter() - .filter_map(|n| match n { - NormalizerWrapper::Prepend(_) => None, - _ => Some(n.clone()), - }) - .collect(), - )), - _ => n.clone(), - }; - hft.with_normalizer(n); - } - - if let Some(d) = hft.get_decoder() { - // DecoderWrapper::Sequence() doesn't let one access the decoders - // so we resort to json munching - let v = serde_json::to_value(d).unwrap(); - if v["type"].as_str() == Some("ByteLevel") { - is_byte_level = true; - } else if v["type"].as_str() == Some("Sequence") { - if let Some(decoders) = v["decoders"].as_array() { - for decoder in decoders { - if decoder["type"].as_str() == Some("ByteFallback") { - is_byte_fallback = true; - } else if decoder["type"].as_str() == Some("Replace") - && decoder["content"].as_str() == Some(" ") - { - if let Some(s) = decoder["pattern"]["String"].as_str() { - let s: Vec = s.chars().collect(); - if s.len() == 1 { - space_ch = s[0]; - } - } - } - } - } - } - } - - if !is_byte_fallback && !is_byte_level { - bail!("can't determine decoder type: {:?}", hft.get_decoder()); - } - - let vocab_size = hft.get_vocab_size(true) as u32; - let added = hft.get_added_tokens_decoder(); - - let mut res = ByteTokenizer { - hf_model: "foobar".to_string(), - eos_token: 0, - vocab_size, - special: BTreeMap::new(), - token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), - hf_tokenizer: hft, - }; - - for (id, info) in added.iter() { - if info.special { - match info.content.as_str() { - "" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id, - _ => {} - } - res.special.insert(info.content.clone(), *id); - } else { - res.token_bytes[*id as usize] = info.content.clone().into_bytes(); - } - } - - let char_map = build_char_map(); - - for tok_id in 0..vocab_size { - if added.contains_key(&tok_id) { - continue; - } - if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) { - if is_byte_fallback { - if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") - { - // parse hex number from tok_name - let hex_str = &tok_name[3..5]; - let byte = u8::from_str_radix(hex_str, 16).unwrap(); - res.token_bytes[tok_id as usize] = vec![byte]; - } else { - assert!(!tok_name.starts_with("<0x")); - let tok_name = tok_name.replace(space_ch, " "); - res.token_bytes[tok_id as usize] = tok_name.as_bytes().to_vec(); - } - } else if is_byte_level { - let bytes: Result> = tok_name - .chars() - .map(|c| { - char_map - .get(&c) - .map(|c| *c) - .ok_or_else(|| anyhow!("missing char: {}", c)) - }) - .collect(); - let bytes = match bytes { - Ok(b) => b, - Err(e) => { - println!("error: {} for {:?}", e, tok_name); - continue; - } - }; - - res.token_bytes[tok_id as usize] = bytes; - } else { - panic!(); - } - } else { - log::warn!("missing token: {}", tok_id); - } - } - - Ok(res) - } - - pub fn tokrx_info(&self) -> TokRxInfo { - TokRxInfo { - vocab_size: self.vocab_size, - tok_eos: self.eos_token, - } - } - pub fn token_bytes(&self) -> Vec> { - self.token_bytes.clone() - } - - pub fn add_missing_tokens(&mut self, vocab_size: usize) { - assert!(self.vocab_size == self.token_bytes.len() as u32); - assert!(vocab_size >= self.token_bytes.len()); - assert!(vocab_size - self.token_bytes.len() <= 200); - while self.token_bytes.len() < vocab_size { - let idx = self.token_bytes.len(); - let name = format!(""); - self.token_bytes.push(name.as_bytes().to_vec()); - self.vocab_size += 1; - self.special.insert(name, idx as u32); - } - } -} - -pub struct ByteTokenizerEnv { - pub tokenizer: ByteTokenizer, - pub tok_trie: TokTrie, -} - -impl ByteTokenizerEnv { - pub fn load(tokenizer_name: &str) -> Result { - let tokenizer = find_tokenizer(tokenizer_name)?; - Ok(Self::new(tokenizer)) - } - pub fn new(tokenizer: ByteTokenizer) -> ByteTokenizerEnv { - let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); - ByteTokenizerEnv { - tokenizer, - tok_trie, - } - } -} - -impl TokenizerEnv for ByteTokenizerEnv { - fn stop(&self) -> ! { - panic!("stop called") - } - - fn tok_trie(&self) -> &TokTrie { - &self.tok_trie - } - - fn tokenize_bytes(&self, s: &[u8]) -> Vec { - self.tok_trie.tokenize_with_greedy_fallback(s, |s| { - self.tokenizer - .hf_tokenizer - .encode(s, false) - .expect("tokenizer error") - .get_ids() - .to_vec() - }) - } -} diff --git a/controllers/toktrie_hf_tokenizers/Cargo.toml b/controllers/toktrie_hf_tokenizers/Cargo.toml new file mode 100644 index 00000000..5de529c9 --- /dev/null +++ b/controllers/toktrie_hf_tokenizers/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "toktrie_hf_tokenizers" +version = "0.1.0" +edition = "2021" + +[dependencies] +toktrie = { path = "../toktrie" } +serde = { version = "1.0.192", features = ["derive"] } +serde_json = "1.0.108" +anyhow = "1.0.75" +rustc-hash = { version = "2.0.0" } +tokenizers = { version = "0.15.0", features = ["http"] } +log = "0.4.21" diff --git a/controllers/toktrie_hf_tokenizers/src/lib.rs b/controllers/toktrie_hf_tokenizers/src/lib.rs new file mode 100644 index 00000000..3e434e7a --- /dev/null +++ b/controllers/toktrie_hf_tokenizers/src/lib.rs @@ -0,0 +1,226 @@ +use anyhow::{anyhow, bail, Result}; +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; +use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer}; +use toktrie::{TokRxInfo, TokTrie, TokenId, TokenizerEnv}; + +#[derive(Serialize, Deserialize)] +pub struct ByteTokenizer { + pub hf_model: String, + pub hf_tokenizer: Tokenizer, + pub eos_token: u32, + pub vocab_size: u32, + token_bytes: Vec>, + pub special: BTreeMap, +} + +// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi + +fn is_self_mapped(c: char) -> bool { + match c { + '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}' => true, + _ => false, + } +} + +fn build_char_map() -> FxHashMap { + let mut res = FxHashMap::default(); + let mut k = 0x100u32; + for byte in 0..=255u8 { + let c = byte as char; + if is_self_mapped(c) { + res.insert(c, byte); + } else { + res.insert(char::from_u32(k).unwrap(), byte); + k += 1; + } + } + res +} + +impl ByteTokenizer { + pub fn from_tokenizer(mut hft: Tokenizer) -> Result { + let mut is_byte_level = false; + let mut is_byte_fallback = false; + let mut space_ch = ' '; + + // remove the "Prepend space" + if let Some(n) = hft.get_normalizer() { + let n = match n { + NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new( + x.get_normalizers() + .iter() + .filter_map(|n| match n { + NormalizerWrapper::Prepend(_) => None, + _ => Some(n.clone()), + }) + .collect(), + )), + _ => n.clone(), + }; + hft.with_normalizer(n); + } + + if let Some(d) = hft.get_decoder() { + // DecoderWrapper::Sequence() doesn't let one access the decoders + // so we resort to json munching + let v = serde_json::to_value(d).unwrap(); + if v["type"].as_str() == Some("ByteLevel") { + is_byte_level = true; + } else if v["type"].as_str() == Some("Sequence") { + if let Some(decoders) = v["decoders"].as_array() { + for decoder in decoders { + if decoder["type"].as_str() == Some("ByteFallback") { + is_byte_fallback = true; + } else if decoder["type"].as_str() == Some("Replace") + && decoder["content"].as_str() == Some(" ") + { + if let Some(s) = decoder["pattern"]["String"].as_str() { + let s: Vec = s.chars().collect(); + if s.len() == 1 { + space_ch = s[0]; + } + } + } + } + } + } + } + + if !is_byte_fallback && !is_byte_level { + bail!("can't determine decoder type: {:?}", hft.get_decoder()); + } + + let vocab_size = hft.get_vocab_size(true) as u32; + let added = hft.get_added_tokens_decoder(); + + let mut res = ByteTokenizer { + hf_model: "foobar".to_string(), + eos_token: 0, + vocab_size, + special: BTreeMap::new(), + token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), + hf_tokenizer: hft, + }; + + for (id, info) in added.iter() { + if info.special { + match info.content.as_str() { + "" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id, + _ => {} + } + res.special.insert(info.content.clone(), *id); + } else { + res.token_bytes[*id as usize] = info.content.clone().into_bytes(); + } + } + + let char_map = build_char_map(); + + for tok_id in 0..vocab_size { + if added.contains_key(&tok_id) { + continue; + } + if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) { + if is_byte_fallback { + if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") + { + // parse hex number from tok_name + let hex_str = &tok_name[3..5]; + let byte = u8::from_str_radix(hex_str, 16).unwrap(); + res.token_bytes[tok_id as usize] = vec![byte]; + } else { + assert!(!tok_name.starts_with("<0x")); + let tok_name = tok_name.replace(space_ch, " "); + res.token_bytes[tok_id as usize] = tok_name.as_bytes().to_vec(); + } + } else if is_byte_level { + let bytes: Result> = tok_name + .chars() + .map(|c| { + char_map + .get(&c) + .map(|c| *c) + .ok_or_else(|| anyhow!("missing char: {}", c)) + }) + .collect(); + let bytes = match bytes { + Ok(b) => b, + Err(e) => { + println!("error: {} for {:?}", e, tok_name); + continue; + } + }; + + res.token_bytes[tok_id as usize] = bytes; + } else { + panic!(); + } + } else { + log::warn!("missing token: {}", tok_id); + } + } + + Ok(res) + } + + pub fn tokrx_info(&self) -> TokRxInfo { + TokRxInfo { + vocab_size: self.vocab_size, + tok_eos: self.eos_token, + } + } + pub fn token_bytes(&self) -> Vec> { + self.token_bytes.clone() + } + + pub fn add_missing_tokens(&mut self, vocab_size: usize) { + assert!(self.vocab_size == self.token_bytes.len() as u32); + assert!(vocab_size >= self.token_bytes.len()); + assert!(vocab_size - self.token_bytes.len() <= 200); + while self.token_bytes.len() < vocab_size { + let idx = self.token_bytes.len(); + let name = format!(""); + self.token_bytes.push(name.as_bytes().to_vec()); + self.vocab_size += 1; + self.special.insert(name, idx as u32); + } + } +} + +pub struct ByteTokenizerEnv { + pub tokenizer: ByteTokenizer, + pub tok_trie: TokTrie, +} + +impl ByteTokenizerEnv { + pub fn new(tokenizer: ByteTokenizer) -> ByteTokenizerEnv { + let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); + ByteTokenizerEnv { + tokenizer, + tok_trie, + } + } +} + +impl TokenizerEnv for ByteTokenizerEnv { + fn stop(&self) -> ! { + panic!("stop called") + } + + fn tok_trie(&self) -> &TokTrie { + &self.tok_trie + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + self.tok_trie.tokenize_with_greedy_fallback(s, |s| { + self.tokenizer + .hf_tokenizer + .encode(s, false) + .expect("tokenizer error") + .get_ids() + .to_vec() + }) + } +} From 9fd01946d938d01c10852eeeb821f0c4d61acee9 Mon Sep 17 00:00:00 2001 From: "microsoft-github-operations[bot]" <55726097+microsoft-github-operations[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 23:25:40 +0000 Subject: [PATCH 247/301] Initial commit --- .gitignore | 398 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..8a30d258 --- /dev/null +++ b/.gitignore @@ -0,0 +1,398 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml From f765298116e30ffa15ae486623dd9e61931e3043 Mon Sep 17 00:00:00 2001 From: Microsoft Open Source Date: Fri, 5 Jul 2024 16:25:43 -0700 Subject: [PATCH 248/301] CODE_OF_CONDUCT.md committed --- CODE_OF_CONDUCT.md | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..f9ba8cf6 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns From 9c8a8b81f632163a97b5b27f8fe6aa6353cd9aed Mon Sep 17 00:00:00 2001 From: Microsoft Open Source Date: Fri, 5 Jul 2024 16:25:43 -0700 Subject: [PATCH 249/301] LICENSE committed --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..9e841e7a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE From ee4b1e31ce2c5c2d412871aa6e007a87b0c57778 Mon Sep 17 00:00:00 2001 From: Microsoft Open Source Date: Fri, 5 Jul 2024 16:25:44 -0700 Subject: [PATCH 250/301] README.md committed --- README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 00000000..5cd7cecf --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Project + +> This repo has been populated by an initial template to help get you started. Please +> make sure to update the content to build a great experience for community-building. + +As the maintainer of this project, please make a few updates: + +- Improving this README.MD file to provide a great experience +- Updating SUPPORT.MD with content about this project's support experience +- Understanding the security reporting process in SECURITY.MD +- Remove this section from the README + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us +the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions +provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +## Trademarks + +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft +trademarks or logos is subject to and must follow +[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). +Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. +Any use of third-party trademarks or logos are subject to those third-party's policies. From 83229e414625a28a646a4638a5977a43b303a893 Mon Sep 17 00:00:00 2001 From: Microsoft Open Source Date: Fri, 5 Jul 2024 16:25:45 -0700 Subject: [PATCH 251/301] SUPPORT.md committed --- SUPPORT.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 SUPPORT.md diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 00000000..291d4d43 --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,25 @@ +# TODO: The maintainer of this repo has not yet edited this file + +**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? + +- **No CSS support:** Fill out this template with information about how to file issues and get help. +- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. +- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. + +*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* + +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE +FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER +CHANNEL. WHERE WILL YOU HELP PEOPLE?**. + +## Microsoft Support Policy + +Support for this **PROJECT or PRODUCT** is limited to the resources listed above. From 7313e361cb9c6296d307035904158a198e7ce13a Mon Sep 17 00:00:00 2001 From: Microsoft Open Source Date: Fri, 5 Jul 2024 16:25:45 -0700 Subject: [PATCH 252/301] SECURITY.md committed --- SECURITY.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..b3c89efc --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + From a0a4389bdf91f2ad7cb8a04f3d0f5a8422452078 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:17:17 -0700 Subject: [PATCH 253/301] remove unneeded code --- controllers/aici_abi/.cargo/config.toml | 8 - controllers/aici_abi/Cargo.toml | 30 - controllers/aici_abi/README.md | 159 --- controllers/aici_abi/grammars/c.y | 442 ------- controllers/aici_abi/grammars/json0.guidance | Bin 1326 -> 0 bytes controllers/aici_abi/grammars/sample.c | 1245 ------------------ controllers/aici_abi/implementation.md | 60 - controllers/aici_abi/src/cfg.rs | 597 --------- controllers/aici_abi/src/dlex.rs | 266 ---- controllers/aici_abi/src/host.rs | 383 ------ controllers/aici_abi/src/lex.rs | 349 ----- controllers/aici_abi/src/lib.rs | 236 ---- controllers/aici_abi/src/rx.rs | 114 -- controllers/aici_abi/src/substring.rs | 277 ---- controllers/aici_abi/src/yesno.rs | 43 - controllers/aici_native/Cargo.toml | 18 - controllers/aici_native/README.md | 3 - controllers/aici_native/src/bintokens.rs | 157 --- controllers/aici_native/src/lib.rs | 8 - controllers/aici_native/src/log.rs | 100 -- controllers/aici_native/src/variables.rs | 56 - 21 files changed, 4551 deletions(-) delete mode 100644 controllers/aici_abi/.cargo/config.toml delete mode 100644 controllers/aici_abi/Cargo.toml delete mode 100644 controllers/aici_abi/README.md delete mode 100644 controllers/aici_abi/grammars/c.y delete mode 100644 controllers/aici_abi/grammars/json0.guidance delete mode 100644 controllers/aici_abi/grammars/sample.c delete mode 100644 controllers/aici_abi/implementation.md delete mode 100644 controllers/aici_abi/src/cfg.rs delete mode 100644 controllers/aici_abi/src/dlex.rs delete mode 100644 controllers/aici_abi/src/host.rs delete mode 100644 controllers/aici_abi/src/lex.rs delete mode 100644 controllers/aici_abi/src/lib.rs delete mode 100644 controllers/aici_abi/src/rx.rs delete mode 100644 controllers/aici_abi/src/substring.rs delete mode 100644 controllers/aici_abi/src/yesno.rs delete mode 100644 controllers/aici_native/Cargo.toml delete mode 100644 controllers/aici_native/README.md delete mode 100644 controllers/aici_native/src/bintokens.rs delete mode 100644 controllers/aici_native/src/lib.rs delete mode 100644 controllers/aici_native/src/log.rs delete mode 100644 controllers/aici_native/src/variables.rs diff --git a/controllers/aici_abi/.cargo/config.toml b/controllers/aici_abi/.cargo/config.toml deleted file mode 100644 index e0b0d22a..00000000 --- a/controllers/aici_abi/.cargo/config.toml +++ /dev/null @@ -1,8 +0,0 @@ -[build] -target = "wasm32-wasi" - -[profile.dev] -strip = "debuginfo" - -[profile.release] -strip = "debuginfo" diff --git a/controllers/aici_abi/Cargo.toml b/controllers/aici_abi/Cargo.toml deleted file mode 100644 index 9d8027a3..00000000 --- a/controllers/aici_abi/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "aici_abi" -version = "0.1.0" -edition = "2021" -rust-version = "1.75.0" - -[lib] -name = "aici_abi" - -[dependencies] -toktrie = { path = "../toktrie" } -serde = { version = "1.0.192", features = ["derive"] } -serde_json = "1.0.108" -anyhow = "1.0.75" -regex-automata = { version = "0.4.6", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } -cfgrammar = { version = "0.13.3", optional = true } -lrtable = { version = "0.13.3", optional = true } -vob = { version = "3.0.3", optional = true } -rustc-hash = { version = "1.1.0", optional = true } -bytemuck = "1.16.0" -bytemuck_derive = "1.6.0" - -[features] -default = ["cfg", "rx"] -cfg = ["dep:cfgrammar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] -rx = ["dep:regex-automata"] - -[[bin]] -name = "yesno" -path = "src/yesno.rs" diff --git a/controllers/aici_abi/README.md b/controllers/aici_abi/README.md deleted file mode 100644 index b1df3c30..00000000 --- a/controllers/aici_abi/README.md +++ /dev/null @@ -1,159 +0,0 @@ -# aici_abi - -This crate specifies the application binary interface (ABI) for the AICI Controllers. -It also provides higher-level interfaces for implementing controllers. - -## Low-level interface - -Conceptually, the lowest level interface to AICI constraint is this: - -```rust -type TokenId = u32; -type SeqId = u32; - -trait AiciCtrl { - /// Called with the initial prompt. ~1000ms time limit. - fn init_prompt(prompt: Vec); - - /// Called before mid_process(), can fork or suspend. ~1ms. - fn pre_process() -> enum { - Stop, - Continue, // Same as Fork { num_forks: 1 } - Suspend, // skip this generation round - Fork { num_forks: u32 }, - } - - /// This is the main entry point for the module. ~20ms. - fn mid_process(fork_group: Vec) -> enum { - Stop, - SampleWithBias { bias: Vec }, - Splice { backtrack: u32, ff_tokens: Vec } - }; - - /// Called after tokens are appended. ~1ms. - fn post_process(tokens: Vec) -> enum { Stop, Continue }; -} -``` - -Tokens depend on the tokenizer used (eg., for Llama there 32000 tokens, and for GPT-4 there is ~100k). - -The actual binary interface is a bit more complicated, due -to limitations in passing values to and from Wasm. -A Wasm module instance is created for each token sequence. -Also, when the sequence forks (as in beam search), the module instance is cloned. -See the [AiciCtrl Rust trait](src/lib.rs) for details. - -A number of functions are exposed to the Wasm module. - -First, there are functions for accessing the current tokenizer: - -```rust -/// Given a byte sequence, return a sequence of token Ids. -fn tokenize_bytes(s: Vec) -> Vec; - -/// Represents trie of all tokens in the current tokenizer. -impl TokTrie { - /// Get Id for EOS token etc. - fn special_token(tok: SpecialToken) -> TokenId; - /// Number of tokens. - fn vocab_size() -> usize; - /// Convert token Id to bytes (often UTF-8 string). - fn token(token: TokenId) -> Vec; - /// Given a Recognizer, compute the set of allowed tokens. - fn compute_bias(rec: impl Recognizer) -> Vec; -} -``` - -Different forks in a sequence can communicate via shared variables: - -```rust -/// This can be looked up in fork_group. -fn self_seq_id() -> SeqId; - -trait VariableStorage { - fn get(name: str) -> Option>; - fn set(name: str, value: Vec); - fn append(name: str, value: Vec); -} -``` - -Additionally, the `stdout` and `stderr` file descriptors are captured by the runtime -and returned to user when streaming results. - -This interface may need to be extended in the future. - -See the `toktrie` crate for general utilities for building constraints. -This crate implements a few constraints including regexes, LR(1) grammars, and -substrings. - - -## Regular expressions - -The `FunctionalRecognizer` interface is implemented for regular expressions. -The `S` type is the state of the DFA (Deterministic Finite Automaton) that recognizes the regular expression, -then `append()` and `byte_allowed()` are the standard DFA operations, -while `special_allowed()` is only implemented for end-of-sequence token -(which is allowed when the current state is accepting). - -## LR(1) grammars - -The `Recognizer` interface is implemented for LR(1) grammars and DFA-based lexers. - -The grammar uses inline syntax for the lexer: - -- `"keyword"` or `'keyword'` for keywords; any string works, eg. `"+="`, `"while"`, ... -- `"/.../"` or `'/.../'` for regular expressions; you cannot have both `'` and `"` in the regex - Special `SKIP` rule is used to indicate tokens that need to be skipped by the LR(1) parser (eg., whitespace and comments) - -The lexer has a DFA which recognizes all regexps and keywords -(a big disjunction, but with additional machinery to disambiguate between different branches). -It goes byte by byte, until the DFA gets to a dead state (from which no match is possible). -Then it goes back one byte and checks for match. -It prefers keywords over regexps. -If no match is found, an error is reported, which requires careful design of the lexical part of the grammar -(eg., see how the `white-space` rule below is prefix of the `pre-processor` rule). - -For example, this is fragment of [grammar for C](./grammars/c.y): - -```yacc -%start translation_unit -%% - -SKIP - : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment - | "///.*/" // line comment - | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor - | "/\n?[ \t\v\f]*/" // white-space - ; - -IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; - -CONSTANT - : "/0[xX][0-9a-fA-F]+[uUlL]*?/" - | "/0[0-9]+[uUlL]*?/" - ; - -STRING_LITERAL: '/"(\\.|[^\\"])*"/' ; - -primary_expression - : IDENTIFIER - | CONSTANT - | STRING_LITERAL - | "(" expression ")" - ; - -// ... - -enum_specifier - : "enum" "{" enumerator_list "}" - | "enum" IDENTIFIER "{" enumerator_list "}" - | "enum" IDENTIFIER - ; - -// ... - -translation_unit - : external_declaration - | translation_unit external_declaration - ; -``` diff --git a/controllers/aici_abi/grammars/c.y b/controllers/aici_abi/grammars/c.y deleted file mode 100644 index 7397a971..00000000 --- a/controllers/aici_abi/grammars/c.y +++ /dev/null @@ -1,442 +0,0 @@ -// based on http://www.lysator.liu.se/c/ANSI-C-grammar-y.html - -%start translation_unit -%% - -SKIP - : "//\*[^*]*\*+([^/*][^*]*\*+)*//" // block comment - | "///.*/" // line comment - | "/\n[ \t\v\f]*#(.*\\\n)*.*/" // pre-processor - | "/\n?[ \t\v\f]*/" // white-space - ; - -IDENTIFIER: "/[a-zA-Z_][0-9a-zA-Z_]*/" ; - -TYPE_NAME: "/[a-zA-Z_][0-9a-zA-Z_]*_t/" ; - -CONSTANT - : "/0[xX][0-9a-fA-F]+[uUlL]*?/" - | "/0[0-9]+[uUlL]*?/" - | "/[0-9]+[uUlL]*?/" - | "/[a-zA-Z_]?'(\\.|[^\\'])+'/" - | "/[0-9]+[Ee][+-]?[0-9]+[flFL]?/" - | "/[0-9]*\\.[0-9]+([Ee][+-]?[0-9]+)?[flFL]?/" - | "/[0-9]+\\.[0-9]*([Ee][+-]?[0-9]+)?[flFL]?/" - ; - -STRING_LITERAL: '/[a-zA-Z_]?"(\\.|[^\\"])*"/' ; - -primary_expression - : IDENTIFIER - | CONSTANT - | STRING_LITERAL - | "(" expression ")" - ; - -postfix_expression - : primary_expression - | postfix_expression "[" expression "]" - | postfix_expression "(" ")" - | postfix_expression "(" argument_expression_list ")" - | postfix_expression "." IDENTIFIER - | postfix_expression "->" IDENTIFIER - | postfix_expression "++" - | postfix_expression "--" - ; - -argument_expression_list - : assignment_expression - | argument_expression_list "," assignment_expression - ; - -unary_expression - : postfix_expression - | "++" unary_expression - | "--" unary_expression - | unary_operator cast_expression - | "sizeof" unary_expression - | "sizeof" "(" type_name ")" - ; - -unary_operator - : "&" - | "*" - | "+" - | "-" - | "~" - | "!" - ; - -cast_expression - : unary_expression - | "(" type_name ")" cast_expression - ; - -multiplicative_expression - : cast_expression - | multiplicative_expression "*" cast_expression - | multiplicative_expression "/" cast_expression - | multiplicative_expression "%" cast_expression - ; - -additive_expression - : multiplicative_expression - | additive_expression "+" multiplicative_expression - | additive_expression "-" multiplicative_expression - ; - -shift_expression - : additive_expression - | shift_expression "<<" additive_expression - | shift_expression ">>" additive_expression - ; - -relational_expression - : shift_expression - | relational_expression "<" shift_expression - | relational_expression ">" shift_expression - | relational_expression "<=" shift_expression - | relational_expression ">=" shift_expression - ; - -equality_expression - : relational_expression - | equality_expression "==" relational_expression - | equality_expression "!=" relational_expression - ; - -and_expression - : equality_expression - | and_expression "&" equality_expression - ; - -exclusive_or_expression - : and_expression - | exclusive_or_expression "^" and_expression - ; - -inclusive_or_expression - : exclusive_or_expression - | inclusive_or_expression "|" exclusive_or_expression - ; - -logical_and_expression - : inclusive_or_expression - | logical_and_expression "&&" inclusive_or_expression - ; - -logical_or_expression - : logical_and_expression - | logical_or_expression "||" logical_and_expression - ; - -conditional_expression - : logical_or_expression - | logical_or_expression "?" expression ":" conditional_expression - ; - -assignment_expression - : conditional_expression - | unary_expression assignment_operator assignment_expression - ; - -assignment_operator - : "=" - | "*=" - | "/=" - | "%=" - | "+=" - | "-=" - | "<<=" - | ">>=" - | "&=" - | "^=" - | "|=" - ; - -expression - : assignment_expression - | expression "," assignment_expression - ; - -constant_expression - : conditional_expression - ; - -declaration - : declaration_specifiers ";" - | declaration_specifiers init_declarator_list ";" - ; - -declaration_specifiers - : storage_class_specifier - | storage_class_specifier declaration_specifiers - | type_specifier - | type_specifier declaration_specifiers - | type_qualifier - | type_qualifier declaration_specifiers - ; - -init_declarator_list - : init_declarator - | init_declarator_list "," init_declarator - ; - -init_declarator - : declarator - | declarator "=" initializer - ; - -storage_class_specifier - : "typedef" - | "extern" - | "static" - | "auto" - | "register" - | "inline" - ; - -type_specifier - : "void" - | "char" - | "short" - | "int" - | "long" - | "float" - | "double" - | "signed" - | "unsigned" - | "bool" - | struct_or_union_specifier - | enum_specifier - | TYPE_NAME - ; - -struct_or_union_specifier - : struct_or_union IDENTIFIER "{" struct_declaration_list "}" - | struct_or_union "{" struct_declaration_list "}" - | struct_or_union IDENTIFIER - ; - -struct_or_union - : "struct" - | "union" - ; - -struct_declaration_list - : struct_declaration - | struct_declaration_list struct_declaration - ; - -struct_declaration - : specifier_qualifier_list struct_declarator_list ";" - ; - -specifier_qualifier_list - : type_specifier specifier_qualifier_list - | type_specifier - | type_qualifier specifier_qualifier_list - | type_qualifier - ; - -struct_declarator_list - : struct_declarator - | struct_declarator_list "," struct_declarator - ; - -struct_declarator - : declarator - | ":" constant_expression - | declarator ":" constant_expression - ; - -enum_specifier - : "enum" "{" enumerator_list "}" - | "enum" IDENTIFIER "{" enumerator_list "}" - | "enum" IDENTIFIER - ; - -enumerator_list - : enumerator - | enumerator_list "," enumerator - ; - -enumerator - : IDENTIFIER - | IDENTIFIER "=" constant_expression - ; - -type_qualifier - : "const" - | "volatile" - ; - -declarator - : pointer direct_declarator - | direct_declarator - ; - -direct_declarator - : IDENTIFIER - | "(" declarator ")" - | direct_declarator "[" constant_expression "]" - | direct_declarator "[" "]" - | direct_declarator "(" parameter_type_list ")" - | direct_declarator "(" identifier_list ")" - | direct_declarator "(" ")" - ; - -pointer - : "*" - | "*" type_qualifier_list - | "*" pointer - | "*" type_qualifier_list pointer - ; - -type_qualifier_list - : type_qualifier - | type_qualifier_list type_qualifier - ; - - -parameter_type_list - : parameter_list - | parameter_list "," "..." - ; - -parameter_list - : parameter_declaration - | parameter_list "," parameter_declaration - ; - -parameter_declaration - : declaration_specifiers declarator - | declaration_specifiers abstract_declarator - | declaration_specifiers - ; - -identifier_list - : IDENTIFIER - | identifier_list "," IDENTIFIER - ; - -type_name - : specifier_qualifier_list - | specifier_qualifier_list abstract_declarator - ; - -abstract_declarator - : pointer - | direct_abstract_declarator - | pointer direct_abstract_declarator - ; - -direct_abstract_declarator - : "(" abstract_declarator ")" - | "[" "]" - | "[" constant_expression "]" - | direct_abstract_declarator "[" "]" - | direct_abstract_declarator "[" constant_expression "]" - | "(" ")" - | "(" parameter_type_list ")" - | direct_abstract_declarator "(" ")" - | direct_abstract_declarator "(" parameter_type_list ")" - ; - -initializer - : assignment_expression - | "." IDENTIFIER "=" assignment_expression - | "[" assignment_expression "]" "=" assignment_expression - | "{" initializer_list "}" - | "{" initializer_list "," "}" - ; - -initializer_list - : initializer - | initializer_list "," initializer - ; - -statement - : labeled_statement - | compound_statement - | expression_statement - | selection_statement - | iteration_statement - | jump_statement - ; - -labeled_statement - : IDENTIFIER ":" statement - | "case" constant_expression ":" statement - | "default" ":" statement - ; - -compound_statement - : "{" "}" - | "{" statement_list "}" - ; - -declaration_list - : declaration - | declaration_list declaration - ; - -statement_or_declaration - : statement - | declaration - ; - -statement_list - : statement_or_declaration - | statement_list statement_or_declaration - ; - -expression_statement - : ";" - | expression ";" - ; - -for_decl - : expression_statement - | declaration - ; - -selection_statement - : "if" "(" expression ")" statement - | "if" "(" expression ")" statement "else" statement - | "switch" "(" expression ")" statement - ; - -iteration_statement - : "while" "(" expression ")" statement - | "do" statement "while" "(" expression ")" ";" - | "for" "(" for_decl expression_statement ")" statement - | "for" "(" for_decl expression_statement expression ")" statement - ; - -jump_statement - : "goto" IDENTIFIER ";" - | "continue" ";" - | "break" ";" - | "return" ";" - | "return" expression ";" - ; - -translation_unit - : external_declaration - | translation_unit external_declaration - ; - -external_declaration - : function_definition - | declaration - ; - -function_definition - : declaration_specifiers declarator declaration_list compound_statement - | declaration_specifiers declarator compound_statement - | declarator declaration_list compound_statement - | declarator compound_statement - ; - -%% diff --git a/controllers/aici_abi/grammars/json0.guidance b/controllers/aici_abi/grammars/json0.guidance deleted file mode 100644 index bcad296f58677710f08530ebefa086ef3f84dbb6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1326 zcmZvc+j0^?5Qfvs3S>e!Ob}y>kzmXc2@s=bLZUg~F`jP~%VHHGDxip`SXSkQdzpU(4SAX5}_wM#TDx*@`CFfe19e=QP`0IB~DUK`pR-c$S{HX#QQB!$? z+k#VST1VVtF4I@2FG_t*WwjfP#V0xDdxLhr<3Y_pYJ&6I5ceVOK-`d61QSN^Kn~l~ z4PBj1WoC|_IC<*ynfbHlyruKY)s+huFRv%lI-8s2{9K{9uvnsUrKVOlHm@?7>3{rj zw@D>cLOY2hlP{-D6(^O`J`gQ|X{Hwdcb*{BOfLbhtRu`abp!@)SwWU#>H>R?Y?f&s zn1_s+UI8m0%QN+WEhC#_dJU|KtibdJ*gUc#(_3I$$QGFTz-q`AnQ()iA}cWsq3;xDOWQu8w8mmknp>_jnFEM=*Hls4ysYf~Ld#%<0 zSGk1^SC~Fy!data) { - devs_free(ctx, map->data); - map->data = NULL; - map->capacity = 0; - map->length = 0; - } -} - -static inline uint16_t *short_keys(devs_short_map_t *map) { - return (uint16_t *)(map->short_data + map->capacity); -} - -static value_t *lookup_short(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key) { - unsigned len = map->length; - uint16_t *keys = short_keys(map); - for (unsigned i = 0; i < len; i++) { - if (keys[i] == key) { - return &map->short_data[i]; - } - } - return NULL; -} - -static value_t *lookup(devs_ctx_t *ctx, devs_map_t *map, value_t key) { - if (!devs_is_string(ctx, key)) - return NULL; - - value_t *data = map->data; - uint32_t kh = devs_handle_value(key); - unsigned len2 = map->length * 2; - - // do a quick reference-only check - for (unsigned i = 0; i < len2; i += 2) { - // check the low bits first, since they are more likely to be different - if (devs_handle_value(data[i]) == kh && data[i].u64 == key.u64) { - return &data[i + 1]; - } - } - - // slow path - compare strings - unsigned ksz, csz; - const char *cp, *kp = devs_string_get_utf8(ctx, key, &ksz); - for (unsigned i = 0; i < len2; i += 2) { - cp = devs_string_get_utf8(ctx, data[i], &csz); - if (csz == ksz && memcmp(kp, cp, ksz) == 0) - return &data[i + 1]; - } - - // nothing found... - return NULL; -} - -static value_t proto_value(devs_ctx_t *ctx, const devs_builtin_proto_entry_t *p) { - unsigned idx = p->builtin_idx; - if (idx <= DEVS_BUILTIN_OBJECT___MAX) - return devs_builtin_object_value(ctx, idx); - JD_ASSERT(idx >= DEVS_FIRST_BUILTIN_FUNCTION); - return devs_value_from_handle(DEVS_HANDLE_TYPE_STATIC_FUNCTION, idx); -} - -unsigned devs_maplike_iter(devs_ctx_t *ctx, devs_maplike_t *src, void *userdata, - devs_map_iter_cb_t cb) { - if (devs_is_service_spec(ctx, src)) { - // Object.keys() etc or debugger inspection on compiled spec - // return empty for now, do not crash - return 0; - } else if (devs_is_builtin_proto(src)) { - const devs_builtin_proto_t *proto = (const devs_builtin_proto_t *)src; - const devs_builtin_proto_entry_t *p = proto->entries; - while (p->builtin_string_id) { - if (cb) - cb(ctx, userdata, devs_builtin_string(p->builtin_string_id), proto_value(ctx, p)); - p++; - } - return p - proto->entries; - } else { - JD_ASSERT(devs_is_map(src)); - devs_map_t *srcmap = (devs_map_t *)src; - unsigned len = srcmap->length; - - if (cb != NULL) { - unsigned len2 = srcmap->length * 2; - value_t *data = srcmap->data; - for (unsigned i = 0; i < len2; i += 2) { - cb(ctx, userdata, data[i], data[i + 1]); - } - } - - if (devs_gc_tag(srcmap) == DEVS_GC_TAG_HALF_STATIC_MAP) - len += devs_maplike_iter(ctx, srcmap->proto, userdata, cb); - - return len; - } -} - -void devs_map_copy_into(devs_ctx_t *ctx, devs_map_t *dst, devs_maplike_t *src) { - devs_maplike_iter(ctx, src, dst, (devs_map_iter_cb_t)devs_map_set); -} - -struct kv_ctx { - unsigned dp; - bool keys; - devs_array_t *arr; -}; - -static void kv_add(devs_ctx_t *ctx, void *userdata, value_t k, value_t v) { - struct kv_ctx *acc = userdata; - acc->arr->data[acc->dp++] = acc->keys ? k : v; -} - -bool devs_maplike_is_map(devs_ctx_t *ctx, devs_maplike_t *src) { - if (src == NULL || devs_is_builtin_proto(src) || devs_is_service_spec(ctx, src)) - return false; - JD_ASSERT(devs_is_map(src)); - return true; -} - -void devs_maplike_keys_or_values(devs_ctx_t *ctx, devs_maplike_t *src, devs_array_t *arr, - bool keys) { - struct kv_ctx acc = { - .dp = arr->length, - .arr = arr, - .keys = keys, - }; - - unsigned len = devs_maplike_iter(ctx, src, NULL, NULL); - - if (devs_array_insert(ctx, arr, acc.dp, len) != 0) - return; - - devs_maplike_iter(ctx, src, &acc, kv_add); -} - -static int grow_len(int capacity) { - int newlen = capacity * 10 / 8; - if (newlen < 4) - newlen = 4; - return newlen; -} - -void devs_map_set(devs_ctx_t *ctx, devs_map_t *map, value_t key, value_t v) { - value_t *tmp = lookup(ctx, map, key); - if (tmp != NULL) { - *tmp = v; - return; - } - - if (!devs_is_string(ctx, key)) { - devs_throw_expecting_error(ctx, DEVS_BUILTIN_STRING_STRING, key); - return; - } - - JD_ASSERT(map->capacity >= map->length); - - if (map->capacity == map->length) { - int newlen = grow_len(map->capacity); - tmp = devs_try_alloc(ctx, newlen * (2 * sizeof(value_t))); - if (!tmp) - return; - map->capacity = newlen; - if (map->length) { - memcpy(tmp, map->data, map->length * sizeof(value_t) * 2); - } - map->data = tmp; - jd_gc_unpin(ctx->gc, tmp); - } - - map->data[map->length * 2] = key; - map->data[map->length * 2 + 1] = v; - map->length++; -} - -void devs_short_map_set(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key, value_t v) { - value_t *tmp = lookup_short(ctx, map, key); - if (tmp != NULL) { - *tmp = v; - return; - } - - JD_ASSERT(map->capacity >= map->length); - - if (map->capacity == map->length) { - int newlen = grow_len(map->capacity); - tmp = devs_try_alloc(ctx, newlen * (sizeof(value_t) + sizeof(uint16_t))); - if (!tmp) - return; - uint16_t *srckeys = short_keys(map); - map->capacity = newlen; - if (map->length) { - memcpy(tmp, map->short_data, map->length * sizeof(value_t)); - memcpy(tmp + newlen, srckeys, map->length * sizeof(uint16_t)); - } - map->short_data = tmp; - jd_gc_unpin(ctx->gc, tmp); - } - - map->short_data[map->length] = v; - short_keys(map)[map->length] = key; - map->length++; -} - -int devs_map_delete(devs_ctx_t *ctx, devs_map_t *map, value_t key) { - value_t *tmp = lookup(ctx, map, key); - if (tmp == NULL) { - return -1; - } - - tmp--; - unsigned off = tmp - map->data; - unsigned trailing = map->length - off / 2 - 1; - map->length--; - if (trailing) - memmove(tmp, tmp + 2, trailing * 2 * sizeof(value_t)); - return 0; -} - -bool devs_is_service_spec(devs_ctx_t *ctx, const void *ptr) { - return (uintptr_t)((const uint8_t *)ptr - - (const uint8_t *)devs_img_get_service_spec(ctx->img, 0)) < - (sizeof(devs_service_spec_t) * ctx->img.header->num_service_specs); -} - -value_t devs_map_get(devs_ctx_t *ctx, devs_map_t *map, value_t key) { - value_t *tmp = lookup(ctx, map, key); - if (tmp == NULL) - return devs_undefined; - return *tmp; -} - -value_t devs_short_map_get(devs_ctx_t *ctx, devs_short_map_t *map, uint16_t key) { - value_t *tmp = lookup_short(ctx, map, key); - if (tmp == NULL) - return devs_undefined; - return *tmp; -} - -static const devs_builtin_proto_t *get_static_built_in_proto(devs_ctx_t *ctx, unsigned idx) { - JD_ASSERT(idx <= DEVS_BUILTIN_OBJECT___MAX); - if (devs_builtin_protos[idx].entries == NULL) - return NULL; // not there? - return &devs_builtin_protos[idx]; -} - -static const uint8_t builtin_proto_idx[] = { - [DEVS_BUILTIN_OBJECT_MATH] = 1, - [DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE] = 2, - [DEVS_BUILTIN_OBJECT_ARRAY_PROTOTYPE] = 3, - [DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE] = 4, - [DEVS_BUILTIN_OBJECT_DSREGISTER_PROTOTYPE] = 5, - [DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE] = 6, - [DEVS_BUILTIN_OBJECT_DSEVENT_PROTOTYPE] = 7, - [DEVS_BUILTIN_OBJECT_DEVICESCRIPT] = 8, - [DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE] = 9, - [DEVS_BUILTIN_OBJECT_BUFFER] = 10, - [DEVS_BUILTIN_OBJECT_GPIO_PROTOTYPE] = 11, - [DEVS_BUILTIN_OBJECT_GPIO] = 12, -}; -#define MAX_PROTO 12 - -devs_maplike_t *devs_get_builtin_object(devs_ctx_t *ctx, unsigned idx) { - if (idx < sizeof(builtin_proto_idx)) { - unsigned midx = builtin_proto_idx[idx]; - if (midx > 0) { - midx--; - if (ctx->_builtin_protos == NULL) { - ctx->_builtin_protos = devs_try_alloc(ctx, sizeof(void *) * MAX_PROTO); - ctx->_num_builtin_protos = MAX_PROTO; - if (ctx->_builtin_protos == NULL) - return NULL; // whoops - } - JD_ASSERT(midx < MAX_PROTO); - devs_map_t *m = ctx->_builtin_protos[midx]; - if (m == NULL) { - m = devs_any_try_alloc(ctx, DEVS_GC_TAG_HALF_STATIC_MAP, sizeof(devs_map_t)); - if (m != NULL) { - ctx->_builtin_protos[midx] = m; - m->proto = (devs_maplike_t *)get_static_built_in_proto(ctx, idx); - } - } - return (devs_maplike_t *)m; - } - } - - return (devs_maplike_t *)get_static_built_in_proto(ctx, idx); -} - -bool devs_static_streq(devs_ctx_t *ctx, unsigned stridx, const char *other, unsigned other_len) { - unsigned size; - const char *r = devs_img_get_utf8(ctx->img, stridx, &size); - if (other_len != size) - return false; - return memcmp(r, other, size) == 0; -} - -#define MAX_OFF_BITS (DEVS_PACK_SHIFT - DEVS_ROLE_BITS) - -value_t devs_value_from_service_spec_idx(devs_ctx_t *ctx, unsigned idx) { - return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, - DEVS_ROLE_INVALID | (idx << DEVS_ROLE_BITS)); -} - -value_t devs_value_from_service_spec(devs_ctx_t *ctx, const devs_service_spec_t *spec) { - unsigned idx = spec - devs_img_get_service_spec(ctx->img, 0); - JD_ASSERT(idx < ctx->img.header->num_service_specs); - return devs_value_from_service_spec_idx(ctx, idx); -} - -value_t devs_value_from_packet_spec(devs_ctx_t *ctx, const devs_packet_spec_t *pkt) { - if (pkt == NULL) - return devs_undefined; - const uint32_t *baseoff = (const void *)devs_img_get_service_spec(ctx->img, 0); - uintptr_t off = (const uint32_t *)pkt - baseoff; - JD_ASSERT(off < (1 << MAX_OFF_BITS)); - return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, - DEVS_ROLE_INVALID | (off << DEVS_ROLE_BITS)); -} - -int devs_value_to_service_spec_idx(devs_ctx_t *ctx, value_t v) { - if (devs_handle_type(v) != DEVS_HANDLE_TYPE_ROLE_MEMBER) - return -1; - unsigned off = devs_handle_value(v) >> DEVS_ROLE_BITS; - if (off < ctx->img.header->num_service_specs) - return off; - return -1; -} - -const devs_service_spec_t *devs_value_to_service_spec(devs_ctx_t *ctx, value_t v) { - int off = devs_value_to_service_spec_idx(ctx, v); - if (off < 0) - return NULL; - return devs_img_get_service_spec(ctx->img, off); -} - -const devs_packet_spec_t *devs_decode_role_packet(devs_ctx_t *ctx, value_t v, unsigned *roleidx) { - if (roleidx) - *roleidx = DEVS_ROLE_INVALID; - if (devs_handle_type(v) != DEVS_HANDLE_TYPE_ROLE_MEMBER) - return NULL; - if (devs_value_to_service_spec(ctx, v)) - return NULL; - uint32_t h = devs_handle_value(v); - if (roleidx) - *roleidx = h & DEVS_ROLE_MASK; - return devs_img_get_packet_spec(ctx->img, h >> DEVS_ROLE_BITS); -} - -int devs_spec_idx(devs_ctx_t *ctx, const devs_service_spec_t *spec) { - if (spec == NULL) - return -1; - unsigned idx = spec - devs_img_get_service_spec(ctx->img, 0); - JD_ASSERT(idx < ctx->img.header->num_service_specs); - return idx; -} - -const devs_service_spec_t *devs_role_spec_for_class(devs_ctx_t *ctx, uint32_t cls) { - for (unsigned i = 0; i < ctx->img.header->num_service_specs; ++i) { - const devs_service_spec_t *spec = devs_img_get_service_spec(ctx->img, i); - if (spec->service_class == cls) - return spec; - } - return NULL; -} - -int devs_packet_spec_parent(devs_ctx_t *ctx, const devs_packet_spec_t *pspec) { - int off = (uint8_t *)pspec - ctx->img.data - ctx->img.header->service_specs.start; - for (unsigned i = 0; i < ctx->img.header->num_service_specs; ++i) { - const devs_service_spec_t *spec = devs_img_get_service_spec(ctx->img, i); - int idx = off - 4 * spec->packets_offset; - if (0 <= idx && idx < (int)(spec->num_packets * sizeof(devs_packet_spec_t))) - return i; - } - JD_PANIC(); - return -1; -} - -const devs_service_spec_t *devs_role_spec(devs_ctx_t *ctx, unsigned roleidx) { - if (roleidx >= DEVS_ROLE_FIRST_SPEC) { - unsigned specidx = roleidx - DEVS_ROLE_FIRST_SPEC; - if (specidx >= ctx->img.header->num_service_specs) - return NULL; - return devs_img_get_service_spec(ctx->img, specidx); - } - - devs_role_t *r = devs_role(ctx, roleidx); - - if (!r) - return NULL; - - return devs_role_spec_for_class(ctx, r->jdrole->service_class); -} - -devs_role_t *devs_role_or_fail(devs_ctx_t *ctx, unsigned roleidx) { - devs_role_t *r = devs_role(ctx, roleidx); - if (r == NULL) - devs_invalid_program(ctx, 60130); - return r; -} - -jd_device_service_t *devs_role_service(devs_ctx_t *ctx, unsigned roleidx) { - devs_role_t *r = devs_role(ctx, roleidx); - if (r == NULL) - return NULL; - return r->jdrole->service; -} - -const char *devs_role_name(devs_ctx_t *ctx, unsigned idx) { - devs_role_t *r = devs_role(ctx, idx); - if (r == NULL) - return "???"; - return r->jdrole->name; -} - -const devs_service_spec_t *devs_get_base_spec(devs_ctx_t *ctx, const devs_service_spec_t *spec) { - if (spec->service_class == JD_SERVICE_CLASS_BASE) - return NULL; - int idx = spec->flags & DEVS_SERVICESPEC_FLAG_DERIVE_MASK; - JD_ASSERT(idx <= DEVS_SERVICESPEC_FLAG_DERIVE_LAST); - return devs_img_get_service_spec(ctx->img, idx); -} - -value_t devs_spec_lookup(devs_ctx_t *ctx, const devs_service_spec_t *spec, value_t key) { - while (spec) { - JD_ASSERT(devs_is_service_spec(ctx, spec)); - const devs_packet_spec_t *pkts = devs_img_get_packet_spec(ctx->img, spec->packets_offset); - unsigned num_packets = spec->num_packets; - - if (devs_handle_type(key) == DEVS_HANDLE_TYPE_IMG_BUFFERISH) { - unsigned kidx = devs_handle_value(key); - for (unsigned i = 0; i < num_packets; ++i) { - if (pkts[i].name_idx == kidx) - return devs_value_from_packet_spec(ctx, &pkts[i]); - } - } - - unsigned ksz; - const char *kptr = devs_string_get_utf8(ctx, key, &ksz); - if (ksz == 0) - return devs_undefined; - - for (unsigned i = 0; i < num_packets; ++i) { - if (devs_static_streq(ctx, pkts[i].name_idx, kptr, ksz)) - return devs_value_from_packet_spec(ctx, &pkts[i]); - } - - spec = devs_get_base_spec(ctx, spec); - } - - return devs_undefined; -} - -static value_t devs_proto_lookup(devs_ctx_t *ctx, const devs_builtin_proto_t *proto, value_t key) { - JD_ASSERT(devs_is_proto(proto)); - - while (proto) { - const devs_builtin_proto_entry_t *p = proto->entries; - - if (devs_handle_type(key) == DEVS_HANDLE_TYPE_IMG_BUFFERISH && - (devs_handle_value(key) >> DEVS_STRIDX__SHIFT) == DEVS_STRIDX_BUILTIN) { - unsigned kidx = devs_handle_value(key) & ((1 << DEVS_STRIDX__SHIFT) - 1); - while (p->builtin_string_id) { - if (p->builtin_string_id == kidx) - return proto_value(ctx, p); - p++; - } - } else { - unsigned ksz; - const char *kptr = devs_string_get_utf8(ctx, key, &ksz); - if (ksz != strlen(kptr)) - return devs_undefined; - while (p->builtin_string_id) { - if (strcmp(devs_builtin_string_by_idx(p->builtin_string_id), kptr) == 0) - return proto_value(ctx, p); - p++; - } - } - - proto = proto->parent; - } - - return devs_undefined; -} - -static value_t devs_function_bind_alloc(devs_ctx_t *ctx, value_t obj, value_t fn) { - devs_bound_function_t *res = - devs_any_try_alloc(ctx, DEVS_GC_TAG_BOUND_FUNCTION, sizeof(devs_bound_function_t)); - if (res == NULL) - return devs_undefined; - - res->this_val = obj; - res->func = fn; - return devs_value_from_gc_obj(ctx, res); -} - -static const devs_builtin_function_t *devs_get_property_desc(devs_ctx_t *ctx, value_t fn) { - int htp = devs_handle_type(fn); - - if (htp != DEVS_HANDLE_TYPE_STATIC_FUNCTION) - return NULL; - - unsigned fidx = devs_handle_value(fn); - - int bltin = fidx - DEVS_FIRST_BUILTIN_FUNCTION; - if (bltin >= 0) { - JD_ASSERT(bltin < devs_num_builtin_functions); - const devs_builtin_function_t *h = &devs_builtin_functions[bltin]; - if (h->flags & DEVS_BUILTIN_FLAG_IS_PROPERTY) { - JD_ASSERT(h->num_args == 0); - return h; - } - } - - return NULL; -} - -// if `fn` is a static function, return `(obj, fn)` tuple -// if `fn` is a role member and `obj` is role, return (a different) `(obj, fn)` tuple -// otherwise return `obj` -// it may allocate an object for the tuple, but typically it doesn't -value_t devs_function_bind(devs_ctx_t *ctx, value_t obj, value_t fn) { - int htp = devs_handle_type(fn); - - if (htp == DEVS_HANDLE_TYPE_ROLE_MEMBER && devs_handle_type(obj) == DEVS_HANDLE_TYPE_ROLE && - !devs_value_to_service_spec(ctx, fn)) { - uint32_t role = devs_handle_value(obj); - JD_ASSERT((role & DEVS_ROLE_MASK) == role); - role |= devs_handle_value(fn) & ~DEVS_ROLE_MASK; - return devs_value_from_handle(DEVS_HANDLE_TYPE_ROLE_MEMBER, role); - } - - if (htp == DEVS_HANDLE_TYPE_CLOSURE) - return devs_function_bind_alloc(ctx, obj, fn); - - if (htp != DEVS_HANDLE_TYPE_STATIC_FUNCTION) - return fn; - - const devs_builtin_function_t *h = devs_get_property_desc(ctx, fn); - if (h) - return h->handler.prop(ctx, obj); - - unsigned fidx = devs_handle_value(fn); - int otp = devs_handle_type(obj); - - if (fidx <= 0xffff) - switch (otp) { - case DEVS_HANDLE_TYPE_SPECIAL: - case DEVS_HANDLE_TYPE_FIBER: - case DEVS_HANDLE_TYPE_ROLE: - case DEVS_HANDLE_TYPE_ROLE_MEMBER: - case DEVS_HANDLE_TYPE_STATIC_FUNCTION: - case DEVS_HANDLE_TYPE_IMG_BUFFERISH: { - uint32_t hv = devs_handle_value(obj); - JD_ASSERT((((uint32_t)otp << DEVS_PACK_SHIFT) >> DEVS_PACK_SHIFT) == (uint32_t)otp); - JD_ASSERT((hv >> DEVS_PACK_SHIFT) == 0); - JD_ASSERT(devs_handle_high_value(obj) == 0); - return devs_value_from_handle(DEVS_HANDLE_TYPE_BOUND_FUNCTION_STATIC | (fidx << 4), - (otp << DEVS_PACK_SHIFT) | hv); - } - - case DEVS_HANDLE_TYPE_GC_OBJECT: - JD_ASSERT(devs_handle_high_value(obj) == 0); - return devs_value_from_handle(DEVS_HANDLE_TYPE_BOUND_FUNCTION | (fidx << 4), - devs_handle_value(obj)); - } - - return devs_function_bind_alloc(ctx, obj, fn); -} - -value_t devs_make_closure(devs_ctx_t *ctx, devs_activation_t *closure, unsigned fnidx) { - JD_ASSERT(fnidx <= 0xffff); - return devs_value_from_pointer(ctx, DEVS_HANDLE_TYPE_CLOSURE | (fnidx << 4), closure); -} - -static int devs_get_fnidx_core(devs_ctx_t *ctx, value_t src, value_t *this_val, - devs_activation_t **closure, int depth) { - *closure = NULL; - *this_val = devs_undefined; - - if (depth > 2) - return -1; - - uint32_t hv = devs_handle_value(src); - switch (devs_handle_type(src)) { - case DEVS_HANDLE_TYPE_STATIC_FUNCTION: - *this_val = devs_undefined; - return hv; - case DEVS_HANDLE_TYPE_BOUND_FUNCTION_STATIC: - *this_val = - devs_value_from_handle(hv >> DEVS_PACK_SHIFT, hv & ((1 << DEVS_PACK_SHIFT) - 1)); - return devs_handle_high_value(src); - case DEVS_HANDLE_TYPE_BOUND_FUNCTION: - *this_val = devs_value_from_handle(DEVS_HANDLE_TYPE_GC_OBJECT, hv); - return devs_handle_high_value(src); - case DEVS_HANDLE_TYPE_CLOSURE: - *closure = devs_handle_ptr_value(ctx, src); - return devs_handle_high_value(src); - case DEVS_HANDLE_TYPE_GC_OBJECT: { - devs_bound_function_t *bnd = devs_handle_ptr_value(ctx, src); - if (devs_gc_tag(bnd) == DEVS_GC_TAG_BOUND_FUNCTION) { - int r = devs_get_fnidx_core(ctx, bnd->func, this_val, closure, depth + 1); - *this_val = bnd->this_val; - return r; - } else { - return -1; - } - } - default: { - if (devs_is_nullish(src)) - return -1; - value_t f = devs_object_get_built_in_field(ctx, src, DEVS_BUILTIN_STRING___FUNC__); - if (devs_is_undefined(f)) - return -1; - else { - int r = devs_get_fnidx_core(ctx, f, this_val, closure, depth + 1); - *this_val = src; - return r; - } - } - } -} - -int devs_get_fnidx(devs_ctx_t *ctx, value_t src, value_t *this_val, devs_activation_t **closure) { - return devs_get_fnidx_core(ctx, src, this_val, closure, 0); -} - -#define ATTACH_RW 0x01 -#define ATTACH_ENUM 0x02 -#define ATTACH_DIRECT 0x04 - -static void throw_field_error_str(devs_ctx_t *ctx, unsigned attach_flags, const char *objdesc) { - const char *op = attach_flags & ATTACH_RW ? "setting" : "getting"; - char *objd = jd_strdup(objdesc); - - if (devs_is_undefined(ctx->diag_field)) - devs_throw_type_error(ctx, "%s fields of %s", op, objd); - else - devs_throw_type_error(ctx, "%s field '%s' of %s", op, devs_show_value(ctx, ctx->diag_field), - objd); - - jd_free(objd); -} - -static void throw_field_error(devs_ctx_t *ctx, unsigned attach_flags, value_t v) { - throw_field_error_str(ctx, attach_flags, devs_show_value(ctx, v)); -} - -static devs_maplike_t *devs_get_static_proto(devs_ctx_t *ctx, int tp, unsigned attach_flags) { - if ((attach_flags & (ATTACH_DIRECT | ATTACH_ENUM)) == ATTACH_ENUM) - return NULL; - - devs_maplike_t *r = devs_get_builtin_object(ctx, tp); - - // accessing prototype on static object - can't attach properties - if (attach_flags & ATTACH_RW) { - if (attach_flags & ATTACH_DIRECT) { - if (devs_is_builtin_proto(r)) { - throw_field_error_str(ctx, attach_flags, "a builtin frozen object"); - return NULL; - } else { - JD_ASSERT(devs_is_map(r)); - return r; - } - } else { - // note that in ES writing to string/... properties is no-op - // we make it an error - throw_field_error_str(ctx, attach_flags, "a primitive"); - return NULL; - } - } else { - return r; - } -} - -devs_map_t *devs_get_spec_proto(devs_ctx_t *ctx, uint32_t spec_idx) { - value_t r = devs_short_map_get(ctx, ctx->spec_protos, spec_idx); - if (!devs_is_undefined(r)) - return devs_value_to_gc_obj(ctx, r); - - devs_map_t *m = devs_any_try_alloc(ctx, DEVS_GC_TAG_HALF_STATIC_MAP, sizeof(devs_map_t)); - if (m == NULL) - return NULL; - value_t v = devs_value_from_gc_obj(ctx, m); - devs_value_pin(ctx, v); - m->proto = (const void *)devs_img_get_service_spec(ctx->img, spec_idx); - devs_short_map_set(ctx, ctx->spec_protos, spec_idx, v); - devs_value_unpin(ctx, v); - return m; -} - -devs_map_t *devs_get_role_proto(devs_ctx_t *ctx, unsigned roleidx) { - devs_role_t *r = devs_role(ctx, roleidx); - if (!r) - return NULL; - - const devs_service_spec_t *spec = devs_role_spec_for_class(ctx, r->jdrole->service_class); - int idx = devs_spec_idx(ctx, spec); - if (idx < 0) - return NULL; // ??? - - return devs_get_spec_proto(ctx, idx); -} - -static devs_maplike_t *devs_object_get_attached(devs_ctx_t *ctx, value_t v, unsigned attach_flags) { - static const uint8_t proto_by_object_type[] = { - [DEVS_OBJECT_TYPE_NUMBER] = DEVS_BUILTIN_OBJECT_NUMBER_PROTOTYPE, - [DEVS_OBJECT_TYPE_FIBER] = DEVS_BUILTIN_OBJECT_DSFIBER_PROTOTYPE, - [DEVS_OBJECT_TYPE_ROLE] = DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE, - [DEVS_OBJECT_TYPE_FUNCTION] = DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE, - [DEVS_OBJECT_TYPE_STRING] = DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE, - [DEVS_OBJECT_TYPE_BUFFER] = DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE, - [DEVS_OBJECT_TYPE_IMAGE] = DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE, - [DEVS_OBJECT_TYPE_BOOL] = DEVS_BUILTIN_OBJECT_BOOLEAN_PROTOTYPE, - [DEVS_OBJECT_TYPE_EXOTIC] = DEVS_BUILTIN_OBJECT_OBJECT_PROTOTYPE, - }; - - if (devs_is_null_or_undefined(v)) { - throw_field_error(ctx, attach_flags, v); - return NULL; - } - - int htp = devs_handle_type(v); - - if (htp == DEVS_HANDLE_TYPE_ROLE_MEMBER) { - unsigned roleidx; - int pt; - const devs_packet_spec_t *spec = devs_decode_role_packet(ctx, v, &roleidx); - if (roleidx == DEVS_ROLE_INVALID) - pt = devs_value_to_service_spec(ctx, v) ? DEVS_BUILTIN_OBJECT_DSSERVICESPEC_PROTOTYPE - : DEVS_BUILTIN_OBJECT_DSPACKETSPEC_PROTOTYPE; - else - switch (spec->code & DEVS_PACKETSPEC_CODE_MASK) { - case DEVS_PACKETSPEC_CODE_REGISTER: - pt = DEVS_BUILTIN_OBJECT_DSREGISTER_PROTOTYPE; - break; - case DEVS_PACKETSPEC_CODE_EVENT: - pt = DEVS_BUILTIN_OBJECT_DSEVENT_PROTOTYPE; - break; - case DEVS_PACKETSPEC_CODE_COMMAND: - pt = DEVS_BUILTIN_OBJECT_DSCOMMAND_PROTOTYPE; - break; - case DEVS_PACKETSPEC_CODE_REPORT: - pt = DEVS_BUILTIN_OBJECT_DSREPORT_PROTOTYPE; - break; - default: - JD_PANIC(); - } - return devs_get_static_proto(ctx, pt, attach_flags); - } - - if (htp == DEVS_HANDLE_TYPE_ROLE) { - unsigned roleidx = devs_handle_value(v); - devs_role_t *rl = devs_role(ctx, roleidx); - if (!rl) - return NULL; - const void *r = rl->attached; - if (r || (attach_flags & ATTACH_ENUM)) - return r; - r = devs_get_role_proto(ctx, roleidx); - if (!r) - return NULL; - if (attach_flags & ATTACH_RW) { - devs_map_t *m = devs_map_try_alloc(ctx, r); - rl->attached = m; - r = m; - } - return r; - } - - if (htp != DEVS_HANDLE_TYPE_GC_OBJECT) { - int pt = 0; - int tp = devs_value_typeof(ctx, v); - if (tp == DEVS_OBJECT_TYPE_MAP && devs_is_special(v)) { - uint32_t hv = devs_handle_value(v); - if (devs_handle_is_builtin(hv)) - return devs_get_static_proto(ctx, hv - DEVS_SPECIAL_BUILTIN_OBJ_FIRST, - attach_flags | ATTACH_DIRECT); - } - if (tp == DEVS_OBJECT_TYPE_FUNCTION) { - value_t this_val; - devs_activation_t *closure; - int fidx = devs_get_fnidx(ctx, v, &this_val, &closure); - if (fidx >= 0) { - value_t r = devs_short_map_get(ctx, ctx->fn_values, fidx); - if (devs_is_undefined(r) && attach_flags) { - r = devs_value_from_gc_obj( - ctx, - devs_map_try_alloc(ctx, devs_get_builtin_object( - ctx, DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE))); - if (!devs_is_undefined(r)) { - devs_value_pin(ctx, r); - devs_short_map_set(ctx, ctx->fn_values, fidx, r); - devs_value_unpin(ctx, r); - } - } - if (!devs_is_undefined(r)) - return devs_value_to_gc_obj(ctx, r); - } - } - if (tp < (int)sizeof(proto_by_object_type)) { - pt = proto_by_object_type[tp]; - } - JD_ASSERT(pt != 0); - return devs_get_static_proto(ctx, pt, attach_flags); - } - - devs_gc_object_t *obj = devs_handle_ptr_value(ctx, v); - - devs_map_t **attached; - int builtin; - - switch (devs_gc_tag(obj)) { - case DEVS_GC_TAG_BUFFER: - attached = &((devs_buffer_t *)obj)->attached; - builtin = DEVS_BUILTIN_OBJECT_BUFFER_PROTOTYPE; - break; - case DEVS_GC_TAG_IMAGE: - attached = &((devs_gimage_t *)obj)->attached; - builtin = DEVS_BUILTIN_OBJECT_IMAGE_PROTOTYPE; - break; - case DEVS_GC_TAG_ARRAY: - attached = &((devs_array_t *)obj)->attached; - builtin = DEVS_BUILTIN_OBJECT_ARRAY_PROTOTYPE; - break; - case DEVS_GC_TAG_PACKET: - attached = &((devs_packet_t *)obj)->attached; - builtin = DEVS_BUILTIN_OBJECT_DSPACKET_PROTOTYPE; - break; - case DEVS_GC_TAG_HALF_STATIC_MAP: - case DEVS_GC_TAG_MAP: - return (devs_maplike_t *)obj; - case DEVS_GC_TAG_STRING_JMP: - case DEVS_GC_TAG_STRING: - return devs_get_static_proto(ctx, DEVS_BUILTIN_OBJECT_STRING_PROTOTYPE, attach_flags); - case DEVS_GC_TAG_BOUND_FUNCTION: - return devs_get_static_proto(ctx, DEVS_BUILTIN_OBJECT_FUNCTION_PROTOTYPE, attach_flags); - case DEVS_GC_TAG_BUILTIN_PROTO: - case DEVS_GC_TAG_SHORT_MAP: - default: - JD_PANIC(); - break; - } - - devs_map_t *map = *attached; - - if (!map && (attach_flags & ATTACH_RW)) { - map = *attached = devs_map_try_alloc(ctx, devs_get_builtin_object(ctx, builtin)); - if (map == NULL) - return NULL; - } - - if (map || (attach_flags & ATTACH_ENUM)) - return (devs_maplike_t *)map; - else - return devs_get_builtin_object(ctx, builtin); -} - -devs_map_t *devs_object_get_attached_rw(devs_ctx_t *ctx, value_t v) { - const void *r = devs_object_get_attached(ctx, v, ATTACH_RW); - JD_ASSERT(r == NULL || devs_is_map(r)); - ctx->diag_field = devs_undefined; - return (void *)r; -} - -devs_maplike_t *devs_object_get_attached_ro(devs_ctx_t *ctx, value_t v) { - devs_maplike_t *r = devs_object_get_attached(ctx, v, 0); - ctx->diag_field = devs_undefined; - return r; -} - -devs_maplike_t *devs_object_get_attached_enum(devs_ctx_t *ctx, value_t v) { - devs_maplike_t *r = devs_object_get_attached(ctx, v, ATTACH_ENUM); - ctx->diag_field = devs_undefined; - return r; -} - -devs_maplike_t *devs_maplike_get_proto(devs_ctx_t *ctx, devs_maplike_t *obj) { - const void *res; - - if (devs_is_builtin_proto(obj)) { - res = ((const devs_builtin_proto_t *)obj)->parent; - } else if (devs_is_service_spec(ctx, obj)) { - res = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE); - } else { - JD_ASSERT(devs_is_map(obj)); - devs_map_t *map = (devs_map_t *)obj; - return map->proto; - } - - if (res == NULL) - res = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_OBJECT_PROTOTYPE); - if (obj == res) // Object.prototype.__proto__ == NULL - return NULL; - return res; -} - -devs_maplike_t *devs_get_prototype_field(devs_ctx_t *ctx, value_t cls) { - value_t cls_proto_val = devs_object_get_built_in_field(ctx, cls, DEVS_BUILTIN_STRING_PROTOTYPE); - if (devs_is_undefined(cls_proto_val)) { - if (!ctx->in_throw) - devs_throw_type_error(ctx, "no .prototype"); - return NULL; - } else { - devs_maplike_t *cls_proto = devs_object_get_attached_enum(ctx, cls_proto_val); - if (cls_proto == NULL) - devs_throw_type_error(ctx, "invalid .prototype"); - return cls_proto; - } -} - -bool devs_instance_of(devs_ctx_t *ctx, value_t obj, devs_maplike_t *cls_proto) { - if (cls_proto == NULL || devs_is_nullish(obj)) - return false; - - devs_maplike_t *proto = devs_object_get_attached_ro(ctx, obj); - devs_maplike_t *en = devs_object_get_attached_enum(ctx, obj); - if (proto && proto == en) - proto = devs_maplike_get_proto(ctx, proto); - if (proto == NULL) - return false; - - while (proto) { - if (cls_proto == proto) - return true; - proto = devs_maplike_get_proto(ctx, proto); - } - - return false; -} - -value_t devs_maplike_get_no_bind(devs_ctx_t *ctx, devs_maplike_t *proto, value_t key) { - value_t ptmp, *tmp = NULL; - - while (proto) { - devs_map_t *map; - if (devs_is_builtin_proto(proto)) { - ptmp = devs_proto_lookup(ctx, (const devs_builtin_proto_t *)proto, key); - tmp = &ptmp; - break; - } else if (devs_is_service_spec(ctx, proto)) { - ptmp = devs_spec_lookup(ctx, (const devs_service_spec_t *)proto, key); - if (!devs_is_undefined(ptmp)) { - tmp = &ptmp; - break; - } else { - proto = devs_get_builtin_object(ctx, DEVS_BUILTIN_OBJECT_DSROLE_PROTOTYPE); - continue; - } - } else { - JD_ASSERT(devs_is_map(proto)); - map = (devs_map_t *)proto; - tmp = lookup(ctx, map, key); - if (tmp) - break; - } - - proto = map->proto; - } - - if (tmp == NULL) - return devs_undefined; - return *tmp; -} - -value_t devs_object_get(devs_ctx_t *ctx, value_t obj, value_t key) { - ctx->diag_field = key; - value_t tmp = devs_maplike_get_no_bind(ctx, devs_object_get_attached_ro(ctx, obj), key); - return devs_function_bind(ctx, obj, tmp); -} - -value_t devs_object_get_built_in_field(devs_ctx_t *ctx, value_t obj, unsigned idx) { - value_t key = devs_builtin_string(idx); - ctx->diag_field = key; - value_t fn = devs_maplike_get_no_bind(ctx, devs_object_get_attached_ro(ctx, obj), key); - const devs_builtin_function_t *h = devs_get_property_desc(ctx, fn); - if (h) - return h->handler.prop(ctx, obj); - return fn; -} - -value_t devs_seq_get(devs_ctx_t *ctx, value_t seq, unsigned idx) { - if (idx > DEVS_MAX_ALLOC) - return devs_undefined; - - unsigned len; - const uint8_t *p = devs_bufferish_data(ctx, seq, &len); - if (p && idx < len) { - if (devs_is_string(ctx, seq)) { - int off = devs_string_index(ctx, seq, idx); - if (off < 0) - return devs_undefined; - p += off; - unsigned len = devs_utf8_code_point_length((const char *)p); - return devs_value_from_gc_obj(ctx, - devs_string_try_alloc_init(ctx, (const char *)p, len)); - } - return devs_value_from_int(p[idx]); - } - - devs_array_t *arr = devs_value_to_gc_obj(ctx, seq); - if (devs_gc_tag(arr) == DEVS_GC_TAG_ARRAY) { - if (idx < arr->length) - return arr->data[idx]; - } - - return devs_undefined; -} - -bool devs_looks_indexable(devs_ctx_t *ctx, value_t seq) { - return devs_is_array(ctx, seq) || devs_is_buffer(ctx, seq) || devs_is_string(ctx, seq); -} - -value_t devs_any_get(devs_ctx_t *ctx, value_t obj, value_t key) { - if (devs_is_number(key) && devs_looks_indexable(ctx, obj)) { - unsigned idx = devs_value_to_int(ctx, key); - return devs_seq_get(ctx, obj, idx); - } else if (devs_is_string(ctx, key)) { - return devs_object_get(ctx, obj, key); - } else { - key = devs_value_to_string(ctx, key); - devs_value_pin(ctx, key); - value_t res = devs_object_get(ctx, obj, key); - devs_value_unpin(ctx, key); - return res; - } -} - -void devs_any_set(devs_ctx_t *ctx, value_t obj, value_t key, value_t v) { - if (devs_is_number(key) && devs_looks_indexable(ctx, obj)) { - unsigned idx = devs_value_to_int(ctx, key); - devs_seq_set(ctx, obj, idx, v); - } else { - ctx->diag_field = key; - devs_map_t *map = devs_object_get_attached_rw(ctx, obj); - if (!map) - return; - if (devs_is_string(ctx, key)) - devs_map_set(ctx, map, key, v); - else { - key = devs_value_to_string(ctx, key); - devs_value_pin(ctx, key); - devs_map_set(ctx, map, key, v); - devs_value_unpin(ctx, key); - } - } -} - -static int array_ensure_len(devs_ctx_t *ctx, devs_array_t *arr, unsigned newlen) { - if (arr->capacity < newlen) { - newlen = grow_len(newlen); - value_t *newarr = devs_try_alloc(ctx, newlen * sizeof(value_t)); - if (newarr == NULL) - return -1; - if (arr->data) - memcpy(newarr, arr->data, sizeof(value_t) * arr->length); - arr->data = newarr; - arr->capacity = newlen; - jd_gc_unpin(ctx->gc, newarr); - } - return 0; -} - -void devs_array_set(devs_ctx_t *ctx, devs_array_t *arr, unsigned idx, value_t v) { - if (idx > DEVS_MAX_ALLOC / sizeof(value_t)) - devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); - else { - if (array_ensure_len(ctx, arr, idx + 1) != 0) - return; - arr->data[idx] = v; - if (idx >= arr->length) - arr->length = idx + 1; - } -} - -void devs_array_pin_push(devs_ctx_t *ctx, devs_array_t *arr, value_t v) { - devs_value_pin(ctx, v); - devs_array_set(ctx, arr, arr->length, v); - devs_value_unpin(ctx, v); -} - -void devs_seq_set(devs_ctx_t *ctx, value_t seq, unsigned idx, value_t v) { - if (idx > DEVS_MAX_ALLOC) { - devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); - } else if (devs_buffer_is_writable(ctx, seq)) { - unsigned len; - uint8_t *p = devs_buffer_data(ctx, seq, &len); - if (idx < len) { - p[idx] = devs_value_to_int(ctx, v) & 0xff; - } else { - devs_throw_range_error(ctx, "buffer write at %u, len=%u", idx, len); - } - } else { - devs_array_t *arr = devs_value_to_gc_obj(ctx, seq); - if (devs_gc_tag(arr) == DEVS_GC_TAG_ARRAY) { - devs_array_set(ctx, arr, idx, v); - } else { - devs_throw_expecting_error(ctx, DEVS_BUILTIN_STRING_ARRAY, seq); - } - } -} - -int devs_array_insert(devs_ctx_t *ctx, devs_array_t *arr, unsigned idx, int count) { - if (count > (int)(DEVS_MAX_ALLOC / sizeof(value_t))) { - devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); - return -4; - } - - int newlen = arr->length + count; - if (newlen < 0) { - count = -arr->length; - newlen = 0; - } - - if (count == 0) - return 0; - - if (newlen > (int)(DEVS_MAX_ALLOC / sizeof(value_t))) { - devs_throw_too_big_error(ctx, DEVS_BUILTIN_STRING_ARRAY); - return -6; - } - - if (idx > arr->length) - idx = arr->length; - - if (array_ensure_len(ctx, arr, newlen)) - return -5; - - unsigned trailing = arr->length - idx; - - if (count < 0) { - count = -count; - memmove(arr->data + idx, arr->data + idx + count, sizeof(value_t) * (trailing - count)); - } else { - memmove(arr->data + idx + count, arr->data + idx, sizeof(value_t) * trailing); - memset(arr->data + idx, 0, count * sizeof(value_t)); - } - arr->length = newlen; - - return 0; -} - -int32_t devs_arg_int_defl(devs_ctx_t *ctx, unsigned idx, int32_t defl) { - value_t arg = devs_arg(ctx, idx); - if (devs_is_null_or_undefined(arg)) - return defl; - return devs_value_to_int(ctx, arg); -} - -int32_t devs_arg_int(devs_ctx_t *ctx, unsigned idx) { - return devs_value_to_int(ctx, devs_arg(ctx, idx)); -} - -bool devs_arg_bool(devs_ctx_t *ctx, unsigned idx) { - return devs_value_to_bool(ctx, devs_arg(ctx, idx)); -} - -double devs_arg_double(devs_ctx_t *ctx, unsigned idx) { - return devs_value_to_double(ctx, devs_arg(ctx, idx)); -} - -const char *devs_arg_utf8_with_conv(devs_ctx_t *ctx, unsigned idx, unsigned *sz) { - // store it on the stack, so it doesn't get GCed - ctx->the_stack[idx + 1] = devs_value_to_string(ctx, devs_arg(ctx, idx)); - return devs_string_get_utf8(ctx, devs_arg(ctx, idx), sz); -} - -void devs_ret_double(devs_ctx_t *ctx, double v) { - devs_ret(ctx, devs_value_from_double(v)); -} - -void devs_ret_int(devs_ctx_t *ctx, int v) { - devs_ret(ctx, devs_value_from_int(v)); -} - -void devs_ret_bool(devs_ctx_t *ctx, bool v) { - devs_ret(ctx, devs_value_from_bool(v)); -} - -void devs_ret_gc_ptr(devs_ctx_t *ctx, void *v) { - devs_ret(ctx, devs_value_from_gc_obj(ctx, v)); -} - -devs_map_t *devs_arg_self_map(devs_ctx_t *ctx) { - value_t s = devs_arg_self(ctx); - void *p = devs_value_to_gc_obj(ctx, s); - if (devs_is_map(p)) - return p; - devs_throw_type_error(ctx, "object expected"); - return NULL; -} - -void devs_setup_resume(devs_fiber_t *f, devs_resume_cb_t cb, void *userdata) { - if (devs_did_yield(f->ctx)) { - f->resume_cb = cb; - f->resume_data = userdata; - } else { - cb(f->ctx, userdata); - } -} - -bool devs_can_attach(devs_ctx_t *ctx, value_t v) { - switch (devs_value_typeof(ctx, v)) { - case DEVS_OBJECT_TYPE_MAP: - case DEVS_OBJECT_TYPE_ROLE: - case DEVS_OBJECT_TYPE_ARRAY: - case DEVS_OBJECT_TYPE_BUFFER: - case DEVS_OBJECT_TYPE_IMAGE: - return true; - default: - return false; - } -} - -value_t devs_builtin_object_value(devs_ctx_t *ctx, unsigned idx) { - if (idx > DEVS_BUILTIN_OBJECT___MAX) - return devs_undefined; - - devs_maplike_t *p = devs_get_builtin_object(ctx, idx); - if (devs_is_builtin_proto(p)) - return devs_value_from_handle(DEVS_HANDLE_TYPE_SPECIAL, - DEVS_SPECIAL_BUILTIN_OBJ_FIRST + idx); - else - return devs_value_from_gc_obj(ctx, (void *)p); -} - -value_t devs_maplike_to_value(devs_ctx_t *ctx, devs_maplike_t *obj) { - if (devs_is_builtin_proto(obj)) { - return devs_builtin_object_value(ctx, - (const devs_builtin_proto_t *)obj - devs_builtin_protos); - } else if (devs_is_service_spec(ctx, obj)) { - // this shouldn't happen - return devs_undefined; - } else { - JD_ASSERT(devs_is_map(obj)); - devs_map_t *map = (devs_map_t *)obj; - if (devs_gc_tag(map) == DEVS_GC_TAG_HALF_STATIC_MAP && devs_is_builtin_proto(map->proto)) - return devs_maplike_to_value(ctx, map->proto); - return devs_value_from_gc_obj(ctx, map); - } -} \ No newline at end of file diff --git a/controllers/aici_abi/implementation.md b/controllers/aici_abi/implementation.md deleted file mode 100644 index 1fadb63c..00000000 --- a/controllers/aici_abi/implementation.md +++ /dev/null @@ -1,60 +0,0 @@ -# Implementation notes - -## LR(1) parsing - -The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. -DFA has a single number as the state, while the state of the LR(1) is a stack of numbers. -The LR(1) action is determined based on the next token from the lexer and the top of the stack. - -The `Recognizer` interface also has a concept of stack, however every entry on that -stack contains a DFA state and an LR(1) stack. - -Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state, -while the LR(1) stack is copied unchanged (the memory is shared). - - -### Early error detection - -Consider the following invalid C program: - -```c -int 123456; -``` - -The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword. -The parser would reject `123456`, however only after all six characters of it have been read. -This is too late for the LLM. - -To detect such errors early, we compute a set of reachable tokens for each DFA state. -For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`). -The initial DFA state has a full set of tokens, while a state after `'i'` -has only `int`, `if`, and `ID`, -and a state after `'1'` includes only `INTLIT`. -In the picture below, each state is labelled by its reachable set, -and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity. - -```mermaid -graph LR - 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}")) - 0 -- "[a-z] - [i]" --> id(("{ID*}")) - 0 -- "[0-9]" --> const(("{INTLIT*}")) - const -- "[0-9]" --> const - const -- "[a-z]" --> bot["{}"] - i -- "[a-z0-9] - [nf]" --> id - id -- "[a-z0-9]" --> id - i -- "[n]" --> in(("{int,ID*}")) - in -- "[t]" --> int(("{int*,ID}")) - in -- "[a-z0-9] - [t]" --> id - int -- "[a-z0-9]" --> id - i -- "[f]" --> if(("{if*,ID}")) - if -- "[a-z0-9]" --> id -``` - -For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do -not immediately lead to an error. - -While parsing input, if the intersection of viable and reachable tokens is empty, we report an error. - -In the example above, the viable tokens after `int` do not include `INTLIT`, -and thus the parser fails immediately at `1`. - diff --git a/controllers/aici_abi/src/cfg.rs b/controllers/aici_abi/src/cfg.rs deleted file mode 100644 index c0fb412e..00000000 --- a/controllers/aici_abi/src/cfg.rs +++ /dev/null @@ -1,597 +0,0 @@ -use crate::host::host_trie; -use crate::lex::{Lexer, LexerState, StateID, VobIdx, VobSet}; -use crate::{ - toktrie::{Recognizer, SpecialToken}, - SimpleVob, -}; -use anyhow::Result; -use cfgrammar::{ - yacc::{YaccGrammar, YaccKind}, - Span, Spanned, Symbol, TIdx, -}; -use lrtable::{from_yacc, Action, Minimiser, StIdx, StateTable}; -use rustc_hash::FxHashMap; -use std::{cell::RefCell, vec}; -use vob::{vob, Vob}; - -type StorageT = u32; -type PStack = Vec>; // Parse stack - -const LOG_PARSER: bool = false; - -#[derive(Debug, Clone, Copy)] -enum ParseResult { - Accept, - Error, - Continue, -} - -struct CfgStats { - yacc_actions: usize, - states_pushed: usize, -} - -pub struct CfgParser { - grm: YaccGrammar, - stable: StateTable, - lexer: Lexer, - byte_states: Vec, - pat_idx_to_tidx: Vec>, - vobset: VobSet, - stats: RefCell, - tidx_to_pat_idx: FxHashMap, usize>, - parse_stacks: Vec>>, - skip_patterns: Vob, - friendly_pattern_names: Vec, - viable_vobidx_by_state: Vec, -} - -fn is_rx(name: &str) -> bool { - name.len() > 2 && name.starts_with("/") && name.ends_with("/") -} - -fn quote_rx(name: &str) -> String { - name.chars() - .map(|ch| { - if ('0' <= ch && ch <= '9') - || ('a' <= ch && ch <= 'z') - || ('A' <= ch && ch <= 'Z') - || '<' == ch - || '>' == ch - { - ch.to_string() - } else { - format!("\\{}", ch) - } - }) - .collect::() -} - -pub(crate) fn parse_rx_token(name: &str) -> String { - if is_rx(name) { - name[1..name.len() - 1].to_string() - } else { - quote_rx(name) - } -} - -fn span_to_str(s: &Span, src: &str) -> String { - let mut line = 1; - let mut last_nl = 0; - for (idx, ch) in src.chars().enumerate() { - if idx == s.start() { - break; - } - if ch == '\n' { - line += 1; - last_nl = idx; - } - } - let column = s.start() - last_nl; - format!("({},{})", line, column) -} - -pub(crate) fn parse_yacc(yacc: &str) -> Result { - let grmkind = YaccKind::Original(cfgrammar::yacc::YaccOriginalActionKind::NoAction); - let grm = match YaccGrammar::new(grmkind, yacc) { - Ok(grm) => grm, - Err(e) => { - let err_str = e - .iter() - .map(|e| { - let spans = e - .spans() - .iter() - .map(|s| span_to_str(s, yacc)) - .collect::>() - .join(", "); - format!("{}: {}", spans, e) - }) - .collect::>() - .join("\n"); - anyhow::bail!("yacc grammar errors:\n{}", err_str); - } - }; - Ok(grm) -} - -impl CfgParser { - pub fn from_yacc(yacc: &str) -> Result { - let grm = parse_yacc(yacc)?; - // TIME: all these annotation are for native release x86 build for C grammar - // TIME: 27ms - let (sgraph, stable) = match from_yacc(&grm, Minimiser::Pager) { - Ok(r) => r, - Err(e) => { - if false { - // not sure this works: - anyhow::bail!("state table error:\n{e} on {:?}", grm.action(e.pidx)); - } - anyhow::bail!("state table error:\n{e}"); - } - }; - - if false { - println!("core\n{}\n\n", sgraph.pp(&grm, true)); - for pidx in grm.iter_pidxs() { - let prod = grm.prod(pidx); - println!("{:?} -> {}", prod, prod.len()); - } - } - - let mut pat_idx_to_tidx = grm - .iter_tidxs() - .filter(|tidx| grm.token_name(*tidx).is_some()) - .collect::>(); - - pat_idx_to_tidx.sort_by_key(|tidx| { - let name = grm.token_name(*tidx).unwrap(); - let l = name.len() as isize; - if is_rx(name) { - -l + 100000 - } else { - -l - } - }); - - let patterns = pat_idx_to_tidx - .iter() - .map(|tok| parse_rx_token(grm.token_name(*tok).unwrap())) - .collect::>(); - - let mut tidx_to_pat_idx = FxHashMap::default(); - for (idx, _tok) in patterns.iter().enumerate() { - tidx_to_pat_idx.insert(pat_idx_to_tidx[idx], idx); - } - - let mut skip_patterns = vob![false; patterns.len()]; - let friendly_pattern_names = pat_idx_to_tidx - .iter() - .map(|tok| grm.token_name(*tok).unwrap().to_string()) - .collect::>(); - - for ridx in grm.iter_rules() { - let rule_name = grm.rule_name_str(ridx); - if rule_name.to_uppercase() != rule_name { - continue; - } - for pidx in grm.rule_to_prods(ridx) { - let toks = grm.prod(*pidx); - if let [Symbol::Token(tidx)] = toks { - let idx = *tidx_to_pat_idx.get(&tidx).unwrap(); - // this doesn't seem very useful - // friendly_pattern_names[idx] = rule_name.to_string(); - if rule_name == "SKIP" { - skip_patterns.set(idx, true); - } - } - } - } - - println!("patterns: {:?}", friendly_pattern_names); - - let mut vobset = VobSet::new(); - // all-zero has to be inserted first - let _all0 = vobset.insert_or_get(&vob![false; patterns.len()]); - let all1 = vobset.insert_or_get(&vob![true; patterns.len()]); - - // TIME: 27ms - let dfa = Lexer::from(patterns, &mut vobset); - - let cfg_start = stable.start_state(); - let parse_stacks = vec![vec![cfg_start]]; - - let byte_state = ByteState { - lexer_state: dfa.file_start_state(), - parse_stack_idx: PStackIdx(0), - viable: all1, - }; - - let viable_vobidx_by_state = sgraph - .iter_stidxs() - .enumerate() - .map(|(idx, stidx)| { - assert!(idx == stidx.as_storaget() as usize); - - // skip patterns (whitespace) are always viable - let mut r = skip_patterns.clone(); - for tidx in stable.state_actions(stidx) { - match stable.action(stidx, tidx) { - Action::Error => {} - _ => { - if let Some(pat_idx) = tidx_to_pat_idx.get(&tidx) { - r.set(*pat_idx, true); - } - } - } - } - - vobset.insert_or_get(&r) - }) - .collect::>(); - - let mut cfg = CfgParser { - grm, - stable, - lexer: dfa, - byte_states: vec![byte_state], - pat_idx_to_tidx, - tidx_to_pat_idx, - viable_vobidx_by_state, - skip_patterns, - friendly_pattern_names, - parse_stacks, - vobset, - stats: RefCell::new(CfgStats { - yacc_actions: 0, - states_pushed: 0, - }), - }; - - cfg.vobset.pre_compute(); - - // compute viable set of initial tokens - cfg.byte_states[0].viable = cfg.viable_vobidx(cfg_start); - if LOG_PARSER { - println!( - "initial viable: {:?}", - cfg.vobset.resolve(cfg.byte_states[0].viable) - ); - } - - Ok(cfg) - } - - fn viable_vobidx(&self, stidx: StIdx) -> VobIdx { - self.viable_vobidx_by_state[stidx.as_storaget() as usize] - } - - #[allow(dead_code)] - fn friendly_token_name(&self, lexeme: TIdx) -> &str { - if let Some(pidx) = self.tidx_to_pat_idx.get(&lexeme) { - &self.friendly_pattern_names[*pidx] - } else if self.grm.eof_token_idx() == lexeme { - return ""; - } else { - return ""; - } - } - - fn parse_lexeme(&self, lexeme: TIdx, pstack: &mut PStack) -> ParseResult { - loop { - let stidx = *pstack.last().unwrap(); - - let act = self.stable.action(stidx, lexeme); - - if LOG_PARSER { - println!( - "parse: {:?} {:?} -> {:?}", - pstack, - self.friendly_token_name(lexeme), - act - ); - } - - match act { - Action::Reduce(pidx) => { - let ridx = self.grm.prod_to_rule(pidx); - let pop_idx = pstack.len() - self.grm.prod(pidx).len(); - pstack.drain(pop_idx..); - let prior = *pstack.last().unwrap(); - pstack.push(self.stable.goto(prior, ridx).unwrap()); - } - Action::Shift(state_id) => { - pstack.push(state_id); - return ParseResult::Continue; - } - Action::Accept => { - // only happens when lexeme is EOF - return ParseResult::Accept; - } - Action::Error => { - return ParseResult::Error; - } - } - } - } - - #[allow(dead_code)] - fn print_viable(&self, lbl: &str, vob: &Vob) { - println!("viable tokens {}:", lbl); - for (idx, b) in vob.iter().enumerate() { - if b { - println!(" {}: {}", idx, self.friendly_pattern_names[idx]); - } - } - } - - // None means EOF - #[inline(always)] - fn try_push(&mut self, byte: Option) -> Option { - let top = self.byte_states.last().unwrap().clone(); - if LOG_PARSER { - print!("try_push[{}]: ", self.byte_states.len()); - if let Some(b) = byte { - print!("{:?}", b as char) - } else { - print!("") - } - } - let (info, res) = match self.lexer.advance(top.lexer_state, byte) { - // Error? - None => ("lex-err", None), - // Just new state, no token - the hot path - Some((ls, None)) => ( - "lex", - self.mk_byte_state(ls, top.parse_stack_idx, top.viable), - ), - // New state and token generated - Some((ls, Some(pat_idx))) => ("parse", self.run_parser(pat_idx, &top, ls)), - }; - if LOG_PARSER { - println!( - " -> {} {}", - info, - if res.is_none() { "error" } else { "ok" } - ); - } - res - } - - fn pstack_for(&self, top: &ByteState) -> &PStack { - &self.parse_stacks[top.parse_stack_idx.0] - } - - fn push_pstack(&mut self, top: &ByteState, pstack: Vec>) -> PStackIdx { - let new_idx = PStackIdx(top.parse_stack_idx.0 + 1); - if self.parse_stacks.len() <= new_idx.0 { - self.parse_stacks.push(Vec::new()); - } - self.parse_stacks[new_idx.0] = pstack; - new_idx - } - - fn run_parser(&mut self, pat_idx: usize, top: &ByteState, ls: LexerState) -> Option { - { - let mut s = self.stats.borrow_mut(); - s.yacc_actions += 1; - } - if LOG_PARSER { - println!(); - } - let pstack = self.pstack_for(top); - if self.skip_patterns[pat_idx] { - let stidx = *pstack.last().unwrap(); - let viable = self.viable_vobidx(stidx); - //self.print_viable("reset", &viable); - if LOG_PARSER { - println!("parse: {:?} skip", pstack); - } - // reset viable states - they have been narrowed down to SKIP - self.mk_byte_state(ls, top.parse_stack_idx, viable) - } else { - let tidx = self.pat_idx_to_tidx[pat_idx]; - let mut pstack = pstack.clone(); - match self.parse_lexeme(tidx, &mut pstack) { - ParseResult::Accept => panic!("accept non EOF?"), - ParseResult::Continue => { - let stidx = *pstack.last().unwrap(); - let viable = self.viable_vobidx(stidx); - let new_idx = self.push_pstack(top, pstack); - self.mk_byte_state(ls, new_idx, viable) - } - ParseResult::Error => None, - } - } - } - - #[allow(dead_code)] - pub fn viable_now(&self) { - let v = self.byte_states.last().unwrap().viable; - self.print_viable("now", self.vobset.resolve(v)) - } - - pub fn get_stats(&self) -> String { - let mut s = self.stats.borrow_mut(); - let r = format!("yacc: {}/{}", s.yacc_actions, s.states_pushed); - s.yacc_actions = 0; - s.states_pushed = 0; - r - } - - fn mk_byte_state( - &self, - ls: LexerState, - pstack: PStackIdx, - viable: VobIdx, - ) -> Option { - { - let mut s = self.stats.borrow_mut(); - s.states_pushed += 1; - } - if self.vobset.and_is_zero(viable, ls.reachable) { - None - } else { - // print!( - // " {:?} {:?} ", - // self.vobset.resolve(viable), - // self.vobset.resolve(ls.reachable) - // ); - Some(ByteState { - lexer_state: ls.state, - parse_stack_idx: pstack, - viable, - }) - } - } -} - -#[derive(Clone, Copy)] -struct PStackIdx(usize); - -#[derive(Clone)] -struct ByteState { - lexer_state: StateID, - parse_stack_idx: PStackIdx, - viable: VobIdx, -} - -impl Recognizer for CfgParser { - fn pop_bytes(&mut self, num: usize) { - self.byte_states.truncate(self.byte_states.len() - num); - } - - fn collapse(&mut self) { - let final_state = self.byte_states.pop().unwrap(); - self.byte_states.clear(); - self.byte_states.push(final_state); - } - - fn special_allowed(&mut self, tok: SpecialToken) -> bool { - match tok { - SpecialToken::EndOfSentence => { - if let Some(st) = self.try_push(None) { - let tidx = self.grm.eof_token_idx(); - let mut pstack = self.pstack_for(&st).clone(); - match self.parse_lexeme(tidx, &mut pstack) { - ParseResult::Accept => true, - _ => false, - } - } else { - false - } - } - _ => false, - } - } - - fn trie_finished(&mut self) { - assert!(self.byte_states.len() == 1); - } - - #[inline(always)] - fn try_push_byte(&mut self, byte: u8) -> bool { - if let Some(st) = self.try_push(Some(byte)) { - self.byte_states.push(st); - true - } else { - false - } - } -} - -#[allow(dead_code)] -pub fn cfg_test() -> Result<()> { - let yacc_bytes = include_bytes!("../grammars/c.y"); - let mut cfg = CfgParser::from_yacc(&String::from_utf8_lossy(yacc_bytes)).unwrap(); - let sample = include_bytes!("../grammars/sample.c"); - - if true { - let trie = host_trie(); - let toks = trie.greedy_tokenize(sample); - - #[cfg(not(target_arch = "wasm32"))] - let t0 = std::time::Instant::now(); - - let mut line = 1; - let mut vob = SimpleVob::new(); - vob.resize(trie.vocab_size() + 1); - - for tok in &toks[0..1000] { - let tok = *tok; - trie.compute_bias(&mut cfg, &mut vob); - if !vob.is_allowed(tok) { - println!("reject, line={}, tok={:?}", line, trie.token_str(tok)); - panic!(); - } - for b in trie.token(tok) { - if *b == b'\n' { - line += 1; - } - } - if false { - println!( - "tok: {:?} {}; {}", - trie.token_str(tok), - vob.is_allowed(tok), - cfg.get_stats() - ); - cfg.viable_now(); - } - trie.append_token(&mut cfg, tok).unwrap(); - } - - #[cfg(not(target_arch = "wasm32"))] - println!("time: {:?} ", t0.elapsed()); - - println!("stats: {}", cfg.get_stats()); - } - - if false { - let mut rng = crate::rng::Rng::new(0); - let mut ok = true; - let mut idx = 0; - while idx < sample.len() { - let b = sample[idx]; - // println!("idx {} {:?}", idx, b as char); - let r = cfg.try_push_byte(b); - if !r { - ok = false; - println!( - "reject at\n{:?}\n{:?}", - String::from_utf8_lossy(&sample[idx.saturating_sub(50)..idx]), - String::from_utf8_lossy(&sample[idx..std::cmp::min(idx + 30, sample.len())]) - ); - break; - } - idx += 1; - - if false { - let max_pop = cfg.byte_states.len() - 1; - if max_pop > 0 && rng.gen_up_to(4) == 0 { - let num = rng.gen_up_to(max_pop - 1) + 1; - // println!("pop {} {}", num, cfg.byte_states.len()); - cfg.pop_bytes(num); - idx -= num; - } - - if rng.gen_up_to(10) == 0 { - // println!("collapse"); - cfg.collapse(); - } - } - } - - if ok { - if cfg.special_allowed(SpecialToken::EndOfSentence) { - println!("accept EOS"); - } else { - println!("reject EOS"); - } - } else { - println!("reject"); - } - } - - Ok(()) -} diff --git a/controllers/aici_abi/src/dlex.rs b/controllers/aici_abi/src/dlex.rs deleted file mode 100644 index df275fb7..00000000 --- a/controllers/aici_abi/src/dlex.rs +++ /dev/null @@ -1,266 +0,0 @@ -use crate::{ - recognizer::{FunctionalRecognizer, StackRecognizer}, - toktrie::SpecialToken, - SimpleVob, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct NodeId(u32); - -impl NodeId { - const NULL: NodeId = NodeId(0); - const ROOT: NodeId = NodeId(1); -} - -#[derive(Debug, Default, Clone)] -pub struct NodeData { - pub is_terminal: bool, -} - -enum TrieNode { - Sparse { - data: NodeData, - children: Vec<(u8, NodeId)>, - }, - Dense { - data: NodeData, - children: Vec, - }, -} - -impl TrieNode { - fn new_dense(data: NodeData, children: &Vec<(u8, NodeId)>) -> Self { - let mut dense_children = vec![NodeId::NULL; 256]; - for (byte, node_id) in children { - dense_children[*byte as usize] = *node_id; - } - TrieNode::Dense { - data, - children: dense_children, - } - } - - fn new_leaf() -> Self { - TrieNode::Sparse { - data: NodeData::default(), - children: vec![], - } - } - - fn data(&self) -> &NodeData { - match self { - TrieNode::Sparse { data, .. } => data, - TrieNode::Dense { data, .. } => data, - } - } - - fn data_mut(&mut self) -> &mut NodeData { - match self { - TrieNode::Sparse { data, .. } => data, - TrieNode::Dense { data, .. } => data, - } - } -} - -pub struct Trie { - nodes: Vec, -} - -impl Trie { - const MAX_SPARSE: usize = 8; - - pub fn new() -> Self { - Trie { - nodes: vec![ - TrieNode::new_leaf(), - TrieNode::new_dense(NodeData::default(), &vec![]), - ], - } - } - - fn node(&self, node_id: NodeId) -> &TrieNode { - &self.nodes[node_id.0 as usize] - } - - fn node_mut(&mut self, node_id: NodeId) -> &mut TrieNode { - &mut self.nodes[node_id.0 as usize] - } - - pub fn node_data(&self, node_id: NodeId) -> &NodeData { - self.node(node_id).data() - } - - pub fn root(&self) -> NodeId { - NodeId::ROOT - } - - pub fn child_at(&self, start: NodeId, b: u8) -> Option { - match self.node(start) { - TrieNode::Sparse { children, .. } => { - children.iter().find_map( - |&(byte, node_id)| { - if byte == b { - Some(node_id) - } else { - None - } - }, - ) - } - TrieNode::Dense { children, .. } => { - let node_id = children[b as usize]; - if node_id == NodeId::NULL { - None - } else { - Some(node_id) - } - } - } - } - - pub fn lookup(&self, start: NodeId, word: &[u8]) -> Option { - let mut node_id = start; - for &byte in word { - match self.child_at(node_id, byte) { - Some(child_id) => { - node_id = child_id; - } - None => { - return None; - } - } - } - Some(node_id) - } - - pub fn add(&mut self, word: &[u8]) { - let mut node_id = NodeId::ROOT; - for &byte in word { - let new_node_id = NodeId(self.nodes.len() as u32); - let node = self.node_mut(node_id); - match node { - TrieNode::Sparse { data, children } => { - match children.iter().find(|&&(b, _)| b == byte) { - Some(&(_, child_id)) => { - node_id = child_id; - } - None => { - children.push((byte, new_node_id)); - if children.len() > Trie::MAX_SPARSE { - self.nodes[node_id.0 as usize] = - TrieNode::new_dense(data.clone(), children); - } - self.nodes.push(TrieNode::new_leaf()); - node_id = new_node_id; - } - } - } - TrieNode::Dense { children, .. } => { - node_id = children[byte as usize]; - if node_id == NodeId::NULL { - children[byte as usize] = new_node_id; - self.nodes.push(TrieNode::new_leaf()); - node_id = new_node_id; - } - } - } - } - - self.node_mut(node_id).data_mut().is_terminal = true; - } -} - -pub struct DynamicLexer { - trie: Trie, - id_start: SimpleVob, - id_body: SimpleVob, -} - -#[derive(Debug, Clone, Copy)] -pub struct DState { - node_id: NodeId, -} - -impl DState { - const ROOT: DState = DState { - node_id: NodeId::ROOT, - }; -} - -pub type DynamicLexerRec = StackRecognizer; - -impl DynamicLexer { - pub fn new(additional_id_chars: &Vec) -> Self { - let mut id_start = SimpleVob::alloc(0x100); - let mut id_body = SimpleVob::alloc(0x100); - for i in 0..=255u8 { - match i as char { - 'a'..='z' | 'A'..='Z' | '_' => { - id_start.allow_token(i as u32); - id_body.allow_token(i as u32); - } - '0'..='9' => { - id_body.allow_token(i as u32); - } - _ => {} - } - } - for &c in additional_id_chars { - id_start.allow_token(c as u32); - id_body.allow_token(c as u32); - } - DynamicLexer { - trie: Trie::new(), - id_start, - id_body, - } - } - - pub fn to_stack_recognizer(self) -> StackRecognizer { - StackRecognizer::from(self) - } - - pub fn add(&mut self, word: &[u8]) { - self.trie.add(word); - } -} - -impl FunctionalRecognizer for DynamicLexer { - fn initial(&self) -> DState { - DState::ROOT - } - - fn try_append(&self, state: DState, byte: u8) -> Option { - if state.node_id == NodeId::ROOT { - if self.id_start.is_allowed(byte as u32) { - match self.trie.child_at(state.node_id, byte) { - Some(node_id) => Some(DState { node_id }), - None => None, - } - } else { - Some(state) - } - } else { - if self.id_body.is_allowed(byte as u32) { - match self.trie.child_at(state.node_id, byte) { - Some(node_id) => Some(DState { node_id }), - None => None, - } - } else { - if self.trie.node_data(state.node_id).is_terminal { - Some(DState::ROOT) - } else { - None - } - } - } - } - - fn special_allowed(&self, state: DState, tok: SpecialToken) -> bool { - if tok == SpecialToken::EndOfSentence { - self.trie.node_data(state.node_id).is_terminal - } else { - false - } - } -} diff --git a/controllers/aici_abi/src/host.rs b/controllers/aici_abi/src/host.rs deleted file mode 100644 index 222666e8..00000000 --- a/controllers/aici_abi/src/host.rs +++ /dev/null @@ -1,383 +0,0 @@ -use crate::{bytes::vec_from_bytes, toktrie::TokTrie, SeqId, SimpleVob, TokenId}; -use serde::{Deserialize, Serialize}; -use toktrie::TokenizerEnv; - -#[repr(transparent)] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -struct BlobId(u32); - -#[allow(dead_code)] -extern "C" { - // Read binary blob. - // Always returns the size of the blob, will write up to `size` bytes to `dst`. - fn aici_host_read_blob(blob: BlobId, dst: *mut u8, size: u32) -> u32; - - // Return the ID of TokTrie binary representation. - fn aici_host_token_trie() -> BlobId; - - // Return the ID of argument passed by the user. - fn aici_host_module_arg() -> BlobId; - - // Return the ID of argument passed to the process() function. - // It's a JSON serialization of Pre/Mid/PostProcessArg. - fn aici_host_process_arg() -> BlobId; - - // Tokenize given UTF8 string. The result is only valid until next call to this function. - fn aici_host_tokenize(src: *const u8, src_size: u32) -> BlobId; - - // Set logit bias based on bit-mask in src. - fn aici_host_return_logit_bias(src: *const u32) -> u32; - - fn aici_host_self_seq_id() -> u32; - - fn aici_host_return_process_result(res: *const u8, res_size: u32); - - fn aici_host_storage_cmd(cmd: *const u8, cmd_size: u32) -> BlobId; - - // This can be also obtained from the TokTrie. - fn aici_host_eos_token() -> TokenId; - - // Get value of configuration parameters, like "fork". - fn aici_host_get_config(src: *const u8, src_size: u32) -> i32; - - // Stop the program - any error info is assumed to have been printed already. - // Backtraces will be limited. - fn aici_host_stop(); -} - -// TODO: add -fn read_blob(blob: BlobId, prefetch_size: usize) -> Vec { - let mut buffer = vec![0u8; prefetch_size]; - let prefetch_size = prefetch_size as u32; - let size = unsafe { aici_host_read_blob(blob, buffer.as_mut_ptr(), prefetch_size) }; - buffer.resize(size as usize, 0); - if size > prefetch_size { - // didn't read everything; retry - unsafe { aici_host_read_blob(blob, buffer.as_mut_ptr(), size) }; - } - buffer -} - -#[cfg(target_arch = "wasm32")] -fn init_panic() { - std::panic::set_hook(Box::new(|info| { - // skip 'run with `RUST_BACKTRACE=1`' message (not relevant for remote running) - println!("{}", info); - })) -} - -#[cfg(target_arch = "wasm32")] -#[no_mangle] -pub extern "C" fn aici_init() { - init_panic(); - set_host(Box::new(WasmHost {})); -} - -pub struct WasmTokenizerEnv { - toktrie: TokTrie, -} - -impl Default for WasmTokenizerEnv { - fn default() -> Self { - WasmTokenizerEnv { - toktrie: host_trie(), - } - } -} - -impl TokenizerEnv for WasmTokenizerEnv { - fn stop(&self) -> ! { - aici_stop() - } - - fn tok_trie(&self) -> &TokTrie { - &self.toktrie - } - - fn tokenize_bytes(&self, s: &[u8]) -> Vec { - tokenize_bytes(s) - } -} - -/** - * This is normally implemented straightforwardly by wasm callbacks. - * It can be overridden with set_host() when compiling to native. - */ -pub trait HostInterface { - fn arg_bytes(&self) -> Vec; - fn trie_bytes(&self) -> Vec; - fn return_logit_bias(&self, vob: &SimpleVob) -> u32; - fn process_arg_bytes(&self) -> Vec; - fn return_process_result(&self, res: &[u8]); - fn storage_cmd(&self, cmd: StorageCmd) -> StorageResp; - fn tokenize_bytes(&self, s: &[u8]) -> Vec; - fn self_seq_id(&self) -> SeqId; - fn eos_token(&self) -> TokenId; - fn get_config(&self, name: &str) -> i32; - fn stop(&self) -> !; -} - -static mut HOST: Option> = None; - -struct WasmHost {} -impl HostInterface for WasmHost { - fn arg_bytes(&self) -> Vec { - read_blob(unsafe { aici_host_module_arg() }, 1024) - } - - fn trie_bytes(&self) -> Vec { - read_blob(unsafe { aici_host_token_trie() }, 0) - } - - fn return_logit_bias(&self, vob: &SimpleVob) -> u32 { - assert!(vob.len() > 0); - unsafe { aici_host_return_logit_bias(vob.as_ptr()) } - } - - fn process_arg_bytes(&self) -> Vec { - read_blob(unsafe { aici_host_process_arg() }, 1024) - } - - fn return_process_result(&self, res: &[u8]) { - unsafe { - aici_host_return_process_result(res.as_ptr(), res.len() as u32); - } - } - - fn storage_cmd(&self, cmd: StorageCmd) -> StorageResp { - let cmd_bytes = serde_json::to_vec(&cmd).unwrap(); - let res_id = unsafe { aici_host_storage_cmd(cmd_bytes.as_ptr(), cmd_bytes.len() as u32) }; - let resp_bytes = read_blob(res_id, 1024); - serde_json::from_slice(&resp_bytes).unwrap() - } - - fn stop(&self) -> ! { - unsafe { aici_host_stop() }; - panic!("didn't stop") - } - - fn tokenize_bytes(&self, s: &[u8]) -> Vec { - let id = unsafe { aici_host_tokenize(s.as_ptr(), s.len() as u32) }; - let r = read_blob(id, 4 * (s.len() / 3 + 10)); - let res = vec_from_bytes(&r); - // println!( - // "tokenize_bytes: {:?} -> {:?}", - // String::from_utf8_lossy(s), - // res - // ); - res - } - - fn self_seq_id(&self) -> SeqId { - unsafe { SeqId(aici_host_self_seq_id()) } - } - - fn eos_token(&self) -> TokenId { - unsafe { aici_host_eos_token() } - } - - fn get_config(&self, name: &str) -> i32 { - let name_bytes = name.as_bytes(); - let res = unsafe { aici_host_get_config(name_bytes.as_ptr(), name_bytes.len() as u32) }; - res - } -} - -fn get_host() -> &'static Box { - unsafe { HOST.as_ref().unwrap() } -} - -pub fn set_host(host: Box) { - unsafe { - assert!(HOST.is_none()); - HOST = Some(host); - } -} - -pub fn arg_bytes() -> Vec { - get_host().arg_bytes() - - // #[cfg(not(target_arch = "wasm32"))] - // return std::fs::read("arg.json").unwrap(); -} - -pub fn arg_string() -> String { - String::from_utf8_lossy(&arg_bytes()).to_string() -} - -pub fn host_trie() -> TokTrie { - TokTrie::from_bytes(&get_host().trie_bytes()) - // #[cfg(not(target_arch = "wasm32"))] - // return std::fs::read("tokenizer.bin").unwrap(); -} - -pub fn return_logit_bias(vob: &SimpleVob) -> u32 { - get_host().return_logit_bias(vob) -} - -pub fn process_arg_bytes() -> Vec { - get_host().process_arg_bytes() -} - -pub fn return_process_result(res: &[u8]) { - unsafe { - aici_host_return_process_result(res.as_ptr(), res.len() as u32); - } -} - -pub fn get_config(name: &str) -> i32 { - get_host().get_config(name) -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum StorageOp { - Set, - Append, -} - -#[allow(dead_code)] -pub mod bin_string { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - pub fn serialize(v: &Vec, s: S) -> Result { - let binstr = String::from_iter(v.iter().map(|b| *b as char)); - String::serialize(&binstr, s) - } - - pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let binstr = String::deserialize(d)?; - Ok(binstr.chars().map(|c| c as u8).collect()) - } -} - -pub mod hex_string { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - use crate::bytes::{from_hex_string, to_hex_string}; - - pub fn serialize(v: &Vec, s: S) -> Result { - let hexstr = to_hex_string(v); - String::serialize(&hexstr, s) - } - - pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let hexstr = String::deserialize(d)?; - from_hex_string(&hexstr).map_err(serde::de::Error::custom) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum StorageCmd { - /// Read variable. Returns StorageResp::ReadVar or StorageResp::VariableMissing. - ReadVar { name: String }, - - /// Write variable. - /// If `when_version_is == None`, always writes the variable and returns StorageResp::WriteVar. - /// Otherwise, if the variable has the specified version, it writes the variable - /// and returns StorageResp::WriteVar. - /// Otherwise (version conflict), returns either StorageResp::ReadVar or StorageResp::VariableMissing - /// just like ReadVar would. - WriteVar { - name: String, - #[serde(with = "hex_string")] - value: Vec, - op: StorageOp, - when_version_is: Option, - }, -} - -#[derive(Serialize, Deserialize, Debug)] -pub enum StorageResp { - /// Upon handling the request the variable had the specified value and version number. - ReadVar { - version: u64, - #[serde(with = "hex_string")] - value: Vec, - }, - /// Upon handling the request the variable was unset. - VariableMissing {}, - /// The variable has been written, and the new version is returned. - WriteVar { version: u64 }, -} - -pub fn storage_cmd(cmd: StorageCmd) -> StorageResp { - let cmd_bytes = serde_json::to_vec(&cmd).unwrap(); - let res_id = unsafe { aici_host_storage_cmd(cmd_bytes.as_ptr(), cmd_bytes.len() as u32) }; - let resp_bytes = read_blob(res_id, 1024); - serde_json::from_slice(&resp_bytes).unwrap() -} - -// Public APIs - -pub struct VariableStorage { - // no fields (yet?) -} - -impl VariableStorage { - /// Create a new instance of VariableStorage. It currently has no fields. - pub fn new() -> Self { - VariableStorage {} - } - - /// Read variable. Returns None if the variable is unset. - pub fn get(&self, name: &str) -> Option> { - self.get_with_version(name).map(|x| x.1) - } - - /// Write specified value to variable. - pub fn set(&self, name: &str, value: Vec) { - let _ver = self.write_var(name, value, StorageOp::Set); - } - - /// Append specified value to variable. - pub fn append(&self, name: &str, value: Vec) { - let _ver = self.write_var(name, value, StorageOp::Append); - } - - fn write_var(&self, name: &str, value: Vec, op: StorageOp) -> u64 { - match storage_cmd(StorageCmd::WriteVar { - name: name.to_string(), - value, - op, - when_version_is: None, - }) { - StorageResp::WriteVar { version } => version, - _ => panic!("unexpected response to write var"), - } - } - - fn get_with_version(&self, name: &str) -> Option<(u64, Vec)> { - match storage_cmd(StorageCmd::ReadVar { - name: name.to_string(), - }) { - StorageResp::ReadVar { version, value } => Some((version, value)), - StorageResp::VariableMissing {} => None, - StorageResp::WriteVar { .. } => panic!("unexpected response to read var"), - } - } -} - -/// Tokenize given byte string. -pub fn tokenize_bytes(s: &[u8]) -> Vec { - get_host().tokenize_bytes(s) -} - -/// Tokenize given UTF8 string. -pub fn tokenize(s: &str) -> Vec { - get_host().tokenize_bytes(s.as_bytes()) -} - -/// Return the ID of the current process. -pub fn self_seq_id() -> SeqId { - get_host().self_seq_id() -} - -/// Return the ID of the EOS token. -pub fn eos_token() -> TokenId { - get_host().eos_token() -} - -/// Stop the program - any error info is assumed to have been printed already. -pub fn aici_stop() -> ! { - get_host().stop(); -} diff --git a/controllers/aici_abi/src/lex.rs b/controllers/aici_abi/src/lex.rs deleted file mode 100644 index 33f1c6bb..00000000 --- a/controllers/aici_abi/src/lex.rs +++ /dev/null @@ -1,349 +0,0 @@ -use regex_automata::{ - dfa::{dense, Automaton}, - util::syntax, -}; -use rustc_hash::FxHashMap; -use std::{hash::Hash, vec}; -use vob::{vob, Vob}; - -pub type PatIdx = usize; -pub type StateID = regex_automata::util::primitives::StateID; - -const LOG_LEXER: bool = false; - -// enabling this is slightly faster, but it requires ~ |lexer_states|*|parser_states| bits -const PRECOMPUTE_AND: bool = false; - -#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] -pub struct LexerState { - pub state: StateID, - pub reachable: VobIdx, -} - -impl LexerState { - fn fake() -> Self { - LexerState { - state: StateID::default(), - reachable: VobIdx::all_zero(), - } - } -} - -#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] -pub struct VobIdx { - v: u32, -} - -impl VobIdx { - pub fn new(v: usize) -> Self { - VobIdx { v: v as u32 } - } - - pub fn all_zero() -> Self { - VobIdx { v: 0 } - } - - pub fn as_usize(&self) -> usize { - self.v as usize - } - - pub fn is_zero(&self) -> bool { - self.v == 0 - } -} - -pub struct VobSet { - vobs: Vec, - by_vob: FxHashMap, - non_empty: Vob, -} - -impl VobSet { - pub fn new() -> Self { - VobSet { - vobs: Vec::new(), - by_vob: FxHashMap::default(), - non_empty: Vob::new(), - } - } - - pub fn insert_or_get(&mut self, vob: &Vob) -> VobIdx { - if let Some(idx) = self.by_vob.get(vob) { - return *idx; - } - let len = self.vobs.len(); - if len == 0 && !vob_is_zero(vob) { - panic!("first vob must be empty"); - } - let idx = VobIdx::new(len); - self.vobs.push(vob.clone()); - self.by_vob.insert(vob.clone(), idx); - idx - } - - pub fn resolve(&self, idx: VobIdx) -> &Vob { - &self.vobs[idx.as_usize()] - } - - pub fn and_is_zero(&self, a: VobIdx, b: VobIdx) -> bool { - if PRECOMPUTE_AND { - !self.non_empty[a.as_usize() * self.vobs.len() + b.as_usize()] - } else { - vob_and_is_zero(&self.vobs[a.as_usize()], &self.vobs[b.as_usize()]) - } - } - - pub fn pre_compute(&mut self) { - if PRECOMPUTE_AND { - let l = self.vobs.len(); - self.non_empty.resize(l * l, false); - for x in 0..self.vobs.len() { - for y in 0..=x { - if !vob_and_is_zero(&self.vobs[x], &self.vobs[y]) { - self.non_empty.set(x * l + y, true); - self.non_empty.set(y * l + x, true); - } - } - } - println!( - "vob set: {} VOBs, {} nonempty", - self.vobs.len(), - self.non_empty.len() - ); - } - } -} - -pub struct Lexer { - dfa: dense::DFA>, - initial: LexerState, - vobidx_by_state_off: Vec, -} - -impl Lexer { - pub fn from(patterns: Vec, vobset: &mut VobSet) -> Self { - // TIME: 4ms - let dfa = dense::Builder::new() - .configure( - dense::Config::new() - .start_kind(regex_automata::dfa::StartKind::Anchored) - .match_kind(regex_automata::MatchKind::All), - ) - .syntax(syntax::Config::new().unicode(false).utf8(false)) - .build_many(&patterns) - .unwrap(); - - println!( - "dfa: {} bytes, {} patterns", - dfa.memory_usage(), - patterns.len(), - ); - if false { - for p in &patterns { - println!(" {}", p) - } - } - - let anch = regex_automata::Anchored::Yes; - - let mut incoming = FxHashMap::default(); - let initial = dfa.universal_start_state(anch).unwrap(); - let mut todo = vec![initial]; - incoming.insert(initial, Vec::new()); - - // TIME: 1.5ms - while todo.len() > 0 { - let s = todo.pop().unwrap(); - for b in 0..=255 { - let s2 = dfa.next_state(s, b); - if !incoming.contains_key(&s2) { - todo.push(s2); - incoming.insert(s2, Vec::new()); - } - incoming.get_mut(&s2).unwrap().push(s); - } - } - - let states = incoming.keys().map(|x| *x).collect::>(); - let mut reachable_patterns = FxHashMap::default(); - - for s in &states { - let mut v = vob![false; patterns.len()]; - let s2 = dfa.next_eoi_state(*s); - if dfa.is_match_state(s2) { - for idx in 0..dfa.match_len(s2) { - let idx = dfa.match_pattern(s2, idx).as_usize(); - v.set(idx, true); - if LOG_LEXER { - println!(" match: {:?} {}", *s, patterns[idx]) - } - } - } - reachable_patterns.insert(*s, v); - } - - // TIME: 20ms - loop { - let mut num_set = 0; - - for s in &states { - let ours = reachable_patterns.get(s).unwrap().clone(); - for o in &incoming[s] { - let theirs = reachable_patterns.get(o).unwrap(); - let mut tmp = ours.clone(); - tmp |= theirs; - if tmp != *theirs { - num_set += 1; - reachable_patterns.insert(*o, tmp); - } - } - } - - if LOG_LEXER { - println!("iter {} {}", num_set, states.len()); - } - if num_set == 0 { - break; - } - } - - let mut states_idx = states.iter().map(|x| x.as_usize()).collect::>(); - states_idx.sort(); - - let shift = dfa.stride2(); - let mut vobidx_by_state_off = - vec![VobIdx::all_zero(); 1 + (states_idx.iter().max().unwrap() >> shift)]; - for (k, v) in reachable_patterns.iter() { - vobidx_by_state_off[k.as_usize() >> shift] = vobset.insert_or_get(v); - } - - println!("initial: {:?}; {} states", initial, states.len()); - - let mut lex = Lexer { - dfa, - vobidx_by_state_off, - initial: LexerState::fake(), - }; - - lex.initial = lex.mk_state(initial); - - if LOG_LEXER { - for s in &states { - if lex.is_dead(*s) { - println!("dead: {:?} {}", s, lex.dfa.is_dead_state(*s)); - } - } - - println!("reachable: {:#?}", reachable_patterns); - } - - lex - } - - pub fn file_start_state(&self) -> StateID { - self.initial.state - // pretend we've just seen a newline at the beginning of the file - // TODO: this should be configurable - // self.dfa.next_state(self.initial.state, b'\n') - } - - fn mk_state(&self, state: StateID) -> LexerState { - LexerState { - state, - reachable: self.reachable_tokens(state), - } - } - - fn is_dead(&self, state: StateID) -> bool { - self.reachable_tokens(state).is_zero() - } - - fn reachable_tokens(&self, state: StateID) -> VobIdx { - self.vobidx_by_state_off[state.as_usize() >> self.dfa.stride2()] - } - - fn get_token(&self, prev: StateID) -> Option { - let state = self.dfa.next_eoi_state(prev); - if !self.dfa.is_match_state(state) { - return None; - } - - // we take the first token that matched - // (eg., "while" will match both keyword and identifier, but keyword is first) - let pat_idx = (0..self.dfa.match_len(state)) - .map(|idx| self.dfa.match_pattern(state, idx).as_usize()) - .min() - .unwrap(); - - if LOG_LEXER { - println!("token: {}", pat_idx); - } - - Some(pat_idx) - } - - #[inline(always)] - pub fn advance(&self, prev: StateID, byte: Option) -> Option<(LexerState, Option)> { - let dfa = &self.dfa; - if let Some(byte) = byte { - let state = dfa.next_state(prev, byte); - if LOG_LEXER { - println!( - "lex: {:?} -{:?}-> {:?} d={}", - prev, - byte as char, - state, - self.is_dead(state), - ); - } - let v = self.reachable_tokens(state); - if v.is_zero() { - // if final_state is a match state, find the token that matched - let tok = self.get_token(prev); - if tok.is_none() { - None - } else { - let state = dfa.next_state(self.initial.state, byte); - if LOG_LEXER { - println!("lex0: {:?} -{:?}-> {:?}", self.initial, byte as char, state); - } - Some((self.mk_state(state), tok)) - } - } else { - Some(( - LexerState { - state, - reachable: v, - }, - None, - )) - } - } else { - let tok = self.get_token(prev); - if tok.is_none() { - None - } else { - Some((self.initial, tok)) - } - } - } -} - -fn vob_and_is_zero(a: &Vob, b: &Vob) -> bool { - debug_assert!(a.len() == b.len()); - for (a, b) in a.iter_storage().zip(b.iter_storage()) { - if a & b != 0 { - return false; - } - } - return true; -} - -fn vob_is_zero(v: &Vob) -> bool { - for b in v.iter_storage() { - if b != 0 { - return false; - } - } - true -} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs deleted file mode 100644 index 72024825..00000000 --- a/controllers/aici_abi/src/lib.rs +++ /dev/null @@ -1,236 +0,0 @@ -pub use toktrie; -pub use toktrie::{bytes, recognizer, rng}; -pub use toktrie::{SimpleVob, TokenizerEnv}; - -use serde::{Deserialize, Serialize}; - -mod host; - -#[cfg(feature = "cfg")] -pub mod cfg; -#[cfg(feature = "cfg")] -mod lex; - -#[cfg(feature = "rx")] -pub mod rx; - -pub mod dlex; - -pub mod substring; - -pub type TokenId = toktrie::TokenId; - -pub use host::{ - aici_stop, arg_bytes, arg_string, get_config, host_trie, self_seq_id, tokenize, tokenize_bytes, - StorageCmd, StorageOp, StorageResp, VariableStorage, WasmTokenizerEnv, -}; - -#[cfg(not(target_arch = "wasm32"))] -pub use host::{set_host, HostInterface}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitPromptArg { - pub prompt: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitPromptResult { - pub prompt: Vec, -} - -impl InitPromptResult { - pub fn from_arg(arg: InitPromptArg) -> Self { - InitPromptResult { prompt: arg.prompt } - } -} - -#[repr(transparent)] -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] -pub struct SeqId(pub u32); - -#[derive(Serialize, Deserialize, Debug)] -pub struct MidProcessArg { - /// Sampling result for the previous iteration. - /// For simple sampled token 't', backtrack==0 and tokens==[t]. - /// For first request, backtrack==0 and tokens==[] (prompt is passed separately, before). - /// Can be more complex when splices are used. - pub backtrack: u32, - pub tokens: Vec, - /// - pub fork_group: Vec, -} - -impl MidProcessArg { - pub fn has_eos(&self) -> bool { - let eos = host::eos_token(); - self.tokens.iter().any(|t| *t == eos) - } - - pub fn save_tokens(&self, acc_tokens: &mut Vec) { - let bt = self.backtrack as usize; - assert!( - bt <= acc_tokens.len(), - "attempting to backtrack past beginning" - ); - acc_tokens.truncate(acc_tokens.len() - bt); - acc_tokens.extend_from_slice(&self.tokens); - } -} - -pub use toktrie::{Branch, Splice}; - -#[derive(Debug)] -pub struct MidProcessResult { - /// Fork the request into multiple branches. - /// Typically, exactly one branch is returned. - /// If multiple branches are returned, they are executed in parallel. - /// If no branches are returned, the request is terminated. - pub branches: Vec>, -} - -impl MidProcessResult { - pub fn from_branch(branch: Branch) -> Self { - if branch.is_stop() { - Self::stop() - } else { - MidProcessResult { - branches: vec![branch], - } - } - } - - pub fn stop() -> Self { - MidProcessResult { branches: vec![] } - } - - pub fn sample(set: SimpleVob) -> Self { - Self::sample_with_temp(set, None) - } - - pub fn sample_with_temp(set: SimpleVob, temperature: Option) -> Self { - Self::from_branch(Branch::sample(set, temperature)) - } - - pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { - Self::from_branch(Branch::splice(backtrack, ff_tokens)) - } - - pub fn noop() -> Self { - Self::splice(0, vec![]) - } - - pub fn is_stop(&self) -> bool { - self.branches.is_empty() - } -} - -#[derive(Serialize, Deserialize)] -pub struct ProcessResultOffset { - /// Branches use byte offsets into the bias tensor. - pub branches: Vec>, -} - -pub trait AiciCtrl { - /// Called with the initial prompt. ~1000ms time limit. - /// By default ignore prompt. - fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { - InitPromptResult::from_arg(arg) - } - - /// This is the main entry point for the module. ~20ms time limit. - fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult; - - // Internals - fn aici_init_prompt(&mut self) { - let arg: InitPromptArg = serde_json::from_slice(&host::process_arg_bytes()).unwrap(); - let res = self.init_prompt(arg); - let res_bytes = serde_json::to_vec(&res).unwrap(); - host::return_process_result(&res_bytes); - } - - fn aici_mid_process(&mut self) { - let arg: MidProcessArg = serde_json::from_slice(&host::process_arg_bytes()) - .expect("aici_mid_process: failed to deserialize MidProcessArg"); - let res = self.mid_process(arg); - let mut used_logits = false; - let res = ProcessResultOffset { - branches: res - .branches - .into_iter() - .map(|b| { - b.map_mask(|vob| { - if used_logits { - panic!("aici_mid_process: multiple branches with sampling not yet supported"); - } - used_logits = true; - host::return_logit_bias(&vob) as usize - }) - }) - .collect(), - }; - let res_bytes = serde_json::to_vec(&res).expect("aici_mid_process: failed to serialize"); - host::return_process_result(&res_bytes); - } -} - -/// Expose method as extern "C", usage: -/// expose!(Foo::set_count(n: i32) -> i32); -/// Generates "C" function: -/// set_count(Foo *, i32) -> i32 -#[macro_export] -macro_rules! expose { - ($struct_name:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$method_name($($arg),*) - } - } - }; - ($struct_name:ident :: $field:ident :: $method_name:ident ( $($arg:ident : $typ:ty),* ) -> $ret:ty) => { - #[no_mangle] - pub extern "C" fn $method_name(self_: *mut $struct_name, $($arg : $typ),*) -> $ret { - unsafe { - (&mut *self_).$field.$method_name($($arg),*) - } - } - }; -} - -#[macro_export] -macro_rules! aici_expose_all { - ($struct_name:ident, $new:expr) => { - $crate::expose!($struct_name::aici_mid_process() -> ()); - $crate::expose!($struct_name::aici_init_prompt() -> ()); - - #[no_mangle] - pub extern "C" fn aici_create() -> *mut $struct_name { - let b = Box::new($new); - Box::into_raw(b) - } - - #[no_mangle] - pub extern "C" fn aici_panic() { - panic!("aici_panic()") - } - } -} - -#[macro_export] -macro_rules! include_bytes_aligned { - ($align_ty:ty, $path:literal) => {{ - #[repr(C)] // guarantee 'bytes' comes after '_align' - pub struct AlignedAs { - pub _align: [Align; 0], - pub bytes: Bytes, - } - - // this assignment is made possible by CoerceUnsized - static ALIGNED: &AlignedAs<$align_ty, [u8]> = &AlignedAs { - _align: [], - bytes: *include_bytes!($path), - }; - - &ALIGNED.bytes - }}; -} diff --git a/controllers/aici_abi/src/rx.rs b/controllers/aici_abi/src/rx.rs deleted file mode 100644 index 883fd05b..00000000 --- a/controllers/aici_abi/src/rx.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::error::Error; - -use crate::{ - recognizer::{FunctionalRecognizer, StackRecognizer}, - toktrie::SpecialToken, -}; -use anyhow::{bail, Result}; -use regex_automata::{ - dfa::{dense, Automaton}, - util::{primitives::StateID, syntax}, -}; - -pub type RecRxState = StateID; - -#[derive(Clone)] -pub struct RecRx { - dfa: dense::DFA>, - info: String, -} - -pub type RxStackRecognizer = StackRecognizer; - -impl RecRx { - pub fn from_rx(rx: &str, size_limit: Option) -> Result { - let rx = if rx.ends_with("$") { - rx.to_string() - } else { - rx.to_string() + "$" - }; - let rx = if rx.starts_with("^") { - rx[1..].to_string() - } else { - rx - }; - // default to 16MB - it takes about 1s to build - let size_limit = size_limit.unwrap_or(16 << 20); - let t0 = std::time::Instant::now(); - let cfg = dense::Config::new() - .start_kind(regex_automata::dfa::StartKind::Anchored) - .dfa_size_limit(Some(size_limit)) - .determinize_size_limit(Some(size_limit)); - let dfa = dense::Builder::new() - .configure(cfg) - .syntax(syntax::Config::new().unicode(false).utf8(false)) - .build(&rx); - let dfa = match dfa { - Ok(dfa) => dfa, - Err(e) => { - if let Some(e) = e.source() { - if let Some(e) = e.source() { - bail!("error building dfa(2): {}", e) - } else { - bail!("error building dfa(1): {}", e) - } - } else { - bail!("error building dfa(0): {}", e) - } - } - }; - let time = t0.elapsed(); - let mb_per_s = dfa.memory_usage() as f64 / time.as_secs_f64() / 1024.0 / 1024.0; - let info = format!( - "dfa: {} bytes; time {:?}; {:.3} MB/s", - dfa.memory_usage(), - time, - mb_per_s - ); - - if let Err(e) = dfa.start_state(&anchored_start()) { - bail!("DFA has no start state; {}", e) - } - - Ok(Self { dfa, info }) - } - - pub fn info(&self) -> &str { - &self.info - } - - pub fn to_stack_recognizer(self) -> RxStackRecognizer { - StackRecognizer::from(self) - } -} - -fn anchored_start() -> regex_automata::util::start::Config { - regex_automata::util::start::Config::new().anchored(regex_automata::Anchored::Yes) -} - -impl FunctionalRecognizer for RecRx { - fn initial(&self) -> RecRxState { - self.dfa - .start_state(&anchored_start()) - .expect("dfa has no start state") - } - - #[inline(always)] - fn try_append(&self, state: RecRxState, byte: u8) -> Option { - let next = self.dfa.next_state(state, byte); - if self.dfa.is_dead_state(next) { - None - } else { - Some(next) - } - } - - #[inline(always)] - fn special_allowed(&self, state: RecRxState, tok: SpecialToken) -> bool { - let state = self.dfa.next_eoi_state(state); - match tok { - SpecialToken::EndOfSentence => self.dfa.is_match_state(state), - _ => false, - } - } -} diff --git a/controllers/aici_abi/src/substring.rs b/controllers/aici_abi/src/substring.rs deleted file mode 100644 index b8be55d5..00000000 --- a/controllers/aici_abi/src/substring.rs +++ /dev/null @@ -1,277 +0,0 @@ -use std::fmt::Display; - -use crate::{ - bytes::limit_bytes, - recognizer::{FunctionalRecognizer, StackRecognizer}, - toktrie::SpecialToken, -}; -use serde_json::json; - -enum Node { - Inner { children: Vec<(u8, usize)> }, - Leaf { source_offset: usize }, -} - -pub struct SubStrMatcher { - end_str: String, - source: Vec, - nodes: Vec, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SubStrState { - Dead, - Node(usize), - SourceOffset(usize), - EndStrOffset(usize), -} - -pub type SubStrStackRecognizer = StackRecognizer; - -fn add_node(nodes: &mut Vec, n: Node) -> usize { - let idx = nodes.len(); - nodes.push(n); - idx -} - -impl Display for SubStrMatcher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.pp(f, 0, 0) - } -} - -impl SubStrMatcher { - #[allow(dead_code)] - fn to_json(&self, node_idx: usize) -> serde_json::Value { - match &self.nodes[node_idx] { - Node::Inner { children } => { - let mut children_json = serde_json::Map::new(); - for (c, idx) in children.iter() { - children_json.insert(format!("{}", *c as char), self.to_json(*idx)); - } - serde_json::Value::Object(children_json) - } - Node::Leaf { source_offset } => { - json!(limit_bytes(&self.source[*source_offset..], 20)) - } - } - } - - fn pp( - &self, - f: &mut std::fmt::Formatter<'_>, - indent: usize, - node_idx: usize, - ) -> std::fmt::Result { - let node = &self.nodes[node_idx]; - match node { - Node::Inner { children } => { - for (c, idx) in children.iter() { - writeln!(f, "{:indent$}{:?} -> {}", "", *c as char, idx)?; - self.pp(f, indent + 1, *idx)?; - } - } - Node::Leaf { source_offset } => { - writeln!( - f, - "{:indent$}{}: {:?}", - "", - *source_offset, - limit_bytes(&self.source[*source_offset..], 20), - )?; - } - } - Ok(()) - } - - pub fn new(source: &str, end_str: &str) -> Self { - let mut tmp = Self { - source: (source.to_string() + " ").as_bytes().to_vec(), - end_str: end_str.to_string(), - nodes: vec![Node::Inner { children: vec![] }], - }; - tmp.add(0); - for i in 0..tmp.source.len() { - if tmp.source[i] == b' ' { - tmp.add(i + 1); - } - } - // println!("{}", tmp); - // println!("JSON: {}", serde_json::to_string(&tmp.to_json(0)).unwrap()); - tmp - } - - fn find(&self, s: &[u8]) -> (usize, usize) { - let mut node_idx = 0; - for (i, b) in s.iter().enumerate() { - let node = &self.nodes[node_idx]; - match node { - Node::Inner { children } => { - let mut found = false; - for (c, idx) in children.iter() { - if *c == *b { - node_idx = *idx; - found = true; - break; - } - } - if !found { - return (node_idx, i); - } - } - Node::Leaf { .. } => return (node_idx, i), - } - } - (node_idx, s.len()) - } - - fn add(&mut self, source_offset1: usize) { - let s1 = &self.source[source_offset1..]; - let (mut node_idx, offset) = self.find(s1); - if offset >= s1.len() { - return; - } - let source_offset1 = source_offset1 + offset; - let s1 = &self.source[source_offset1..]; - - let num_nodes = self.nodes.len(); - match &mut self.nodes[node_idx] { - Node::Inner { children } => { - children.push((s1[0], num_nodes)); - let n = add_node( - &mut self.nodes, - Node::Leaf { - source_offset: source_offset1 + 1, - }, - ); - assert!(n == num_nodes); - } - Node::Leaf { source_offset } => { - let source_offset2 = *source_offset; - let s2 = &self.source[source_offset2..]; - if s2.starts_with(s1) { - return; - } - if s1.starts_with(s2) { - self.nodes[node_idx] = Node::Leaf { - source_offset: source_offset1, - }; - return; - } - - for i in 0..s1.len() { - let b1 = s1[i]; - let b2 = s2[i]; - if b1 != b2 { - let n1 = add_node( - &mut self.nodes, - Node::Leaf { - source_offset: source_offset1 + i + 1, - }, - ); - let n2 = add_node( - &mut self.nodes, - Node::Leaf { - source_offset: source_offset2 + i + 1, - }, - ); - self.nodes[node_idx] = Node::Inner { - children: vec![(b1, n1), (b2, n2)], - }; - return; - } else { - let n1 = add_node(&mut self.nodes, Node::Inner { children: vec![] }); - self.nodes[node_idx] = Node::Inner { - children: vec![(b1, n1)], - }; - node_idx = n1; - } - } - } - } - } - - pub fn to_stack_recognizer(self) -> SubStrStackRecognizer { - StackRecognizer::from(self) - } - - fn append_to_src_off(&self, off: usize, byte: u8) -> SubStrState { - if off < self.source.len() && self.source[off] == byte { - SubStrState::SourceOffset(off + 1) - } else { - SubStrState::Dead - } - } - - fn append_inner(&self, state: SubStrState, byte: u8) -> SubStrState { - match state { - SubStrState::Dead => SubStrState::Dead, - SubStrState::EndStrOffset(off) => { - if off < self.end_str.len() && self.end_str.as_bytes()[off] == byte { - SubStrState::EndStrOffset(off + 1) - } else { - SubStrState::Dead - } - } - SubStrState::Node(state) => { - let node = &self.nodes[state]; - match node { - Node::Inner { children } => { - for (c, idx) in children.iter() { - if *c == byte { - return SubStrState::Node(*idx); - } - } - SubStrState::Dead - } - Node::Leaf { source_offset } => self.append_to_src_off(*source_offset, byte), - } - } - SubStrState::SourceOffset(off) => self.append_to_src_off(off, byte), - } - } - - #[inline(always)] - fn do_append(&self, state: SubStrState, byte: u8) -> SubStrState { - let state = match state { - SubStrState::Node(_) | SubStrState::SourceOffset(_) - if self.end_str.as_bytes().first() == Some(&byte) - && self.append_inner(state, b' ') != SubStrState::Dead => - { - SubStrState::EndStrOffset(0) - } - _ => state, - }; - - self.append_inner(state, byte) - } -} - -impl FunctionalRecognizer for SubStrMatcher { - fn initial(&self) -> SubStrState { - SubStrState::Node(0) - } - - #[inline(always)] - fn try_append(&self, state: SubStrState, byte: u8) -> Option { - match self.do_append(state, byte) { - SubStrState::Dead => None, - state => Some(state), - } - } - - #[inline(always)] - fn special_allowed(&self, state: SubStrState, tok: SpecialToken) -> bool { - match tok { - SpecialToken::EndOfSentence => { - let l = self.end_str.len(); - if l == 0 { - self.append_inner(state, b' ') != SubStrState::Dead - } else { - state == SubStrState::EndStrOffset(l) - } - } - _ => false, - } - } -} diff --git a/controllers/aici_abi/src/yesno.rs b/controllers/aici_abi/src/yesno.rs deleted file mode 100644 index 78b574c3..00000000 --- a/controllers/aici_abi/src/yesno.rs +++ /dev/null @@ -1,43 +0,0 @@ -use aici_abi::{host_trie, tokenize, toktrie::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; - -pub struct Runner { - toktrie: TokTrie, - tokens: Vec, - yes: TokenId, - no: TokenId, -} - -impl Runner { - pub fn new() -> Self { - let yes = tokenize("Yes")[0]; - let no = tokenize("No")[0]; - // ignore user-passed arg - Runner { - toktrie: host_trie(), - tokens: Vec::new(), - yes, - no, - } - } -} - -impl AiciCtrl for Runner { - fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { - arg.save_tokens(&mut self.tokens); - if self.tokens.len() >= 1 { - // we only want the first token - MidProcessResult::stop() - } else { - let mut set = self.toktrie.alloc_token_set(); - set.allow_token(self.yes); - set.allow_token(self.no); - MidProcessResult::sample(set) - } - } -} - -fn main() { - // test code here? -} - -aici_abi::aici_expose_all!(Runner, Runner::new()); diff --git a/controllers/aici_native/Cargo.toml b/controllers/aici_native/Cargo.toml deleted file mode 100644 index e98ef4d3..00000000 --- a/controllers/aici_native/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "aici_native" -version = "0.1.0" -edition = "2021" - -[lib] -name = "aici_native" - -[dependencies] -aici_abi = { path = "../aici_abi" } -toktrie_hf_tokenizers = { path = "../toktrie_hf_tokenizers" } -serde = { version = "1.0.192", features = ["derive"] } -serde_json = "1.0.108" -anyhow = "1.0.75" -rustc-hash = "1.1.0" -tokenizers = { version = "0.15.0", features = ["http"] } -log = "0.4.21" -flexi_logger = "0.28.0" diff --git a/controllers/aici_native/README.md b/controllers/aici_native/README.md deleted file mode 100644 index 205a0ae0..00000000 --- a/controllers/aici_native/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# AICI native - -Utilities for building native (non-Wasm) AICI Controllers. diff --git a/controllers/aici_native/src/bintokens.rs b/controllers/aici_native/src/bintokens.rs deleted file mode 100644 index e47094a8..00000000 --- a/controllers/aici_native/src/bintokens.rs +++ /dev/null @@ -1,157 +0,0 @@ -use anyhow::{anyhow, Result}; -use tokenizers::{FromPretrainedParameters, Tokenizer}; - -pub use toktrie_hf_tokenizers::{ByteTokenizer, ByteTokenizerEnv}; - -pub struct TokenizerInfo { - pub name: &'static str, - pub description: &'static str, - pub hf_model: &'static str, - pub model_ids: &'static str, -} - -pub fn tokenizers() -> Vec { - vec![ - TokenizerInfo { - name: "gpt4", - description: "cl100k_base, used by GPT-4 and GPT-3.5", - hf_model: "Xenova/gpt-4", - model_ids: "gpt-4", - }, - TokenizerInfo { - name: "llama16", - description: "same as llama, with 16 added tokens (used by 13B codellama)", - hf_model: "codellama/CodeLlama-13b-Instruct-hf", - model_ids: "codellama-13b", - }, - TokenizerInfo { - name: "llama70", - description: "used by codellama-70b; with token", - hf_model: "codellama/CodeLlama-70b-Instruct-hf", - model_ids: "codellama-70b", - }, - TokenizerInfo { - name: "llama", - description: "used by Llama, CodeLlama, etc.", - hf_model: "codellama/CodeLlama-34b-Instruct-hf", - model_ids: "", - }, - TokenizerInfo { - name: "orca", - description: "llama", - hf_model: "microsoft/Orca-2-13b@refs/pr/23", - model_ids: "for microsoft/Orca models; similar to llama, with 3 tokens added for chat", - }, - TokenizerInfo { - name: "falcon", - description: "used by Falcon 7b, 40b, etc.", - hf_model: "tiiuae/falcon-7b", - model_ids: "", - }, - TokenizerInfo { - name: "mistral", - description: "used by Mistral and Mixtral", - hf_model: "mistralai/Mistral-7B-Instruct-v0.2", - model_ids: "mixtral", - }, - TokenizerInfo { - name: "mpt", - description: "MPT", - hf_model: "mosaicml/mpt-7b", - model_ids: "", - }, - TokenizerInfo { - name: "phi", - description: "Phi 1.5 and Phi 2", - hf_model: "microsoft/phi-1_5", - model_ids: "", - }, - TokenizerInfo { - name: "gpt2", - description: "GPT-2", - hf_model: "gpt2", - model_ids: "gpt-2", - }, - ] -} - -pub fn list_tokenizers() -> String { - format!( - "Available tokenizers for -t or --tokenizer:\n{}\n{}\n{}", - tokenizers() - .iter() - .map(|t| format!(" -t {:16} {}", t.name, t.description)) - .collect::>() - .join("\n"), - "You can also use a HuggingFace model name, in format 'user/modelname',", - "or a local file in format './path/to/tokenizer.json'." - ) -} - -pub fn guess_tokenizer(model_name: &str) -> Option { - let m = model_name.to_lowercase(); - tokenizers() - .iter() - .find(|t| { - m.contains(&t.name) - || t.model_ids - .split(',') - .map(|x| x.trim()) - .filter(|x| x.len() > 0) - .any(|x| m.contains(x)) - }) - .map(|t| t.name.to_string()) -} - -fn strip_suffix(sep: &str, s: &mut String) -> Option { - let mut parts = s.splitn(2, sep); - let core = parts.next().unwrap().to_string(); - let suff = parts.next().map(|s| s.to_string()); - *s = core; - suff -} - -pub fn test_tokenizers() { - for t in tokenizers() { - let t = find_tokenizer(t.name).unwrap(); - println!("tokenizer: {} {}", t.hf_model, t.vocab_size); - } -} - -pub fn find_tokenizer(mut name: &str) -> Result { - if !name.contains("/") { - for t in tokenizers() { - if t.name == name { - name = t.hf_model; - break; - } - } - } - - log::info!("loading tokenizer: {}", name); - - let loaded = if name.starts_with(".") || name.starts_with("/") { - Tokenizer::from_file(name) - } else { - let mut name2 = name.to_string(); - let mut args = FromPretrainedParameters::default(); - - match strip_suffix("@", &mut name2) { - Some(s) => args.revision = s, - None => {} - } - Tokenizer::from_pretrained(name2, Some(args)) - }; - - match loaded { - Err(e) => { - let msg = format!("can't load tokenizer {}: {}", name, e); - println!("{}\n{}", msg, list_tokenizers()); - return Err(anyhow!("{}", msg)); - } - Ok(t) => { - let bt = ByteTokenizer::from_tokenizer(t)?; - Ok(bt) - } - } -} diff --git a/controllers/aici_native/src/lib.rs b/controllers/aici_native/src/lib.rs deleted file mode 100644 index 6bdf9d43..00000000 --- a/controllers/aici_native/src/lib.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod bintokens; -mod log; -pub mod variables; - -pub use log::*; - -pub use rustc_hash::FxHashMap as HashMap; -pub use rustc_hash::FxHashSet as HashSet; diff --git a/controllers/aici_native/src/log.rs b/controllers/aici_native/src/log.rs deleted file mode 100644 index c8c6f8ad..00000000 --- a/controllers/aici_native/src/log.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::fmt::Write; - -use anyhow::Result; -use flexi_logger::style; -use flexi_logger::{DeferredNow, Logger, WriteMode}; -use log::Record; - -pub enum LogMode { - Normal, - Test, - Daemon, -} - -struct LimitedWrite { - limit: usize, - dst: Vec, -} - -impl Write for LimitedWrite { - fn write_str(&mut self, s: &str) -> std::fmt::Result { - if self.dst.len() > self.limit { - return Err(std::fmt::Error); - } - if self.dst.len() + s.len() < self.limit { - self.dst.extend_from_slice(s.as_bytes()); - Ok(()) - } else { - let remaining = self.limit - self.dst.len(); - self.dst.extend_from_slice(&s.as_bytes()[..remaining]); - self.dst.extend_from_slice(b" (...)"); - Err(std::fmt::Error) - } - } -} - -fn args_to_str(limit: usize, args: &std::fmt::Arguments) -> String { - // let capacity = args.estimated_capacity(); - let mut output = LimitedWrite { - limit, - dst: Vec::with_capacity(128), - }; - if output.write_fmt(*args).is_err() { - assert!(output.dst.len() > limit); - } - match String::from_utf8(output.dst) { - Ok(s) => s, - Err(err) => String::from_utf8_lossy(err.as_bytes()).to_string(), - } -} - -fn truncated_format( - w: &mut dyn std::io::Write, - _now: &mut DeferredNow, - record: &Record, -) -> Result<(), std::io::Error> { - let level = record.level(); - write!( - w, - "{} [{}] {}", - style(level).paint(level.to_string()), - record.module_path().unwrap_or(""), - style(level).paint(args_to_str(1000, record.args())) - ) -} - -fn daemon_format( - w: &mut dyn std::io::Write, - now: &mut DeferredNow, - record: &Record, -) -> Result<(), std::io::Error> { - write!( - w, - "{} {} [{}] {}", - now.format("%Y-%m-%d %H:%M:%S%.3f"), - record.level(), - record.module_path().unwrap_or(""), - args_to_str(5000, record.args()) - ) -} - -pub fn init_log(mode: LogMode) -> Result<()> { - let logger = match mode { - LogMode::Normal => Logger::try_with_env_or_str("info")? - .format(truncated_format) - .log_to_stdout(), - LogMode::Test => { - Logger::try_with_env_or_str("debug")?.write_mode(WriteMode::SupportCapture) - } - LogMode::Daemon => Logger::try_with_env_or_str("info")? - .format(daemon_format) - .log_to_stdout(), - }; - - logger.start()?; - Ok(()) -} - -pub fn setup_log() { - init_log(LogMode::Normal).expect("Failed to initialize log") -} diff --git a/controllers/aici_native/src/variables.rs b/controllers/aici_native/src/variables.rs deleted file mode 100644 index 4f0dc07b..00000000 --- a/controllers/aici_native/src/variables.rs +++ /dev/null @@ -1,56 +0,0 @@ -use aici_abi::{StorageCmd, StorageOp, StorageResp}; -use rustc_hash::FxHashMap; - -#[derive(Default)] -pub struct Variables { - pub variables: FxHashMap)>, -} - -impl Variables { - pub fn process_cmd(&mut self, cmd: StorageCmd) -> StorageResp { - match cmd { - StorageCmd::ReadVar { name } => match self.variables.get(&name).map(|x| x.clone()) { - None => StorageResp::VariableMissing {}, - Some((version, value)) => StorageResp::ReadVar { value, version }, - }, - StorageCmd::WriteVar { - name, - value, - when_version_is, - op, - } => { - let curr = self.variables.get(&name).map(|x| x.clone()); - match curr { - Some((prev_version, prev_val)) => match when_version_is { - Some(v) if v != prev_version => StorageResp::ReadVar { - version: prev_version, - value: prev_val, - }, - _ => { - let value = match op { - StorageOp::Append => { - let mut v = prev_val.clone(); - v.extend(value); - v - } - StorageOp::Set => value, - }; - let version = prev_version + 1; - self.variables.insert(name, (version, value)); - StorageResp::WriteVar { version } - } - }, - - None => match when_version_is { - None => { - self.variables.insert(name, (1, value)); - StorageResp::WriteVar { version: 1 } - } - Some(_) => StorageResp::VariableMissing {}, - }, - } - } - } - } -} - From 431896ebd0194940aa0d73f2232129aa842885a6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:21:11 -0700 Subject: [PATCH 254/301] move folders --- {controllers/toktrie => core}/Cargo.toml | 0 {controllers/toktrie => core}/README.md | 0 {controllers/toktrie => core}/implementation.md | 0 {controllers/toktrie => core}/src/bytes.rs | 0 {controllers/toktrie => core}/src/lib.rs | 0 {controllers/toktrie => core}/src/recognizer.rs | 0 {controllers/toktrie => core}/src/rng.rs | 0 {controllers/toktrie => core}/src/svob.rs | 0 {controllers/toktrie => core}/src/toktree.rs | 0 {controllers/toktrie_hf_tokenizers => hf_tokenizers}/Cargo.toml | 0 {controllers/toktrie_hf_tokenizers => hf_tokenizers}/src/lib.rs | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename {controllers/toktrie => core}/Cargo.toml (100%) rename {controllers/toktrie => core}/README.md (100%) rename {controllers/toktrie => core}/implementation.md (100%) rename {controllers/toktrie => core}/src/bytes.rs (100%) rename {controllers/toktrie => core}/src/lib.rs (100%) rename {controllers/toktrie => core}/src/recognizer.rs (100%) rename {controllers/toktrie => core}/src/rng.rs (100%) rename {controllers/toktrie => core}/src/svob.rs (100%) rename {controllers/toktrie => core}/src/toktree.rs (100%) rename {controllers/toktrie_hf_tokenizers => hf_tokenizers}/Cargo.toml (100%) rename {controllers/toktrie_hf_tokenizers => hf_tokenizers}/src/lib.rs (100%) diff --git a/controllers/toktrie/Cargo.toml b/core/Cargo.toml similarity index 100% rename from controllers/toktrie/Cargo.toml rename to core/Cargo.toml diff --git a/controllers/toktrie/README.md b/core/README.md similarity index 100% rename from controllers/toktrie/README.md rename to core/README.md diff --git a/controllers/toktrie/implementation.md b/core/implementation.md similarity index 100% rename from controllers/toktrie/implementation.md rename to core/implementation.md diff --git a/controllers/toktrie/src/bytes.rs b/core/src/bytes.rs similarity index 100% rename from controllers/toktrie/src/bytes.rs rename to core/src/bytes.rs diff --git a/controllers/toktrie/src/lib.rs b/core/src/lib.rs similarity index 100% rename from controllers/toktrie/src/lib.rs rename to core/src/lib.rs diff --git a/controllers/toktrie/src/recognizer.rs b/core/src/recognizer.rs similarity index 100% rename from controllers/toktrie/src/recognizer.rs rename to core/src/recognizer.rs diff --git a/controllers/toktrie/src/rng.rs b/core/src/rng.rs similarity index 100% rename from controllers/toktrie/src/rng.rs rename to core/src/rng.rs diff --git a/controllers/toktrie/src/svob.rs b/core/src/svob.rs similarity index 100% rename from controllers/toktrie/src/svob.rs rename to core/src/svob.rs diff --git a/controllers/toktrie/src/toktree.rs b/core/src/toktree.rs similarity index 100% rename from controllers/toktrie/src/toktree.rs rename to core/src/toktree.rs diff --git a/controllers/toktrie_hf_tokenizers/Cargo.toml b/hf_tokenizers/Cargo.toml similarity index 100% rename from controllers/toktrie_hf_tokenizers/Cargo.toml rename to hf_tokenizers/Cargo.toml diff --git a/controllers/toktrie_hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs similarity index 100% rename from controllers/toktrie_hf_tokenizers/src/lib.rs rename to hf_tokenizers/src/lib.rs From 6dfefaca36d3379bf00a3f3048b980fc163966df Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:21:24 -0700 Subject: [PATCH 255/301] move files --- .gitignore | 400 +------------------- README.md | 59 ++- core/Cargo.lock | 122 ++++++ core/README.md | 54 --- core/implementation.md => implementation.md | 0 5 files changed, 175 insertions(+), 460 deletions(-) create mode 100644 core/Cargo.lock delete mode 100644 core/README.md rename core/implementation.md => implementation.md (100%) diff --git a/.gitignore b/.gitignore index 8a30d258..847709f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,398 +1,2 @@ -## Ignore Visual Studio temporary files, build results, and -## files generated by popular Visual Studio add-ons. -## -## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore - -# User-specific files -*.rsuser -*.suo -*.user -*.userosscache -*.sln.docstates - -# User-specific files (MonoDevelop/Xamarin Studio) -*.userprefs - -# Mono auto generated files -mono_crash.* - -# Build results -[Dd]ebug/ -[Dd]ebugPublic/ -[Rr]elease/ -[Rr]eleases/ -x64/ -x86/ -[Ww][Ii][Nn]32/ -[Aa][Rr][Mm]/ -[Aa][Rr][Mm]64/ -bld/ -[Bb]in/ -[Oo]bj/ -[Ll]og/ -[Ll]ogs/ - -# Visual Studio 2015/2017 cache/options directory -.vs/ -# Uncomment if you have tasks that create the project's static files in wwwroot -#wwwroot/ - -# Visual Studio 2017 auto generated files -Generated\ Files/ - -# MSTest test Results -[Tt]est[Rr]esult*/ -[Bb]uild[Ll]og.* - -# NUnit -*.VisualState.xml -TestResult.xml -nunit-*.xml - -# Build Results of an ATL Project -[Dd]ebugPS/ -[Rr]eleasePS/ -dlldata.c - -# Benchmark Results -BenchmarkDotNet.Artifacts/ - -# .NET Core -project.lock.json -project.fragment.lock.json -artifacts/ - -# ASP.NET Scaffolding -ScaffoldingReadMe.txt - -# StyleCop -StyleCopReport.xml - -# Files built by Visual Studio -*_i.c -*_p.c -*_h.h -*.ilk -*.meta -*.obj -*.iobj -*.pch -*.pdb -*.ipdb -*.pgc -*.pgd -*.rsp -*.sbr -*.tlb -*.tli -*.tlh -*.tmp -*.tmp_proj -*_wpftmp.csproj -*.log -*.tlog -*.vspscc -*.vssscc -.builds -*.pidb -*.svclog -*.scc - -# Chutzpah Test files -_Chutzpah* - -# Visual C++ cache files -ipch/ -*.aps -*.ncb -*.opendb -*.opensdf -*.sdf -*.cachefile -*.VC.db -*.VC.VC.opendb - -# Visual Studio profiler -*.psess -*.vsp -*.vspx -*.sap - -# Visual Studio Trace Files -*.e2e - -# TFS 2012 Local Workspace -$tf/ - -# Guidance Automation Toolkit -*.gpState - -# ReSharper is a .NET coding add-in -_ReSharper*/ -*.[Rr]e[Ss]harper -*.DotSettings.user - -# TeamCity is a build add-in -_TeamCity* - -# DotCover is a Code Coverage Tool -*.dotCover - -# AxoCover is a Code Coverage Tool -.axoCover/* -!.axoCover/settings.json - -# Coverlet is a free, cross platform Code Coverage Tool -coverage*.json -coverage*.xml -coverage*.info - -# Visual Studio code coverage results -*.coverage -*.coveragexml - -# NCrunch -_NCrunch_* -.*crunch*.local.xml -nCrunchTemp_* - -# MightyMoose -*.mm.* -AutoTest.Net/ - -# Web workbench (sass) -.sass-cache/ - -# Installshield output folder -[Ee]xpress/ - -# DocProject is a documentation generator add-in -DocProject/buildhelp/ -DocProject/Help/*.HxT -DocProject/Help/*.HxC -DocProject/Help/*.hhc -DocProject/Help/*.hhk -DocProject/Help/*.hhp -DocProject/Help/Html2 -DocProject/Help/html - -# Click-Once directory -publish/ - -# Publish Web Output -*.[Pp]ublish.xml -*.azurePubxml -# Note: Comment the next line if you want to checkin your web deploy settings, -# but database connection strings (with potential passwords) will be unencrypted -*.pubxml -*.publishproj - -# Microsoft Azure Web App publish settings. Comment the next line if you want to -# checkin your Azure Web App publish settings, but sensitive information contained -# in these scripts will be unencrypted -PublishScripts/ - -# NuGet Packages -*.nupkg -# NuGet Symbol Packages -*.snupkg -# The packages folder can be ignored because of Package Restore -**/[Pp]ackages/* -# except build/, which is used as an MSBuild target. -!**/[Pp]ackages/build/ -# Uncomment if necessary however generally it will be regenerated when needed -#!**/[Pp]ackages/repositories.config -# NuGet v3's project.json files produces more ignorable files -*.nuget.props -*.nuget.targets - -# Microsoft Azure Build Output -csx/ -*.build.csdef - -# Microsoft Azure Emulator -ecf/ -rcf/ - -# Windows Store app package directories and files -AppPackages/ -BundleArtifacts/ -Package.StoreAssociation.xml -_pkginfo.txt -*.appx -*.appxbundle -*.appxupload - -# Visual Studio cache files -# files ending in .cache can be ignored -*.[Cc]ache -# but keep track of directories ending in .cache -!?*.[Cc]ache/ - -# Others -ClientBin/ -~$* -*~ -*.dbmdl -*.dbproj.schemaview -*.jfm -*.pfx -*.publishsettings -orleans.codegen.cs - -# Including strong name files can present a security risk -# (https://github.com/github/gitignore/pull/2483#issue-259490424) -#*.snk - -# Since there are multiple workflows, uncomment next line to ignore bower_components -# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) -#bower_components/ - -# RIA/Silverlight projects -Generated_Code/ - -# Backup & report files from converting an old project file -# to a newer Visual Studio version. Backup files are not needed, -# because we have git ;-) -_UpgradeReport_Files/ -Backup*/ -UpgradeLog*.XML -UpgradeLog*.htm -ServiceFabricBackup/ -*.rptproj.bak - -# SQL Server files -*.mdf -*.ldf -*.ndf - -# Business Intelligence projects -*.rdl.data -*.bim.layout -*.bim_*.settings -*.rptproj.rsuser -*- [Bb]ackup.rdl -*- [Bb]ackup ([0-9]).rdl -*- [Bb]ackup ([0-9][0-9]).rdl - -# Microsoft Fakes -FakesAssemblies/ - -# GhostDoc plugin setting file -*.GhostDoc.xml - -# Node.js Tools for Visual Studio -.ntvs_analysis.dat -node_modules/ - -# Visual Studio 6 build log -*.plg - -# Visual Studio 6 workspace options file -*.opt - -# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) -*.vbw - -# Visual Studio 6 auto-generated project file (contains which files were open etc.) -*.vbp - -# Visual Studio 6 workspace and project file (working project files containing files to include in project) -*.dsw -*.dsp - -# Visual Studio 6 technical files -*.ncb -*.aps - -# Visual Studio LightSwitch build output -**/*.HTMLClient/GeneratedArtifacts -**/*.DesktopClient/GeneratedArtifacts -**/*.DesktopClient/ModelManifest.xml -**/*.Server/GeneratedArtifacts -**/*.Server/ModelManifest.xml -_Pvt_Extensions - -# Paket dependency manager -.paket/paket.exe -paket-files/ - -# FAKE - F# Make -.fake/ - -# CodeRush personal settings -.cr/personal - -# Python Tools for Visual Studio (PTVS) -__pycache__/ -*.pyc - -# Cake - Uncomment if you are using it -# tools/** -# !tools/packages.config - -# Tabs Studio -*.tss - -# Telerik's JustMock configuration file -*.jmconfig - -# BizTalk build output -*.btp.cs -*.btm.cs -*.odx.cs -*.xsd.cs - -# OpenCover UI analysis results -OpenCover/ - -# Azure Stream Analytics local run output -ASALocalRun/ - -# MSBuild Binary and Structured Log -*.binlog - -# NVidia Nsight GPU debugger configuration file -*.nvuser - -# MFractors (Xamarin productivity tool) working folder -.mfractor/ - -# Local History for Visual Studio -.localhistory/ - -# Visual Studio History (VSHistory) files -.vshistory/ - -# BeatPulse healthcheck temp database -healthchecksdb - -# Backup folder for Package Reference Convert tool in Visual Studio 2017 -MigrationBackup/ - -# Ionide (cross platform F# VS Code tools) working folder -.ionide/ - -# Fody - auto-generated XML schema -FodyWeavers.xsd - -# VS Code files for those working on multiple tools -.vscode/* -!.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json -*.code-workspace - -# Local History for Visual Studio Code -.history/ - -# Windows Installer files from build outputs -*.cab -*.msi -*.msix -*.msm -*.msp - -# JetBrains Rider -*.sln.iml +target +tmp diff --git a/README.md b/README.md index 5cd7cecf..5871fe4d 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,57 @@ -# Project +# toktrie - Token utility library -> This repo has been populated by an initial template to help get you started. Please -> make sure to update the content to build a great experience for community-building. +This crate provides a utility library for working with tokens and token tries. -As the maintainer of this project, please make a few updates: +## Byte stack interface + +The constraints are typically expressed on strings or bytes, not tokens. +To compute the set of tokens that match a string constraint, one needs go through all the possible tokens +and apply the constraint. +An efficient way to do this is walk a prefix tree (trie) of all tokens. +This library implements this trie and exposes a way of filtering when provided with a constraints +implementing the [following interface](core/src/toktree.rs): + +```rust +pub trait Recognizer { + /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. + fn push_byte(&mut self, byte: u8); + /// for _ in 0..num { stack.pop() } + fn pop_bytes(&mut self, num: usize); + /// X = stack.top(); stack.empty(); stack.push(X) + fn collapse(&mut self); + /// check if stack.top() transitions via byte to a viable state + fn byte_allowed(&mut self, byte: u8) -> bool; + /// check if stack.top() transitions via tok to a viable state + fn special_allowed(&mut self, tok: SpecialToken) -> bool; + /// Called when iteration over the trie is finished + /// Stack has exactly one element then. + fn trie_finished(&mut self); + /// This combines `push_byte` and `byte_allowed` into one function for performance. + fn try_push_byte(&mut self, byte: u8) -> bool; +} +``` + +The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. + +## Functional byte interface + +The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. + +```rust +pub trait FunctionalRecognizer { + /// Initial state + fn initial(&self) -> S; + /// Extend the recognizer with given byte. + fn append(&self, state: S, byte: u8) -> S; + /// Check if given byte is allowed in given state. + fn byte_allowed(&self, state: S, byte: u8) -> bool; + /// Check if given special token is allowed in given state. + fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; +} +``` + +These three layers add up to about 40k of compiled code (Wasm). -- Improving this README.MD file to provide a great experience -- Updating SUPPORT.MD with content about this project's support experience -- Understanding the security reporting process in SECURITY.MD -- Remove this section from the README ## Contributing diff --git a/core/Cargo.lock b/core/Cargo.lock new file mode 100644 index 00000000..aab2775f --- /dev/null +++ b/core/Cargo.lock @@ -0,0 +1,122 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "bytemuck" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" + +[[package]] +name = "bytemuck_derive" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "2.0.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "toktrie" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytemuck", + "bytemuck_derive", + "rustc-hash", + "serde", + "serde_json", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/core/README.md b/core/README.md deleted file mode 100644 index 035599b5..00000000 --- a/core/README.md +++ /dev/null @@ -1,54 +0,0 @@ -# toktrie - Token utility library - -This crate provides a utility library for working with tokens and token tries. - -## Byte stack interface - -The constraints are typically expressed on strings or bytes, not tokens. -To compute the set of tokens that match a string constraint, one needs go through all the possible tokens -and apply the constraint. -An efficient way to do this is walk a prefix tree (trie) of all tokens. -The `aici_abi` library implements this trie and exposes a way of filtering when provided with a constraints -implementing the [following interface](src/toktree.rs): - -```rust -pub trait Recognizer { - /// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. - fn push_byte(&mut self, byte: u8); - /// for _ in 0..num { stack.pop() } - fn pop_bytes(&mut self, num: usize); - /// X = stack.top(); stack.empty(); stack.push(X) - fn collapse(&mut self); - /// check if stack.top() transitions via byte to a viable state - fn byte_allowed(&mut self, byte: u8) -> bool; - /// check if stack.top() transitions via tok to a viable state - fn special_allowed(&mut self, tok: SpecialToken) -> bool; - /// Called when iteration over the trie is finished - /// Stack has exactly one element then. - fn trie_finished(&mut self); - /// This combines `push_byte` and `byte_allowed` into one function for performance. - fn try_push_byte(&mut self, byte: u8) -> bool; -} -``` - -The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. - -## Functional byte interface - -The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. - -```rust -pub trait FunctionalRecognizer { - /// Initial state - fn initial(&self) -> S; - /// Extend the recognizer with given byte. - fn append(&self, state: S, byte: u8) -> S; - /// Check if given byte is allowed in given state. - fn byte_allowed(&self, state: S, byte: u8) -> bool; - /// Check if given special token is allowed in given state. - fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; -} -``` - -These three layers add up to about 40k of compiled code (Wasm). - diff --git a/core/implementation.md b/implementation.md similarity index 100% rename from core/implementation.md rename to implementation.md From 050253c56d4390414e7837b2e675bedc45430e66 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:23:05 -0700 Subject: [PATCH 256/301] build for hf_tokenizers --- hf_tokenizers/Cargo.lock | 1414 ++++++++++++++++++++++++++++++++++++++ hf_tokenizers/Cargo.toml | 2 +- 2 files changed, 1415 insertions(+), 1 deletion(-) create mode 100644 hf_tokenizers/Cargo.lock diff --git a/hf_tokenizers/Cargo.lock b/hf_tokenizers/Cargo.lock new file mode 100644 index 00000000..3adb495c --- /dev/null +++ b/hf_tokenizers/Cargo.lock @@ -0,0 +1,1414 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" + +[[package]] +name = "anstyle-parse" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "bytemuck" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" + +[[package]] +name = "bytemuck_derive" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "cc" +version = "1.0.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim 0.11.1", +] + +[[package]] +name = "clap_derive" +version = "4.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "clap_lex" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" + +[[package]] +name = "colorchoice" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" + +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hf-hub" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +dependencies = [ + "dirs", + "indicatif", + "log", + "native-tls", + "rand", + "serde", + "serde_json", + "thiserror", + "ureq", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "macro_rules_attribute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +dependencies = [ + "adler", +] + +[[package]] +name = "monostate" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "openssl" +version = "0.10.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" + +[[package]] +name = "rustls-webpki" +version = "0.102.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "serde_json" +version = "1.0.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "thiserror" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + +[[package]] +name = "tinyvec" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6b6a2fb3a985e99cebfaefa9faa3024743da73304ca1c683a36429613d3d22" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokenizers" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "toktrie" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytemuck", + "bytemuck_derive", + "rustc-hash", + "serde", + "serde_json", +] + +[[package]] +name = "toktrie_hf_tokenizers" +version = "0.1.0" +dependencies = [ + "anyhow", + "log", + "rustc-hash", + "serde", + "serde_json", + "tokenizers", + "toktrie", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "webpki-roots" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/hf_tokenizers/Cargo.toml b/hf_tokenizers/Cargo.toml index 5de529c9..bef35682 100644 --- a/hf_tokenizers/Cargo.toml +++ b/hf_tokenizers/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -toktrie = { path = "../toktrie" } +toktrie = { path = "../core" } serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" From d2cfea79efb474e13b8c11a87a0c96b70bd2ec60 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 5 Jul 2024 17:24:41 -0700 Subject: [PATCH 257/301] add build file --- .github/workflows/rust.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 00000000..2df7c1fc --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,24 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build core + run: cargo build --verbose + working-directory: core + - name: Build for hf-tokenizers + run: cargo build --verbose + working-directory: hf_tokenizers From e096d8bb8d612f7e0c2a504a3d3ba535d181abdb Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 8 Jul 2024 22:56:04 +0000 Subject: [PATCH 258/301] add utility methods on Branch --- core/src/lib.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 1e2bba29..ac05223c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -99,6 +99,15 @@ impl Branch { } } + pub fn has_backtrack(&self) -> bool { + let max_bt = if self.sample_mask.is_none() { 0 } else { 1 }; + self.splices.iter().any(|s| s.backtrack > max_bt) + } + + pub fn has_ff_tokens(&self) -> bool { + self.splices.len() > 0 + } + pub fn stop() -> Self { Branch { sample_mask: None, @@ -136,4 +145,4 @@ impl Branch { } } -pub type StepResult = Branch; \ No newline at end of file +pub type StepResult = Branch; From 0af6a2806e931bebe8646876269cafd0b5bed64b Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 10 Jul 2024 02:50:44 +0000 Subject: [PATCH 259/301] Add `sampled` field to `StepArg` struct --- core/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/lib.rs b/core/src/lib.rs index ac05223c..f703c306 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -17,6 +17,8 @@ pub struct StepArg { /// Can be more complex when splices are used. pub backtrack: u32, pub tokens: Vec, + /// The token that was sampled (after applying the mask), before any splicing. + pub sampled: Option, } impl StepArg { From a13ea99a9327761086537d7819211c0122d52540 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 10 Jul 2024 18:46:13 +0000 Subject: [PATCH 260/301] add utility functions --- core/src/lib.rs | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index f703c306..de86acb4 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -9,7 +9,7 @@ mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct StepArg { /// Sampling result for the previous iteration. /// For simple sampled token 't', backtrack==0 and tokens==[t]. @@ -22,6 +22,14 @@ pub struct StepArg { } impl StepArg { + pub fn empty() -> Self { + StepArg { + backtrack: 0, + tokens: vec![], + sampled: None, + } + } + pub fn save_tokens(&self, acc_tokens: &mut Vec) { let bt = self.backtrack as usize; assert!( @@ -101,6 +109,30 @@ impl Branch { } } + pub fn find_splice(&self, sampled: TokenId) -> Option<&Splice> { + self.splices + .iter() + .find(|s| s.when_sampled.is_empty() || s.when_sampled.contains(&sampled)) + } + + pub fn spliced(&self, sampled: TokenId) -> Splice { + self.find_splice(sampled).cloned().unwrap_or_else(|| { + Splice { + when_sampled: vec![], + backtrack: 0, + ff_tokens: vec![sampled], + } + }) + } + + pub fn unconditional_splice(&self) -> Option<&Splice> { + if self.splices.len() == 1 && self.splices[0].when_sampled.is_empty() { + Some(&self.splices[0]) + } else { + None + } + } + pub fn has_backtrack(&self) -> bool { let max_bt = if self.sample_mask.is_none() { 0 } else { 1 }; self.splices.iter().any(|s| s.backtrack > max_bt) From 2bf689038069a954293e398ee40b028f1db0a32b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:15:36 +0000 Subject: [PATCH 261/301] Bump openssl from 0.10.64 to 0.10.66 in /hf_tokenizers Bumps [openssl](https://github.com/sfackler/rust-openssl) from 0.10.64 to 0.10.66. - [Release notes](https://github.com/sfackler/rust-openssl/releases) - [Commits](https://github.com/sfackler/rust-openssl/compare/openssl-v0.10.64...openssl-v0.10.66) --- updated-dependencies: - dependency-name: openssl dependency-type: indirect ... Signed-off-by: dependabot[bot] --- hf_tokenizers/Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hf_tokenizers/Cargo.lock b/hf_tokenizers/Cargo.lock index 3adb495c..ca52adf7 100644 --- a/hf_tokenizers/Cargo.lock +++ b/hf_tokenizers/Cargo.lock @@ -655,9 +655,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -687,9 +687,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", From 420a91ca48fae6b9395036e2e9468758051c761c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 22 Jul 2024 14:37:53 -0700 Subject: [PATCH 262/301] Move InferenceCapabilities from aici --- core/src/lib.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/core/src/lib.rs b/core/src/lib.rs index de86acb4..e39bf9b0 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -9,6 +9,27 @@ mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; +/// Defines what is allowed in Branch +#[derive(Serialize, Deserialize, Clone, Debug, Default)] +pub struct InferenceCapabilities { + /// Unconditional splice is allowed. + #[serde(default)] + pub ff_tokens: bool, + + /// Conditional (and unconditional) splices are allowed. + #[serde(default)] + pub conditional_ff_tokens: bool, + + /// Backtracking is allowed. + #[serde(default)] + pub backtrack: bool, + + /// More than one branch is allowed. + #[serde(default)] + pub fork: bool, +} + + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StepArg { /// Sampling result for the previous iteration. From 158139aff4aa0ec0c049a2870492b000bbc375b9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 22 Jul 2024 16:20:21 -0700 Subject: [PATCH 263/301] chore: Add singleton_token_set method to TokTrie --- core/src/toktree.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 495cd53d..54a10e1d 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -226,6 +226,12 @@ impl TokTrie { r } + pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob { + let mut r = self.alloc_token_set(); + r.allow_token(tok); + r + } + pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { let max_examples = 50; From ae506a08efc9d41928155d7e447011965e172aa6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 24 Jul 2024 14:05:27 -0700 Subject: [PATCH 264/301] bugfix and small API usability --- core/src/svob.rs | 9 ++++++++- core/src/toktree.rs | 14 ++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/svob.rs b/core/src/svob.rs index 9c648827..d220c891 100644 --- a/core/src/svob.rs +++ b/core/src/svob.rs @@ -67,7 +67,7 @@ impl SimpleVob { r } - pub fn all_true(size: usize) -> Self { + pub fn alloc_ones(size: usize) -> Self { let mut r = Self::alloc(size); r.set_all(true); r @@ -304,6 +304,13 @@ impl SimpleVob { .all(|(a, b)| *a & *b == 0) } + pub fn sub(&mut self, other: &SimpleVob) { + assert_eq!(self.size, other.size); + for (idx, v) in self.data.iter_mut().zip(other.data.iter()) { + *idx &= !*v; + } + } + pub fn first_bit_set_here_and_in(&self, other: &SimpleVob) -> Option { assert_eq!(self.size, other.size); for (idx, (a, b)) in self.data.iter().zip(other.data.iter()).enumerate() { diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 54a10e1d..f6b901e5 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -221,9 +221,7 @@ impl TokTrie { } pub fn alloc_token_set(&self) -> SimpleVob { - let mut r = SimpleVob::new(); - r.resize(self.vocab_size() + 1); - r + SimpleVob::alloc(self.vocab_size() + 1) } pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob { @@ -324,6 +322,9 @@ impl TokTrie { } pub fn token(&self, idx: u32) -> &[u8] { + if idx >= self.token_offsets.len() as u32 { + return &[]; + } let off = self.token_offsets[idx as usize]; let len = off & ((1 << LEN_BITS) - 1); let off = (off >> LEN_BITS) as usize; @@ -711,8 +712,13 @@ impl TokTrie { } } + let n = self.child_at_bytes(self.root(), start); + if n.is_none() { + return; + } + let n = n.unwrap(); + r.trie_started(); - let n = self.child_at_bytes(self.root(), start).unwrap(); let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); let mut p = off + 1; From b91819ad39b45b40a0b3fa0872081414154826e2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 25 Jul 2024 16:40:09 -0700 Subject: [PATCH 265/301] clean up template info from SUPPORT.md --- SUPPORT.md | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/SUPPORT.md b/SUPPORT.md index 291d4d43..382f1b8b 100644 --- a/SUPPORT.md +++ b/SUPPORT.md @@ -1,13 +1,3 @@ -# TODO: The maintainer of this repo has not yet edited this file - -**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? - -- **No CSS support:** Fill out this template with information about how to file issues and get help. -- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. -- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. - -*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* - # Support ## How to file issues and get help @@ -16,10 +6,8 @@ This project uses GitHub Issues to track bugs and feature requests. Please searc issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue. -For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE -FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER -CHANNEL. WHERE WILL YOU HELP PEOPLE?**. +For help and questions about using this project, please use GitHub Discussions. ## Microsoft Support Policy -Support for this **PROJECT or PRODUCT** is limited to the resources listed above. +Support for this project is limited to the resources listed above. From d8d179ae2bbfe41fcd140d583781efd9fedfcfb0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 26 Jul 2024 10:44:11 -0700 Subject: [PATCH 266/301] don't expose the 'defl_tok' in add_bias() --- core/src/svob.rs | 21 +++++++++++++++++++-- core/src/toktree.rs | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/core/src/svob.rs b/core/src/svob.rs index d220c891..2f582744 100644 --- a/core/src/svob.rs +++ b/core/src/svob.rs @@ -73,6 +73,14 @@ impl SimpleVob { r } + pub fn alloc_with_capacity(size: usize, capacity: usize) -> Self { + let mut r = Self::new(); + assert!(size <= capacity); + r.resize(capacity); + r.size = size; + r + } + pub fn len(&self) -> usize { self.size } @@ -91,7 +99,11 @@ impl SimpleVob { pub fn to_bin_string(&self) -> String { let mut s = String::new(); for i in 0..self.size { - s.push(if self.is_allowed(i as TokenId) { '1' } else { '0' }); + s.push(if self.is_allowed(i as TokenId) { + '1' + } else { + '0' + }); } s } @@ -280,7 +292,12 @@ impl SimpleVob { pub fn or_minus(&mut self, other: &SimpleVob, minus: &SimpleVob) { assert_eq!(self.size, other.size); assert_eq!(self.size, minus.size); - for ((slf, oth), mn) in self.data.iter_mut().zip(other.data.iter()).zip(minus.data.iter()) { + for ((slf, oth), mn) in self + .data + .iter_mut() + .zip(other.data.iter()) + .zip(minus.data.iter()) + { *slf |= *oth & !*mn; } } diff --git a/core/src/toktree.rs b/core/src/toktree.rs index f6b901e5..3964b4b9 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -221,7 +221,7 @@ impl TokTrie { } pub fn alloc_token_set(&self) -> SimpleVob { - SimpleVob::alloc(self.vocab_size() + 1) + SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1) } pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob { From a3bb0afa8a133bf6ca7d79c8b2a8371294f3c9e0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 29 Jul 2024 18:39:44 -0700 Subject: [PATCH 267/301] optimize set_all() --- core/src/svob.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/svob.rs b/core/src/svob.rs index 2f582744..114d7bac 100644 --- a/core/src/svob.rs +++ b/core/src/svob.rs @@ -258,9 +258,11 @@ impl SimpleVob { } pub fn set_all(&mut self, val: bool) { - let val = if val { !0 } else { 0 }; - self.data.iter_mut().for_each(|x| *x = val); - self.clear_excessive_bits(); + let bits = if val { !0 } else { 0 }; + self.data.iter_mut().for_each(|x| *x = bits); + if val { + self.clear_excessive_bits(); + } } pub fn apply_to(&self, logits: &mut [f32]) { From 7550e792ba9af7d22ee6a9bf4fd7631e1ca659f9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 30 Jul 2024 21:08:11 -0700 Subject: [PATCH 268/301] add utility functions/types --- core/src/lib.rs | 2 +- core/src/toktree.rs | 4 ++++ hf_tokenizers/src/lib.rs | 47 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index e39bf9b0..296e0538 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,7 +7,7 @@ mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; -pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; +pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TokEnv}; /// Defines what is allowed in Branch #[derive(Serialize, Deserialize, Clone, Debug, Default)] diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 3964b4b9..5829c57c 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -1,6 +1,8 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 +use std::sync::Arc; + use anyhow::Result; use bytemuck_derive::{Pod, Zeroable}; use rustc_hash::FxHashMap; @@ -71,6 +73,8 @@ pub trait TokenizerEnv: Send { } } +pub type TokEnv = Arc; + #[derive(Clone)] pub struct TokTrie { info: TokRxInfo, diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 3e434e7a..781e78b9 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -1,9 +1,9 @@ use anyhow::{anyhow, bail, Result}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; -use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer}; -use toktrie::{TokRxInfo, TokTrie, TokenId, TokenizerEnv}; +use std::{collections::BTreeMap, sync::Arc}; +use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; +use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; #[derive(Serialize, Deserialize)] pub struct ByteTokenizer { @@ -39,7 +39,39 @@ fn build_char_map() -> FxHashMap { res } +fn strip_suffix(sep: &str, s: &mut String) -> Option { + let mut parts = s.splitn(2, sep); + let core = parts.next().unwrap().to_string(); + let suff = parts.next().map(|s| s.to_string()); + *s = core; + suff +} + impl ByteTokenizer { + pub fn from_name(name: &str) -> Result { + let loaded = if name.starts_with(".") || name.starts_with("/") { + Tokenizer::from_file(name) + } else { + let mut name2 = name.to_string(); + let mut args = FromPretrainedParameters::default(); + match strip_suffix("@", &mut name2) { + Some(s) => args.revision = s, + None => {} + } + Tokenizer::from_pretrained(name2, Some(args)) + }; + + let tok = loaded.map_err(|e| anyhow!("error loading tokenizer: {}", e))?; + + ByteTokenizer::from_tokenizer(tok) + } + + pub fn from_file(name: &str) -> Result { + let tok = + Tokenizer::from_file(name).map_err(|e| anyhow!("error loading tokenizer: {}", e))?; + ByteTokenizer::from_tokenizer(tok) + } + pub fn from_tokenizer(mut hft: Tokenizer) -> Result { let mut is_byte_level = false; let mut is_byte_fallback = false; @@ -195,6 +227,11 @@ pub struct ByteTokenizerEnv { } impl ByteTokenizerEnv { + pub fn from_name(name: &str) -> Result { + let tokenizer = ByteTokenizer::from_name(name)?; + Ok(ByteTokenizerEnv::new(tokenizer)) + } + pub fn new(tokenizer: ByteTokenizer) -> ByteTokenizerEnv { let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); ByteTokenizerEnv { @@ -202,6 +239,10 @@ impl ByteTokenizerEnv { tok_trie, } } + + pub fn to_env(self) -> TokEnv { + Arc::new(self) + } } impl TokenizerEnv for ByteTokenizerEnv { From 022b496ef5feebe4a0234df5ffa7c146067dc55d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 6 Aug 2024 17:20:59 +0000 Subject: [PATCH 269/301] chat mode support --- core/src/toktree.rs | 56 +++++++++++++++++++++++++++++++++++++--- hf_tokenizers/src/lib.rs | 22 +++++++--------- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 5829c57c..643d14d0 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -16,9 +16,50 @@ pub type TokenId = u32; #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] +pub struct BinTokRxInfo { + pub vocab_size: u32, + pub tok_eos: TokenId, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct TokRxInfo { pub vocab_size: u32, pub tok_eos: TokenId, + pub tok_bos: Option, + pub tok_pad: Option, + pub tok_unk: Option, + pub tok_end_of_turn: Option, +} + +impl TokRxInfo { + pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self { + TokRxInfo { + vocab_size, + tok_eos, + tok_bos: None, + tok_pad: None, + tok_unk: None, + tok_end_of_turn: None, + } + } + + pub fn from_bin(info: &BinTokRxInfo) -> Self { + TokRxInfo { + vocab_size: info.vocab_size, + tok_eos: info.tok_eos, + tok_bos: None, + tok_pad: None, + tok_unk: None, + tok_end_of_turn: None, + } + } + + pub fn to_bin(&self) -> BinTokRxInfo { + BinTokRxInfo { + vocab_size: self.vocab_size, + tok_eos: self.tok_eos, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -28,6 +69,7 @@ pub enum SpecialToken { Separator, BeginningOfSentence, EndOfSentence, + EndOfTurn, } pub trait Recognizer { @@ -93,7 +135,7 @@ pub struct TokTrieHeader { trie_bytes: u32, token_offset_bytes: u32, token_data_bytes: u32, - info: TokRxInfo, + info: BinTokRxInfo, align: [u32; 0], } @@ -178,6 +220,14 @@ impl TokTrie { r } + pub fn build_chat_mode_trie(&self) -> Self { + let mut r = self.clone(); + if let Some(t) = self.info.tok_end_of_turn { + r.info.tok_eos = t; + } + r + } + fn finalize_ctor(&mut self) { for tok_id in 0..self.info.vocab_size { let bytes = self.token(tok_id); @@ -447,7 +497,7 @@ impl TokTrie { let token_data = vec_from_bytes(&bytes[offsets_end..]); let mut r = TokTrie { - info: hd.info, + info: TokRxInfo::from_bin(&hd.info), token_offsets, token_data, nodes, @@ -497,7 +547,7 @@ impl TokTrie { trie_bytes: trie_data.len() as u32, token_offset_bytes: token_offsets.len() as u32, token_data_bytes: trie_data.len() as u32, - info: self.info.clone(), + info: self.info.to_bin(), align: [], }; diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 781e78b9..73c9066a 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -1,16 +1,13 @@ use anyhow::{anyhow, bail, Result}; use rustc_hash::FxHashMap; -use serde::{Deserialize, Serialize}; use std::{collections::BTreeMap, sync::Arc}; use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; -#[derive(Serialize, Deserialize)] pub struct ByteTokenizer { pub hf_model: String, pub hf_tokenizer: Tokenizer, - pub eos_token: u32, - pub vocab_size: u32, + info: TokRxInfo, token_bytes: Vec>, pub special: BTreeMap, } @@ -129,8 +126,7 @@ impl ByteTokenizer { let mut res = ByteTokenizer { hf_model: "foobar".to_string(), - eos_token: 0, - vocab_size, + info: TokRxInfo::new(vocab_size, 0), special: BTreeMap::new(), token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), hf_tokenizer: hft, @@ -139,7 +135,10 @@ impl ByteTokenizer { for (id, info) in added.iter() { if info.special { match info.content.as_str() { - "" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id, + "" | "<|endoftext|>" | "<|end_of_text|>" => res.info.tok_eos = *id, + "<|end|>" | "<|eot_id|>" => res.info.tok_end_of_turn = Some(*id), + "" | "<|unk|>" => res.info.tok_unk = Some(*id), + "" | "<|pad|>" => res.info.tok_pad = Some(*id), _ => {} } res.special.insert(info.content.clone(), *id); @@ -198,24 +197,21 @@ impl ByteTokenizer { } pub fn tokrx_info(&self) -> TokRxInfo { - TokRxInfo { - vocab_size: self.vocab_size, - tok_eos: self.eos_token, - } + self.info.clone() } pub fn token_bytes(&self) -> Vec> { self.token_bytes.clone() } pub fn add_missing_tokens(&mut self, vocab_size: usize) { - assert!(self.vocab_size == self.token_bytes.len() as u32); + assert!(self.info.vocab_size == self.token_bytes.len() as u32); assert!(vocab_size >= self.token_bytes.len()); assert!(vocab_size - self.token_bytes.len() <= 200); while self.token_bytes.len() < vocab_size { let idx = self.token_bytes.len(); let name = format!(""); self.token_bytes.push(name.as_bytes().to_vec()); - self.vocab_size += 1; + self.info.vocab_size += 1; self.special.insert(name, idx as u32); } } From ad0448ab7a041adb9eb239bfc635c64114d9db4a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 6 Aug 2024 20:47:36 +0000 Subject: [PATCH 270/301] allow vocab size override --- hf_tokenizers/src/lib.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 73c9066a..61ca936c 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -223,17 +223,28 @@ pub struct ByteTokenizerEnv { } impl ByteTokenizerEnv { - pub fn from_name(name: &str) -> Result { + pub fn from_name(name: &str, n_vocab: Option) -> Result { let tokenizer = ByteTokenizer::from_name(name)?; - Ok(ByteTokenizerEnv::new(tokenizer)) + ByteTokenizerEnv::new(tokenizer, n_vocab) } - pub fn new(tokenizer: ByteTokenizer) -> ByteTokenizerEnv { + pub fn new(tokenizer: ByteTokenizer, n_vocab: Option) -> Result { + let mut info = tokenizer.tokrx_info(); + let mut token_bytes = tokenizer.token_bytes(); + if let Some(n_vocab) = n_vocab { + if n_vocab < token_bytes.len() { + bail!("vocab size too small; {} vs {}", n_vocab, token_bytes.len()); + } + while n_vocab > token_bytes.len() { + token_bytes.push(Vec::new()); + } + info.vocab_size = n_vocab as u32; + } let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); - ByteTokenizerEnv { + Ok(ByteTokenizerEnv { tokenizer, tok_trie, - } + }) } pub fn to_env(self) -> TokEnv { From 37eef39f45f0ad104238f9979f1822d1764f93a3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 6 Aug 2024 20:54:15 +0000 Subject: [PATCH 271/301] fix tokenizer construction --- hf_tokenizers/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 61ca936c..c54ac598 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -240,7 +240,7 @@ impl ByteTokenizerEnv { } info.vocab_size = n_vocab as u32; } - let tok_trie = TokTrie::from(&tokenizer.tokrx_info(), &tokenizer.token_bytes()); + let tok_trie = TokTrie::from(&info, &token_bytes); Ok(ByteTokenizerEnv { tokenizer, tok_trie, From c1b18f647e95b460bd2aede214aede3b8c834f80 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 10 Aug 2024 00:00:09 +0000 Subject: [PATCH 272/301] add trie_stats --- core/src/toktree.rs | 102 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 94 insertions(+), 8 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 643d14d0..98202559 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -754,7 +754,6 @@ impl TokTrie { ok } - #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { // all prefixes of 'start' are also allowed if start.len() > 0 { @@ -771,8 +770,20 @@ impl TokTrie { return; } let n = n.unwrap(); - r.trie_started(); + let next_pop = self.add_bias_inner(r, toks, n); + if start.len() == 0 { + // if start was non-empty, trie_finished() is supposed to clean this up + r.pop_bytes(next_pop); + } + r.trie_finished(); + // revert the fake token + let defl_tok = self.vocab_size() as u32; + toks.disallow_token(defl_tok); + } + + #[inline(never)] + fn add_bias_inner(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, n: &TrieNode) -> usize { let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); let mut p = off + 1; @@ -795,13 +806,88 @@ impl TokTrie { next_pop = n.num_parents() - 1; } } - if start.len() == 0 { - // if start was non-empty, trie_finished() is supposed to clean this up - r.pop_bytes(next_pop); + next_pop + } + + fn count_until_depth(&self, depth: usize) -> usize { + let mut count = 0; + let mut stack = vec![(self.root(), 0)]; + while let Some((n, d)) = stack.pop() { + if d == depth { + continue; + } else { + for c in self.node_children(n) { + count += 1; + stack.push((c, d + 1)); + } + } } - r.trie_finished(); - // revert the fake token - toks.disallow_token(defl_tok); + count + } + + pub fn trie_stats(&self) -> String { + let mut nodes_histogram = vec![0; 256]; + + let mut token_nodes = 0; + + let n = self.root(); + let off = self.node_offset(n); + let mut p = off + 1; + let endp = off + n.subtree_size(); + while p < endp { + let n = &self.nodes[p]; + + if n.token_id().is_some() { + token_nodes += 1; + } + + let last_ch = self.next_node(n); + let mut ch_p = p + 1; + let mut num_children = 0; + + while ch_p < last_ch { + let ch = &self.nodes[ch_p]; + ch_p += ch.subtree_size(); + num_children += 1; + } + + nodes_histogram[std::cmp::min(9, num_children)] += 1; + + p += 1; + } + + let mut histogram = String::new(); + for (idx, num) in nodes_histogram.iter().enumerate() { + if *num > 0 { + if !histogram.is_empty() { + histogram.push_str(", "); + } + histogram.push_str(&format!("{}:{}", idx, num)); + } + } + + for n in self.node_children(self.root()) { + histogram.push_str(&format!( + "\n{} => {} {}", + n.byte(), + self.node_children(n).count(), + n.subtree_size() + )); + } + + for depth in 0..30 { + let count = self.count_until_depth(depth); + if count > 0 { + histogram.push_str(&format!("\ndepth {}: {} nodes", depth, count)); + } + } + + format!( + "{} nodes, {} token nodes,\n{}", + self.nodes.len(), + token_nodes, + histogram + ) } } From 6934722328ee1d3d679f95fcd5c669d47cee08f2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 10 Aug 2024 00:54:54 +0000 Subject: [PATCH 273/301] more stats --- core/src/toktree.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 98202559..02982c80 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -809,8 +809,9 @@ impl TokTrie { next_pop } - fn count_until_depth(&self, depth: usize) -> usize { + fn count_until_depth(&self, depth: usize) -> (usize, usize) { let mut count = 0; + let mut num_tokens = 0; let mut stack = vec![(self.root(), 0)]; while let Some((n, d)) = stack.pop() { if d == depth { @@ -818,11 +819,14 @@ impl TokTrie { } else { for c in self.node_children(n) { count += 1; + if c.token_id().is_some() { + num_tokens += 1; + } stack.push((c, d + 1)); } } } - count + (count, num_tokens) } pub fn trie_stats(&self) -> String { @@ -876,17 +880,19 @@ impl TokTrie { } for depth in 0..30 { - let count = self.count_until_depth(depth); - if count > 0 { - histogram.push_str(&format!("\ndepth {}: {} nodes", depth, count)); - } + let (count, num_tokens) = self.count_until_depth(depth); + histogram.push_str(&format!( + "\ndepth {}: {} nodes {} tokens", + depth, count, num_tokens + )); } format!( - "{} nodes, {} token nodes,\n{}", + "{}\n{} nodes, {} token nodes, {} token bytes", + histogram, self.nodes.len(), token_nodes, - histogram + self.token_data.len(), ) } } From 59641076bc86504317f07f99465a0f600e957fd3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 17 Aug 2024 00:51:25 +0000 Subject: [PATCH 274/301] add StepArg::from_splice --- core/src/lib.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 296e0538..cb0edf8c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,7 +7,7 @@ mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; -pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TokEnv}; +pub use toktree::{Recognizer, SpecialToken, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; /// Defines what is allowed in Branch #[derive(Serialize, Deserialize, Clone, Debug, Default)] @@ -23,13 +23,12 @@ pub struct InferenceCapabilities { /// Backtracking is allowed. #[serde(default)] pub backtrack: bool, - + /// More than one branch is allowed. #[serde(default)] pub fork: bool, } - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StepArg { /// Sampling result for the previous iteration. @@ -60,6 +59,14 @@ impl StepArg { acc_tokens.truncate(acc_tokens.len() - bt); acc_tokens.extend_from_slice(&self.tokens); } + + pub fn from_splice(s: &Splice, sampled: Option) -> Self { + StepArg { + backtrack: s.backtrack, + tokens: s.ff_tokens.clone(), + sampled, + } + } } /* @@ -137,13 +144,13 @@ impl Branch { } pub fn spliced(&self, sampled: TokenId) -> Splice { - self.find_splice(sampled).cloned().unwrap_or_else(|| { - Splice { + self.find_splice(sampled) + .cloned() + .unwrap_or_else(|| Splice { when_sampled: vec![], backtrack: 0, ff_tokens: vec![sampled], - } - }) + }) } pub fn unconditional_splice(&self) -> Option<&Splice> { From bafe0f49e4334cc82c8a0234270dfbac71697db1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 16 Aug 2024 19:32:08 -0700 Subject: [PATCH 275/301] add TokTrie.sorted_tokens() --- core/src/lib.rs | 4 +++- core/src/toktree.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index cb0edf8c..f8785ff8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,7 +7,9 @@ mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; -pub use toktree::{Recognizer, SpecialToken, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; +pub use toktree::{ + Recognizer, SpecialToken, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TrieNode, +}; /// Defines what is allowed in Branch #[derive(Serialize, Deserialize, Clone, Debug, Default)] diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 02982c80..bd18b881 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -809,6 +809,32 @@ impl TokTrie { next_pop } + pub fn sorted_tokens(&self) -> Vec<(u32, Vec)> { + let mut res = vec![]; + let n = self.root(); + let off = self.node_offset(n); + let mut p = off + 1; + let endp = off + n.subtree_size(); + let mut next_pop = 0; + let mut bytes = vec![]; + while p < endp { + bytes.drain(bytes.len() - next_pop..); + let n = &self.nodes[p]; + let b = n.byte(); + bytes.push(b); + if let Some(t) = n.token_id() { + res.push((t, bytes.clone())); + } + next_pop = if n.subtree_size() == 1 { + n.num_parents() + } else { + 0 + }; + p += 1; + } + res + } + fn count_until_depth(&self, depth: usize) -> (usize, usize) { let mut count = 0; let mut num_tokens = 0; From 7cb20cd594428f5397f6159115d5e070e8c0935f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 18 Aug 2024 23:29:55 +0000 Subject: [PATCH 276/301] dial down the trie stats --- core/src/toktree.rs | 50 ++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index bd18b881..0b795db9 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -887,38 +887,50 @@ impl TokTrie { } let mut histogram = String::new(); - for (idx, num) in nodes_histogram.iter().enumerate() { - if *num > 0 { - if !histogram.is_empty() { - histogram.push_str(", "); + + if false { + for (idx, num) in nodes_histogram.iter().enumerate() { + if *num > 0 { + if !histogram.is_empty() { + histogram.push_str(", "); + } + histogram.push_str(&format!("{}:{}", idx, num)); } - histogram.push_str(&format!("{}:{}", idx, num)); } } - for n in self.node_children(self.root()) { - histogram.push_str(&format!( - "\n{} => {} {}", - n.byte(), - self.node_children(n).count(), - n.subtree_size() - )); + if false { + for n in self.node_children(self.root()) { + histogram.push_str(&format!( + "\n{} => {} {}", + n.byte(), + self.node_children(n).count(), + n.subtree_size() + )); + } + } + + if false { + for depth in 0..30 { + let (count, num_tokens) = self.count_until_depth(depth); + histogram.push_str(&format!( + "\ndepth {}: {} nodes {} tokens", + depth, count, num_tokens + )); + } } - for depth in 0..30 { - let (count, num_tokens) = self.count_until_depth(depth); - histogram.push_str(&format!( - "\ndepth {}: {} nodes {} tokens", - depth, count, num_tokens - )); + if histogram.len() > 0 { + histogram = format!("\n{}", histogram); } format!( - "{}\n{} nodes, {} token nodes, {} token bytes", + "{}{} nodes, {} token nodes, {} token bytes, {} max len", histogram, self.nodes.len(), token_nodes, self.token_data.len(), + self.max_token_len, ) } } From b217d0e4afbed8000c700335c7171c3d7011df02 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Sep 2024 15:42:15 +0000 Subject: [PATCH 277/301] add more impl notes --- implementation.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/implementation.md b/implementation.md index 29d5b28c..50e7cca0 100644 --- a/implementation.md +++ b/implementation.md @@ -92,3 +92,28 @@ while p < nodes.len() { Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. + +### Actual code + +See `add_bias_inner` in [toktree.rs](./core/src/toktree.rs). + +* it uses `try_push_byte()` which combines `byte_allowed()` and `push_byte()` +* it calls `pop_bytes()` at the beginning with a variable stored in previous iteration + +The following is a breakdown of all memory reads and writes, +when used with [llguidance](https://github.com/microsoft/llguidance), +see `try_push_byte()` in [parser.rs](https://github.com/microsoft/llguidance/blob/main/parser/src/earley/parser.rs#L1638). +This only considers the fast lexer path. + +* `pop_bytes()` - only register update (stack length) +* fetch current `TrieNode` (8 bytes) +* `try_push_byte()` - 3 reads, 1 write, see below +* updating token bit-mask - 1 read, 1 write + +The `try_push_byte()` function: + +* fetch lexer state from the stack (1 read) +* compute next DFA state: 1 read for alphabet compression if enabled, 1 read for transition table +* push lexer state to the stack (1 write) + +Together, this is 5 reads and 2 writes per node. From 76f77d124676bcad2899c29612c70d9a1b079f81 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 5 Sep 2024 17:23:13 +0000 Subject: [PATCH 278/301] add more info on impl --- implementation.md | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/implementation.md b/implementation.md index 50e7cca0..0e3f1f5f 100644 --- a/implementation.md +++ b/implementation.md @@ -97,23 +97,36 @@ The `if` in argument to `pop_bytes` is compiled to bit operations, so it is bran See `add_bias_inner` in [toktree.rs](./core/src/toktree.rs). -* it uses `try_push_byte()` which combines `byte_allowed()` and `push_byte()` -* it calls `pop_bytes()` at the beginning with a variable stored in previous iteration +- it uses `try_push_byte()` which combines `byte_allowed()` and `push_byte()` +- it calls `pop_bytes()` at the beginning with a variable stored in previous iteration The following is a breakdown of all memory reads and writes, when used with [llguidance](https://github.com/microsoft/llguidance), see `try_push_byte()` in [parser.rs](https://github.com/microsoft/llguidance/blob/main/parser/src/earley/parser.rs#L1638). This only considers the fast lexer path. -* `pop_bytes()` - only register update (stack length) -* fetch current `TrieNode` (8 bytes) -* `try_push_byte()` - 3 reads, 1 write, see below -* updating token bit-mask - 1 read, 1 write +- `pop_bytes()` - only register update (stack length) +- fetch current `TrieNode` (8 bytes) +- `try_push_byte()` - 3 reads, 1 write, see below +- updating token bit-mask - 1 read, 1 write The `try_push_byte()` function: -* fetch lexer state from the stack (1 read) -* compute next DFA state: 1 read for alphabet compression if enabled, 1 read for transition table -* push lexer state to the stack (1 write) +- fetch lexer state from the stack (1 read) +- compute next DFA state: 1 read for alphabet compression if enabled, 1 read for transition table +- push lexer state to the stack (1 write) Together, this is 5 reads and 2 writes per node. +There is at least one dependency chains of length 3 +(read lexer state -> compute dfa state -> write lexer state) +and another one with compression +(read byte -> compute compressed byte -> compute dfa state). + +On an AMD EPYC 7V13 a single node is processed in around 13 cycles; +this drops by 1 cycle if the alphabet compression is disabled +(likely only 1 because lexer stack fetch and alphabet compression fetch can be done in parallel). + +The 7V13 has 4 cycles L1 latency (32KB), 13 cycles L2 latency (512KB), +and 34 cycles L3 latency (16MB or so) [source](https://www.anandtech.com/show/14694/amd-rome-epyc-2nd-gen/7). + +Given the 4 cycle L1 and 3-deep dependency chain, 13 seems fairly optimal. From f4370794574a1b186c1bd472c553434fffe4aa36 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 6 Sep 2024 15:25:32 +0000 Subject: [PATCH 279/301] update impl/hw notes --- implementation.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/implementation.md b/implementation.md index 0e3f1f5f..74b5db9b 100644 --- a/implementation.md +++ b/implementation.md @@ -117,16 +117,18 @@ The `try_push_byte()` function: - push lexer state to the stack (1 write) Together, this is 5 reads and 2 writes per node. -There is at least one dependency chains of length 3 -(read lexer state -> compute dfa state -> write lexer state) -and another one with compression -(read byte -> compute compressed byte -> compute dfa state). +Dependency chain lengths are difficult to estimate, given the possible +speculation and out-of-order execution. -On an AMD EPYC 7V13 a single node is processed in around 13 cycles; +On an AMD EPYC 7V13 a single node is processed in around 13 cycles +(at 4.2 instructions per cycle); this drops by 1 cycle if the alphabet compression is disabled (likely only 1 because lexer stack fetch and alphabet compression fetch can be done in parallel). The 7V13 has 4 cycles L1 latency (32KB), 13 cycles L2 latency (512KB), -and 34 cycles L3 latency (16MB or so) [source](https://www.anandtech.com/show/14694/amd-rome-epyc-2nd-gen/7). - -Given the 4 cycle L1 and 3-deep dependency chain, 13 seems fairly optimal. +and 46 cycles L3 latency (up to 32MB per core, but shared). +It also has 6-wide uop dispatch. +Sources: +[EPYC Milan](https://www.anandtech.com/show/16529/amd-epyc-milan-review/4), +[Zen3](https://www.anandtech.com/show/16214/amd-zen-3-ryzen-deep-dive-review-5950x-5900x-5800x-and-5700x-tested/4), +[Zen2](https://www.anandtech.com/show/14694/amd-rome-epyc-2nd-gen/7) (shares L1/L2 specs). From 8828701d3b1c743472fe61bdf6dab12cdd726ab4 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 7 Sep 2024 12:42:55 -0700 Subject: [PATCH 280/301] add utility methods --- core/src/lib.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/core/src/lib.rs b/core/src/lib.rs index f8785ff8..73f98066 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -69,6 +69,14 @@ impl StepArg { sampled, } } + + pub fn from_sampled_token(tok: TokenId) -> Self { + StepArg { + backtrack: 0, + tokens: vec![tok], + sampled: Some(tok), + } + } } /* @@ -105,6 +113,24 @@ pub struct Splice { pub ff_tokens: Vec, } +impl Splice { + pub fn noop() -> Self { + Splice { + when_sampled: vec![], + backtrack: 0, + ff_tokens: vec![], + } + } + + pub fn tokens(ff_tokens: Vec) -> Self { + Splice { + when_sampled: vec![], + backtrack: 0, + ff_tokens, + } + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct Branch { /// If None, no sampling is performed. From 5e7013ad05081e918809d4ecebb33db7c4aabc69 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 26 Sep 2024 16:49:15 +0000 Subject: [PATCH 281/301] add TokEnvWithTrie --- core/src/lib.rs | 3 ++- core/src/toktree.rs | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 73f98066..31918f21 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -8,7 +8,8 @@ mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; pub use toktree::{ - Recognizer, SpecialToken, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TrieNode, + Recognizer, SpecialToken, TokEnv, TokEnvWithTrie, TokRxInfo, TokTrie, TokenId, TokenizerEnv, + TrieNode, }; /// Defines what is allowed in Branch diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 0b795db9..1f610999 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -117,6 +117,31 @@ pub trait TokenizerEnv: Send { pub type TokEnv = Arc; +pub struct TokEnvWithTrie { + base_env: TokEnv, + tok_trie: TokTrie, +} + +impl TokEnvWithTrie { + pub fn new(base_env: TokEnv, tok_trie: TokTrie) -> Self { + Self { base_env, tok_trie } + } +} + +impl TokenizerEnv for TokEnvWithTrie { + fn tok_trie(&self) -> &TokTrie { + &self.tok_trie + } + + fn stop(&self) -> ! { + self.base_env.stop() + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + self.base_env.tokenize_bytes(s) + } +} + #[derive(Clone)] pub struct TokTrie { info: TokRxInfo, @@ -220,14 +245,16 @@ impl TokTrie { r } - pub fn build_chat_mode_trie(&self) -> Self { + pub fn with_eos_token(&self, eos_token: TokenId) -> Self { let mut r = self.clone(); - if let Some(t) = self.info.tok_end_of_turn { - r.info.tok_eos = t; - } + r.info.tok_eos = eos_token; r } + pub fn build_chat_mode_trie(&self) -> Self { + self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos)) + } + fn finalize_ctor(&mut self) { for tok_id in 0..self.info.vocab_size { let bytes = self.token(tok_id); From 1e39ea4a6116c841ae4293f7f45f851d57883cd7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 23 Oct 2024 18:42:34 +0000 Subject: [PATCH 282/301] add TokTrie.with_info() --- core/src/toktree.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 1f610999..ce301b76 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -246,8 +246,15 @@ impl TokTrie { } pub fn with_eos_token(&self, eos_token: TokenId) -> Self { + self.with_info(TokRxInfo { + tok_eos: eos_token, + ..self.info.clone() + }) + } + + pub fn with_info(&self, info: TokRxInfo) -> Self { let mut r = self.clone(); - r.info.tok_eos = eos_token; + r.info = info.clone(); r } From a2af4e056ecbc2058a3fe72ac4a6cb05693d600a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 26 Oct 2024 00:05:15 +0000 Subject: [PATCH 283/301] add SPECIAL_TOKEN_PREFIX_BYTE --- core/src/toktree.rs | 90 ++++++++++++++++++++++++++++++++++++---- hf_tokenizers/src/lib.rs | 24 +++++------ 2 files changed, 95 insertions(+), 19 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index ce301b76..23be30c2 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -103,13 +103,48 @@ pub trait Recognizer { } pub trait TokenizerEnv: Send { + /// Stop the program; not used. + // TODO remove this fn stop(&self) -> !; + + /// Associated trie. fn tok_trie(&self) -> &TokTrie; + + /// Tokenize a given byte sequence. + /// It may or may not interpret <|special_tokens|> as special. fn tokenize_bytes(&self, s: &[u8]) -> Vec; + /// Tokenize a given byte sequence. + /// It will interpret text starting with SPECIAL_TOKEN_PREFIX_BYTE as special tokens. + fn tokenize_bytes_prefix(&self, s: &[u8]) -> Vec { + if s.contains(&TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) { + let copy = s + .iter() + .filter_map(|&b| { + if b == TokTrie::SPECIAL_TOKEN_PREFIX_BYTE { + None + } else { + Some(b) + } + }) + .collect::>(); + self.tokenize_bytes(©) + } else { + self.tokenize_bytes(s) + } + } + + /// Tokenize a string coming from user. It may or may not interpret <|special_tokens|> as special. fn tokenize(&self, s: &str) -> Vec { self.tokenize_bytes(s.as_bytes()) } + + /// Tokenize a string. It will interpret <|special_tokens|> as special. + fn tokenize_special(&self, s: &str) -> Vec { + self.tokenize(s) + } + + /// End of sentence token fn eos_token(&self) -> TokenId { self.tok_trie().eos_token() } @@ -216,6 +251,8 @@ impl TrieNode { const LEN_BITS: u32 = 10; impl TokTrie { + pub const SPECIAL_TOKEN_PREFIX_BYTE: u8 = 0xff; + pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { let mut trie = TrieHash::new(0xff); let mut token_offsets = Vec::new(); @@ -393,14 +430,19 @@ impl TokTrie { format!("OOB[{}]", idx) } else { // format!("{:?}[{}]", self.token_str(idx), idx) - let s = self.token_str(idx); - if s.len() == 0 { - format!("EMPTY[{}]", idx) - } else if !s.contains('\u{fffd}') { - format!("{:?}", s) + let bytes = self.token(idx); + if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_PREFIX_BYTE { + String::from_utf8_lossy(&bytes[1..]).to_string() } else { - let bytes = self.token(idx); - format!("HEX[{}]", to_hex_string(bytes)) + let s = String::from_utf8_lossy(bytes); + if s.len() == 0 { + format!("EMPTY[{}]", idx) + } else if !s.contains('\u{fffd}') { + format!("{:?}", s) + } else { + let bytes = self.token(idx); + format!("HEX[{}]", to_hex_string(bytes)) + } } } } @@ -420,6 +462,14 @@ impl TokTrie { } pub fn decode(&self, tokens: &[TokenId]) -> Vec { + let mut bytes = self.decode_raw(tokens); + if bytes.contains(&TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) { + bytes.retain(|&b| b != TokTrie::SPECIAL_TOKEN_PREFIX_BYTE); + } + bytes + } + + pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec { tokens .iter() .flat_map(|t| self.token(*t).to_vec()) @@ -430,6 +480,32 @@ impl TokTrie { String::from_utf8_lossy(&self.decode(tokens)).to_string() } + pub fn get_special_token(&self, name: &str) -> Option { + self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) + .and_then(|n| { + self.child_at_bytes(n, name.as_bytes()) + .and_then(|n| n.token_id()) + }) + } + + pub fn get_special_tokens(&self) -> Vec { + let mut res = Vec::new(); + let pref_node = self + .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) + .expect("missing special token prefix"); + let mut stack = vec![pref_node]; + while let Some(n) = stack.pop() { + for c in self.node_children(n) { + if let Some(tok) = c.token_id() { + res.push(tok); + } + stack.push(c); + } + } + res.remove(0); + res + } + pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec { let mut r = Vec::new(); if bytes.len() == 0 { diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index c54ac598..9dde8f0a 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -150,21 +150,22 @@ impl ByteTokenizer { let char_map = build_char_map(); for tok_id in 0..vocab_size { - if added.contains_key(&tok_id) { - continue; - } if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) { - if is_byte_fallback { + let bytes = if added.contains_key(&tok_id) { + let mut bytes = tok_name.as_bytes().to_vec(); + bytes.insert(0, TokTrie::SPECIAL_TOKEN_PREFIX_BYTE); + bytes + } else if is_byte_fallback { if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") { // parse hex number from tok_name let hex_str = &tok_name[3..5]; let byte = u8::from_str_radix(hex_str, 16).unwrap(); - res.token_bytes[tok_id as usize] = vec![byte]; + vec![byte] } else { assert!(!tok_name.starts_with("<0x")); let tok_name = tok_name.replace(space_ch, " "); - res.token_bytes[tok_id as usize] = tok_name.as_bytes().to_vec(); + tok_name.as_bytes().to_vec() } } else if is_byte_level { let bytes: Result> = tok_name @@ -176,18 +177,17 @@ impl ByteTokenizer { .ok_or_else(|| anyhow!("missing char: {}", c)) }) .collect(); - let bytes = match bytes { + match bytes { Ok(b) => b, Err(e) => { - println!("error: {} for {:?}", e, tok_name); + log::warn!("error: {} for {:?}", e, tok_name); continue; } - }; - - res.token_bytes[tok_id as usize] = bytes; + } } else { panic!(); - } + }; + res.token_bytes[tok_id as usize] = bytes; } else { log::warn!("missing token: {}", tok_id); } From fffe4f7ca998b4ac9f21b81d3c60e592cb86ecee Mon Sep 17 00:00:00 2001 From: v-jkegler Date: Sun, 27 Oct 2024 19:54:46 -0400 Subject: [PATCH 284/301] Update README.md Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5871fe4d..ab5ce844 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The constraints are typically expressed on strings or bytes, not tokens. To compute the set of tokens that match a string constraint, one needs go through all the possible tokens and apply the constraint. An efficient way to do this is walk a prefix tree (trie) of all tokens. -This library implements this trie and exposes a way of filtering when provided with a constraints +This library implements this trie and exposes a way of filtering when provided with a constraint implementing the [following interface](core/src/toktree.rs): ```rust From b2715704f3b6f2e4e6b27565cb100bac67492523 Mon Sep 17 00:00:00 2001 From: v-jkegler Date: Mon, 4 Nov 2024 14:03:46 -0500 Subject: [PATCH 285/301] Expand comment on collapse() in toktree.rs Expand comment on collapse() in toktree.rs. --- core/src/toktree.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 23be30c2..48f66f0b 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -75,6 +75,8 @@ pub enum SpecialToken { pub trait Recognizer { /// for _ in 0..num { stack.pop() } fn pop_bytes(&mut self, num: usize); + /// "Collapse" the stack so that it consists only of its former + /// top element. /// X = stack.top(); stack.empty(); stack.push(X) fn collapse(&mut self); /// check if stack.top() transitions via byte to a viable state From ad6be22981c3db42a89aa6ff6bf54ffa326054b8 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Nov 2024 11:07:42 -0800 Subject: [PATCH 286/301] update hf tokenizers --- core/src/bytes.rs | 2 +- hf_tokenizers/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/bytes.rs b/core/src/bytes.rs index 6aa39fdb..88073161 100644 --- a/core/src/bytes.rs +++ b/core/src/bytes.rs @@ -2,7 +2,7 @@ use std::mem::size_of; use anyhow::{anyhow, Result}; use bytemuck::{NoUninit, Pod}; -use bytemuck_derive::{Pod, Zeroable}; +use bytemuck_derive::Zeroable; #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] diff --git a/hf_tokenizers/Cargo.toml b/hf_tokenizers/Cargo.toml index bef35682..0536d99b 100644 --- a/hf_tokenizers/Cargo.toml +++ b/hf_tokenizers/Cargo.toml @@ -9,5 +9,5 @@ serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" rustc-hash = { version = "2.0.0" } -tokenizers = { version = "0.15.0", features = ["http"] } +tokenizers = { version = "0.19.1", features = ["http"] } log = "0.4.21" From 36f9cf7e85c661cc78484c176d81d991399bb017 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Nov 2024 13:24:22 -0800 Subject: [PATCH 287/301] force newer bytemuck --- core/Cargo.lock | 47 +-- core/Cargo.toml | 4 +- core/src/bytes.rs | 2 +- hf_tokenizers/Cargo.lock | 671 ++++++++++++++++++++++++--------------- 4 files changed, 442 insertions(+), 282 deletions(-) diff --git a/core/Cargo.lock b/core/Cargo.lock index aab2775f..8add3ab3 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -4,21 +4,21 @@ version = 3 [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", @@ -31,20 +31,26 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -63,18 +69,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -83,20 +89,21 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] [[package]] name = "syn" -version = "2.0.68" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -117,6 +124,6 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" diff --git a/core/Cargo.toml b/core/Cargo.toml index 20a8dd63..7d96dbcb 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,6 +10,6 @@ name = "toktrie" serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" -bytemuck = "1.16.0" -bytemuck_derive = "1.6.0" +bytemuck = "1.19.0" +bytemuck_derive = "1.8.0" rustc-hash = { version = "2.0.0" } diff --git a/core/src/bytes.rs b/core/src/bytes.rs index 88073161..6aa39fdb 100644 --- a/core/src/bytes.rs +++ b/core/src/bytes.rs @@ -2,7 +2,7 @@ use std::mem::size_of; use anyhow::{anyhow, Result}; use bytemuck::{NoUninit, Pod}; -use bytemuck_derive::Zeroable; +use bytemuck_derive::{Pod, Zeroable}; #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] diff --git a/hf_tokenizers/Cargo.lock b/hf_tokenizers/Cargo.lock index ca52adf7..1ef7281d 100644 --- a/hf_tokenizers/Cargo.lock +++ b/hf_tokenizers/Cargo.lock @@ -3,10 +3,10 @@ version = 3 [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aho-corasick" @@ -17,60 +17,11 @@ dependencies = [ "memchr", ] -[[package]] -name = "anstream" -version = "0.6.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" - -[[package]] -name = "anstyle-parse" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" -dependencies = [ - "windows-sys 0.52.0", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" -dependencies = [ - "anstyle", - "windows-sys 0.52.0", -] - [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "base64" @@ -98,78 +49,41 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn", ] [[package]] -name = "cc" -version = "1.0.104" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "clap" -version = "4.5.8" +name = "byteorder" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" -dependencies = [ - "clap_builder", - "clap_derive", -] +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] -name = "clap_builder" -version = "4.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" -dependencies = [ - "anstream", - "anstyle", - "clap_lex", - "strsim 0.11.1", -] - -[[package]] -name = "clap_derive" -version = "4.5.8" +name = "cc" +version = "1.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" +checksum = "baee610e9452a8f6f0a1b6194ec09ff9e2d85dea54432acdae41aa0761c95d70" dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 2.0.68", + "shlex", ] [[package]] -name = "clap_lex" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" - -[[package]] -name = "colorchoice" -version = "1.0.1" +name = "cfg-if" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "console" @@ -196,9 +110,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "crc32fast" @@ -236,9 +150,9 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "darling" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -246,58 +160,58 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", - "strsim 0.10.0", - "syn 1.0.109", + "strsim", + "syn", ] [[package]] name = "darling_macro" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "derive_builder" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ "darling", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "derive_builder_macro" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 1.0.109", + "syn", ] [[package]] @@ -321,6 +235,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.13.0" @@ -354,15 +279,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", "miniz_oxide", @@ -409,12 +334,6 @@ dependencies = [ "wasi", ] -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - [[package]] name = "hf-hub" version = "0.3.2" @@ -432,6 +351,124 @@ dependencies = [ "ureq", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -440,12 +477,23 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -470,12 +518,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "is_terminal_polyfill" -version = "1.70.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" - [[package]] name = "itertools" version = "0.11.0" @@ -508,9 +550,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libredox" @@ -528,6 +570,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "log" version = "0.4.22" @@ -564,11 +612,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" dependencies = [ - "adler", + "adler2", ] [[package]] @@ -589,7 +637,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn", ] [[package]] @@ -627,9 +675,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "onig" @@ -655,9 +703,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -676,7 +724,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn", ] [[package]] @@ -687,9 +735,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -717,36 +765,39 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -814,9 +865,9 @@ dependencies = [ [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -825,9 +876,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -837,9 +888,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -848,9 +899,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "ring" @@ -875,9 +926,9 @@ checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ "bitflags 2.6.0", "errno", @@ -888,11 +939,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "log", + "once_cell", "ring", "rustls-pki-types", "rustls-webpki", @@ -902,15 +954,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" -version = "0.102.5" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ "ring", "rustls-pki-types", @@ -925,18 +977,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -947,9 +999,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -957,35 +1009,42 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "smallvec" version = "1.13.2" @@ -1011,10 +1070,10 @@ dependencies = [ ] [[package]] -name = "strsim" -version = "0.10.0" +name = "stable_deref_trait" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "strsim" @@ -1030,9 +1089,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "1.0.109" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1040,71 +1099,66 @@ dependencies = [ ] [[package]] -name = "syn" -version = "2.0.68" +name = "synstructure" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "unicode-ident", + "syn", ] [[package]] name = "tempfile" -version = "3.10.1" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn", ] [[package]] -name = "tinyvec" -version = "1.7.0" +name = "tinystr" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6b6a2fb3a985e99cebfaefa9faa3024743da73304ca1c683a36429613d3d22" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" dependencies = [ - "tinyvec_macros", + "displaydoc", + "zerovec", ] -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "tokenizers" -version = "0.15.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d" +checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" dependencies = [ "aho-corasick", - "clap", "derive_builder", "esaxx-rs", "getrandom", @@ -1156,26 +1210,11 @@ dependencies = [ "toktrie", ] -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unicode-normalization" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" -dependencies = [ - "tinyvec", -] +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization-alignments" @@ -1188,15 +1227,15 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode_categories" @@ -1212,9 +1251,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.7" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" dependencies = [ "base64 0.22.1", "flate2", @@ -1223,7 +1262,6 @@ dependencies = [ "once_cell", "rustls", "rustls-pki-types", - "rustls-webpki", "serde", "serde_json", "url", @@ -1232,9 +1270,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", @@ -1242,10 +1280,16 @@ dependencies = [ ] [[package]] -name = "utf8parse" -version = "0.2.2" +name = "utf16_iter" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "vcpkg" @@ -1261,9 +1305,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "webpki-roots" -version = "0.26.3" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] @@ -1286,6 +1330,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -1407,8 +1460,108 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] From 6172936f8c965d2050a53d14de0e3410ecc78ad1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 6 Nov 2024 13:45:41 -0800 Subject: [PATCH 288/301] fix compilation --- core/src/bytes.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/bytes.rs b/core/src/bytes.rs index 6aa39fdb..17f3878e 100644 --- a/core/src/bytes.rs +++ b/core/src/bytes.rs @@ -1,7 +1,7 @@ use std::mem::size_of; use anyhow::{anyhow, Result}; -use bytemuck::{NoUninit, Pod}; +use bytemuck::{NoUninit, Pod as PodTrait}; use bytemuck_derive::{Pod, Zeroable}; #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] @@ -12,7 +12,7 @@ pub fn clone_vec_as_bytes(input: &[T]) -> Vec { bytemuck::cast_slice(input).to_vec() } -pub fn vec_from_bytes(bytes: &[u8]) -> Vec { +pub fn vec_from_bytes(bytes: &[u8]) -> Vec { if bytes.len() % size_of::() != 0 { panic!( "vecT: got {} bytes, needed multiple of {}", From 148399b250ec9bbf0fc97149534035d8fad71164 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 20 Nov 2024 14:58:36 -0800 Subject: [PATCH 289/301] optimize decode_raw() --- core/src/toktree.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 48f66f0b..192d5018 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -361,7 +361,7 @@ impl TokTrie { let max_examples = 50; let ts_neg = ts.negated(); - let use_neg = ts_neg.num_set() * 20 < ts.num_set(); + let use_neg = ts_neg.num_set() * 10 < ts.num_set(); let ts1 = if use_neg { &ts_neg } else { &ts }; let num_set = ts1.num_set(); let max_tok = std::cmp::min(max_examples, num_set); @@ -472,10 +472,12 @@ impl TokTrie { } pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec { - tokens - .iter() - .flat_map(|t| self.token(*t).to_vec()) - .collect() + let mut res = Vec::new(); + res.reserve(tokens.len() * 6 + 32); // approximately + for &tok in tokens { + res.extend_from_slice(self.token(tok)); + } + res } pub fn decode_str(&self, tokens: &[TokenId]) -> String { From 6bb1d1e8f26afa0ddb96168494b4ee9efff00c18 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 17:01:31 +0000 Subject: [PATCH 290/301] Bump rustls from 0.23.16 to 0.23.18 in /hf_tokenizers Bumps [rustls](https://github.com/rustls/rustls) from 0.23.16 to 0.23.18. - [Release notes](https://github.com/rustls/rustls/releases) - [Changelog](https://github.com/rustls/rustls/blob/main/CHANGELOG.md) - [Commits](https://github.com/rustls/rustls/compare/v/0.23.16...v/0.23.18) --- updated-dependencies: - dependency-name: rustls dependency-type: indirect ... Signed-off-by: dependabot[bot] --- hf_tokenizers/Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_tokenizers/Cargo.lock b/hf_tokenizers/Cargo.lock index 1ef7281d..6c1232d4 100644 --- a/hf_tokenizers/Cargo.lock +++ b/hf_tokenizers/Cargo.lock @@ -939,9 +939,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", From 12e73f8c95686afcbdcfea331b9a72e4be3e5424 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 25 Nov 2024 13:36:44 -0800 Subject: [PATCH 291/301] improve default tokenize_bytes_prefix --- core/src/toktree.rs | 43 +++++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 192d5018..85fb8a89 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -119,21 +119,36 @@ pub trait TokenizerEnv: Send { /// Tokenize a given byte sequence. /// It will interpret text starting with SPECIAL_TOKEN_PREFIX_BYTE as special tokens. fn tokenize_bytes_prefix(&self, s: &[u8]) -> Vec { - if s.contains(&TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) { - let copy = s + let mut idx = 0; + let ff = TokTrie::SPECIAL_TOKEN_PREFIX_BYTE; + let mut result = Vec::new(); + let trie = self.tok_trie(); + while idx < s.len() { + let normal_len = s[idx..] .iter() - .filter_map(|&b| { - if b == TokTrie::SPECIAL_TOKEN_PREFIX_BYTE { - None - } else { - Some(b) + .position(|&x| x == ff) + .unwrap_or(s.len() - idx); + if normal_len != 0 { + result.extend_from_slice(&self.tokenize_bytes(&s[idx..idx + normal_len])); + idx += normal_len; + } + idx += 1; // skip ff + if idx + 3 < s.len() && s[idx] == '<' as u8 { + let spec_len = s[idx..std::cmp::min(s.len(), idx + 100)] + .iter() + .position(|&x| x == '>' as u8); + if let Some(mut spec_len) = spec_len { + spec_len += 1; + let spec_token = &s[idx - 1..idx + spec_len]; + if let Some(id) = trie.token_id_at_bytes(spec_token) { + result.push(id); + idx += spec_len; } - }) - .collect::>(); - self.tokenize_bytes(©) - } else { - self.tokenize_bytes(s) + } + } } + + result } /// Tokenize a string coming from user. It may or may not interpret <|special_tokens|> as special. @@ -741,6 +756,10 @@ impl TokTrie { Some(n) } + pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option { + self.child_at_bytes(self.root(), bytes).and_then(|n| n.token_id()) + } + pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut SimpleVob) { self.compute_bias_ext(r, logits, &[]); } From f4c1f0492acc574d48d39a9955790d009df72d87 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 25 Nov 2024 13:38:25 -0800 Subject: [PATCH 292/301] remove obsolete TokenizerEnv::stop() --- core/src/toktree.rs | 8 -------- hf_tokenizers/src/lib.rs | 4 ---- 2 files changed, 12 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 85fb8a89..07692e15 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -105,10 +105,6 @@ pub trait Recognizer { } pub trait TokenizerEnv: Send { - /// Stop the program; not used. - // TODO remove this - fn stop(&self) -> !; - /// Associated trie. fn tok_trie(&self) -> &TokTrie; @@ -185,10 +181,6 @@ impl TokenizerEnv for TokEnvWithTrie { &self.tok_trie } - fn stop(&self) -> ! { - self.base_env.stop() - } - fn tokenize_bytes(&self, s: &[u8]) -> Vec { self.base_env.tokenize_bytes(s) } diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 9dde8f0a..60411af4 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -253,10 +253,6 @@ impl ByteTokenizerEnv { } impl TokenizerEnv for ByteTokenizerEnv { - fn stop(&self) -> ! { - panic!("stop called") - } - fn tok_trie(&self) -> &TokTrie { &self.tok_trie } From e22222c1a4b5b3b1c5cd7e20858b9592604cc590 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 25 Nov 2024 13:42:15 -0800 Subject: [PATCH 293/301] add tokenize_is_approximate() method to TokenizerEnv trait --- core/src/toktree.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 07692e15..37c8fe45 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -161,6 +161,13 @@ pub trait TokenizerEnv: Send { fn eos_token(&self) -> TokenId { self.tok_trie().eos_token() } + + /// If this returns true, this tokenizer may return non-canonical tokenizations + /// and should generally not be used for forcing tokens. + /// Typically, it will just use TokTrie::greedy_tokenize(). + fn tokenize_is_approximate(&self) -> bool { + false + } } pub type TokEnv = Arc; From fafc88bbd19ba9db522ff112b0a7dcf8a610cc78 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 25 Nov 2024 13:43:49 -0800 Subject: [PATCH 294/301] update CI workflow to use --locked flag for cargo build --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2df7c1fc..369c923a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -17,8 +17,8 @@ jobs: steps: - uses: actions/checkout@v4 - name: Build core - run: cargo build --verbose + run: cargo build --verbose --locked working-directory: core - name: Build for hf-tokenizers - run: cargo build --verbose + run: cargo build --verbose --locked working-directory: hf_tokenizers From 2a90302614b92fcb6fa7bd5bc9be8702742c770a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 25 Nov 2024 14:01:59 -0800 Subject: [PATCH 295/301] update tokenize_is_approximate() to tokenize_is_canonical() --- core/src/toktree.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 37c8fe45..1d5ef975 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -162,11 +162,11 @@ pub trait TokenizerEnv: Send { self.tok_trie().eos_token() } - /// If this returns true, this tokenizer may return non-canonical tokenizations - /// and should generally not be used for forcing tokens. - /// Typically, it will just use TokTrie::greedy_tokenize(). - fn tokenize_is_approximate(&self) -> bool { - false + /// If this returns true, this tokenizer always returns canonical tokenizations + /// and can be used for forcing tokens. + /// Non-canonical tokenizers will typically just use TokTrie::greedy_tokenize(). + fn tokenize_is_canonical(&self) -> bool { + true } } From d8545d3392fc46683b95b416b51b13b86328f08f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 27 Nov 2024 21:04:25 -0800 Subject: [PATCH 296/301] limit debug tokens to a maximum of 200 and indicate truncation in output --- core/src/toktree.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 1d5ef975..ecdf707e 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -422,8 +422,16 @@ impl TokTrie { .join("‧") } + pub const MAX_DBG_TOKENS: usize = 200; + pub fn tokens_dbg(&self, toks: &[u32]) -> String { - let joined = toks + let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS { + (true, &toks[0..Self::MAX_DBG_TOKENS]) + } else { + (false, toks) + }; + + let mut joined = toks .iter() .map(|t| { let s = self.token_dbg(*t); @@ -436,6 +444,10 @@ impl TokTrie { .collect::>() .join("‧"); + if limited { + joined.push_str("…"); + } + format!("\"{}\"", joined) } @@ -516,6 +528,9 @@ impl TokTrie { for c in self.node_children(n) { if let Some(tok) = c.token_id() { res.push(tok); + if res.len() > Self::MAX_DBG_TOKENS + 1 { + break; + } } stack.push(c); } @@ -756,7 +771,8 @@ impl TokTrie { } pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option { - self.child_at_bytes(self.root(), bytes).and_then(|n| n.token_id()) + self.child_at_bytes(self.root(), bytes) + .and_then(|n| n.token_id()) } pub fn compute_bias(&self, r: &mut impl Recognizer, logits: &mut SimpleVob) { From 848784ea95a89077c95e77b80861cd19d5abfce7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 29 Nov 2024 13:56:50 -0800 Subject: [PATCH 297/301] add docs for special token prefix --- core/src/toktree.rs | 1 + special_tokens.md | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 special_tokens.md diff --git a/core/src/toktree.rs b/core/src/toktree.rs index ecdf707e..e518f3fd 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -267,6 +267,7 @@ impl TrieNode { const LEN_BITS: u32 = 10; impl TokTrie { + // see https://github.com/microsoft/toktrie/blob/main/special_tokens.md pub const SPECIAL_TOKEN_PREFIX_BYTE: u8 = 0xff; pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { diff --git a/special_tokens.md b/special_tokens.md new file mode 100644 index 00000000..b69e23ad --- /dev/null +++ b/special_tokens.md @@ -0,0 +1,27 @@ +# Support for special tokens + +Tokenizers typically include special tokens, such as +`<|end_of_text|>`, `<|eot_id|>`, `<|python_tag|>`, `<|start_header_id|>`, etc. +This library is tasked with translating between the byte sequences +and tokens. +If you see bytes `<|eot_id|>` in the input, you may or may not want to treat them +as a special token. + +The library assumes that by default you want ot treat them as bytes +(so they would be tokenized as `<|`, `eot`, `_`, `id`, `|>` or similar). +To indicate that you want to treat them as a special token, you need to +prefix them with byte 0xFF (255) (`TokTrie::SPECIAL_TOKEN_PREFIX_BYTE`). + +Byte FF is chosen because it is not a valid UTF-8 byte, so it should not normally +occur in regular inputs. +In Rust, you cannot have byte FF in `&str`, only in `&[u8]`. +In Python note the difference between `b"\xFF"` and `"\xFF".encode("utf-8")` +(or equivalently `"\u00FF".encode("utf-8")`), which is `b"\xC3\xBF"`. + +If you're constructing it manually, +the token array passed to the `TokTrie` constructor should include the special tokens +with the prefix byte FF. + +The llguidance library does not expose the FF bytes externally +(except for special `tokenize_bytes_prefix` methods), so you +generally don't need to worry about them, except when building the `TokTrie`. From cbdd3d717e2d1671a326d68f74d5054a5a79c0c9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 29 Nov 2024 14:15:37 -0800 Subject: [PATCH 298/301] rename SPECIAL_TOKEN_PREFIX_BYTE to SPECIAL_TOKEN_MARKER --- core/src/toktree.rs | 18 +++++++++--------- hf_tokenizers/src/lib.rs | 2 +- special_tokens.md | 11 +++++------ 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index e518f3fd..0f43b078 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -113,10 +113,10 @@ pub trait TokenizerEnv: Send { fn tokenize_bytes(&self, s: &[u8]) -> Vec; /// Tokenize a given byte sequence. - /// It will interpret text starting with SPECIAL_TOKEN_PREFIX_BYTE as special tokens. - fn tokenize_bytes_prefix(&self, s: &[u8]) -> Vec { + /// It will interpret text starting with SPECIAL_TOKEN_MARKER as special tokens. + fn tokenize_bytes_marker(&self, s: &[u8]) -> Vec { let mut idx = 0; - let ff = TokTrie::SPECIAL_TOKEN_PREFIX_BYTE; + let ff = TokTrie::SPECIAL_TOKEN_MARKER; let mut result = Vec::new(); let trie = self.tok_trie(); while idx < s.len() { @@ -268,7 +268,7 @@ const LEN_BITS: u32 = 10; impl TokTrie { // see https://github.com/microsoft/toktrie/blob/main/special_tokens.md - pub const SPECIAL_TOKEN_PREFIX_BYTE: u8 = 0xff; + pub const SPECIAL_TOKEN_MARKER: u8 = 0xff; pub fn from(info: &TokRxInfo, words: &Vec>) -> Self { let mut trie = TrieHash::new(0xff); @@ -460,7 +460,7 @@ impl TokTrie { } else { // format!("{:?}[{}]", self.token_str(idx), idx) let bytes = self.token(idx); - if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_PREFIX_BYTE { + if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER { String::from_utf8_lossy(&bytes[1..]).to_string() } else { let s = String::from_utf8_lossy(bytes); @@ -492,8 +492,8 @@ impl TokTrie { pub fn decode(&self, tokens: &[TokenId]) -> Vec { let mut bytes = self.decode_raw(tokens); - if bytes.contains(&TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) { - bytes.retain(|&b| b != TokTrie::SPECIAL_TOKEN_PREFIX_BYTE); + if bytes.contains(&TokTrie::SPECIAL_TOKEN_MARKER) { + bytes.retain(|&b| b != TokTrie::SPECIAL_TOKEN_MARKER); } bytes } @@ -512,7 +512,7 @@ impl TokTrie { } pub fn get_special_token(&self, name: &str) -> Option { - self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) + self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER) .and_then(|n| { self.child_at_bytes(n, name.as_bytes()) .and_then(|n| n.token_id()) @@ -522,7 +522,7 @@ impl TokTrie { pub fn get_special_tokens(&self) -> Vec { let mut res = Vec::new(); let pref_node = self - .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_PREFIX_BYTE) + .child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER) .expect("missing special token prefix"); let mut stack = vec![pref_node]; while let Some(n) = stack.pop() { diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 60411af4..540625d4 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -153,7 +153,7 @@ impl ByteTokenizer { if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) { let bytes = if added.contains_key(&tok_id) { let mut bytes = tok_name.as_bytes().to_vec(); - bytes.insert(0, TokTrie::SPECIAL_TOKEN_PREFIX_BYTE); + bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER); bytes } else if is_byte_fallback { if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") diff --git a/special_tokens.md b/special_tokens.md index b69e23ad..c30f2e7e 100644 --- a/special_tokens.md +++ b/special_tokens.md @@ -10,18 +10,17 @@ as a special token. The library assumes that by default you want ot treat them as bytes (so they would be tokenized as `<|`, `eot`, `_`, `id`, `|>` or similar). To indicate that you want to treat them as a special token, you need to -prefix them with byte 0xFF (255) (`TokTrie::SPECIAL_TOKEN_PREFIX_BYTE`). +prefix them with "marker" byte 0xFF (255) (`TokTrie::SPECIAL_TOKEN_MARKER`). -Byte FF is chosen because it is not a valid UTF-8 byte, so it should not normally +Byte FF is chosen as a marker because it is not a valid UTF-8 byte, so it should not normally occur in regular inputs. In Rust, you cannot have byte FF in `&str`, only in `&[u8]`. In Python note the difference between `b"\xFF"` and `"\xFF".encode("utf-8")` (or equivalently `"\u00FF".encode("utf-8")`), which is `b"\xC3\xBF"`. -If you're constructing it manually, -the token array passed to the `TokTrie` constructor should include the special tokens -with the prefix byte FF. +If you're constructing the token array for `TokTrie` constructor manually, +it should include the special tokens prefixed with the marker byte FF. The llguidance library does not expose the FF bytes externally -(except for special `tokenize_bytes_prefix` methods), so you +(except for special `tokenize_bytes_marker` methods), so you generally don't need to worry about them, except when building the `TokTrie`. From f38eab5c52e895d512bb40aa2a177723ced24739 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 30 Nov 2024 10:37:18 -0800 Subject: [PATCH 299/301] more lose versioning of hf tokenizers --- hf_tokenizers/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hf_tokenizers/Cargo.toml b/hf_tokenizers/Cargo.toml index 0536d99b..435114d4 100644 --- a/hf_tokenizers/Cargo.toml +++ b/hf_tokenizers/Cargo.toml @@ -9,5 +9,5 @@ serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" anyhow = "1.0.75" rustc-hash = { version = "2.0.0" } -tokenizers = { version = "0.19.1", features = ["http"] } +tokenizers = { version = ">=0.19.1, <1.0.0", features = ["http"] } log = "0.4.21" From bb21d4114947024da83451e32116180140988a7f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 30 Nov 2024 10:49:44 -0800 Subject: [PATCH 300/301] move files around before merge --- CODE_OF_CONDUCT.md | 9 ---- LICENSE | 21 ---------- SECURITY.md | 41 ------------------- SUPPORT.md | 13 ------ special_tokens.md => docs/special_tokens.md | 0 implementation.md => docs/toktrie.md | 0 {core => toktrie}/Cargo.lock | 0 {core => toktrie}/Cargo.toml | 0 README.md => toktrie/README.md | 0 {core => toktrie}/src/bytes.rs | 0 {core => toktrie}/src/lib.rs | 0 {core => toktrie}/src/recognizer.rs | 0 {core => toktrie}/src/rng.rs | 0 {core => toktrie}/src/svob.rs | 0 {core => toktrie}/src/toktree.rs | 0 .../Cargo.lock | 0 .../Cargo.toml | 0 .../src/lib.rs | 0 18 files changed, 84 deletions(-) delete mode 100644 CODE_OF_CONDUCT.md delete mode 100644 LICENSE delete mode 100644 SECURITY.md delete mode 100644 SUPPORT.md rename special_tokens.md => docs/special_tokens.md (100%) rename implementation.md => docs/toktrie.md (100%) rename {core => toktrie}/Cargo.lock (100%) rename {core => toktrie}/Cargo.toml (100%) rename README.md => toktrie/README.md (100%) rename {core => toktrie}/src/bytes.rs (100%) rename {core => toktrie}/src/lib.rs (100%) rename {core => toktrie}/src/recognizer.rs (100%) rename {core => toktrie}/src/rng.rs (100%) rename {core => toktrie}/src/svob.rs (100%) rename {core => toktrie}/src/toktree.rs (100%) rename {hf_tokenizers => toktrie_hf_tokenizers}/Cargo.lock (100%) rename {hf_tokenizers => toktrie_hf_tokenizers}/Cargo.toml (100%) rename {hf_tokenizers => toktrie_hf_tokenizers}/src/lib.rs (100%) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md deleted file mode 100644 index f9ba8cf6..00000000 --- a/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,9 +0,0 @@ -# Microsoft Open Source Code of Conduct - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). - -Resources: - -- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) -- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) -- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 9e841e7a..00000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ - MIT License - - Copyright (c) Microsoft Corporation. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE diff --git a/SECURITY.md b/SECURITY.md deleted file mode 100644 index b3c89efc..00000000 --- a/SECURITY.md +++ /dev/null @@ -1,41 +0,0 @@ - - -## Security - -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). - -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. - -## Reporting Security Issues - -**Please do not report security vulnerabilities through public GitHub issues.** - -Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). - -If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). - -You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). - -Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: - - * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) - * Full paths of source file(s) related to the manifestation of the issue - * The location of the affected source code (tag/branch/commit or direct URL) - * Any special configuration required to reproduce the issue - * Step-by-step instructions to reproduce the issue - * Proof-of-concept or exploit code (if possible) - * Impact of the issue, including how an attacker might exploit the issue - -This information will help us triage your report more quickly. - -If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. - -## Preferred Languages - -We prefer all communications to be in English. - -## Policy - -Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). - - diff --git a/SUPPORT.md b/SUPPORT.md deleted file mode 100644 index 382f1b8b..00000000 --- a/SUPPORT.md +++ /dev/null @@ -1,13 +0,0 @@ -# Support - -## How to file issues and get help - -This project uses GitHub Issues to track bugs and feature requests. Please search the existing -issues before filing new issues to avoid duplicates. For new issues, file your bug or -feature request as a new Issue. - -For help and questions about using this project, please use GitHub Discussions. - -## Microsoft Support Policy - -Support for this project is limited to the resources listed above. diff --git a/special_tokens.md b/docs/special_tokens.md similarity index 100% rename from special_tokens.md rename to docs/special_tokens.md diff --git a/implementation.md b/docs/toktrie.md similarity index 100% rename from implementation.md rename to docs/toktrie.md diff --git a/core/Cargo.lock b/toktrie/Cargo.lock similarity index 100% rename from core/Cargo.lock rename to toktrie/Cargo.lock diff --git a/core/Cargo.toml b/toktrie/Cargo.toml similarity index 100% rename from core/Cargo.toml rename to toktrie/Cargo.toml diff --git a/README.md b/toktrie/README.md similarity index 100% rename from README.md rename to toktrie/README.md diff --git a/core/src/bytes.rs b/toktrie/src/bytes.rs similarity index 100% rename from core/src/bytes.rs rename to toktrie/src/bytes.rs diff --git a/core/src/lib.rs b/toktrie/src/lib.rs similarity index 100% rename from core/src/lib.rs rename to toktrie/src/lib.rs diff --git a/core/src/recognizer.rs b/toktrie/src/recognizer.rs similarity index 100% rename from core/src/recognizer.rs rename to toktrie/src/recognizer.rs diff --git a/core/src/rng.rs b/toktrie/src/rng.rs similarity index 100% rename from core/src/rng.rs rename to toktrie/src/rng.rs diff --git a/core/src/svob.rs b/toktrie/src/svob.rs similarity index 100% rename from core/src/svob.rs rename to toktrie/src/svob.rs diff --git a/core/src/toktree.rs b/toktrie/src/toktree.rs similarity index 100% rename from core/src/toktree.rs rename to toktrie/src/toktree.rs diff --git a/hf_tokenizers/Cargo.lock b/toktrie_hf_tokenizers/Cargo.lock similarity index 100% rename from hf_tokenizers/Cargo.lock rename to toktrie_hf_tokenizers/Cargo.lock diff --git a/hf_tokenizers/Cargo.toml b/toktrie_hf_tokenizers/Cargo.toml similarity index 100% rename from hf_tokenizers/Cargo.toml rename to toktrie_hf_tokenizers/Cargo.toml diff --git a/hf_tokenizers/src/lib.rs b/toktrie_hf_tokenizers/src/lib.rs similarity index 100% rename from hf_tokenizers/src/lib.rs rename to toktrie_hf_tokenizers/src/lib.rs From eae9b7fca1b7db7598988e4547160167c64aab45 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 30 Nov 2024 10:52:17 -0800 Subject: [PATCH 301/301] remove conflicting files --- .github/workflows/rust.yml | 24 ------------------------ .gitignore | 2 -- 2 files changed, 26 deletions(-) delete mode 100644 .github/workflows/rust.yml delete mode 100644 .gitignore diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 369c923a..00000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Rust - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - name: Build core - run: cargo build --verbose --locked - working-directory: core - - name: Build for hf-tokenizers - run: cargo build --verbose --locked - working-directory: hf_tokenizers diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 847709f2..00000000 --- a/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -target -tmp