Skip to content

Commit

Permalink
chore: change all instances of iter to use zip
Browse files Browse the repository at this point in the history
  • Loading branch information
yi-sun committed Jan 18, 2025
1 parent 8980df0 commit 0805e65
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 375 deletions.
89 changes: 0 additions & 89 deletions extensions/native/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,32 +286,6 @@ impl<C: Config> Builder<C> {
}
}

pub fn iter<'a, V: MemVariable<C>>(
&'a mut self,
array: &'a Array<C, V>,
) -> IteratorBuilder<'a, C, V> {
match array {
Array::Fixed(_) => IteratorBuilder {
start: RVar::zero(),
end: array.len().into(),
step_size: 1,
builder: self,
array,
},
Array::Dyn(ptr, len) => {
let len: RVar<C::N> = len.clone().into();
let end: Var<C::N> = self.eval(ptr.address + len * RVar::from(V::size_of()));
IteratorBuilder {
start: ptr.address.into(),
end: end.into(),
step_size: V::size_of(),
builder: self,
array,
}
}
}
}

pub fn zip<'a>(
&'a mut self,
arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
Expand Down Expand Up @@ -825,69 +799,6 @@ impl<C: Config> ZippedPointerIteratorBuilder<'_, C> {
}
}

pub struct IteratorBuilder<'a, C: Config, V: MemVariable<C>> {
start: RVar<C::N>,
end: RVar<C::N>,
step_size: usize,
builder: &'a mut Builder<C>,
array: &'a Array<C, V>,
}

impl<C: Config, V: MemVariable<C>> IteratorBuilder<'_, C, V> {
pub fn for_each(&mut self, mut f: impl FnMut(V, &mut Builder<C>)) {
if self.start.is_const() && self.end.is_const() {
self.for_each_unrolled(|var, builder| {
f(var, builder);
});
return;
}
self.for_each_dynamic(|var, builder| {
f(var, builder);
});
}

fn for_each_unrolled(&mut self, mut f: impl FnMut(V, &mut Builder<C>)) {
let start = self.start.value();
let end = self.end.value();
for i in (start..end).step_by(self.step_size) {
let val = self.builder.get(self.array, i);
f(val, self.builder);
}
}

fn for_each_dynamic(&mut self, mut f: impl FnMut(V, &mut Builder<C>)) {
assert!(
!self.builder.flags.static_only,
"Cannot use dynamic loop in static mode"
);
let step_size = C::N::from_canonical_usize(self.step_size);
let loop_variable: Var<C::N> = self.builder.uninit();
let mut loop_body_builder = self.builder.create_sub_builder();
let val: V = loop_body_builder.uninit();
loop_body_builder.load(
val.clone(),
Ptr {
address: loop_variable,
},
MemIndex {
index: 0.into(),
offset: 0,
size: V::size_of(),
},
);
f(val, &mut loop_body_builder);
let loop_instructions = loop_body_builder.operations;
let op = DslIr::For(
self.start,
self.end,
step_size,
loop_variable,
loop_instructions,
);
self.builder.operations.push(op);
}
}

/// A builder for the DSL that handles for loops.
pub struct RangeBuilder<'a, C: Config> {
start: RVar<C::N>,
Expand Down
9 changes: 6 additions & 3 deletions extensions/native/compiler/src/ir/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use openvm_native_compiler_derive::compile_zip;
use openvm_stark_backend::p3_field::FieldAlgebra;

use super::{Array, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var};
use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var};

