Skip to content

Commit

Permalink
use mul add in babybear felt ops (#1175)
Browse files Browse the repository at this point in the history
* use mul add in babybear felt ops

* test new reduce

* test new reduce

* fix bug

* fix bug again

* fixed reducing

* slightly optimize reduce

* optimize dsl code

* chore: halo2 profiling flag in SDK prover

* chore: some debug statements

* chore: more debug..

* chore: another debug

* chore: more debug

* fix: add bench-metric feature dependency

* fix: another Cargo.toml bench-metrics feature dependency

* chore: remove debugging statements

* feat: new flamegraph definitions

* chore: don't try to generate all 0 flamegraphs

* chore: metrics to sum during flamegraph generation

* generate benchmarks before reduce

* simple div optimization

* fix_lint

* document code

* rewrite comment

* Apply suggestions from code review

---------

Co-authored-by: Stephen Hwang <[email protected]>
Co-authored-by: Jonathan Wang <[email protected]>
  • Loading branch information
3 people authored Jan 20, 2025
1 parent d514bc1 commit 2da1713
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 35 deletions.
5 changes: 1 addition & 4 deletions benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ pprof = { version = "0.13", features = [

[features]
default = ["parallel", "mimalloc", "bench-metrics"]
bench-metrics = [
"openvm-native-recursion/bench-metrics",
"openvm-native-compiler/bench-metrics",
]
bench-metrics = ["openvm-native-recursion/bench-metrics"]
profiling = ["openvm-sdk/profiling"]
aggregation = []
static-verifier = ["openvm-native-recursion/static-verifier"]
Expand Down
142 changes: 115 additions & 27 deletions extensions/native/compiler/src/constraints/halo2/baby_bear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use snark_verifier_sdk::snark_verifier::{
};

pub(crate) const BABYBEAR_MAX_BITS: usize = 31;
// bits reserved so that if we do lazy range checking, we still have a valid result
// the first reserved bit is so that we can represent negative numbers
// the second is to accomodate lazy range checking
const RESERVED_HIGH_BITS: usize = 2;

#[derive(Copy, Clone, Debug)]
pub struct AssignedBabyBear {
Expand Down Expand Up @@ -66,6 +70,7 @@ impl BabyBearChip {
}

pub fn reduce(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
debug_assert!(fe_to_bigint(a.value.value()).bits() as usize <= a.max_bits);
let (_, r) = signed_div_mod(&self.range, ctx, a.value, BabyBear::ORDER_U32, a.max_bits);
let r = AssignedBabyBear {
value: r,
Expand All @@ -78,17 +83,22 @@ impl BabyBearChip {
pub fn add(
&self,
ctx: &mut Context<Fr>,
a: AssignedBabyBear,
b: AssignedBabyBear,
mut a: AssignedBabyBear,
mut b: AssignedBabyBear,
) -> AssignedBabyBear {
if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
a = self.reduce(ctx, a);
if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
b = self.reduce(ctx, b);
}
}
let value = self.gate().add(ctx, a.value, b.value);
let max_bits = a.max_bits.max(b.max_bits) + 1;

let mut c = AssignedBabyBear { value, max_bits };
if c.max_bits >= Fr::CAPACITY as usize - 1 {
debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() + b.to_baby_bear());
if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
c = self.reduce(ctx, c);
}
debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() + b.to_baby_bear());
c
}

Expand All @@ -105,15 +115,20 @@ impl BabyBearChip {
pub fn sub(
&self,
ctx: &mut Context<Fr>,
a: AssignedBabyBear,
b: AssignedBabyBear,
mut a: AssignedBabyBear,
mut b: AssignedBabyBear,
) -> AssignedBabyBear {
if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
a = self.reduce(ctx, a);
if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
b = self.reduce(ctx, b);
}
}
let value = self.gate().sub(ctx, a.value, b.value);
let max_bits = a.max_bits.max(b.max_bits) + 1;

let mut c = AssignedBabyBear { value, max_bits };
debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() - b.to_baby_bear());
if c.max_bits >= Fr::CAPACITY as usize - 1 {
if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
c = self.reduce(ctx, c);
}
c
Expand All @@ -128,23 +143,56 @@ impl BabyBearChip {
if a.max_bits < b.max_bits {
std::mem::swap(&mut a, &mut b);
}
if a.max_bits + b.max_bits >= Fr::CAPACITY as usize - 1 {
if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
a = self.reduce(ctx, a);
if a.max_bits + b.max_bits >= Fr::CAPACITY as usize - 1 {
if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
b = self.reduce(ctx, b);
}
}
let value = self.gate().mul(ctx, a.value, b.value);
let max_bits = a.max_bits + b.max_bits;

let mut c = AssignedBabyBear { value, max_bits };
if c.max_bits >= Fr::CAPACITY as usize - 1 {
if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
c = self.reduce(ctx, c);
}
debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() * b.to_baby_bear());
c
}

pub fn mul_add(
&self,
ctx: &mut Context<Fr>,
mut a: AssignedBabyBear,
mut b: AssignedBabyBear,
mut c: AssignedBabyBear,
) -> AssignedBabyBear {
if a.max_bits < b.max_bits {
std::mem::swap(&mut a, &mut b);
}
if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
a = self.reduce(ctx, a);
if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
b = self.reduce(ctx, b);
}
}
if c.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
c = self.reduce(ctx, c)
}
let value = self.gate().mul_add(ctx, a.value, b.value, c.value);
let max_bits = c.max_bits.max(a.max_bits + b.max_bits) + 1;

