Skip to content

Commit

Permalink
mask computation time histograms
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 23, 2024
1 parent e2bf313 commit 27b7e19
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 26 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The library implements a context-free grammar parser using Earley’s algorithm

Recently released [XGrammar](https://github.com/mlc-ai/xgrammar) follows an approach similar to llama.cpp (explicit stack-based, character-level parser) with additional pre-computation of certain token masks, similar to Outlines. The pre-computation often runs into seconds, and sometimes minutes. If the pre-computation works well for a given input, the masks are computed quickly (under 50us in half of masks we tested), however if it doesn't fit the particular input,
the mask computation times can run into seconds.
Avarage mask computation time, for masks under 10ms was 277us, however over 3% of masks took longer than 10ms (with avarage time of over 1s).

In llguidance, the full mask computation for a typical JSON schema takes about 1.5ms (for 128k tokenizer).
However, very often the ["slicer" optimization](./docs/optimizations.md) applies,
Expand Down
97 changes: 75 additions & 22 deletions json_stats/scripts/xgr/xgr_combine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
#!/usr/bin/env python3

import json
import math
import glob

output_path = "tmp/xgr/"
llg = False


class Stats:
def __init__(self) -> None:
self.ttfm_us = 0
self.max_ttfm_us = 0
self.masks_us = 0
self.masks_us_under_10ms = 0
self.num_masks_under_10ms = 0
self.avg_masks_under_10ms = 0
self.masks_us_over_10ms = 0
self.num_masks_over_10ms = 0
self.avg_masks_over_10ms = 0
self.max_mask_us = 0
self.num_tokens = 0
self.num_schemas = 0
Expand All @@ -21,6 +29,7 @@ def __init__(self) -> None:
self.num_valid_tests = 0
self.num_invalid_tests = 0


def log_fraction_plot(times: list[int]):
times.sort()
cutoff = 1
Expand All @@ -35,41 +44,85 @@ def log_fraction_plot(times: list[int]):
count += 1
return csv


def histogram_position(us: int):
return int(math.floor(math.log10(max(1, us - 1))))

def us_to_str(us: int):
if us < 1000:
return f"{us}us"
if us < 1000000:
return f"{us//1000}ms"
return f"{us//1000000}s"


def main():
files = glob.glob(output_path + "*.json")
files = sorted(files)
if llg:
files = ["tmp/llg_results.json"]
else:
files = glob.glob(output_path + "*.json")
files = sorted(files)
stats = Stats()
ttfm_us = []
all_masks_us = []
histogram_us = [0] * 10
histogram_num = [0] * 10
for f in files:
with open(f) as f:
data = json.load(f)
if "num_tests" not in data:
continue
stats.num_schemas += 1
stats.num_tests += data["num_tests"]
if "compile_error" in data:
stats.num_compilation_errors += 1
else:
stats.ttfm_us += data["ttfm_us"]
ttfm_us.append(data["ttfm_us"])
stats.max_ttfm_us = max(data["max_ttfm_us"], stats.max_ttfm_us)
stats.masks_us += data["masks_us"]
stats.max_mask_us = max(data["max_mask_us"], stats.max_mask_us)
stats.num_tokens += data["num_tokens"]
if "validation_error" in data:
stats.num_validation_errors += 1
elts = [data]
if isinstance(data, list):
elts = data
for data in elts:
if "num_tests" not in data:
continue
stats.num_schemas += 1
stats.num_tests += data["num_tests"]
if "compile_error" in data:
stats.num_compilation_errors += 1
else:
stats.num_schemas_ok += 1
stats.num_valid_tests += data["num_valid_tests"]
stats.num_invalid_tests += data["num_invalid_tests"]
all_masks_us.extend(data["all_mask_us"])
stats.ttfm_us += data["ttfm_us"]
ttfm_us.append(data["ttfm_us"])
stats.max_ttfm_us = max(data["max_ttfm_us"], stats.max_ttfm_us)
if "masks_us" in data:
stats.masks_us += data["masks_us"]
stats.max_mask_us = max(data["max_mask_us"], stats.max_mask_us)
stats.num_tokens += data["num_tokens"]
if "validation_error" in data:
stats.num_validation_errors += 1
else:
stats.num_schemas_ok += 1
stats.num_valid_tests += data["num_valid_tests"]
stats.num_invalid_tests += data["num_invalid_tests"]
all_masks_us.extend(data["all_mask_us"])
for us in data["all_mask_us"]:
p = histogram_position(us)
histogram_us[p] += us
histogram_num[p] += 1
if us < 10000:
stats.masks_us_under_10ms += us
stats.num_masks_under_10ms += 1
else:
stats.masks_us_over_10ms += us
stats.num_masks_over_10ms += 1
stats.avg_masks_under_10ms = stats.masks_us_under_10ms // stats.num_masks_under_10ms
stats.avg_masks_over_10ms = stats.masks_us_over_10ms // stats.num_masks_over_10ms
print(json.dumps(stats.__dict__, indent=2))
with open("tmp/xgr_ttfm_us.csv", "w") as f:
f.write(log_fraction_plot(ttfm_us))
with open("tmp/xgr_masks_us.csv", "w") as f:
f.write(log_fraction_plot(all_masks_us))

num_masks = sum(histogram_num)
h_csv = "above us,frac\n"
for i in range(10)[1:]:
frac = sum(histogram_num[i:]) * 100 / num_masks
h_csv += f"{us_to_str(10**i):10}"
h_csv += f","
h_csv += f"{frac:1.15}"
h_csv += f"\n"
with open("tmp/xgr_histogram.csv", "w") as f:
f.write(h_csv)
print(h_csv)


main()
8 changes: 4 additions & 4 deletions json_stats/src/json_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ struct LlgResult {
#[serde(skip)]
slow_mask_us_a: [usize; MASK_STEPS],

#[serde(skip)]
all_masks_us: Vec<usize>,
// #[serde(skip)]
all_mask_us: Vec<usize>,

#[serde(skip_serializing_if = "Option::is_none")]
compile_error: Option<String>,
Expand Down Expand Up @@ -276,7 +276,7 @@ impl TestEnv {
let us = t0.elapsed().as_micros() as usize;
let pstats = parser.last_step_stats();

stats.all_masks_us.push(us);
stats.all_mask_us.push(us);

// && pstats.lexer_cost < 7 * us as u64
if self.cli.csv && us > 1000 {
Expand Down Expand Up @@ -728,7 +728,7 @@ fn main() {
total.llg.num_parsers += 1;

all_ttfm_us.push(llg.ttfm_us);
all_masks_us.extend_from_slice(&llg.all_masks_us);
all_masks_us.extend_from_slice(&llg.all_mask_us);
}

total.llg.ttfm_us += llg.ttfm_us;
Expand Down

0 comments on commit 27b7e19

Please sign in to comment.