Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] support execute until segment #1271

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 81 additions & 50 deletions crates/vm/src/arch/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ pub struct VmExecutorResult<SC: StarkGenericConfig> {
pub final_memory: Option<VmMemoryState<Val<SC>>>,
}

pub struct VmExecutorNextSegmentState<F: PrimeField32> {
pub memory: MemoryImage<F>,
pub input: Streams<F>,
pub pc: u32,
}

pub struct VmExecutorOneSegmentResult<F: PrimeField32, VC: VmConfig<F>> {
pub segment: ExecutionSegment<F, VC>,
pub next_state: Option<VmExecutorNextSegmentState<F>>,
}

impl<F, VC> VmExecutor<F, VC>
where
F: PrimeField32,
Expand Down Expand Up @@ -110,70 +121,90 @@ where
exe: impl Into<VmExe<F>>,
input: impl Into<Streams<F>>,
) -> Result<Vec<ExecutionSegment<F, VC>>, ExecutionError> {
let mem_config = self.config.system().memory_config;
let exe = exe.into();
let streams = input.into();
let mut streams = input.into();
let mut segments = vec![];
let mem_config = self.config.system().memory_config;
let mut memory = AddressMap::from_iter(
mem_config.as_offset,
1 << mem_config.as_height,
1 << mem_config.pointer_max_bits,
exe.init_memory.clone(),
);
let mut pc = exe.pc_start;
let mut segment_idx = 0;

loop {
let _span = info_span!("execute_segment", segment = segment_idx).entered();
let one_segment_result =
self.execute_until_segment(exe.clone(), memory, streams, pc)?;
segments.push(one_segment_result.segment);
if one_segment_result.next_state.is_none() {
break;
}
let next_state = one_segment_result.next_state.unwrap();
memory = next_state.memory;
pc = next_state.pc;
streams = next_state.input;
segment_idx += 1;
}
tracing::debug!("Number of continuation segments: {}", segments.len());

Ok(segments)
}

/// Executes a program until a segmentation happens.
/// Returns the last segment and the vm state for next segment.
/// This is so that the tracegen and proving of this segment can be immediately started (on a separate machine).
pub fn execute_until_segment(
&self,
exe: impl Into<VmExe<F>>,
memory: MemoryImage<F>,
input: impl Into<Streams<F>>,
pc: u32,
) -> Result<VmExecutorOneSegmentResult<F, VC>, ExecutionError> {
let exe = exe.into();
let streams = input.into();
let mut segment = ExecutionSegment::new(
&self.config,
exe.program.clone(),
streams,
Some(AddressMap::from_iter(
mem_config.as_offset,
1 << mem_config.as_height,
1 << mem_config.pointer_max_bits,
exe.init_memory.clone(),
)),
Some(memory),
exe.fn_bounds.clone(),
);
if let Some(overridden_heights) = self.overridden_heights.as_ref() {
segment.set_override_trace_heights(overridden_heights.clone());
}
let mut pc = exe.pc_start;

loop {
// Used to add `segment` label to metrics
let _span = info_span!("execute_segment", segment = segments.len()).entered();
let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?;
pc = state.pc;
let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?;

if state.is_terminated {
break;
}

assert!(
self.continuation_enabled(),
"multiple segments require to enable continuations"
);

assert_eq!(
pc,
segment.chip_complex.connector_chip().boundary_states[1]
.unwrap()
.pc
);

let final_memory = mem::take(&mut segment.final_memory)
.expect("final memory should be set in continuations segment");
let streams = segment.chip_complex.take_streams();

segments.push(segment);

segment = ExecutionSegment::new(
&self.config,
exe.program.clone(),
streams,
Some(final_memory),
exe.fn_bounds.clone(),
);
if let Some(overridden_heights) = self.overridden_heights.as_ref() {
segment.set_override_trace_heights(overridden_heights.clone());
}
if state.is_terminated {
return Ok(VmExecutorOneSegmentResult {
segment,
next_state: None,
});
}
segments.push(segment);
tracing::debug!("Number of continuation segments: {}", segments.len());

Ok(segments)
assert!(
self.continuation_enabled(),
"multiple segments require to enable continuations"
);
assert_eq!(
state.pc,
segment.chip_complex.connector_chip().boundary_states[1]
.unwrap()
.pc
);
let final_memory = mem::take(&mut segment.final_memory)
.expect("final memory should be set in continuations segment");
let streams = segment.chip_complex.take_streams();
Ok(VmExecutorOneSegmentResult {
segment,
next_state: Some(VmExecutorNextSegmentState {
memory: final_memory,
input: streams,
pc: state.pc,
}),
})
}

pub fn execute(
Expand Down