Skip to content

Commit

Permalink
rename things to Llg*
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Sep 8, 2024
1 parent c434c96 commit 9cdd254
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 37 deletions.
1 change: 1 addition & 0 deletions parser/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn main() {
.with_config(config)
.with_include_guard("LLGUIDANCE_H")
.with_crate(crate_dir)
.rename_item("ParserLimits", "LlgParserLimits")
.generate()
.map_or_else(
|error| match error {
Expand Down
79 changes: 42 additions & 37 deletions parser/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{

struct CTokenizerInner {
trie: TokTrie,
tokenize_fn: TokenizeFn,
tokenize_fn: LlgTokenizeFn,
tokenize_assumes_string: bool,
}

Expand Down Expand Up @@ -51,12 +51,12 @@ impl TokenizerEnv for CTokenizerInner {
}
}

pub struct CTokenizer {
pub struct LlgTokenizer {
token_env: TokEnv,
}

impl CTokenizer {
fn from_init(init: &TokenizerInit) -> Self {
impl LlgTokenizer {
fn from_init(init: &LlgTokenizerInit) -> Self {
let token_lens =
unsafe { std::slice::from_raw_parts(init.token_lens, init.vocab_size as usize) };
let total_len = token_lens.iter().sum::<u32>();
Expand All @@ -72,7 +72,7 @@ impl CTokenizer {
}
let trie = TokTrie::from(&TokRxInfo::new(init.vocab_size, init.tok_eos), &tokens);

CTokenizer {
LlgTokenizer {
token_env: Arc::new(CTokenizerInner {
trie,
tokenize_assumes_string: init.tokenize_assumes_string,
Expand All @@ -86,26 +86,26 @@ impl CTokenizer {
}
}

pub type TokenId = u32;
pub type LlgToken = u32;

/// Tokenization function
/// Will not write more than output_tokens_len tokens (which can be 0)
/// Returns the total number of tokens (which can be more than output_tokens_len)
pub type TokenizeFn = extern "C" fn(
pub type LlgTokenizeFn = extern "C" fn(
bytes: *const u8,
bytes_len: usize,
output_tokens: *mut u32,
output_tokens_len: usize,
) -> usize;

#[repr(C)]
pub struct TokenizerInit {
pub struct LlgTokenizerInit {
/// The number of tokens in the vocabulary
pub vocab_size: u32,

/// The token ID for the end of sentence token
/// For chat mode, set it to end-of-turn token
pub tok_eos: TokenId,
pub tok_eos: LlgToken,

/// An array of the lengths of the token strings (vocab_size elements)
pub token_lens: *const u32,
Expand All @@ -124,13 +124,13 @@ pub struct TokenizerInit {
/// any <BOS> etc. It should also work on any byte sequence, including
/// invalid UTF-8. If this is not the case, set tokenize_assumes_string to true.
/// Either way, this function has to be thread-safe!
pub tokenize_fn: TokenizeFn,
pub tokenize_fn: LlgTokenizeFn,
}

#[repr(C)]
pub struct ConstraintInit {
pub struct LlgConstraintInit {
/// The tokenizer to use, created with llg_new_tokenizer()
pub tokenizer: *const CTokenizer,
pub tokenizer: *const LlgTokenizer,
/// The log level for the buffer that is kept inside of the constraint
/// 0 - no logging, 1 - warnings only, 2 - info
pub log_buffer_level: u32,
Expand All @@ -147,14 +147,14 @@ pub struct ConstraintInit {
pub limits: ParserLimits,
}

pub struct CConstraint {
pub struct LlgConstraint {
local_error: Option<String>,
last_logs: String,
constraint: Option<Constraint>,
}

#[repr(C)]
pub struct CMaskResult {
pub struct LlgMaskResult {
/// One bit per vocab token
/// This is valid until any call to llg_*() on the current constraint
pub sample_mask: *const u32,
Expand All @@ -166,7 +166,7 @@ pub struct CMaskResult {

/// Represents result from llg_commit_token()
#[repr(C)]
pub struct CCommitResult {
pub struct LlgCommitResult {
/// The tokens to append to the output if any
/// This is valid until any call to llg_*() on the current constraint
pub tokens: *const u32,
Expand All @@ -176,7 +176,7 @@ pub struct CCommitResult {
pub is_stop: bool,
}

fn new_constraint(init: &ConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
let grammar_json = unsafe { CStr::from_ptr(grammar_json) }
.to_str()
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in grammar_json"))?;
Expand All @@ -201,7 +201,7 @@ fn new_constraint(init: &ConstraintInit, grammar_json: *const c_char) -> Result<
Ok(Constraint::new(tok_parser))
}

impl CConstraint {
impl LlgConstraint {
fn get_error(&self) -> *const c_char {
match &self.local_error {
Some(e) => e.as_ptr() as *const c_char,
Expand All @@ -228,8 +228,8 @@ impl CConstraint {
/// and all logging to the buffer (get with llg_flush_logs()).
/// You need to set the tokenizer field manually.
#[no_mangle]
pub extern "C" fn llg_constraint_init_set_defaults(init: &mut ConstraintInit) {
*init = ConstraintInit {
pub extern "C" fn llg_constraint_init_set_defaults(init: &mut LlgConstraintInit) {
*init = LlgConstraintInit {
tokenizer: std::ptr::null(),
log_buffer_level: 2,
log_stderr_level: 1,
Expand All @@ -243,10 +243,10 @@ pub extern "C" fn llg_constraint_init_set_defaults(init: &mut ConstraintInit) {
/// Always returns a non-null value. Call llg_get_error() on the result to check for errors.
#[no_mangle]
pub extern "C" fn llg_new_constraint(
init: &ConstraintInit,
init: &LlgConstraintInit,
grammar_json: *const c_char,
) -> *mut CConstraint {
let mut res = CConstraint {
) -> *mut LlgConstraint {
let mut res = LlgConstraint {
local_error: None,
constraint: None,
last_logs: "\x00".to_string(),
Expand All @@ -264,7 +264,7 @@ pub extern "C" fn llg_new_constraint(
/// After it returns a non-null value, it will always return it until the constraint is freed
/// using llg_free_constraint() (at which point the pointer will be invalid).
#[no_mangle]
pub extern "C" fn llg_get_error(cc: &CConstraint) -> *const c_char {
pub extern "C" fn llg_get_error(cc: &LlgConstraint) -> *const c_char {
cc.get_error()
}

Expand All @@ -273,11 +273,11 @@ pub extern "C" fn llg_get_error(cc: &CConstraint) -> *const c_char {
/// Returns 0 on success and -1 on error (use llg_get_error() to get the exact error).
/// When 0 is returned, the result is written to *res_p.
#[no_mangle]
pub extern "C" fn llg_compute_mask(cc: &mut CConstraint, res_p: *mut CMaskResult) -> i32 {
pub extern "C" fn llg_compute_mask(cc: &mut LlgConstraint, res_p: *mut LlgMaskResult) -> i32 {
if let Some(constraint) = &mut cc.constraint {
match constraint.compute_mask() {
Ok(r) => {
let r = CMaskResult {
let r = LlgMaskResult {
sample_mask: r
.sample_mask
.as_ref()
Expand All @@ -299,27 +299,27 @@ pub extern "C" fn llg_compute_mask(cc: &mut CConstraint, res_p: *mut CMaskResult
/// When 0 is returned, the result is written to *res_p.
#[no_mangle]
pub extern "C" fn llg_commit_token(
cc: &mut CConstraint,
token: TokenId,
res_p: *mut CCommitResult,
cc: &mut LlgConstraint,
token: LlgToken,
res_p: *mut LlgCommitResult,
) -> i32 {
if let Some(constraint) = &mut cc.constraint {
let trie = constraint.parser.token_env.tok_trie();
let token = if token < trie.vocab_size() as TokenId {
let token = if token < trie.vocab_size() as LlgToken {
Some(token)
} else {
None
};
match constraint.commit_token(token) {
Ok(r) => {
let res = if let Some(s) = r.unconditional_splice() {
CCommitResult {
LlgCommitResult {
tokens: s.ff_tokens.as_ptr(),
n_tokens: s.ff_tokens.len() as u32,
is_stop: r.is_stop(),
}
} else {
CCommitResult {
LlgCommitResult {
tokens: std::ptr::null(),
n_tokens: 0,
is_stop: r.is_stop(),
Expand All @@ -335,22 +335,22 @@ pub extern "C" fn llg_commit_token(

/// Construct a new tokenizer from the given TokenizerInit
#[no_mangle]
pub extern "C" fn llg_new_tokenizer(tok_init: &TokenizerInit) -> *mut CTokenizer {
let tok = CTokenizer::from_init(tok_init);
pub extern "C" fn llg_new_tokenizer(tok_init: &LlgTokenizerInit) -> *mut LlgTokenizer {
let tok = LlgTokenizer::from_init(tok_init);
Box::into_raw(Box::new(tok))
}

/// Free the tokenizer. Should *NOT* be called while there are still constraints using it.
#[no_mangle]
pub extern "C" fn llg_free_tokenizer(tok: *mut CTokenizer) {
pub extern "C" fn llg_free_tokenizer(tok: *mut LlgTokenizer) {
unsafe {
drop(Box::from_raw(tok));
}
}

/// Free the constraint
#[no_mangle]
pub extern "C" fn llg_free_constraint(cc: *mut CConstraint) {
pub extern "C" fn llg_free_constraint(cc: *mut LlgConstraint) {
unsafe {
drop(Box::from_raw(cc));
}
Expand All @@ -361,9 +361,14 @@ pub extern "C" fn llg_free_constraint(cc: *mut CConstraint) {
/// The logs are kept in the constraint until the next call to this function
/// or until the constraint is freed.
#[no_mangle]
pub extern "C" fn llg_flush_logs(cc: &mut CConstraint) -> *const c_char {
pub extern "C" fn llg_flush_logs(cc: &mut LlgConstraint) -> *const c_char {
if let Some(constraint) = &mut cc.constraint {
cc.last_logs = constraint.flush_logs();
let s = constraint.flush_logs();
if s.contains('\0') {
cc.last_logs = s.replace('\0', "\\0");
} else {
cc.last_logs = s;
}
cc.last_logs.push('\0');
}
cc.last_logs.as_ptr() as *const c_char
Expand Down

0 comments on commit 9cdd254

Please sign in to comment.