Skip to content

Commit

Permalink
chore: cleanup num2bits (#1231)
Browse files Browse the repository at this point in the history
* chore: remove unused num2bits_v

* fix: lint

* fix: prevent overflow in num2bits_f

* chore: restrict to F = BabyBear

* fix: fix type assert

* fix: constants

* chore: update comment

* chore: update comments

* chore: remove usused DslIr::HintBitsV

* chore: remove unused DslIr::LessThan
  • Loading branch information
yi-sun authored Jan 17, 2025
1 parent 7e86717 commit ee7129a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 72 deletions.
3 changes: 0 additions & 3 deletions extensions/native/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
DslIr::HintBitsF(var, len) => {
self.push(AsmInstruction::HintBits(var.fp(), len), debug_info);
}
DslIr::HintBitsV(var, len) => {
self.push(AsmInstruction::HintBits(var.fp(), len), debug_info);
}
DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) {
(Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push(
AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()),
Expand Down
68 changes: 36 additions & 32 deletions extensions/native/compiler/src/ir/bits.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
use std::any::TypeId;

use openvm_stark_backend::p3_field::FieldAlgebra;
use openvm_stark_sdk::p3_baby_bear::BabyBear;

use super::{Array, Builder, Config, DslIr, Felt, MemIndex, Var};

impl<C: Config> Builder<C> {
/// Converts a variable to bits.
pub fn num2bits_v(&mut self, num: Var<C::N>, num_bits: u32) -> Array<C, Var<C::N>> {
self.push(DslIr::HintBitsV(num, num_bits));

let output = self.dyn_array::<Var<_>>(num_bits as usize);

let sum: Var<_> = self.eval(C::N::ZERO);
for i in 0..num_bits as usize {
let index = MemIndex {
index: i.into(),
offset: 0,
size: 1,
};
self.push(DslIr::StoreHintWord(output.ptr(), index));

let bit = self.get(&output, i);
self.assert_var_eq(bit * (bit - C::N::ONE), C::N::ZERO);
self.assign(&sum, sum + bit * C::N::from_canonical_u32(1 << i));
}

// FIXME: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect.
self.assert_var_eq(sum, num);

output
}

/// Converts a variable to bits inside a circuit.
pub fn num2bits_v_circuit(&mut self, num: Var<C::N>, bits: usize) -> Vec<Var<C::N>> {
let mut output = Vec::new();
Expand All @@ -42,13 +18,21 @@ impl<C: Config> Builder<C> {
output
}

/// Converts a felt to bits.
/// Converts a felt to bits. Will result in a failed assertion if `num` has more than `num_bits` bits.
/// Only works for C::F = BabyBear
pub fn num2bits_f(&mut self, num: Felt<C::F>, num_bits: u32) -> Array<C, Var<C::N>> {
self.push(DslIr::HintBitsF(num, num_bits));
assert!(TypeId::of::<C::F>() == TypeId::of::<BabyBear>());

self.push(DslIr::HintBitsF(num, num_bits));
let output = self.dyn_array::<Felt<_>>(num_bits as usize);

let sum: Felt<_> = self.eval(C::F::ZERO);
// if `num_bits >= 27`, this will be used to compute b_0 + ... + b_26 * 2^26
// otherwise, this will be 0
let prefix_sum: Felt<_> = self.eval(C::F::ZERO);
// if `num_bits >= 27`, this will be used to compute b_27 + ... + b_30
// otherwise, this will be 0
let suffix_bit_sum: Felt<_> = self.eval(C::F::ZERO);
for i in 0..num_bits as usize {
let index = MemIndex {
index: i.into(),
Expand All @@ -60,12 +44,32 @@ impl<C: Config> Builder<C> {
let bit = self.get(&output, i);
self.assert_felt_eq(bit * (bit - C::F::ONE), C::F::ZERO);
self.assign(&sum, sum + bit * C::F::from_canonical_u32(1 << i));
if i == 26 {
self.assign(&prefix_sum, sum);
}
if i > 26 {
self.assign(&suffix_bit_sum, suffix_bit_sum + bit);
}
}

// FIXME: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect.
self.assert_felt_eq(sum, num);

// Check that the bits represent the number without overflow.
// If F is BabyBear, then any element of F can be represented either as:
// * 2^30 + ... + 2^x + y for y in [0, 2^(x - 1)) and 27 < x <= 30
// * 2^30 + ... + 2^27
// * y for y in [0, 2^27)
// To check that bits `b[0], ..., b[30]` represent `num = b[0] + ... + b[30] * 2^30` without overflow,
// we may check that:
// * if `num_bits < 27`, then `b[30] = 0`, so overflow is impossible.
// In this case, `suffix_bit_sum = 0`, so the check below passes.
// * if `num_bits >= 27`, then we must check:
// if `suffix_bit_sum = b[27] + ... + b[30] = 4`, then `prefix_sum = b[0] + ... + b[26] * 2^26 = 0`
let suffix_bit_sum_var = self.cast_felt_to_var(suffix_bit_sum);
self.if_eq(suffix_bit_sum_var, C::N::from_canonical_u32(4))
.then(|builder| {
builder.assert_felt_eq(prefix_sum, C::F::ZERO);
});

// Cast Array<C, Felt<C::F>> to Array<C, Var<C::N>>
Array::Dyn(output.ptr(), output.len())
}
Expand Down
6 changes: 0 additions & 6 deletions extensions/native/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ pub enum DslIr<C: Config> {
/// Prepare next input vector (preceded by its length) for hinting.
HintInputVec(),
/// Prepare bit decomposition for hinting.
HintBitsV(Var<C::N>, u32),
/// Prepare bit decomposition for hinting.
HintBitsF(Felt<C::F>, u32),

StoreHintWord(Ptr<C::N>, MemIndex<C::N>),
Expand Down Expand Up @@ -264,10 +262,6 @@ pub enum DslIr<C: Config> {
Ext<C::F, C::EF>,
),

// Debugging instructions.
/// Executes less than (var = var < var). This operation is NOT constrained.
LessThan(Var<C::N>, Var<C::N>, Var<C::N>),

/// Start the cycle tracker used by a block of code annotated by the string input. Calling this with the same
/// string will end the open cycle tracker instance and start a new one with an increasing numeric postfix.
CycleTrackerStart(String),
Expand Down
31 changes: 1 addition & 30 deletions extensions/native/compiler/tests/hint.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use openvm_native_circuit::execute_program;
use openvm_native_compiler::{
asm::AsmBuilder,
ir::{Felt, RVar, Var},
};
use openvm_native_compiler::{asm::AsmBuilder, ir::Felt};
use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra};
use openvm_stark_sdk::p3_baby_bear::BabyBear;

Expand Down Expand Up @@ -34,29 +31,3 @@ fn test_hint_bits_felt() {
println!("{}", program);
execute_program(program, vec![]);
}

#[test]
fn test_hint_bits_var() {
let mut builder = AsmBuilder::<F, EF>::default();

let var: Var<_> = builder.constant(F::from_canonical_u32(5));
let bits = builder.num2bits_v(var, F::bits() as u32);

let x = builder.get(&bits, RVar::zero());
builder.assert_var_eq(x, F::ONE);
let x = builder.get(&bits, RVar::one());
builder.assert_var_eq(x, F::ZERO);
let x = builder.get(&bits, 2);
builder.assert_var_eq(x, F::ONE);

for i in 3..31 {
let x = builder.get(&bits, i);
builder.assert_var_eq(x, F::ZERO);
}

builder.halt();

let program = builder.compile_isa();
println!("{}", program);
execute_program(program, vec![]);
}
2 changes: 1 addition & 1 deletion extensions/native/recursion/src/halo2/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn test_publish() {
}

#[test]
fn test_num2bits_v() {
fn test_num2bits_v_circuit() {
let mut builder = Builder::<OuterConfig>::default();
builder.flags.static_only = true;
let mut value_u32 = 1345237507;
Expand Down

0 comments on commit ee7129a

Please sign in to comment.