From 59533bf096b71c55cd113da04d0d3cb8df0b551e Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 2 Dec 2024 15:01:11 -0800 Subject: [PATCH 1/3] Optimize reduce_32 --- extensions/native/recursion/src/utils.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions/native/recursion/src/utils.rs b/extensions/native/recursion/src/utils.rs index 67fb45cbd..afbc2a4eb 100644 --- a/extensions/native/recursion/src/utils.rs +++ b/extensions/native/recursion/src/utils.rs @@ -40,8 +40,7 @@ pub fn reduce_32(builder: &mut Builder, vals: &[Felt]) -> Va let mut power = C::N::ONE; let result: Var = builder.eval(C::N::ZERO); for val in vals.iter() { - let bits = builder.num2bits_f_circuit(*val); - let val = builder.bits2num_v_circuit(&bits); + let val = builder.cast_felt_to_var(*val); builder.assign(&result, result + val * power); power *= C::N::from_canonical_usize(1usize << 32); } From 438ee3f6b2a41386cde466d90fe24a6915197878 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:51:19 -0500 Subject: [PATCH 2/3] safety: static compiler CastFV max_bits bound --- extensions/native/compiler/src/constraints/halo2/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index 558c80c23..d64af0227 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -249,7 +249,7 @@ impl Halo2ConstraintCompiler { } DslIr::CastFV(a, b) => { let felt = felts[&b.0]; - let reduced_felt = if felt.max_bits > BABYBEAR_MAX_BITS { + let reduced_felt = if felt.max_bits >= BABYBEAR_MAX_BITS { f_chip.reduce(ctx, felt) } else { felt From 90d4139c88287053800d261b6e8f3ec7d6668564 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:08:35 -0500 Subject: [PATCH 3/3] opt: handle equal bits case --- extensions/native/compiler/Cargo.toml | 2 +- .../src/constraints/halo2/compiler.rs | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/extensions/native/compiler/Cargo.toml b/extensions/native/compiler/Cargo.toml index f5f0ddd56..26d64368e 100644 --- a/extensions/native/compiler/Cargo.toml +++ b/extensions/native/compiler/Cargo.toml @@ -45,7 +45,7 @@ rand = "0.8.5" axvm-circuit = { workspace = true, features = ["test-utils"] } [features] -default = ["parallel"] +default = ["parallel", "halo2-compiler"] halo2-compiler = ["dep:snark-verifier-sdk"] parallel = ["axvm-circuit/parallel"] bench-metrics = ["dep:metrics", "axvm-circuit/bench-metrics"] diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index d64af0227..baf80daf0 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -10,13 +10,15 @@ use axvm_circuit::metrics::cycle_tracker::CycleTracker; use itertools::Itertools; use p3_baby_bear::BabyBear; use p3_bn254_fr::Bn254Fr; -use p3_field::{ExtensionField, PrimeField}; +use p3_field::{ExtensionField, PrimeField, PrimeField32}; use snark_verifier_sdk::snark_verifier::{ halo2_base::{ - gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip}, + gates::{ + circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip, RangeInstructions, + }, halo2_proofs::halo2curves::bn256::Fr, utils::{biguint_to_fe, ScalarField}, - Context, + Context, QuantumCell, }, util::arithmetic::PrimeField as _, }; @@ -249,8 +251,18 @@ impl Halo2ConstraintCompiler { } DslIr::CastFV(a, b) => { let felt = felts[&b.0]; - let reduced_felt = if felt.max_bits >= BABYBEAR_MAX_BITS { + #[allow(clippy::comparison_chain)] + let reduced_felt = if felt.max_bits > BABYBEAR_MAX_BITS { f_chip.reduce(ctx, felt) + } else if felt.max_bits == BABYBEAR_MAX_BITS { + // Ensure cast is canonical + f_chip.range.check_less_than( + ctx, + felt.value, + QuantumCell::Constant(Fr::from(BabyBear::ORDER_U32 as u64)), + BABYBEAR_MAX_BITS, + ); + felt } else { felt };