Skip to content
This repository has been archived by the owner on Nov 30, 2024. It is now read-only.

Commit

Permalink
add utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jul 10, 2024
1 parent 0af6a28 commit a13ea99
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod toktree;
pub use svob::{SimpleVob, SimpleVobIter};
pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv};

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StepArg {
/// Sampling result for the previous iteration.
/// For simple sampled token 't', backtrack==0 and tokens==[t].
Expand All @@ -22,6 +22,14 @@ pub struct StepArg {
}

impl StepArg {
pub fn empty() -> Self {
StepArg {
backtrack: 0,
tokens: vec![],
sampled: None,
}
}

pub fn save_tokens(&self, acc_tokens: &mut Vec<TokenId>) {
let bt = self.backtrack as usize;
assert!(
Expand Down Expand Up @@ -101,6 +109,30 @@ impl<S> Branch<S> {
}
}

pub fn find_splice(&self, sampled: TokenId) -> Option<&Splice> {
self.splices
.iter()
.find(|s| s.when_sampled.is_empty() || s.when_sampled.contains(&sampled))
}

pub fn spliced(&self, sampled: TokenId) -> Splice {
self.find_splice(sampled).cloned().unwrap_or_else(|| {
Splice {
when_sampled: vec![],
backtrack: 0,
ff_tokens: vec![sampled],
}
})
}

pub fn unconditional_splice(&self) -> Option<&Splice> {
if self.splices.len() == 1 && self.splices[0].when_sampled.is_empty() {
Some(&self.splices[0])
} else {
None
}
}

pub fn has_backtrack(&self) -> bool {
let max_bt = if self.sample_mask.is_none() { 0 } else { 1 };
self.splices.iter().any(|s| s.backtrack > max_bt)
Expand Down

0 comments on commit a13ea99

Please sign in to comment.