Skip to content

Commit

Permalink
refactor: remove usage of unstable generic_const_exprs in starky (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
eightfilms authored Oct 24, 2023
1 parent 8af189b commit 8326db6
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 177 deletions.
67 changes: 67 additions & 0 deletions starky/src/evaluation_frame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/// A trait for viewing an evaluation frame of a STARK table.
///
/// It allows to access the current and next rows at a given step
/// and can be used to implement constraint evaluation both natively
/// and recursively.
pub trait StarkEvaluationFrame<T: Copy + Clone + Default, U: Copy + Clone + Default>:
Sized
{
/// The number of columns for the STARK table this evaluation frame views.
const COLUMNS: usize;
const PUBLIC_INPUTS: usize;

/// Returns the local values (i.e. current row) for this evaluation frame.
fn get_local_values(&self) -> &[T];
/// Returns the next values (i.e. next row) for this evaluation frame.
fn get_next_values(&self) -> &[T];

fn get_public_inputs(&self) -> &[U];

/// Outputs a new evaluation frame from the provided local and next values.
///
/// **NOTE**: Concrete implementations of this method SHOULD ensure that
/// the provided slices lengths match the `Self::COLUMNS` value.
fn from_values(lv: &[T], nv: &[T], pis: &[U]) -> Self;
}

pub struct StarkFrame<
T: Copy + Clone + Default,
U: Copy + Clone + Default,
const N: usize,
const N2: usize,
> {
local_values: [T; N],
next_values: [T; N],
public_inputs: [U; N2],
}

impl<T: Copy + Clone + Default, U: Copy + Clone + Default, const N: usize, const N2: usize>
StarkEvaluationFrame<T, U> for StarkFrame<T, U, N, N2>
{
const COLUMNS: usize = N;
const PUBLIC_INPUTS: usize = N2;

fn get_local_values(&self) -> &[T] {
&self.local_values
}

fn get_next_values(&self) -> &[T] {
&self.next_values
}

fn get_public_inputs(&self) -> &[U] {
&self.public_inputs
}

fn from_values(lv: &[T], nv: &[T], pis: &[U]) -> Self {
assert_eq!(lv.len(), Self::COLUMNS);
assert_eq!(nv.len(), Self::COLUMNS);
assert_eq!(pis.len(), Self::PUBLIC_INPUTS);

Self {
local_values: lv.try_into().unwrap(),
next_values: nv.try_into().unwrap(),
public_inputs: pis.try_into().unwrap(),
}
}
}
59 changes: 34 additions & 25 deletions starky/src/fibonacci_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};

/// Toy STARK system used for testing.
/// Computes a Fibonacci sequence with state `[x0, x1, i, j]` using the state transition
Expand Down Expand Up @@ -57,57 +58,67 @@ impl<F: RichField + Extendable<D>, const D: usize> FibonacciStark<F, D> {
}
}

const COLUMNS: usize = 4;
const PUBLIC_INPUTS: usize = 3;

impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStark<F, D> {
const COLUMNS: usize = 4;
const PUBLIC_INPUTS: usize = 3;
type EvaluationFrame<FE, P, const D2: usize> = StarkFrame<P, P::Scalar, COLUMNS, PUBLIC_INPUTS>
where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>;

type EvaluationFrameTarget =
StarkFrame<ExtensionTarget<D>, ExtensionTarget<D>, COLUMNS, PUBLIC_INPUTS>;

fn eval_packed_generic<FE, P, const D2: usize>(
&self,
vars: StarkEvaluationVars<FE, P, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
vars: &Self::EvaluationFrame<FE, P, D2>,
yield_constr: &mut ConstraintConsumer<P>,
) where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();
let public_inputs = vars.get_public_inputs();

// Check public inputs.
yield_constr
.constraint_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]);
yield_constr
.constraint_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]);
yield_constr
.constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]);
yield_constr.constraint_first_row(local_values[0] - public_inputs[Self::PI_INDEX_X0]);
yield_constr.constraint_first_row(local_values[1] - public_inputs[Self::PI_INDEX_X1]);
yield_constr.constraint_last_row(local_values[1] - public_inputs[Self::PI_INDEX_RES]);

// x0' <- x1
yield_constr.constraint_transition(vars.next_values[0] - vars.local_values[1]);
yield_constr.constraint_transition(next_values[0] - local_values[1]);
// x1' <- x0 + x1
yield_constr.constraint_transition(
vars.next_values[1] - vars.local_values[0] - vars.local_values[1],
);
yield_constr.constraint_transition(next_values[1] - local_values[0] - local_values[1]);
}

