Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove usage of unstable generic_const_exprs in starky #1300

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions starky/src/evaluation_frame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/// A trait for viewing an evaluation frame of a STARK table.
///
/// It allows to access the current and next rows at a given step
/// and can be used to implement constraint evaluation both natively
/// and recursively.
pub trait StarkEvaluationFrame<T: Copy + Clone + Default, U: Copy + Clone + Default>:
Sized
{
/// The number of columns for the STARK table this evaluation frame views.
const COLUMNS: usize;
const PUBLIC_INPUTS: usize;

/// Returns the local values (i.e. current row) for this evaluation frame.
fn get_local_values(&self) -> &[T];
/// Returns the next values (i.e. next row) for this evaluation frame.
fn get_next_values(&self) -> &[T];

fn get_public_inputs(&self) -> &[U];

/// Outputs a new evaluation frame from the provided local and next values.
///
/// **NOTE**: Concrete implementations of this method SHOULD ensure that
/// the provided slices lengths match the `Self::COLUMNS` value.
fn from_values(lv: &[T], nv: &[T], pis: &[U]) -> Self;
}

pub struct StarkFrame<
T: Copy + Clone + Default,
U: Copy + Clone + Default,
const N: usize,
const N2: usize,
> {
local_values: [T; N],
next_values: [T; N],
public_inputs: [U; N2],
}

impl<T: Copy + Clone + Default, U: Copy + Clone + Default, const N: usize, const N2: usize>
StarkEvaluationFrame<T, U> for StarkFrame<T, U, N, N2>
{
const COLUMNS: usize = N;
const PUBLIC_INPUTS: usize = N2;

fn get_local_values(&self) -> &[T] {
&self.local_values
}

fn get_next_values(&self) -> &[T] {
&self.next_values
}

fn get_public_inputs(&self) -> &[U] {
&self.public_inputs
}

fn from_values(lv: &[T], nv: &[T], pis: &[U]) -> Self {
assert_eq!(lv.len(), Self::COLUMNS);
assert_eq!(nv.len(), Self::COLUMNS);
assert_eq!(pis.len(), Self::PUBLIC_INPUTS);

Self {
local_values: lv.try_into().unwrap(),
next_values: nv.try_into().unwrap(),
public_inputs: pis.try_into().unwrap(),
}
}
}
59 changes: 34 additions & 25 deletions starky/src/fibonacci_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};

/// Toy STARK system used for testing.
/// Computes a Fibonacci sequence with state `[x0, x1, i, j]` using the state transition
Expand Down Expand Up @@ -57,57 +58,67 @@ impl<F: RichField + Extendable<D>, const D: usize> FibonacciStark<F, D> {
}
}

const COLUMNS: usize = 4;
const PUBLIC_INPUTS: usize = 3;

impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for FibonacciStark<F, D> {
const COLUMNS: usize = 4;
const PUBLIC_INPUTS: usize = 3;
type EvaluationFrame<FE, P, const D2: usize> = StarkFrame<P, P::Scalar, COLUMNS, PUBLIC_INPUTS>
where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>;

type EvaluationFrameTarget =
StarkFrame<ExtensionTarget<D>, ExtensionTarget<D>, COLUMNS, PUBLIC_INPUTS>;

fn eval_packed_generic<FE, P, const D2: usize>(
&self,
vars: StarkEvaluationVars<FE, P, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
vars: &Self::EvaluationFrame<FE, P, D2>,
yield_constr: &mut ConstraintConsumer<P>,
) where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();
let public_inputs = vars.get_public_inputs();

// Check public inputs.
yield_constr
.constraint_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]);
yield_constr
.constraint_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]);
yield_constr
.constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]);
yield_constr.constraint_first_row(local_values[0] - public_inputs[Self::PI_INDEX_X0]);
yield_constr.constraint_first_row(local_values[1] - public_inputs[Self::PI_INDEX_X1]);
yield_constr.constraint_last_row(local_values[1] - public_inputs[Self::PI_INDEX_RES]);

// x0' <- x1
yield_constr.constraint_transition(vars.next_values[0] - vars.local_values[1]);
yield_constr.constraint_transition(next_values[0] - local_values[1]);
// x1' <- x0 + x1
yield_constr.constraint_transition(
vars.next_values[1] - vars.local_values[0] - vars.local_values[1],
);
yield_constr.constraint_transition(next_values[1] - local_values[0] - local_values[1]);
}

