Skip to content

Commit

Permalink
add sample for C integration
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Oct 22, 2024
1 parent 10abb44 commit 1ac677b
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 11 deletions.
11 changes: 11 additions & 0 deletions c_sample/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ifeq ($(wildcard ../../target),)
TARGET = ../target/release
else
TARGET = ../../target/release
endif

all:
cd ../parser && cargo build --release
c++ -W -Wall -std=c++20 -o $(TARGET)/c_sample c_sample.cpp -I../parser -L$(TARGET) -lllguidance_parser
$(TARGET)/c_sample ../sample_parser/data/blog.schema.ll.json ../sample_parser/data/blog.sample.json

22 changes: 22 additions & 0 deletions c_sample/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# llguidance C++ sample

This is a simple example of how to use the llguidance library in C++.

It reads a Guidance grammar from a JSON file as well as the text that we
pretend the LLM has generated and then makes sure the text conforms to the
grammar.

For a real integration:

- replace `bogus_tokenize()` with a real tokenizer for your LLM
- make sure you pass the list of tokens to `create_tokenizer()`
- for an incoming request, create a constraint based on data in the
request; make sure to handle errors returned by `llg_get_error()`
- while computing logits, run `llg_compute_mask()`
- sample with the returned mask
- pass the sampled token to `llg_commit_token()`

## TODO

- [ ] extend to read JSON schema
- [ ] extend to allow simple regex as constraint
152 changes: 152 additions & 0 deletions c_sample/c_sample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include <cstdio>
#include <cstdint>
#include <vector>
#include <fstream>
#include <sstream>
#include <string>
#include <cassert>

#include "llguidance.h"

// Create an LlgTokenizer; tokens[token_id] is a byte sequence corresponding to
// given token_id; see below for tokenize_fn
LlgTokenizer *create_tokenizer(std::vector<std::vector<uint8_t>> &tokens,
uint32_t tok_eos, LlgTokenizeFn tokenize_fn,
const void *tokenize_user_data) {
auto token_lens = new uint32_t[tokens.size()];
size_t total_size = 0;
for (size_t i = 0; i < tokens.size(); i++) {
token_lens[i] = tokens[i].size();
total_size += token_lens[i];
}
auto token_bytes = new uint8_t[total_size];
size_t offset = 0;
for (size_t i = 0; i < tokens.size(); i++) {
memcpy(token_bytes + offset, tokens[i].data(), token_lens[i]);
offset += token_lens[i];
}
LlgTokenizerInit tok_init = {
.vocab_size = (uint32_t)tokens.size(),
.tok_eos = tok_eos,
.token_lens = token_lens,
.token_bytes = token_bytes,
.tokenize_assumes_string = false,
.tokenize_user_data = tokenize_user_data,
.tokenize_fn = tokenize_fn,
};
return llg_new_tokenizer(&tok_init);
}

// This function assumes that each byte is a single token.
// You want to replace this. This has to be thread-safe!
std::vector<uint32_t> bogus_tokenize(const uint8_t *bytes_ptr, size_t nbytes) {
std::vector<uint32_t> token_ids;
for (size_t i = 0; i < nbytes; i++) {
token_ids.push_back(bytes_ptr[i]);
}
return token_ids;
}

// This wraps a C++-style "bogus_tokenize()" in a way llg wants it.
size_t tokenize_callback(const void *user_data, const uint8_t *bytes,
size_t bytes_len, uint32_t *output_tokens,
size_t output_tokens_len) {
(void)user_data;
auto tokens = bogus_tokenize(bytes, bytes_len);
if (output_tokens_len > 0) {
memcpy(output_tokens, tokens.data(),
std::min(output_tokens_len, tokens.size()) * sizeof(uint32_t));
}
return tokens.size();
}

// This creates a tokenizer that treats each byte as a token.
LlgTokenizer *create_byte_tokenizer(void) {
std::vector<std::vector<uint8_t>> tokens;
// every byte is a token
for (size_t i = 0; i < 256; i++) {
tokens.push_back({(uint8_t)i});
}
const char *eos = "<EOS>";
tokens.push_back(std::vector<uint8_t>(eos, eos + strlen(eos)));
return create_tokenizer(tokens, tokens.size() - 1, tokenize_callback,
nullptr);
}

