Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: clean up for loops and breaks #1234

Merged
merged 17 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/sdk/src/verifier/common/non_leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ impl<C: Config> NonLeafVerifierVariables<C> {
let pvs = VmVerifierPvs::<Felt<C::F>>::uninit(builder);
let leaf_verifier_commit = array::from_fn(|_| builder.uninit());

builder.range(0, proofs.len()).for_each(|i, builder| {
builder.range(0, proofs.len()).for_each(|i_vec, builder| {
let i = i_vec[0];
let proof = builder.get(proofs, i);
assert_required_air_for_agg_vm_present(builder, &proof);
let proof_vm_pvs = self.verify_internal_or_leaf_verifier_proof(builder, &proof);
Expand Down
3 changes: 2 additions & 1 deletion crates/sdk/src/verifier/leaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ impl LeafVmVerifierConfig {

builder.cycle_tracker_start("VerifyProofs");
let pvs = VmVerifierPvs::<Felt<F>>::uninit(&mut builder);
builder.range(0, proofs.len()).for_each(|i, builder| {
builder.range(0, proofs.len()).for_each(|i_vec, builder| {
let i = i_vec[0];
let proof = builder.get(&proofs, i);
assert_required_air_for_app_vm_present(builder, &proof);
StarkVerifier::verify::<DuplexChallengerVariable<C>>(
Expand Down
4 changes: 2 additions & 2 deletions crates/sdk/src/verifier/leaf/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ impl Hintable<C> for UserPublicValuesRootProof<F> {
fn read(builder: &mut Builder<C>) -> Self::HintVariable {
let len = builder.hint_var();
let sibling_hashes = builder.array(len);
builder.range(0, len).for_each(|i, builder| {
builder.range(0, len).for_each(|i_vec, builder| {
// FIXME: add hint support for slices.
let hash = array::from_fn(|_| builder.hint_felt());
builder.set_value(&sibling_hashes, i, hash);
builder.set_value(&sibling_hashes, i_vec[0], hash);
});
let public_values_commit = array::from_fn(|_| builder.hint_felt());
Self::HintVariable {
Expand Down
13 changes: 1 addition & 12 deletions extensions/native/compiler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,4 @@ In **static programs**, only constant branches are allowed.
When both `start` and `end` of a loop are constant, the loop is a constant loop. The loop body will be unrolled. This
optimization saves 1 instruction per iteration.

In **static programs**, only constant loops are allowed.

## Break Support
**!!Attention!!**: Break support for constant loops is not perfect. It brings some restrictions which require
developers' awareness.:

- If you want to use `break` in a possibly constant loop, you need to use `.for_each_may_break` instead of `.for_each`.
- If you want to use `break` in a branch inside a loop, you need to use `.then_may_break`/`.then_or_else_may_break`
instead of `.for_each`/`.then_or_else`.
- Inside a **constant loop**, you cannot use a **non-constant branch** to break.


In **static programs**, only constant loops are allowed.
10 changes: 5 additions & 5 deletions extensions/native/compiler/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,24 @@ pub fn hintable_derive(input: TokenStream) -> TokenStream {
}
}

struct CompileZipArgs {
struct IterZipArgs {
builder: Expr,
args: Punctuated<Expr, Token![,]>,
}

impl Parse for CompileZipArgs {
impl Parse for IterZipArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let builder = input.parse()?;
let _: Token![,] = input.parse()?;
let args = Punctuated::parse_terminated(input)?;

Ok(CompileZipArgs { builder, args })
Ok(IterZipArgs { builder, args })
}
}

#[proc_macro]
pub fn compile_zip(input: TokenStream) -> TokenStream {
let CompileZipArgs { builder, args } = parse_macro_input!(input as CompileZipArgs);
pub fn iter_zip(input: TokenStream) -> TokenStream {
let IterZipArgs { builder, args } = parse_macro_input!(input as IterZipArgs);
let array_elements = args.iter().map(|arg| {
quote! {
Box::new(#arg.clone()) as Box<dyn ArrayLike<_>>
Expand Down
168 changes: 2 additions & 166 deletions extensions/native/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use alloc::{collections::BTreeMap, vec};
use std::collections::BTreeSet;

use openvm_circuit::arch::instructions::instruction::DebugInfo;
use openvm_stark_backend::p3_field::{ExtensionField, Field, PrimeField32, TwoAdicField};
Expand Down Expand Up @@ -29,10 +28,6 @@ pub(crate) const STACK_TOP: i32 = HEAP_START_ADDRESS - 64;
// #[derive(Debug, Clone, Default)]
pub struct AsmCompiler<F, EF> {
basic_blocks: Vec<BasicBlock<F, EF>>,
break_label: Option<F>,
break_label_map: BTreeMap<F, F>,
break_counter: usize,
contains_break: BTreeSet<F>,
function_labels: BTreeMap<String, F>,
trap_label: F,
word_size: usize,
Expand Down Expand Up @@ -74,25 +69,12 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
pub fn new(word_size: usize) -> Self {
Self {
basic_blocks: vec![BasicBlock::new()],
break_label: None,
break_label_map: BTreeMap::new(),
contains_break: BTreeSet::new(),
function_labels: BTreeMap::new(),
break_counter: 0,
trap_label: F::ONE,
word_size,
}
}

/// Creates a new break label.
pub fn new_break_label(&mut self) -> F {
let label = self.break_counter;
self.break_counter += 1;
let label = F::from_canonical_usize(label);
self.break_label = Some(label);
label
}

/// Builds the operations into assembly instructions.
pub fn build(&mut self, operations: TracedVec<DslIr<AsmConfig<F, EF>>>) {
if self.block_label().is_zero() {
Expand Down Expand Up @@ -365,22 +347,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
);
}
}
DslIr::Break => {
let label = self.break_label.expect("No break label set");
let current_block = self.block_label();
self.contains_break.insert(current_block);
self.push(AsmInstruction::Break(label), debug_info);
}
DslIr::For(start, end, step_size, loop_var, block) => {
let for_compiler = ForCompiler {
compiler: self,
start,
end,
step_size,
loop_var,
};
for_compiler.for_each(move |_, builder| builder.build(block), debug_info);
}
DslIr::ZipFor(starts, end0, step_sizes, loop_vars, block) => {
let zip_for_compiler = ZipForCompiler {
compiler: self,
Expand Down Expand Up @@ -815,7 +781,8 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> IfCom
}
}

// Zipped for loop -- loop extends over the first entry in starts and ends
// Zipped for loop -- loop extends over the first entry in starts and end0
// ATTENTION: starting with starts[0] > end0 will lead to undefined behavior.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit scary in the VM case: is there any way we could prevent this at compile time?
Is it basically only used to iterate through arrays and in the range(0, ?) cases (where it's safe)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, only via builder.zip() and builder.range()

pub struct ZipForCompiler<'a, F: Field, EF> {
compiler: &'a mut AsmCompiler<F, EF>,
starts: Vec<RVar<F>>,
Expand Down Expand Up @@ -856,8 +823,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
});

let loop_call_label = self.compiler.block_label();
let break_label = self.compiler.new_break_label();
self.compiler.break_label = Some(break_label);

self.compiler.basic_block();
let loop_label = self.compiler.block_label();
Expand Down Expand Up @@ -898,135 +863,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
.push_to_block(loop_call_label, instr, debug_info.clone());

self.compiler.basic_block();
let label = self.compiler.block_label();
self.compiler.break_label_map.insert(break_label, label);

for block in self.compiler.contains_break.iter() {
for instruction in self.compiler.basic_blocks[block.as_canonical_u32() as usize]
.0
.iter_mut()
{
if let AsmInstruction::Break(l) = instruction {
if *l == break_label {
*instruction = AsmInstruction::j(label);
}
}
}
}
}
}

/// A builder for a for loop.
///
/// SAFETY: Starting with end < start will lead to undefined behavior.
pub struct ForCompiler<'a, F: Field, EF> {
compiler: &'a mut AsmCompiler<F, EF>,
start: RVar<F>,
end: RVar<F>,
step_size: F,
loop_var: Var<F>,
}

impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> ForCompiler<'_, F, EF> {
pub(super) fn for_each(
mut self,
f: impl FnOnce(Var<F>, &mut AsmCompiler<F, EF>),
debug_info: Option<DebugInfo>,
) {
// The function block structure:
// - Setting the loop range
// - Executing the loop body and incrementing the loop variable
// - the loop condition

// Set the loop variable to the start of the range.
self.set_loop_var(debug_info.clone());

// Save the label of the for loop call.
let loop_call_label = self.compiler.block_label();

// Initialize a break label for this loop.
let break_label = self.compiler.new_break_label();
self.compiler.break_label = Some(break_label);

// A basic block for the loop body
self.compiler.basic_block();

// Save the loop body label for the loop condition.
let loop_label = self.compiler.block_label();

// The loop body.
f(self.loop_var, self.compiler);

// Increment the loop variable.
self.compiler.push(
AsmInstruction::AddFI(self.loop_var.fp(), self.loop_var.fp(), self.step_size),
debug_info.clone(),
);

// Add a basic block for the loop condition.
self.compiler.basic_block();

// Jump to loop body if the loop condition still holds.
self.jump_to_loop_body(loop_label, debug_info.clone());

// Add a jump instruction to the loop condition in the loop call block.
let label = self.compiler.block_label();
let instr = AsmInstruction::j(label);
self.compiler
.push_to_block(loop_call_label, instr, debug_info.clone());

// Initialize the after loop block.
self.compiler.basic_block();

// Resolve the break label.
let label = self.compiler.block_label();
self.compiler.break_label_map.insert(break_label, label);

// Replace the break instruction with a jump to the after loop block.
for block in self.compiler.contains_break.iter() {
for instruction in self.compiler.basic_blocks[block.as_canonical_u32() as usize]
.0
.iter_mut()
{
if let AsmInstruction::Break(l) = instruction {
if *l == break_label {
*instruction = AsmInstruction::j(label);
}
}
}
}

// self.compiler.contains_break.clear();
}

fn set_loop_var(&mut self, debug_info: Option<DebugInfo>) {
match self.start {
RVar::Const(start) => {
self.compiler.push(
AsmInstruction::ImmF(self.loop_var.fp(), start),
debug_info.clone(),
);
}
RVar::Val(var) => {
self.compiler.push(
AsmInstruction::CopyF(self.loop_var.fp(), var.fp()),
debug_info.clone(),
);
}
}
}

fn jump_to_loop_body(&mut self, loop_label: F, debug_info: Option<DebugInfo>) {
match self.end {
RVar::Const(end) => {
let instr = AsmInstruction::BneI(loop_label, self.loop_var.fp(), end);
self.compiler.push(instr, debug_info.clone());
}
RVar::Val(end) => {
let instr = AsmInstruction::Bne(loop_label, self.loop_var.fp(), end.fp());
self.compiler.push(instr, debug_info.clone());
}
}
}
}

Expand Down
4 changes: 0 additions & 4 deletions extensions/native/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ pub enum AsmInstruction<F, EF> {
/// Halt.
Halt,

/// Break(label)
Break(F),

/// Perform a Poseidon2 permutation on state starting at address `lhs`
/// and store new state at `rhs`.
/// (a, b) are pointers to (lhs, rhs).
Expand Down Expand Up @@ -159,7 +156,6 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {

pub fn fmt(&self, labels: &BTreeMap<F, String>, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AsmInstruction::Break(_) => panic!("Unresolved break instruction"),
AsmInstruction::LoadFI(dst, src, var_index, size, offset) => {
write!(
f,
Expand Down
1 change: 0 additions & 1 deletion extensions/native/compiler/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ fn convert_instruction<F: PrimeField32, EF: ExtensionField<F>>(
options: &CompilerOptions,
) -> Program<F> {
let instructions = match instruction {
AsmInstruction::Break(_) => panic!("Unresolved break instruction"),
AsmInstruction::LoadFI(dst, src, index, size, offset) => vec![
// mem[dst] <- mem[mem[src] + index * size + offset]
inst(
Expand Down
Loading