fn eval_ext_circuit(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, { Self::COLUMNS }, { Self::PUBLIC_INPUTS }>,
vars: &Self::EvaluationFrameTarget,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();
let public_inputs = vars.get_public_inputs();
// Check public inputs.
let pis_constraints = [
builder.sub_extension(vars.local_values[0], vars.public_inputs[Self::PI_INDEX_X0]),
builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_X1]),
builder.sub_extension(vars.local_values[1], vars.public_inputs[Self::PI_INDEX_RES]),
builder.sub_extension(local_values[0], public_inputs[Self::PI_INDEX_X0]),
builder.sub_extension(local_values[1], public_inputs[Self::PI_INDEX_X1]),
builder.sub_extension(local_values[1], public_inputs[Self::PI_INDEX_RES]),
];
yield_constr.constraint_first_row(builder, pis_constraints[0]);
yield_constr.constraint_first_row(builder, pis_constraints[1]);
yield_constr.constraint_last_row(builder, pis_constraints[2]);

// x0' <- x1
let first_col_constraint = builder.sub_extension(vars.next_values[0], vars.local_values[1]);
let first_col_constraint = builder.sub_extension(next_values[0], local_values[1]);
yield_constr.constraint_transition(builder, first_col_constraint);
// x1' <- x0 + x1
let second_col_constraint = {
let tmp = builder.sub_extension(vars.next_values[1], vars.local_values[0]);
builder.sub_extension(tmp, vars.local_values[1])
let tmp = builder.sub_extension(next_values[1], local_values[0]);
builder.sub_extension(tmp, local_values[1])
};
yield_constr.constraint_transition(builder, second_col_constraint);
}
Expand Down Expand Up @@ -165,7 +176,7 @@ mod tests {
stark,
&config,
trace,
public_inputs,
&public_inputs,
&mut TimingTree::default(),
)?;

Expand Down Expand Up @@ -213,7 +224,7 @@ mod tests {
stark,
&config,
trace,
public_inputs,
&public_inputs,
&mut TimingTree::default(),
)?;
verify_stark_proof(stark, proof.clone(), &config)?;
Expand All @@ -235,8 +246,6 @@ mod tests {
) -> Result<()>
where
InnerC::Hasher: AlgebraicHasher<F>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let circuit_config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(circuit_config);
Expand Down
4 changes: 1 addition & 3 deletions starky/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![allow(incomplete_features)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::type_complexity)]
#![feature(generic_const_exprs)]
#![cfg_attr(not(feature = "std"), no_std)]

extern crate alloc;
Expand All @@ -10,6 +8,7 @@ mod get_challenges;

pub mod config;
pub mod constraint_consumer;
pub mod evaluation_frame;
pub mod permutation;
pub mod proof;
pub mod prover;
Expand All @@ -18,7 +17,6 @@ pub mod stark;
pub mod stark_testing;
pub mod util;
pub mod vanishing_poly;
pub mod vars;
pub mod verifier;

#[cfg(test)]
Expand Down
18 changes: 9 additions & 9 deletions starky/src/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use plonky2_maybe_rayon::*;

use crate::config::StarkConfig;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::evaluation_frame::StarkEvaluationFrame;
use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};

/// A pair of lists of columns, `lhs` and `rhs`, that should be permutations of one another.
/// In particular, there should exist some permutation `pi` such that for any `i`,
Expand Down Expand Up @@ -262,17 +262,17 @@ where
pub(crate) fn eval_permutation_checks<F, FE, P, S, const D: usize, const D2: usize>(
stark: &S,
config: &StarkConfig,
vars: StarkEvaluationVars<FE, P, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
vars: &S::EvaluationFrame<FE, P, D2>,
permutation_data: PermutationCheckVars<F, FE, P, D2>,
consumer: &mut ConstraintConsumer<P>,
) where
F: RichField + Extendable<D>,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let local_values = vars.get_local_values();

