From b31c5b3e148cb242f680755468ace56305232b7c Mon Sep 17 00:00:00 2001 From: luffykai Date: Tue, 21 Jan 2025 14:09:18 -0800 Subject: [PATCH 1/7] support custom segmentation strategy --- crates/vm/src/arch/config.rs | 9 ----- crates/vm/src/arch/segment.rs | 68 ++++++++++++++++++++++++----------- crates/vm/src/arch/vm.rs | 20 +++++++++-- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index d789e040f..cb60ca1e7 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -18,7 +18,6 @@ use super::{ }; use crate::system::memory::BOUNDARY_AIR_OFFSET; -const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; // sbox is decomposed to have this max degree for Poseidon2. We set to 3 so quotient_degree = 2 // allows log_blowup = 1 const DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE: usize = 3; @@ -86,8 +85,6 @@ pub struct SystemConfig { /// cannot read public values directly, but they can decommit the public values from the memory /// merkle root. pub num_public_values: usize, - /// When continuations are enabled, a heuristic used to determine when to segment execution. - pub max_segment_len: usize, /// Whether to collect detailed profiling metrics. /// **Warning**: this slows down the runtime. pub profiling: bool, @@ -110,7 +107,6 @@ impl SystemConfig { continuation_enabled: false, memory_config, num_public_values, - max_segment_len: DEFAULT_MAX_SEGMENT_LEN, profiling: false, } } @@ -135,11 +131,6 @@ impl SystemConfig { self } - pub fn with_max_segment_len(mut self, max_segment_len: usize) -> Self { - self.max_segment_len = max_segment_len; - self - } - pub fn with_profiling(mut self) -> Self { self.profiling = true; self diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index f889dff76..702e56dfa 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use backtrace::Backtrace; use openvm_instructions::{ exe::FnBounds, @@ -27,6 +29,40 @@ use crate::{ /// Check segment every 100 instructions. const SEGMENT_CHECK_INTERVAL: usize = 100; +const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; +// a heuristic number for the maximum number of cells per chip in a segment +// a few reasons for this number: +// 1. `VmAirWrapper` is +// the chip with the most cells in a segment from the reth-benchmark. +// 2. `VmAirWrapper`: +// its trace width is 36 and its after challenge trace width is 80. +const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; + +pub trait SegmentationStrategy { + fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool; +} + +/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. +pub struct DefaultSegmentationStrategy; + +impl SegmentationStrategy for DefaultSegmentationStrategy { + fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool { + for (i, &height) in trace_heights.iter().enumerate() { + if height > DEFAULT_MAX_SEGMENT_LEN { + tracing::info!("Should segment because chip {} has height {}", i, height); + return true; + } + } + for (i, &num_cells) in trace_cells.iter().enumerate() { + if num_cells > DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT { + tracing::info!("Should segment because chip {} has {} cells", i, num_cells,); + return true; + } + } + false + } +} + pub struct ExecutionSegment where F: PrimeField32, @@ -37,8 +73,6 @@ where pub since_last_segment_check: usize, - /// Air names for debug purposes only. - pub(crate) air_names: Vec, /// Metrics collected for this execution segment alone. #[cfg(feature = "bench-metrics")] pub(crate) metrics: VmMetrics, @@ -70,12 +104,10 @@ impl> ExecutionSegment { if let Some(initial_memory) = initial_memory { chip_complex.set_initial_memory(initial_memory); } - let air_names = chip_complex.air_names(); Self { chip_complex, final_memory: None, - air_names, #[cfg(feature = "bench-metrics")] metrics: VmMetrics { fn_bounds, @@ -99,11 +131,9 @@ impl> ExecutionSegment { }; chip_complex.set_program(program); chip_complex.load_state(vm_chip_complex_state); - let air_names = chip_complex.air_names(); Self { chip_complex, final_memory: None, - air_names, #[cfg(feature = "bench-metrics")] metrics: Default::default(), since_last_segment_check: 0, @@ -125,6 +155,7 @@ impl> ExecutionSegment { pub fn execute_from_pc( &mut self, mut pc: u32, + segment_strategy: Option>, ) -> Result { let mut timestamp = self.chip_complex.memory_controller().timestamp(); let mut prev_backtrace: Option = None; @@ -226,7 +257,7 @@ impl> ExecutionSegment { #[cfg(feature = "bench-metrics")] self.update_instruction_metrics(pc, opcode, dsl_instr); - if self.should_segment() { + if self.should_segment(segment_strategy.clone()) { self.chip_complex .connector_chip_mut() .end(ExecutionState::new(pc, timestamp), None); @@ -269,26 +300,21 @@ impl> ExecutionSegment { /// Returns bool of whether to switch to next segment or not. This is called every clock cycle inside of Core trace generation. /// /// Default config: switch if any runtime chip height exceeds 1<<20 - 100 - fn should_segment(&mut self) -> bool { + fn should_segment(&mut self, segment_strategy: Option>) -> bool { + if segment_strategy.is_none() { + return false; + } + let segment_strategy = segment_strategy.unwrap(); // Avoid checking segment too often. if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL { self.since_last_segment_check += 1; return false; } self.since_last_segment_check = 0; - let heights = self.chip_complex.dynamic_trace_heights(); - for (i, height) in heights.enumerate() { - if height > self.system_config().max_segment_len { - tracing::info!( - "Should segment because chip {} has height {}", - self.air_names[i], - height - ); - return true; - } - } - - false + segment_strategy.should_segment( + &self.chip_complex.current_trace_heights(), + &self.chip_complex.current_trace_cells(), + ) } pub fn current_trace_cells(&self) -> Vec { diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index da4c8fc09..b556ed06a 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -15,7 +15,10 @@ use openvm_stark_backend::{ use thiserror::Error; use tracing::info_span; -use super::{ExecutionError, VmComplexTraceHeights, VmConfig, CONNECTOR_AIR_ID, MERKLE_AIR_ID}; +use super::{ + DefaultSegmentationStrategy, ExecutionError, SegmentationStrategy, VmComplexTraceHeights, + VmConfig, CONNECTOR_AIR_ID, MERKLE_AIR_ID, +}; use crate::{ arch::segment::ExecutionSegment, system::{ @@ -58,6 +61,7 @@ impl From>> for Streams { pub struct VmExecutor { pub config: VC, pub overridden_heights: Option, + pub segmentation_strategy: Arc, _marker: PhantomData, } @@ -94,13 +98,19 @@ where config: VC, overridden_heights: Option, ) -> Self { + let segmentation_strategy = Arc::new(DefaultSegmentationStrategy); Self { config, overridden_heights, + segmentation_strategy, _marker: Default::default(), } } + pub fn set_custom_segmentation_strategy(&mut self, strategy: Arc) { + self.segmentation_strategy = strategy; + } + pub fn continuation_enabled(&self) -> bool { self.config.system().continuation_enabled } @@ -128,7 +138,9 @@ where loop { // Used to add `segment` label to metrics let _span = info_span!("execute_segment", segment = segments.len()).entered(); - let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?; + let state = metrics_span("execute_time_ms", || { + segment.execute_from_pc(pc, Some(self.segmentation_strategy.clone())) + })?; pc = state.pc; if state.is_terminated { @@ -351,7 +363,9 @@ where if let Some(overridden_heights) = self.overridden_heights.as_ref() { segment.set_override_trace_heights(overridden_heights.clone()); } - metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?; + metrics_span("execute_time_ms", || { + segment.execute_from_pc(pc_start, None) + })?; Ok(segment) } } From d9bf83f0c35b6e9832e3e0c54efd8ca0cda1f82f Mon Sep 17 00:00:00 2001 From: luffykai Date: Tue, 21 Jan 2025 14:45:08 -0800 Subject: [PATCH 2/7] make it compile --- benchmarks/src/bin/fib_e2e.rs | 5 +-- crates/sdk/tests/integration_test.rs | 1 - crates/vm/src/arch/segment.rs | 18 ++++++++-- crates/vm/src/arch/vm.rs | 2 +- crates/vm/src/metrics/mod.rs | 2 +- crates/vm/tests/integration_test.rs | 4 +-- extensions/native/circuit/src/extension.rs | 3 +- extensions/native/circuit/src/utils.rs | 4 +-- extensions/rv32im/circuit/src/extension.rs | 40 +++++++++++----------- 9 files changed, 42 insertions(+), 37 deletions(-) diff --git a/benchmarks/src/bin/fib_e2e.rs b/benchmarks/src/bin/fib_e2e.rs index 32cd9442a..d3f376d42 100644 --- a/benchmarks/src/bin/fib_e2e.rs +++ b/benchmarks/src/bin/fib_e2e.rs @@ -24,10 +24,7 @@ async fn main() -> Result<()> { // Must be larger than RangeTupleCheckerAir.height == 524288 let max_segment_length = args.max_segment_length.unwrap_or(1_000_000); - let app_config = args.app_config(Rv32ImConfig::with_public_values_and_segment_len( - NUM_PUBLIC_VALUES, - max_segment_length, - )); + let app_config = args.app_config(Rv32ImConfig::with_public_values(NUM_PUBLIC_VALUES)); let agg_config = args.agg_config(); let sdk = Sdk; diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 7fa7cd768..57c5ce96e 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -127,7 +127,6 @@ fn small_test_app_config(app_log_blowup: usize) -> AppConfig { .into(), app_vm_config: NativeConfig::new( SystemConfig::default() - .with_max_segment_len(200) .with_continuations() .with_public_values(NUM_PUB_VALUES), Native, diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 702e56dfa..9c2d7dc38 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -43,18 +43,30 @@ pub trait SegmentationStrategy { } /// Default segmentation strategy: segment if any chip's height or cells exceed the limits. -pub struct DefaultSegmentationStrategy; +pub struct DefaultSegmentationStrategy { + max_segment_len: usize, + max_cells_per_chip_in_segment: usize, +} + +impl Default for DefaultSegmentationStrategy { + fn default() -> Self { + Self { + max_segment_len: DEFAULT_MAX_SEGMENT_LEN, + max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, + } + } +} impl SegmentationStrategy for DefaultSegmentationStrategy { fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool { for (i, &height) in trace_heights.iter().enumerate() { - if height > DEFAULT_MAX_SEGMENT_LEN { + if height > self.max_segment_len { tracing::info!("Should segment because chip {} has height {}", i, height); return true; } } for (i, &num_cells) in trace_cells.iter().enumerate() { - if num_cells > DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT { + if num_cells > self.max_cells_per_chip_in_segment { tracing::info!("Should segment because chip {} has {} cells", i, num_cells,); return true; } diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index b556ed06a..edeec5962 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -98,7 +98,7 @@ where config: VC, overridden_heights: Option, ) -> Self { - let segmentation_strategy = Arc::new(DefaultSegmentationStrategy); + let segmentation_strategy = Arc::new(DefaultSegmentationStrategy::default()); Self { config, overridden_heights, diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index c9815f5c3..e42a1cdbc 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -49,7 +49,7 @@ where let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); let opcode_name = executor.get_opcode_name(opcode.as_usize()); self.metrics.update_trace_cells( - &self.air_names, + &self.chip_complex.air_names(), self.current_trace_cells(), opcode_name, dsl_instr, diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 0468806ba..ce4bc1bd8 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -443,7 +443,7 @@ fn test_vm_continuations() { let n = 200000; let program = gen_continuation_test_program(n); let config = NativeConfig { - system: SystemConfig::new(3, MemoryConfig::default(), 0).with_max_segment_len(200000), + system: SystemConfig::new(3, MemoryConfig::default(), 0), native: Default::default(), } .with_continuations(); @@ -473,7 +473,7 @@ fn test_vm_continuations_recover_state() { let n = 2000; let program = gen_continuation_test_program(n); let config = NativeConfig { - system: SystemConfig::new(3, MemoryConfig::default(), 0).with_max_segment_len(500), + system: SystemConfig::new(3, MemoryConfig::default(), 0), native: Default::default(), } .with_continuations(); diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 91a91cd8a..f73b4ab0e 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -67,8 +67,7 @@ impl NativeConfig { ..Default::default() }, num_public_values, - ) - .with_max_segment_len((1 << 24) - 100), + ), native: Default::default(), } } diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index 0c188998c..481c54b57 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -5,9 +5,7 @@ use openvm_stark_sdk::p3_baby_bear::BabyBear; use crate::{Native, NativeConfig}; pub fn execute_program(program: Program, input_stream: impl Into>) { - let system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); + let system_config = SystemConfig::default().with_public_values(4); let config = NativeConfig::new(system_config, Native); let executor = VmExecutor::::new(config); diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index 9669b84a8..ff7c2e979 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -83,17 +83,17 @@ impl Rv32IConfig { } } - pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - let system = SystemConfig::default() - .with_continuations() - .with_public_values(public_values) - .with_max_segment_len(segment_len); - Self { - system, - base: Default::default(), - io: Default::default(), - } - } + // pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { + // let system = SystemConfig::default() + // .with_continuations() + // .with_public_values(public_values) + // .with_max_segment_len(segment_len); + // Self { + // system, + // base: Default::default(), + // io: Default::default(), + // } + // } } impl Rv32ImConfig { @@ -107,15 +107,15 @@ impl Rv32ImConfig { } } - pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - let inner = Rv32IConfig::with_public_values_and_segment_len(public_values, segment_len); - Self { - system: inner.system, - base: inner.base, - mul: Default::default(), - io: Default::default(), - } - } + // pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { + // let inner = Rv32IConfig::with_public_values_and_segment_len(public_values, segment_len); + // Self { + // system: inner.system, + // base: inner.base, + // mul: Default::default(), + // io: Default::default(), + // } + // } } // ============ Extension Implementations ============ From e59218f901dd622760291474ed7e12c5ac320d27 Mon Sep 17 00:00:00 2001 From: luffykai Date: Tue, 21 Jan 2025 19:29:54 -0800 Subject: [PATCH 3/7] fix test --- benchmarks/src/bin/fib_e2e.rs | 1 + crates/sdk/examples/sdk_evm.rs | 1 + crates/sdk/src/lib.rs | 6 ++++- crates/sdk/src/prover/app.rs | 3 +-- crates/sdk/src/prover/mod.rs | 16 +++++++++++--- crates/sdk/src/prover/stark.rs | 4 ++-- crates/sdk/src/prover/vm/local.rs | 17 +++++++++++--- crates/sdk/tests/integration_test.rs | 12 +++++++--- crates/vm/src/arch/segment.rs | 9 ++++++++ crates/vm/src/utils/stark_utils.rs | 33 ++++++++++++++++++++++++++++ 10 files changed, 88 insertions(+), 14 deletions(-) diff --git a/benchmarks/src/bin/fib_e2e.rs b/benchmarks/src/bin/fib_e2e.rs index d3f376d42..fbdd95d5a 100644 --- a/benchmarks/src/bin/fib_e2e.rs +++ b/benchmarks/src/bin/fib_e2e.rs @@ -55,6 +55,7 @@ async fn main() -> Result<()> { run_with_metric_collection("OUTPUT_PATH", || { let mut e2e_prover = ContinuationProver::new(&halo2_params_reader, app_pk, app_committed_exe, full_agg_pk); + e2e_prover.set_max_segment_len(max_segment_length); e2e_prover.set_program_name("fib_e2e"); let _proof = e2e_prover.generate_proof_for_evm(stdin); }); diff --git a/crates/sdk/examples/sdk_evm.rs b/crates/sdk/examples/sdk_evm.rs index 41b91c71c..ecefed411 100644 --- a/crates/sdk/examples/sdk_evm.rs +++ b/crates/sdk/examples/sdk_evm.rs @@ -105,6 +105,7 @@ fn main() -> Result<(), Box> { app_pk, app_committed_exe, agg_pk, + None, stdin, )?; diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index 67048a003..fa39e5d09 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -190,13 +190,17 @@ impl Sdk { app_pk: Arc>, app_exe: Arc, agg_pk: AggProvingKey, + max_segment_len: Option, inputs: StdIn, ) -> Result where VC::Executor: Chip, VC::Periphery: Chip, { - let e2e_prover = ContinuationProver::new(reader, app_pk, app_exe, agg_pk); + let mut e2e_prover = ContinuationProver::new(reader, app_pk, app_exe, agg_pk); + if let Some(max_segment_len) = max_segment_len { + e2e_prover.set_max_segment_len(max_segment_len); + } let proof = e2e_prover.generate_proof_for_evm(inputs); Ok(proof) } diff --git a/crates/sdk/src/prover/app.rs b/crates/sdk/src/prover/app.rs index 2a0123ccc..de22b3b37 100644 --- a/crates/sdk/src/prover/app.rs +++ b/crates/sdk/src/prover/app.rs @@ -17,8 +17,7 @@ use crate::{ #[derive(Getters)] pub struct AppProver { pub program_name: Option, - #[getset(get = "pub")] - app_prover: VmLocalProver, + pub app_prover: VmLocalProver, } impl AppProver { diff --git a/crates/sdk/src/prover/mod.rs b/crates/sdk/src/prover/mod.rs index 08c602f03..25b99beb7 100644 --- a/crates/sdk/src/prover/mod.rs +++ b/crates/sdk/src/prover/mod.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use openvm_circuit::arch::VmConfig; +use openvm_circuit::arch::{DefaultSegmentationStrategy, VmConfig}; use openvm_native_recursion::halo2::EvmProof; use openvm_stark_sdk::openvm_stark_backend::Chip; @@ -29,8 +29,8 @@ use crate::{ }; pub struct ContinuationProver { - stark_prover: StarkProver, - halo2_prover: Halo2Prover, + pub stark_prover: StarkProver, + pub halo2_prover: Halo2Prover, } impl ContinuationProver { @@ -59,6 +59,16 @@ impl ContinuationProver { self } + pub fn set_max_segment_len(&mut self, max_segment_len: usize) -> &mut Self { + self.stark_prover + .app_prover + .app_prover + .set_custom_segmentation_strategy(Arc::new( + DefaultSegmentationStrategy::new_with_max_segment_len(max_segment_len), + )); + self + } + pub fn generate_proof_for_evm(&self, input: StdIn) -> EvmProof where VC: VmConfig, diff --git a/crates/sdk/src/prover/stark.rs b/crates/sdk/src/prover/stark.rs index 87f5aa2cd..07f1e060f 100644 --- a/crates/sdk/src/prover/stark.rs +++ b/crates/sdk/src/prover/stark.rs @@ -10,8 +10,8 @@ use crate::{ }; pub struct StarkProver { - app_prover: AppProver, - agg_prover: AggStarkProver, + pub app_prover: AppProver, + pub agg_prover: AggStarkProver, } impl StarkProver { pub fn new( diff --git a/crates/sdk/src/prover/vm/local.rs b/crates/sdk/src/prover/vm/local.rs index a8f6889da..bc517ccf0 100644 --- a/crates/sdk/src/prover/vm/local.rs +++ b/crates/sdk/src/prover/vm/local.rs @@ -3,8 +3,8 @@ use std::{marker::PhantomData, sync::Arc}; use async_trait::async_trait; use openvm_circuit::{ arch::{ - hasher::poseidon2::vm_poseidon2_hasher, SingleSegmentVmExecutor, Streams, VirtualMachine, - VmComplexTraceHeights, VmConfig, + hasher::poseidon2::vm_poseidon2_hasher, segment::SegmentationStrategy, + SingleSegmentVmExecutor, Streams, VirtualMachine, VmComplexTraceHeights, VmConfig, }, system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, }; @@ -25,6 +25,7 @@ pub struct VmLocalProver> { pub pk: Arc>, pub committed_exe: Arc>, overridden_heights: Option, + segmentation_strategy: Option>, _marker: PhantomData, } @@ -34,6 +35,7 @@ impl> VmLocalProver pk, committed_exe, overridden_heights: None, + segmentation_strategy: None, _marker: PhantomData, } } @@ -47,6 +49,7 @@ impl> VmLocalProver pk, committed_exe, overridden_heights, + segmentation_strategy: None, _marker: PhantomData, } } @@ -55,6 +58,10 @@ impl> VmLocalProver self.overridden_heights = Some(overridden_heights); } + pub fn set_custom_segmentation_strategy(&mut self, strategy: Arc) { + self.segmentation_strategy = Some(strategy); + } + pub fn vm_config(&self) -> &VC { &self.pk.vm_config } @@ -74,11 +81,15 @@ where fn prove(&self, input: impl Into>>) -> ContinuationVmProof { assert!(self.pk.vm_config.system().continuation_enabled); let e = E::new(self.pk.fri_params); - let vm = VirtualMachine::new_with_overridden_trace_heights( + let mut vm = VirtualMachine::new_with_overridden_trace_heights( e, self.pk.vm_config.clone(), self.overridden_heights.clone(), ); + if let Some(segmentation_strategy) = self.segmentation_strategy.clone() { + vm.executor + .set_custom_segmentation_strategy(segmentation_strategy); + } let results = vm .execute_and_generate_with_cached_program(self.committed_exe.clone(), input) .unwrap(); diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 57c5ce96e..416b94e73 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -3,8 +3,8 @@ use std::{borrow::Borrow, path::PathBuf, sync::Arc}; use openvm_build::GuestOptions; use openvm_circuit::{ arch::{ - hasher::poseidon2::vm_poseidon2_hasher, ExecutionError, SingleSegmentVmExecutor, - SystemConfig, VmConfig, VmExecutor, + hasher::poseidon2::vm_poseidon2_hasher, DefaultSegmentationStrategy, ExecutionError, + SingleSegmentVmExecutor, SystemConfig, VmConfig, VmExecutor, }, system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, }; @@ -153,7 +153,11 @@ fn test_public_values_and_leaf_verification() { let leaf_committed_exe = app_pk.leaf_committed_exe.clone(); let app_engine = BabyBearPoseidon2Engine::new(app_pk.app_vm_pk.fri_params); - let app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); + + let segmentation_strategy = + Arc::new(DefaultSegmentationStrategy::new_with_max_segment_len(200)); + let mut app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); + app_vm.set_custom_segmentation_strategy(segmentation_strategy); let app_vm_result = app_vm .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) .unwrap(); @@ -349,6 +353,7 @@ fn test_static_verifier_custom_pv_handler() { Arc::new(app_pk), app_committed_exe, agg_pk, + Some(200), StdIn::default(), ) .unwrap(); @@ -378,6 +383,7 @@ fn test_e2e_proof_generation_and_verification() { Arc::new(app_pk), app_committed_exe_for_test(app_log_blowup), agg_pk, + Some(200), StdIn::default(), ) .unwrap(); diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 9c2d7dc38..542ee8269 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -57,6 +57,15 @@ impl Default for DefaultSegmentationStrategy { } } +impl DefaultSegmentationStrategy { + pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { + Self { + max_segment_len, + max_cells_per_chip_in_segment: max_segment_len * 120, + } + } +} + impl SegmentationStrategy for DefaultSegmentationStrategy { fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool { for (i, &height) in trace_heights.iter().enumerate() { diff --git a/crates/vm/src/utils/stark_utils.rs b/crates/vm/src/utils/stark_utils.rs index fbbdc3529..e0bc40c3d 100644 --- a/crates/vm/src/utils/stark_utils.rs +++ b/crates/vm/src/utils/stark_utils.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use openvm_instructions::{exe::VmExe, program::Program}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -15,6 +17,7 @@ use openvm_stark_sdk::{ }; use crate::arch::{ + segment::DefaultSegmentationStrategy, vm::{VirtualMachine, VmExecutor}, Streams, VmConfig, VmMemoryState, }; @@ -54,6 +57,36 @@ where final_memory } +/// Executes the VM and returns the final memory state. +pub fn air_test_with_custom_segmentation( + config: VC, + exe: impl Into>, + input: impl Into>, + min_segments: usize, + max_segment_len: usize, +) -> Option> +where + VC: VmConfig, + VC::Executor: Chip, + VC::Periphery: Chip, +{ + setup_tracing(); + let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let mut vm = VirtualMachine::new(engine, config); + vm.executor.set_custom_segmentation_strategy(Arc::new( + DefaultSegmentationStrategy::new_with_max_segment_len(max_segment_len), + )); + let pk = vm.keygen(); + let mut result = vm.execute_and_generate(exe, input).unwrap(); + let final_memory = result.final_memory.take(); + let proofs = vm.prove(&pk, result); + + assert!(proofs.len() >= min_segments); + vm.verify(&pk.get_vk(), proofs) + .expect("segment proofs should verify"); + final_memory +} + // TODO[jpw]: this should be deleted once tests switch to new API /// Generates the VM STARK circuit, in the form of AIRs and traces, but does not /// do any proving. Output is the payload of everything the prover needs. From eb571c9de61cbda4e651351270dd7e5be4702d5d Mon Sep 17 00:00:00 2001 From: luffykai Date: Tue, 21 Jan 2025 21:00:04 -0800 Subject: [PATCH 4/7] fix --- crates/cli/src/commands/prove.rs | 10 ++++++++-- crates/vm/tests/integration_test.rs | 18 +++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/crates/cli/src/commands/prove.rs b/crates/cli/src/commands/prove.rs index bf263d2c1..fded55894 100644 --- a/crates/cli/src/commands/prove.rs +++ b/crates/cli/src/commands/prove.rs @@ -84,8 +84,14 @@ impl ProveCmd { let agg_pk = read_agg_pk_from_file(DEFAULT_AGG_PK_PATH).map_err(|e| { eyre::eyre!("Failed to read aggregation proving key: {}\nPlease run 'cargo openvm setup' first", e) })?; - let evm_proof = - Sdk.generate_evm_proof(¶ms_reader, app_pk, committed_exe, agg_pk, input)?; + let evm_proof = Sdk.generate_evm_proof( + ¶ms_reader, + app_pk, + committed_exe, + agg_pk, + None, + input, + )?; write_evm_proof_to_file(evm_proof, output)?; } } diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index ce4bc1bd8..3dd5fb9d7 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -4,10 +4,10 @@ use derive_more::derive::From; use openvm_circuit::{ arch::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, - ChipId, ExecutionSegment, ExitCode, MemoryConfig, SingleSegmentVmExecutor, Streams, - SystemConfig, SystemExecutor, SystemPeriphery, SystemTraceHeights, VirtualMachine, - VmChipComplex, VmComplexTraceHeights, VmConfig, VmExecutorResult, VmInventoryError, - VmInventoryTraceHeights, + ChipId, DefaultSegmentationStrategy, ExecutionSegment, ExitCode, MemoryConfig, + SingleSegmentVmExecutor, Streams, SystemConfig, SystemExecutor, SystemPeriphery, + SystemTraceHeights, VirtualMachine, VmChipComplex, VmComplexTraceHeights, VmConfig, + VmExecutorResult, VmInventoryError, VmInventoryTraceHeights, }, derive::{AnyEnum, InstructionExecutor, VmConfig}, system::{ @@ -17,7 +17,7 @@ use openvm_circuit::{ }, program::trace::VmCommittedExe, }, - utils::{air_test, air_test_with_min_segments}, + utils::{air_test, air_test_with_custom_segmentation, air_test_with_min_segments}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{ @@ -459,7 +459,8 @@ fn test_vm_continuations() { }; let memory_dimensions = config.system.memory_config.memory_dimensions(); - let final_state = air_test_with_min_segments(config, program, vec![], 3).unwrap(); + let final_state = + air_test_with_custom_segmentation(config, program, vec![], 3, 200000).unwrap(); let hasher = vm_poseidon2_hasher(); let num_public_values = 8; let pv_proof = @@ -478,7 +479,10 @@ fn test_vm_continuations_recover_state() { } .with_continuations(); let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); - let vm = VirtualMachine::new(engine, config.clone()); + let mut vm = VirtualMachine::new(engine, config.clone()); + vm.executor.set_custom_segmentation_strategy(Arc::new( + DefaultSegmentationStrategy::new_with_max_segment_len(500), + )); let pk = vm.keygen(); let segments = vm .executor From 4cf96652e6a382cdd4e57452220aa75f53d8b7be Mon Sep 17 00:00:00 2001 From: luffykai Date: Wed, 22 Jan 2025 04:34:01 -0800 Subject: [PATCH 5/7] fix --- crates/vm/src/arch/segment.rs | 34 ++++++++++++++++++---- extensions/native/circuit/src/utils.rs | 11 +++++-- extensions/rv32im/circuit/src/extension.rs | 22 -------------- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 542ee8269..21610ff09 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -39,7 +39,12 @@ const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; pub trait SegmentationStrategy { - fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool; + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool; } /// Default segmentation strategy: segment if any chip's height or cells exceed the limits. @@ -67,16 +72,31 @@ impl DefaultSegmentationStrategy { } impl SegmentationStrategy for DefaultSegmentationStrategy { - fn should_segment(&self, trace_heights: &[usize], trace_cells: &[usize]) -> bool { + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool { for (i, &height) in trace_heights.iter().enumerate() { if height > self.max_segment_len { - tracing::info!("Should segment because chip {} has height {}", i, height); + tracing::info!( + "Should segment because chip {} (name: {}) has height {}", + i, + air_names[i], + height + ); return true; } } for (i, &num_cells) in trace_cells.iter().enumerate() { if num_cells > self.max_cells_per_chip_in_segment { - tracing::info!("Should segment because chip {} has {} cells", i, num_cells,); + tracing::info!( + "Should segment because chip {} (name: {}) has {} cells", + i, + air_names[i], + num_cells + ); return true; } } @@ -333,7 +353,11 @@ impl> ExecutionSegment { } self.since_last_segment_check = 0; segment_strategy.should_segment( - &self.chip_complex.current_trace_heights(), + &self.chip_complex.air_names(), + &self + .chip_complex + .dynamic_trace_heights() + .collect::>(), &self.chip_complex.current_trace_cells(), ) } diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index 481c54b57..fd38719d6 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,4 +1,8 @@ -use openvm_circuit::arch::{Streams, SystemConfig, VmExecutor}; +use std::sync::Arc; + +use openvm_circuit::arch::{ + segment::DefaultSegmentationStrategy, Streams, SystemConfig, VmExecutor, +}; use openvm_instructions::program::Program; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -7,7 +11,10 @@ use crate::{Native, NativeConfig}; pub fn execute_program(program: Program, input_stream: impl Into>) { let system_config = SystemConfig::default().with_public_values(4); let config = NativeConfig::new(system_config, Native); - let executor = VmExecutor::::new(config); + let mut executor = VmExecutor::::new(config); + executor.set_custom_segmentation_strategy(Arc::new( + DefaultSegmentationStrategy::new_with_max_segment_len(500), + )); executor.execute(program, input_stream).unwrap(); } diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index ff7c2e979..57ea0538f 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -82,18 +82,6 @@ impl Rv32IConfig { io: Default::default(), } } - - // pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - // let system = SystemConfig::default() - // .with_continuations() - // .with_public_values(public_values) - // .with_max_segment_len(segment_len); - // Self { - // system, - // base: Default::default(), - // io: Default::default(), - // } - // } } impl Rv32ImConfig { @@ -106,16 +94,6 @@ impl Rv32ImConfig { io: Default::default(), } } - - // pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - // let inner = Rv32IConfig::with_public_values_and_segment_len(public_values, segment_len); - // Self { - // system: inner.system, - // base: inner.base, - // mul: Default::default(), - // io: Default::default(), - // } - // } } // ============ Extension Implementations ============ From 7bf930dd6b7dfb41f198f649cd5d47e966c4f905 Mon Sep 17 00:00:00 2001 From: luffykai Date: Wed, 22 Jan 2025 04:35:17 -0800 Subject: [PATCH 6/7] fix --- extensions/native/circuit/src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index fd38719d6..f5a18887a 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -13,7 +13,7 @@ pub fn execute_program(program: Program, input_stream: impl Into::new(config); executor.set_custom_segmentation_strategy(Arc::new( - DefaultSegmentationStrategy::new_with_max_segment_len(500), + DefaultSegmentationStrategy::new_with_max_segment_len((1 << 25) - 100), )); executor.execute(program, input_stream).unwrap(); From edd0f0602941e956440bc3b5d9cde79914e6389e Mon Sep 17 00:00:00 2001 From: luffykai Date: Wed, 22 Jan 2025 14:06:26 -0800 Subject: [PATCH 7/7] serde works --- crates/vm/src/arch/config.rs | 23 +++++++++++++++++++++-- crates/vm/src/arch/segment.rs | 3 ++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index cb60ca1e7..ab612b9bc 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use derive_new::new; use openvm_circuit::system::memory::MemoryTraceHeights; use openvm_instructions::program::DEFAULT_MAX_NUM_PUBLIC_VALUES; @@ -13,8 +15,9 @@ pub use super::testing::{ POSEIDON2_DIRECT_BUS, RANGE_TUPLE_CHECKER_BUS, READ_INSTRUCTION_BUS, }; use super::{ - AnyEnum, InstructionExecutor, SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, - VmInventoryError, PUBLIC_VALUES_AIR_ID, + segment::SegmentationStrategy, AnyEnum, DefaultSegmentationStrategy, InstructionExecutor, + SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, VmInventoryError, + PUBLIC_VALUES_AIR_ID, }; use crate::system::memory::BOUNDARY_AIR_OFFSET; @@ -63,6 +66,15 @@ impl Default for MemoryConfig { } } +#[derive(Debug, Clone)] +pub struct SegmentationStrategyArc(Arc); + +impl Default for SegmentationStrategyArc { + fn default() -> Self { + Self(Arc::new(DefaultSegmentationStrategy::default())) + } +} + /// System-level configuration for the virtual machine. Contains all configuration parameters that /// are managed by the architecture, including configuration for continuations support. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -88,6 +100,11 @@ pub struct SystemConfig { /// Whether to collect detailed profiling metrics. /// **Warning**: this slows down the runtime. pub profiling: bool, + /// Segmentation strategy + /// This field is skipped in serde as it's only used in execution and + /// not needed after any serialize/deserialize. + #[serde(skip)] + pub segmentation_strategy: SegmentationStrategyArc, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -102,12 +119,14 @@ impl SystemConfig { memory_config: MemoryConfig, num_public_values: usize, ) -> Self { + let segmentation_strategy = SegmentationStrategyArc::default(); Self { max_constraint_degree, continuation_enabled: false, memory_config, num_public_values, profiling: false, + segmentation_strategy, } } diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 21610ff09..b42e4e38a 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -38,7 +38,7 @@ const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; // its trace width is 36 and its after challenge trace width is 80. const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; -pub trait SegmentationStrategy { +pub trait SegmentationStrategy: std::fmt::Debug { fn should_segment( &self, air_names: &[String], @@ -48,6 +48,7 @@ pub trait SegmentationStrategy { } /// Default segmentation strategy: segment if any chip's height or cells exceed the limits. +#[derive(Debug)] pub struct DefaultSegmentationStrategy { max_segment_len: usize, max_cells_per_chip_in_segment: usize,