Skip to content

Commit

Permalink
chore: Use aux_cols_factory.generate_*_aux
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley committed Jan 23, 2025
1 parent 40dc754 commit c2e0557
Show file tree
Hide file tree
Showing 18 changed files with 173 additions and 148 deletions.
13 changes: 11 additions & 2 deletions crates/vm/src/system/memory/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use crate::{
merkle::{MemoryMerkleBus, MemoryMerkleChip},
offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP},
offline_checker::{
MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols,
MemoryWriteAuxCols, AUX_LEN,
MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols,
MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN,
},
online::{Memory, MemoryLogEntry},
persistent::PersistentBoundaryChip,
Expand Down Expand Up @@ -759,6 +759,15 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
);
}

pub fn generate_base_aux(&self, record: &MemoryRecord<F>, buffer: &mut MemoryBaseAuxCols<F>) {
buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp);
self.generate_timestamp_lt(
record.prev_timestamp,
record.timestamp,
&mut buffer.clk_lt_aux,
);
}

fn generate_timestamp_lt(
&self,
prev_timestamp: u32,
Expand Down
6 changes: 3 additions & 3 deletions crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN;
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryBaseAuxCols<T> {
/// The previous timestamps in which the cells were accessed.
pub(crate) prev_timestamp: T,
pub(in crate::system::memory) prev_timestamp: T,
/// The auxiliary columns to perform the less than check.
pub(crate) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
pub(in crate::system::memory) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
}

#[repr(C)]
Expand Down Expand Up @@ -69,7 +69,7 @@ impl<const N: usize, F: FieldAlgebra> MemoryWriteAuxCols<F, N> {
#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryReadAuxCols<T> {
pub(crate) base: MemoryBaseAuxCols<T>,
pub(in crate::system::memory) base: MemoryBaseAuxCols<T>,
}

impl<F: PrimeField32> MemoryReadAuxCols<F> {
Expand Down
10 changes: 5 additions & 5 deletions crates/vm/src/system/memory/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,27 @@ fn generate_trace<F: PrimeField32>(

match (record.data.len(), &record.prev_data) {
(1, &None) => {
row.read_1_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_read_1 = F::ONE;
}
(1, &Some(_)) => {
row.write_1_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_1_aux);
row.data_1 = record.data.try_into().unwrap();
row.is_write_1 = F::ONE;
}
(4, &None) => {
row.read_4_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_read_4 = F::ONE;
}
(4, &Some(_)) => {
row.write_4_aux = aux_factory.make_write_aux_cols(&record);
aux_factory.generate_write_aux(&record, &mut row.write_4_aux);
row.data_4 = record.data.try_into().unwrap();
row.is_write_4 = F::ONE;
}
(MAX, &None) => {
row.read_max_aux = aux_factory.make_read_aux_cols(&record);
aux_factory.generate_read_aux(&record, &mut row.read_max_aux);
row.data_max = record.data.try_into().unwrap();
row.is_read_max = F::ONE;
}
Expand Down
32 changes: 16 additions & 16 deletions crates/vm/src/system/native_adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,22 @@ impl<F: PrimeField32, const R: usize, const W: usize> VmAdapterChip<F>

row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);

row_slice.reads_aux = read_record.reads.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterReadCols {
address,
read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(record),
}
});
row_slice.writes_aux = write_record.writes.map(|(id, _)| {
let record = memory.record_by_id(id);
let address = MemoryAddress::new(record.address_space, record.pointer);
NativeAdapterWriteCols {
address,
write_aux: aux_cols_factory.make_write_aux_cols(record),
}
});
for (i, read) in read_record.reads.iter().enumerate() {
let (id, _) = read;
let record = memory.record_by_id(*id);
aux_cols_factory
.generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux);
row_slice.reads_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}

for (i, write) in write_record.writes.iter().enumerate() {
let (id, _) = write;
let record = memory.record_by_id(*id);
aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux);
row_slice.writes_aux[i].address =
MemoryAddress::new(record.address_space, record.pointer);
}
}

fn air(&self) -> &Self::Air {
Expand Down
14 changes: 8 additions & 6 deletions extensions/keccak256/circuit/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,17 @@ where
);
}
for (i, id) in register_reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.register_aux[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.register_aux[i],
);
}
}
for (i, id) in block.reads.into_iter().enumerate() {
// TODO[jpw] make_read_aux_cols should directly write into slice
first_row.mem_oc.absorb_reads[i] =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(id));
aux_cols_factory.generate_read_aux(
memory.record_by_id(id),
&mut first_row.mem_oc.absorb_reads[i],
);
}

