From 6258838933dd2fd0a51934185c903058df123d7b Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Wed, 22 Jan 2025 22:49:10 +0800 Subject: [PATCH] Refactor FriReducedOpeningChip --- Cargo.lock | 1 + crates/vm/src/arch/testing/mod.rs | 12 + crates/vm/src/utils/stark_utils.rs | 8 +- crates/vm/tests/integration_test.rs | 2 +- extensions/native/circuit/Cargo.toml | 1 + extensions/native/circuit/src/fri/mod.rs | 898 ++++++++++-------- extensions/native/circuit/src/fri/tests.rs | 42 +- extensions/native/circuit/src/utils.rs | 3 +- .../native/compiler/src/asm/compiler.rs | 3 +- .../native/compiler/src/asm/instruction.rs | 10 +- .../native/compiler/src/conversion/mod.rs | 4 +- extensions/native/compiler/src/ir/fri.rs | 2 - .../native/compiler/src/ir/instructions.rs | 3 +- .../native/compiler/tests/fri_ro_eval.rs | 9 +- .../native/recursion/src/fri/two_adic_pcs.rs | 9 +- 15 files changed, 584 insertions(+), 423 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c135ac1d5d..127c328d9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3952,6 +3952,7 @@ dependencies = [ "rand", "serde", "serde-big-array", + "static_assertions", "strum", "test-case", "test-log", diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 0636147c57..4d6a9f9190 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -256,6 +256,18 @@ impl VmChipTestBuilder { let tester = tester.load(self.execution); tester.load(self.program) } + pub fn build_babybear_poseidon2(self) -> VmChipTester { + self.memory + .controller + .borrow_mut() + .finalize(None::<&mut Poseidon2PeripheryChip>); + let tester = VmChipTester { + memory: Some(self.memory), + ..Default::default() + }; + let tester = tester.load(self.execution); + tester.load(self.program) + } } impl Default for VmChipTestBuilder { diff --git a/crates/vm/src/utils/stark_utils.rs b/crates/vm/src/utils/stark_utils.rs index fbbdc35297..2a0bf1ab98 100644 --- a/crates/vm/src/utils/stark_utils.rs +++ b/crates/vm/src/utils/stark_utils.rs @@ -41,7 +41,13 @@ where VC::Periphery: Chip, { setup_tracing(); - let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let mut log_blowup = 1; + while config.system().max_constraint_degree > (1 << log_blowup) + 1 { + log_blowup += 1; + } + let engine = BabyBearPoseidon2Engine::new( + FriParameters::standard_with_100_bits_conjectured_security(log_blowup), + ); let vm = VirtualMachine::new(engine, config); let pk = vm.keygen(); let mut result = vm.execute_and_generate(exe, input).unwrap(); diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 0468806ba2..c954702d4b 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -333,7 +333,7 @@ fn test_vm_initial_memory() { fn test_vm_1_persistent() { let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); let config = NativeConfig { - system: SystemConfig::new(3, MemoryConfig::new(1, 1, 16, 10, 6, 64, 1024), 0), + system: SystemConfig::new(5, MemoryConfig::new(1, 1, 16, 10, 6, 64, 1024), 0), native: Default::default(), } .with_continuations(); diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 1bb5a04676..8dc8fcd0c8 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -30,6 +30,7 @@ eyre.workspace = true serde.workspace = true serde-big-array.workspace = true bitcode.workspace = true +static_assertions.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7b757d06fd..fb3792996e 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -1,26 +1,23 @@ +use core::ops::Deref; use std::{ - array::{self, from_fn}, + array, borrow::{Borrow, BorrowMut}, + mem::offset_of, sync::{Arc, Mutex}, }; +use itertools::{zip_eq, Itertools}; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor}, system::{ memory::{ - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols, - }, + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, }, program::ProgramBus, }, }; -use openvm_circuit_primitives::{ - is_zero::{IsZeroIo, IsZeroSubAir}, - utils::{assert_array_eq, next_power_of_two_or_zero, not}, - SubAir, TraceSubRowGenerator, -}; +use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::FriOpcode::FRI_REDUCED_OPENING; @@ -36,264 +33,399 @@ use openvm_stark_backend::{ Chip, ChipUsageGetter, Stateful, }; use serde::{Deserialize, Serialize}; +use static_assertions::const_assert_eq; -use super::field_extension::{FieldExtension, EXT_DEG}; +use crate::field_extension::{FieldExtension, EXT_DEG}; #[cfg(test)] mod tests; #[repr(C)] -#[derive(AlignedBorrow)] -pub struct FriReducedOpeningCols { - pub enabled: T, - - pub pc: T, - pub start_timestamp: T, - - pub a_ptr_ptr: T, - pub b_ptr_ptr: T, - pub result_ptr: T, - pub addr_space: T, - pub length_ptr: T, - pub alpha_ptr: T, - pub alpha_pow_ptr: T, - - pub a_ptr_aux: MemoryReadAuxCols, - pub b_ptr_aux: MemoryReadAuxCols, - pub a_aux: MemoryReadAuxCols, - pub b_aux: MemoryReadAuxCols, - pub result_aux: MemoryWriteAuxCols, - pub length_aux: MemoryReadAuxCols, - pub alpha_aux: MemoryReadAuxCols, - pub alpha_pow_aux: MemoryBaseAuxCols, - - pub a_ptr: T, - pub b_ptr: T, - pub a: T, - pub b: [T; EXT_DEG], - pub alpha: [T; EXT_DEG], - pub alpha_pow_original: [T; EXT_DEG], - pub alpha_pow_current: [T; EXT_DEG], - - pub idx: T, - pub idx_is_zero: T, - pub is_zero_aux: T, - pub current: [T; EXT_DEG], +#[derive(Debug, AlignedBorrow)] +struct WorkloadCols { + prefix: PrefixCols, + + a_aux: MemoryReadAuxCols, + b: [T; EXT_DEG], + b_aux: MemoryReadAuxCols, +} +const WL_WIDTH: usize = WorkloadCols::::width(); +const_assert_eq!(WL_WIDTH, 26); + +#[repr(C)] +#[derive(Debug, AlignedBorrow)] +struct Instruction1Cols { + prefix: PrefixCols, + + pc: T, + + a_ptr_ptr: T, + a_ptr_aux: MemoryReadAuxCols, + + b_ptr_ptr: T, + b_ptr_aux: MemoryReadAuxCols, +} +const INS_1_WIDTH: usize = Instruction1Cols::::width(); +const_assert_eq!(INS_1_WIDTH, 25); +const_assert_eq!( + offset_of!(WorkloadCols, prefix), + offset_of!(Instruction1Cols, prefix) +); + +#[repr(C)] +#[derive(Debug, AlignedBorrow)] +struct Instruction2Cols { + general: GeneralCols, + // is_first = 0 means the second instruction row. + is_first: T, + + result_ptr: T, + result_aux: MemoryWriteAuxCols, + + length_ptr: T, + length_aux: MemoryReadAuxCols, + + alpha_ptr: T, + alpha_aux: MemoryReadAuxCols, +} +const INS_2_WIDTH: usize = Instruction2Cols::::width(); +const_assert_eq!(INS_2_WIDTH, 20); +const_assert_eq!( + offset_of!(WorkloadCols, prefix) + offset_of!(PrefixCols, general), + offset_of!(Instruction2Cols, general) +); +const_assert_eq!( + offset_of!(Instruction1Cols, prefix) + offset_of!(PrefixCols, a_or_is_first), + offset_of!(Instruction2Cols, is_first) +); + +const fn const_max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} +pub const OVERALL_WIDTH: usize = const_max(const_max(WL_WIDTH, INS_1_WIDTH), INS_2_WIDTH); +const_assert_eq!(OVERALL_WIDTH, 26); + +#[repr(C)] +#[derive(Debug, AlignedBorrow)] +struct GeneralCols { + enabled: T, + is_ins_row: T, + timestamp: T, +} +const GENERAL_WIDTH: usize = GeneralCols::::width(); +const_assert_eq!(GENERAL_WIDTH, 3); + +#[repr(C)] +#[derive(Debug, AlignedBorrow)] +struct DataCols { + addr_space: T, + a_ptr: T, + b_ptr: T, + idx: T, + result: [T; EXT_DEG], + alpha: [T; EXT_DEG], } +#[allow(dead_code)] +const DATA_WIDTH: usize = DataCols::::width(); +const_assert_eq!(DATA_WIDTH, 12); + +/// Prefix of `WorkloadCols` and `Instruction1Cols` +#[repr(C)] +#[derive(Debug, AlignedBorrow)] +struct PrefixCols { + general: GeneralCols, + /// WorkloadCols uses this column as `a`. Instruction1Cols uses this column as `is_first` which + /// indicates whether this is the first row of an instruction row. This is to save a column. + a_or_is_first: T, + data: DataCols, +} +const PREFIX_WIDTH: usize = PrefixCols::::width(); +const_assert_eq!(PREFIX_WIDTH, 16); #[derive(Copy, Clone, Debug)] -pub struct FriReducedOpeningAir { - pub execution_bridge: ExecutionBridge, - pub memory_bridge: MemoryBridge, +struct FriReducedOpeningAir { + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, } impl BaseAir for FriReducedOpeningAir { fn width(&self) -> usize { - FriReducedOpeningCols::::width() + OVERALL_WIDTH } } impl BaseAirWithPublicValues for FriReducedOpeningAir {} impl PartitionedBaseAir for FriReducedOpeningAir {} - impl Air for FriReducedOpeningAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); - let local: &FriReducedOpeningCols = (*local).borrow(); let next = main.row_slice(1); - let next: &FriReducedOpeningCols = (*next).borrow(); + let local_slice = local.deref(); + let next_slice = next.deref(); + self.eval_general(builder, local_slice, next_slice); + self.eval_workload_row(builder, local_slice, next_slice); + self.eval_instruction1_row(builder, local_slice, next_slice); + self.eval_instruction2_row(builder, local_slice, next_slice); + } +} - let &FriReducedOpeningCols { - enabled, - pc, - start_timestamp, - a_ptr_ptr, - b_ptr_ptr, - result_ptr, - addr_space, - length_ptr, - alpha_ptr, - alpha_pow_ptr, - a_ptr, - b_ptr, - a, - b, - alpha, - alpha_pow_original, - alpha_pow_current, - idx, - idx_is_zero, - is_zero_aux, - current, - a_ptr_aux, - b_ptr_aux, - a_aux, - b_aux, - result_aux, - length_aux, - alpha_aux, - alpha_pow_aux, - } = local; - - let is_first = idx_is_zero; - let is_last = next.idx_is_zero; - - builder.assert_bool(enabled); - // transition constraints - let mut when_is_not_last = builder.when(not(is_last)); - - let next_alpha_pow_times_b = FieldExtension::multiply(next.alpha_pow_current, next.b); - for i in 0..EXT_DEG { - when_is_not_last.assert_eq( - next.current[i], - next_alpha_pow_times_b[i].clone() - (next.alpha_pow_current[i] * next.a) - + current[i], - ); +impl FriReducedOpeningAir { + fn eval_general( + &self, + builder: &mut AB, + local_slice: &[AB::Var], + next_slice: &[AB::Var], + ) { + let local: &GeneralCols = local_slice[..GENERAL_WIDTH].borrow(); + let next: &GeneralCols = next_slice[..GENERAL_WIDTH].borrow(); + builder.assert_bool(local.enabled); + builder.assert_bool(local.is_ins_row); + { + // All enabled rows must be before disabled rows. + let mut when_transition = builder.when_transition(); + let mut when_disabled = when_transition.when_ne(local.enabled, AB::Expr::ONE); + when_disabled.assert_zero(next.enabled); } + } - assert_array_eq(&mut when_is_not_last, next.alpha, alpha); - assert_array_eq( - &mut when_is_not_last, - next.alpha_pow_original, - alpha_pow_original, - ); - assert_array_eq( - &mut when_is_not_last, - next.alpha_pow_current, - FieldExtension::multiply(alpha, alpha_pow_current), - ); - when_is_not_last.assert_eq(next.idx, idx + AB::Expr::ONE); - when_is_not_last.assert_eq(next.enabled, enabled); - when_is_not_last.assert_eq(next.start_timestamp, start_timestamp); - - // first row constraint - assert_array_eq( - &mut builder.when(is_first), - alpha_pow_current, - alpha_pow_original, - ); - - let alpha_pow_times_b = FieldExtension::multiply(alpha_pow_current, b); - for i in 0..EXT_DEG { - builder.when(is_first).assert_eq( - current[i], - alpha_pow_times_b[i].clone() - (alpha_pow_current[i] * a), + fn eval_workload_row( + &self, + builder: &mut AB, + local_slice: &[AB::Var], + next_slice: &[AB::Var], + ) { + let local: &WorkloadCols = local_slice[..WL_WIDTH].borrow(); + let next: &PrefixCols = next_slice[..PREFIX_WIDTH].borrow(); + let local_data = &local.prefix.data; + let start_timestamp = next.general.timestamp; + let multiplicity = + local.prefix.general.enabled * (AB::Expr::ONE - local.prefix.general.is_ins_row); + // a_ptr/b_ptr/length/result + let ptr_reads = AB::F::from_canonical_usize(4); + // read a + self.memory_bridge + .read( + MemoryAddress::new(local_data.addr_space, next.data.a_ptr), + [local.prefix.a_or_is_first], + start_timestamp + ptr_reads, + &local.a_aux, + ) + .eval(builder, multiplicity.clone()); + // read b + self.memory_bridge + .read( + MemoryAddress::new(local_data.addr_space, next.data.b_ptr), + local.b, + start_timestamp + ptr_reads + AB::Expr::ONE, + &local.b_aux, + ) + .eval(builder, multiplicity); + { + let mut when_transition = builder.when_transition(); + let mut not_ins_row = + when_transition.when_ne(local.prefix.general.is_ins_row, AB::F::ONE); + let mut builder = not_ins_row.when(next.general.enabled); + // ATTENTION: degree of builder is 2 + // local.timestamp = next.timestamp + 2 + builder.assert_eq( + local.prefix.general.timestamp, + start_timestamp + AB::Expr::TWO, ); + // local.idx = next.idx + 1 + builder.assert_eq(local_data.idx + AB::Expr::ONE, next.data.idx); + // local.alpha = next.alpha + assert_array_eq(&mut builder, local_data.alpha, next.data.alpha); + // local.addr_space = next.addr_space + builder.assert_eq(local_data.addr_space, next.data.addr_space); + // local.a_ptr = next.a_ptr + 1 + builder.assert_eq(local_data.a_ptr, next.data.a_ptr + AB::F::ONE); + // local.b_ptr = next.b_ptr + EXT_DEG + builder.assert_eq( + local_data.b_ptr, + next.data.b_ptr + AB::F::from_canonical_usize(EXT_DEG), + ); + // local.timestamp = next.timestamp + 2 + builder.assert_eq( + local.prefix.general.timestamp, + next.general.timestamp + AB::Expr::TWO, + ); + // local.result * local.alpha + local.b - local.a = next.result + let mut expected_result = FieldExtension::multiply(local_data.result, local_data.alpha); + expected_result + .iter_mut() + .zip(local.b.iter()) + .for_each(|(e, b)| { + *e += (*b).into(); + }); + expected_result[0] -= local.prefix.a_or_is_first.into(); + assert_array_eq(&mut builder, expected_result, next.data.result); } + { + let mut next_ins = builder.when(next.general.is_ins_row); + let mut local_non_ins = + next_ins.when_ne(local.prefix.general.is_ins_row, AB::Expr::ONE); + // The row after a workload row can only be the first instruction row. + local_non_ins.assert_one(next.a_or_is_first); + } + { + let mut when_first_row = builder.when_first_row(); + let mut when_enabled = when_first_row.when(local.prefix.general.enabled); + // First row must be a workload row. + when_enabled.assert_zero(local.prefix.general.is_ins_row); + // Workload rows must start with the first element. + when_enabled.assert_zero(local.prefix.data.idx); + // local.result is all 0s. + assert_array_eq( + &mut when_enabled, + local.prefix.data.result, + [AB::Expr::ZERO; EXT_DEG], + ); + } + } - // is zero constraint - let is_zero_io = IsZeroIo::new(idx.into(), idx_is_zero.into(), AB::Expr::ONE); - IsZeroSubAir.eval(builder, (is_zero_io, is_zero_aux)); - - // length will only be used on the last row, so it equals 1 + idx - let length = AB::Expr::ONE + idx; - let num_initial_accesses = AB::F::from_canonical_usize(4); - let num_loop_accesses = AB::Expr::TWO * length.clone(); - let num_final_accesses = AB::F::TWO; - - // execution interaction - let total_accesses = num_loop_accesses.clone() + num_initial_accesses + num_final_accesses; + fn eval_instruction1_row( + &self, + builder: &mut AB, + local_slice: &[AB::Var], + next_slice: &[AB::Var], + ) { + let local: &Instruction1Cols = local_slice[..INS_1_WIDTH].borrow(); + let next: &Instruction2Cols = next_slice[..INS_2_WIDTH].borrow(); + let mut is_ins_row = builder.when(local.prefix.general.is_ins_row); + let mut is_first_ins = is_ins_row.when(local.prefix.a_or_is_first); + let mut next_enabled = is_first_ins.when(next.general.enabled); + // ATTENTION: degree of next_enabled is 3 + next_enabled.assert_zero(next.is_first); + next_enabled.assert_one(next.general.is_ins_row); + + let local_data = &local.prefix.data; + let length = local.prefix.data.idx; + let multiplicity = local.prefix.general.enabled + * local.prefix.general.is_ins_row + * local.prefix.a_or_is_first; + let start_timestamp = local.prefix.general.timestamp; + // 4 reads + let write_timestamp = + start_timestamp + AB::Expr::TWO * length + AB::Expr::from_canonical_usize(4); + let end_timestamp = write_timestamp.clone() + AB::Expr::ONE; self.execution_bridge .execute( AB::F::from_canonical_usize(FRI_REDUCED_OPENING.global_opcode().as_usize()), [ - a_ptr_ptr, - b_ptr_ptr, - result_ptr, - addr_space, - length_ptr, - alpha_ptr, - alpha_pow_ptr, + local.a_ptr_ptr, + local.b_ptr_ptr, + next.result_ptr, + local_data.addr_space, + next.length_ptr, + next.alpha_ptr, ], - ExecutionState::new(pc, start_timestamp), + ExecutionState::new(local.pc, local.prefix.general.timestamp), ExecutionState::::new( - AB::Expr::from_canonical_u32(DEFAULT_PC_STEP) + pc, - start_timestamp + total_accesses, + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP) + local.pc, + end_timestamp.clone(), ), ) - .eval(builder, enabled * is_last); - - // initial reads + .eval(builder, multiplicity.clone()); + // Read alpha self.memory_bridge .read( - MemoryAddress::new(addr_space, alpha_ptr), - alpha, + MemoryAddress::new(local_data.addr_space, next.alpha_ptr), + local_data.alpha, start_timestamp, - &alpha_aux, + &next.alpha_aux, ) - .eval(builder, enabled * is_last); + .eval(builder, multiplicity.clone()); + // Read length self.memory_bridge .read( - MemoryAddress::new(addr_space, length_ptr), + MemoryAddress::new(local_data.addr_space, next.length_ptr), [length], - start_timestamp + AB::F::ONE, - &length_aux, - ) - .eval(builder, enabled * is_last); - self.memory_bridge - .read( - MemoryAddress::new(addr_space, a_ptr_ptr), - [a_ptr], - start_timestamp + AB::F::TWO, - &a_ptr_aux, + start_timestamp + AB::Expr::ONE, + &next.length_aux, ) - .eval(builder, enabled * is_last); + .eval(builder, multiplicity.clone()); + // Read a_ptr self.memory_bridge .read( - MemoryAddress::new(addr_space, b_ptr_ptr), - [b_ptr], - start_timestamp + AB::F::from_canonical_usize(3), - &b_ptr_aux, + MemoryAddress::new(local_data.addr_space, local.a_ptr_ptr), + [local_data.a_ptr], + start_timestamp + AB::Expr::TWO, + &local.a_ptr_aux, ) - .eval(builder, enabled * is_last); - - // general reads - let timestamp = start_timestamp + num_initial_accesses + (idx * AB::F::TWO); + .eval(builder, multiplicity.clone()); + // Read b_ptr self.memory_bridge .read( - MemoryAddress::new(addr_space, a_ptr + idx), - [a], - timestamp.clone(), - &a_aux, + MemoryAddress::new(local_data.addr_space, local.b_ptr_ptr), + [local_data.b_ptr], + start_timestamp + AB::Expr::from_canonical_u32(3), + &local.b_ptr_aux, ) - .eval(builder, enabled); - self.memory_bridge - .read( - MemoryAddress::new( - addr_space, - b_ptr + (idx * AB::F::from_canonical_usize(EXT_DEG)), - ), - b, - timestamp + AB::F::ONE, - &b_aux, - ) - .eval(builder, enabled); - - // final writes - let timestamp = start_timestamp + num_initial_accesses + num_loop_accesses.clone(); - self.memory_bridge - .write( - MemoryAddress::new(addr_space, alpha_pow_ptr), - FieldExtension::multiply(alpha, alpha_pow_current), - timestamp.clone(), - &MemoryWriteAuxCols { - base: alpha_pow_aux, - prev_data: alpha_pow_original, - }, - ) - .eval(builder, enabled * is_last); + .eval(builder, multiplicity.clone()); self.memory_bridge .write( - MemoryAddress::new(addr_space, result_ptr), - current, - timestamp + AB::F::ONE, - &result_aux, + MemoryAddress::new(local_data.addr_space, next.result_ptr), + local_data.result, + write_timestamp, + &next.result_aux, ) - .eval(builder, enabled * is_last); + .eval(builder, multiplicity.clone()); + } + + fn eval_instruction2_row( + &self, + builder: &mut AB, + local_slice: &[AB::Var], + next_slice: &[AB::Var], + ) { + let local: &Instruction2Cols = local_slice[..INS_2_WIDTH].borrow(); + let next: &WorkloadCols = next_slice[..WL_WIDTH].borrow(); + + { + let mut last_row = builder.when_last_row(); + let mut enabled = last_row.when(local.general.enabled); + // If the last row is enabled, it must be the second row of an instruction row. This + // is a safeguard for edge cases. + enabled.assert_one(local.general.is_ins_row); + enabled.assert_one(local.is_first); + } + { + let mut when_transition = builder.when_transition(); + let mut is_ins_row = when_transition.when(local.general.is_ins_row); + let mut not_first_row = is_ins_row.when_ne(local.is_first, AB::Expr::ONE); + // when_transition is necessary to check the next row is enabled. + let mut enabled = not_first_row.when(next.prefix.general.enabled); + // The next row must be a workload row. + enabled.assert_zero(next.prefix.general.is_ins_row); + // The next row must have idx = 0. + enabled.assert_zero(next.prefix.data.idx); + // next.result is all 0s + assert_array_eq( + &mut enabled, + next.prefix.data.result, + [AB::Expr::ZERO; EXT_DEG], + ); + } + } +} + +fn assert_array_eq, I2: Into, const N: usize>( + builder: &mut AB, + x: [I1; N], + y: [I2; N], +) { + for (x, y) in zip_eq(x, y) { + builder.assert_eq(x, y); } } +fn elem_to_ext(elem: F) -> [F; EXT_DEG] { + let mut ret = [F::ZERO; EXT_DEG]; + ret[0] = elem; + ret +} + #[derive(Serialize, Deserialize)] #[serde(bound = "F: Field")] pub struct FriReducedOpeningRecord { @@ -306,18 +438,22 @@ pub struct FriReducedOpeningRecord { pub b_ptr_read: RecordId, pub a_reads: Vec, pub b_reads: Vec, - pub alpha_pow_original: [F; EXT_DEG], - pub alpha_pow_write: RecordId, pub result_write: RecordId, } +impl FriReducedOpeningRecord { + fn get_height(&self) -> usize { + // 2 for instruction rows + self.a_reads.len() + 2 + } +} + pub struct FriReducedOpeningChip { air: FriReducedOpeningAir, records: Vec>, height: usize, offline_memory: Arc>>, } - impl FriReducedOpeningChip { pub fn new( execution_bus: ExecutionBus, @@ -337,13 +473,6 @@ impl FriReducedOpeningChip { } } } - -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { - let mut ret = [F::ZERO; EXT_DEG]; - ret[0] = elem; - ret -} - impl InstructionExecutor for FriReducedOpeningChip { fn execute( &mut self, @@ -358,7 +487,6 @@ impl InstructionExecutor for FriReducedOpeningChip { d: addr_space, e: length_ptr, f: alpha_ptr, - g: alpha_pow_ptr, .. } = instruction; @@ -368,10 +496,6 @@ impl InstructionExecutor for FriReducedOpeningChip { let b_ptr_read = memory.read_cell(addr_space, b_ptr_ptr); let alpha = alpha_read.1; - let alpha_pow_original = from_fn(|i| { - memory.unsafe_read_cell(addr_space, alpha_pow_ptr + F::from_canonical_usize(i)) - }); - let mut alpha_pow = alpha_pow_original; let length = length_read.1.as_canonical_u32() as usize; let a_ptr = a_ptr_read.1; let b_ptr = b_ptr_read.1; @@ -382,23 +506,25 @@ impl InstructionExecutor for FriReducedOpeningChip { for i in 0..length { let a_read = memory.read_cell(addr_space, a_ptr + F::from_canonical_usize(i)); - let b_read = memory.read(addr_space, b_ptr + F::from_canonical_usize(4 * i)); + let b_read = + memory.read::(addr_space, b_ptr + F::from_canonical_usize(EXT_DEG * i)); a_reads.push(a_read); b_reads.push(b_read); + } + + for (a_read, b_read) in a_reads.iter().rev().zip_eq(b_reads.iter().rev()) { let a = a_read.1; let b = b_read.1; + // result = result * alpha + (b - a) result = FieldExtension::add( - result, - FieldExtension::multiply(FieldExtension::subtract(b, elem_to_ext(a)), alpha_pow), + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), ); - alpha_pow = FieldExtension::multiply(alpha, alpha_pow); } - let (alpha_pow_write, prev_data) = memory.write(addr_space, alpha_pow_ptr, alpha_pow); - debug_assert_eq!(prev_data, alpha_pow_original); let (result_write, _) = memory.write(addr_space, result_ptr, result); - self.records.push(FriReducedOpeningRecord { + let record = FriReducedOpeningRecord { pc: F::from_canonical_u32(from_state.pc), start_timestamp: F::from_canonical_u32(from_state.timestamp), instruction: instruction.clone(), @@ -408,12 +534,10 @@ impl InstructionExecutor for FriReducedOpeningChip { b_ptr_read: b_ptr_read.0, a_reads: a_reads.into_iter().map(|r| r.0).collect(), b_reads: b_reads.into_iter().map(|r| r.0).collect(), - alpha_pow_original, - alpha_pow_write, result_write, - }); - - self.height += length; + }; + self.height += record.get_height(); + self.records.push(record); Ok(ExecutionState { pc: from_state.pc + DEFAULT_PC_STEP, @@ -427,6 +551,132 @@ impl InstructionExecutor for FriReducedOpeningChip { } } +fn record_to_rows( + record: FriReducedOpeningRecord, + aux_cols_factory: &MemoryAuxColsFactory, + slice: &mut [F], + memory: &OfflineMemory, +) { + let Instruction { + a: a_ptr_ptr, + b: b_ptr_ptr, + c: result_ptr, + d: addr_space, + e: length_ptr, + f: alpha_ptr, + .. + } = record.instruction; + + let length_read = memory.record_by_id(record.length_read); + let alpha_read = memory.record_by_id(record.alpha_read); + let a_ptr_read = memory.record_by_id(record.a_ptr_read); + let b_ptr_read = memory.record_by_id(record.b_ptr_read); + + let length = length_read.data[0].as_canonical_u32() as usize; + let alpha: [F; EXT_DEG] = array::from_fn(|i| alpha_read.data[i]); + let a_ptr = a_ptr_read.data[0]; + let b_ptr = b_ptr_read.data[0]; + + let mut result = [F::ZERO; EXT_DEG]; + + let alpha_aux = aux_cols_factory.make_read_aux_cols(alpha_read); + let length_aux = aux_cols_factory.make_read_aux_cols(length_read); + let a_ptr_aux = aux_cols_factory.make_read_aux_cols(a_ptr_read); + let b_ptr_aux = aux_cols_factory.make_read_aux_cols(b_ptr_read); + + let result_aux = aux_cols_factory.make_write_aux_cols(memory.record_by_id(record.result_write)); + + // WorkloadCols + for (i, (&a_record_id, &b_record_id)) in record + .a_reads + .iter() + .rev() + .zip_eq(record.b_reads.iter().rev()) + .enumerate() + { + let a_read = memory.record_by_id(a_record_id); + let b_read = memory.record_by_id(b_record_id); + let a = a_read.data[0]; + let b: [F; EXT_DEG] = array::from_fn(|i| b_read.data[i]); + + let start = i * OVERALL_WIDTH; + let cols: &mut WorkloadCols = slice[start..start + WL_WIDTH].borrow_mut(); + *cols = WorkloadCols { + prefix: PrefixCols { + general: GeneralCols { + enabled: F::ONE, + is_ins_row: F::ZERO, + timestamp: record.start_timestamp + F::from_canonical_usize((length - i) * 2), + }, + a_or_is_first: a, + data: DataCols { + addr_space, + a_ptr: a_ptr + F::from_canonical_usize(length - i), + b_ptr: b_ptr + F::from_canonical_usize((length - i) * EXT_DEG), + idx: F::from_canonical_usize(i), + result, + alpha, + }, + }, + a_aux: aux_cols_factory.make_read_aux_cols(a_read), + b, + b_aux: aux_cols_factory.make_read_aux_cols(b_read), + }; + // result = result * alpha + (b - a) + result = FieldExtension::add( + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), + ); + } + // Instruction1Cols + { + let start = length * OVERALL_WIDTH; + let cols: &mut Instruction1Cols = slice[start..start + INS_1_WIDTH].borrow_mut(); + *cols = Instruction1Cols { + prefix: PrefixCols { + general: GeneralCols { + enabled: F::ONE, + is_ins_row: F::ONE, + timestamp: record.start_timestamp, + }, + a_or_is_first: F::ONE, + data: DataCols { + addr_space, + a_ptr, + b_ptr, + idx: F::from_canonical_usize(length), + result, + alpha, + }, + }, + pc: record.pc, + a_ptr_ptr, + a_ptr_aux, + b_ptr_ptr, + b_ptr_aux, + }; + } + // Instruction2Cols + { + let start = (length + 1) * OVERALL_WIDTH; + let cols: &mut Instruction2Cols = slice[start..start + INS_2_WIDTH].borrow_mut(); + *cols = Instruction2Cols { + general: GeneralCols { + enabled: F::ONE, + is_ins_row: F::ONE, + timestamp: record.start_timestamp, + }, + is_first: F::ZERO, + result_ptr, + result_aux, + length_ptr, + length_aux, + alpha_ptr, + alpha_aux, + }; + } +} + impl ChipUsageGetter for FriReducedOpeningChip { fn air_name(&self) -> String { "FriReducedOpeningAir".to_string() @@ -437,141 +687,7 @@ impl ChipUsageGetter for FriReducedOpeningChip { } fn trace_width(&self) -> usize { - FriReducedOpeningCols::::width() - } -} - -impl FriReducedOpeningChip { - fn record_to_rows( - record: FriReducedOpeningRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let width = FriReducedOpeningCols::::width(); - - let Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: result_ptr, - d: addr_space, - e: length_ptr, - f: alpha_ptr, - g: alpha_pow_ptr, - .. - } = record.instruction; - - let length_read = memory.record_by_id(record.length_read); - let alpha_read = memory.record_by_id(record.alpha_read); - let a_ptr_read = memory.record_by_id(record.a_ptr_read); - let b_ptr_read = memory.record_by_id(record.b_ptr_read); - - let length = length_read.data[0].as_canonical_u32() as usize; - let alpha: [F; EXT_DEG] = array::from_fn(|i| alpha_read.data[i]); - let a_ptr = a_ptr_read.data[0]; - let b_ptr = b_ptr_read.data[0]; - - let mut alpha_pow_current = record.alpha_pow_original; - let mut current = [F::ZERO; EXT_DEG]; - - let alpha_aux = aux_cols_factory.make_read_aux_cols(alpha_read); - let length_aux = aux_cols_factory.make_read_aux_cols(length_read); - let a_ptr_aux = aux_cols_factory.make_read_aux_cols(a_ptr_read); - let b_ptr_aux = aux_cols_factory.make_read_aux_cols(b_ptr_read); - - let alpha_pow_aux = aux_cols_factory - .make_write_aux_cols::(memory.record_by_id(record.alpha_pow_write)) - .get_base(); - let result_aux = - aux_cols_factory.make_write_aux_cols(memory.record_by_id(record.result_write)); - - for i in 0..length { - let a_read = memory.record_by_id(record.a_reads[i]); - let b_read = memory.record_by_id(record.b_reads[i]); - let a = a_read.data[0]; - let b: [F; EXT_DEG] = array::from_fn(|i| b_read.data[i]); - current = FieldExtension::add( - current, - FieldExtension::multiply( - FieldExtension::subtract(b, elem_to_ext(a)), - alpha_pow_current, - ), - ); - - let mut idx_is_zero = F::ZERO; - let mut is_zero_aux = F::ZERO; - - let idx = F::from_canonical_usize(i); - IsZeroSubAir.generate_subrow(idx, (&mut is_zero_aux, &mut idx_is_zero)); - - let cols: &mut FriReducedOpeningCols = - slice[i * width..(i + 1) * width].borrow_mut(); - *cols = FriReducedOpeningCols { - enabled: F::ONE, - pc: record.pc, - a_ptr_ptr, - b_ptr_ptr, - result_ptr, - addr_space, - length_ptr, - alpha_ptr, - alpha_pow_ptr, - start_timestamp: record.start_timestamp, - a_ptr_aux, - b_ptr_aux, - a_aux: aux_cols_factory.make_read_aux_cols(a_read), - b_aux: aux_cols_factory.make_read_aux_cols(b_read), - alpha_aux, - length_aux, - alpha_pow_aux, - result_aux, - a_ptr, - b_ptr, - a, - b, - alpha, - alpha_pow_original: record.alpha_pow_original, - alpha_pow_current, - idx, - idx_is_zero, - is_zero_aux, - current, - }; - - alpha_pow_current = FieldExtension::multiply(alpha, alpha_pow_current); - } - } - - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); - - let memory = self.offline_memory.lock().unwrap(); - - let aux_cols_factory = memory.aux_cols_factory(); - - let mut idx = 0; - for record in self.records { - let length = record.a_reads.len(); - Self::record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[idx..idx + (length * width)], - &memory, - ); - idx += length * width; - } - // In padding rows, need idx_is_zero = 1 so IsZero constraints pass, and also because next.idx_is_zero is used - // to determine the last row per instruction, so the last non-padding row needs next.idx_is_zero = 1 - flat_trace[self.height * width..] - .par_chunks_mut(width) - .for_each(|row| { - let row: &mut FriReducedOpeningCols = row.borrow_mut(); - row.idx_is_zero = F::ONE; - }); - - RowMajorMatrix::new(flat_trace, width) + OVERALL_WIDTH } } @@ -583,17 +699,51 @@ where Arc::new(self.air) } fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.air(), self.generate_trace()) + let air = self.air(); + let height = next_power_of_two_or_zero(self.height); + let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); + let chunked_trace = { + let sizes: Vec<_> = self + .records + .par_iter() + .map(|record| OVERALL_WIDTH * record.get_height()) + .collect(); + variable_chunks_mut(&mut flat_trace, &sizes) + }; + + let memory = self.offline_memory.lock().unwrap(); + let aux_cols_factory = memory.aux_cols_factory(); + + self.records + .into_par_iter() + .zip_eq(chunked_trace.into_par_iter()) + .for_each(|(record, slice)| { + record_to_rows(record, &aux_cols_factory, slice, &memory); + }); + + let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); + AirProofInput::simple_no_pis(air, matrix) } } impl Stateful> for FriReducedOpeningChip { fn load_state(&mut self, state: Vec) { self.records = bitcode::deserialize(&state).unwrap(); - self.height = self.records.iter().map(|record| record.a_reads.len()).sum(); + self.height = self.records.iter().map(|record| record.get_height()).sum(); } fn store_state(&self) -> Vec { bitcode::serialize(&self.records).unwrap() } } + +fn variable_chunks_mut<'a, T>(mut slice: &'a mut [T], sizes: &[usize]) -> Vec<&'a mut [T]> { + let mut result = Vec::with_capacity(sizes.len()); + for &size in sizes { + // split_at_mut guarantees disjoint slices + let (left, right) = slice.split_at_mut(size); + result.push(left); + slice = right; // move forward for the next chunk + } + result +} diff --git a/extensions/native/circuit/src/fri/tests.rs b/extensions/native/circuit/src/fri/tests.rs index 1fd138ddde..de96861360 100644 --- a/extensions/native/circuit/src/fri/tests.rs +++ b/extensions/native/circuit/src/fri/tests.rs @@ -7,20 +7,23 @@ use openvm_stark_backend::{ utils::disable_debug_builder, verifier::VerificationError, }; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, + utils::create_seeded_rng, +}; use rand::Rng; -use super::{ - super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningChip, - FriReducedOpeningCols, EXT_DEG, -}; +use super::{super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningChip, EXT_DEG}; +use crate::OVERALL_WIDTH; fn compute_fri_mat_opening( alpha: [F; EXT_DEG], - mut alpha_pow: [F; EXT_DEG], a: &[F], b: &[[F; EXT_DEG]], -) -> ([F; EXT_DEG], [F; EXT_DEG]) { +) -> [F; EXT_DEG] { + let mut alpha_pow: [F; EXT_DEG] = elem_to_ext(F::ONE); let mut result = [F::ZERO; EXT_DEG]; for (&a, &b) in a.iter().zip_eq(b) { result = FieldExtension::add( @@ -29,7 +32,7 @@ fn compute_fri_mat_opening( ); alpha_pow = FieldExtension::multiply(alpha, alpha_pow); } - (alpha_pow, result) + result } #[test] @@ -60,19 +63,17 @@ fn fri_mat_opening_air_test() { for _ in 0..num_ops { let alpha = gen_ext!(); let length = rng.gen_range(length_range()); - let alpha_pow_initial = gen_ext!(); let a = (0..length) .map(|_| BabyBear::from_canonical_u32(rng.gen_range(elem_range()))) .collect_vec(); let b = (0..length).map(|_| gen_ext!()).collect_vec(); - let (alpha_pow_final, result) = compute_fri_mat_opening(alpha, alpha_pow_initial, &a, &b); + let result = compute_fri_mat_opening(alpha, &a, &b); let alpha_pointer = gen_pointer(&mut rng, 4); let length_pointer = gen_pointer(&mut rng, 1); let a_pointer_pointer = gen_pointer(&mut rng, 1); let b_pointer_pointer = gen_pointer(&mut rng, 1); - let alpha_pow_pointer = gen_pointer(&mut rng, 4); let result_pointer = gen_pointer(&mut rng, 4); let a_pointer = gen_pointer(&mut rng, 1); let b_pointer = gen_pointer(&mut rng, 4); @@ -100,7 +101,6 @@ fn fri_mat_opening_air_test() { b_pointer_pointer, BabyBear::from_canonical_usize(b_pointer), ); - tester.write(address_space, alpha_pow_pointer, alpha_pow_initial); for i in 0..length { tester.write_cell(address_space, a_pointer + i, a[i]); tester.write(address_space, b_pointer + (4 * i), b[i]); @@ -117,19 +117,21 @@ fn fri_mat_opening_air_test() { address_space, length_pointer, alpha_pointer, - alpha_pow_pointer, ], ), ); - assert_eq!( - alpha_pow_final, - tester.read(address_space, alpha_pow_pointer) - ); assert_eq!(result, tester.read(address_space, result_pointer)); } - let mut tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); + let mut tester = tester.build_babybear_poseidon2().load(chip).finalize(); + // Degree needs >= 4 + tester + .test::(|| { + BabyBearPoseidon2Engine::new( + FriParameters::standard_with_100_bits_conjectured_security(2), + ) + }) + .expect("Verification failed"); disable_debug_builder(); // negative test pranking each value @@ -137,7 +139,7 @@ fn fri_mat_opening_air_test() { // TODO: better way to modify existing traces in tester let trace = tester.air_proof_inputs[2].raw.common_main.as_mut().unwrap(); let old_trace = trace.clone(); - for width in 0..FriReducedOpeningCols::::width() + for width in 0..OVERALL_WIDTH /* num operands */ { let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index 0c188998c7..09635e1dc3 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -7,7 +7,8 @@ use crate::{Native, NativeConfig}; pub fn execute_program(program: Program, input_stream: impl Into>) { let system_config = SystemConfig::default() .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); + .with_max_segment_len((1 << 25) - 100) + .with_profiling(); let config = NativeConfig::new(system_config, Native); let executor = VmExecutor::::new(config); diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 3d2911165e..086136fb0d 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -547,7 +547,7 @@ impl + TwoAdicField> AsmCo DslIr::Halt => { self.push(AsmInstruction::Halt, debug_info); } - DslIr::FriReducedOpening(alpha, curr_alpha_pow, at_x_array, at_z_array, result) => { + DslIr::FriReducedOpening(alpha, at_x_array, at_z_array, result) => { self.push( AsmInstruction::FriReducedOpening( at_x_array.ptr().fp(), @@ -560,7 +560,6 @@ impl + TwoAdicField> AsmCo Usize::Var(len) => len.fp(), }, alpha.fp(), - curr_alpha_pow.fp(), ), debug_info, ); diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index 665ba70eb8..bc60bd3115 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -116,8 +116,8 @@ pub enum AsmInstruction { /// (a, b, c) are memory pointers to (dst, lhs, rhs) Poseidon2Compress(i32, i32, i32), - /// (a, b, res, len, alpha, alpha_pow) - FriReducedOpening(i32, i32, i32, i32, i32, i32), + /// (a, b, res, len, alpha) + FriReducedOpening(i32, i32, i32, i32, i32), /// (dim, opened, opened_length, sibling, index, commit) /// opened values are field elements @@ -358,11 +358,11 @@ impl> AsmInstruction { AsmInstruction::CycleTrackerEnd() => { write!(f, "cycle_tracker_end") } - AsmInstruction::FriReducedOpening(a, b, res, len, alpha, alpha_pow) => { + AsmInstruction::FriReducedOpening(a, b, res, len, alpha) => { write!( f, - "fri_mat_opening ({})fp, ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", - a, b, res, len, alpha, alpha_pow + "fri_mat_opening ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", + a, b, res, len, alpha ) } AsmInstruction::VerifyBatchFelt(dim, opened, opened_length, sibling, index, commit) => { diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index c3f7c5365c..80029c7c38 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -464,7 +464,7 @@ fn convert_instruction>( AS::Native, AS::Native, )], - AsmInstruction::FriReducedOpening(a, b, res, len, alpha, alpha_pow) => vec![Instruction { + AsmInstruction::FriReducedOpening(a, b, res, len, alpha) => vec![Instruction { opcode: options.opcode_with_offset(FriOpcode::FRI_REDUCED_OPENING), a: i32_f(a), b: i32_f(b), @@ -472,7 +472,7 @@ fn convert_instruction>( d: AS::Native.to_field(), e: i32_f(len), f: i32_f(alpha), - g: i32_f(alpha_pow), + g: F::ZERO, }], AsmInstruction::VerifyBatchFelt(dim, opened, opened_length, sibling, index, commit) => vec![Instruction { opcode: options.opcode_with_offset(VerifyBatchOpcode::VERIFY_BATCH), diff --git a/extensions/native/compiler/src/ir/fri.rs b/extensions/native/compiler/src/ir/fri.rs index 91de10762d..f96bb14de0 100644 --- a/extensions/native/compiler/src/ir/fri.rs +++ b/extensions/native/compiler/src/ir/fri.rs @@ -4,14 +4,12 @@ impl Builder { pub fn fri_single_reduced_opening_eval( &mut self, alpha: Ext, - curr_alpha_pow: Ext, at_x_array: &Array>, at_z_array: &Array>, ) -> Ext { let result = self.uninit(); self.operations.push(crate::ir::DslIr::FriReducedOpening( alpha, - curr_alpha_pow, at_x_array.clone(), at_z_array.clone(), result, diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 8a6f8dfd90..1cf5b1fc1b 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -248,9 +248,8 @@ pub enum DslIr { CircuitExt2Felt([Felt; 4], Ext), /// Converts a slice of felts to an ext. Should only be used when target is a circuit. CircuitFelts2Ext([Felt; 4], Ext), - /// FriReducedOpening(alpha, curr_alpha_pow, at_x_array, at_z_array, result) + /// FriReducedOpening(alpha, at_x_array, at_z_array, result) FriReducedOpening( - Ext, Ext, Array>, Array>, diff --git a/extensions/native/compiler/tests/fri_ro_eval.rs b/extensions/native/compiler/tests/fri_ro_eval.rs index 138653e087..30c4c2f726 100644 --- a/extensions/native/compiler/tests/fri_ro_eval.rs +++ b/extensions/native/compiler/tests/fri_ro_eval.rs @@ -53,7 +53,6 @@ fn test_single_reduced_opening_eval() { builder.assign(&cur_alpha_pow, cur_alpha_pow * alpha); }); let expected_result = cur_ro; - let expected_final_alpha_pow = cur_alpha_pow; // prints don't work? /*builder.print_e(expected_result); @@ -64,19 +63,15 @@ fn test_single_reduced_opening_eval() { let ext_1210 = builder.constant(EF::from_base_slice(&[F::ONE, F::TWO, F::ONE, F::ZERO])); builder.print_e(ext_1210);*/ - let cur_alpha_pow: Ext<_, _> = builder.uninit(); builder.assign(&cur_alpha_pow, initial_alpha_pow); - let single_ro_eval_res = - builder.fri_single_reduced_opening_eval(alpha, cur_alpha_pow, &mat_opening, &ps_at_z); - let actual_final_alpha_pow = cur_alpha_pow; + let single_ro_eval_res = builder.fri_single_reduced_opening_eval(alpha, &mat_opening, &ps_at_z); let actual_result: Ext<_, _> = builder.uninit(); - builder.assign(&actual_result, single_ro_eval_res / (z - x)); + builder.assign(&actual_result, single_ro_eval_res * cur_alpha_pow / (z - x)); //builder.print_e(actual_result); //builder.print_e(actual_final_alpha_pow); builder.assert_ext_eq(expected_result, actual_result); - builder.assert_ext_eq(expected_final_alpha_pow, actual_final_alpha_pow); builder.halt(); diff --git a/extensions/native/recursion/src/fri/two_adic_pcs.rs b/extensions/native/recursion/src/fri/two_adic_pcs.rs index 777f10dafc..b1d95d412c 100644 --- a/extensions/native/recursion/src/fri/two_adic_pcs.rs +++ b/extensions/native/recursion/src/fri/two_adic_pcs.rs @@ -235,17 +235,14 @@ pub fn verify_two_adic_pcs( builder.assign(&cur_ro, cur_ro + cur_alpha_pow * n / (z - x)); builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow); } else { - // TODO: this is just for testing the correctness. Will remove later. - let expected_alpha_pow: Ext<_, _> = - builder.eval(cur_alpha_pow * mat_alpha_pow); let mat_ro = builder.fri_single_reduced_opening_eval( alpha, - cur_alpha_pow, &mat_opening, &ps_at_z, ); - builder.assert_ext_eq(expected_alpha_pow, cur_alpha_pow); - builder.assign(&cur_ro, cur_ro + (mat_ro / (z - x))); + builder + .assign(&cur_ro, cur_ro + (mat_ro * cur_alpha_pow / (z - x))); + builder.assign(&cur_alpha_pow, cur_alpha_pow * mat_alpha_pow); } builder.cycle_tracker_end("single-reduced-opening-eval");