diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 206c0e16c0..b794aacaef 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -65,14 +65,17 @@ template struct SimplePoseidonSpecificCols { template struct MultiObserveCols { T pc; T final_timestamp_increment; + T state_ptr_register; + T ctx_register; + T input_ptr_register; + T hint_id_register; T state_ptr; + T ctx_ptr; T input_ptr; - T init_pos; - T len; - T input_register_1; - T input_register_2; - T input_register_3; - T output_register; + T hint_id; + T ctx[4]; + MemoryReadAuxCols read_ctx; + T chunk_ts_count; T is_first; T is_last; T curr_len; diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index fdbe0d3ce5..772a708f89 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -24,6 +24,7 @@ template struct NativePoseidon2Cols { T inside_row; T simple; T multi_observe_row; + T not_hint_multi_observe; T end_inside_row; T end_top_level; @@ -355,31 +356,46 @@ template struct Poseidon2Wrapper { if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { uint32_t very_start_timestamp = row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); - for (uint32_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 3; ++i) { mem_fill_base( mem_helper, very_start_timestamp + i, specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) ); } + mem_fill_base( + mem_helper, + very_start_timestamp + 3, + specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base)) + ); + mem_fill_base( + mem_helper, + very_start_timestamp + 4, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base)) + ); } else { uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); uint32_t chunk_start = specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); uint32_t chunk_end = specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + uint32_t is_hint = + specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32(); + uint32_t ts_per_element = 2 - is_hint; for (uint32_t j = chunk_start; j < chunk_end; ++j) { + if (!is_hint) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + } mem_fill_base( mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, + start_timestamp + (1 - is_hint), specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) ); - start_timestamp += 2; + start_timestamp += ts_per_element; } if (chunk_end >= CHUNK) { mem_fill_base( diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 50a9cba86d..0d476413ac 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -75,19 +75,29 @@ impl VmProverExtension FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::>()?; - let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); - inventory.add_executor_chip(poseidon2); - let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; + let cpu_range_checker = range_checker + .cpu_chip + .clone() + .expect("VariableRangeCheckerChipGPU is expected to be hybrid with cpu_chip"); let cpu_chip = Arc::new(HintSpaceProviderChip::new( hint_air.hint_bus, - range_checker.clone(), + cpu_range_checker, timestamp_max_bits, )); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); inventory.add_periphery_chip(provider_gpu); + inventory.next_air::>()?; + + let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider( + range_checker.clone(), + timestamp_max_bits, + cpu_chip.clone(), + ); + inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip); diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 924d4927e8..4a930ddafe 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -271,15 +271,6 @@ where ); inventory.add_air(fri_reduced_opening); - let verify_batch = NativePoseidon2Air::<_, 1>::new( - exec_bridge, - memory_bridge, - hint_bridge, - VerifyBatchBus::new(inventory.new_bus_idx()), - Poseidon2Config::default(), - ); - inventory.add_air(verify_batch); - let hint_space_provider = HintSpaceProviderAir { hint_bus: hint_bridge.hint_bus(), lt_air: IsLtSubAir::new( @@ -289,6 +280,15 @@ where }; inventory.add_air(hint_space_provider); + let verify_batch = NativePoseidon2Air::<_, 1>::new( + exec_bridge, + memory_bridge, + hint_bridge, + VerifyBatchBus::new(inventory.new_bus_idx()), + Poseidon2Config::default(), + ); + inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); @@ -365,13 +365,6 @@ where FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone()); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::, 1>>()?; - let poseidon2 = NativePoseidon2Chip::<_, 1>::new( - NativePoseidon2Filler::new(Poseidon2Config::default()), - mem_helper.clone(), - ); - inventory.add_executor_chip(poseidon2); - let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); let hint_space_provider = Arc::new(HintSpaceProviderChip::new( hint_bus, @@ -382,8 +375,17 @@ where inventory.next_air::()?; inventory.add_periphery_chip(hint_space_provider.clone()); + inventory.next_air::, 1>>()?; + + let poseidon2 = NativePoseidon2Chip::<_, 1>::new( + NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()), + mem_helper.clone(), + ); + inventory.add_executor_chip(poseidon2); + + inventory.next_air::()?; let tower_verify = NativeSumcheckChip::new( - NativeSumcheckFiller::new(hint_space_provider), + NativeSumcheckFiller::new(hint_space_provider.clone()), mem_helper.clone(), ); inventory.add_executor_chip(tower_verify); diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 9e9cdf5ce8..8b3b02ebc2 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -94,6 +94,7 @@ impl Air inside_row, simple, multi_observe_row, + not_hint_multi_observe, end_inside_row, end_top_level, start_top_level, @@ -713,10 +714,17 @@ impl Air let &MultiObserveCols { pc, final_timestamp_increment, + state_ptr_register, + ctx_register, + input_ptr_register, + hint_id_register, state_ptr, + ctx_ptr, input_ptr, - init_pos, - len, + hint_id, + ctx, + read_ctx, + chunk_ts_count, is_first, is_last, curr_len, @@ -731,26 +739,38 @@ impl Air should_permute, write_sponge_state, write_final_idx, - input_register_1, - input_register_2, - input_register_3, - output_register, } = multi_observe_specific; + // Alias context values + let init_pos = ctx[0]; + let len = ctx[1]; + let is_hint = ctx[2]; + builder.when(multi_observe_row).assert_bool(is_first); builder.when(multi_observe_row).assert_bool(is_last); builder.when(multi_observe_row).assert_bool(should_permute); + builder.when(multi_observe_row).assert_bool(is_hint); + builder.assert_eq( + not_hint_multi_observe, + multi_observe_row * (AB::Expr::ONE - is_hint), + ); + let hint_multi_observe: AB::Expr = multi_observe_row - not_hint_multi_observe; + // chunk_ts_count = (end_idx - start_idx) * (2 - is_hint) + builder.when(multi_observe_row).assert_eq( + chunk_ts_count, + (end_idx - start_idx) * AB::F::TWO - (end_idx - start_idx) * is_hint, + ); self.execution_bridge .execute_and_increment_pc( AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), [ - output_register.into(), - input_register_1.into(), - input_register_2.into(), + state_ptr_register.into(), + ctx_register.into(), + input_ptr_register.into(), self.address_space.into(), self.address_space.into(), - input_register_3.into(), + hint_id_register.into(), ], ExecutionState::new(pc, very_first_timestamp), final_timestamp_increment, @@ -759,7 +779,7 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, output_register), + MemoryAddress::new(self.address_space, state_ptr_register), [state_ptr], very_first_timestamp, &read_data[0], @@ -768,8 +788,8 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_1), - [init_pos], + MemoryAddress::new(self.address_space, ctx_register), + [ctx_ptr], very_first_timestamp + AB::F::ONE, &read_data[1], ) @@ -777,24 +797,47 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_2), + MemoryAddress::new(self.address_space, input_ptr_register), [input_ptr], very_first_timestamp + AB::F::TWO, &read_data[2], ) .eval(builder, multi_observe_row * is_first); + // Read context array: [init_pos, len, is_hint, reserved] from ctx_ptr self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_3), - [len], + MemoryAddress::new(self.address_space, ctx_ptr), + ctx, very_first_timestamp + AB::F::from_canonical_usize(3), + &read_ctx, + ) + .eval(builder, multi_observe_row * is_first); + + // Read hint_id from register (reuse spare read_data[3] on head row) + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, hint_id_register), + [hint_id], + very_first_timestamp + AB::F::from_canonical_usize(4), &read_data[3], ) .eval(builder, multi_observe_row * is_first); + // Per-element constraints for chunk rows. for i in 0..CHUNK { let i_var = AB::F::from_canonical_usize(i); + + // Hint mode: lookup from hint space. + self.hint_bridge.lookup( + builder, + hint_id, + curr_len + i_var - start_idx, + data[i], + hint_multi_observe.clone() * aux_read_enabled[i], + ); + + // Non-hint mode: read from memory. self.memory_bridge .read( MemoryAddress::new( @@ -805,8 +848,7 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, &read_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); - + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); self.memory_bridge .write( MemoryAddress::new(self.address_space, state_ptr + i_var), @@ -814,7 +856,16 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, &write_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); + + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, state_ptr + i_var), + [data[i]], + start_timestamp + i_var - start_idx, + &write_data[i], + ) + .eval(builder, hint_multi_observe.clone() * aux_read_enabled[i]); } for i in 0..(CHUNK - 1) { @@ -885,7 +936,7 @@ impl Air .write( MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + (end_idx - start_idx) * AB::F::TWO, + start_timestamp + chunk_ts_count, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); @@ -909,11 +960,12 @@ impl Air // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; + // Write final_idx back to ctx[0] (ctx_ptr address) self.memory_bridge .write( - MemoryAddress::new(self.address_space, input_register_1), + MemoryAddress::new(self.address_space, ctx_ptr), [final_idx], - start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + start_timestamp + chunk_ts_count + should_permute, &write_final_idx, ) .eval(builder, multi_observe_row * is_last); @@ -962,41 +1014,59 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(init_pos, next_multi_observe_specific.init_pos); + .assert_eq(init_pos, next_multi_observe_specific.ctx[0]); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(len, next_multi_observe_specific.ctx[1]); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(len, next_multi_observe_specific.len); + .assert_eq( + state_ptr_register, + next_multi_observe_specific.state_ptr_register, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_1, - next_multi_observe_specific.input_register_1, + ctx_register, + next_multi_observe_specific.ctx_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_2, - next_multi_observe_specific.input_register_2, + input_ptr_register, + next_multi_observe_specific.input_ptr_register, ); + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(ctx_ptr, next_multi_observe_specific.ctx_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(hint_id, next_multi_observe_specific.hint_id); + builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_3, - next_multi_observe_specific.input_register_3, + hint_id_register, + next_multi_observe_specific.hint_id_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(output_register, next_multi_observe_specific.output_register); + .assert_eq(is_hint, next_multi_observe_specific.ctx[2]); // Timestamp constraints builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 770efc7307..e136c2fbb0 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,4 +1,5 @@ use std::borrow::{Borrow, BorrowMut}; +use std::sync::Arc; use openvm_circuit::{ arch::*, @@ -22,6 +23,7 @@ use openvm_stark_backend::{ }; use crate::{ + hint_space_provider::HintSpaceProviderChip, mem_fill_helper, poseidon2::{ columns::{ @@ -45,6 +47,7 @@ pub struct NativePoseidon2Filler { // pre-computed Poseidon2 sub cols for dummy rows. empty_poseidon2_sub_cols: Vec, pub(super) subchip: Poseidon2SubChip, + pub hint_space_provider: Arc>, } impl NativePoseidon2Executor { @@ -71,12 +74,16 @@ pub(crate) fn compress( } impl NativePoseidon2Filler { - pub fn new(poseidon2_config: Poseidon2Config) -> Self { + pub fn new( + poseidon2_config: Poseidon2Config, + hint_space_provider: Arc>, + ) -> Self { let subchip = Poseidon2SubChip::new(poseidon2_config.constants); let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; Self { empty_poseidon2_sub_cols, subchip, + hint_space_provider, } } } @@ -649,11 +656,11 @@ where } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { let &Instruction { a: state_ptr_register, - b: init_pos_register, + b: ctx_register, c: input_ptr_register, d: register_address_space, e: data_address_space, - f: len_register, + f: hint_id_register, .. } = instruction; @@ -663,31 +670,52 @@ where ); assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); - let [init_pos]: [F; 1] = - memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); - let [input_len]: [F; 1] = - memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + // Read ctx_ptr from register, then read context array from memory + let [ctx_ptr]: [F; 1] = + memory_read_native(state.memory.data(), ctx_register.as_canonical_u32()); + let ctx: [F; 4] = + memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()); + let init_pos = ctx[0]; + let input_len = ctx[1]; + let is_hint = ctx[2].as_canonical_u32() != 0; + + // Read hint_id from register + let [hint_id]: [F; 1] = + memory_read_native(state.memory.data(), hint_id_register.as_canonical_u32()); + + // Get hint_space data if in hint mode + let hint_data: Vec = if is_hint { + state.streams.hint_space[hint_id.as_canonical_u32() as usize].clone() + } else { + vec![] + }; let mut len = input_len.as_canonical_u32() as usize; let mut pos = init_pos.as_canonical_u32() as usize; let mut chunks: Vec<(usize, usize)> = vec![]; - const NUM_HEAD_ACCESSES: usize = 4; + // 3 register reads + 1 context array read + 1 hint_id register read = 5 head accesses + const NUM_HEAD_ACCESSES: usize = 5; let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + // In hint mode: 1 timestamp per element (write only) + // In non-hint mode: 2 timestamps per element (read + write) + let ts_per_element = if is_hint { 1 } else { 2 }; while len > 0 { if len >= (CHUNK - pos) { chunks.push((pos, CHUNK)); len -= CHUNK - pos; - final_timestamp_inc += 2 * (CHUNK - pos) + 1; + final_timestamp_inc += ts_per_element * (CHUNK - pos) + 1; pos = 0; } else { chunks.push((pos, pos + len)); - final_timestamp_inc += 2 * len; + final_timestamp_inc += ts_per_element * len; len = 0; pos += len; } } - final_timestamp_inc += 1; // write back to init_pos_register + // Final ctx[0] writeback always happens (including zero-length input + // where the head row is both the first and last row). + final_timestamp_inc += 1; let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -698,14 +726,15 @@ where let head_multi_observe_cols: &mut MultiObserveCols = head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + // 3 register reads: state_ptr, ctx_ptr, input_ptr let [state_ptr]: [F; 1] = tracing_read_native_helper( state.memory, state_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[0].as_mut(), ); - let [init_pos]: [F; 1] = tracing_read_native_helper( + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, - init_pos_register.as_canonical_u32(), + ctx_register.as_canonical_u32(), head_multi_observe_cols.read_data[1].as_mut(), ); let [input_ptr]: [F; 1] = tracing_read_native_helper( @@ -713,9 +742,16 @@ where input_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[2].as_mut(), ); - let [input_len]: [F; 1] = tracing_read_native_helper( + // 1 context array read: [init_pos, len, is_hint, reserved] + let ctx: [F; 4] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_multi_observe_cols.read_ctx.as_mut(), + ); + // 1 hint_id register read (reuse spare read_data[3] on head row) + let [hint_id]: [F; 1] = tracing_read_native_helper( state.memory, - len_register.as_canonical_u32(), + hint_id_register.as_canonical_u32(), head_multi_observe_cols.read_data[3].as_mut(), ); @@ -727,16 +763,20 @@ where for (i, cols) in allocated_rows.iter_mut().enumerate() { let multi_observe_cols: &mut MultiObserveCols = cols.specific[..MultiObserveCols::::width()].borrow_mut(); - multi_observe_cols.input_register_1 = init_pos_register; - multi_observe_cols.input_register_2 = input_ptr_register; - multi_observe_cols.input_register_3 = len_register; - multi_observe_cols.output_register = state_ptr_register; - multi_observe_cols.init_pos = init_pos; - multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr_register = state_ptr_register; + multi_observe_cols.ctx_register = ctx_register; + multi_observe_cols.input_ptr_register = input_ptr_register; + multi_observe_cols.hint_id_register = hint_id_register; multi_observe_cols.state_ptr = state_ptr; - multi_observe_cols.len = input_len; + multi_observe_cols.ctx_ptr = ctx_ptr; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.hint_id = hint_id; + multi_observe_cols.ctx = ctx; + + // chunk_ts_count will be filled per-chunk row below cols.multi_observe_row = F::ONE; + cols.not_hint_multi_observe = if is_hint { F::ZERO } else { F::ONE }; cols.very_first_timestamp = init_timestamp; if i == 0 { @@ -746,9 +786,26 @@ where multi_observe_cols.final_timestamp_increment = F::from_canonical_usize(final_timestamp_inc); multi_observe_cols.is_first = F::ONE; - multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.is_last = if chunks.is_empty() { F::ONE } else { F::ZERO }; multi_observe_cols.curr_len = F::ZERO; multi_observe_cols.should_permute = F::ZERO; + if chunks.is_empty() { + // Zero-length input: head row is both first and last. + // Set start_timestamp to right after the 5 head reads, + // and write back init_pos (unchanged) to ctx_ptr[0]. + cols.start_timestamp = F::from_canonical_u32( + init_timestamp_u32 + NUM_HEAD_ACCESSES as u32, + ); + multi_observe_cols.start_idx = init_pos; + multi_observe_cols.end_idx = init_pos; + // state.memory.timestamp == init_ts + NUM_HEAD_ACCESSES here. + tracing_write_native_inplace( + state.memory, + ctx_ptr.as_canonical_u32(), + [init_pos], + &mut multi_observe_cols.write_final_idx, + ); + } } } @@ -767,6 +824,7 @@ where multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + multi_observe_cols.chunk_ts_count = F::from_canonical_usize((chunk_end - chunk_start) * ts_per_element); multi_observe_cols.is_first = F::ZERO; multi_observe_cols.is_last = if i == num_chunks - 1 { F::ONE } else { F::ZERO }; @@ -779,21 +837,29 @@ where multi_observe_cols.aux_before_end[j] = F::ONE; } for j in chunk_start..chunk_end { - let n_f: [F; 1] = tracing_read_native_helper( - state.memory, - input_ptr_u32 + input_idx as u32, - multi_observe_cols.read_data[j].as_mut(), - ); + let n_f: F = if is_hint { + // In hint mode: read from hint_space + hint_data[input_idx] + } else { + // In non-hint mode: read from memory via tracing read + let [v]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + v + }; + multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, state_ptr_u32 + j as u32, - n_f, + [n_f], &mut multi_observe_cols.write_data[j], ); - multi_observe_cols.data[j] = n_f[0]; + multi_observe_cols.data[j] = n_f; input_idx += 1; - cur_timestamp += 2; + cur_timestamp += ts_per_element as u32; } let permutation_input: [F; 16] = @@ -817,7 +883,7 @@ where let final_idx = F::from_canonical_usize(chunk_end % CHUNK); tracing_write_native_inplace( state.memory, - init_pos_register.as_canonical_u32(), + ctx_ptr.as_canonical_u32(), [final_idx], &mut multi_observe_cols.write_final_idx, ); @@ -1161,7 +1227,7 @@ impl NativePoseidon2Filler::width()].borrow_mut(); let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); - // state_ptr, init_pos, input_ptr, len + // 3 register reads: state_ptr, ctx_ptr, input_ptr mem_fill_helper( mem_helper, start_timestamp_u32, @@ -1177,12 +1243,23 @@ impl NativePoseidon2Filler = chunk_slice @@ -1194,6 +1271,8 @@ impl NativePoseidon2Filler = chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); @@ -1205,18 +1284,32 @@ impl NativePoseidon2Filler= CHUNK as u32 { @@ -1235,6 +1328,15 @@ impl NativePoseidon2Filler = + chunk_slice[..width].borrow_mut(); + let head_mo: &mut MultiObserveCols = + head_c.specific[..MultiObserveCols::::width()].borrow_mut(); + let head_ts = head_c.start_timestamp.as_canonical_u32(); + mem_fill_helper(mem_helper, head_ts, head_mo.write_final_idx.as_mut()); + } } #[inline(always)] diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index abb8db54a2..67557a5f73 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -31,6 +31,10 @@ pub struct NativePoseidon2Cols { /// Indicates that this row is a multi_observe row. pub multi_observe_row: T, + /// Materialized column: multi_observe_row * (1 - is_hint). + /// Lives in main cols (not overlaid specific) so it is 0 on non-multi_observe rows. + pub not_hint_multi_observe: T, + /// Indicates the last row in an inside-row block. pub end_inside_row: T, /// Indicates the last row in a top-level block. @@ -211,16 +215,24 @@ pub struct MultiObserveCols { pub pc: T, pub final_timestamp_increment: T, - // Initial reads from registers - // They are same across same instance of multi_observe + // Register addresses + pub state_ptr_register: T, + pub ctx_register: T, + pub input_ptr_register: T, + pub hint_id_register: T, + + // Values read from registers pub state_ptr: T, + pub ctx_ptr: T, pub input_ptr: T, - pub init_pos: T, - pub len: T, - pub input_register_1: T, - pub input_register_2: T, - pub input_register_3: T, - pub output_register: T, + pub hint_id: T, + + // Context array values read from ctx_ptr + // ctx[0] = init_pos, ctx[1] = len, ctx[2] = is_hint, ctx[3] = reserved + pub ctx: [T; 4], + pub read_ctx: MemoryReadAuxCols, + + pub chunk_ts_count: T, pub is_first: T, pub is_last: T, @@ -240,6 +252,6 @@ pub struct MultiObserveCols { pub should_permute: T, pub write_sponge_state: MemoryWriteAuxCols, - // Final write back and registers + // Final write back to ctx[0] pub write_final_idx: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/poseidon2/cuda.rs b/extensions/native/circuit/src/poseidon2/cuda.rs index 0425cfda18..4bdf337a75 100644 --- a/extensions/native/circuit/src/poseidon2/cuda.rs +++ b/extensions/native/circuit/src/poseidon2/cuda.rs @@ -1,6 +1,5 @@ use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc}; -use derive_new::new; use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU; use openvm_cuda_backend::{ @@ -8,15 +7,115 @@ use openvm_cuda_backend::{ }; use openvm_cuda_common::copy::MemCopyH2D; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; -use p3_field::{Field, PrimeField32}; +use p3_field::{Field, FieldAlgebra, PrimeField32}; -use super::columns::NativePoseidon2Cols; -use crate::cuda_abi::poseidon2_cuda; +use super::columns::{MultiObserveCols, NativePoseidon2Cols}; +use crate::{ + cuda_abi::poseidon2_cuda, + hint_space_provider::SharedHintSpaceProviderChip, +}; -#[derive(new)] pub struct NativePoseidon2ChipGpu { pub range_checker: Arc, pub timestamp_max_bits: usize, + pub hint_space_provider: Option>, +} + +impl NativePoseidon2ChipGpu { + pub fn new(range_checker: Arc, timestamp_max_bits: usize) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: None, + } + } + + pub fn new_with_hint_space_provider( + range_checker: Arc, + timestamp_max_bits: usize, + hint_space_provider: SharedHintSpaceProviderChip, + ) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: Some(hint_space_provider), + } + } + + /// Scans multi-observe execution records to populate the hint provider with + /// (hint_id, offset, value) triples for hint-mode rows. + fn populate_hint_provider(&self, records: &[u8]) { + let Some(hint_space_provider) = &self.hint_space_provider else { + return; + }; + + let width = NativePoseidon2Cols::::width(); + let record_size = width * size_of::(); + if records.len() % record_size != 0 { + return; + } + let height = records.len() / record_size; + + let row_slice = unsafe { + let ptr = records.as_ptr() as *const F; + from_raw_parts(ptr, height * width) + }; + + let mut row_idx = 0; + while row_idx < height { + let start = row_idx * width; + let cols: &NativePoseidon2Cols = + row_slice[start..(start + width)].borrow(); + + if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + if num_rows > 1 { + let head_multi_observe_cols: &MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow(); + let is_hint = head_multi_observe_cols.ctx[2] != F::ZERO; + if is_hint { + let hint_id = head_multi_observe_cols.hint_id; + for local_row in 1..num_rows { + let chunk_cols: &NativePoseidon2Cols = + row_slice[(row_idx + local_row) * width + ..(row_idx + local_row + 1) * width] + .borrow(); + let multi_observe_cols: &MultiObserveCols = chunk_cols.specific + [..MultiObserveCols::::width()] + .borrow(); + + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + let curr_len = multi_observe_cols.curr_len.as_canonical_u32(); + + for j in chunk_start..chunk_end { + let input_idx = curr_len + (j - chunk_start); + let val = multi_observe_cols.data[j as usize]; + hint_space_provider.request( + hint_id, + F::from_canonical_u32(input_idx), + val, + ); + } + } + } + } + row_idx += num_rows.max(1); + continue; + } + + if cols.simple.is_one() { + row_idx += 1; + } else { + let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; + let non_inside_start = start + (num_non_inside_row - 1) * width; + let last_non_inside_cols: &NativePoseidon2Cols = + row_slice[non_inside_start..(non_inside_start + width)].borrow(); + let total_num_row = last_non_inside_cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + } + } + } } impl Chip @@ -28,6 +127,9 @@ impl Chip return get_empty_air_proving_ctx::(); } + // Populate hint space provider from multi-observe records before GPU upload. + self.populate_hint_provider(records); + // For Poseidon2, the records are already the trace rows // Use the columns width directly let width = NativePoseidon2Cols::::width(); diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index a0c1fc72a2..d41d911812 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -35,9 +35,9 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { #[repr(C)] struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { subchip: &'a Poseidon2SubChip, - pub init_pos_register: u32, + pub ctx_register: u32, pub input_ptr_register: u32, - pub len_register: u32, + pub hint_id_register: u32, pub state_ptr_register: u32, } @@ -137,9 +137,9 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor = if is_hint { + exec_state.streams.hint_space[hint_id_u32 as usize].clone() + } else { + vec![] + }; for (chunk_start, chunk_end) in observation_chunks { for j in chunk_start..chunk_end { - let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + let n_f = if is_hint { + hint_data[input_idx as usize] + } else { + let [v]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + v + }; + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } @@ -634,9 +653,10 @@ unsafe fn execute_multi_observe_e12_impl< height += 1; } if let Some(final_idx) = final_idx { + // Write final_idx back to ctx[0] (overwriting init_pos in context array) exec_state.vm_write::( NATIVE_AS, - pre_compute.init_pos_register, + ctx_ptr.as_canonical_u32(), &[F::from_canonical_usize(final_idx)], ); } diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 689bb1ebd0..14d84b4f1f 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,13 +489,13 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } - DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + DslIr::Poseidon2MultiObserve(dst, ctx_ptr, arr_ptr, hint_id) => { self.push( AsmInstruction::Poseidon2MultiObserve( dst.fp(), - init_pos.fp(), + ctx_ptr.fp(), arr_ptr.fp(), - len.get_var().fp(), + hint_id.fp(), ), debug_info, ); diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index b715d97cf1..16b2be4b49 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,9 +110,10 @@ pub enum AsmInstruction { /// Halt. Halt, - /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation - /// (sponge_state, init_pos, arr_ptr, len) - /// Returns the final index position of hash sponge + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation. + /// (sponge_state, ctx_ptr, arr_ptr, hint_id) + /// Context array at ctx_ptr: [init_pos, len, is_hint, reserved] + /// When is_hint=1, data is read from hint space using hint_id. Poseidon2MultiObserve(i32, i32, i32, i32), /// Perform a Poseidon2 permutation on state starting at address `lhs` @@ -350,11 +351,11 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), - AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => { write!( f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", - dst, init_pos, arr, len + dst, ctx, arr, hint_id ) } AsmInstruction::Poseidon2Permute(dst, lhs) => { diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 0ff358ec70..61fd726d3f 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -441,15 +441,15 @@ fn convert_instruction>( AS::Native, AS::Native, )], - AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => vec![ Instruction { opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), a: i32_f(dst), - b: i32_f(init), + b: i32_f(ctx), c: i32_f(arr), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(len), + f: i32_f(hint_id), g: F::ZERO, } ], diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index a4932d2826..8658a8a06e 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,13 +208,14 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), - /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations - /// (output = p2_multi_observe(array, els)). + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations. + /// Context values (init_pos, len, is_hint) are passed via a context array instead of separate registers. + /// When is_hint=1, data is read from hint space using hint_id instead of from input array pointer. Poseidon2MultiObserve( - Ptr, // sponge_state - Var, // initial input_ptr position - Ptr, // input array (base elements) - Usize, // len of els + Ptr, // sponge_state + Ptr, // ctx_ptr (context array: [init_pos, len, is_hint, reserved]) + Ptr, // input array (base elements; used when is_hint=0) + Var, // hint_id (hint space id; used when is_hint=1) ), // Miscellaneous instructions. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 6d32f89409..6310917c0d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -19,6 +19,8 @@ impl Builder { sponge_state: &Array>, input_ptr: Ptr, arr: &Array>, + input_len: Usize, + hint_id: Option>, ) -> Usize { let buffer_size: Var = Var::uninit(self); self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); @@ -31,19 +33,40 @@ impl Builder { Array::Fixed(_) => { panic!("Base elements input must be dynamic"); } - Array::Dyn(ptr, len) => { + Array::Dyn(ptr, _) => { let init_pos: Var = Var::uninit(self); self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + let is_hint = hint_id.is_some(); + let hint_id_var: Var = if let Some(id) = hint_id { + id + } else { + let v: Var = Var::uninit(self); + self.assign(&v, C::N::ZERO); + v + }; + + // Allocate context array: [init_pos, len, is_hint, reserved] + let ctx = self.dyn_array::>(4usize); + self.set(&ctx, 0, init_pos); + self.set(&ctx, 1, input_len.get_var()); + self.set( + &ctx, + 2, + if is_hint { C::N::ONE } else { C::N::ZERO }, + ); + self.set(&ctx, 3, C::N::ZERO); + self.operations.push(DslIr::Poseidon2MultiObserve( *sponge_ptr, - init_pos, + ctx.ptr(), *ptr, - len.clone(), + hint_id_var, )); - // automatically updated by Poseidon2MultiObserve operation - Usize::Var(init_pos) + // Read back the updated init_pos from ctx[0] + let final_pos: Var = self.get(&ctx, 0); + Usize::Var(final_pos) } }, } diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 440b14ec59..10b9bc62e9 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -81,7 +81,7 @@ impl DuplexChallengerVariable { // This is equivalent to calling `observe` multiple times, but more efficient. pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { - let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr, arr.len(), None); builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); builder.if_ne(next_pos, Usize::from(0)).then_or_else(