From ef7e21d016d7ae8282f31e4c442d3b1ac9b84909 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 23 Dec 2024 15:06:58 +0000 Subject: [PATCH] experiment in compressing masks --- json_stats/src/json_stats.rs | 3 ++ parser/src/earley/parser.rs | 1 + parser/src/earley/slicer.rs | 55 ++++++++++++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/json_stats/src/json_stats.rs b/json_stats/src/json_stats.rs index 3abd4bc..8e013fb 100644 --- a/json_stats/src/json_stats.rs +++ b/json_stats/src/json_stats.rs @@ -85,6 +85,8 @@ struct LlgResult { max_mask_us: usize, #[serde(skip_serializing_if = "is_zero")] slicer_leftover_us: usize, + #[serde(skip_serializing_if = "is_zero")] + compressed_mask_size: usize, one: usize, @@ -382,6 +384,7 @@ impl TestEnv { let m = parser.parser.metrics_mut(); stats.slicer_leftover_us += m.slicer_leftover_us; + stats.compressed_mask_size += m.compressed_mask_size; let lx = parser.parser.lexer_stats(); stats.max_lexer_states = std::cmp::max(stats.max_lexer_states, lx.num_states); diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index aa120b3..e09e0e2 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -125,6 +125,7 @@ pub struct ParserMetrics { pub rand: XorShift, pub message: String, pub slicer_leftover_us: usize, + pub compressed_mask_size: usize, } impl ParserStats { diff --git a/parser/src/earley/slicer.rs b/parser/src/earley/slicer.rs index bf1b0dd..b777ba4 100644 --- a/parser/src/earley/slicer.rs +++ b/parser/src/earley/slicer.rs @@ -9,7 +9,7 @@ use crate::{ toktrie::{SimpleVob, TokEnv, TokTrie, TokenId}, }; -use super::parser::ITEM_TRACE; +use super::{parser::ITEM_TRACE, ParserMetrics}; struct TokenizerSlice { idx: usize, @@ -187,6 +187,7 @@ impl BiasComputer for SlicedBiasComputer { fn compute_bias<'b>(&self, rec: &mut ParserRecognizer<'b>, start: &[u8]) -> SimpleVob { let mut set = self.trie().alloc_token_set(); let lexer_state = rec.lexer_state(); + if self.slices.len() > 0 && start.is_empty() && rec.lexer_mut().subsume_possible(lexer_state) @@ -208,9 +209,11 @@ impl BiasComputer for SlicedBiasComputer { if slice_matches.iter().all(|&x| x == false) { // if nothing matches, just run the full trie self.wildcard_slice.add_bias(rec, &mut set, start); + apply_metrics(rec.metrics_mut(), &set); debug!("no slice matches; {} tokens", set.num_set()); } else { // otherwise, apply the matching slices, and compute the rest + let mut acc = self.trie().alloc_token_set(); for (i, slice) in self.slices.iter().enumerate() { if slice_matches[i] { rec.stats_mut().slices_applied += 1; @@ -219,7 +222,8 @@ impl BiasComputer for SlicedBiasComputer { // assert!(slice.regex == ""); let c0 = if DEBUG { set.num_set() } else { 0 }; let t0 = std::time::Instant::now(); - slice.trie.add_bias(rec, &mut set, start); + slice.trie.add_bias(rec, &mut acc, start); + set.or(&acc); let us = t0.elapsed().as_micros() as usize; rec.metrics_mut().slicer_leftover_us += us; debug!("slice matches #{}; {} tokens", i, set.num_set() - c0); @@ -234,9 +238,11 @@ impl BiasComputer for SlicedBiasComputer { // } } } + apply_metrics(rec.metrics_mut(), &acc); } } else { self.wildcard_slice.add_bias(rec, &mut set, start); + apply_metrics(rec.metrics_mut(), &set); debug!("slicer disabled; {} tokens", set.num_set()); } @@ -249,3 +255,48 @@ impl BiasComputer for SlicedBiasComputer { self.tok_env.tok_trie() } } + +fn apply_metrics(parser_metrics: &mut ParserMetrics, mask: &SimpleVob) { + //let size = compress_mask(&mask).len(); + let size = std::cmp::min(mask.num_set() * 2, mask.len() / 8); + parser_metrics.compressed_mask_size += size; +} + +fn compress_mask(s: &SimpleVob) -> Vec { + let mut res: Vec = vec![]; + let mut num_zero = 0; + for &d in s.as_slice() { + let num_bits = d.count_ones(); + if num_bits == 0 { + if num_zero < 32 { + num_zero += 1; + continue; + } + } + if num_zero > 0 { + res.push(num_zero + 32); + num_zero = 0; + } + if num_bits == 1 { + res.push(d.leading_zeros() as u8); + } else if num_bits == 2 { + res.push(d.leading_zeros() as u8); + res.push(d.leading_zeros() as u8); + } else if num_bits == 3 { + res.push(d.leading_zeros() as u8); + res.push(d.leading_zeros() as u8); + res.push(d.leading_zeros() as u8); + } else if false && num_bits == 31 { + res.push(d.leading_ones() as u8); + } else if num_bits == 32 { + res.push(60); + } else { + res.push(61); + res.push(d as u8); + res.push((d >> 8) as u8); + res.push((d >> 16) as u8); + res.push((d >> 24) as u8); + } + } + res +}