-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Import toktrie with history https://github.com/microsoft/toktrie/tree…
- Loading branch information
Showing
14 changed files
with
4,241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Support for special tokens | ||
|
||
Tokenizers typically include special tokens, such as | ||
`<|end_of_text|>`, `<|eot_id|>`, `<|python_tag|>`, `<|start_header_id|>`, etc. | ||
This library is tasked with translating between the byte sequences | ||
and tokens. | ||
If you see bytes `<|eot_id|>` in the input, you may or may not want to treat them | ||
as a special token. | ||
|
||
The library assumes that by default you want ot treat them as bytes | ||
(so they would be tokenized as `<|`, `eot`, `_`, `id`, `|>` or similar). | ||
To indicate that you want to treat them as a special token, you need to | ||
prefix them with "marker" byte 0xFF (255) (`TokTrie::SPECIAL_TOKEN_MARKER`). | ||
|
||
Byte FF is chosen as a marker because it is not a valid UTF-8 byte, so it should not normally | ||
occur in regular inputs. | ||
In Rust, you cannot have byte FF in `&str`, only in `&[u8]`. | ||
In Python note the difference between `b"\xFF"` and `"\xFF".encode("utf-8")` | ||
(or equivalently `"\u00FF".encode("utf-8")`), which is `b"\xC3\xBF"`. | ||
|
||
If you're constructing the token array for `TokTrie` constructor manually, | ||
it should include the special tokens prefixed with the marker byte FF. | ||
|
||
The llguidance library does not expose the FF bytes externally | ||
(except for special `tokenize_bytes_marker` methods), so you | ||
generally don't need to worry about them, except when building the `TokTrie`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Implementation notes | ||
|
||
## Token trie | ||
|
||
The round nodes represent tokens, the square nodes do not have a corresponding token. | ||
|
||
The number (`num_parents`) specifies how many parents do you need to pop to get to the parent of the node which comes after our children in DFS order. | ||
|
||
We also keep the `token_id` and a `subtree_size` (which includes the node itself) in each node. | ||
A bogus `token_id` is used for nodes that do not have a corresponding token. | ||
|
||
```mermaid | ||
graph TD | ||
root[ε, 0] -- a --> a((a, 1)) | ||
root -- b --> b((b, 1)) | ||
root -- c --> c((c, 1)) | ||
a -- x --> ax((ax, 1)) | ||
a -- y --> ay[ay, 1] | ||
a -- z --> az((az, 2)) | ||
az -- a --> azq((aza, 3)) | ||
ay -- a --> ayq((aya, 1)) | ||
ay -- b --> ayw((ayb, 2)) | ||
``` | ||
|
||
Traversal algorithm - computing the set of tokens allowed by a stack-based recognizer. | ||
The set is stored in `logits` array - entries with `0.0` are allowed. | ||
|
||
```rust | ||
let mut logits = vec![-100.0; VOCAB_SIZE + 1]; | ||
``` | ||
|
||
A simple version of traversal algorithm: | ||
|
||
```rust | ||
fn traverse(n) { | ||
// mark token as allowed; nodes without token use `token_id == VOCAB_SIZE` | ||
logits[n.token_id] = 0.0; | ||
for c in n.children { | ||
// for every child that starts with an allowed byte | ||
if byte_allowed(c.byte) { | ||
push_byte(c.byte); | ||
// traverse it | ||
traverse(c); | ||
pop_bytes(1); | ||
} | ||
} | ||
} | ||
``` | ||
|
||
Now, assume the tree is laid out in memory in DFS order: | ||
|
||
```rust | ||
fn traverse(mut p) { | ||
let endp = p + nodes[p].subtree_size; | ||
p += 1; // move to first child | ||
while p < endp { | ||
let n = nodes[p]; | ||
if byte_allowed(n.byte) { | ||
push_byte(n.byte); | ||
logits[n.token_id] = 0.0; | ||
// p is moved by n.subtree_size | ||
p = traverse(p); | ||
pop_bytes(1); | ||
} else { | ||
p += n.subtree_size; | ||
} | ||
} | ||
} | ||
``` | ||
|
||
Now, we get rid of the recursion: | ||
|
||
```rust | ||
let mut p = 0; | ||
while p < nodes.len() { | ||
let n = nodes[p]; | ||
if byte_allowed(n.byte) { | ||
push_byte(n.byte); | ||
logits[n.token_id] = 0.0; | ||
// if the node is a leaf, we need to pop all the parents | ||
pop_bytes(if n.subtree_size == 1 { n.num_parents } else { 0 }); | ||
// move to first child, or sibling if no children | ||
p += 1; | ||
} else { | ||
// skip the children, and go to the sibling node | ||
p += n.subtree_size; | ||
// regardless if the node is a leaf, we need to pop all the parents | ||
pop_bytes(n.num_parents - 1); | ||
} | ||
} | ||
``` | ||
|
||
Note that the only branch that gets mis-predicted here is the `if byte_allowed(n.byte)`. | ||
The `if` in argument to `pop_bytes` is compiled to bit operations, so it is branchless. | ||
|
||
### Actual code | ||
|
||
See `add_bias_inner` in [toktree.rs](./core/src/toktree.rs). | ||
|
||
- it uses `try_push_byte()` which combines `byte_allowed()` and `push_byte()` | ||
- it calls `pop_bytes()` at the beginning with a variable stored in previous iteration | ||
|
||
The following is a breakdown of all memory reads and writes, | ||
when used with [llguidance](https://github.com/microsoft/llguidance), | ||
see `try_push_byte()` in [parser.rs](https://github.com/microsoft/llguidance/blob/main/parser/src/earley/parser.rs#L1638). | ||
This only considers the fast lexer path. | ||
|
||
- `pop_bytes()` - only register update (stack length) | ||
- fetch current `TrieNode` (8 bytes) | ||
- `try_push_byte()` - 3 reads, 1 write, see below | ||
- updating token bit-mask - 1 read, 1 write | ||
|
||
The `try_push_byte()` function: | ||
|
||
- fetch lexer state from the stack (1 read) | ||
- compute next DFA state: 1 read for alphabet compression if enabled, 1 read for transition table | ||
- push lexer state to the stack (1 write) | ||
|
||
Together, this is 5 reads and 2 writes per node. | ||
Dependency chain lengths are difficult to estimate, given the possible | ||
speculation and out-of-order execution. | ||
|
||
On an AMD EPYC 7V13 a single node is processed in around 13 cycles | ||
(at 4.2 instructions per cycle); | ||
this drops by 1 cycle if the alphabet compression is disabled | ||
(likely only 1 because lexer stack fetch and alphabet compression fetch can be done in parallel). | ||
|
||
The 7V13 has 4 cycles L1 latency (32KB), 13 cycles L2 latency (512KB), | ||
and 46 cycles L3 latency (up to 32MB per core, but shared). | ||
It also has 6-wide uop dispatch. | ||
Sources: | ||
[EPYC Milan](https://www.anandtech.com/show/16529/amd-epyc-milan-review/4), | ||
[Zen3](https://www.anandtech.com/show/16214/amd-zen-3-ryzen-deep-dive-review-5950x-5900x-5800x-and-5700x-tested/4), | ||
[Zen2](https://www.anandtech.com/show/14694/amd-rome-epyc-2nd-gen/7) (shares L1/L2 specs). |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[package] | ||
name = "toktrie" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[lib] | ||
name = "toktrie" | ||
|
||
[dependencies] | ||
serde = { version = "1.0.192", features = ["derive"] } | ||
serde_json = "1.0.108" | ||
anyhow = "1.0.75" | ||
bytemuck = "1.19.0" | ||
bytemuck_derive = "1.8.0" | ||
rustc-hash = { version = "2.0.0" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# toktrie - Token utility library | ||
|
||
This crate provides a utility library for working with tokens and token tries. | ||
|
||
## Byte stack interface | ||
|
||
The constraints are typically expressed on strings or bytes, not tokens. | ||
To compute the set of tokens that match a string constraint, one needs go through all the possible tokens | ||
and apply the constraint. | ||
An efficient way to do this is walk a prefix tree (trie) of all tokens. | ||
This library implements this trie and exposes a way of filtering when provided with a constraint | ||
implementing the [following interface](core/src/toktree.rs): | ||
|
||
```rust | ||
pub trait Recognizer { | ||
/// If `stack.top()` transitions via `byte` to `X`, execute `stack.push(X)`. | ||
fn push_byte(&mut self, byte: u8); | ||
/// for _ in 0..num { stack.pop() } | ||
fn pop_bytes(&mut self, num: usize); | ||
/// X = stack.top(); stack.empty(); stack.push(X) | ||
fn collapse(&mut self); | ||
/// check if stack.top() transitions via byte to a viable state | ||
fn byte_allowed(&mut self, byte: u8) -> bool; | ||
/// check if stack.top() transitions via tok to a viable state | ||
fn special_allowed(&mut self, tok: SpecialToken) -> bool; | ||
/// Called when iteration over the trie is finished | ||
/// Stack has exactly one element then. | ||
fn trie_finished(&mut self); | ||
/// This combines `push_byte` and `byte_allowed` into one function for performance. | ||
fn try_push_byte(&mut self, byte: u8) -> bool; | ||
} | ||
``` | ||
|
||
The `AiciRecognizer` struct converts `Recognizer` to `AiciCtrl`. | ||
|
||
## Functional byte interface | ||
|
||
The following interface can be transformed into `Recognizer` using `StackRecognizer` struct. | ||
|
||
```rust | ||
pub trait FunctionalRecognizer<S: Copy> { | ||
/// Initial state | ||
fn initial(&self) -> S; | ||
/// Extend the recognizer with given byte. | ||
fn append(&self, state: S, byte: u8) -> S; | ||
/// Check if given byte is allowed in given state. | ||
fn byte_allowed(&self, state: S, byte: u8) -> bool; | ||
/// Check if given special token is allowed in given state. | ||
fn special_allowed(&self, state: S, tok: SpecialToken) -> bool; | ||
} | ||
``` | ||
|
||
These three layers add up to about 40k of compiled code (Wasm). | ||
|
||
|
||
## Contributing | ||
|
||
This project welcomes contributions and suggestions. Most contributions require you to agree to a | ||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us | ||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. | ||
|
||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide | ||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions | ||
provided by the bot. You will only need to do this once across all repos using our CLA. | ||
|
||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). | ||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or | ||
contact [[email protected]](mailto:[email protected]) with any additional questions or comments. | ||
|
||
## Trademarks | ||
|
||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft | ||
trademarks or logos is subject to and must follow | ||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). | ||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. | ||
Any use of third-party trademarks or logos are subject to those third-party's policies. |
Oops, something went wrong.