diff --git a/core/src/lib.rs b/core/src/lib.rs index 296e053..cb0edf8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,7 +7,7 @@ mod svob; mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; -pub use toktree::{Recognizer, SpecialToken, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TokEnv}; +pub use toktree::{Recognizer, SpecialToken, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}; /// Defines what is allowed in Branch #[derive(Serialize, Deserialize, Clone, Debug, Default)] @@ -23,13 +23,12 @@ pub struct InferenceCapabilities { /// Backtracking is allowed. #[serde(default)] pub backtrack: bool, - + /// More than one branch is allowed. #[serde(default)] pub fork: bool, } - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StepArg { /// Sampling result for the previous iteration. @@ -60,6 +59,14 @@ impl StepArg { acc_tokens.truncate(acc_tokens.len() - bt); acc_tokens.extend_from_slice(&self.tokens); } + + pub fn from_splice(s: &Splice, sampled: Option) -> Self { + StepArg { + backtrack: s.backtrack, + tokens: s.ff_tokens.clone(), + sampled, + } + } } /* @@ -137,13 +144,13 @@ impl Branch { } pub fn spliced(&self, sampled: TokenId) -> Splice { - self.find_splice(sampled).cloned().unwrap_or_else(|| { - Splice { + self.find_splice(sampled) + .cloned() + .unwrap_or_else(|| Splice { when_sampled: vec![], backtrack: 0, ff_tokens: vec![sampled], - } - }) + }) } pub fn unconditional_splice(&self) -> Option<&Splice> {