From 106a28f0a392a47755e13d5df48c29aad258230f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sat, 28 Sep 2024 01:13:46 +0000 Subject: [PATCH] fixes in temperature handling --- parser/src/constraint.rs | 14 +++++++++----- parser/src/earley/parser.rs | 32 ++++++++++++++++++++------------ parser/src/tokenparser.rs | 2 +- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/parser/src/constraint.rs b/parser/src/constraint.rs index 1a48c71a..c9729320 100644 --- a/parser/src/constraint.rs +++ b/parser/src/constraint.rs @@ -66,9 +66,6 @@ impl Constraint { } fn save_progress_and_result(&mut self, res: StepResult) { - if let Some(temp) = res.temperature { - self.temperature = temp; - } self.last_res = res; if self.log_json_progress { for p in self.reporter.get_progress(&mut self.parser, &self.last_res) { @@ -79,6 +76,13 @@ impl Constraint { self.parser.logger.write_buffer("\n"); } } + self.save_temperature(); + } + + fn save_temperature(&mut self) { + if let Some(temp) = self.parser.parser.temperature() { + self.temperature = temp; + } } /// You can call this first with the prompt from the user, when not @@ -89,7 +93,7 @@ impl Constraint { assert!(!self.started); self.started = true; let r = self.parser.process_prompt(prompt); - self.temperature = self.parser.parser.temperature(); + self.save_temperature(); r } @@ -131,7 +135,7 @@ impl Constraint { if !self.started { self.started = true; self.parser.start_without_prompt(); - self.temperature = self.parser.parser.temperature(); + self.save_temperature(); } if self.delayed_stop { diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index c0f05291..726cb2ce 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -18,7 +18,10 @@ use derivre::{RegexAst, StateID}; use serde::{Deserialize, Serialize}; use toktrie::{Recognizer, SimpleVob, SpecialToken, TokEnv, TokTrie, TokenId}; -use crate::{api::{GenGrammarOptions, ParserLimits}, earley::lexer::Lexer}; +use crate::{ + api::{GenGrammarOptions, ParserLimits}, + earley::lexer::Lexer, +}; use super::{ grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx}, @@ -620,17 +623,21 @@ impl ParserState { prefix_len + last_lexeme_visible_len } - pub fn temperature(&self) -> f32 { - let mut temp = 0.0f32; + pub fn temperature(&self) -> Option { + let mut temp = -1000.0f32; for data in self.after_dots_symdata() { if data.is_terminal { temp = temp.max(data.props.temperature); } } - if self.options.temperature.is_some() { - temp = temp.max(self.options.temperature.unwrap()); + if let Some(t) = self.options.temperature { + temp = temp.max(t); + } + if temp < 0.00000001 { + None + } else { + Some(temp) } - temp } pub fn apply_tokens( @@ -1130,12 +1137,12 @@ impl ParserState { } true } - + // scan() implements the version of Earley described in Kallmeyer 2018. // An important difference between the algorithm implemented here // and Kallmeyer's is that in scan(), the token scan is performed // first, while in Kallmeyer it is performed last. - + // lexeme body only used for captures (in definitive mode) // and debugging (lexeme.idx used always) fn scan(&mut self, lexeme: &Lexeme) -> bool { @@ -1173,7 +1180,7 @@ impl ParserState { // push_row() does the agenda processing. There is an agenda for // each Earley set (aka row). - + // lexeme only used for captures (in definitive mode) #[inline(always)] fn push_row(&mut self, curr_idx: usize, mut agenda_ptr: usize, lexeme: &Lexeme) -> bool { @@ -1252,7 +1259,8 @@ impl ParserState { } } } - } else { // ... if 'rule' is an incompletion + } else { + // ... if 'rule' is an incompletion let sym_data = self.grammar.sym_data(after_dot); if let Some(lx) = sym_data.lexeme { allowed_lexemes.set(lx.as_usize(), true); @@ -1280,7 +1288,7 @@ impl ParserState { let new_item = Item::new(*rule, curr_idx); self.scratch.add_unique(new_item, item_idx, "predict"); } - + // TODO the hidden stuff is no longer used if self.scratch.definitive && sym_data.props.hidden { for rule in &sym_data.rules { @@ -1825,7 +1833,7 @@ impl Parser { self.state.can_advance() } - pub fn temperature(&self) -> f32 { + pub fn temperature(&self) -> Option { self.state.temperature() } diff --git a/parser/src/tokenparser.rs b/parser/src/tokenparser.rs index 8fa72c3f..4fe3552f 100644 --- a/parser/src/tokenparser.rs +++ b/parser/src/tokenparser.rs @@ -679,7 +679,7 @@ impl TokenParser { return self.stop("", StopReason::NoExtensionBias); } - return StepResult::sample(set, Some(self.parser.temperature())); + return StepResult::sample(set, self.parser.temperature()); } fn maybe_push_parser(&mut self) -> Result<()> {