Skip to content

Commit

Permalink
Import toktrie with history https://github.com/microsoft/toktrie/tree…
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 30, 2024
2 parents 8727abd + eae9b7f commit eb61519
Show file tree
Hide file tree
Showing 14 changed files with 4,241 additions and 0 deletions.
26 changes: 26 additions & 0 deletions docs/special_tokens.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Support for special tokens

Tokenizers typically include special tokens, such as
`<|end_of_text|>`, `<|eot_id|>`, `<|python_tag|>`, `<|start_header_id|>`, etc.
This library is tasked with translating between the byte sequences
and tokens.
If you see bytes `<|eot_id|>` in the input, you may or may not want to treat them
as a special token.

The library assumes that by default you want ot treat them as bytes
(so they would be tokenized as `<|`, `eot`, `_`, `id`, `|>` or similar).
To indicate that you want to treat them as a special token, you need to
prefix them with "marker" byte 0xFF (255) (`TokTrie::SPECIAL_TOKEN_MARKER`).

Byte FF is chosen as a marker because it is not a valid UTF-8 byte, so it should not normally
occur in regular inputs.
In Rust, you cannot have byte FF in `&str`, only in `&[u8]`.
In Python note the difference between `b"\xFF"` and `"\xFF".encode("utf-8")`
(or equivalently `"\u00FF".encode("utf-8")`), which is `b"\xC3\xBF"`.

If you're constructing the token array for `TokTrie` constructor manually,
it should include the special tokens prefixed with the marker byte FF.

The llguidance library does not expose the FF bytes externally
(except for special `tokenize_bytes_marker` methods), so you
generally don't need to worry about them, except when building the `TokTrie`.
134 changes: 134 additions & 0 deletions docs/toktrie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Implementation notes

## Token trie

The round nodes represent tokens, the square nodes do not have a corresponding token.

The number (`num_parents`) specifies how many parents do you need to pop to get to the parent of the node which comes after our children in DFS order.

We also keep the `token_id` and a `subtree_size` (which includes the node itself) in each node.
A bogus `token_id` is used for nodes that do not have a corresponding token.

```mermaid
graph TD
root[ε, 0] -- a --> a((a, 1))
root -- b --> b((b, 1))
root -- c --> c((c, 1))
a -- x --> ax((ax, 1))
a -- y --> ay[ay, 1]
a -- z --> az((az, 2))
az -- a --> azq((aza, 3))
ay -- a --> ayq((aya, 1))
ay -- b --> ayw((ayb, 2))
```

Traversal algorithm - computing the set of tokens allowed by a stack-based recognizer.
The set is stored in `logits` array - entries with `0.0` are allowed.

```rust
let mut logits = vec![-100.0; VOCAB_SIZE + 1];
```

A simple version of traversal algorithm:

```rust
fn traverse(n) {
// mark token as allowed; nodes without token use `token_id == VOCAB_SIZE`
logits[n.token_id] = 0.0;
for c in n.children {
// for every child that starts with an allowed byte
if byte_allowed(c.byte) {
push_byte(c.byte);
// traverse it
traverse(c);
pop_bytes(1);
}
}
}
```

Now, assume the tree is laid out in memory in DFS order:

```rust
fn traverse(mut p) {
let endp = p + nodes[p].subtree_size;
p += 1; // move to first child
while p < endp {
let n = nodes[p];
if byte_allowed(n.byte) {
push_byte(n.byte);
logits[n.token_id] = 0.0;
// p is moved by n.subtree_size
p = traverse(p);
pop_bytes(1);
} else {
p += n.subtree_size;
}
}
}
```

Now, we get rid of the recursion:

