From e4da12fc667dfd52e83e3b178be80f1ab8a17dfc Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 15 Mar 2026 19:39:52 -0400 Subject: [PATCH 1/9] rebase hint multi observe --- .../native/circuit/cuda/src/poseidon2.cu | 37 ++++- .../native/circuit/src/extension/mod.rs | 11 +- .../native/circuit/src/poseidon2/air.rs | 126 +++++++++++---- .../native/circuit/src/poseidon2/chip.rs | 147 +++++++++++++----- .../native/circuit/src/poseidon2/columns.rs | 24 +-- .../native/circuit/src/poseidon2/execution.rs | 39 +++-- .../native/compiler/src/asm/compiler.rs | 6 +- .../native/compiler/src/asm/instruction.rs | 11 +- .../native/compiler/src/conversion/mod.rs | 6 +- .../native/compiler/src/ir/instructions.rs | 13 +- extensions/native/compiler/src/ir/poseidon.rs | 30 +++- .../native/recursion/src/challenger/duplex.rs | 2 +- 12 files changed, 326 insertions(+), 126 deletions(-) diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index fdbe0d3ce5..b39038a079 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -355,31 +355,52 @@ template struct Poseidon2Wrapper { if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { uint32_t very_start_timestamp = row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); - for (uint32_t i = 0; i < 4; ++i) { + // 3 register reads at timestamps +0, +1, +2 + for (uint32_t i = 0; i < 3; ++i) { mem_fill_base( mem_helper, very_start_timestamp + i, specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) ); } + // 1 context array read at timestamp +3 + mem_fill_base( + mem_helper, + very_start_timestamp + 3, + specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base)) + ); + // 1 hint_id register read at timestamp +4 (reuse spare read_data[3] on head row) + mem_fill_base( + mem_helper, + very_start_timestamp + 4, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base)) + ); } else { uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); uint32_t chunk_start = specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); uint32_t chunk_end = specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + // is_hint = ctx[2] + uint32_t is_hint = + specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32(); + uint32_t ts_per_element = 2 - is_hint; for (uint32_t j = chunk_start; j < chunk_end; ++j) { + if (!is_hint) { + // Non-hint mode: fill read_data aux + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + } + // Write timestamp: start_timestamp + (1 - is_hint) for non-hint, start_timestamp for hint mem_fill_base( mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, + start_timestamp + (1 - is_hint), specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) ); - start_timestamp += 2; + start_timestamp += ts_per_element; } if (chunk_end >= CHUNK) { mem_fill_base( diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 924d4927e8..0aab146a13 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -366,11 +366,6 @@ where inventory.add_executor_chip(fri_reduced_opening); inventory.next_air::, 1>>()?; - let poseidon2 = NativePoseidon2Chip::<_, 1>::new( - NativePoseidon2Filler::new(Poseidon2Config::default()), - mem_helper.clone(), - ); - inventory.add_executor_chip(poseidon2); let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); let hint_space_provider = Arc::new(HintSpaceProviderChip::new( @@ -379,6 +374,12 @@ where timestamp_max_bits, )); + let poseidon2 = NativePoseidon2Chip::<_, 1>::new( + NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()), + mem_helper.clone(), + ); + inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; inventory.add_periphery_chip(hint_space_provider.clone()); diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 9e9cdf5ce8..b7f9f458bb 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -713,10 +713,16 @@ impl Air let &MultiObserveCols { pc, final_timestamp_increment, + state_ptr_register, + ctx_register, + input_ptr_register, + hint_id_register, state_ptr, + ctx_ptr, input_ptr, - init_pos, - len, + hint_id, + ctx, + read_ctx, is_first, is_last, curr_len, @@ -731,35 +737,38 @@ impl Air should_permute, write_sponge_state, write_final_idx, - input_register_1, - input_register_2, - input_register_3, - output_register, } = multi_observe_specific; + // Alias context values + let init_pos = ctx[0]; + let len = ctx[1]; + let is_hint = ctx[2]; + builder.when(multi_observe_row).assert_bool(is_first); builder.when(multi_observe_row).assert_bool(is_last); builder.when(multi_observe_row).assert_bool(should_permute); + builder.when(multi_observe_row).assert_bool(is_hint); self.execution_bridge .execute_and_increment_pc( AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), [ - output_register.into(), - input_register_1.into(), - input_register_2.into(), + state_ptr_register.into(), + ctx_register.into(), + input_ptr_register.into(), self.address_space.into(), self.address_space.into(), - input_register_3.into(), + hint_id_register.into(), ], ExecutionState::new(pc, very_first_timestamp), final_timestamp_increment, ) .eval(builder, multi_observe_row * is_first); + // Head row: 3 register reads + 1 context array read + 1 hint_id register read self.memory_bridge .read( - MemoryAddress::new(self.address_space, output_register), + MemoryAddress::new(self.address_space, state_ptr_register), [state_ptr], very_first_timestamp, &read_data[0], @@ -768,8 +777,8 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_1), - [init_pos], + MemoryAddress::new(self.address_space, ctx_register), + [ctx_ptr], very_first_timestamp + AB::F::ONE, &read_data[1], ) @@ -777,41 +786,73 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_2), + MemoryAddress::new(self.address_space, input_ptr_register), [input_ptr], very_first_timestamp + AB::F::TWO, &read_data[2], ) .eval(builder, multi_observe_row * is_first); + // Read context array: [init_pos, len, is_hint, reserved] from ctx_ptr self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_3), - [len], + MemoryAddress::new(self.address_space, ctx_ptr), + ctx, very_first_timestamp + AB::F::from_canonical_usize(3), + &read_ctx, + ) + .eval(builder, multi_observe_row * is_first); + + // Read hint_id from register (reuse spare read_data[3] on head row) + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, hint_id_register), + [hint_id], + very_first_timestamp + AB::F::from_canonical_usize(4), &read_data[3], ) .eval(builder, multi_observe_row * is_first); + // ts_per_element = 2 - is_hint (non-hint: read+write=2, hint: write-only=1) + let is_hint_expr: AB::Expr = is_hint.into(); + let ts_per_element: AB::Expr = AB::Expr::TWO - is_hint_expr.clone(); for i in 0..CHUNK { - let i_var = AB::F::from_canonical_usize(i); + let i_var: AB::Expr = AB::F::from_canonical_usize(i).into(); + let start_idx_expr: AB::Expr = start_idx.into(); + let element_start_ts: AB::Expr = + start_timestamp.into() + (i_var.clone() - start_idx_expr.clone()) * ts_per_element.clone(); + + // Non-hint mode: read from memory self.memory_bridge .read( MemoryAddress::new( self.address_space, - input_ptr + curr_len + i_var - start_idx, + input_ptr + curr_len + i_var.clone() - start_idx_expr.clone(), ), [data[i]], - start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, + element_start_ts.clone(), &read_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); + .eval( + builder, + multi_observe_row * aux_read_enabled[i] * (AB::Expr::ONE - is_hint_expr.clone()), + ); + // Hint mode: lookup from hint space + self.hint_bridge.lookup( + builder, + hint_id, + curr_len + i_var.clone() - start_idx_expr.clone(), + data[i], + multi_observe_row * aux_read_enabled[i] * is_hint_expr.clone(), + ); + + // Write to sponge state (always, for both modes) self.memory_bridge .write( MemoryAddress::new(self.address_space, state_ptr + i_var), [data[i]], - start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + element_start_ts + (AB::Expr::ONE - is_hint_expr.clone()), &write_data[i], ) .eval(builder, multi_observe_row * aux_read_enabled[i]); @@ -885,7 +926,7 @@ impl Air .write( MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + (end_idx - start_idx) * AB::F::TWO, + start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr.clone()), &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); @@ -909,11 +950,12 @@ impl Air // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; + // Write final_idx back to ctx[0] (ctx_ptr address) self.memory_bridge .write( - MemoryAddress::new(self.address_space, input_register_1), + MemoryAddress::new(self.address_space, ctx_ptr), [final_idx], - start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr) + should_permute, &write_final_idx, ) .eval(builder, multi_observe_row * is_last); @@ -962,41 +1004,59 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(init_pos, next_multi_observe_specific.init_pos); + .assert_eq(init_pos, next_multi_observe_specific.ctx[0]); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(len, next_multi_observe_specific.len); + .assert_eq(len, next_multi_observe_specific.ctx[1]); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq( + state_ptr_register, + next_multi_observe_specific.state_ptr_register, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_1, - next_multi_observe_specific.input_register_1, + ctx_register, + next_multi_observe_specific.ctx_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_2, - next_multi_observe_specific.input_register_2, + input_ptr_register, + next_multi_observe_specific.input_ptr_register, ); + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(ctx_ptr, next_multi_observe_specific.ctx_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(hint_id, next_multi_observe_specific.hint_id); + builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_3, - next_multi_observe_specific.input_register_3, + hint_id_register, + next_multi_observe_specific.hint_id_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(output_register, next_multi_observe_specific.output_register); + .assert_eq(is_hint, next_multi_observe_specific.ctx[2]); // Timestamp constraints builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 770efc7307..7dee62f60d 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,4 +1,5 @@ use std::borrow::{Borrow, BorrowMut}; +use std::sync::Arc; use openvm_circuit::{ arch::*, @@ -22,6 +23,7 @@ use openvm_stark_backend::{ }; use crate::{ + hint_space_provider::HintSpaceProviderChip, mem_fill_helper, poseidon2::{ columns::{ @@ -45,6 +47,7 @@ pub struct NativePoseidon2Filler { // pre-computed Poseidon2 sub cols for dummy rows. empty_poseidon2_sub_cols: Vec, pub(super) subchip: Poseidon2SubChip, + pub hint_space_provider: Arc>, } impl NativePoseidon2Executor { @@ -71,12 +74,16 @@ pub(crate) fn compress( } impl NativePoseidon2Filler { - pub fn new(poseidon2_config: Poseidon2Config) -> Self { + pub fn new( + poseidon2_config: Poseidon2Config, + hint_space_provider: Arc>, + ) -> Self { let subchip = Poseidon2SubChip::new(poseidon2_config.constants); let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; Self { empty_poseidon2_sub_cols, subchip, + hint_space_provider, } } } @@ -649,11 +656,11 @@ where } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { let &Instruction { a: state_ptr_register, - b: init_pos_register, + b: ctx_register, c: input_ptr_register, d: register_address_space, e: data_address_space, - f: len_register, + f: hint_id_register, .. } = instruction; @@ -663,31 +670,50 @@ where ); assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); - let [init_pos]: [F; 1] = - memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); - let [input_len]: [F; 1] = - memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + // Read ctx_ptr from register, then read context array from memory + let [ctx_ptr]: [F; 1] = + memory_read_native(state.memory.data(), ctx_register.as_canonical_u32()); + let ctx: [F; 4] = + memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()); + let init_pos = ctx[0]; + let input_len = ctx[1]; + let is_hint = ctx[2].as_canonical_u32() != 0; + + // Read hint_id from register + let [hint_id]: [F; 1] = + memory_read_native(state.memory.data(), hint_id_register.as_canonical_u32()); + + // Get hint_space data if in hint mode + let hint_data: Vec = if is_hint { + state.streams.hint_space[hint_id.as_canonical_u32() as usize].clone() + } else { + vec![] + }; let mut len = input_len.as_canonical_u32() as usize; let mut pos = init_pos.as_canonical_u32() as usize; let mut chunks: Vec<(usize, usize)> = vec![]; - const NUM_HEAD_ACCESSES: usize = 4; + // 3 register reads + 1 context array read + 1 hint_id register read = 5 head accesses + const NUM_HEAD_ACCESSES: usize = 5; let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + // In hint mode: 1 timestamp per element (write only) + // In non-hint mode: 2 timestamps per element (read + write) + let ts_per_element = if is_hint { 1 } else { 2 }; while len > 0 { if len >= (CHUNK - pos) { chunks.push((pos, CHUNK)); len -= CHUNK - pos; - final_timestamp_inc += 2 * (CHUNK - pos) + 1; + final_timestamp_inc += ts_per_element * (CHUNK - pos) + 1; pos = 0; } else { chunks.push((pos, pos + len)); - final_timestamp_inc += 2 * len; + final_timestamp_inc += ts_per_element * len; len = 0; pos += len; } } - final_timestamp_inc += 1; // write back to init_pos_register + final_timestamp_inc += 1; // write back to ctx[0] let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -698,14 +724,15 @@ where let head_multi_observe_cols: &mut MultiObserveCols = head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + // 3 register reads: state_ptr, ctx_ptr, input_ptr let [state_ptr]: [F; 1] = tracing_read_native_helper( state.memory, state_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[0].as_mut(), ); - let [init_pos]: [F; 1] = tracing_read_native_helper( + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, - init_pos_register.as_canonical_u32(), + ctx_register.as_canonical_u32(), head_multi_observe_cols.read_data[1].as_mut(), ); let [input_ptr]: [F; 1] = tracing_read_native_helper( @@ -713,9 +740,16 @@ where input_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[2].as_mut(), ); - let [input_len]: [F; 1] = tracing_read_native_helper( + // 1 context array read: [init_pos, len, is_hint, reserved] + let ctx: [F; 4] = tracing_read_native_helper( state.memory, - len_register.as_canonical_u32(), + ctx_ptr.as_canonical_u32(), + head_multi_observe_cols.read_ctx.as_mut(), + ); + // 1 hint_id register read (reuse spare read_data[3] on head row) + let [hint_id]: [F; 1] = tracing_read_native_helper( + state.memory, + hint_id_register.as_canonical_u32(), head_multi_observe_cols.read_data[3].as_mut(), ); @@ -727,14 +761,15 @@ where for (i, cols) in allocated_rows.iter_mut().enumerate() { let multi_observe_cols: &mut MultiObserveCols = cols.specific[..MultiObserveCols::::width()].borrow_mut(); - multi_observe_cols.input_register_1 = init_pos_register; - multi_observe_cols.input_register_2 = input_ptr_register; - multi_observe_cols.input_register_3 = len_register; - multi_observe_cols.output_register = state_ptr_register; - multi_observe_cols.init_pos = init_pos; - multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr_register = state_ptr_register; + multi_observe_cols.ctx_register = ctx_register; + multi_observe_cols.input_ptr_register = input_ptr_register; + multi_observe_cols.hint_id_register = hint_id_register; multi_observe_cols.state_ptr = state_ptr; - multi_observe_cols.len = input_len; + multi_observe_cols.ctx_ptr = ctx_ptr; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.hint_id = hint_id; + multi_observe_cols.ctx = ctx; cols.multi_observe_row = F::ONE; cols.very_first_timestamp = init_timestamp; @@ -779,21 +814,28 @@ where multi_observe_cols.aux_before_end[j] = F::ONE; } for j in chunk_start..chunk_end { - let n_f: [F; 1] = tracing_read_native_helper( - state.memory, - input_ptr_u32 + input_idx as u32, - multi_observe_cols.read_data[j].as_mut(), - ); + let n_f: F = if is_hint { + // In hint mode: read from hint_space + hint_data[input_idx] + } else { + // In non-hint mode: read from memory via tracing read + let [v]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + v + }; multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, state_ptr_u32 + j as u32, - n_f, + [n_f], &mut multi_observe_cols.write_data[j], ); - multi_observe_cols.data[j] = n_f[0]; + multi_observe_cols.data[j] = n_f; input_idx += 1; - cur_timestamp += 2; + cur_timestamp += ts_per_element as u32; } let permutation_input: [F; 16] = @@ -817,7 +859,7 @@ where let final_idx = F::from_canonical_usize(chunk_end % CHUNK); tracing_write_native_inplace( state.memory, - init_pos_register.as_canonical_u32(), + ctx_ptr.as_canonical_u32(), [final_idx], &mut multi_observe_cols.write_final_idx, ); @@ -1161,7 +1203,7 @@ impl NativePoseidon2Filler::width()].borrow_mut(); let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); - // state_ptr, init_pos, input_ptr, len + // 3 register reads: state_ptr, ctx_ptr, input_ptr mem_fill_helper( mem_helper, start_timestamp_u32, @@ -1177,12 +1219,23 @@ impl NativePoseidon2Filler = chunk_slice @@ -1194,6 +1247,8 @@ impl NativePoseidon2Filler = chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); @@ -1205,18 +1260,32 @@ impl NativePoseidon2Filler= CHUNK as u32 { diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index abb8db54a2..c8490fc285 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -211,16 +211,22 @@ pub struct MultiObserveCols { pub pc: T, pub final_timestamp_increment: T, - // Initial reads from registers - // They are same across same instance of multi_observe + // Register addresses + pub state_ptr_register: T, + pub ctx_register: T, + pub input_ptr_register: T, + pub hint_id_register: T, + + // Values read from registers pub state_ptr: T, + pub ctx_ptr: T, pub input_ptr: T, - pub init_pos: T, - pub len: T, - pub input_register_1: T, - pub input_register_2: T, - pub input_register_3: T, - pub output_register: T, + pub hint_id: T, + + // Context array values read from ctx_ptr + // ctx[0] = init_pos, ctx[1] = len, ctx[2] = is_hint, ctx[3] = reserved + pub ctx: [T; 4], + pub read_ctx: MemoryReadAuxCols, pub is_first: T, pub is_last: T, @@ -240,6 +246,6 @@ pub struct MultiObserveCols { pub should_permute: T, pub write_sponge_state: MemoryWriteAuxCols, - // Final write back and registers + // Final write back to ctx[0] pub write_final_idx: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index a0c1fc72a2..a558205729 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -35,9 +35,9 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { #[repr(C)] struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { subchip: &'a Poseidon2SubChip, - pub init_pos_register: u32, + pub ctx_register: u32, pub input_ptr_register: u32, - pub len_register: u32, + pub hint_id_register: u32, pub state_ptr_register: u32, } @@ -137,9 +137,9 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor = if is_hint { + exec_state.streams.hint_space[hint_id_u32 as usize].clone() + } else { + vec![] + }; for (chunk_start, chunk_end) in observation_chunks { for j in chunk_start..chunk_end { - let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + let n_f = if is_hint { + hint_data[input_idx as usize] + } else { + let [v]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + v + }; exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } @@ -634,9 +652,10 @@ unsafe fn execute_multi_observe_e12_impl< height += 1; } if let Some(final_idx) = final_idx { + // Write final_idx back to ctx[0] (overwriting init_pos in context array) exec_state.vm_write::( NATIVE_AS, - pre_compute.init_pos_register, + ctx_ptr.as_canonical_u32(), &[F::from_canonical_usize(final_idx)], ); } diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 689bb1ebd0..14d84b4f1f 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,13 +489,13 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } - DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + DslIr::Poseidon2MultiObserve(dst, ctx_ptr, arr_ptr, hint_id) => { self.push( AsmInstruction::Poseidon2MultiObserve( dst.fp(), - init_pos.fp(), + ctx_ptr.fp(), arr_ptr.fp(), - len.get_var().fp(), + hint_id.fp(), ), debug_info, ); diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index b715d97cf1..16b2be4b49 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,9 +110,10 @@ pub enum AsmInstruction { /// Halt. Halt, - /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation - /// (sponge_state, init_pos, arr_ptr, len) - /// Returns the final index position of hash sponge + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation. + /// (sponge_state, ctx_ptr, arr_ptr, hint_id) + /// Context array at ctx_ptr: [init_pos, len, is_hint, reserved] + /// When is_hint=1, data is read from hint space using hint_id. Poseidon2MultiObserve(i32, i32, i32, i32), /// Perform a Poseidon2 permutation on state starting at address `lhs` @@ -350,11 +351,11 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), - AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => { write!( f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", - dst, init_pos, arr, len + dst, ctx, arr, hint_id ) } AsmInstruction::Poseidon2Permute(dst, lhs) => { diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 0ff358ec70..61fd726d3f 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -441,15 +441,15 @@ fn convert_instruction>( AS::Native, AS::Native, )], - AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => vec![ Instruction { opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), a: i32_f(dst), - b: i32_f(init), + b: i32_f(ctx), c: i32_f(arr), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(len), + f: i32_f(hint_id), g: F::ZERO, } ], diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index a4932d2826..8658a8a06e 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,13 +208,14 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), - /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations - /// (output = p2_multi_observe(array, els)). + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations. + /// Context values (init_pos, len, is_hint) are passed via a context array instead of separate registers. + /// When is_hint=1, data is read from hint space using hint_id instead of from input array pointer. Poseidon2MultiObserve( - Ptr, // sponge_state - Var, // initial input_ptr position - Ptr, // input array (base elements) - Usize, // len of els + Ptr, // sponge_state + Ptr, // ctx_ptr (context array: [init_pos, len, is_hint, reserved]) + Ptr, // input array (base elements; used when is_hint=0) + Var, // hint_id (hint space id; used when is_hint=1) ), // Miscellaneous instructions. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 6d32f89409..deb3b47f14 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -19,6 +19,7 @@ impl Builder { sponge_state: &Array>, input_ptr: Ptr, arr: &Array>, + hint_id: Option>, ) -> Usize { let buffer_size: Var = Var::uninit(self); self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); @@ -35,15 +36,36 @@ impl Builder { let init_pos: Var = Var::uninit(self); self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + let is_hint = hint_id.is_some(); + let hint_id_var: Var = if let Some(id) = hint_id { + id + } else { + let v: Var = Var::uninit(self); + self.assign(&v, C::N::ZERO); + v + }; + + // Allocate context array: [init_pos, len, is_hint, reserved] + let ctx = self.dyn_array::>(4usize); + self.set(&ctx, 0, init_pos); + self.set(&ctx, 1, len.get_var()); + self.set( + &ctx, + 2, + if is_hint { C::N::ONE } else { C::N::ZERO }, + ); + self.set(&ctx, 3, C::N::ZERO); + self.operations.push(DslIr::Poseidon2MultiObserve( *sponge_ptr, - init_pos, + ctx.ptr(), *ptr, - len.clone(), + hint_id_var, )); - // automatically updated by Poseidon2MultiObserve operation - Usize::Var(init_pos) + // Read back the updated init_pos from ctx[0] + let final_pos: Var = self.get(&ctx, 0); + Usize::Var(final_pos) } }, } diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 440b14ec59..b45639dc31 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -81,7 +81,7 @@ impl DuplexChallengerVariable { // This is equivalent to calling `observe` multiple times, but more efficient. pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { - let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr, None); builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); builder.if_ne(next_pos, Usize::from(0)).then_or_else( From 38a9ba482ddda7166e90cf671b577a60db91b02a Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 15 Mar 2026 22:15:42 -0400 Subject: [PATCH 2/9] adjust degree --- .../native/circuit/src/poseidon2/air.rs | 67 +++++++++++-------- .../native/circuit/src/poseidon2/chip.rs | 4 ++ .../native/circuit/src/poseidon2/columns.rs | 6 ++ 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index b7f9f458bb..11b9e0fede 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -94,6 +94,7 @@ impl Air inside_row, simple, multi_observe_row, + not_hint_multi_observe, end_inside_row, end_top_level, start_top_level, @@ -723,6 +724,7 @@ impl Air hint_id, ctx, read_ctx, + chunk_ts_count, is_first, is_last, curr_len, @@ -748,6 +750,16 @@ impl Air builder.when(multi_observe_row).assert_bool(is_last); builder.when(multi_observe_row).assert_bool(should_permute); builder.when(multi_observe_row).assert_bool(is_hint); + builder.assert_eq( + not_hint_multi_observe, + multi_observe_row * (AB::Expr::ONE - is_hint), + ); + let hint_multi_observe: AB::Expr = multi_observe_row - not_hint_multi_observe; + // chunk_ts_count = (end_idx - start_idx) * (2 - is_hint) + builder.when(multi_observe_row).assert_eq( + chunk_ts_count, + (end_idx - start_idx) * AB::F::TWO - (end_idx - start_idx) * is_hint, + ); self.execution_bridge .execute_and_increment_pc( @@ -813,49 +825,48 @@ impl Air ) .eval(builder, multi_observe_row * is_first); - // ts_per_element = 2 - is_hint (non-hint: read+write=2, hint: write-only=1) - let is_hint_expr: AB::Expr = is_hint.into(); - let ts_per_element: AB::Expr = AB::Expr::TWO - is_hint_expr.clone(); + // Per-element constraints for chunk rows. for i in 0..CHUNK { - let i_var: AB::Expr = AB::F::from_canonical_usize(i).into(); - let start_idx_expr: AB::Expr = start_idx.into(); - let element_start_ts: AB::Expr = - start_timestamp.into() + (i_var.clone() - start_idx_expr.clone()) * ts_per_element.clone(); + let i_var = AB::F::from_canonical_usize(i); + + // Hint mode: lookup from hint space. + self.hint_bridge.lookup( + builder, + hint_id, + curr_len + i_var - start_idx, + data[i], + hint_multi_observe.clone() * aux_read_enabled[i], + ); - // Non-hint mode: read from memory + // Non-hint mode: read from memory. self.memory_bridge .read( MemoryAddress::new( self.address_space, - input_ptr + curr_len + i_var.clone() - start_idx_expr.clone(), + input_ptr + curr_len + i_var - start_idx, ), [data[i]], - element_start_ts.clone(), + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, &read_data[i], ) - .eval( - builder, - multi_observe_row * aux_read_enabled[i] * (AB::Expr::ONE - is_hint_expr.clone()), - ); - - // Hint mode: lookup from hint space - self.hint_bridge.lookup( - builder, - hint_id, - curr_len + i_var.clone() - start_idx_expr.clone(), - data[i], - multi_observe_row * aux_read_enabled[i] * is_hint_expr.clone(), - ); + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, state_ptr + i_var), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_data[i], + ) + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); - // Write to sponge state (always, for both modes) self.memory_bridge .write( MemoryAddress::new(self.address_space, state_ptr + i_var), [data[i]], - element_start_ts + (AB::Expr::ONE - is_hint_expr.clone()), + start_timestamp + i_var - start_idx, &write_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); + .eval(builder, hint_multi_observe.clone() * aux_read_enabled[i]); } for i in 0..(CHUNK - 1) { @@ -926,7 +937,7 @@ impl Air .write( MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr.clone()), + start_timestamp + chunk_ts_count, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); @@ -955,7 +966,7 @@ impl Air .write( MemoryAddress::new(self.address_space, ctx_ptr), [final_idx], - start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr) + should_permute, + start_timestamp + chunk_ts_count + should_permute, &write_final_idx, ) .eval(builder, multi_observe_row * is_last); diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 7dee62f60d..061c6767db 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -771,7 +771,10 @@ where multi_observe_cols.hint_id = hint_id; multi_observe_cols.ctx = ctx; + // chunk_ts_count will be filled per-chunk row below + cols.multi_observe_row = F::ONE; + cols.not_hint_multi_observe = if is_hint { F::ZERO } else { F::ONE }; cols.very_first_timestamp = init_timestamp; if i == 0 { @@ -802,6 +805,7 @@ where multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + multi_observe_cols.chunk_ts_count = F::from_canonical_usize((chunk_end - chunk_start) * ts_per_element); multi_observe_cols.is_first = F::ZERO; multi_observe_cols.is_last = if i == num_chunks - 1 { F::ONE } else { F::ZERO }; diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index c8490fc285..67557a5f73 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -31,6 +31,10 @@ pub struct NativePoseidon2Cols { /// Indicates that this row is a multi_observe row. pub multi_observe_row: T, + /// Materialized column: multi_observe_row * (1 - is_hint). + /// Lives in main cols (not overlaid specific) so it is 0 on non-multi_observe rows. + pub not_hint_multi_observe: T, + /// Indicates the last row in an inside-row block. pub end_inside_row: T, /// Indicates the last row in a top-level block. @@ -228,6 +232,8 @@ pub struct MultiObserveCols { pub ctx: [T; 4], pub read_ctx: MemoryReadAuxCols, + pub chunk_ts_count: T, + pub is_first: T, pub is_last: T, pub curr_len: T, From 1f5c698d18ffadcf8d6652991ccdb801d2789511 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 16 Mar 2026 06:03:11 -0400 Subject: [PATCH 3/9] fix --- .../native/circuit/src/extension/cuda.rs | 10 +- .../native/circuit/src/extension/mod.rs | 30 ++--- .../native/circuit/src/poseidon2/chip.rs | 12 +- .../native/circuit/src/poseidon2/cuda.rs | 112 +++++++++++++++++- .../native/circuit/src/poseidon2/execution.rs | 6 + extensions/native/compiler/src/ir/poseidon.rs | 14 ++- .../native/recursion/src/challenger/duplex.rs | 2 +- 7 files changed, 160 insertions(+), 26 deletions(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 50a9cba86d..0646eda347 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -76,8 +76,6 @@ impl VmProverExtension inventory.add_executor_chip(fri_reduced_opening); inventory.next_air::>()?; - let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); - inventory.add_executor_chip(poseidon2); let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; let cpu_chip = Arc::new(HintSpaceProviderChip::new( @@ -85,6 +83,14 @@ impl VmProverExtension range_checker.clone(), timestamp_max_bits, )); + + let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider( + range_checker.clone(), + timestamp_max_bits, + cpu_chip.clone(), + ); + inventory.add_executor_chip(poseidon2); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); inventory.add_periphery_chip(provider_gpu); diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 0aab146a13..23d28c10d3 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -271,15 +271,6 @@ where ); inventory.add_air(fri_reduced_opening); - let verify_batch = NativePoseidon2Air::<_, 1>::new( - exec_bridge, - memory_bridge, - hint_bridge, - VerifyBatchBus::new(inventory.new_bus_idx()), - Poseidon2Config::default(), - ); - inventory.add_air(verify_batch); - let hint_space_provider = HintSpaceProviderAir { hint_bus: hint_bridge.hint_bus(), lt_air: IsLtSubAir::new( @@ -289,6 +280,15 @@ where }; inventory.add_air(hint_space_provider); + let verify_batch = NativePoseidon2Air::<_, 1>::new( + exec_bridge, + memory_bridge, + hint_bridge, + VerifyBatchBus::new(inventory.new_bus_idx()), + Poseidon2Config::default(), + ); + inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); @@ -365,8 +365,6 @@ where FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone()); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::, 1>>()?; - let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); let hint_space_provider = Arc::new(HintSpaceProviderChip::new( hint_bus, @@ -374,17 +372,19 @@ where timestamp_max_bits, )); + inventory.next_air::()?; + inventory.add_periphery_chip(hint_space_provider.clone()); + + inventory.next_air::, 1>>()?; + let poseidon2 = NativePoseidon2Chip::<_, 1>::new( NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()), mem_helper.clone(), ); inventory.add_executor_chip(poseidon2); - inventory.next_air::()?; - inventory.add_periphery_chip(hint_space_provider.clone()); - let tower_verify = NativeSumcheckChip::new( - NativeSumcheckFiller::new(hint_space_provider), + NativeSumcheckFiller::new(hint_space_provider.clone()), mem_helper.clone(), ); inventory.add_executor_chip(tower_verify); diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 061c6767db..9e9a6b07ba 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -679,6 +679,10 @@ where let input_len = ctx[1]; let is_hint = ctx[2].as_canonical_u32() != 0; + + // _debug + println!("=> is_hint: {:?}", is_hint); + // Read hint_id from register let [hint_id]: [F; 1] = memory_read_native(state.memory.data(), hint_id_register.as_canonical_u32()); @@ -830,6 +834,12 @@ where ); v }; + + // _debug + if is_hint { + println!("multi_observe hint mode: reading nf = {}", n_f); + } + multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, @@ -1265,7 +1275,7 @@ impl NativePoseidon2Filler { pub range_checker: Arc, pub timestamp_max_bits: usize, + pub hint_space_provider: Option>, +} + +impl NativePoseidon2ChipGpu { + pub fn new(range_checker: Arc, timestamp_max_bits: usize) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: None, + } + } + + pub fn new_with_hint_space_provider( + range_checker: Arc, + timestamp_max_bits: usize, + hint_space_provider: SharedHintSpaceProviderChip, + ) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: Some(hint_space_provider), + } + } + + /// Scans multi-observe execution records to populate the hint provider with + /// (hint_id, offset, value) triples for hint-mode rows. + fn populate_hint_provider(&self, records: &[u8]) { + let Some(hint_space_provider) = &self.hint_space_provider else { + return; + }; + + let width = NativePoseidon2Cols::::width(); + let record_size = width * size_of::(); + if records.len() % record_size != 0 { + return; + } + let height = records.len() / record_size; + + let row_slice = unsafe { + let ptr = records.as_ptr() as *const F; + from_raw_parts(ptr, height * width) + }; + + let mut row_idx = 0; + while row_idx < height { + let start = row_idx * width; + let cols: &NativePoseidon2Cols = + row_slice[start..(start + width)].borrow(); + + if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + if num_rows > 1 { + let head_multi_observe_cols: &MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow(); + let is_hint = head_multi_observe_cols.ctx[2] != F::ZERO; + if is_hint { + let hint_id = head_multi_observe_cols.hint_id; + for local_row in 1..num_rows { + let chunk_cols: &NativePoseidon2Cols = + row_slice[(row_idx + local_row) * width + ..(row_idx + local_row + 1) * width] + .borrow(); + let multi_observe_cols: &MultiObserveCols = chunk_cols.specific + [..MultiObserveCols::::width()] + .borrow(); + + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + let curr_len = multi_observe_cols.curr_len.as_canonical_u32(); + + for j in chunk_start..chunk_end { + let input_idx = curr_len + (j - chunk_start); + let val = multi_observe_cols.data[j as usize]; + hint_space_provider.request( + hint_id, + F::from_canonical_u32(input_idx), + val, + ); + } + } + } + } + row_idx += num_rows.max(1); + continue; + } + + if cols.simple.is_one() { + row_idx += 1; + } else { + let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; + let non_inside_start = start + (num_non_inside_row - 1) * width; + let last_non_inside_cols: &NativePoseidon2Cols = + row_slice[non_inside_start..(non_inside_start + width)].borrow(); + let total_num_row = last_non_inside_cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + } + } + } } impl Chip @@ -28,6 +127,9 @@ impl Chip return get_empty_air_proving_ctx::(); } + // Populate hint space provider from multi-observe records before GPU upload. + self.populate_hint_provider(records); + // For Poseidon2, the records are already the trace rows // Use the columns width directly let width = NativePoseidon2Cols::::width(); diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index a558205729..4cec7ab2b4 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -640,6 +640,12 @@ unsafe fn execute_multi_observe_e12_impl< let [v]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); v }; + + // _debug + if is_hint { + println!("=> n_f: {n_f}"); + } + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index deb3b47f14..aadadf46ea 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -19,6 +19,7 @@ impl Builder { sponge_state: &Array>, input_ptr: Ptr, arr: &Array>, + input_len: Usize, hint_id: Option>, ) -> Usize { let buffer_size: Var = Var::uninit(self); @@ -32,7 +33,7 @@ impl Builder { Array::Fixed(_) => { panic!("Base elements input must be dynamic"); } - Array::Dyn(ptr, len) => { + Array::Dyn(ptr, _) => { let init_pos: Var = Var::uninit(self); self.assign(&init_pos, input_ptr.address - sponge_ptr.address); @@ -48,7 +49,7 @@ impl Builder { // Allocate context array: [init_pos, len, is_hint, reserved] let ctx = self.dyn_array::>(4usize); self.set(&ctx, 0, init_pos); - self.set(&ctx, 1, len.get_var()); + self.set(&ctx, 1, input_len.get_var()); self.set( &ctx, 2, @@ -56,6 +57,15 @@ impl Builder { ); self.set(&ctx, 3, C::N::ZERO); + + // _debug + let ctx1 = self.get(&ctx, 1); + let ctx2 = self.get(&ctx, 2); + self.print_debug(777); + self.print_v(ctx1); + self.print_v(ctx2); + + self.operations.push(DslIr::Poseidon2MultiObserve( *sponge_ptr, ctx.ptr(), diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index b45639dc31..10b9bc62e9 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -81,7 +81,7 @@ impl DuplexChallengerVariable { // This is equivalent to calling `observe` multiple times, but more efficient. pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { - let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr, None); + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr, arr.len(), None); builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); builder.if_ne(next_pos, Usize::from(0)).then_or_else( From 3a006f3ac63ac81876874d712d1fa175a73171e6 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 16 Mar 2026 22:22:48 -0400 Subject: [PATCH 4/9] debug --- .../native/circuit/src/extension/mod.rs | 1 + .../native/circuit/src/poseidon2/chip.rs | 43 ++++++++++++++----- .../native/circuit/src/poseidon2/execution.rs | 5 --- extensions/native/compiler/src/ir/poseidon.rs | 9 ---- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 23d28c10d3..4a930ddafe 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -383,6 +383,7 @@ where ); inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; let tower_verify = NativeSumcheckChip::new( NativeSumcheckFiller::new(hint_space_provider.clone()), mem_helper.clone(), diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 9e9a6b07ba..c0a559bb65 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -679,10 +679,6 @@ where let input_len = ctx[1]; let is_hint = ctx[2].as_canonical_u32() != 0; - - // _debug - println!("=> is_hint: {:?}", is_hint); - // Read hint_id from register let [hint_id]: [F; 1] = memory_read_native(state.memory.data(), hint_id_register.as_canonical_u32()); @@ -717,7 +713,9 @@ where pos += len; } } - final_timestamp_inc += 1; // write back to ctx[0] + // Final ctx[0] writeback always happens (including zero-length input + // where the head row is both the first and last row). + final_timestamp_inc += 1; let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -788,9 +786,26 @@ where multi_observe_cols.final_timestamp_increment = F::from_canonical_usize(final_timestamp_inc); multi_observe_cols.is_first = F::ONE; - multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.is_last = if chunks.is_empty() { F::ONE } else { F::ZERO }; multi_observe_cols.curr_len = F::ZERO; multi_observe_cols.should_permute = F::ZERO; + if chunks.is_empty() { + // Zero-length input: head row is both first and last. + // Set start_timestamp to right after the 5 head reads, + // and write back init_pos (unchanged) to ctx_ptr[0]. + cols.start_timestamp = F::from_canonical_u32( + init_timestamp_u32 + NUM_HEAD_ACCESSES as u32, + ); + multi_observe_cols.start_idx = init_pos; + multi_observe_cols.end_idx = init_pos; + // state.memory.timestamp == init_ts + NUM_HEAD_ACCESSES here. + tracing_write_native_inplace( + state.memory, + ctx_ptr.as_canonical_u32(), + [init_pos], + &mut multi_observe_cols.write_final_idx, + ); + } } } @@ -835,11 +850,6 @@ where v }; - // _debug - if is_hint { - println!("multi_observe hint mode: reading nf = {}", n_f); - } - multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, @@ -1318,6 +1328,17 @@ impl NativePoseidon2Filler = + chunk_slice[..width].borrow_mut(); + let head_mo: &mut MultiObserveCols = + head_c.specific[..MultiObserveCols::::width()].borrow_mut(); + let head_ts = head_c.start_timestamp.as_canonical_u32(); + mem_fill_helper(mem_helper, head_ts, head_mo.write_final_idx.as_mut()); + } } #[inline(always)] diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 4cec7ab2b4..d41d911812 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -641,11 +641,6 @@ unsafe fn execute_multi_observe_e12_impl< v }; - // _debug - if is_hint { - println!("=> n_f: {n_f}"); - } - exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index aadadf46ea..6310917c0d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -57,15 +57,6 @@ impl Builder { ); self.set(&ctx, 3, C::N::ZERO); - - // _debug - let ctx1 = self.get(&ctx, 1); - let ctx2 = self.get(&ctx, 2); - self.print_debug(777); - self.print_v(ctx1); - self.print_v(ctx2); - - self.operations.push(DslIr::Poseidon2MultiObserve( *sponge_ptr, ctx.ptr(), From 065d59811285057fc8206336503b64febad9c99f Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 17 Mar 2026 16:13:03 -0400 Subject: [PATCH 5/9] adjust --- extensions/native/circuit/cuda/src/poseidon2.cu | 6 ------ extensions/native/circuit/src/poseidon2/air.rs | 1 - extensions/native/circuit/src/poseidon2/chip.rs | 2 -- 3 files changed, 9 deletions(-) diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index b39038a079..749599d906 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -355,7 +355,6 @@ template struct Poseidon2Wrapper { if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { uint32_t very_start_timestamp = row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); - // 3 register reads at timestamps +0, +1, +2 for (uint32_t i = 0; i < 3; ++i) { mem_fill_base( mem_helper, @@ -363,13 +362,11 @@ template struct Poseidon2Wrapper { specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) ); } - // 1 context array read at timestamp +3 mem_fill_base( mem_helper, very_start_timestamp + 3, specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base)) ); - // 1 hint_id register read at timestamp +4 (reuse spare read_data[3] on head row) mem_fill_base( mem_helper, very_start_timestamp + 4, @@ -381,20 +378,17 @@ template struct Poseidon2Wrapper { specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); uint32_t chunk_end = specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); - // is_hint = ctx[2] uint32_t is_hint = specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32(); uint32_t ts_per_element = 2 - is_hint; for (uint32_t j = chunk_start; j < chunk_end; ++j) { if (!is_hint) { - // Non-hint mode: fill read_data aux mem_fill_base( mem_helper, start_timestamp, specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) ); } - // Write timestamp: start_timestamp + (1 - is_hint) for non-hint, start_timestamp for hint mem_fill_base( mem_helper, start_timestamp + (1 - is_hint), diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 11b9e0fede..8b3b02ebc2 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -777,7 +777,6 @@ impl Air ) .eval(builder, multi_observe_row * is_first); - // Head row: 3 register reads + 1 context array read + 1 hint_id register read self.memory_bridge .read( MemoryAddress::new(self.address_space, state_ptr_register), diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index c0a559bb65..e136c2fbb0 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1330,8 +1330,6 @@ impl NativePoseidon2Filler = chunk_slice[..width].borrow_mut(); let head_mo: &mut MultiObserveCols = From 4f33e48406a6baf8b082dccb19fa2bbc74171205 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 17 Mar 2026 22:36:50 -0400 Subject: [PATCH 6/9] adjust cuda --- .../circuit/cuda/include/native/poseidon2.cuh | 15 +++++++++------ extensions/native/circuit/cuda/src/poseidon2.cu | 1 + 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 206c0e16c0..b794aacaef 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -65,14 +65,17 @@ template struct SimplePoseidonSpecificCols { template struct MultiObserveCols { T pc; T final_timestamp_increment; + T state_ptr_register; + T ctx_register; + T input_ptr_register; + T hint_id_register; T state_ptr; + T ctx_ptr; T input_ptr; - T init_pos; - T len; - T input_register_1; - T input_register_2; - T input_register_3; - T output_register; + T hint_id; + T ctx[4]; + MemoryReadAuxCols read_ctx; + T chunk_ts_count; T is_first; T is_last; T curr_len; diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 749599d906..772a708f89 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -24,6 +24,7 @@ template struct NativePoseidon2Cols { T inside_row; T simple; T multi_observe_row; + T not_hint_multi_observe; T end_inside_row; T end_top_level; From 4a1272231ca0db1a81e89ac42141cce388c5b2f7 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 17 Mar 2026 22:59:54 -0400 Subject: [PATCH 7/9] fix cuda --- extensions/native/circuit/src/extension/cuda.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 0646eda347..d3eb6da4fc 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -78,9 +78,13 @@ impl VmProverExtension inventory.next_air::>()?; let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; + let cpu_range_checker = range_checker + .cpu_chip + .clone() + .expect("VariableRangeCheckerChipGPU is expected to be hybrid with cpu_chip"); let cpu_chip = Arc::new(HintSpaceProviderChip::new( hint_air.hint_bus, - range_checker.clone(), + cpu_range_checker, timestamp_max_bits, )); From 7e6d1be52421d33102980a4fb5e6a58ee3ee799f Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 17 Mar 2026 23:30:01 -0400 Subject: [PATCH 8/9] fix cuda --- extensions/native/circuit/src/extension/cuda.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index d3eb6da4fc..9777b53dc4 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -75,8 +75,6 @@ impl VmProverExtension FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::>()?; - let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; let cpu_range_checker = range_checker .cpu_chip @@ -88,6 +86,8 @@ impl VmProverExtension timestamp_max_bits, )); + inventory.next_air::>()?; + let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider( range_checker.clone(), timestamp_max_bits, From fbe927e773967769af5e0bedafdc9565255c326c Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 18 Mar 2026 00:39:10 -0400 Subject: [PATCH 9/9] fix cuda --- extensions/native/circuit/src/extension/cuda.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 9777b53dc4..0d476413ac 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -86,6 +86,9 @@ impl VmProverExtension timestamp_max_bits, )); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); + inventory.add_periphery_chip(provider_gpu); + inventory.next_air::>()?; let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider( @@ -95,9 +98,6 @@ impl VmProverExtension ); inventory.add_executor_chip(poseidon2); - let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); - inventory.add_periphery_chip(provider_gpu); - inventory.next_air::()?; let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip);