From 6b0489e0b730df3767639cc1fd0ae2716bff295c Mon Sep 17 00:00:00 2001 From: Zach Langley Date: Tue, 7 Jan 2025 16:27:40 -0500 Subject: [PATCH] chore: Use aux_cols_factory.generate_*_aux --- crates/vm/src/system/memory/controller/mod.rs | 56 ++++------- .../system/memory/offline_checker/columns.rs | 6 +- crates/vm/src/system/memory/tests.rs | 10 +- crates/vm/src/system/native_adapter/mod.rs | 32 +++---- extensions/keccak256/circuit/src/trace.rs | 18 ++-- .../src/adapters/branch_native_adapter.rs | 13 ++- .../circuit/src/adapters/convert_adapter.rs | 4 +- .../src/adapters/jal_native_adapter.rs | 2 +- .../src/adapters/loadstore_native_adapter.rs | 12 +-- .../src/adapters/native_vectorized_adapter.rs | 9 +- .../native/circuit/src/poseidon2/trace.rs | 94 ++++++++++++------- extensions/rv32-adapters/src/eq_mod.rs | 19 ++-- extensions/rv32-adapters/src/heap_branch.rs | 21 +++-- extensions/rv32-adapters/src/vec_heap.rs | 36 ++++--- .../rv32-adapters/src/vec_heap_two_reads.rs | 39 ++++---- extensions/rv32im/circuit/src/adapters/alu.rs | 14 +-- .../rv32im/circuit/src/adapters/branch.rs | 6 +- .../rv32im/circuit/src/adapters/hintstore.rs | 4 +- .../rv32im/circuit/src/adapters/jalr.rs | 19 ++-- .../rv32im/circuit/src/adapters/loadstore.rs | 10 +- extensions/rv32im/circuit/src/adapters/mul.rs | 8 +- .../rv32im/circuit/src/adapters/rdwrite.rs | 6 +- .../sha256/circuit/src/sha256_chip/trace.rs | 19 ++-- 23 files changed, 231 insertions(+), 226 deletions(-) diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index eb33f15e6..9c6e8c991 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -41,8 +41,8 @@ use crate::{ merkle::{MemoryMerkleBus, MemoryMerkleChip}, offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP}, offline_checker::{ - MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols, - MemoryWriteAuxCols, AUX_LEN, + MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols, + MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN, }, online::{Memory, MemoryLogEntry}, persistent::PersistentBoundaryChip, @@ -719,12 +719,7 @@ impl MemoryAuxColsFactory { !read.address_space.is_zero(), "cannot make `MemoryReadAuxCols` for address space 0" ); - buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp); - self.generate_timestamp_lt( - read.prev_timestamp, - read.timestamp, - &mut buffer.base.clk_lt_aux, - ); + self.generate_base_aux(read, &mut buffer.base); } pub fn generate_read_or_immediate_aux( @@ -736,12 +731,7 @@ impl MemoryAuxColsFactory { read.address_space, (&mut buffer.is_zero_aux, &mut buffer.is_immediate), ); - buffer.base.prev_timestamp = F::from_canonical_u32(read.prev_timestamp); - self.generate_timestamp_lt( - read.prev_timestamp, - read.timestamp, - &mut buffer.base.clk_lt_aux, - ); + self.generate_base_aux(read, &mut buffer.base); } pub fn generate_write_aux( @@ -749,13 +739,18 @@ impl MemoryAuxColsFactory { write: &MemoryRecord, buffer: &mut MemoryWriteAuxCols, ) { - let prev_data = write.prev_data.clone().unwrap(); - buffer.prev_data = prev_data.try_into().unwrap(); - buffer.base.prev_timestamp = F::from_canonical_u32(write.prev_timestamp); + buffer + .prev_data + .copy_from_slice(&write.prev_data.as_ref().unwrap()); + self.generate_base_aux(write, &mut buffer.base); + } + + pub fn generate_base_aux(&self, record: &MemoryRecord, buffer: &mut MemoryBaseAuxCols) { + buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp); self.generate_timestamp_lt( - write.prev_timestamp, - write.timestamp, - &mut buffer.base.clk_lt_aux, + record.prev_timestamp, + record.timestamp, + &mut buffer.clk_lt_aux, ); } @@ -772,7 +767,7 @@ impl MemoryAuxColsFactory { ); } - // TODO[jpw]: delete these functions and use the ones that fill buffers above. + /// In general, prefer `generate_read_aux` which writes in-place rather than this function. pub fn make_read_aux_cols(&self, read: &MemoryRecord) -> MemoryReadAuxCols { assert!( !read.address_space.is_zero(), @@ -784,24 +779,7 @@ impl MemoryAuxColsFactory { ) } - pub fn make_read_or_immediate_aux_cols( - &self, - read: &MemoryRecord, - ) -> MemoryReadOrImmediateAuxCols { - let mut inv = F::ZERO; - let mut is_zero = F::ZERO; - IsZeroSubAir.generate_subrow(read.address_space, (&mut inv, &mut is_zero)); - let timestamp_lt_cols = - self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp); - - MemoryReadOrImmediateAuxCols::new( - F::from_canonical_u32(read.prev_timestamp), - is_zero, - inv, - timestamp_lt_cols, - ) - } - + /// In general, prefer `generate_write_aux` which writes in-place rather than this function. pub fn make_write_aux_cols( &self, write: &MemoryRecord, diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 7f1397111..06495e48e 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -14,9 +14,9 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN; #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryBaseAuxCols { /// The previous timestamps in which the cells were accessed. - pub(crate) prev_timestamp: T, + pub(in crate::system::memory) prev_timestamp: T, /// The auxiliary columns to perform the less than check. - pub(crate) clk_lt_aux: LessThanAuxCols, + pub(in crate::system::memory) clk_lt_aux: LessThanAuxCols, } #[repr(C)] @@ -69,7 +69,7 @@ impl MemoryWriteAuxCols { #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryReadAuxCols { - pub(crate) base: MemoryBaseAuxCols, + pub(in crate::system::memory) base: MemoryBaseAuxCols, } impl MemoryReadAuxCols { diff --git a/crates/vm/src/system/memory/tests.rs b/crates/vm/src/system/memory/tests.rs index 094b2204c..51811c889 100644 --- a/crates/vm/src/system/memory/tests.rs +++ b/crates/vm/src/system/memory/tests.rs @@ -168,27 +168,27 @@ fn generate_trace( match (record.data.len(), &record.prev_data) { (1, &None) => { - row.read_1_aux = aux_factory.make_read_aux_cols(&record); + aux_factory.generate_read_aux(&record, &mut row.read_1_aux); row.data_1 = record.data.try_into().unwrap(); row.is_read_1 = F::ONE; } (1, &Some(_)) => { - row.write_1_aux = aux_factory.make_write_aux_cols(&record); + aux_factory.generate_write_aux(&record, &mut row.write_1_aux); row.data_1 = record.data.try_into().unwrap(); row.is_write_1 = F::ONE; } (4, &None) => { - row.read_4_aux = aux_factory.make_read_aux_cols(&record); + aux_factory.generate_read_aux(&record, &mut row.read_4_aux); row.data_4 = record.data.try_into().unwrap(); row.is_read_4 = F::ONE; } (4, &Some(_)) => { - row.write_4_aux = aux_factory.make_write_aux_cols(&record); + aux_factory.generate_write_aux(&record, &mut row.write_4_aux); row.data_4 = record.data.try_into().unwrap(); row.is_write_4 = F::ONE; } (MAX, &None) => { - row.read_max_aux = aux_factory.make_read_aux_cols(&record); + aux_factory.generate_read_aux(&record, &mut row.read_max_aux); row.data_max = record.data.try_into().unwrap(); row.is_read_max = F::ONE; } diff --git a/crates/vm/src/system/native_adapter/mod.rs b/crates/vm/src/system/native_adapter/mod.rs index 47679cfab..0d393b4b1 100644 --- a/crates/vm/src/system/native_adapter/mod.rs +++ b/crates/vm/src/system/native_adapter/mod.rs @@ -279,22 +279,22 @@ impl VmAdapterChip row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.reads_aux = read_record.reads.map(|(id, _)| { - let record = memory.record_by_id(id); - let address = MemoryAddress::new(record.address_space, record.pointer); - NativeAdapterReadCols { - address, - read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(record), - } - }); - row_slice.writes_aux = write_record.writes.map(|(id, _)| { - let record = memory.record_by_id(id); - let address = MemoryAddress::new(record.address_space, record.pointer); - NativeAdapterWriteCols { - address, - write_aux: aux_cols_factory.make_write_aux_cols(record), - } - }); + for (i, read) in read_record.reads.iter().enumerate() { + let (id, _) = read; + let record = memory.record_by_id(*id); + aux_cols_factory + .generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux); + row_slice.reads_aux[i].address = + MemoryAddress::new(record.address_space, record.pointer); + } + + for (i, write) in write_record.writes.iter().enumerate() { + let (id, _) = write; + let record = memory.record_by_id(*id); + aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux); + row_slice.writes_aux[i].address = + MemoryAddress::new(record.address_space, record.pointer); + } } fn air(&self) -> &Self::Air { diff --git a/extensions/keccak256/circuit/src/trace.rs b/extensions/keccak256/circuit/src/trace.rs index f21e46b5f..c2302d0fe 100644 --- a/extensions/keccak256/circuit/src/trace.rs +++ b/extensions/keccak256/circuit/src/trace.rs @@ -219,15 +219,17 @@ where ); } for (i, id) in register_reads.into_iter().enumerate() { - // TODO[jpw] make_read_aux_cols should directly write into slice - first_row.mem_oc.register_aux[i] = - aux_cols_factory.make_read_aux_cols(memory.record_by_id(id)); + aux_cols_factory.generate_read_aux( + memory.record_by_id(id), + &mut first_row.mem_oc.register_aux[i], + ); } } for (i, id) in block.reads.into_iter().enumerate() { - // TODO[jpw] make_read_aux_cols should directly write into slice - first_row.mem_oc.absorb_reads[i] = - aux_cols_factory.make_read_aux_cols(memory.record_by_id(id)); + aux_cols_factory.generate_read_aux( + memory.record_by_id(id), + &mut first_row.mem_oc.absorb_reads[i], + ); } let last_row: &mut KeccakVmCols> = @@ -239,8 +241,8 @@ where for (i, record_id) in digest_writes.into_iter().enumerate() { // TODO: these aux columns are only used for the last row - can we share them with aux reads in first row? let record = memory.record_by_id(record_id); - last_row.mem_oc.digest_writes[i] = - aux_cols_factory.make_write_aux_cols(record); + aux_cols_factory + .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]); } } }); diff --git a/extensions/native/circuit/src/adapters/branch_native_adapter.rs b/extensions/native/circuit/src/adapters/branch_native_adapter.rs index 1bfcbb173..7a6a0e672 100644 --- a/extensions/native/circuit/src/adapters/branch_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/branch_native_adapter.rs @@ -187,14 +187,13 @@ impl VmAdapterChip for BranchNativeAdapterChip { let aux_cols_factory = memory.aux_cols_factory(); row_slice.from_state = write_record.map(F::from_canonical_u32); - row_slice.reads_aux = read_record.reads.map(|x| { + for (i, x) in read_record.reads.iter().enumerate() { let read = memory.record_by_id(x.0); - let address = MemoryAddress::new(read.address_space, read.pointer); - BranchNativeAdapterReadCols { - address, - read_aux: aux_cols_factory.make_read_or_immediate_aux_cols(read), - } - }); + + row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer); + aux_cols_factory + .generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux); + } } fn air(&self) -> &Self::Air { diff --git a/extensions/native/circuit/src/adapters/convert_adapter.rs b/extensions/native/circuit/src/adapters/convert_adapter.rs index 4757ce257..62200f566 100644 --- a/extensions/native/circuit/src/adapters/convert_adapter.rs +++ b/extensions/native/circuit/src/adapters/convert_adapter.rs @@ -216,8 +216,8 @@ impl VmAdapter row_slice.b_pointer = read.pointer; row_slice.b_as = read.address_space; - row_slice.reads_aux = [aux_cols_factory.make_read_aux_cols(read)]; - row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(write)]; + aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]); + aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]); } fn air(&self) -> &Self::Air { diff --git a/extensions/native/circuit/src/adapters/jal_native_adapter.rs b/extensions/native/circuit/src/adapters/jal_native_adapter.rs index 3f445b74d..b6f3588b7 100644 --- a/extensions/native/circuit/src/adapters/jal_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/jal_native_adapter.rs @@ -171,7 +171,7 @@ impl VmAdapterChip for JalNativeAdapterChip { row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); row_slice.a_pointer = write.pointer; row_slice.a_as = write.address_space; - row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(write); + aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux); } fn air(&self) -> &Self::Air { diff --git a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs index 64cac42ce..af1ca15fd 100644 --- a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs @@ -311,18 +311,18 @@ impl VmAdapterChip let data_read = read_record.data_read.map(|read| memory.record_by_id(read)); if let Some(data_read) = data_read { - cols.data_read_aux_cols = aux_cols_factory.make_read_aux_cols(data_read); - } else { - cols.data_read_aux_cols = MemoryReadAuxCols::disabled(); + aux_cols_factory.generate_read_aux(data_read, &mut cols.data_read_aux_cols); } let write = memory.record_by_id(write_record.write_id); cols.data_write_as = write.address_space; cols.data_write_pointer = write.pointer; - cols.pointer_read_aux_cols = - aux_cols_factory.make_read_aux_cols(memory.record_by_id(read_record.pointer_read)); - cols.data_write_aux_cols = aux_cols_factory.make_write_aux_cols(write); + aux_cols_factory.generate_read_aux( + memory.record_by_id(read_record.pointer_read), + &mut cols.pointer_read_aux_cols, + ); + aux_cols_factory.generate_write_aux(write, &mut cols.data_write_aux_cols); } fn air(&self) -> &Self::Air { diff --git a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs index 1741270e3..d2bff7837 100644 --- a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs +++ b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs @@ -223,12 +223,9 @@ impl VmAdapterChip for NativeVectorizedAdapt row_slice.b_pointer = b_record.pointer; row_slice.c_pointer = c_record.pointer; row_slice.c_as = c_record.address_space; - - row_slice.reads_aux = [ - aux_cols_factory.make_read_aux_cols(b_record), - aux_cols_factory.make_read_aux_cols(c_record), - ]; - row_slice.writes_aux = [aux_cols_factory.make_write_aux_cols(a_record)]; + aux_cols_factory.generate_read_aux(b_record, &mut row_slice.reads_aux[0]); + aux_cols_factory.generate_read_aux(c_record, &mut row_slice.reads_aux[1]); + aux_cols_factory.generate_write_aux(a_record, &mut row_slice.writes_aux[0]); } fn air(&self) -> &Self::Air { diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs index e51a2d784..b294cfcac 100644 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ b/extensions/native/circuit/src/poseidon2/trace.rs @@ -93,8 +93,9 @@ impl NativePoseidon2Chip NativePoseidon2Chip NativePoseidon2Chip NativePoseidon2Chip NativePoseidon2Chip NativePoseidon2Chip VmAdapterC let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = row_slice.borrow_mut(); row_slice.from_state = write_record.map(F::from_canonical_u32); + let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r)); - row_slice.rs_ptr = array::from_fn(|i| rs_reads[i].pointer); - row_slice.rs_val = array::from_fn(|i| rs_reads[i].data.clone().try_into().unwrap()); - row_slice.rs_read_aux = - array::from_fn(|i| aux_cols_factory.make_read_aux_cols(rs_reads[i])); - row_slice.heap_read_aux = read_record - .heap_reads - .map(|r| aux_cols_factory.make_read_aux_cols(memory.record_by_id(r))); + + for (i, rs_read) in rs_reads.iter().enumerate() { + row_slice.rs_ptr[i] = rs_read.pointer; + row_slice.rs_val[i].copy_from_slice(&rs_read.data); + aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]); + } + + for (i, heap_read) in read_record.heap_reads.iter().enumerate() { + let record = memory.record_by_id(*heap_read); + aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]); + } // Range checks: let need_range_check: Vec = rs_reads diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index 36e987b36..3ac019dc8 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -1,5 +1,5 @@ use std::{ - array::{self, from_fn}, + array::from_fn, borrow::{Borrow, BorrowMut}, iter::{once, zip}, marker::PhantomData, @@ -502,19 +502,27 @@ pub(super) fn vec_heap_generate_trace_row_impl< .collect::>(); row_slice.rd_ptr = rd.pointer; - row_slice.rs_ptr = array::from_fn(|i| rs[i].pointer); - - row_slice.rd_val = array::from_fn(|i| rd.data[i]); - row_slice.rs_val = array::from_fn(|j| array::from_fn(|i| rs[j].data[i])); - - row_slice.rs_read_aux = array::from_fn(|i| aux_cols_factory.make_read_aux_cols(rs[i])); - row_slice.rd_read_aux = aux_cols_factory.make_read_aux_cols(rd); - row_slice.reads_aux = read_record - .reads - .map(|r| r.map(|x| aux_cols_factory.make_read_aux_cols(memory.record_by_id(x)))); - row_slice.writes_aux = write_record - .writes - .map(|w| aux_cols_factory.make_write_aux_cols(memory.record_by_id(w))); + row_slice.rd_val.copy_from_slice(&rd.data); + + for (i, r) in rs.iter().enumerate() { + row_slice.rs_ptr[i] = r.pointer; + row_slice.rs_val[i].copy_from_slice(&r.data); + aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); + } + + aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + + for (i, reads) in read_record.reads.iter().enumerate() { + for (j, &x) in reads.iter().enumerate() { + let record = memory.record_by_id(x); + aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]); + } + } + + for (i, &w) in write_record.writes.iter().enumerate() { + let record = memory.record_by_id(w); + aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); + } // Range checks: let need_range_check: Vec = rs diff --git a/extensions/rv32-adapters/src/vec_heap_two_reads.rs b/extensions/rv32-adapters/src/vec_heap_two_reads.rs index a96864f8c..e5eea6d2f 100644 --- a/extensions/rv32-adapters/src/vec_heap_two_reads.rs +++ b/extensions/rv32-adapters/src/vec_heap_two_reads.rs @@ -1,5 +1,5 @@ use std::{ - array::{self, from_fn}, + array::from_fn, borrow::{Borrow, BorrowMut}, iter::zip, marker::PhantomData, @@ -531,23 +531,28 @@ pub(super) fn vec_heap_two_reads_generate_trace_row_impl< row_slice.rs1_ptr = rs1.pointer; row_slice.rs2_ptr = rs2.pointer; - row_slice.rd_val = array::from_fn(|i| rd.data[i]); - row_slice.rs1_val = array::from_fn(|i| rs1.data[i]); - row_slice.rs2_val = array::from_fn(|i| rs2.data[i]); - - row_slice.rs1_read_aux = aux_cols_factory.make_read_aux_cols(rs1); - row_slice.rs2_read_aux = aux_cols_factory.make_read_aux_cols(rs2); - row_slice.rd_read_aux = aux_cols_factory.make_read_aux_cols(rd); - row_slice.reads1_aux = read_record - .reads1 - .map(|r| aux_cols_factory.make_read_aux_cols(memory.record_by_id(r))); - row_slice.reads2_aux = read_record - .reads2 - .map(|r| aux_cols_factory.make_read_aux_cols(memory.record_by_id(r))); - row_slice.writes_aux = write_record - .writes - .map(|w| aux_cols_factory.make_write_aux_cols(memory.record_by_id(w))); + row_slice.rd_val.copy_from_slice(&rd.data); + row_slice.rs1_val.copy_from_slice(&rs1.data); + row_slice.rs2_val.copy_from_slice(&rs2.data); + aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux); + aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux); + aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + + for (i, r) in read_record.reads1.iter().enumerate() { + let record = memory.record_by_id(*r); + aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]); + } + + for (i, r) in read_record.reads2.iter().enumerate() { + let record = memory.record_by_id(*r); + aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]); + } + + for (i, w) in write_record.writes.iter().enumerate() { + let record = memory.record_by_id(*w); + aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); + } // Range checks: let need_range_check = [ &read_record.rs1, diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index de8dff1ff..097e19a28 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -302,19 +302,15 @@ impl VmAdapterChip for Rv32BaseAluAdapterChip { if let Some(rs2) = rs2 { row_slice.rs2 = rs2.pointer; row_slice.rs2_as = rs2.address_space; - row_slice.reads_aux = [ - aux_cols_factory.make_read_aux_cols(rs1), - aux_cols_factory.make_read_aux_cols(rs2), - ]; + aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); + aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); } else { row_slice.rs2 = read_record.rs2_imm; row_slice.rs2_as = F::ZERO; - row_slice.reads_aux = [ - aux_cols_factory.make_read_aux_cols(rs1), - MemoryReadAuxCols::::disabled(), - ]; + aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); + // row_slice.reads_aux[1] is disabled } - row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(rd); + aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); } fn air(&self) -> &Self::Air { diff --git a/extensions/rv32im/circuit/src/adapters/branch.rs b/extensions/rv32im/circuit/src/adapters/branch.rs index eeadd7e44..25338e559 100644 --- a/extensions/rv32im/circuit/src/adapters/branch.rs +++ b/extensions/rv32im/circuit/src/adapters/branch.rs @@ -216,10 +216,8 @@ impl VmAdapterChip for Rv32BranchAdapterChip { let rs2 = memory.record_by_id(read_record.rs2); row_slice.rs1_ptr = rs1.pointer; row_slice.rs2_ptr = rs2.pointer; - row_slice.reads_aux = [ - aux_cols_factory.make_read_aux_cols(rs1), - aux_cols_factory.make_read_aux_cols(rs2), - ] + aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); + aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); } fn air(&self) -> &Self::Air { diff --git a/extensions/rv32im/circuit/src/adapters/hintstore.rs b/extensions/rv32im/circuit/src/adapters/hintstore.rs index fe658e46c..0df8d630d 100644 --- a/extensions/rv32im/circuit/src/adapters/hintstore.rs +++ b/extensions/rv32im/circuit/src/adapters/hintstore.rs @@ -319,14 +319,14 @@ impl VmAdapterChip for Rv32HintStoreAdapterChip { adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); let rs1 = memory.record_by_id(read_record.rs1_record); adapter_cols.rs1_data = rs1.data.clone().try_into().unwrap(); - adapter_cols.rs1_aux_cols = aux_cols_factory.make_read_aux_cols(rs1); + aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); adapter_cols.rs1_ptr = read_record.rs1_ptr; adapter_cols.imm = read_record.imm; adapter_cols.imm_sign = F::from_bool(read_record.imm_sign); adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32); let rd = memory.record_by_id(write_record.record_id); - adapter_cols.write_aux = aux_cols_factory.make_write_aux_cols(rd); + aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.write_aux); } fn air(&self) -> &Self::Air { diff --git a/extensions/rv32im/circuit/src/adapters/jalr.rs b/extensions/rv32im/circuit/src/adapters/jalr.rs index 42669a6fb..e8f8a3f7f 100644 --- a/extensions/rv32im/circuit/src/adapters/jalr.rs +++ b/extensions/rv32im/circuit/src/adapters/jalr.rs @@ -242,18 +242,13 @@ impl VmAdapterChip for Rv32JalrAdapterChip { adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); let rs1 = memory.record_by_id(read_record.rs1); adapter_cols.rs1_ptr = rs1.pointer; - adapter_cols.rs1_aux_cols = aux_cols_factory.make_read_aux_cols(rs1); - ( - adapter_cols.rd_ptr, - adapter_cols.rd_aux_cols, - adapter_cols.needs_write, - ) = match write_record.rd_id { - Some(id) => { - let rd = memory.record_by_id(id); - (rd.pointer, aux_cols_factory.make_write_aux_cols(rd), F::ONE) - } - None => (F::ZERO, MemoryWriteAuxCols::disabled(), F::ZERO), - }; + aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); + if let Some(id) = write_record.rd_id { + let rd = memory.record_by_id(id); + adapter_cols.rd_ptr = rd.pointer; + adapter_cols.needs_write = F::ONE; + aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + } } fn air(&self) -> &Self::Air { diff --git a/extensions/rv32im/circuit/src/adapters/loadstore.rs b/extensions/rv32im/circuit/src/adapters/loadstore.rs index f403fcc35..302d51bb0 100644 --- a/extensions/rv32im/circuit/src/adapters/loadstore.rs +++ b/extensions/rv32im/circuit/src/adapters/loadstore.rs @@ -467,18 +467,16 @@ impl VmAdapterChip for Rv32LoadStoreAdapterChip { adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); let rs1 = memory.record_by_id(read_record.rs1_record); adapter_cols.rs1_data = rs1.data.clone().try_into().unwrap(); - adapter_cols.rs1_aux_cols = aux_cols_factory.make_read_aux_cols(rs1); + aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); adapter_cols.rs1_ptr = read_record.rs1_ptr; adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr; - adapter_cols.read_data_aux = - aux_cols_factory.make_read_aux_cols(memory.record_by_id(read_record.read)); + let read = memory.record_by_id(read_record.read); + aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux); adapter_cols.imm = read_record.imm; adapter_cols.imm_sign = F::from_bool(read_record.imm_sign); adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32); let write = memory.record_by_id(write_record.write_id); - adapter_cols.write_base_aux = aux_cols_factory - .make_write_aux_cols::(write) - .get_base(); + aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux); adapter_cols.mem_as = read_record.mem_as; } diff --git a/extensions/rv32im/circuit/src/adapters/mul.rs b/extensions/rv32im/circuit/src/adapters/mul.rs index 96cf8155f..c941b36a2 100644 --- a/extensions/rv32im/circuit/src/adapters/mul.rs +++ b/extensions/rv32im/circuit/src/adapters/mul.rs @@ -245,11 +245,9 @@ impl VmAdapterChip for Rv32MultAdapterChip { let rs2 = memory.record_by_id(read_record.rs2); row_slice.rs1_ptr = rs1.pointer; row_slice.rs2_ptr = rs2.pointer; - row_slice.reads_aux = [ - aux_cols_factory.make_read_aux_cols(rs1), - aux_cols_factory.make_read_aux_cols(rs2), - ]; - row_slice.writes_aux = aux_cols_factory.make_write_aux_cols(rd); + aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); + aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); + aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); } fn air(&self) -> &Self::Air { diff --git a/extensions/rv32im/circuit/src/adapters/rdwrite.rs b/extensions/rv32im/circuit/src/adapters/rdwrite.rs index a45348948..3db28592d 100644 --- a/extensions/rv32im/circuit/src/adapters/rdwrite.rs +++ b/extensions/rv32im/circuit/src/adapters/rdwrite.rs @@ -295,7 +295,7 @@ impl VmAdapterChip for Rv32RdWriteAdapterChip { adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); let rd = memory.record_by_id(write_record.rd_id.unwrap()); adapter_cols.rd_ptr = rd.pointer; - adapter_cols.rd_aux_cols = aux_cols_factory.make_write_aux_cols(rd); + aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); } fn air(&self) -> &Self::Air { @@ -360,10 +360,8 @@ impl VmAdapterChip for Rv32CondRdWriteAdapterChip { if let Some(rd_id) = write_record.rd_id { let rd = memory.record_by_id(rd_id); adapter_cols.inner.rd_ptr = rd.pointer; - adapter_cols.inner.rd_aux_cols = aux_cols_factory.make_write_aux_cols(rd); + aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols); adapter_cols.needs_write = F::ONE; - } else { - adapter_cols.needs_write = F::ZERO; } } diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs index aae1ab0ea..fca1ac01c 100644 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ b/extensions/sha256/circuit/src/sha256_chip/trace.rs @@ -152,8 +152,8 @@ where if row < 4 { read_ptr += read_size; cur_timestamp += Val::::ONE; - cols.read_aux = - memory_aux_cols_factory.make_read_aux_cols(block_reads[row]); + memory_aux_cols_factory + .generate_read_aux(block_reads[row], &mut cols.read_aux); if (row + 1) * SHA256_READ_SIZE <= message_left { cols.control.pad_flags = get_flag_pt_array( @@ -221,13 +221,14 @@ where cols.dst_ptr = dst_read.data.clone().try_into().unwrap(); cols.src_ptr = src_read.data.clone().try_into().unwrap(); cols.len_data = len_read.data.clone().try_into().unwrap(); - cols.register_reads_aux = [ - memory_aux_cols_factory.make_read_aux_cols(dst_read), - memory_aux_cols_factory.make_read_aux_cols(src_read), - memory_aux_cols_factory.make_read_aux_cols(len_read), - ]; - cols.writes_aux = - memory_aux_cols_factory.make_write_aux_cols(digest_write); + memory_aux_cols_factory + .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]); + memory_aux_cols_factory + .generate_read_aux(src_read, &mut cols.register_reads_aux[1]); + memory_aux_cols_factory + .generate_read_aux(len_read, &mut cols.register_reads_aux[2]); + memory_aux_cols_factory + .generate_write_aux(digest_write, &mut cols.writes_aux); } cols.control.padding_occurred = Val::::from_bool(has_padding_occurred);