Skip to content

Commit

Permalink
feat: positive tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanpwang committed Jan 23, 2025
1 parent ab7e0f8 commit b115924
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 88 deletions.
25 changes: 8 additions & 17 deletions crates/vm/src/arch/new_integration_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,9 @@ pub trait VmAdapter<F>: BaseAir<F> + Clone {
/// Trait to be implemented on a struct that has enough information to determine
/// the adapter row width.
pub trait VmAdapterAir<AB: AirBuilder>: BaseAir<AB::F> {
type AirTx<'tx>
where
Self: 'tx,
AB: 'tx;
type AirTx;

fn air_tx<'a>(&self, local_adapter: &'a [AB::Var]) -> Self::AirTx<'a>;
fn air_tx(&self, local_adapter: &[AB::Var]) -> Self::AirTx;
}

/// Trait to be implemented on primitive chip to integrate with the machine.
Expand Down Expand Up @@ -95,12 +92,12 @@ pub trait VmCoreChip<F, A: VmAdapter<F>> {
}
}

/// The generic `TX` should be an `AirTx` type.
pub trait VmCoreAir<AB, TX>: BaseAirWithPublicValues<AB::F>
pub trait VmCoreAir<AB, A>: BaseAirWithPublicValues<AB::F>
where
AB: AirBuilder,
A: VmAdapterAir<AB>,
{
fn eval(&self, builder: &mut AB, local_core: &[AB::Var], tx: &mut TX);
fn eval(&self, builder: &mut AB, local_core: &[AB::Var], tx: &mut A::AirTx);

/// The offset the opcodes by this chip start from.
/// This is usually just `CorrespondingOpcode::CLASS_OFFSET`,
Expand Down Expand Up @@ -198,14 +195,8 @@ where
A: VmAdapterAir<SymbolicRapBuilder<Val<SC>>>
+ for<'a> VmAdapterAir<DebugConstraintBuilder<'a, SC>>, // AirRef bound
C::Air: Send + Sync + 'static,
C::Air: for<'tx> VmCoreAir<
SymbolicRapBuilder<Val<SC>>,
<A as VmAdapterAir<SymbolicRapBuilder<Val<SC>>>>::AirTx<'tx>,
>,
C::Air: for<'tx, 'a> VmCoreAir<
DebugConstraintBuilder<'a, SC>,
<A as VmAdapterAir<DebugConstraintBuilder<'a, SC>>>::AirTx<'tx>,
>,
C::Air: VmCoreAir<SymbolicRapBuilder<Val<SC>>, A>
+ for<'a> VmCoreAir<DebugConstraintBuilder<'a, SC>, A>,
{
fn air(&self) -> Arc<dyn AnyRap<SC>> {
let air: VmAirWrapper<A, C::Air> = VmAirWrapper {
Expand Down Expand Up @@ -301,7 +292,7 @@ impl<AB, A, C> Air<AB> for VmAirWrapper<A, C>
where
AB: AirBuilder,
A: VmAdapterAir<AB>,
C: for<'tx> VmCoreAir<AB, A::AirTx<'tx>>,
C: VmCoreAir<AB, A>,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
Expand Down
3 changes: 0 additions & 3 deletions extensions/rv32im/circuit/src/base_alu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ use crate::adapters::Rv32BaseAluAdapterChip;
mod core;
pub use core::*;

#[cfg(test)]
mod tests;

pub type Rv32BaseAluChip<F> = VmChipWrapper<
F,
Rv32BaseAluAdapterChip<F>,
Expand Down
78 changes: 40 additions & 38 deletions extensions/rv32im/circuit/src/new_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,26 @@ impl<F> BaseAir<F> for Rv32RegisterAdapter {
}

impl<AB: AirBuilder> VmAdapterAir<AB> for Rv32RegisterAdapter {
type AirTx<'tx>
= Rv32RegisterAirTx<'tx, AB>
where
Self: 'tx,
AB: 'tx;
type AirTx = Rv32RegisterAirTx<AB>;

fn air_tx<'a>(&self, local_adapter: &'a [AB::Var]) -> Rv32RegisterAirTx<'a, AB> {
fn air_tx(&self, local_adapter: &[AB::Var]) -> Rv32RegisterAirTx<AB> {
Rv32RegisterAirTx {
port: self.port,
row_buffer: local_adapter,
row_buffer: local_adapter.to_vec(),
pos: 0,
cur_timestamp: None,
instr_multiplicity: AB::Expr::ZERO,
from_state: None,
}
}
}

pub struct Rv32RegisterAirTx<'a, AB: AirBuilder> {
pub struct Rv32RegisterAirTx<AB: AirBuilder> {
port: SystemPort,
row_buffer: &'a [AB::Var],
// We use Vec instead of slice because there are some lifetime issues around
// AB needing to outlive 'tx which Rust GATs can't handle yet.
row_buffer: Vec<AB::Var>,
pos: usize,
pub cur_timestamp: Option<AB::Expr>,
/// Multiplicity to use for program and execution bus
instr_multiplicity: AB::Expr,
Expand All @@ -107,7 +107,7 @@ pub struct Rv32RegisterAirTx<'a, AB: AirBuilder> {
// will be mutable borrow issues (you can't share a mutable reference)
}

impl<AB: AirBuilder> Drop for Rv32RegisterAirTx<'_, AB> {
impl<AB: AirBuilder> Drop for Rv32RegisterAirTx<AB> {
fn drop(&mut self) {
assert!(self.cur_timestamp.is_none(), "Transaction was never ended");
}
Expand Down Expand Up @@ -142,12 +142,12 @@ const READ_WIDTH: usize = size_of::<Rv32RegisterReadCols<u8>>();
const READ_IMM_WIDTH: usize = size_of::<Rv32RegOrImmReadCols<u8>>();
const WRITE_WIDTH: usize = size_of::<Rv32RegisterWriteCols<u8>>();

impl<AB: InteractionBuilder> AirTx<AB> for Rv32RegisterAirTx<'_, AB> {
impl<AB: InteractionBuilder> AirTx<AB> for Rv32RegisterAirTx<AB> {
fn start(&mut self, _builder: &mut AB, multiplicity: impl Into<AB::Expr>) {
self.instr_multiplicity = multiplicity.into();
let (local, remaining) = self.row_buffer.split_at(STATE_WIDTH);
self.row_buffer = remaining;
let from_state: &ExecutionState<AB::Var> = local.borrow();
let pos = self.pos;
let from_state: &ExecutionState<AB::Var> = self.row_buffer[pos..pos + STATE_WIDTH].borrow();
self.pos += STATE_WIDTH;
self.cur_timestamp = Some(from_state.timestamp.into());
self.from_state = Some(ExecutionState::new(from_state.pc, from_state.timestamp));
}
Expand Down Expand Up @@ -186,32 +186,28 @@ impl<AB: InteractionBuilder> AirTx<AB> for Rv32RegisterAirTx<'_, AB> {
}
}

impl<AB: AirBuilder> Rv32RegisterAirTx<'_, AB> {
impl<AB: AirBuilder> Rv32RegisterAirTx<AB> {
pub fn set_cur_timestamp(&mut self, timestamp: impl Into<AB::Expr>) {
self.cur_timestamp = Some(timestamp.into());
}

fn timestamp_pp(&mut self) -> AB::Expr {
let cur_timestamp = self.cur_timestamp.as_mut().unwrap();
let t = cur_timestamp.clone();
*cur_timestamp = cur_timestamp.clone() + AB::Expr::ONE;
t
}
}

impl<AB: InteractionBuilder> AirTxRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIMBS]>
for Rv32RegisterAirTx<'_, AB>
for Rv32RegisterAirTx<AB>
{
fn read(
&mut self,
builder: &mut AB,
data: [AB::Expr; RV32_REGISTER_NUM_LIMBS],
multiplicity: impl Into<AB::Expr>,
) -> MemoryAddress<AB::Expr, AB::Expr> {
let (local, remaining) = self.row_buffer.split_at(READ_WIDTH);
self.row_buffer = remaining;
let local: &Rv32RegisterReadCols<AB::Var> = local.borrow();
let timestamp = self.timestamp_pp();
let pos = self.pos;
let local: &Rv32RegisterReadCols<AB::Var> = self.row_buffer[pos..pos + READ_WIDTH].borrow();
self.pos += READ_WIDTH;
// Annoyance: we cannot make self.timestamp_pp() a function due to selective mutable borrow. This may be possible in a newer Rust version.
let cur_timestamp = self.cur_timestamp.as_mut().unwrap();
let timestamp = cur_timestamp.clone();
*cur_timestamp = cur_timestamp.clone() + AB::Expr::ONE;
let addr = MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
local.ptr.into(),
Expand All @@ -225,7 +221,7 @@ impl<AB: InteractionBuilder> AirTxRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIMBS]>
}

impl<AB: InteractionBuilder> AirTxMaybeRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIMBS]>
for Rv32RegisterAirTx<'_, AB>
for Rv32RegisterAirTx<AB>
{
/// Memory bridge multiplicity is equal to `address_space`, which is 0 or 1.
/// In particular, dummy rows should set `address_space` to 0
Expand All @@ -236,9 +232,10 @@ impl<AB: InteractionBuilder> AirTxMaybeRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIM
data: [AB::Expr; RV32_REGISTER_NUM_LIMBS],
_multiplicity: impl Into<AB::Expr>,
) -> MemoryAddress<AB::Expr, AB::Expr> {
let (local, remaining) = self.row_buffer.split_at(READ_IMM_WIDTH);
self.row_buffer = remaining;
let local: &Rv32RegOrImmReadCols<AB::Var> = local.borrow();
let pos = self.pos;
let local: &Rv32RegOrImmReadCols<AB::Var> =
self.row_buffer[pos..pos + READ_IMM_WIDTH].borrow();
self.pos += READ_IMM_WIDTH;

// if an immediate value, constrain that its 4-byte representation is correct
let rs_sign = data[2].clone();
Expand All @@ -253,7 +250,9 @@ impl<AB: InteractionBuilder> AirTxMaybeRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIM
rs_sign.clone() * (AB::Expr::from_canonical_usize((1 << RV32_CELL_BITS) - 1) - rs_sign),
);

let timestamp = self.timestamp_pp();
let cur_timestamp = self.cur_timestamp.as_mut().unwrap();
let timestamp = cur_timestamp.clone();
*cur_timestamp = cur_timestamp.clone() + AB::Expr::ONE;
let addr = MemoryAddress::new(local.address_space.into(), local.ptr_or_imm.into());
self.port
.memory_bridge
Expand All @@ -264,18 +263,21 @@ impl<AB: InteractionBuilder> AirTxMaybeRead<AB, [AB::Expr; RV32_REGISTER_NUM_LIM
}

impl<AB: InteractionBuilder> AirTxWrite<AB, [AB::Expr; RV32_REGISTER_NUM_LIMBS]>
for Rv32RegisterAirTx<'_, AB>
for Rv32RegisterAirTx<AB>
{
fn write(
&mut self,
builder: &mut AB,
data: [AB::Expr; RV32_REGISTER_NUM_LIMBS],
multiplicity: impl Into<AB::Expr>,
) -> MemoryAddress<AB::Expr, AB::Expr> {
let (local, remaining) = self.row_buffer.split_at(WRITE_WIDTH);
self.row_buffer = remaining;
let local: &Rv32RegisterWriteCols<AB::Var> = local.borrow();
let timestamp = self.timestamp_pp();
let pos = self.pos;
let local: &Rv32RegisterWriteCols<AB::Var> =
self.row_buffer[pos..pos + WRITE_WIDTH].borrow();
self.pos += WRITE_WIDTH;
let cur_timestamp = self.cur_timestamp.as_mut().unwrap();
let timestamp = cur_timestamp.clone();
*cur_timestamp = cur_timestamp.clone() + AB::Expr::ONE;
let addr = MemoryAddress::new(
AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
local.ptr.into(),
Expand Down Expand Up @@ -312,7 +314,7 @@ impl<F> ExecuteTx for Rv32RegisterExecuteTx<F> {
}

fn end(&mut self) -> u32 {
self.from_pc.unwrap() + DEFAULT_PC_STEP
self.from_pc.take().unwrap() + DEFAULT_PC_STEP
}
}

Expand Down
11 changes: 6 additions & 5 deletions extensions/rv32im/circuit/src/new_base_alu/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use openvm_circuit::{
arch::{
new_integration_api::{VmAdapter, VmCoreAir, VmCoreChip},
new_integration_api::{VmAdapter, VmAdapterAir, VmCoreAir, VmCoreChip},
AirTx, AirTxMaybeRead, AirTxRead, AirTxWrite, ExecuteTx, ExecuteTxMaybeRead, ExecuteTxRead,
ExecuteTxWrite, Result, TraceTx, TraceTxMaybeRead, TraceTxRead, TraceTxWrite,
},
Expand Down Expand Up @@ -60,16 +60,17 @@ impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublic
{
}

impl<AB, TX, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, TX>
impl<AB, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, A>
for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
where
AB: InteractionBuilder,
TX: AirTx<AB>
A: VmAdapterAir<AB>,
A::AirTx: AirTx<AB>
+ AirTxRead<AB, [AB::Expr; NUM_LIMBS]>
+ AirTxMaybeRead<AB, [AB::Expr; NUM_LIMBS]>
+ AirTxWrite<AB, [AB::Expr; NUM_LIMBS]>,
{
fn eval(&self, builder: &mut AB, local_core: &[AB::Var], tx: &mut TX) {
fn eval(&self, builder: &mut AB, local_core: &[AB::Var], tx: &mut A::AirTx) {
let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
let flags = [
cols.opcode_add_flag,
Expand Down Expand Up @@ -141,7 +142,7 @@ where
.eval(builder, is_valid.clone());
}

let expected_opcode = VmCoreAir::<AB, TX>::expr_to_global_expr(
let expected_opcode = VmCoreAir::<AB, A>::expr_to_global_expr(
self,
flags.iter().zip(BaseAluOpcode::iter()).fold(
AB::Expr::ZERO,
Expand Down
15 changes: 6 additions & 9 deletions extensions/rv32im/circuit/src/new_base_alu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use openvm_circuit::arch::VmChipWrapper;
use openvm_circuit::arch::new_integration_api::VmChipWrapper;

use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
use crate::adapters::Rv32BaseAluAdapterChip;
use crate::new_adapter::Rv32RegisterAdapter;

mod core;
pub use core::*;

// #[cfg(test)]
// mod tests;
#[cfg(test)]
mod tests;

pub type Rv32BaseAluChip<F> = VmChipWrapper<
F,
Rv32BaseAluAdapterChip<F>,
BaseAluCoreChip<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>,
>;
pub type Rv32BaseAluChip<F> =
VmChipWrapper<F, Rv32RegisterAdapter, BaseAluCoreChip<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>>;
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::borrow::BorrowMut;
use openvm_circuit::{
arch::{
testing::{TestAdapterChip, VmChipTestBuilder},
ExecutionBridge, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS,
ExecutionBridge, BITWISE_OP_LOOKUP_BUS,
},
utils::generate_long_number,
};
Expand All @@ -28,8 +28,9 @@ use rand::Rng;

use super::{core::run_alu, BaseAluCoreChip, Rv32BaseAluChip};
use crate::{
adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
base_alu::BaseAluCoreCols,
new_adapter::Rv32RegisterAdapter,
test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm},
};

Expand All @@ -49,11 +50,7 @@ fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) {

let mut tester = VmChipTestBuilder::default();
let mut chip = Rv32BaseAluChip::<F>::new(
Rv32BaseAluAdapterChip::new(
tester.execution_bus(),
tester.program_bus(),
tester.memory_bridge(),
),
Rv32RegisterAdapter::new(tester.system_port(), 1, 1, 1),
BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET),
tester.offline_memory_mutex_arc(),
);
Expand Down Expand Up @@ -122,8 +119,7 @@ fn rv32_alu_and_rand_test() {
// A dummy adapter is used so memory interactions don't indirectly cause false passes.
//////////////////////////////////////////////////////////////////////////////////////

type Rv32BaseAluTestChip<F> =
VmChipWrapper<F, TestAdapterChip<F>, BaseAluCoreChip<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>>;
// TODO: make new prank adapter

#[allow(clippy::too_many_arguments)]
fn run_rv32_alu_negative_test(
Expand All @@ -137,12 +133,8 @@ fn run_rv32_alu_negative_test(
let bitwise_chip = SharedBitwiseOperationLookupChip::<RV32_CELL_BITS>::new(bitwise_bus);

let mut tester: VmChipTestBuilder<BabyBear> = VmChipTestBuilder::default();
let mut chip = Rv32BaseAluTestChip::<F>::new(
TestAdapterChip::new(
vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()],
vec![None],
ExecutionBridge::new(tester.execution_bus(), tester.program_bus()),
),
let mut chip = Rv32BaseAluChip::<F>::new(
Rv32RegisterAdapter::new(tester.system_port(), 1, 1, 1),
BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET),
tester.offline_memory_mutex_arc(),
);
Expand All @@ -153,7 +145,7 @@ fn run_rv32_alu_negative_test(
);

let trace_width = chip.trace_width();
let adapter_width = BaseAir::<F>::width(chip.adapter.air());
let adapter_width = BaseAir::<F>::width(&chip.adapter);

if (opcode == BaseAluOpcode::ADD || opcode == BaseAluOpcode::SUB)
&& a.iter().all(|&a_val| a_val < (1 << RV32_CELL_BITS))
Expand Down

0 comments on commit b115924

Please sign in to comment.