fn eval_ext_circuit(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
vars: &Self::EvaluationFrameTarget,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();
let public_inputs = vars.get_public_inputs();
// Check public inputs.
let pis_constraints = [
builder.sub_extension(vars.local_values[0], vars.public_inputs[Self::PI_INDEX_X0]),
builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_X1]),
builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_RES]),
builder.sub_extension(local_values[0], public_inputs[Self::PI_INDEX_X0]),
builder.sub_extension(local_values[1], public_inputs[Self::PI_INDEX_X1]),
builder.sub_extension(local_values[1], public_inputs[Self::PI_INDEX_RES]),
];
yield_constr.constraint_first_row(builder, pis_constraints[0]);
yield_constr.constraint_first_row(builder, pis_constraints[1]);
yield_constr.constraint_last_row(builder, pis_constraints[2]);

// x0' <- x1
let first_col_constraint = builder.sub_extension(vars.next_values[0], vars.local_values[1]);
let first_col_constraint = builder.sub_extension(next_values[0], local_values[1]);
yield_constr.constraint_transition(builder, first_col_constraint);
// x1' <- x0 + x1
let second_col_constraint = {
let tmp = builder.sub_extension(vars.next_values[1], vars.local_values[0]);
builder.sub_extension(tmp, vars.local_values[1])
let tmp = builder.sub_extension(next_values[1], local_values[0]);
builder.sub_extension(tmp, local_values[1])
};
yield_constr.constraint_transition(builder, second_col_constraint);
}
Expand Down Expand Up @@ -165,7 +176,7 @@ mod tests {
stark,
&config,
trace,
public_inputs,
&public_inputs,
&mut TimingTree::default(),
)?;

Expand Down Expand Up @@ -213,7 +224,7 @@ mod tests {
stark,
&config,
trace,
public_inputs,
&public_inputs,
&mut TimingTree::default(),
)?;
verify_stark_proof(stark, proof.clone(), &config)?;
Expand All @@ -235,8 +246,6 @@ mod tests {
) -> Result<()>
where
InnerC::Hasher: AlgebraicHasher<F>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let circuit_config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(circuit_config);
Expand Down
4 changes: 1 addition & 3 deletions starky/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![allow(incomplete_features)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::type_complexity)]
#![feature(generic_const_exprs)]
#![cfg_attr(not(feature = "std"), no_std)]

extern crate alloc;
Expand All @@ -10,6 +8,7 @@ mod get_challenges;

pub mod config;
pub mod constraint_consumer;
pub mod evaluation_frame;
pub mod permutation;
pub mod proof;
pub mod prover;
Expand All @@ -18,7 +17,6 @@ pub mod stark;
pub mod stark_testing;
pub mod util;
pub mod vanishing_poly;
pub mod vars;
pub mod verifier;

#[cfg(test)]
Expand Down
18 changes: 9 additions & 9 deletions starky/src/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use plonky2_maybe_rayon::*;

use crate::config::StarkConfig;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::evaluation_frame::StarkEvaluationFrame;
use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};

/// A pair of lists of columns, `lhs` and `rhs`, that should be permutations of one another.
/// In particular, there should exist some permutation `pi` such that for any `i`,
Expand Down Expand Up @@ -262,17 +262,17 @@ where
pub(crate) fn eval_permutation_checks<F, FE, P, S, const D: usize, const D2: usize>(
stark: &S,
config: &StarkConfig,
vars: StarkEvaluationVars<FE, P, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
vars: &S::EvaluationFrame<FE, P, D2>,
permutation_data: PermutationCheckVars<F, FE, P, D2>,
consumer: &mut ConstraintConsumer<P>,
) where
F: RichField + Extendable<D>,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let local_values = vars.get_local_values();

