Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/inc_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::Debug;

use crate::MESSAGE_LENGTH;
use crate::serialization::Serializable;
use crate::symmetric::prf::Pseudorandom;

/// Trait to model incomparable encoding schemes.
/// These schemes allow to encode a message into a codeword.
Expand Down Expand Up @@ -45,6 +46,29 @@ pub trait IncomparableEncoding {
randomness: &Self::Randomness,
epoch: u32,
) -> Result<Vec<u8>, Self::Error>;

/// Deterministically search for the first randomness that yields a valid codeword.
///
/// Implementations may override this with a batched or SIMD-accelerated search.
fn grind<PRF>(
parameter: &Self::Parameter,
prf_key: &PRF::Key,
epoch: u32,
message: &[u8; MESSAGE_LENGTH],
) -> Option<(Self::Randomness, Vec<u8>)>
where
PRF: Pseudorandom,
PRF::Randomness: Into<Self::Randomness>,
{
for attempt in 0..Self::MAX_TRIES {
let randomness = PRF::get_randomness(prf_key, epoch, message, attempt as u64).into();
if let Ok(codeword) = Self::encode(parameter, message, &randomness, epoch) {
return Some((randomness, codeword));
}
}

None
}
}

pub mod target_sum;
97 changes: 97 additions & 0 deletions src/inc_encoding/target_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::IncomparableEncoding;
use crate::{MESSAGE_LENGTH, symmetric::message_hash::MessageHash};
use std::fmt::Debug;
use thiserror::Error;
use crate::symmetric::prf::Pseudorandom;

/// Specific errors that can occur during target sum encoding.
#[derive(Debug, Error)]
Expand Down Expand Up @@ -61,7 +62,32 @@ impl<MH: MessageHash, const TARGET_SUM: usize> IncomparableEncoding
randomness: &Self::Randomness,
epoch: u32,
) -> Result<Vec<u8>, Self::Error> {
// Compile-time parameter validation for Target Sum Encoding
//
// This encoding implements Construction 6 (IE for Target Sum Winternitz)
// from DKKW25. It maps a message to a codeword x ∈ C ⊆ Z_w^v, where:
//
// C = { (x_1, ..., x_v) ∈ {0, ..., w-1}^v | Σ x_i = T }
//
// The code C enforces the *incomparability* property (Definition 13):
// no two distinct codewords x, x' satisfy x_i ≥ x'_i for all i.
// This is critical for the security of the XMSS signature scheme.
//
// DKKW25: https://eprint.iacr.org/2025/055
// HHKTW26: https://eprint.iacr.org/2026/016
const {
// Representation constraints
//
// In the Generalized XMSS construction (DKKW25),
// each chain position and chain index is encoded as a single byte
// in the tweak function:
//
// tweak(ep, i, k) = (0x00 || ep || i || k)
// 8b ⌈log L⌉ ⌈log v⌉ w bits
//
// - Since chain_index `i` is stored as u8, we need v ≤ 256.
// - Since pos_in_chain `k` is stored as u8, we need w ≤ 256.
// - Codeword entries (chunks) are also stored as u8 in signatures.
// base and dimension must not be too large
assert!(
MH::BASE <= 1 << 8,
Expand All @@ -71,6 +97,36 @@ impl<MH: MessageHash, const TARGET_SUM: usize> IncomparableEncoding
MH::DIMENSION <= 1 << 8,
"Target Sum Encoding: Dimension must be at most 2^8"
);

// Encoding well-formedness
//
// Definition 13 (DKKW25): an incomparable encoding maps messages
// to codewords in {0, ..., w-1}^v. For the incomparability
// property to be meaningful, we need w ≥ 2 (otherwise every
// codeword is the zero vector, and distinct codewords cannot
// exist).
// assert!(
// MH::BASE >= 2,
// "Target Sum Encoding: Base must be at least 2"
// );

// Target sum range
//
// Construction 6 (DKKW25) defines the code:
//
// C = { x ∈ {0,...,w-1}^v | Σ x_i = T }
//
// For C to be non-empty, T must be achievable: each x_i can
// contribute at most w-1 to the sum, so T ≤ v*(w-1). The lower
// bound T ≥ 0 is guaranteed by the usize type.
//
// Choosing T close to v*(w-1)/2 (the expected sum of a uniform
// hash) maximizes |C| and minimizes the signing retry rate
// (Lemma 7, DKKW25).
// assert!(
// TARGET_SUM <= MH::DIMENSION * (MH::BASE - 1),
// "Target Sum Encoding: TARGET_SUM must be at most DIMENSION * (BASE - 1)"
// );
}

// apply the message hash first to get chunks
Expand All @@ -87,6 +143,19 @@ impl<MH: MessageHash, const TARGET_SUM: usize> IncomparableEncoding
})
}
}

