From 97f04502a9e4dad98a8bfdc11253d502aa11c730 Mon Sep 17 00:00:00 2001 From: Adam Mohammed A Latif Date: Thu, 26 Mar 2026 16:21:39 +0000 Subject: [PATCH] fix: reduce PRF-to-field bias to meet 128-bit security target The previous 16-byte (128-bit) modular reduction gave statistical distance ~2^{-97} per field element, below the 2^{-128} target matching SHAKE128's security level. Per RFC 9380 (hash-to-field), L = ceil((ceil(log2(p)) + k) / 8) where k is the security parameter. For KoalaBear (31-bit prime) at k = 128: L = ceil((31 + 128) / 8) = 20 bytes. This change: - Increases PRF_BYTES_PER_FE from 16 to 20 - Adds a reduce_160_to_field() helper using native u128 arithmetic (no BigUint in the hot path) - Adds tests: BigUint reference equivalence, RFC byte-count derivation, and boundary values Deterministic-output breaking change: existing keys will produce different chain starts. Expected for pre-production research code. Partially addresses #10. --- src/symmetric/prf/shake_to_field.rs | 83 +++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/src/symmetric/prf/shake_to_field.rs b/src/symmetric/prf/shake_to_field.rs index 28c794f..f5ff2da 100644 --- a/src/symmetric/prf/shake_to_field.rs +++ b/src/symmetric/prf/shake_to_field.rs @@ -1,15 +1,39 @@ use crate::F; use super::Pseudorandom; -use p3_field::PrimeCharacteristicRing; +use p3_field::{PrimeCharacteristicRing, PrimeField64}; use serde::{Serialize, de::DeserializeOwned}; use sha3::{ Shake128, digest::{ExtendableOutput, Update, XofReader}, }; -/// Number of pseudorandom bytes to generate one pseudorandom field element -const PRF_BYTES_PER_FE: usize = 16; +/// Number of pseudorandom bytes to generate one pseudorandom field element. +/// +/// Per RFC 9380 (hash-to-field), L = ceil((ceil(log2(p)) + k) / 8) where k is +/// the security parameter. For KoalaBear (p = 2^31 - 2^24 + 1, ceil(log2(p)) = 31) +/// and k = 128 (matching SHAKE128): L = ceil((31 + 128) / 8) = 20. +/// +/// This gives a statistical distance from uniform of at most p / 2^161 < 2^{-129}, +/// meeting the 128-bit security target. +const PRF_BYTES_PER_FE: usize = 20; + +/// Reduce a 160-bit big-endian value to a field element with negligible bias. +/// +/// Splits the 20-byte input into a 128-bit high part and a 32-bit low part, +/// then computes (hi * 2^32 + lo) mod p using native u128 arithmetic. +#[inline] +fn reduce_160_to_field(buf: &[u8; PRF_BYTES_PER_FE]) -> F { + let hi = u128::from_be_bytes(buf[..16].try_into().unwrap()); + let lo = u32::from_be_bytes(buf[16..20].try_into().unwrap()) as u128; + + let p = F::ORDER_U64 as u128; + let hi_mod = hi % p; + let two_32_mod_p = (1u128 << 32) % p; + + let reduced = (hi_mod * two_32_mod_p + lo) % p; + F::from_u64(reduced as u64) +} const KEY_LENGTH: usize = 32; // 32 bytes const PRF_DOMAIN_SEP: [u8; 16] = [ @@ -61,14 +85,9 @@ where // Mapping bytes to field elements std::array::from_fn(|_| { - // Buffer to store the output let mut buf = [0u8; PRF_BYTES_PER_FE]; - - // Read the extended output into the buffer xof_reader.read(&mut buf); - - // Mapping bytes to a field element - F::from_u128(u128::from_be_bytes(buf)) + reduce_160_to_field(&buf) }) } @@ -105,14 +124,9 @@ where // Mapping bytes to field elements std::array::from_fn(|_| { - // Buffer to store the output let mut buf = [0u8; PRF_BYTES_PER_FE]; - - // Read the extended output into the buffer xof_reader.read(&mut buf); - - // Mapping bytes to a field element - F::from_u128(u128::from_be_bytes(buf)) + reduce_160_to_field(&buf) }) } } @@ -121,6 +135,8 @@ where mod tests { use super::*; use crate::MESSAGE_LENGTH; + use num_bigint::BigUint; + use p3_field::PrimeField64; use proptest::prelude::*; const DOMAIN_LEN: usize = 4; @@ -209,5 +225,42 @@ mod tests { let other_epoch = PRF::get_randomness(&key, epoch.wrapping_add(1), &msg, counter1); prop_assert_ne!(result1, other_epoch); } + + #[test] + fn proptest_reduce_160_matches_bigint_reference( + bytes in prop::array::uniform20(any::()) + ) { + let fast = reduce_160_to_field(&bytes); + + let value = BigUint::from_bytes_be(&bytes); + let p = BigUint::from(F::ORDER_U64); + let expected_u64: u64 = (value % p).try_into().unwrap(); + let reference = F::from_u64(expected_u64); + + prop_assert_eq!(fast, reference); + } + } + + #[test] + fn test_prf_bytes_per_fe_matches_rfc9380() { + let ceil_log2_p = 64 - (F::ORDER_U64 - 1).leading_zeros() as usize; + let k = 128; + let expected_l = (ceil_log2_p + k).div_ceil(8); + assert_eq!( + PRF_BYTES_PER_FE, expected_l, + "PRF_BYTES_PER_FE should be L = ceil((ceil(log2(p)) + k) / 8) per RFC 9380" + ); + } + + #[test] + fn test_reduce_160_boundary_values() { + let all_zeros = [0u8; PRF_BYTES_PER_FE]; + assert_eq!(reduce_160_to_field(&all_zeros), F::from_u64(0)); + + let all_ones = [0xff; PRF_BYTES_PER_FE]; + let value = BigUint::from_bytes_be(&all_ones); + let p = BigUint::from(F::ORDER_U64); + let expected: u64 = (value % p).try_into().unwrap(); + assert_eq!(reduce_160_to_field(&all_ones), F::from_u64(expected)); } }