let PermutationCheckVars {
local_zs,
next_zs,
Expand Down Expand Up @@ -306,7 +306,7 @@ pub(crate) fn eval_permutation_checks<F, FE, P, S, const D: usize, const D2: usi
let mut factor = ReducingFactor::new(*beta);
let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs
.iter()
.map(|&(i, j)| (vars.local_values[i], vars.local_values[j]))
.map(|&(i, j)| (local_values[i], local_values[j]))
.unzip();
(
factor.reduce_ext(lhs.into_iter()) + FE::from_basefield(*gamma),
Expand All @@ -330,15 +330,15 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
stark: &S,
config: &StarkConfig,
vars: StarkEvaluationTargets<D, { S::COLUMNS }, { S::PUBLIC_INPUTS }>,
vars: &S::EvaluationFrameTarget,
permutation_data: PermutationCheckDataTarget<D>,
consumer: &mut RecursiveConstraintConsumer<F, D>,
) where
F: RichField + Extendable<D>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let local_values = vars.get_local_values();

let PermutationCheckDataTarget {
local_zs,
next_zs,
Expand Down Expand Up @@ -376,7 +376,7 @@ pub(crate) fn eval_permutation_checks_circuit<F, S, const D: usize>(
let mut factor = ReducingFactorTarget::new(beta_ext);
let (lhs, rhs): (Vec<_>, Vec<_>) = column_pairs
.iter()
.map(|&(i, j)| (vars.local_values[i], vars.local_values[j]))
.map(|&(i, j)| (local_values[i], local_values[j]))
.unzip();
let reduced_lhs = factor.reduce(&lhs, builder);
let reduced_rhs = factor.reduce(&rhs, builder);
Expand Down
30 changes: 11 additions & 19 deletions starky/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,26 @@ use plonky2_maybe_rayon::*;

use crate::config::StarkConfig;
use crate::constraint_consumer::ConstraintConsumer;
use crate::evaluation_frame::StarkEvaluationFrame;
use crate::permutation::{
compute_permutation_z_polys, get_n_permutation_challenge_sets, PermutationChallengeSet,
PermutationCheckVars,
};
use crate::proof::{StarkOpeningSet, StarkProof, StarkProofWithPublicInputs};
use crate::stark::Stark;
use crate::vanishing_poly::eval_vanishing_poly;
use crate::vars::StarkEvaluationVars;

pub fn prove<F, C, S, const D: usize>(
stark: S,
config: &StarkConfig,
trace_poly_values: Vec<PolynomialValues<F>>,
public_inputs: [F; S::PUBLIC_INPUTS],
public_inputs: &[F],
timing: &mut TimingTree,
) -> Result<StarkProofWithPublicInputs<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let degree = trace_poly_values[0].len();
let degree_bits = log2_strict(degree);
Expand Down Expand Up @@ -202,7 +200,7 @@ fn compute_quotient_polys<'a, F, P, C, S, const D: usize>(
PolynomialBatch<F, C, D>,
Vec<PermutationChallengeSet<F>>,
)>,
public_inputs: [F; S::PUBLIC_INPUTS],
public_inputs: &[F],
alphas: Vec<F>,
degree_bits: usize,
config: &StarkConfig,
Expand All @@ -212,8 +210,6 @@ where
P: PackedField<Scalar = F>,
C: GenericConfig<D, F = F>,
S: Stark<F, D>,
[(); S::COLUMNS]:,
[(); S::PUBLIC_INPUTS]:,
{
let degree = 1 << degree_bits;
let rate_bits = config.fri_config.rate_bits;
Expand All @@ -236,12 +232,8 @@ where
let z_h_on_coset = ZeroPolyOnCoset::<F>::new(degree_bits, quotient_degree_bits);

// Retrieve the LDE values at index `i`.
let get_trace_values_packed = |i_start| -> [P; S::COLUMNS] {
trace_commitment
.get_lde_values_packed(i_start, step)
.try_into()
.unwrap()
};
let get_trace_values_packed =
|i_start| -> Vec<P> { trace_commitment.get_lde_values_packed(i_start, step) };

// Last element of the subgroup.
let last = F::primitive_root_of_unity(degree_bits).inverse();
Expand Down Expand Up @@ -272,11 +264,11 @@ where
lagrange_basis_first,
lagrange_basis_last,
);
let vars = StarkEvaluationVars {
local_values: &get_trace_values_packed(i_start),
next_values: &get_trace_values_packed(i_next_start),
public_inputs: &public_inputs,
};
let vars = S::EvaluationFrame::from_values(
&get_trace_values_packed(i_start),
&get_trace_values_packed(i_next_start),
public_inputs,
);
let permutation_check_data = permutation_zs_commitment_challenges.as_ref().map(
|(permutation_zs_commitment, permutation_challenge_sets)| PermutationCheckVars {
local_zs: permutation_zs_commitment.get_lde_values_packed(i_start, step),
Expand All @@ -287,7 +279,7 @@ where
eval_vanishing_poly::<F, F, P, S, D, 1>(
stark,
config,
vars,
&vars,
permutation_check_data,
&mut consumer,
);
Expand Down
Loading
Loading