From 3c4d6d23e8ab225d42f11cb3bc2527b98c544b8c Mon Sep 17 00:00:00 2001 From: dicethedev Date: Fri, 24 Apr 2026 14:13:44 +0100 Subject: [PATCH] Add SIMD target-sum grinding for Poseidon signing --- src/inc_encoding.rs | 24 +++++++ src/inc_encoding/target_sum.rs | 97 ++++++++++++++++++++++++++ src/signature/generalized_xmss.rs | 36 ++-------- src/symmetric/message_hash.rs | 29 ++++++++ src/symmetric/message_hash/poseidon.rs | 61 +++++++++++++++- 5 files changed, 213 insertions(+), 34 deletions(-) diff --git a/src/inc_encoding.rs b/src/inc_encoding.rs index f5803e2..b935b88 100644 --- a/src/inc_encoding.rs +++ b/src/inc_encoding.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use crate::MESSAGE_LENGTH; use crate::serialization::Serializable; +use crate::symmetric::prf::Pseudorandom; /// Trait to model incomparable encoding schemes. /// These schemes allow to encode a message into a codeword. @@ -45,6 +46,29 @@ pub trait IncomparableEncoding { randomness: &Self::Randomness, epoch: u32, ) -> Result, Self::Error>; + + /// Deterministically search for the first randomness that yields a valid codeword. + /// + /// Implementations may override this with a batched or SIMD-accelerated search. + fn grind( + parameter: &Self::Parameter, + prf_key: &PRF::Key, + epoch: u32, + message: &[u8; MESSAGE_LENGTH], + ) -> Option<(Self::Randomness, Vec)> + where + PRF: Pseudorandom, + PRF::Randomness: Into, + { + for attempt in 0..Self::MAX_TRIES { + let randomness = PRF::get_randomness(prf_key, epoch, message, attempt as u64).into(); + if let Ok(codeword) = Self::encode(parameter, message, &randomness, epoch) { + return Some((randomness, codeword)); + } + } + + None + } } pub mod target_sum; diff --git a/src/inc_encoding/target_sum.rs b/src/inc_encoding/target_sum.rs index 7360719..e2b2be1 100644 --- a/src/inc_encoding/target_sum.rs +++ b/src/inc_encoding/target_sum.rs @@ -2,6 +2,7 @@ use super::IncomparableEncoding; use crate::{MESSAGE_LENGTH, symmetric::message_hash::MessageHash}; use std::fmt::Debug; use thiserror::Error; +use crate::symmetric::prf::Pseudorandom; /// Specific errors that can occur during target sum encoding. #[derive(Debug, Error)] @@ -61,7 +62,32 @@ impl IncomparableEncoding randomness: &Self::Randomness, epoch: u32, ) -> Result, Self::Error> { + // Compile-time parameter validation for Target Sum Encoding + // + // This encoding implements Construction 6 (IE for Target Sum Winternitz) + // from DKKW25. It maps a message to a codeword x ∈ C ⊆ Z_w^v, where: + // + // C = { (x_1, ..., x_v) ∈ {0, ..., w-1}^v | Σ x_i = T } + // + // The code C enforces the *incomparability* property (Definition 13): + // no two distinct codewords x, x' satisfy x_i ≥ x'_i for all i. + // This is critical for the security of the XMSS signature scheme. + // + // DKKW25: https://eprint.iacr.org/2025/055 + // HHKTW26: https://eprint.iacr.org/2026/016 const { + // Representation constraints + // + // In the Generalized XMSS construction (DKKW25), + // each chain position and chain index is encoded as a single byte + // in the tweak function: + // + // tweak(ep, i, k) = (0x00 || ep || i || k) + // 8b ⌈log L⌉ ⌈log v⌉ w bits + // + // - Since chain_index `i` is stored as u8, we need v ≤ 256. + // - Since pos_in_chain `k` is stored as u8, we need w ≤ 256. + // - Codeword entries (chunks) are also stored as u8 in signatures. // base and dimension must not be too large assert!( MH::BASE <= 1 << 8, @@ -71,6 +97,36 @@ impl IncomparableEncoding MH::DIMENSION <= 1 << 8, "Target Sum Encoding: Dimension must be at most 2^8" ); + + // Encoding well-formedness + // + // Definition 13 (DKKW25): an incomparable encoding maps messages + // to codewords in {0, ..., w-1}^v. For the incomparability + // property to be meaningful, we need w ≥ 2 (otherwise every + // codeword is the zero vector, and distinct codewords cannot + // exist). + // assert!( + // MH::BASE >= 2, + // "Target Sum Encoding: Base must be at least 2" + // ); + + // Target sum range + // + // Construction 6 (DKKW25) defines the code: + // + // C = { x ∈ {0,...,w-1}^v | Σ x_i = T } + // + // For C to be non-empty, T must be achievable: each x_i can + // contribute at most w-1 to the sum, so T ≤ v*(w-1). The lower + // bound T ≥ 0 is guaranteed by the usize type. + // + // Choosing T close to v*(w-1)/2 (the expected sum of a uniform + // hash) maximizes |C| and minimizes the signing retry rate + // (Lemma 7, DKKW25). + // assert!( + // TARGET_SUM <= MH::DIMENSION * (MH::BASE - 1), + // "Target Sum Encoding: TARGET_SUM must be at most DIMENSION * (BASE - 1)" + // ); } // apply the message hash first to get chunks @@ -87,6 +143,19 @@ impl IncomparableEncoding }) } } + + fn grind( + parameter: &Self::Parameter, + prf_key: &PRF::Key, + epoch: u32, + message: &[u8; MESSAGE_LENGTH], + ) -> Option<(Self::Randomness, Vec)> + where + PRF: Pseudorandom, + PRF::Randomness: Into, + { + MH::grind_target_sum::(parameter, prf_key, epoch, message, Self::MAX_TRIES) + } } #[cfg(test)] @@ -95,12 +164,14 @@ mod tests { use crate::F; use crate::array::FieldArray; use crate::symmetric::message_hash::poseidon::PoseidonMessageHash445; + use crate::symmetric::prf::{Pseudorandom, shake_to_field::ShakePRFtoF}; use p3_field::PrimeField32; use proptest::prelude::*; use rand::RngExt; const TEST_TARGET_SUM: usize = 115; type TestTargetSumEncoding = TargetSumEncoding; + type TestPRF = ShakePRFtoF<4, 4>; #[test] fn test_successful_encoding_fixed_message() { @@ -180,6 +251,32 @@ mod tests { panic!("failed to find successful encoding after 1000 attempts"); } + #[test] + fn test_grind_matches_first_successful_attempt() { + let mut rng = rand::rng(); + let parameter: FieldArray<4> = FieldArray(rng.random()); + let message: [u8; 32] = rng.random(); + let epoch = 7u32; + let prf_key = TestPRF::key_gen(&mut rng); + + let expected = (0..TestTargetSumEncoding::MAX_TRIES).find_map(|attempt| { + let randomness = TestPRF::get_randomness(&prf_key, epoch, &message, attempt as u64); + TestTargetSumEncoding::encode(¶meter, &message, &randomness.into(), epoch) + .ok() + .map(|chunks| (randomness.into(), chunks)) + }); + + let actual = + ::grind::( + ¶meter, + &prf_key, + epoch, + &message, + ); + + assert_eq!(actual, expected); + } + proptest! { #[test] fn proptest_encoding_determinism_and_error_reporting( diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 72ab34f..46dc073 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -814,39 +814,13 @@ where let path = combined_path(&sk.top_tree, bottom_tree, epoch); // now, we need to encode our message using the incomparable encoding. - // we retry until we get a valid codeword, or until we give up. - let max_tries = IE::MAX_TRIES; - let mut attempts = 0; - let mut x = None; - let mut rho = None; - while attempts < max_tries { - // get a randomness and try to encode the message. Note: we get the randomness from the PRF - // which ensures that signing is deterministic. The PRF is applied to the message and the epoch. - // While the intention is that users of the scheme never call sign twice with the same (epoch, sk) pair, - // this deterministic approach ensures that calling sign twice is fine, as long as the message stays the same. - let curr_rho = PRF::get_randomness(&sk.prf_key, epoch, message, attempts as u64).into(); - let curr_x = IE::encode(&sk.parameter.into(), message, &curr_rho, epoch); - - // check if we have found a valid codeword, and if so, stop searching - if curr_x.is_ok() { - rho = Some(curr_rho); - x = curr_x.ok(); - break; - } - - attempts += 1; - } - - // if we have not found a valid codeword, return an error - if x.is_none() { + // this search stays deterministic: we always return the first successful PRF counter. + let Some((rho, x)) = IE::grind::(&sk.parameter.into(), &sk.prf_key, epoch, message) + else { return Err(SigningError::EncodingAttemptsExceeded { - attempts: max_tries, + attempts: IE::MAX_TRIES, }); - } - - // otherwise, unwrap x and rho - let x = x.unwrap(); - let rho = rho.unwrap(); + }; // we will include rho in the signature, and // we use x to determine how far the signer walks in the chains diff --git a/src/symmetric/message_hash.rs b/src/symmetric/message_hash.rs index 0ba2181..9b6dd0c 100644 --- a/src/symmetric/message_hash.rs +++ b/src/symmetric/message_hash.rs @@ -4,6 +4,7 @@ use rand::RngExt; use crate::MESSAGE_LENGTH; use crate::serialization::Serializable; +use crate::symmetric::prf::Pseudorandom; pub use poseidon::encode_message; @@ -39,6 +40,34 @@ pub trait MessageHash { randomness: &Self::Randomness, message: &[u8; MESSAGE_LENGTH], ) -> Result, Self::Error>; + + /// Search deterministically for the first randomness whose chunks hit `TARGET_SUM`. + /// + /// Implementations may override this with a batched or SIMD-accelerated search. + fn grind_target_sum( + parameter: &Self::Parameter, + prf_key: &PRF::Key, + epoch: u32, + message: &[u8; MESSAGE_LENGTH], + max_tries: usize, + ) -> Option<(Self::Randomness, Vec)> + where + PRF: Pseudorandom, + PRF::Randomness: Into, + { + for attempt in 0..max_tries { + let randomness = PRF::get_randomness(prf_key, epoch, message, attempt as u64).into(); + let Ok(chunks) = Self::apply(parameter, epoch, &randomness, message) else { + continue; + }; + + if chunks.iter().map(|&chunk| chunk as usize).sum::() == TARGET_SUM { + return Some((randomness, chunks)); + } + } + + None + } } pub mod aborting; diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index f4fe81e..6062a8b 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -1,9 +1,8 @@ +use core::array; use std::convert::Infallible; use num_bigint::BigUint; -use p3_field::PrimeCharacteristicRing; -use p3_field::PrimeField; -use p3_field::PrimeField64; +use p3_field::{PackedValue, PrimeCharacteristicRing, PrimeField, PrimeField64}; use serde::{Serialize, de::DeserializeOwned}; use super::MessageHash; @@ -12,7 +11,10 @@ use crate::MESSAGE_LENGTH; use crate::TWEAK_SEPARATOR_FOR_MESSAGE_HASH; use crate::array::FieldArray; use crate::poseidon1_24; +use crate::simd_utils::pack_array; +use crate::symmetric::prf::Pseudorandom; use crate::symmetric::tweak_hash::poseidon::poseidon_compress; +use crate::PackedF; /// Function to encode a message as an array of field elements pub fn encode_message(message: &[u8; MESSAGE_LENGTH]) -> [F; MSG_LEN_FE] { @@ -241,6 +243,59 @@ where Ok(decode_to_chunks::(&hash_fe).to_vec()) } + + fn grind_target_sum( + parameter: &Self::Parameter, + prf_key: &PRF::Key, + epoch: u32, + message: &[u8; MESSAGE_LENGTH], + max_tries: usize, + ) -> Option<(Self::Randomness, Vec)> + where + PRF: Pseudorandom, + PRF::Randomness: Into, + { + let perm = poseidon1_24(); + let lanes = PackedF::WIDTH; + + let packed_message: [PackedF; MSG_LEN_FE] = + encode_message::(message).map(PackedF::from); + let packed_parameter: [PackedF; PARAMETER_LEN] = array::from_fn(|i| PackedF::from(parameter[i])); + let packed_epoch: [PackedF; TWEAK_LEN_FE] = encode_epoch::(epoch).map(PackedF::from); + + for batch_start in (0..max_tries).step_by(lanes) { + let valid_lanes = (max_tries - batch_start).min(lanes); + let randomnesses: [FieldArray; PackedF::WIDTH] = array::from_fn(|lane| { + let attempt = batch_start + lane.min(valid_lanes.saturating_sub(1)); + PRF::get_randomness(prf_key, epoch, message, attempt as u64).into() + }); + let packed_randomness = pack_array(&randomnesses); + + let combined_input_len = MSG_LEN_FE + PARAMETER_LEN + TWEAK_LEN_FE + RAND_LEN_FE; + let mut combined_input = Vec::with_capacity(combined_input_len); + combined_input.extend_from_slice(&packed_message); + combined_input.extend_from_slice(&packed_parameter); + combined_input.extend_from_slice(&packed_epoch); + combined_input.extend_from_slice(&packed_randomness); + + let packed_hash = + poseidon_compress::(&perm, &combined_input); + + let mut unpacked_hashes = [FieldArray([F::ZERO; HASH_LEN_FE]); PackedF::WIDTH]; + PackedF::unpack_into(&packed_hash, FieldArray::as_raw_slice_mut(&mut unpacked_hashes)); + + for lane in 0..valid_lanes { + let chunks = + decode_to_chunks::(&unpacked_hashes[lane].0) + .to_vec(); + if chunks.iter().map(|&chunk| chunk as usize).sum::() == TARGET_SUM { + return Some((randomnesses[lane], chunks)); + } + } + } + + None + } } // Example instantiations