pub const DIGEST_SIZE: usize = 8;
pub const HASH_RATE: usize = 8;
Expand Down Expand Up @@ -88,8 +89,10 @@ impl<C: Config> Builder<C> {
let address = self.eval(state.ptr().address);
let start: Var<_> = self.eval(address);
let end: Var<_> = self.eval(address + C::N::from_canonical_usize(HASH_RATE));
self.iter(array).for_each(|subarray, builder| {
builder.iter(&subarray).for_each(|element, builder| {
compile_zip!(self, array).for_each(|idx_vec, builder| {
let subarray = builder.iter_ptr_get(&array, idx_vec[0]);
compile_zip!(builder, subarray).for_each(|ptr_vec, builder| {
let element = builder.iter_ptr_get(&subarray, ptr_vec[0]);
builder.cycle_tracker_start("poseidon2-hash-setup");
builder.store(
Ptr { address },
Expand Down
7 changes: 5 additions & 2 deletions extensions/native/compiler/src/ir/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::{Add, Mul};

use openvm_native_compiler_derive::compile_zip;
use openvm_stark_backend::p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField};

use super::{
Array, Builder, CanSelect, Config, DslIr, Ext, Felt, MemIndex, RVar, SymbolicExt, Var, Variable,
Array, ArrayLike, Builder, CanSelect, Config, DslIr, Ext, Felt, MemIndex, RVar, SymbolicExt,
Var, Variable,
};

pub const NUM_LIMBS: usize = 32;
Expand Down Expand Up @@ -88,7 +90,8 @@ impl<C: Config> Builder<C> {
let one_var: V = self.eval(V::Expression::ONE);

// Implements a square-and-multiply algorithm.
self.iter(power_bits).for_each(|bit, builder| {
compile_zip!(self, power_bits).for_each(|ptr_vec, builder| {
let bit = builder.iter_ptr_get(&power_bits, ptr_vec[0]);
builder.assign(&result, result * result);
let mul = V::select(builder, bit, power_f, one_var);
builder.assign(&result, result * mul);
Expand Down
45 changes: 0 additions & 45 deletions extensions/native/compiler/tests/for_loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,51 +49,6 @@ fn test_compiler_for_loops() {
execute_program(program, vec![]);
}

#[test]
fn test_compiler_iter_fixed() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let two: Var<_> = builder.eval(F::TWO);
let arr = builder.vec(vec![zero, one, two]);
let x: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
builder.iter(&arr).for_each(|val: Var<_>, builder| {
builder.assign(&x, x + val);
builder.assign(&count, count + F::ONE);
});
builder.assert_var_eq(count, F::from_canonical_usize(3));
builder.assert_var_eq(x, F::from_canonical_usize(3));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_iter_dyn() {
let mut builder = AsmBuilder::<F, EF>::default();
let zero: Var<_> = builder.eval(F::ZERO);
let one: Var<_> = builder.eval(F::ONE);
let two: Var<_> = builder.eval(F::TWO);
let arr = builder.dyn_array(3);
builder.set(&arr, 0, zero);
builder.set(&arr, 1, one);
builder.set(&arr, 2, two);
let x: Var<_> = builder.eval(F::ZERO);
let count: Var<_> = builder.eval(F::ZERO);
builder.iter(&arr).for_each(|val: Var<_>, builder| {
builder.assign(&x, x + val);
builder.assign(&count, count + F::ONE);
});
builder.assert_var_eq(count, F::from_canonical_usize(3));
builder.assert_var_eq(x, F::from_canonical_usize(3));
builder.halt();

let program = builder.compile_isa();
execute_program(program, vec![]);
}

#[test]
fn test_compiler_zip_fixed() {
let mut builder = AsmBuilder::<F, EF>::default();
Expand Down
15 changes: 8 additions & 7 deletions extensions/native/recursion/src/challenger/duplex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use openvm_native_compiler::{
ir::{RVar, DIGEST_SIZE, PERMUTATION_WIDTH},
prelude::{Array, Builder, Config, Ext, Felt, Var},
prelude::{Array, ArrayLike, Builder, Config, Ext, Felt, Var},
};
use openvm_native_compiler_derive::compile_zip;
use openvm_stark_backend::p3_field::{Field, FieldAlgebra};

use crate::{
Expand Down Expand Up @@ -146,11 +147,10 @@ impl<C: Config> DuplexChallengerVariable<C> {
self.observe(builder, witness);
let element_bits = self.sample_bits(builder, RVar::from(nb_bits));
let element_bits_truncated = element_bits.slice(builder, 0, nb_bits);
builder
.iter(&element_bits_truncated)
.for_each(|element, builder| {
builder.assert_var_eq(element, C::N::ZERO);
});
compile_zip!(builder, element_bits_truncated).for_each(|ptr_vec, builder| {
let element = builder.iter_ptr_get(&element_bits_truncated, ptr_vec[0]);
builder.assert_var_eq(element, C::N::ZERO);
});
}
}

Expand All @@ -160,7 +160,8 @@ impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C
}

fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
builder.iter(&values).for_each(|element, builder| {
compile_zip!(builder, values).for_each(|ptr_vec, builder| {
let element = builder.iter_ptr_get(&values, ptr_vec[0]);
self.observe(builder, element);
});
}
Expand Down
Loading

0 comments on commit 0805e65

Please sign in to comment.