Skip to content

Commit

Permalink
chore: Use aux_cols_factory.generate_*_aux
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley committed Jan 23, 2025
1 parent 40dc754 commit 486a166
Show file tree
Hide file tree
Showing 25 changed files with 247 additions and 298 deletions.
56 changes: 17 additions & 39 deletions crates/vm/src/system/memory/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use crate::{
merkle::{MemoryMerkleBus, MemoryMerkleChip},
offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP},
offline_checker::{
MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols,
MemoryWriteAuxCols, AUX_LEN,
MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols,
MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN,
},
online::{Memory, MemoryLogEntry},
persistent::PersistentBoundaryChip,
Expand Down Expand Up @@ -719,12 +719,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
!read.address_space.is_zero(),
"cannot make `MemoryReadAuxCols` for address space 0"
);
buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp);
self.generate_timestamp_lt(
read.prev_timestamp,
read.timestamp,
&mut buffer.base.clk_lt_aux,
);
self.generate_base_aux(read, &mut buffer.base);
}

pub fn generate_read_or_immediate_aux(
Expand All @@ -736,26 +731,26 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
read.address_space,
(&mut buffer.is_zero_aux, &mut buffer.is_immediate),
);
buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp);
self.generate_timestamp_lt(
read.prev_timestamp,
read.timestamp,
&mut buffer.base.clk_lt_aux,
);
self.generate_base_aux(read, &mut buffer.base);
}

pub fn generate_write_aux<const N: usize>(
&self,
write: &MemoryRecord<F>,
buffer: &mut MemoryWriteAuxCols<F, N>,
) {
let prev_data = write.prev_data.clone().unwrap();
buffer.prev_data = prev_data.try_into().unwrap();
buffer.base.prev_timestamp = F::from_canonical_u32(write.prev_timestamp);
buffer
.prev_data
.copy_from_slice(&write.prev_data.as_ref().unwrap());
self.generate_base_aux(write, &mut buffer.base);
}

pub fn generate_base_aux(&self, record: &MemoryRecord<F>, buffer: &mut MemoryBaseAuxCols<F>) {
buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp);
self.generate_timestamp_lt(
write.prev_timestamp,
write.timestamp,
&mut buffer.base.clk_lt_aux,
record.prev_timestamp,
record.timestamp,
&mut buffer.timestamp_lt_aux,
);
}

Expand All @@ -772,7 +767,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
);
}

// TODO[jpw]: delete these functions and use the ones that fill buffers above.
/// In general, prefer `generate_read_aux` which writes in-place rather than this function.
pub fn make_read_aux_cols(&self, read: &MemoryRecord<F>) -> MemoryReadAuxCols<F> {
assert!(
!read.address_space.is_zero(),
Expand All @@ -784,24 +779,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
)
}

pub fn make_read_or_immediate_aux_cols(
&self,
read: &MemoryRecord<F>,
) -> MemoryReadOrImmediateAuxCols<F> {
let mut inv = F::ZERO;
let mut is_zero = F::ZERO;
IsZeroSubAir.generate_subrow(read.address_space, (&mut inv, &mut is_zero));
let timestamp_lt_cols =
self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp);

MemoryReadOrImmediateAuxCols::new(
F::from_canonical_u32(read.prev_timestamp),
is_zero,
inv,
timestamp_lt_cols,
)
}

/// In general, prefer `generate_write_aux` which writes in-place rather than this function.
pub fn make_write_aux_cols<const N: usize>(
&self,
write: &MemoryRecord<F>,
Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/system/memory/offline_checker/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl MemoryOfflineChecker {
) {
let lt_io = AssertLessThanIo::new(base.prev_timestamp, timestamp.clone(), enabled);
self.timestamp_lt_air
.eval(builder, (lt_io, &base.clk_lt_aux.lower_decomp));
.eval(builder, (lt_io, &base.timestamp_lt_aux.lower_decomp));
}

/// At the core, eval_bulk_access is a bunch of push_sends and push_receives.
Expand Down
87 changes: 17 additions & 70 deletions crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use openvm_circuit_primitives::is_less_than::LessThanAuxCols;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32};
use openvm_stark_backend::p3_field::PrimeField32;

use crate::system::memory::offline_checker::bridge::AUX_LEN;

Expand All @@ -14,24 +14,28 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN;
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryBaseAuxCols<T> {
/// The previous timestamps in which the cells were accessed.
pub(crate) prev_timestamp: T,
pub(in crate::system::memory) prev_timestamp: T,
/// The auxiliary columns to perform the less than check.
pub(crate) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
pub(in crate::system::memory) timestamp_lt_aux: LessThanAuxCols<T, AUX_LEN>,
}

#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryWriteAuxCols<T, const N: usize> {
pub base: MemoryBaseAuxCols<T>,
pub prev_data: [T; N],
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
pub(in crate::system::memory) prev_data: [T; N],
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn new(prev_data: [T; N], prev_timestamp: T, lt_aux: LessThanAuxCols<T, AUX_LEN>) -> Self {
pub(in crate::system::memory) fn new(
prev_data: [T; N],
prev_timestamp: T,
lt_aux: LessThanAuxCols<T, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp,
clk_lt_aux: lt_aux,
timestamp_lt_aux: lt_aux,
},
prev_data,
}
Expand All @@ -48,49 +52,25 @@ impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, F: FieldAlgebra> MemoryWriteAuxCols<F, N> {
pub const fn disabled() -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
prev_data: [F::ZERO; N],
}
}
}

