Skip to content

Commit

Permalink
[perf] Optimize alpha_pow for Static Verifier (#1261)
Browse files Browse the repository at this point in the history
* Revert alpha_pow computation in static verifier

* Add cycle tracker span

* Fix lint

* Use k=23 for static verifier by default

* Fix bug

* Remove unused branch

* Change MAX_LOG_WIDTH to 31
  • Loading branch information
nyunyunyunyu authored Jan 24, 2025
1 parent d45b226 commit 157f819
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 23 deletions.
2 changes: 1 addition & 1 deletion benchmarks/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl BenchmarkCli {
},
},
halo2_config: Halo2Config {
verifier_k: self.halo2_outer_k.unwrap_or(24),
verifier_k: self.halo2_outer_k.unwrap_or(23),
wrapper_k: self.halo2_wrapper_k,
profiling: self.profiling,
},
Expand Down
84 changes: 62 additions & 22 deletions extensions/native/recursion/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,26 @@ pub fn verify_two_adic_pcs<C: Config>(

let log_max_height =
builder.eval_expr(proof.commit_phase_commits.len() + RVar::from(log_blowup));

builder.cycle_tracker_start("pre-compute-alpha-pows");
// Only used in dynamic mode.
let round_alpha_pows = compute_round_alpha_pows(builder, rounds.clone(), alpha);
// Only used in static mode.
let alpha_pows = if builder.flags.static_only {
let widths = get_round_widths(builder, &rounds);
let max_width = *widths.iter().max().unwrap();
let mut ret = Vec::with_capacity(max_width + 1);
ret.push(C::EF::ONE.cons());
for i in 1..=max_width {
let curr = builder.eval(ret[i - 1].clone() * alpha);
builder.ext_reduce_circuit(curr);
ret.push(curr.into());
}
ret
} else {
vec![]
};
builder.cycle_tracker_end("pre-compute-alpha-pows");

iter_zip!(builder, proof.query_proofs).for_each(|ptr_vec, builder| {
let query_proof = builder.iter_ptr_get(&proof.query_proofs, ptr_vec[0]);
Expand Down Expand Up @@ -105,7 +124,12 @@ pub fn verify_two_adic_pcs<C: Config>(
|ptr_vec, builder| {
let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]);
let round = builder.iter_ptr_get(&rounds, ptr_vec[1]);
let mat_alpha_pows = builder.iter_ptr_get(&round_alpha_pows, ptr_vec[2]);
let mat_alpha_pows = if builder.flags.static_only {
// Static verifier uses a different way to compute `alpha_pows` but we need to return a placeholder here.
builder.array(0)
} else {
builder.iter_ptr_get(&round_alpha_pows, ptr_vec[2])
};
let batch_commit = round.batch_commit;
let mats = round.mats;
let permutation = round.permutation;
Expand Down Expand Up @@ -205,7 +229,11 @@ pub fn verify_two_adic_pcs<C: Config>(
|ptr_vec, builder| {
let mat_opening = builder.iter_ptr_get(&opened_values, ptr_vec[0]);
let mat = builder.iter_ptr_get(&mats, ptr_vec[1]);
let mat_alpha_pow = builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2]);
let mat_alpha_pow = if builder.flags.static_only {
builder.uninit()
} else {
builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2])
};
let mat_points = mat.points;
let mat_values = mat.values;
let domain = mat.domain;
Expand All @@ -230,13 +258,16 @@ pub fn verify_two_adic_pcs<C: Config>(
if builder.flags.static_only {
let n: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let width = ps_at_z.len().value();
for t in (0..width).rev() {
for (t, alpha_pow) in alpha_pows.iter().take(width).enumerate() {
let p_at_x = builder.get(&mat_opening, t);
let p_at_z = builder.get(&ps_at_z, t);
builder.assign(&n, n * alpha + (p_at_z - p_at_x));
builder.assign(&n, n + (p_at_z - p_at_x) * alpha_pow.clone());
}
builder.assign(&cur_ro, cur_ro + n / (z - x) * cur_alpha_pow);
builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow);
builder.assign(
&cur_alpha_pow,
cur_alpha_pow * alpha_pows[width].clone(),
);
} else {
let mat_ro = builder.fri_single_reduced_opening_eval(
alpha,
Expand Down Expand Up @@ -375,9 +406,13 @@ fn compute_round_alpha_pows<C: Config>(
rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
alpha: Ext<C::F, C::EF>,
) -> Array<C, Array<C, Ext<C::F, C::EF>>> {
// Static verifier uses a different way to compute `alpha_pows` but we need to return a placeholder here.
if builder.flags.static_only {
return builder.array(0);
}
// Max log of matrix width
// TODO: this should be determined by VK.
const MAX_LOG_WIDTH: usize = 15;
const MAX_LOG_WIDTH: usize = 31;
let pow_of_alpha: Array<C, Ext<_, _>> = builder.array(MAX_LOG_WIDTH);
let current: Ext<_, _> = builder.eval(alpha);
for i in 0..MAX_LOG_WIDTH {
Expand All @@ -390,24 +425,11 @@ fn compute_round_alpha_pows<C: Config>(
let round = builder.iter_ptr_get(&rounds, ptr_vec[0]);
let mat_alpha_pows: Array<C, Ext<_, _>> = builder.array(round.mats.len());
iter_zip!(builder, round.mats, mat_alpha_pows).for_each(|ptr_vec, builder| {
assert!(!builder.flags.static_only);
let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]);
let local = builder.get(&mat.values, 0);
let width = local.len();
let mat_alpha_pow: Ext<_, _> = if builder.flags.static_only {
let width = width.value();
assert!(width < 1 << MAX_LOG_WIDTH);
let mut expr = C::EF::ONE.cons();
for i in 0..MAX_LOG_WIDTH {
if width & (1 << i) != 0 {
expr *= builder.get(&pow_of_alpha, i);
}
}
let ret: Ext<_, _> = builder.eval(expr);
// Minimize max_bits so following computation becomes cheaper.
builder.ext_reduce_circuit(ret);
ret
} else {
let width = width.get_var();
let mat_alpha_pow: Ext<_, _> = {
let width = local.len().get_var();
// This is dynamic only so safe to cast.
let width_f = builder.unsafe_cast_var_to_felt(width);
let bits = builder.num2bits_f(width_f, MAX_LOG_WIDTH as u32);
Expand All @@ -427,6 +449,24 @@ fn compute_round_alpha_pows<C: Config>(
});
round_alpha_pows
}

// Get widths of all matrices in rounds.
fn get_round_widths<C: Config>(
builder: &mut Builder<C>,
rounds: &Array<C, TwoAdicPcsRoundVariable<C>>,
) -> Vec<usize> {
assert!(builder.flags.static_only);
let mut ret = Vec::new();
for i in 0..rounds.len().value() {
let round = builder.get(rounds, i);
for j in 0..round.mats.len().value() {
let mat = builder.get(&round.mats, j);
let local = builder.get(&mat.values, 0);
ret.push(local.len().value());
}
}
ret
}
pub mod tests {
use std::cmp::Reverse;

Expand Down

0 comments on commit 157f819

Please sign in to comment.