let last_row: &mut KeccakVmCols<Val<SC>> =
Expand Down
4 changes: 2 additions & 2 deletions extensions/native/circuit/src/adapters/convert_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ impl<F: PrimeField32, const READ_SIZE: usize, const WRITE_SIZE: usize> VmAdapter
row_slice.b_pointer = read.pointer;
row_slice.b_as = read.address_space;

row_slice.reads_aux = [aux_cols_factory.make_read_aux_cols(read)];
row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(write)];
aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]);
aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]);
}

fn air(&self) -> &Self::Air {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,18 +311,18 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>

let data_read = read_record.data_read.map(|read| memory.record_by_id(read));
if let Some(data_read) = data_read {
cols.data_read_aux_cols = aux_cols_factory.make_read_aux_cols(data_read);
} else {
cols.data_read_aux_cols = MemoryReadAuxCols::disabled();
aux_cols_factory.generate_read_aux(data_read, &mut cols.data_read_aux_cols);
}

let write = memory.record_by_id(write_record.write_id);
cols.data_write_as = write.address_space;
cols.data_write_pointer = write.pointer;

cols.pointer_read_aux_cols =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.data_write_aux_cols = aux_cols_factory.make_write_aux_cols(write);
aux_cols_factory.generate_read_aux(
memory.record_by_id(read_record.pointer_read),
&mut cols.pointer_read_aux_cols,
);
aux_cols_factory.generate_write_aux(write, &mut cols.data_write_aux_cols);
}

fn air(&self) -> &Self::Air {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,9 @@ impl<F: PrimeField32, const N: usize> VmAdapterChip<F> for NativeVectorizedAdapt
row_slice.b_pointer = b_record.pointer;
row_slice.c_pointer = c_record.pointer;
row_slice.c_as = c_record.address_space;

row_slice.reads_aux = [
aux_cols_factory.make_read_aux_cols(b_record),
aux_cols_factory.make_read_aux_cols(c_record),
];
row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(a_record)];
aux_cols_factory.generate_read_aux(b_record, &mut row_slice.reads_aux[0]);
aux_cols_factory.generate_read_aux(c_record, &mut row_slice.reads_aux[1]);
aux_cols_factory.generate_write_aux(a_record, &mut row_slice.writes_aux[0]);
}