/// The auxiliary columns for a memory read operation with block size `N`.
/// These columns should be automatically managed by the memory controller.
/// To fully constrain a memory read, in addition to these columns,
/// the address space, pointer, and data must be provided.
#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryReadAuxCols<T> {
pub(crate) base: MemoryBaseAuxCols<T>,
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
}

impl<F: PrimeField32> MemoryReadAuxCols<F> {
pub fn new(prev_timestamp: u32, clk_lt_aux: LessThanAuxCols<F, AUX_LEN>) -> Self {
pub(in crate::system::memory) fn new(
prev_timestamp: u32,
timestamp_lt_aux: LessThanAuxCols<F, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::from_canonical_u32(prev_timestamp),
clk_lt_aux,
},
}
}
}

impl<F: FieldAlgebra + Copy> MemoryReadAuxCols<F> {
pub const fn disabled() -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
timestamp_lt_aux,
},
}
}
Expand All @@ -103,36 +83,3 @@ pub struct MemoryReadOrImmediateAuxCols<T> {
pub(crate) is_immediate: T,
pub(crate) is_zero_aux: T,
}

impl<T> MemoryReadOrImmediateAuxCols<T> {
pub fn new(
prev_timestamp: T,
is_immediate: T,
is_zero_aux: T,
clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp,
clk_lt_aux,
},
is_immediate,
is_zero_aux,
}
}
}

impl<F: FieldAlgebra + Copy> MemoryReadOrImmediateAuxCols<F> {
pub const fn disabled() -> Self {
MemoryReadOrImmediateAuxCols {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
is_immediate: F::ZERO,
is_zero_aux: F::ZERO,
}
}
}
10 changes: 5 additions & 5 deletions crates/vm/src/system/memory/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,27 @@ fn generate_trace<F: PrimeField32>(

match (record.data.len(), &record.prev_data) {
(1, &None) => {
row.read_1_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_read_1 = F::ONE;
}
(1, &Some(_)) => {
row.write_1_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_write_1 = F::ONE;
}
(4, &None) => {
row.read_4_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_read_4 = F::ONE;
}
(4, &Some(_)) => {
row.write_4_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_write_4 = F::ONE;
}
(MAX, &None) => {
row.read_max_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_max_aux);
row.data_max = record.data.try_into().unwrap();
row.is_read_max = F::ONE;
}
Expand Down
32 changes: 16 additions & 16 deletions crates/vm/src/system/native_adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,22 @@ impl<F: PrimeField32, const R: usize, const W: usize> VmAdapterChip<F>

row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);

row_slice.reads_aux = read_record.reads.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterReadCols {
address,
read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(record),
}
});
row_slice.writes_aux = write_record.writes.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterWriteCols {
address,
write_aux: aux_cols_factory.make_write_aux_cols(record),
}
});
for (i, read) in read_record.reads.iter().enumerate() {
let (id, _) = read;
let record = memory.record_by_id(*id);
aux_cols_factory
.generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux);
row_slice.reads_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}

for (i, write) in write_record.writes.iter().enumerate() {
let (id, _) = write;
let record = memory.record_by_id(*id);
aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux);
row_slice.writes_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}
}

fn air(&self) -> &Self::Air {
Expand Down
18 changes: 10 additions & 8 deletions extensions/keccak256/circuit/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,17 @@ where
);
}
for (i, id) in register_reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.register_aux[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.register_aux[i],
);
}
}
for (i, id) in block.reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.absorb_reads[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.absorb_reads[i],
);
}

let last_row: &mut KeccakVmCols<Val<SC>> =
Expand All @@ -239,8 +241,8 @@ where
for (i, record_id) in digest_writes.into_iter().enumerate() {
// TODO: these aux columns are only used for the last row - can we share them with aux reads in first row?
let record = memory.record_by_id(record_id);
last_row.mem_oc.digest_writes[i] =
aux_cols_factory.make_write_aux_cols(record);
aux_cols_factory
.generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]);
}
}
});
Expand Down
13 changes: 6 additions & 7 deletions extensions/native/circuit/src/adapters/branch_native_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,13 @@ impl<F: PrimeField32> VmAdapterChip<F> for BranchNativeAdapterChip<F> {
let aux_cols_factory = memory.aux_cols_factory();

row_slice.from_state = write_record.map(F::from_canonical_u32);
row_slice.reads_aux = read_record.reads.map(|x| {
for (i, x) in read_record.reads.iter().enumerate() {
let read = memory.record_by_id(x.0);
let address = MemoryAddress::new(read.address_space, read.pointer);
BranchNativeAdapterReadCols {
address,
read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(read),
}
});

row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer);
aux_cols_factory
.generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux);
}
}

fn air(&self) -> &Self::Air {
Expand Down
4 changes: 2 additions & 2 deletions extensions/native/circuit/src/adapters/convert_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ impl<F: PrimeField32, const READ_SIZE: usize, const WRITE_SIZE: usize> VmAdapter
row_slice.b_pointer = read.pointer;
row_slice.b_as = read.address_space;

row_slice.reads_aux = [aux_cols_factory.make_read_aux_cols(read)];
row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(write)];
aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]);
aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]);
}

fn air(&self) -> &Self::Air {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl<F: PrimeField32> VmAdapterChip<F> for JalNativeAdapterChip<F> {
row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
row_slice.a_pointer = write.pointer;
row_slice.a_as = write.address_space;
row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(write);
aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux);
}

fn air(&self) -> &Self::Air {
Expand Down
Loading

0 comments on commit 486a166

Please sign in to comment.