diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index 9369d7b4b..ca7c98a64 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -74,6 +74,17 @@ pub struct VmExecutorResult { pub final_memory: Option>>, } +pub struct VmExecutorNextSegmentState { + pub memory: MemoryImage, + pub input: Streams, + pub pc: u32, +} + +pub struct VmExecutorOneSegmentResult> { + pub segment: ExecutionSegment, + pub next_state: Option>, +} + impl VmExecutor where F: PrimeField32, @@ -110,70 +121,90 @@ where exe: impl Into>, input: impl Into>, ) -> Result>, ExecutionError> { + let mem_config = self.config.system().memory_config; let exe = exe.into(); - let streams = input.into(); + let mut streams = input.into(); let mut segments = vec![]; - let mem_config = self.config.system().memory_config; + let mut memory = AddressMap::from_iter( + mem_config.as_offset, + 1 << mem_config.as_height, + 1 << mem_config.pointer_max_bits, + exe.init_memory.clone(), + ); + let mut pc = exe.pc_start; + let mut segment_idx = 0; + + loop { + let _span = info_span!("execute_segment", segment = segment_idx).entered(); + let one_segment_result = + self.execute_until_segment(exe.clone(), memory, streams, pc)?; + segments.push(one_segment_result.segment); + if one_segment_result.next_state.is_none() { + break; + } + let next_state = one_segment_result.next_state.unwrap(); + memory = next_state.memory; + pc = next_state.pc; + streams = next_state.input; + segment_idx += 1; + } + tracing::debug!("Number of continuation segments: {}", segments.len()); + + Ok(segments) + } + + /// Executes a program until a segmentation happens. + /// Returns the last segment and the vm state for next segment. + /// This is so that the tracegen and proving of this segment can be immediately started (on a separate machine). + pub fn execute_until_segment( + &self, + exe: impl Into>, + memory: MemoryImage, + input: impl Into>, + pc: u32, + ) -> Result, ExecutionError> { + let exe = exe.into(); + let streams = input.into(); let mut segment = ExecutionSegment::new( &self.config, exe.program.clone(), streams, - Some(AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - exe.init_memory.clone(), - )), + Some(memory), exe.fn_bounds.clone(), ); if let Some(overridden_heights) = self.overridden_heights.as_ref() { segment.set_override_trace_heights(overridden_heights.clone()); } - let mut pc = exe.pc_start; - - loop { - // Used to add `segment` label to metrics - let _span = info_span!("execute_segment", segment = segments.len()).entered(); - let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?; - pc = state.pc; + let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?; - if state.is_terminated { - break; - } - - assert!( - self.continuation_enabled(), - "multiple segments require to enable continuations" - ); - - assert_eq!( - pc, - segment.chip_complex.connector_chip().boundary_states[1] - .unwrap() - .pc - ); - - let final_memory = mem::take(&mut segment.final_memory) - .expect("final memory should be set in continuations segment"); - let streams = segment.chip_complex.take_streams(); - - segments.push(segment); - - segment = ExecutionSegment::new( - &self.config, - exe.program.clone(), - streams, - Some(final_memory), - exe.fn_bounds.clone(), - ); - if let Some(overridden_heights) = self.overridden_heights.as_ref() { - segment.set_override_trace_heights(overridden_heights.clone()); - } + if state.is_terminated { + return Ok(VmExecutorOneSegmentResult { + segment, + next_state: None, + }); } - segments.push(segment); - tracing::debug!("Number of continuation segments: {}", segments.len()); - Ok(segments) + assert!( + self.continuation_enabled(), + "multiple segments require to enable continuations" + ); + assert_eq!( + state.pc, + segment.chip_complex.connector_chip().boundary_states[1] + .unwrap() + .pc + ); + let final_memory = mem::take(&mut segment.final_memory) + .expect("final memory should be set in continuations segment"); + let streams = segment.chip_complex.take_streams(); + Ok(VmExecutorOneSegmentResult { + segment, + next_state: Some(VmExecutorNextSegmentState { + memory: final_memory, + input: streams, + pc: state.pc, + }), + }) } pub fn execute(