fn grind<PRF>(
parameter: &Self::Parameter,
prf_key: &PRF::Key,
epoch: u32,
message: &[u8; MESSAGE_LENGTH],
) -> Option<(Self::Randomness, Vec<u8>)>
where
PRF: Pseudorandom,
PRF::Randomness: Into<Self::Randomness>,
{
MH::grind_target_sum::<PRF, TARGET_SUM>(parameter, prf_key, epoch, message, Self::MAX_TRIES)
}
}

#[cfg(test)]
Expand All @@ -95,12 +164,14 @@ mod tests {
use crate::F;
use crate::array::FieldArray;
use crate::symmetric::message_hash::poseidon::PoseidonMessageHash445;
use crate::symmetric::prf::{Pseudorandom, shake_to_field::ShakePRFtoF};
use p3_field::PrimeField32;
use proptest::prelude::*;
use rand::RngExt;

const TEST_TARGET_SUM: usize = 115;
type TestTargetSumEncoding = TargetSumEncoding<PoseidonMessageHash445, TEST_TARGET_SUM>;
type TestPRF = ShakePRFtoF<4, 4>;

#[test]
fn test_successful_encoding_fixed_message() {
Expand Down Expand Up @@ -180,6 +251,32 @@ mod tests {
panic!("failed to find successful encoding after 1000 attempts");
}

#[test]
fn test_grind_matches_first_successful_attempt() {
let mut rng = rand::rng();
let parameter: FieldArray<4> = FieldArray(rng.random());
let message: [u8; 32] = rng.random();
let epoch = 7u32;
let prf_key = TestPRF::key_gen(&mut rng);

let expected = (0..TestTargetSumEncoding::MAX_TRIES).find_map(|attempt| {
let randomness = TestPRF::get_randomness(&prf_key, epoch, &message, attempt as u64);
TestTargetSumEncoding::encode(&parameter, &message, &randomness.into(), epoch)
.ok()
.map(|chunks| (randomness.into(), chunks))
});

let actual =
<TestTargetSumEncoding as IncomparableEncoding>::grind::<TestPRF>(
&parameter,
&prf_key,
epoch,
&message,
);

assert_eq!(actual, expected);
}

proptest! {
#[test]
fn proptest_encoding_determinism_and_error_reporting(
Expand Down
36 changes: 5 additions & 31 deletions src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,39 +814,13 @@ where
let path = combined_path(&sk.top_tree, bottom_tree, epoch);

// now, we need to encode our message using the incomparable encoding.
// we retry until we get a valid codeword, or until we give up.
let max_tries = IE::MAX_TRIES;
let mut attempts = 0;
let mut x = None;
let mut rho = None;
while attempts < max_tries {
// get a randomness and try to encode the message. Note: we get the randomness from the PRF
// which ensures that signing is deterministic. The PRF is applied to the message and the epoch.
// While the intention is that users of the scheme never call sign twice with the same (epoch, sk) pair,
// this deterministic approach ensures that calling sign twice is fine, as long as the message stays the same.
let curr_rho = PRF::get_randomness(&sk.prf_key, epoch, message, attempts as u64).into();
let curr_x = IE::encode(&sk.parameter.into(), message, &curr_rho, epoch);

// check if we have found a valid codeword, and if so, stop searching
if curr_x.is_ok() {
rho = Some(curr_rho);
x = curr_x.ok();
break;
}

attempts += 1;
}

// if we have not found a valid codeword, return an error
if x.is_none() {
// this search stays deterministic: we always return the first successful PRF counter.
let Some((rho, x)) = IE::grind::<PRF>(&sk.parameter.into(), &sk.prf_key, epoch, message)
else {
return Err(SigningError::EncodingAttemptsExceeded {
attempts: max_tries,
attempts: IE::MAX_TRIES,
});
}

// otherwise, unwrap x and rho
let x = x.unwrap();
let rho = rho.unwrap();
};

// we will include rho in the signature, and
// we use x to determine how far the signer walks in the chains
Expand Down
29 changes: 29 additions & 0 deletions src/symmetric/message_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rand::RngExt;

use crate::MESSAGE_LENGTH;
use crate::serialization::Serializable;
use crate::symmetric::prf::Pseudorandom;

pub use poseidon::encode_message;

Expand Down Expand Up @@ -39,6 +40,34 @@ pub trait MessageHash {
randomness: &Self::Randomness,
message: &[u8; MESSAGE_LENGTH],
) -> Result<Vec<u8>, Self::Error>;

/// Search deterministically for the first randomness whose chunks hit `TARGET_SUM`.
///
/// Implementations may override this with a batched or SIMD-accelerated search.
fn grind_target_sum<PRF, const TARGET_SUM: usize>(
parameter: &Self::Parameter,
prf_key: &PRF::Key,
epoch: u32,
message: &[u8; MESSAGE_LENGTH],
max_tries: usize,
) -> Option<(Self::Randomness, Vec<u8>)>
where
PRF: Pseudorandom,
PRF::Randomness: Into<Self::Randomness>,
{
for attempt in 0..max_tries {
let randomness = PRF::get_randomness(prf_key, epoch, message, attempt as u64).into();
let Ok(chunks) = Self::apply(parameter, epoch, &randomness, message) else {
continue;
};

if chunks.iter().map(|&chunk| chunk as usize).sum::<usize>() == TARGET_SUM {
return Some((randomness, chunks));
}
}

None
}
}

pub mod aborting;
Expand Down
61 changes: 58 additions & 3 deletions src/symmetric/message_hash/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use core::array;
use std::convert::Infallible;

use num_bigint::BigUint;
use p3_field::PrimeCharacteristicRing;
use p3_field::PrimeField;
use p3_field::PrimeField64;
use p3_field::{PackedValue, PrimeCharacteristicRing, PrimeField, PrimeField64};
use serde::{Serialize, de::DeserializeOwned};

use super::MessageHash;
Expand All @@ -12,7 +11,10 @@ use crate::MESSAGE_LENGTH;
use crate::TWEAK_SEPARATOR_FOR_MESSAGE_HASH;
use crate::array::FieldArray;
use crate::poseidon1_24;
use crate::simd_utils::pack_array;
use crate::symmetric::prf::Pseudorandom;
use crate::symmetric::tweak_hash::poseidon::poseidon_compress;
use crate::PackedF;

/// Function to encode a message as an array of field elements
pub fn encode_message<const MSG_LEN_FE: usize>(message: &[u8; MESSAGE_LENGTH]) -> [F; MSG_LEN_FE] {
Expand Down Expand Up @@ -241,6 +243,59 @@ where

Ok(decode_to_chunks::<DIMENSION, BASE, HASH_LEN_FE>(&hash_fe).to_vec())
}

fn grind_target_sum<PRF, const TARGET_SUM: usize>(
parameter: &Self::Parameter,
prf_key: &PRF::Key,
epoch: u32,
message: &[u8; MESSAGE_LENGTH],
max_tries: usize,
) -> Option<(Self::Randomness, Vec<u8>)>
where
PRF: Pseudorandom,
PRF::Randomness: Into<Self::Randomness>,
{
let perm = poseidon1_24();
let lanes = PackedF::WIDTH;

let packed_message: [PackedF; MSG_LEN_FE] =
encode_message::<MSG_LEN_FE>(message).map(PackedF::from);
let packed_parameter: [PackedF; PARAMETER_LEN] = array::from_fn(|i| PackedF::from(parameter[i]));
let packed_epoch: [PackedF; TWEAK_LEN_FE] = encode_epoch::<TWEAK_LEN_FE>(epoch).map(PackedF::from);

for batch_start in (0..max_tries).step_by(lanes) {
let valid_lanes = (max_tries - batch_start).min(lanes);
let randomnesses: [FieldArray<RAND_LEN_FE>; PackedF::WIDTH] = array::from_fn(|lane| {
let attempt = batch_start + lane.min(valid_lanes.saturating_sub(1));
PRF::get_randomness(prf_key, epoch, message, attempt as u64).into()
});
let packed_randomness = pack_array(&randomnesses);

let combined_input_len = MSG_LEN_FE + PARAMETER_LEN + TWEAK_LEN_FE + RAND_LEN_FE;
let mut combined_input = Vec::with_capacity(combined_input_len);
combined_input.extend_from_slice(&packed_message);
combined_input.extend_from_slice(&packed_parameter);
combined_input.extend_from_slice(&packed_epoch);
combined_input.extend_from_slice(&packed_randomness);

let packed_hash =
poseidon_compress::<PackedF, _, 24, HASH_LEN_FE>(&perm, &combined_input);

let mut unpacked_hashes = [FieldArray([F::ZERO; HASH_LEN_FE]); PackedF::WIDTH];
PackedF::unpack_into(&packed_hash, FieldArray::as_raw_slice_mut(&mut unpacked_hashes));

for lane in 0..valid_lanes {
let chunks =
decode_to_chunks::<DIMENSION, BASE, HASH_LEN_FE>(&unpacked_hashes[lane].0)
.to_vec();
if chunks.iter().map(|&chunk| chunk as usize).sum::<usize>() == TARGET_SUM {
return Some((randomnesses[lane], chunks));
}
}
}

None
}
}

// Example instantiations
Expand Down