fn air(&self) -> &Self::Air {
Expand Down
94 changes: 59 additions & 35 deletions extensions/native/circuit/src/poseidon2/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R

specific.end_timestamp =
F::from_canonical_usize(read_root_is_on_right.timestamp as usize + (2 + CHUNK));
specific.reads =
reads.map(|read| aux_cols_factory.make_read_aux_cols(memory.record_by_id(read)));
for (i, read) in reads.iter().enumerate() {
aux_cols_factory.generate_read_aux(memory.record_by_id(*read), &mut specific.reads[i]);
}
cols.initial_opened_index = F::from_canonical_usize(opened_index);
specific.final_opened_index = F::from_canonical_usize(opened_index - 1);
specific.height = F::from_canonical_usize(height);
Expand All @@ -105,10 +106,14 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
specific.index_base_pointer = parent.index_base_pointer;

specific.proof_index = F::from_canonical_usize(proof_index);
specific.read_initial_height_or_root_is_on_right =
aux_cols_factory.make_read_aux_cols(read_root_is_on_right);
specific.read_final_height_or_sibling_array_start =
aux_cols_factory.make_read_aux_cols(read_sibling_array_start);
aux_cols_factory.generate_read_aux(
read_root_is_on_right,
&mut specific.read_initial_height_or_root_is_on_right,
);
aux_cols_factory.generate_read_aux(
read_sibling_array_start,
&mut specific.read_final_height_or_sibling_array_start,
);
specific.root_is_on_right = F::from_bool(root_is_on_right);
specific.sibling_array_start = read_sibling_array_start.data[0];
}
Expand Down Expand Up @@ -146,20 +151,32 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
specific.index_register = instruction.e;
specific.commit_register = instruction.f;
specific.commit_pointer = commit_pointer;
specific.dim_base_pointer_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(dim_base_pointer_read));
specific.opened_base_pointer_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(opened_base_pointer_read));
specific.opened_length_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(opened_length_read));
specific.sibling_base_pointer_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(sibling_base_pointer_read));
specific.index_base_pointer_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(index_base_pointer_read));
specific.commit_pointer_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(commit_pointer_read));
specific.commit_read =
aux_cols_factory.make_read_aux_cols(memory.record_by_id(commit_read));
aux_cols_factory.generate_read_aux(
memory.record_by_id(dim_base_pointer_read),
&mut specific.dim_base_pointer_read,
);
aux_cols_factory.generate_read_aux(
memory.record_by_id(opened_base_pointer_read),
&mut specific.opened_base_pointer_read,
);
aux_cols_factory.generate_read_aux(
memory.record_by_id(opened_length_read),
&mut specific.opened_length_read,
);
aux_cols_factory.generate_read_aux(
memory.record_by_id(sibling_base_pointer_read),
&mut specific.sibling_base_pointer_read,
);
aux_cols_factory.generate_read_aux(
memory.record_by_id(index_base_pointer_read),
&mut specific.index_base_pointer_read,
);
aux_cols_factory.generate_read_aux(
memory.record_by_id(commit_pointer_read),
&mut specific.commit_pointer_read,
);
aux_cols_factory
.generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read);
}
#[allow(clippy::too_many_arguments)]
fn incorporate_row_record_to_row(
Expand Down Expand Up @@ -220,10 +237,14 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
specific.index_base_pointer = parent.index_base_pointer;

specific.proof_index = F::from_canonical_usize(proof_index);
specific.read_initial_height_or_root_is_on_right =
aux_cols_factory.make_read_aux_cols(initial_height_read);
specific.read_final_height_or_sibling_array_start =
aux_cols_factory.make_read_aux_cols(final_height_read);
aux_cols_factory.generate_read_aux(
initial_height_read,
&mut specific.read_initial_height_or_root_is_on_right,
);
aux_cols_factory.generate_read_aux(
final_height_read,
&mut specific.read_final_height_or_sibling_array_start,
);
}
#[allow(clippy::too_many_arguments)]
fn inside_row_record_to_row(
Expand Down Expand Up @@ -269,11 +290,13 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
row_pointer,
row_end,
} = record;
cell.read = aux_cols_factory.make_read_aux_cols(memory.record_by_id(read));
aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read);
cell.opened_index = F::from_canonical_usize(opened_index);
if let Some(read_row_pointer_and_length) = read_row_pointer_and_length {
cell.read_row_pointer_and_length = aux_cols_factory
.make_read_aux_cols(memory.record_by_id(read_row_pointer_and_length));
aux_cols_factory.generate_read_aux(
memory.record_by_id(read_row_pointer_and_length),
&mut cell.read_row_pointer_and_length,
);
}
cell.row_pointer = F::from_canonical_usize(row_pointer);
cell.row_end = F::from_canonical_usize(row_end);
Expand Down Expand Up @@ -419,19 +442,20 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
specific.output_pointer = output_pointer;
specific.input_pointer_1 = input_pointer_1;
specific.input_pointer_2 = input_pointer_2;
specific.read_output_pointer = aux_cols_factory.make_read_aux_cols(read_output_pointer);
specific.read_input_pointer_1 = aux_cols_factory.make_read_aux_cols(read_input_pointer_1);
specific.read_data_1 = aux_cols_factory.make_read_aux_cols(read_data_1);
specific.read_data_2 = aux_cols_factory.make_read_aux_cols(read_data_2);
specific.write_data_1 = aux_cols_factory.make_write_aux_cols(write_data_1);
aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer);
aux_cols_factory
.generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1);
aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1);
aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2);
aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1);

if opcode == COMP_POS2.global_opcode() {
let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap());
specific.read_input_pointer_2 =
aux_cols_factory.make_read_aux_cols(read_input_pointer_2);
aux_cols_factory
.generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2);
} else {
let write_data_2 = memory.record_by_id(write_data_2.unwrap());
specific.write_data_2 = aux_cols_factory.make_write_aux_cols(write_data_2);
aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2);
}
}

Expand Down
19 changes: 11 additions & 8 deletions extensions/rv32-adapters/src/eq_mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
array::{self, from_fn},
array::from_fn,
borrow::{Borrow, BorrowMut},
marker::PhantomData,
};
Expand Down Expand Up @@ -397,16 +397,19 @@ impl<
row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);

let rs = read_record.rs.map(|r| memory.record_by_id(r));
row_slice.rs_ptr = array::from_fn(|i| rs[i].pointer);
row_slice.rs_val = array::from_fn(|i| rs[i].data.clone().try_into().unwrap());
row_slice.rs_read_aux = array::from_fn(|i| aux_cols_factory.make_read_aux_cols(rs[i]));
row_slice.heap_read_aux = read_record
.reads
.map(|r| r.map(|x| aux_cols_factory.make_read_aux_cols(memory.record_by_id(x))));
for (i, r) in rs.iter().enumerate() {
row_slice.rs_ptr[i] = r.pointer;
row_slice.rs_val[i] = r.data.clone().try_into().unwrap();
aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
for (j, x) in read_record.reads[i].iter().enumerate() {
let read = memory.record_by_id(*x);
aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]);
}
}

let rd = memory.record_by_id(write_record.rd_id);
row_slice.rd_ptr = rd.pointer;
row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(rd);
aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);

// Range checks
let need_range_check: [u32; 2] = from_fn(|i| {
Expand Down
Loading

0 comments on commit c2e0557

Please sign in to comment.