From 022b496ef5feebe4a0234df5ffa7c146067dc55d Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 6 Aug 2024 17:20:59 +0000 Subject: [PATCH] chat mode support --- core/src/toktree.rs | 56 +++++++++++++++++++++++++++++++++++++--- hf_tokenizers/src/lib.rs | 22 +++++++--------- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/core/src/toktree.rs b/core/src/toktree.rs index 5829c57..643d14d 100644 --- a/core/src/toktree.rs +++ b/core/src/toktree.rs @@ -16,9 +16,50 @@ pub type TokenId = u32; #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)] #[repr(C)] +pub struct BinTokRxInfo { + pub vocab_size: u32, + pub tok_eos: TokenId, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct TokRxInfo { pub vocab_size: u32, pub tok_eos: TokenId, + pub tok_bos: Option, + pub tok_pad: Option, + pub tok_unk: Option, + pub tok_end_of_turn: Option, +} + +impl TokRxInfo { + pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self { + TokRxInfo { + vocab_size, + tok_eos, + tok_bos: None, + tok_pad: None, + tok_unk: None, + tok_end_of_turn: None, + } + } + + pub fn from_bin(info: &BinTokRxInfo) -> Self { + TokRxInfo { + vocab_size: info.vocab_size, + tok_eos: info.tok_eos, + tok_bos: None, + tok_pad: None, + tok_unk: None, + tok_end_of_turn: None, + } + } + + pub fn to_bin(&self) -> BinTokRxInfo { + BinTokRxInfo { + vocab_size: self.vocab_size, + tok_eos: self.tok_eos, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -28,6 +69,7 @@ pub enum SpecialToken { Separator, BeginningOfSentence, EndOfSentence, + EndOfTurn, } pub trait Recognizer { @@ -93,7 +135,7 @@ pub struct TokTrieHeader { trie_bytes: u32, token_offset_bytes: u32, token_data_bytes: u32, - info: TokRxInfo, + info: BinTokRxInfo, align: [u32; 0], } @@ -178,6 +220,14 @@ impl TokTrie { r } + pub fn build_chat_mode_trie(&self) -> Self { + let mut r = self.clone(); + if let Some(t) = self.info.tok_end_of_turn { + r.info.tok_eos = t; + } + r + } + fn finalize_ctor(&mut self) { for tok_id in 0..self.info.vocab_size { let bytes = self.token(tok_id); @@ -447,7 +497,7 @@ impl TokTrie { let token_data = vec_from_bytes(&bytes[offsets_end..]); let mut r = TokTrie { - info: hd.info, + info: TokRxInfo::from_bin(&hd.info), token_offsets, token_data, nodes, @@ -497,7 +547,7 @@ impl TokTrie { trie_bytes: trie_data.len() as u32, token_offset_bytes: token_offsets.len() as u32, token_data_bytes: trie_data.len() as u32, - info: self.info.clone(), + info: self.info.to_bin(), align: [], }; diff --git a/hf_tokenizers/src/lib.rs b/hf_tokenizers/src/lib.rs index 781e78b..73c9066 100644 --- a/hf_tokenizers/src/lib.rs +++ b/hf_tokenizers/src/lib.rs @@ -1,16 +1,13 @@ use anyhow::{anyhow, bail, Result}; use rustc_hash::FxHashMap; -use serde::{Deserialize, Serialize}; use std::{collections::BTreeMap, sync::Arc}; use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; use toktrie::{TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; -#[derive(Serialize, Deserialize)] pub struct ByteTokenizer { pub hf_model: String, pub hf_tokenizer: Tokenizer, - pub eos_token: u32, - pub vocab_size: u32, + info: TokRxInfo, token_bytes: Vec>, pub special: BTreeMap, } @@ -129,8 +126,7 @@ impl ByteTokenizer { let mut res = ByteTokenizer { hf_model: "foobar".to_string(), - eos_token: 0, - vocab_size, + info: TokRxInfo::new(vocab_size, 0), special: BTreeMap::new(), token_bytes: (0..vocab_size).map(|_| Vec::new()).collect(), hf_tokenizer: hft, @@ -139,7 +135,10 @@ impl ByteTokenizer { for (id, info) in added.iter() { if info.special { match info.content.as_str() { - "" | "<|endoftext|>" | "<|end_of_text|>" => res.eos_token = *id, + "" | "<|endoftext|>" | "<|end_of_text|>" => res.info.tok_eos = *id, + "<|end|>" | "<|eot_id|>" => res.info.tok_end_of_turn = Some(*id), + "" | "<|unk|>" => res.info.tok_unk = Some(*id), + "" | "<|pad|>" => res.info.tok_pad = Some(*id), _ => {} } res.special.insert(info.content.clone(), *id); @@ -198,24 +197,21 @@ impl ByteTokenizer { } pub fn tokrx_info(&self) -> TokRxInfo { - TokRxInfo { - vocab_size: self.vocab_size, - tok_eos: self.eos_token, - } + self.info.clone() } pub fn token_bytes(&self) -> Vec> { self.token_bytes.clone() } pub fn add_missing_tokens(&mut self, vocab_size: usize) { - assert!(self.vocab_size == self.token_bytes.len() as u32); + assert!(self.info.vocab_size == self.token_bytes.len() as u32); assert!(vocab_size >= self.token_bytes.len()); assert!(vocab_size - self.token_bytes.len() <= 200); while self.token_bytes.len() < vocab_size { let idx = self.token_bytes.len(); let name = format!(""); self.token_bytes.push(name.as_bytes().to_vec()); - self.vocab_size += 1; + self.info.vocab_size += 1; self.special.insert(name, idx as u32); } }