Skip to content
This repository has been archived by the owner on Nov 30, 2024. It is now read-only.

Commit

Permalink
chat mode support
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Aug 6, 2024
1 parent 7550e79 commit 022b496
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
56 changes: 53 additions & 3 deletions core/src/toktree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,50 @@ pub type TokenId = u32;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
#[repr(C)]
pub struct BinTokRxInfo {
pub vocab_size: u32,
pub tok_eos: TokenId,
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct TokRxInfo {
pub vocab_size: u32,
pub tok_eos: TokenId,
pub tok_bos: Option<TokenId>,
pub tok_pad: Option<TokenId>,
pub tok_unk: Option<TokenId>,
pub tok_end_of_turn: Option<TokenId>,
}

impl TokRxInfo {
pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
TokRxInfo {
vocab_size,
tok_eos,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
}
}

pub fn from_bin(info: &BinTokRxInfo) -> Self {
TokRxInfo {
vocab_size: info.vocab_size,
tok_eos: info.tok_eos,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
}
}

pub fn to_bin(&self) -> BinTokRxInfo {
BinTokRxInfo {
vocab_size: self.vocab_size,
tok_eos: self.tok_eos,
}
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
Expand All @@ -28,6 +69,7 @@ pub enum SpecialToken {
Separator,
BeginningOfSentence,
EndOfSentence,
EndOfTurn,
}

pub trait Recognizer {
Expand Down Expand Up @@ -93,7 +135,7 @@ pub struct TokTrieHeader {
trie_bytes: u32,
token_offset_bytes: u32,
token_data_bytes: u32,
info: TokRxInfo,
info: BinTokRxInfo,
align: [u32; 0],
}

Expand Down Expand Up @@ -178,6 +220,14 @@ impl TokTrie {
r
}

pub fn build_chat_mode_trie(&self) -> Self {
let mut r = self.clone();
if let Some(t) = self.info.tok_end_of_turn {
r.info.tok_eos = t;
}
r
}

fn finalize_ctor(&mut self) {
for tok_id in 0..self.info.vocab_size {
let bytes = self.token(tok_id);
Expand Down Expand Up @@ -447,7 +497,7 @@ impl TokTrie {
let token_data = vec_from_bytes(&bytes[offsets_end..]);

let mut r = TokTrie {
info: hd.info,
info: TokRxInfo::from_bin(&hd.info),
token_offsets,
token_data,
nodes,
Expand Down Expand Up @@ -497,7 +547,7 @@ impl TokTrie {
trie_bytes: trie_data.len() as u32,
token_offset_bytes: token_offsets.len() as u32,
token_data_bytes: trie_data.len() as u32,
info: self.info.clone(),
info: self.info.to_bin(),
align: [],
};

Expand Down
22 changes: 9 additions & 13 deletions hf_tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use anyhow::{anyhow, bail, Result};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, sync::Arc};
use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer};
use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv};

#[derive(Serialize, Deserialize)]
pub struct ByteTokenizer {
pub hf_model: String,
pub hf_tokenizer: Tokenizer,
pub eos_token: u32,
pub vocab_size: u32,
info: TokRxInfo,
token_bytes: Vec<Vec<u8>>,
pub special: BTreeMap<String, u32>,
}
Expand Down Expand Up @@ -129,8 +126,7 @@ impl ByteTokenizer {

let mut res = ByteTokenizer {
hf_model: "foobar".to_string(),
eos_token: 0,
vocab_size,
info: TokRxInfo::new(vocab_size, 0),
special: BTreeMap::new(),
token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(),
hf_tokenizer: hft,
Expand All @@ -139,7 +135,10 @@ impl ByteTokenizer {
for (id, info) in added.iter() {
if info.special {
match info.content.as_str() {
"</s>" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id,
"</s>" | "<|endoftext|>" | "<|end_of_text|>" => res.info.tok_eos = *id,
"<|end|>" | "<|eot_id|>" => res.info.tok_end_of_turn = Some(*id),
"<unk>" | "<|unk|>" => res.info.tok_unk = Some(*id),
"<pad>" | "<|pad|>" => res.info.tok_pad = Some(*id),
_ => {}
}
res.special.insert(info.content.clone(), *id);
Expand Down Expand Up @@ -198,24 +197,21 @@ impl ByteTokenizer {
}

pub fn tokrx_info(&self) -> TokRxInfo {
TokRxInfo {
vocab_size: self.vocab_size,
tok_eos: self.eos_token,
}
self.info.clone()
}
pub fn token_bytes(&self) -> Vec<Vec<u8>> {
self.token_bytes.clone()
}

pub fn add_missing_tokens(&mut self, vocab_size: usize) {
assert!(self.vocab_size == self.token_bytes.len() as u32);
assert!(self.info.vocab_size == self.token_bytes.len() as u32);
assert!(vocab_size >= self.token_bytes.len());
assert!(vocab_size - self.token_bytes.len() <= 200);
while self.token_bytes.len() < vocab_size {
let idx = self.token_bytes.len();
let name = format!("<AddedToken_{idx}>");
self.token_bytes.push(name.as_bytes().to_vec());
self.vocab_size += 1;
self.info.vocab_size += 1;
self.special.insert(name, idx as u32);
}
}
Expand Down

0 comments on commit 022b496

Please sign in to comment.