Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Use aux_cols_factory.generate_*_aux #1263

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 17 additions & 39 deletions crates/vm/src/system/memory/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use crate::{
merkle::{MemoryMerkleBus, MemoryMerkleChip},
offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP},
offline_checker::{
MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols,
MemoryWriteAuxCols, AUX_LEN,
MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols,
MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN,
},
online::{Memory, MemoryLogEntry},
persistent::PersistentBoundaryChip,
Expand Down Expand Up @@ -719,12 +719,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
!read.address_space.is_zero(),
"cannot make `MemoryReadAuxCols` for address space 0"
);
buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp);
self.generate_timestamp_lt(
read.prev_timestamp,
read.timestamp,
&mut buffer.base.clk_lt_aux,
);
self.generate_base_aux(read, &mut buffer.base);
}

pub fn generate_read_or_immediate_aux(
Expand All @@ -736,26 +731,26 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
read.address_space,
(&mut buffer.is_zero_aux, &mut buffer.is_immediate),
);
buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp);
self.generate_timestamp_lt(
read.prev_timestamp,
read.timestamp,
&mut buffer.base.clk_lt_aux,
);
self.generate_base_aux(read, &mut buffer.base);
}

pub fn generate_write_aux<const N: usize>(
&self,
write: &MemoryRecord<F>,
buffer: &mut MemoryWriteAuxCols<F, N>,
) {
let prev_data = write.prev_data.clone().unwrap();
buffer.prev_data = prev_data.try_into().unwrap();
buffer.base.prev_timestamp = F::from_canonical_u32(write.prev_timestamp);
buffer
.prev_data
.copy_from_slice(write.prev_data.as_ref().unwrap());
self.generate_base_aux(write, &mut buffer.base);
}

pub fn generate_base_aux(&self, record: &MemoryRecord<F>, buffer: &mut MemoryBaseAuxCols<F>) {
buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp);
self.generate_timestamp_lt(
write.prev_timestamp,
write.timestamp,
&mut buffer.base.clk_lt_aux,
record.prev_timestamp,
record.timestamp,
&mut buffer.timestamp_lt_aux,
);
}

Expand All @@ -772,7 +767,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
);
}

// TODO[jpw]: delete these functions and use the ones that fill buffers above.
/// In general, prefer `generate_read_aux` which writes in-place rather than this function.
pub fn make_read_aux_cols(&self, read: &MemoryRecord<F>) -> MemoryReadAuxCols<F> {
assert!(
!read.address_space.is_zero(),
Expand All @@ -784,24 +779,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
)
}

pub fn make_read_or_immediate_aux_cols(
&self,
read: &MemoryRecord<F>,
) -> MemoryReadOrImmediateAuxCols<F> {
let mut inv = F::ZERO;
let mut is_zero = F::ZERO;
IsZeroSubAir.generate_subrow(read.address_space, (&mut inv, &mut is_zero));
let timestamp_lt_cols =
self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp);

MemoryReadOrImmediateAuxCols::new(
F::from_canonical_u32(read.prev_timestamp),
is_zero,
inv,
timestamp_lt_cols,
)
}

/// In general, prefer `generate_write_aux` which writes in-place rather than this function.
pub fn make_write_aux_cols<const N: usize>(
&self,
write: &MemoryRecord<F>,
Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/system/memory/offline_checker/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl MemoryOfflineChecker {
) {
let lt_io = AssertLessThanIo::new(base.prev_timestamp, timestamp.clone(), enabled);
self.timestamp_lt_air
.eval(builder, (lt_io, &base.clk_lt_aux.lower_decomp));
.eval(builder, (lt_io, &base.timestamp_lt_aux.lower_decomp));
}

/// At the core, eval_bulk_access is a bunch of push_sends and push_receives.
Expand Down
87 changes: 17 additions & 70 deletions crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use openvm_circuit_primitives::is_less_than::LessThanAuxCols;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32};
use openvm_stark_backend::p3_field::PrimeField32;

use crate::system::memory::offline_checker::bridge::AUX_LEN;

Expand All @@ -14,24 +14,28 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN;
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryBaseAuxCols<T> {
/// The previous timestamps in which the cells were accessed.
pub(crate) prev_timestamp: T,
pub(in crate::system::memory) prev_timestamp: T,
/// The auxiliary columns to perform the less than check.
pub(crate) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
pub(in crate::system::memory) timestamp_lt_aux: LessThanAuxCols<T, AUX_LEN>,
}

#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryWriteAuxCols<T, const N: usize> {
pub base: MemoryBaseAuxCols<T>,
pub prev_data: [T; N],
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
pub(in crate::system::memory) prev_data: [T; N],
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn new(prev_data: [T; N], prev_timestamp: T, lt_aux: LessThanAuxCols<T, AUX_LEN>) -> Self {
pub(in crate::system::memory) fn new(
prev_data: [T; N],
prev_timestamp: T,
lt_aux: LessThanAuxCols<T, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp,
clk_lt_aux: lt_aux,
timestamp_lt_aux: lt_aux,
},
prev_data,
}
Expand All @@ -48,49 +52,25 @@ impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, F: FieldAlgebra> MemoryWriteAuxCols<F, N> {
pub const fn disabled() -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
prev_data: [F::ZERO; N],
}
}
}

