Skip to content

Commit

Permalink
fixes in temperature handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Sep 28, 2024
1 parent ce013ae commit 106a28f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
14 changes: 9 additions & 5 deletions parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ impl Constraint {
}

fn save_progress_and_result(&mut self, res: StepResult) {
if let Some(temp) = res.temperature {
self.temperature = temp;
}
self.last_res = res;
if self.log_json_progress {
for p in self.reporter.get_progress(&mut self.parser, &self.last_res) {
Expand All @@ -79,6 +76,13 @@ impl Constraint {
self.parser.logger.write_buffer("\n");
}
}
self.save_temperature();
}

fn save_temperature(&mut self) {
if let Some(temp) = self.parser.parser.temperature() {
self.temperature = temp;
}
}

/// You can call this first with the prompt from the user, when not
Expand All @@ -89,7 +93,7 @@ impl Constraint {
assert!(!self.started);
self.started = true;
let r = self.parser.process_prompt(prompt);
self.temperature = self.parser.parser.temperature();
self.save_temperature();
r
}

Expand Down Expand Up @@ -131,7 +135,7 @@ impl Constraint {
if !self.started {
self.started = true;
self.parser.start_without_prompt();
self.temperature = self.parser.parser.temperature();
self.save_temperature();
}

if self.delayed_stop {
Expand Down
32 changes: 20 additions & 12 deletions parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use derivre::{RegexAst, StateID};
use serde::{Deserialize, Serialize};
use toktrie::{Recognizer, SimpleVob, SpecialToken, TokEnv, TokTrie, TokenId};

use crate::{api::{GenGrammarOptions, ParserLimits}, earley::lexer::Lexer};
use crate::{
api::{GenGrammarOptions, ParserLimits},
earley::lexer::Lexer,
};

use super::{
grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx},
Expand Down Expand Up @@ -620,17 +623,21 @@ impl ParserState {
prefix_len + last_lexeme_visible_len
}

pub fn temperature(&self) -> f32 {
let mut temp = 0.0f32;
pub fn temperature(&self) -> Option<f32> {
let mut temp = -1000.0f32;
for data in self.after_dots_symdata() {
if data.is_terminal {
temp = temp.max(data.props.temperature);
}
}
if self.options.temperature.is_some() {
temp = temp.max(self.options.temperature.unwrap());
if let Some(t) = self.options.temperature {
temp = temp.max(t);
}
if temp < 0.00000001 {
None
} else {
Some(temp)
}
temp
}

pub fn apply_tokens(
Expand Down Expand Up @@ -1130,12 +1137,12 @@ impl ParserState {
}
true
}

// scan() implements the version of Earley described in Kallmeyer 2018.
// An important difference between the algorithm implemented here
// and Kallmeyer's is that in scan(), the token scan is performed
// first, while in Kallmeyer it is performed last.

// lexeme body only used for captures (in definitive mode)
// and debugging (lexeme.idx used always)
fn scan(&mut self, lexeme: &Lexeme) -> bool {
Expand Down Expand Up @@ -1173,7 +1180,7 @@ impl ParserState {

// push_row() does the agenda processing. There is an agenda for
// each Earley set (aka row).

// lexeme only used for captures (in definitive mode)
#[inline(always)]
fn push_row(&mut self, curr_idx: usize, mut agenda_ptr: usize, lexeme: &Lexeme) -> bool {
Expand Down Expand Up @@ -1252,7 +1259,8 @@ impl ParserState {
}
}
}
} else { // ... if 'rule' is an incompletion
} else {
// ... if 'rule' is an incompletion
let sym_data = self.grammar.sym_data(after_dot);
if let Some(lx) = sym_data.lexeme {
allowed_lexemes.set(lx.as_usize(), true);
Expand Down Expand Up @@ -1280,7 +1288,7 @@ impl ParserState {
let new_item = Item::new(*rule, curr_idx);
self.scratch.add_unique(new_item, item_idx, "predict");
}

// TODO the hidden stuff is no longer used
if self.scratch.definitive && sym_data.props.hidden {
for rule in &sym_data.rules {
Expand Down Expand Up @@ -1825,7 +1833,7 @@ impl Parser {
self.state.can_advance()
}

pub fn temperature(&self) -> f32 {
pub fn temperature(&self) -> Option<f32> {
self.state.temperature()
}

Expand Down
2 changes: 1 addition & 1 deletion parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ impl TokenParser {
return self.stop("", StopReason::NoExtensionBias);
}

return StepResult::sample(set, Some(self.parser.temperature()));
return StepResult::sample(set, self.parser.temperature());
}

fn maybe_push_parser(&mut self) -> Result<()> {
Expand Down

0 comments on commit 106a28f

Please sign in to comment.