Skip to content

Commit

Permalink
add experimental advance_parser()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Aug 17, 2024
1 parent 3c767a6 commit 80edb0e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.6"
edition = "2021"

[dependencies]
toktrie = { git = "https://github.com/microsoft/toktrie", rev = "6934722328ee1d3d679f95fcd5c669d47cee08f2" }
toktrie = { git = "https://github.com/microsoft/toktrie", rev = "59641076bc86504317f07f99465a0f600e957fd3" }
derivre = { git = "https://github.com/microsoft/derivre", rev = "ad363698cc95d7e63c5116aa114596f18dc79385" }
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1.0.108"
Expand Down
9 changes: 9 additions & 0 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,12 @@ class LLInterpreter:
list of tokens to append.
If mid_process() returned None, this should be called immedietly with None.
"""

def advance_parser(
self,
sampled_token: Optional[TokenId]) -> Tuple[int, List[TokenId]]:
"""
Like post_process(), but goes further.
This is experimental and breaks tests when used instead of post_process().
"""

37 changes: 37 additions & 0 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,43 @@ impl LLInterpreter {
}
}

fn advance_parser(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
if !self.step_arg.tokens.is_empty() || !self.step_arg.sampled.is_none() {
return Err(PyValueError::new_err("post_process() called twice"));
}

if self.last_result.is_stop() {
return Err(PyValueError::new_err("post_process() called after stop"));
}

if let Some(s) = self.last_result.unconditional_splice() {
self.step_arg = StepArg::from_splice(s, sampled_token);
return Ok((s.backtrack, s.ff_tokens.clone()));
}

let tok = sampled_token.ok_or_else(|| PyValueError::new_err("Expecting sampled token"))?;
let arg = StepArg {
backtrack: 0,
tokens: vec![tok],
sampled: sampled_token,
};
// TODO this may generate progress entries that we should return
let pres = self.inner.advance_parser(arg);

if let Some(splice) = pres {
self.step_arg = StepArg::from_splice(&splice, sampled_token);
if self.step_arg.backtrack > 0 {
self.step_arg.backtrack -= 1; // the sampled token was ignored
} else {
self.step_arg.tokens.insert(0, tok);
}
Ok((self.step_arg.backtrack, self.step_arg.tokens.clone()))
} else {
// it's stop, really; let the next mid_process() call handle it
Ok((0, vec![]))
}
}

fn post_process(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
let splice = if let Some(t) = sampled_token {
self.last_result.spliced(t)
Expand Down

0 comments on commit 80edb0e

Please sign in to comment.