let PermutationCheckVars {
local_zs,
next_zs,
Expand Down Expand Up @@ -306,7 +306,7 @@ pub(crate) fn eval_permutation_checks<F, FE, P, S, const D: usize, const D2: usi
let mut factor = ReducingFactor::new(*beta);
let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs
.iter()
.map(|&(i, j)| (vars.local_values[i], vars.local_values[j]))
.map(|&(i, j)| (local_values[i], local_values[j]))
.unzip();
(
factor.reduce_ext(lhs.into_iter()) + FE::from_basefield(*gamma),
Expand All @@ -330,15 +330,15 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
stark: &S,
config: &StarkConfig,
vars: StarkEvaluationTargets<D, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
vars: &S::EvaluationFrameTarget,
permutation_data: PermutationCheckDataTarget<D>,
consumer: &mut RecursiveConstraintConsumer<F, D>,
) where
F: RichField + Extendable<D>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let local_values = vars.get_local_values();

let PermutationCheckDataTarget {
local_zs,
next_zs,
Expand Down Expand Up @@ -376,7 +376,7 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
let mut factor = ReducingFactorTarget::new(beta_ext);
let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs
.iter()
.map(|&(i, j)| (vars.local_values[i], vars.local_values[j]))
.map(|&(i, j)| (local_values[i], local_values[j]))
.unzip();
let reduced_lhs = factor.reduce(&lhs, builder);
let reduced_rhs = factor.reduce(&rhs, builder);
Expand Down
30 changes: 11 additions & 19 deletions starky/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,26 @@ use plonky2_maybe_rayon::*;

use crate::config::StarkConfig;
use crate::constraint_consumer::ConstraintConsumer;
use crate::evaluation_frame::StarkEvaluationFrame;
use crate::permutation::{
compute_permutation_z_polys, get_n_permutation_challenge_sets, PermutationChallengeSet,
PermutationCheckVars,
};
use crate::proof::{StarkOpeningSet, StarkProof, StarkProofWithPublicInputs};
use crate::stark::Stark;
use crate::vanishing_poly::eval_vanishing_poly;
use crate::vars::StarkEvaluationVars;

pub fn prove<F, C, S, const D: usize>(
stark: S,
config: &StarkConfig,
trace_poly_values: Vec<PolynomialValues<F>>,
public_inputs: [F; S::PUBLIC_INPUTS],
public_inputs: &[F],
timing: &mut TimingTree,
) -> Result<StarkProofWithPublicInputs<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let degree = trace_poly_values[0].len();
let degree_bits = log2_strict(degree);
Expand Down Expand Up @@ -202,7 +200,7 @@ fn compute_quotient_polys<'a, F, P, C, S, const D: usize>(
PolynomialBatch<F, C, D>,
Vec<PermutationChallengeSet<F>>,
)>,
public_inputs: [F; S::PUBLIC_INPUTS],
public_inputs: &[F],
alphas: Vec<F>,
degree_bits: usize,
config: &StarkConfig,
Expand All @@ -212,8 +210,6 @@ where
P: PackedField<Scalar = F>,
C: GenericConfig<D, F = F>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let degree = 1 << degree_bits;
let rate_bits = config.fri_config.rate_bits;
Expand All @@ -236,12 +232,8 @@ where
let z_h_on_coset = ZeroPolyOnCoset::<F>::new(degree_bits, quotient_degree_bits);

// Retrieve the LDE values at index `i`.
let get_trace_values_packed = |i_start| -> [P; S::COLUMNS] {
trace_commitment
.get_lde_values_packed(i_start, step)
.try_into()
.unwrap()
};
let get_trace_values_packed =
|i_start| -> Vec<P> { trace_commitment.get_lde_values_packed(i_start, step) };

// Last element of the subgroup.
let last = F::primitive_root_of_unity(degree_bits).inverse();
Expand Down Expand Up @@ -272,11 +264,11 @@ where
lagrange_basis_first,
lagrange_basis_last,
);
let vars = StarkEvaluationVars {
local_values: &get_trace_values_packed(i_start),
next_values: &get_trace_values_packed(i_next_start),
public_inputs: &public_inputs,
};
let vars = S::EvaluationFrame::from_values(
&get_trace_values_packed(i_start),
&get_trace_values_packed(i_next_start),
public_inputs,
);
let permutation_check_data = permutation_zs_commitment_challenges.as_ref().map(
|(permutation_zs_commitment, permutation_challenge_sets)| PermutationCheckVars {
local_zs: permutation_zs_commitment.get_lde_values_packed(i_start, step),
Expand All @@ -287,7 +279,7 @@ where
eval_vanishing_poly::<F, F, P, S, D, 1>(
stark,
config,
vars,
&vars,
permutation_check_data,
&mut consumer,
);
Expand Down
Loading

0 comments on commit 8326db6

Please sign in to comment.