Skip to content

Commit

Permalink
Optimize alpha_pow computation
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu committed Jan 21, 2025
1 parent ab01b0c commit 50bdb48
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 168 deletions.
48 changes: 32 additions & 16 deletions extensions/native/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{iter::Zip, vec::IntoIter};

use backtrace::Backtrace;
use itertools::izip;
use openvm_native_compiler_derive::iter_zip;
use openvm_stark_backend::p3_field::FieldAlgebra;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -275,13 +276,22 @@ impl<C: Config> Builder<C> {
&mut self,
start: impl Into<RVar<C::N>>,
end: impl Into<RVar<C::N>>,
) -> IteratorBuilder<C> {
self.range_with_step(start, end, C::N::ONE)
}
/// Evaluate a block of operations over a range from start to end with a custom step.
pub fn range_with_step(
&mut self,
start: impl Into<RVar<C::N>>,
end: impl Into<RVar<C::N>>,
step: C::N,
) -> IteratorBuilder<C> {
let start = start.into();
let end0 = end.into();
IteratorBuilder {
starts: vec![start],
end0,
step_sizes: vec![1],
step_sizes: vec![step],
builder: self,
}
}
Expand All @@ -295,7 +305,7 @@ impl<C: Config> Builder<C> {
IteratorBuilder {
starts: vec![RVar::zero(); arrays.len()],
end0: arrays[0].len().into(),
step_sizes: vec![1; arrays.len()],
step_sizes: vec![C::N::ONE; arrays.len()],
builder: self,
}
} else if arrays.iter().all(|array| !array.is_fixed()) {
Expand All @@ -311,7 +321,10 @@ impl<C: Config> Builder<C> {
self.eval(arrays[0].ptr().address + len * RVar::from(size));
end.into()
},
step_sizes: arrays.iter().map(|array| array.element_size_of()).collect(),
step_sizes: arrays
.iter()
.map(|array| C::N::from_canonical_usize(array.element_size_of()))
.collect(),
builder: self,
}
} else {
Expand Down Expand Up @@ -735,7 +748,7 @@ impl<C: Config> IfBuilder<'_, C> {
pub struct IteratorBuilder<'a, C: Config> {
starts: Vec<RVar<C::N>>,
end0: RVar<C::N>,
step_sizes: Vec<usize>,
step_sizes: Vec<C::N>,
builder: &'a mut Builder<C>,
}

Expand All @@ -757,12 +770,20 @@ impl<C: Config> IteratorBuilder<'_, C> {
}

fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
let starts: Vec<usize> = self.starts.iter().map(|start| start.value()).collect();
let end0 = self.end0.value();

for i in (starts[0]..end0).step_by(self.step_sizes[0]) {
let ptrs = vec![i.into(); self.starts.len()];
f(ptrs, self.builder);
let mut ptrs: Vec<_> = self
.starts
.iter()
.map(|start| start.field_value())
.collect();
let end0 = self.end0.field_value();
while ptrs[0] != end0 {
f(
ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
self.builder,
);
for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
*ptr += *step_size;
}
}
}

Expand All @@ -772,11 +793,6 @@ impl<C: Config> IteratorBuilder<'_, C> {
"Cannot use dynamic loop in static mode"
);

let step_sizes = self
.step_sizes
.iter()
.map(|s| C::N::from_canonical_usize(*s))
.collect();
let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
.map(|_| self.builder.uninit())
.collect();
Expand All @@ -791,7 +807,7 @@ impl<C: Config> IteratorBuilder<'_, C> {
let op = DslIr::ZipFor(
self.starts.clone(),
self.end0,
step_sizes,
self.step_sizes.clone(),
loop_variables,
loop_instructions,
);
Expand Down
6 changes: 6 additions & 0 deletions extensions/native/compiler/src/ir/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ impl<N: PrimeField> RVar<N> {
_ => panic!("RVar::value() called on non-const value"),
}
}
pub fn field_value(&self) -> N {
match self {
RVar::Const(c) => *c,
_ => panic!("RVar::field_value() called on non-const value"),
}
}
}

impl<N: Field> Hash for SymbolicVar<N> {
Expand Down
Loading

0 comments on commit 50bdb48

Please sign in to comment.