diff --git a/extensions/native/compiler/src/ir/builder.rs b/extensions/native/compiler/src/ir/builder.rs index 301a033961..703e19fbce 100644 --- a/extensions/native/compiler/src/ir/builder.rs +++ b/extensions/native/compiler/src/ir/builder.rs @@ -1,6 +1,7 @@ use std::{iter::Zip, vec::IntoIter}; use backtrace::Backtrace; +use itertools::izip; use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; use serde::{Deserialize, Serialize}; @@ -275,13 +276,22 @@ impl Builder { &mut self, start: impl Into>, end: impl Into>, + ) -> IteratorBuilder { + self.range_with_step(start, end, C::N::ONE) + } + /// Evaluate a block of operations over a range from start to end with a custom step. + pub fn range_with_step( + &mut self, + start: impl Into>, + end: impl Into>, + step: C::N, ) -> IteratorBuilder { let start = start.into(); let end0 = end.into(); IteratorBuilder { starts: vec![start], end0, - step_sizes: vec![1], + step_sizes: vec![step], builder: self, } } @@ -295,7 +305,7 @@ impl Builder { IteratorBuilder { starts: vec![RVar::zero(); arrays.len()], end0: arrays[0].len().into(), - step_sizes: vec![1; arrays.len()], + step_sizes: vec![C::N::ONE; arrays.len()], builder: self, } } else if arrays.iter().all(|array| !array.is_fixed()) { @@ -311,7 +321,10 @@ impl Builder { self.eval(arrays[0].ptr().address + len * RVar::from(size)); end.into() }, - step_sizes: arrays.iter().map(|array| array.element_size_of()).collect(), + step_sizes: arrays + .iter() + .map(|array| C::N::from_canonical_usize(array.element_size_of())) + .collect(), builder: self, } } else { @@ -735,7 +748,7 @@ impl IfBuilder<'_, C> { pub struct IteratorBuilder<'a, C: Config> { starts: Vec>, end0: RVar, - step_sizes: Vec, + step_sizes: Vec, builder: &'a mut Builder, } @@ -757,12 +770,20 @@ impl IteratorBuilder<'_, C> { } fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec>, &mut Builder)) { - let starts: Vec = self.starts.iter().map(|start| start.value()).collect(); - let end0 = self.end0.value(); - - for i in (starts[0]..end0).step_by(self.step_sizes[0]) { - let ptrs = vec![i.into(); self.starts.len()]; - f(ptrs, self.builder); + let mut ptrs: Vec<_> = self + .starts + .iter() + .map(|start| start.field_value()) + .collect(); + let end0 = self.end0.field_value(); + while ptrs[0] != end0 { + f( + ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(), + self.builder, + ); + for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) { + *ptr += *step_size; + } } } @@ -772,11 +793,6 @@ impl IteratorBuilder<'_, C> { "Cannot use dynamic loop in static mode" ); - let step_sizes = self - .step_sizes - .iter() - .map(|s| C::N::from_canonical_usize(*s)) - .collect(); let loop_variables: Vec> = (0..self.starts.len()) .map(|_| self.builder.uninit()) .collect(); @@ -791,7 +807,7 @@ impl IteratorBuilder<'_, C> { let op = DslIr::ZipFor( self.starts.clone(), self.end0, - step_sizes, + self.step_sizes.clone(), loop_variables, loop_instructions, ); diff --git a/extensions/native/compiler/src/ir/symbolic.rs b/extensions/native/compiler/src/ir/symbolic.rs index 2aeca05ef9..b92093b114 100644 --- a/extensions/native/compiler/src/ir/symbolic.rs +++ b/extensions/native/compiler/src/ir/symbolic.rs @@ -119,6 +119,12 @@ impl RVar { _ => panic!("RVar::value() called on non-const value"), } } + pub fn field_value(&self) -> N { + match self { + RVar::Const(c) => *c, + _ => panic!("RVar::field_value() called on non-const value"), + } + } } impl Hash for SymbolicVar { diff --git a/extensions/native/recursion/src/fri/two_adic_pcs.rs b/extensions/native/recursion/src/fri/two_adic_pcs.rs index 63abf59dc2..c65e0d4d7c 100644 --- a/extensions/native/recursion/src/fri/two_adic_pcs.rs +++ b/extensions/native/recursion/src/fri/two_adic_pcs.rs @@ -66,17 +66,14 @@ pub fn verify_two_adic_pcs( challenger.observe_slice(builder, final_poly_elem_felts); }); - let num_query_proofs = proof.query_proofs.len().clone(); - builder - .if_ne(num_query_proofs, RVar::from(config.num_queries)) - .then(|builder| { - builder.error(); - }); + // **ATTENTION**: always check shape of user inputs. + builder.assert_eq::>(proof.query_proofs.len(), RVar::from(config.num_queries)); challenger.check_witness(builder, config.proof_of_work_bits, proof.pow_witness); let log_max_height = builder.eval_expr(proof.commit_phase_commits.len() + RVar::from(log_blowup)); + let round_alpha_pows = compute_round_alpha_pows(builder, rounds.clone(), alpha); iter_zip!(builder, proof.query_proofs).for_each(|ptr_vec, builder| { let query_proof = builder.iter_ptr_get(&proof.query_proofs, ptr_vec[0]); @@ -99,159 +96,151 @@ pub fn verify_two_adic_pcs( builder.set_value(&alpha_pow, j, one_ef); } } - let mut alpha_pow_cache = Vec::new(); - - iter_zip!(builder, query_proof.input_proof, rounds).for_each(|ptr_vec, builder| { - let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]); - let round = builder.iter_ptr_get(&rounds, ptr_vec[1]); - let batch_commit = round.batch_commit; - let mats = round.mats; - let permutation = round.permutation; - let to_perm_index = |builder: &mut Builder<_>, k: RVar<_>| { - // Always no permutation in static mode - if builder.flags.static_only { - builder.eval(k) - } else { - let ret: Usize<_> = builder.uninit(); - builder.if_eq(permutation.len(), RVar::zero()).then_or_else( - |builder| { - builder.assign(&ret, k); - }, - |builder| { - let value = builder.get(&permutation, k); - builder.assign(&ret, value); - }, - ); - ret - } - }; - - let log_batch_max_height: Usize<_> = { - let log_batch_max_index = to_perm_index(builder, RVar::zero()); - let mat = builder.get(&mats, log_batch_max_index); - let domain = mat.domain; - builder.eval(domain.log_n + RVar::from(log_blowup)) - }; - - let batch_dims: Array> = builder.array(mats.len()); - // `verify_batch` requires `permed_opened_values` to be in the committed order. - let permed_opened_values = builder.array(batch_opening.opened_values.len()); - builder.range(0, mats.len()).for_each(|k_vec, builder| { - let k = k_vec[0]; - let mat_index = to_perm_index(builder, k); - - let mat = builder.get(&mats, mat_index.clone()); - let domain = mat.domain; - let dim = DimensionsVariable:: { - height: builder.eval(domain.size() * RVar::from(blowup)), - }; - builder.set_value(&batch_dims, k, dim); - let opened_value = builder.get(&batch_opening.opened_values, mat_index); - builder.set_value(&permed_opened_values, k, opened_value); - }); - - let permed_opened_values = NestedOpenedValues::Felt(permed_opened_values); - - let bits_reduced: Usize<_> = builder.eval(log_max_height - log_batch_max_height); - let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced); - - builder.cycle_tracker_start("verify-batch"); - verify_batch::( - builder, - &batch_commit, - batch_dims, - index_bits_shifted_v1, - &permed_opened_values, - &batch_opening.opening_proof, - ); - builder.cycle_tracker_end("verify-batch"); - - builder.cycle_tracker_start("compute-reduced-opening"); - // `verify_challenges` requires `opened_values` to be in the original order. - let opened_values = batch_opening.opened_values; - - iter_zip!(builder, opened_values, mats).for_each(|ptr_vec, builder| { - let mat_opening = builder.iter_ptr_get(&opened_values, ptr_vec[0]); - let mat = builder.iter_ptr_get(&mats, ptr_vec[1]); - let mat_points = mat.points; - let mat_values = mat.values; - let domain = mat.domain; - let log2_domain_size = domain.log_n; - let log_height = builder.eval_expr(log2_domain_size + RVar::from(log_blowup)); - - let cur_ro = builder.get(&ro, log_height); - let cur_alpha_pow = builder.get(&alpha_pow, log_height); - - let bits_reduced: Usize<_> = builder.eval(log_max_height - log_height); - let index_bits_shifted = index_bits.shift(builder, bits_reduced.clone()); - - let two_adic_generator = config.get_two_adic_generator(builder, log_height); - builder.cycle_tracker_start("exp-reverse-bits-len"); - - let index_bits_shifted_truncated = index_bits_shifted.slice(builder, 0, log_height); - let two_adic_generator_exp = - builder.exp_bits_big_endian(two_adic_generator, &index_bits_shifted_truncated); - builder.cycle_tracker_end("exp-reverse-bits-len"); - let x: Felt = builder.eval(two_adic_generator_exp * g); - - iter_zip!(builder, mat_points, mat_values).for_each(|ptr_vec, builder| { - let z: Ext = builder.iter_ptr_get(&mat_points, ptr_vec[0]); - let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]); - - builder.cycle_tracker_start("single-reduced-opening-eval"); - + // **ATTENTION**: always check shape of user inputs. + builder.assert_eq::>(query_proof.input_proof.len(), rounds.len()); + iter_zip!(builder, query_proof.input_proof, rounds, round_alpha_pows).for_each( + |ptr_vec, builder| { + let batch_opening = builder.iter_ptr_get(&query_proof.input_proof, ptr_vec[0]); + let round = builder.iter_ptr_get(&rounds, ptr_vec[1]); + let mat_alpha_pows = builder.iter_ptr_get(&round_alpha_pows, ptr_vec[2]); + let batch_commit = round.batch_commit; + let mats = round.mats; + let permutation = round.permutation; + // `verify_challenges` requires `opened_values` to be in the original order. + let opened_values = batch_opening.opened_values; + // **ATTENTION**: always check shape of user inputs. + builder.assert_eq::>(opened_values.len(), rounds.len()); + let to_perm_index = |builder: &mut Builder<_>, k: RVar<_>| { + // Always no permutation in static mode if builder.flags.static_only { - let n: Ext = builder.constant(C::EF::ZERO); - builder.range(0, ps_at_z.len()).for_each(|t_vec, builder| { - let t = t_vec[0]; - let p_at_x = builder.get(&mat_opening, t); - let p_at_z = builder.get(&ps_at_z, t); - - if ptr_vec[0].value() == 0 { - if t.value() == 0 && alpha_pow_cache.is_empty() { - alpha_pow_cache.push(builder.constant(C::EF::ONE)); - } else if t.value() >= alpha_pow_cache.len() { - let next: Ext<_, _> = builder.uninit(); - alpha_pow_cache.push(next); - builder.assign( - &alpha_pow_cache[t.value()], - alpha_pow_cache[t.value() - 1] * alpha, - ); - } - } - builder.assign(&n, (p_at_z - p_at_x) * alpha_pow_cache[t.value()] + n); - }); - if ps_at_z.len().value() >= alpha_pow_cache.len() { - let next: Ext<_, _> = builder.uninit(); - alpha_pow_cache.push(next); - builder.assign( - &alpha_pow_cache[ps_at_z.len().value()], - alpha_pow_cache[ps_at_z.len().value() - 1] * alpha, - ); - } - builder.assign(&cur_ro, cur_ro + cur_alpha_pow * n / (z - x)); - builder.assign( - &cur_alpha_pow, - cur_alpha_pow * alpha_pow_cache[ps_at_z.len().value()], - ); + builder.eval(k) } else { - let mat_ro = builder.fri_single_reduced_opening_eval( - alpha, - cur_alpha_pow, - &mat_opening, - &ps_at_z, + let ret: Usize<_> = builder.uninit(); + builder.if_eq(permutation.len(), RVar::zero()).then_or_else( + |builder| { + builder.assign(&ret, k); + }, + |builder| { + let value = builder.get(&permutation, k); + builder.assign(&ret, value); + }, ); - builder.assign(&cur_ro, cur_ro + (mat_ro / (z - x))); + ret } + }; - builder.cycle_tracker_end("single-reduced-opening-eval"); + let log_batch_max_height: Usize<_> = { + let log_batch_max_index = to_perm_index(builder, RVar::zero()); + let mat = builder.get(&mats, log_batch_max_index); + let domain = mat.domain; + builder.eval(domain.log_n + RVar::from(log_blowup)) + }; + + let batch_dims: Array> = builder.array(mats.len()); + // `verify_batch` requires `permed_opened_values` to be in the committed order. + let permed_opened_values = builder.array(opened_values.len()); + builder.range(0, mats.len()).for_each(|k_vec, builder| { + let k = k_vec[0]; + let mat_index = to_perm_index(builder, k); + + let mat = builder.get(&mats, mat_index.clone()); + let domain = mat.domain; + let dim = DimensionsVariable:: { + height: builder.eval(domain.size() * RVar::from(blowup)), + }; + builder.set_value(&batch_dims, k, dim); + let opened_value = builder.get(&opened_values, mat_index); + builder.set_value(&permed_opened_values, k, opened_value); }); - builder.set_value(&ro, log_height, cur_ro); - builder.set_value(&alpha_pow, log_height, cur_alpha_pow); - }); - builder.cycle_tracker_end("compute-reduced-opening"); - }); + let permed_opened_values = NestedOpenedValues::Felt(permed_opened_values); + + let bits_reduced: Usize<_> = builder.eval(log_max_height - log_batch_max_height); + let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced); + + builder.cycle_tracker_start("verify-batch"); + verify_batch::( + builder, + &batch_commit, + batch_dims, + index_bits_shifted_v1, + &permed_opened_values, + &batch_opening.opening_proof, + ); + builder.cycle_tracker_end("verify-batch"); + + builder.cycle_tracker_start("compute-reduced-opening"); + + iter_zip!(builder, opened_values, mats, mat_alpha_pows).for_each( + |ptr_vec, builder| { + let mat_opening = builder.iter_ptr_get(&opened_values, ptr_vec[0]); + let mat = builder.iter_ptr_get(&mats, ptr_vec[1]); + let mat_alpha_pow = builder.iter_ptr_get(&mat_alpha_pows, ptr_vec[2]); + let mat_points = mat.points; + let mat_values = mat.values; + let domain = mat.domain; + let log2_domain_size = domain.log_n; + let log_height = + builder.eval_expr(log2_domain_size + RVar::from(log_blowup)); + + let cur_ro = builder.get(&ro, log_height); + let cur_alpha_pow = builder.get(&alpha_pow, log_height); + + let bits_reduced: Usize<_> = builder.eval(log_max_height - log_height); + let index_bits_shifted = index_bits.shift(builder, bits_reduced.clone()); + + let two_adic_generator = config.get_two_adic_generator(builder, log_height); + builder.cycle_tracker_start("exp-reverse-bits-len"); + + let index_bits_shifted_truncated = + index_bits_shifted.slice(builder, 0, log_height); + let two_adic_generator_exp = builder + .exp_bits_big_endian(two_adic_generator, &index_bits_shifted_truncated); + builder.cycle_tracker_end("exp-reverse-bits-len"); + let x: Felt = builder.eval(two_adic_generator_exp * g); + + iter_zip!(builder, mat_points, mat_values).for_each(|ptr_vec, builder| { + let z: Ext = builder.iter_ptr_get(&mat_points, ptr_vec[0]); + let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]); + + builder.cycle_tracker_start("single-reduced-opening-eval"); + + if builder.flags.static_only { + let n: Ext = builder.constant(C::EF::ZERO); + builder + .range_with_step(ps_at_z.len(), 0, C::N::NEG_ONE) + .for_each(|t_vec, builder| { + let t = t_vec[0]; + let p_at_x = builder.get(&mat_opening, t); + let p_at_z = builder.get(&ps_at_z, t); + builder.assign(&n, n * alpha + (p_at_z - p_at_x)); + }); + builder.assign(&cur_ro, cur_ro + cur_alpha_pow * n / (z - x)); + builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow); + } else { + // TODO: this is just for testing the correctness. Will remove later. + let expected_alpha_pow: Ext<_, _> = + builder.eval(cur_alpha_pow * mat_alpha_pow); + let mat_ro = builder.fri_single_reduced_opening_eval( + alpha, + cur_alpha_pow, + &mat_opening, + &ps_at_z, + ); + builder.assert_ext_eq(expected_alpha_pow, cur_alpha_pow); + builder.assign(&cur_ro, cur_ro + (mat_ro / (z - x))); + } + + builder.cycle_tracker_end("single-reduced-opening-eval"); + }); + + builder.set_value(&ro, log_height, cur_ro); + builder.set_value(&alpha_pow, log_height, cur_alpha_pow); + }, + ); + builder.cycle_tracker_end("compute-reduced-opening"); + }, + ); let folded_eval = verify_query( builder, @@ -363,6 +352,59 @@ where } } +#[allow(clippy::type_complexity)] +fn compute_round_alpha_pows( + builder: &mut Builder, + rounds: Array>, + alpha: Ext, +) -> Array>> { + // Max log of matrix width + // TODO: this should be determined by VK. + const MAX_LOG_WIDTH: usize = 15; + let pow_of_alpha = builder.array(MAX_LOG_WIDTH); + let current: Ext<_, _> = builder.eval(alpha); + for i in 0..MAX_LOG_WIDTH { + builder.set_value(&pow_of_alpha, i, current); + builder.assign(¤t, current * current); + } + let round_alpha_pows: Array>> = builder.array(rounds.len()); + iter_zip!(builder, rounds, round_alpha_pows).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&rounds, ptr_vec[0]); + let mat_alpha_pows: Array> = builder.array(round.mats.len()); + iter_zip!(builder, round.mats, mat_alpha_pows).for_each(|ptr_vec, builder| { + let mat = builder.iter_ptr_get(&round.mats, ptr_vec[0]); + let local = builder.get(&mat.values, 0); + let width = local.len(); + let mat_alpha_pow: Ext<_, _> = if builder.flags.static_only { + let width = width.value(); + let mut ret = C::EF::ONE.cons(); + for i in 0..MAX_LOG_WIDTH { + if width & (1 << i) != 0 { + ret *= builder.get(&pow_of_alpha, i); + } + } + builder.eval(ret) + } else { + let width = width.get_var(); + // This is dynamic only so safe to cast. + let width_f = builder.unsafe_cast_var_to_felt(width); + let bits = builder.num2bits_f(width_f, MAX_LOG_WIDTH as u32); + let ret: Ext<_, _> = builder.eval(C::EF::ONE.cons()); + for i in 0..MAX_LOG_WIDTH { + let bit = builder.get(&bits, i); + builder.if_eq(bit, RVar::one()).then(|builder| { + let to_mul = builder.get(&pow_of_alpha, i); + builder.assign(&ret, ret * to_mul); + }); + } + ret + }; + builder.iter_ptr_set(&mat_alpha_pows, ptr_vec[1], mat_alpha_pow); + }); + builder.iter_ptr_set(&round_alpha_pows, ptr_vec[1], mat_alpha_pows); + }); + round_alpha_pows +} pub mod tests { use std::cmp::Reverse;