```rust
let mut p = 0;
while p < nodes.len() {
let n = nodes[p];
if byte_allowed(n.byte) {
push_byte(n.byte);
logits[n.token_id] = 0.0;
// if the node is a leaf, we need to pop all the parents
pop_bytes(if n.subtree_size == 1 { n.num_parents } else { 0 });
// move to first child, or sibling if no children
p += 1;
} else {
// skip the children, and go to the sibling node
p += n.subtree_size;
// regardless if the node is a leaf, we need to pop all the parents
pop_bytes(n.num_parents - 1);
}
}
```

Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`.
The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless.

### Actual code

See `add_bias_inner` in [toktree.rs](./core/src/toktree.rs).

- it uses `try_push_byte()` which combines `byte_allowed()` and `push_byte()`
- it calls `pop_bytes()` at the beginning with a variable stored in previous iteration

The following is a breakdown of all memory reads and writes,
when used with [llguidance](https://github.com/microsoft/llguidance),
see `try_push_byte()` in [parser.rs](https://github.com/microsoft/llguidance/blob/main/parser/src/earley/parser.rs#L1638).
This only considers the fast lexer path.

- `pop_bytes()` - only register update (stack length)
- fetch current `TrieNode` (8 bytes)
- `try_push_byte()` - 3 reads, 1 write, see below
- updating token bit-mask - 1 read, 1 write

The `try_push_byte()` function:

- fetch lexer state from the stack (1 read)
- compute next DFA state: 1 read for alphabet compression if enabled, 1 read for transition table
- push lexer state to the stack (1 write)

Together, this is 5 reads and 2 writes per node.
Dependency chain lengths are difficult to estimate, given the possible
speculation and out-of-order execution.

On an AMD EPYC 7V13 a single node is processed in around 13 cycles
(at 4.2 instructions per cycle);
this drops by 1 cycle if the alphabet compression is disabled
(likely only 1 because lexer stack fetch and alphabet compression fetch can be done in parallel).

The 7V13 has 4 cycles L1 latency (32KB), 13 cycles L2 latency (512KB),
and 46 cycles L3 latency (up to 32MB per core, but shared).
It also has 6-wide uop dispatch.
Sources:
[EPYC Milan](https://www.anandtech.com/show/16529/amd-epyc-milan-review/4),
[Zen3](https://www.anandtech.com/show/16214/amd-zen-3-ryzen-deep-dive-review-5950x-5900x-5800x-and-5700x-tested/4),
[Zen2](https://www.anandtech.com/show/14694/amd-rome-epyc-2nd-gen/7) (shares L1/L2 specs).
129 changes: 129 additions & 0 deletions toktrie/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions toktrie/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "toktrie"
version = "0.1.0"
edition = "2021"

[lib]
name = "toktrie"

[dependencies]
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1.0.108"
anyhow = "1.0.75"
bytemuck = "1.19.0"
bytemuck_derive = "1.8.0"
rustc-hash = { version = "2.0.0" }
76 changes: 76 additions & 0 deletions toktrie/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# toktrie - Token utility library

This crate provides a utility library for working with tokens and token tries.

## Byte stack interface

The constraints are typically expressed on strings or bytes, not tokens.
To compute the set of tokens that match a string constraint, one needs go through all the possible tokens
and apply the constraint.
An efficient way to do this is walk a prefix tree (trie) of all tokens.
This library implements this trie and exposes a way of filtering when provided with a constraint
implementing the [following interface](core/src/toktree.rs):

```rust
pub trait Recognizer {
/// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`.
fn push_byte(&mut self, byte: u8);
/// for _ in 0..num { stack.pop() }
fn pop_bytes(&mut self, num: usize);
/// X = stack.top(); stack.empty(); stack.push(X)
fn collapse(&mut self);
/// check if stack.top() transitions via byte to a viable state
fn byte_allowed(&mut self, byte: u8) -> bool;
/// check if stack.top() transitions via tok to a viable state
fn special_allowed(&mut self, tok: SpecialToken) -> bool;
/// Called when iteration over the trie is finished
/// Stack has exactly one element then.
fn trie_finished(&mut self);
/// This combines `push_byte` and `byte_allowed` into one function for performance.
fn try_push_byte(&mut self, byte: u8) -> bool;
}
```

The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`.

## Functional byte interface

The following interface can be transformed into `Recognizer` using `StackRecognizer` struct.

```rust
pub trait FunctionalRecognizer<S: Copy> {
/// Initial state
fn initial(&self) -> S;
/// Extend the recognizer with given byte.
fn append(&self, state: S, byte: u8) -> S;
/// Check if given byte is allowed in given state.
fn byte_allowed(&self, state: S, byte: u8) -> bool;
/// Check if given special token is allowed in given state.
fn special_allowed(&self, state: S, tok: SpecialToken) -> bool;
}
```

These three layers add up to about 40k of compiled code (Wasm).


## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [[email protected]](mailto:[email protected]) with any additional questions or comments.

## Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
Loading

0 comments on commit eb61519

Please sign in to comment.