std::string read_file(const std::string &filePath) {
std::ifstream file(filePath);
std::stringstream buffer;
buffer << file.rdbuf();
return buffer.str();
}

void fail_constraint(LlgConstraint *c) {
printf("Error: %s\n", llg_get_error(c));
llg_free_constraint(c);
exit(1);
}

int main(int argc, const char *argv[]) {
// the tokenizer can (and should) be shared between constraints
LlgTokenizer *tokenizer = create_byte_tokenizer();

if (argc != 3) {
printf("Usage: %s <schema.ll.json> <sample.json>\n", argv[0]);
return 1;
}

auto schema_json = read_file(argv[1]);
auto sample_json = read_file(argv[2]);

LlgConstraintInit init;
llg_constraint_init_set_defaults(&init, tokenizer);
init.log_stderr_level = 0; // default to 1 (warnings only)

LlgConstraint *c = llg_new_constraint(&init, schema_json.c_str());
// this is a very common place where errors can happen - for example the
// schema was invalid
if (llg_get_error(c)) {
fail_constraint(c);
}

// we assume our "LLM" will generate these tokens
auto tokens =
bogus_tokenize((const uint8_t *)sample_json.c_str(), sample_json.size());

LlgMaskResult mask_res;
for (size_t i = 0; i < tokens.size(); i++) {
// compute mask - this can be done with parallel with logit generation
if (llg_compute_mask(c, &mask_res) != 0) {
fail_constraint(c);
}

// here, we would normally sample constrained to mask_res.sample_mask
// using mask_res.temperature
uint32_t token = tokens[i];

// make sure token is in the mask
assert(mask_res.sample_mask[token / 32] & (1 << (token % 32)));

// here we commit the token
// if "ff_tokens" are enabled, this can return more than one token
// to fast-forward
LlgCommitResult commit_res;
if (llg_commit_token(c, tokens[i], &commit_res) != 0) {
fail_constraint(c);
}

// we didn't enable ff_tokens, so the exact token that we passed should be
// returned
assert(commit_res.n_tokens == 1);
assert(commit_res.tokens[0] == token);
}

if (llg_compute_mask(c, &mask_res) != 0) {
fail_constraint(c);
}
// we assume the constraint will force EOS at the end of the input
assert(mask_res.is_stop);

printf("OK!\n");
return 0;
}
13 changes: 10 additions & 3 deletions parser/llguidance.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ typedef struct LlgCommitResult {
* Tokenization function
* Will not write more than output_tokens_len tokens (which can be 0)
* Returns the total number of tokens (which can be more than output_tokens_len)
* This function has to be thread-safe!
*/
typedef size_t (*LlgTokenizeFn)(const uint8_t *bytes,
typedef size_t (*LlgTokenizeFn)(const void *user_data,
const uint8_t *bytes,
size_t bytes_len,
uint32_t *output_tokens,
size_t output_tokens_len);
Expand Down Expand Up @@ -144,13 +146,17 @@ typedef struct LlgTokenizerInit {
*/
bool tokenize_assumes_string;
/**
* Tokenization function, see TokenizeFn docs.
* Tokenization function, see LlgTokenizeFn docs.
* It should only tokenize the bytes and not add
* any <BOS> etc. It should also work on any byte sequence, including
* invalid UTF-8. If this is not the case, set tokenize_assumes_string to true.
* Either way, this function has to be thread-safe!
*/
LlgTokenizeFn tokenize_fn;
/**
* User data to pass to the tokenize_fn
*/
const void *tokenize_user_data;
} LlgTokenizerInit;

#ifdef __cplusplus
Expand All @@ -163,7 +169,8 @@ extern "C" {
* and all logging to the buffer (get with llg_flush_logs()).
* You need to set the tokenizer field manually.
*/
void llg_constraint_init_set_defaults(struct LlgConstraintInit *init);
void llg_constraint_init_set_defaults(struct LlgConstraintInit *init,
const struct LlgTokenizer *tokenizer);

/**
* Create a new constraint from a grammar JSON string
Expand Down
44 changes: 36 additions & 8 deletions parser/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{
ffi::{c_char, CStr},
ffi::{c_char, c_void, CStr},
sync::Arc,
};

use anyhow::Result;
use anyhow::{bail, Result};
use toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenizerEnv};

use crate::{
Expand All @@ -14,17 +14,32 @@ use crate::{
struct CTokenizerInner {
trie: TokTrie,
tokenize_fn: LlgTokenizeFn,
tokenize_user_data: *const c_void,
tokenize_assumes_string: bool,
}
unsafe impl Send for CTokenizerInner {}
unsafe impl Sync for CTokenizerInner {}

impl CTokenizerInner {
fn raw_tokenize(&self, s: &[u8]) -> Vec<toktrie::TokenId> {
let mut res_toks = vec![0; s.len() / 4 + 5];
let n_toks = (self.tokenize_fn)(s.as_ptr(), s.len(), res_toks.as_mut_ptr(), res_toks.len());
let n_toks = (self.tokenize_fn)(
self.tokenize_user_data,
s.as_ptr(),
s.len(),
res_toks.as_mut_ptr(),
res_toks.len(),
);

if n_toks > res_toks.len() {
res_toks.resize(n_toks, 0);
(self.tokenize_fn)(s.as_ptr(), s.len(), res_toks.as_mut_ptr(), res_toks.len());
(self.tokenize_fn)(
self.tokenize_user_data,
s.as_ptr(),
s.len(),
res_toks.as_mut_ptr(),
res_toks.len(),
);
}

res_toks.truncate(n_toks);
Expand Down Expand Up @@ -77,6 +92,7 @@ impl LlgTokenizer {
trie,
tokenize_assumes_string: init.tokenize_assumes_string,
tokenize_fn: init.tokenize_fn,
tokenize_user_data: init.tokenize_user_data,
}),
}
}
Expand All @@ -91,7 +107,9 @@ pub type LlgToken = u32;
/// Tokenization function
/// Will not write more than output_tokens_len tokens (which can be 0)
/// Returns the total number of tokens (which can be more than output_tokens_len)
/// This function has to be thread-safe!
pub type LlgTokenizeFn = extern "C" fn(
user_data: *const c_void,
bytes: *const u8,
bytes_len: usize,
output_tokens: *mut u32,
Expand Down Expand Up @@ -119,12 +137,15 @@ pub struct LlgTokenizerInit {
/// TODO: the <BOS> bit not implemented yet
pub tokenize_assumes_string: bool,

/// Tokenization function, see TokenizeFn docs.
/// Tokenization function, see LlgTokenizeFn docs.
/// It should only tokenize the bytes and not add
/// any <BOS> etc. It should also work on any byte sequence, including
/// invalid UTF-8. If this is not the case, set tokenize_assumes_string to true.
/// Either way, this function has to be thread-safe!
pub tokenize_fn: LlgTokenizeFn,

/// User data to pass to the tokenize_fn
pub tokenize_user_data: *const c_void,
}

#[repr(C)]
Expand Down Expand Up @@ -193,6 +214,10 @@ impl LlgCommitResult {
}

fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
if init.tokenizer.is_null() {
bail!("Tokenizer is null");
}

let grammar_json = unsafe { CStr::from_ptr(grammar_json) }
.to_str()
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in grammar_json"))?;
Expand Down Expand Up @@ -244,10 +269,13 @@ impl LlgConstraint {
/// and all logging to the buffer (get with llg_flush_logs()).
/// You need to set the tokenizer field manually.
#[no_mangle]
pub extern "C" fn llg_constraint_init_set_defaults(init: &mut LlgConstraintInit) {
pub extern "C" fn llg_constraint_init_set_defaults(
init: &mut LlgConstraintInit,
tokenizer: *const LlgTokenizer,
) {
*init = LlgConstraintInit {
tokenizer: std::ptr::null(),
log_buffer_level: 2,
tokenizer,
log_buffer_level: 0,
log_stderr_level: 1,
ff_tokens_ok: false,
backtrack_ok: false,
Expand Down

0 comments on commit 1ac677b

Please sign in to comment.