From 1c4d60bab41925acbad4f714cf315f254e2b52d2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 18 Mar 2024 21:30:56 +0000 Subject: [PATCH] account for duplicate tokens; see #78 --- controllers/aici_abi/src/svob.rs | 10 +++ controllers/aici_abi/src/toktree.rs | 94 +++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 25 deletions(-) diff --git a/controllers/aici_abi/src/svob.rs b/controllers/aici_abi/src/svob.rs index 6513080d..c5fd1997 100644 --- a/controllers/aici_abi/src/svob.rs +++ b/controllers/aici_abi/src/svob.rs @@ -41,6 +41,16 @@ impl SimpleVob { self.data.iter().map(|x| x.count_ones() as usize).sum() } + pub fn negated(&self, size: usize) -> Self { + let mut r = Self::new(); + r.data = self.data.iter().map(|x| !x).collect(); + for i in size..r.len() { + // disallow tokens that are out of range + r.disallow_token(i as TokenId); + } + r + } + pub unsafe fn as_ptr(&self) -> *const u32 { self.data.as_ptr() } diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 5e48878c..4c87b206 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -1,6 +1,8 @@ // use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node // special case num_ch=0xff -> num_ch=0x100 +use rustc_hash::FxHashMap; + use crate::{ bytes::{ box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId, @@ -53,6 +55,7 @@ pub struct TokTrie { token_data: Vec, nodes: Vec, max_token_len: usize, + token_duplicates: FxHashMap>, } #[repr(C)] @@ -126,7 +129,7 @@ impl TokTrie { assert!(info.vocab_size == words.len() as u32); for (idx, word) in words.iter().enumerate() { if word.len() > 0 { - trie.insert(word, idx as u32) + trie.insert(word, idx as u32); } assert!(word.len() < 0xff); let desc = (word.len() as u32) | ((token_data.len() as u32) << 8); @@ -141,15 +144,27 @@ impl TokTrie { token_data, nodes, max_token_len: 0, + token_duplicates: FxHashMap::default(), }; - r.max_token_len = (0..info.vocab_size) - .map(|idx| r.token(idx).len()) - .max() - .unwrap(); - r.validate(); + r.finalize_ctor(); r } + fn finalize_ctor(&mut self) { + for tok_id in 0..self.info.vocab_size { + let bytes = self.token(tok_id); + let tok_ids = self.greedy_tokenize(bytes); + self.max_token_len = std::cmp::max(self.max_token_len, bytes.len()); + if tok_ids.len() == 1 && tok_ids[0] != tok_id { + self.token_duplicates + .entry(tok_ids[0]) + .or_insert_with(Vec::new) + .push(tok_id); + } + } + self.validate(); + } + fn node_offset(&self, n: &TrieNode) -> usize { let off = unsafe { (n as *const TrieNode).offset_from(self.root() as *const TrieNode) }; assert!(off >= 0); @@ -184,11 +199,14 @@ impl TokTrie { } pub fn token_set_dbg(&self, ts: &SimpleVob) -> String { - let num_set = ts.num_set(); + let ts_neg = ts.negated(self.vocab_size()); + let use_neg = ts_neg.num_set() * 20 < ts.num_set(); + let ts1 = if use_neg { &ts_neg } else { &ts }; + let num_set = ts1.num_set(); let max_tok = std::cmp::min(100, num_set); let mut token_names = Vec::new(); for idx in 0..self.vocab_size() { - if ts.is_allowed(idx as TokenId) { + if ts1.is_allowed(idx as TokenId) { token_names.push(self.token_dbg(idx as TokenId)); if token_names.len() >= max_tok { break; @@ -199,9 +217,10 @@ impl TokTrie { token_names.push("...".to_string()); } format!( - "TokenSet: {}/{}; {}", - num_set, + "TokenSet: {}/{}; {}{}", + ts.num_set(), self.vocab_size(), + if use_neg { "ALL EXCEPT " } else { "" }, token_names.join(", ") ) } @@ -243,7 +262,21 @@ impl TokTrie { format!("OOB[{}]", idx) } else { // format!("{:?}[{}]", self.token_str(idx), idx) - format!("{:?}", self.token_str(idx)) + let s = self.token_str(idx); + if s.len() == 0 { + format!("EMPTY[{}]", idx) + } else if !s.contains('\u{fffd}') { + format!("{:?}", s) + } else { + let bytes = self.token(idx); + format!( + "HEX[{}]", + bytes + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(), + ) + } } } @@ -351,12 +384,9 @@ impl TokTrie { token_data, nodes, max_token_len: 0, + token_duplicates: FxHashMap::default(), }; - r.validate(); - r.max_token_len = (0..r.info.vocab_size) - .map(|idx| r.token(idx).len()) - .max() - .unwrap(); + r.finalize_ctor(); r } @@ -422,13 +452,14 @@ impl TokTrie { assert!(bytes == self.token(tid)); let root = self.root(); if bytes.len() > 0 { - assert!( - self.child_at_bytes(root, &bytes) - .unwrap() - .token_id() - .unwrap() - == tid - ); + let tid2 = self + .child_at_bytes(root, &bytes) + .unwrap() + .token_id() + .unwrap(); + if tid != tid2 { + assert!(self.token_duplicates[&tid2].contains(&tid)); + } } } } @@ -468,7 +499,18 @@ impl TokTrie { logits.allow_token(self.special_token(tok)) } } - self.add_bias(r, logits) + self.add_bias(r, logits); + self.apply_duplicates(logits); + } + + pub fn apply_duplicates(&self, logits: &mut SimpleVob) { + for (tok, dups) in &self.token_duplicates { + if logits.is_allowed(*tok) { + for &dup in dups { + logits.allow_token(dup); + } + } + } } pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) { @@ -575,7 +617,9 @@ impl TrieHash { } fn insert(&mut self, word: &[u8], token_id: u32) { if word.len() == 0 { - assert!(self.token_id == NO_TOKEN); + // Some tokenizers have duplicate tokens... + // we just override + // assert!(self.token_id == NO_TOKEN); self.token_id = token_id; } else { if self.children.len() == 0x100 {