/// The auxiliary columns for a memory read operation with block size `N`.
/// These columns should be automatically managed by the memory controller.
/// To fully constrain a memory read, in addition to these columns,
/// the address space, pointer, and data must be provided.
#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryReadAuxCols<T> {
pub(crate) base: MemoryBaseAuxCols<T>,
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
}

impl<F: PrimeField32> MemoryReadAuxCols<F> {
pub fn new(prev_timestamp: u32, clk_lt_aux: LessThanAuxCols<F, AUX_LEN>) -> Self {
pub(in crate::system::memory) fn new(
prev_timestamp: u32,
timestamp_lt_aux: LessThanAuxCols<F, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::from_canonical_u32(prev_timestamp),
clk_lt_aux,
},
}
}
}

impl<F: FieldAlgebra + Copy> MemoryReadAuxCols<F> {
pub const fn disabled() -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
timestamp_lt_aux,
},
}
}
Expand All @@ -103,36 +83,3 @@ pub struct MemoryReadOrImmediateAuxCols<T> {
pub(crate) is_immediate: T,
pub(crate) is_zero_aux: T,
}

impl<T> MemoryReadOrImmediateAuxCols<T> {
pub fn new(
prev_timestamp: T,
is_immediate: T,
is_zero_aux: T,
clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
) -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp,
clk_lt_aux,
},
is_immediate,
is_zero_aux,
}
}
}

impl<F: FieldAlgebra + Copy> MemoryReadOrImmediateAuxCols<F> {
pub const fn disabled() -> Self {
MemoryReadOrImmediateAuxCols {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
is_immediate: F::ZERO,
is_zero_aux: F::ZERO,
}
}
}
10 changes: 5 additions & 5 deletions crates/vm/src/system/memory/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,27 @@ fn generate_trace<F: PrimeField32>(

match (record.data.len(), &record.prev_data) {
(1, &None) => {
row.read_1_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_read_1 = F::ONE;
}
(1, &Some(_)) => {
row.write_1_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_write_1 = F::ONE;
}
(4, &None) => {
row.read_4_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_read_4 = F::ONE;
}
(4, &Some(_)) => {
row.write_4_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_write_4 = F::ONE;
}
(MAX, &None) => {
row.read_max_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_max_aux);
row.data_max = record.data.try_into().unwrap();
row.is_read_max = F::ONE;
}
Expand Down
32 changes: 16 additions & 16 deletions crates/vm/src/system/native_adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,22 @@ impl<F: PrimeField32, const R: usize, const W: usize> VmAdapterChip<F>

row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);

row_slice.reads_aux = read_record.reads.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterReadCols {
address,
read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(record),
}
});
row_slice.writes_aux = write_record.writes.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterWriteCols {
address,
write_aux: aux_cols_factory.make_write_aux_cols(record),
}
});
for (i, read) in read_record.reads.iter().enumerate() {
let (id, _) = read;
let record = memory.record_by_id(*id);
aux_cols_factory
.generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux);
row_slice.reads_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}

for (i, write) in write_record.writes.iter().enumerate() {
let (id, _) = write;
let record = memory.record_by_id(*id);
aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux);
row_slice.writes_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}
}

fn air(&self) -> &Self::Air {
Expand Down
18 changes: 10 additions & 8 deletions extensions/keccak256/circuit/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,17 @@ where
);
}
for (i, id) in register_reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.register_aux[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.register_aux[i],
);
}
}
for (i, id) in block.reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.absorb_reads[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.absorb_reads[i],
);
}

let last_row: &mut KeccakVmCols<Val<SC>> =
Expand All @@ -239,8 +241,8 @@ where
for (i, record_id) in digest_writes.into_iter().enumerate() {
// TODO: these aux columns are only used for the last row - can we share them with aux reads in first row?
let record = memory.record_by_id(record_id);
last_row.mem_oc.digest_writes[i] =
aux_cols_factory.make_write_aux_cols(record);
aux_cols_factory
.generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]);
}
}
});
Expand Down
13 changes: 6 additions & 7 deletions extensions/native/circuit/src/adapters/branch_native_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,13 @@ impl<F: PrimeField32> VmAdapterChip<F> for BranchNativeAdapterChip<F> {
let aux_cols_factory = memory.aux_cols_factory();

row_slice.from_state = write_record.map(F::from_canonical_u32);
row_slice.reads_aux = read_record.reads.map(|x| {
for (i, x) in read_record.reads.iter().enumerate() {
let read = memory.record_by_id(x.0);
let address = MemoryAddress::new(read.address_space, read.pointer);
BranchNativeAdapterReadCols {
address,
read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(read),
}
});

row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer);
aux_cols_factory
.generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux);
}
}

fn air(&self) -> &Self::Air {
Expand Down
4 changes: 2 additions & 2 deletions extensions/native/circuit/src/adapters/convert_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ impl<F: PrimeField32, const READ_SIZE: usize, const WRITE_SIZE: usize> VmAdapter
row_slice.b_pointer = read.pointer;
row_slice.b_as = read.address_space;

row_slice.reads_aux = [aux_cols_factory.make_read_aux_cols(read)];
row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(write)];
aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]);
aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]);
}

fn air(&self) -> &Self::Air {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl<F: PrimeField32> VmAdapterChip<F> for JalNativeAdapterChip<F> {
row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
row_slice.a_pointer = write.pointer;
row_slice.a_as = write.address_space;
row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(write);
aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux);
}

fn air(&self) -> &Self::Air {
Expand Down
Loading