Skip to content

Commit

Permalink
account for duplicate tokens; see guidance-ai#78
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Mar 18, 2024
1 parent 814342f commit 1c4d60b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 25 deletions.
10 changes: 10 additions & 0 deletions controllers/aici_abi/src/svob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ impl SimpleVob {
self.data.iter().map(|x| x.count_ones() as usize).sum()
}

pub fn negated(&self, size: usize) -> Self {
let mut r = Self::new();
r.data = self.data.iter().map(|x| !x).collect();
for i in size..r.len() {
// disallow tokens that are out of range
r.disallow_token(i as TokenId);
}
r
}

pub unsafe fn as_ptr(&self) -> *const u32 {
self.data.as_ptr()
}
Expand Down
94 changes: 69 additions & 25 deletions controllers/aici_abi/src/toktree.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// use 8:24 encoding - num_ch:tok_id (ch_byte:ch_off)* - 8 bytes per tree node
// special case num_ch=0xff -> num_ch=0x100

use rustc_hash::FxHashMap;

use crate::{
bytes::{
box_from_bytes, clone_as_bytes, clone_vec_as_bytes, vec_from_bytes, TokRxInfo, TokenId,
Expand Down Expand Up @@ -53,6 +55,7 @@ pub struct TokTrie {
token_data: Vec<u8>,
nodes: Vec<TrieNode>,
max_token_len: usize,
token_duplicates: FxHashMap<TokenId, Vec<TokenId>>,
}

#[repr(C)]
Expand Down Expand Up @@ -126,7 +129,7 @@ impl TokTrie {
assert!(info.vocab_size == words.len() as u32);
for (idx, word) in words.iter().enumerate() {
if word.len() > 0 {
trie.insert(word, idx as u32)
trie.insert(word, idx as u32);
}
assert!(word.len() < 0xff);
let desc = (word.len() as u32) | ((token_data.len() as u32) << 8);
Expand All @@ -141,15 +144,27 @@ impl TokTrie {
token_data,
nodes,
max_token_len: 0,
token_duplicates: FxHashMap::default(),
};
r.max_token_len = (0..info.vocab_size)
.map(|idx| r.token(idx).len())
.max()
.unwrap();
r.validate();
r.finalize_ctor();
r
}

fn finalize_ctor(&mut self) {
for tok_id in 0..self.info.vocab_size {
let bytes = self.token(tok_id);
let tok_ids = self.greedy_tokenize(bytes);
self.max_token_len = std::cmp::max(self.max_token_len, bytes.len());
if tok_ids.len() == 1 && tok_ids[0] != tok_id {
self.token_duplicates
.entry(tok_ids[0])
.or_insert_with(Vec::new)
.push(tok_id);
}
}
self.validate();
}

fn node_offset(&self, n: &TrieNode) -> usize {
let off = unsafe { (n as *const TrieNode).offset_from(self.root() as *const TrieNode) };
assert!(off >= 0);
Expand Down Expand Up @@ -184,11 +199,14 @@ impl TokTrie {
}

pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
let num_set = ts.num_set();
let ts_neg = ts.negated(self.vocab_size());
let use_neg = ts_neg.num_set() * 20 < ts.num_set();
let ts1 = if use_neg { &ts_neg } else { &ts };
let num_set = ts1.num_set();
let max_tok = std::cmp::min(100, num_set);
let mut token_names = Vec::new();
for idx in 0..self.vocab_size() {
if ts.is_allowed(idx as TokenId) {
if ts1.is_allowed(idx as TokenId) {
token_names.push(self.token_dbg(idx as TokenId));
if token_names.len() >= max_tok {
break;
Expand All @@ -199,9 +217,10 @@ impl TokTrie {
token_names.push("...".to_string());
}
format!(
"TokenSet: {}/{}; {}",
num_set,
"TokenSet: {}/{}; {}{}",
ts.num_set(),
self.vocab_size(),
if use_neg { "ALL EXCEPT " } else { "" },
token_names.join(", ")
)
}
Expand Down Expand Up @@ -243,7 +262,21 @@ impl TokTrie {
format!("OOB[{}]", idx)
} else {
// format!("{:?}[{}]", self.token_str(idx), idx)
format!("{:?}", self.token_str(idx))
let s = self.token_str(idx);
if s.len() == 0 {
format!("EMPTY[{}]", idx)
} else if !s.contains('\u{fffd}') {
format!("{:?}", s)
} else {
let bytes = self.token(idx);
format!(
"HEX[{}]",
bytes
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>(),
)
}
}
}

Expand Down Expand Up @@ -351,12 +384,9 @@ impl TokTrie {
token_data,
nodes,
max_token_len: 0,
token_duplicates: FxHashMap::default(),
};
r.validate();
r.max_token_len = (0..r.info.vocab_size)
.map(|idx| r.token(idx).len())
.max()
.unwrap();
r.finalize_ctor();
r
}

Expand Down Expand Up @@ -422,13 +452,14 @@ impl TokTrie {
assert!(bytes == self.token(tid));
let root = self.root();
if bytes.len() > 0 {
assert!(
self.child_at_bytes(root, &bytes)
.unwrap()
.token_id()
.unwrap()
== tid
);
let tid2 = self
.child_at_bytes(root, &bytes)
.unwrap()
.token_id()
.unwrap();
if tid != tid2 {
assert!(self.token_duplicates[&tid2].contains(&tid));
}
}
}
}
Expand Down Expand Up @@ -468,7 +499,18 @@ impl TokTrie {
logits.allow_token(self.special_token(tok))
}
}
self.add_bias(r, logits)
self.add_bias(r, logits);
self.apply_duplicates(logits);
}

pub fn apply_duplicates(&self, logits: &mut SimpleVob) {
for (tok, dups) in &self.token_duplicates {
if logits.is_allowed(*tok) {
for &dup in dups {
logits.allow_token(dup);
}
}
}
}

pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) {
Expand Down Expand Up @@ -575,7 +617,9 @@ impl TrieHash {
}
fn insert(&mut self, word: &[u8], token_id: u32) {
if word.len() == 0 {
assert!(self.token_id == NO_TOKEN);
// Some tokenizers have duplicate tokens...
// we just override
// assert!(self.token_id == NO_TOKEN);
self.token_id = token_id;
} else {
if self.children.len() == 0x100 {
Expand Down

0 comments on commit 1c4d60b

Please sign in to comment.