let mut d = AssignedBabyBear { value, max_bits };
if d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
d = self.reduce(ctx, d);
}
debug_assert_eq!(
d.to_baby_bear(),
a.to_baby_bear() * b.to_baby_bear() + c.to_baby_bear()
);
d
}

pub fn div(
&self,
ctx: &mut Context<Fr>,
Expand All @@ -156,19 +204,24 @@ impl BabyBearChip {

let mut c = self.load_witness(ctx, a.to_baby_bear() * b_inv);
// constraint a = b * c (mod p)
if a.max_bits >= Fr::CAPACITY as usize - 1 {
if a.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
a = self.reduce(ctx, a);
}
if b.max_bits + c.max_bits >= Fr::CAPACITY as usize - 1 {
if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
b = self.reduce(ctx, b);
}
if b.max_bits + c.max_bits >= Fr::CAPACITY as usize - 1 {
if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
c = self.reduce(ctx, c);
}
let diff = self.gate().sub_mul(ctx, a.value, b.value, c.value);
let max_bits = a.max_bits.max(b.max_bits + c.max_bits) + 1;
let (_, r) = signed_div_mod(&self.range, ctx, diff, BabyBear::ORDER_U32, max_bits);
self.gate().assert_is_const(ctx, &r, &Fr::ZERO);
self.assert_zero(
ctx,
AssignedBabyBear {
value: diff,
max_bits,
},
);
debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() / b.to_baby_bear());
c
}
Expand All @@ -187,9 +240,39 @@ impl BabyBearChip {

pub fn assert_zero(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) {
debug_assert_eq!(a.to_baby_bear(), BabyBear::ZERO);
assert!(a.max_bits < Fr::CAPACITY as usize);
let (_, r) = signed_div_mod(&self.range, ctx, a.value, BabyBear::ORDER_U32, a.max_bits);
self.gate().assert_is_const(ctx, &r, &Fr::ZERO);
assert!(a.max_bits <= Fr::CAPACITY as usize - RESERVED_HIGH_BITS);
let a_num_bits = a.max_bits;
let b: BigUint = BabyBear::ORDER_U32.into();
let a_val = fe_to_bigint(a.value.value());
assert!(a_val.bits() <= a_num_bits as u64);
let (div, _) = a_val.div_mod_floor(&b.clone().into());
let div = bigint_to_fe(&div);
ctx.assign_region(
[
QuantumCell::Constant(Fr::ZERO),
QuantumCell::Constant(biguint_to_fe(&b)),
QuantumCell::Witness(div),
a.value.into(),
],
[0],
);
let div = ctx.get(-2);
// Constrain that `abs(div) <= 2 ** (2 ** a_num_bits / b).bits()`.
let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
let shifted_div =
self.range
.gate()
.add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
// in this use case, we know that a.max_bits <= Fr::CAPACITY - RESERVED_HIGH_BITS, which means that
// bound has at most Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS bits.
// In particular, 2 * bound + 1 has at most Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS + 1 bits.
// Most notably, suppose we could fake |p * shifted_div - p * shifted_div'| < p, with both shifted_div and shifted_div'
// distinct and satisfying the range check. We note that |shifted_div-shifted_div'| < 1 << (Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS + 1)
// In particular, even if we multiply by babybear, we have 0 < p * |shifted_div-shifted_div'| < 1 << (Fr::CAPACITY - RESERVED_HIGH_BITS + 2)
// its pretty clear that this has no overlap with (-p, p), so we are safe.
self.range
.range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
}

pub fn assert_equal(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear, b: AssignedBabyBear) {
Expand All @@ -207,7 +290,7 @@ impl BabyBearChip {
///
/// ## Assumptions
/// * `b != 0` and that `abs(a) < 2^a_max_bits`
/// * `a_max_bits < F::CAPACITY = F::NUM_BITS - 1`
/// * `a_max_bits < F::CAPACITY = F::NUM_BITS - RESERVED_HIGH_BITS`
/// * Unsafe behavior if `a_max_bits >= F::CAPACITY`
fn signed_div_mod<F>(
range: &RangeChip<F>,
Expand Down Expand Up @@ -236,13 +319,20 @@ where
);
let rem = ctx.get(-4);
let div = ctx.get(-2);
// Constrain that `abs(div) <= 2 ** a_num_bits / b`.
// Constrain that `abs(div) <= 2 ** (2 ** a_num_bits / b).bits()`.
let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
let shifted_div = range
.gate()
.add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
range.check_big_less_than_safe(ctx, shifted_div, bound * 2u32 + 1u32);
// in this use case, we know that a.max_bits <= Fr::CAPACITY - RESERVED_HIGH_BITS, which means that
// bound has at most Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS bits.
// In particular, 2 * bound + 1 has at most Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS + 1 bits.
// Most notably, suppose we could fake |p * shifted_div - p * shifted_div'| < p, with both shifted_div and shifted_div'
// distinct and satisfying the range check. We note that |shifted_div-shifted_div'| < 1 << (Fr::CAPACITY - RESERVED_HIGH_BITS - BABYBEAR_ORDER_BITS + 1)
// In particular, even if we multiply by babybear, we have 0 < p * |shifted_div-shifted_div'| < 1 << (Fr::CAPACITY - RESERVED_HIGH_BITS + 2)
// its pretty clear that this has no overlap with (-p, p), so we are safe.
range.range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
// Constrain that remainder is less than divisor (i.e. `r < b`).
debug_assert!(*rem.value() < biguint_to_fe(&b));
range.check_big_less_than_safe(ctx, rem, b);
Expand Down Expand Up @@ -396,8 +486,7 @@ impl BabyBearExt4Chip {
for i in 0..4 {
for j in 0..4 {
if i + j < coeffs.len() {
let tmp = self.base.mul(ctx, a.0[i], b.0[j]);
coeffs[i + j] = self.base.add(ctx, coeffs[i + j], tmp);
coeffs[i + j] = self.base.mul_add(ctx, a.0[i], b.0[j], coeffs[i + j]);
} else {
coeffs.push(self.base.mul(ctx, a.0[i], b.0[j]));
}
Expand All @@ -407,8 +496,7 @@ impl BabyBearExt4Chip {
.base
.load_constant(ctx, <BabyBear as BinomiallyExtendable<4>>::W);
for i in 4..7 {
let tmp = self.base.mul(ctx, coeffs[i], w);
coeffs[i - 4] = self.base.add(ctx, coeffs[i - 4], tmp);
coeffs[i - 4] = self.base.mul_add(ctx, coeffs[i], w, coeffs[i - 4]);
}
coeffs.truncate(4);
let c = AssignedBabyBearExt4(coeffs.try_into().unwrap());
Expand Down
7 changes: 3 additions & 4 deletions extensions/native/recursion/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ pub fn verify_two_adic_pcs<C: Config>(

let two_adic_generator = config.get_two_adic_generator(builder, log_height);
builder.cycle_tracker_start("exp-reverse-bits-len");

let index_bits_shifted_truncated =
index_bits_shifted.slice(builder, 0, log_height);
let two_adic_generator_exp = builder
Expand All @@ -198,16 +197,16 @@ pub fn verify_two_adic_pcs<C: Config>(
let ps_at_z = builder.iter_ptr_get(&mat_values, ptr_vec[1]);

builder.cycle_tracker_start("single-reduced-opening-eval");

if builder.flags.static_only {
let n: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
builder.range(0, ps_at_z.len()).for_each(|t, builder| {
let p_at_x = builder.get(&mat_opening, t);
let p_at_z = builder.get(&ps_at_z, t);
let quotient = (p_at_z - p_at_x) / (z - x);

builder.assign(&cur_ro, cur_ro + cur_alpha_pow * quotient);
builder.assign(&n, cur_alpha_pow * (p_at_z - p_at_x) + n);
builder.assign(&cur_alpha_pow, cur_alpha_pow * alpha);
});
builder.assign(&cur_ro, cur_ro + n / (z - x));
} else {
let mat_ro = builder.fri_single_reduced_opening_eval(
alpha,
Expand Down

0 comments on commit 2da1713

Please sign in to comment.