diff --git a/README.md b/README.md index e46ecedd5..14c27a99d 100644 --- a/README.md +++ b/README.md @@ -81,9 +81,7 @@ cargo run --release -- fancy-aggregation ### XMSS -Currently, we use an [XMSS](crates/xmss/xmss.md) with hash digests of 4 field elements ≈ 124 bits. Tweaks and public parameters ensure domain separation. An analysis in the ROM (resp. QROM), inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103) would lead to ≈ 124 (resp. 62) bits of classical (resp. quantum) security. Going to 128 / 64 bits of classical / quantum security, i.e. NIST level 1 (in the ROM/QROM), is an ongoing effort. It requires either: -- hash digests of 5 field elements (drawback: we need to double the hash chain length from 8 to 16 if we want to stay below one IPv6 MTU = 1280 bytes) -- a new prime, close to 32 bits (typically p = 125.2^25 + 1) or 64 bits ([goldilocks](https://2π.com/22/goldilocks/)) +Currently, we use an [XMSS](crates/xmss/xmss.md) with hash digests of 4 field elements ≈ 128 bits. Tweaks and public parameters ensure domain separation. An analysis in the ROM (resp. QROM), inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103) would lead to ≈ 128 (resp. 64) bits of classical (resp. quantum) security. It's important to mention that a security analysis in the ROM / QROM is not the most conservative. In particular, [eprint 2025/055](https://eprint.iacr.org/2025/055.pdf)'s security proof holds in the standard model (at the cost of bigger hash digests): the implementation is available in the [leanSig](https://github.com/leanEthereum/leanSig) repository. A compatible version of leanMultisig can be found in the [devnet4](https://github.com/leanEthereum/leanMultisig/tree/devnet4) branch. diff --git a/crates/backend/field/src/exponentiation.rs b/crates/backend/field/src/exponentiation.rs index 2e9f567e4..3a71e8f7a 100644 --- a/crates/backend/field/src/exponentiation.rs +++ b/crates/backend/field/src/exponentiation.rs @@ -8,7 +8,7 @@ pub(crate) const fn bits_u64(n: u64) -> usize { /// Compute the exponential `x -> x^1420470955` using a custom addition chain. /// -/// This map computes the third root of `x` if `x` is a member of the field `KoalaBear`. +/// This map computes the third root of `x` if `x` is a member of the old KoalaBear field (p = 2^31 - 2^24 + 1). /// This follows from the computation: `3 * 1420470955 = 2*(2^31 - 2^24) + 1 = 1 mod (p - 1)`. #[must_use] pub fn exp_1420470955(val: R) -> R { @@ -30,3 +30,29 @@ pub fn exp_1420470955(val: R) -> R { let p1010100101010101010101010101010 = p101010010101010101010101010101.square(); p1010100101010101010101010101010 * p1 } + +/// Compute the exponential `x -> x^2796202667` using a custom addition chain. +/// +/// This map computes the third root of `x` if `x` is a member of the KoalaBear field (p = 125 * 2^25 + 1). +/// This follows from the computation: `3 * 2796202667 = 1 mod (p - 1)`. +#[must_use] +pub fn exp_2796202667(val: R) -> R { + // 2796202667 = 10100110101010101010101010101011_2 + // Uses 30 Squares + 8 Multiplications => 38 Operations total. + let p1 = val; + let p10 = p1.square(); + let p11 = p10 * p1; + let p101 = p10 * p11; + let p1010 = p101.square(); + let p10100 = p1010.square(); + let p101001 = p10100 * p1; + let p10100110 = p101001.exp_power_of_2(2); + let p101001101 = p10100110 * p1; + let p10100110101 = p101001101.exp_power_of_2(2); + let p1010011010101 = p10100110101.exp_power_of_2(2) * p1; + let p10100110101010101 = p1010011010101.exp_power_of_2(4) * p101; + let p101001101010101010101 = p10100110101010101.exp_power_of_2(4) * p101; + let p1010011010101010101010101 = p101001101010101010101.exp_power_of_2(4) * p101; + let p10100110101010101010101010101 = p1010011010101010101010101.exp_power_of_2(4) * p101; + p10100110101010101010101010101.exp_power_of_2(2) * p11 +} diff --git a/crates/backend/field/src/packed/aarch64_neon.rs b/crates/backend/field/src/packed/aarch64_neon.rs index 9e65bde11..67227c669 100644 --- a/crates/backend/field/src/packed/aarch64_neon.rs +++ b/crates/backend/field/src/packed/aarch64_neon.rs @@ -25,72 +25,57 @@ fn uint32x4_to_array(input: uint32x4_t) -> [u32; 4] { /// Add the packed vectors `a` and `b` modulo `p`. /// -/// This allows us to add 4 elements at once. -/// -/// Assumes that `p` is less than `2^31` and `a + b <= 2P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod p`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. The result will be in `[0, P)`. +/// Works for any P < 2^32, including P > 2^31 where a + b may overflow u32. #[inline] #[must_use] pub fn uint32x4_mod_add(a: uint32x4_t, b: uint32x4_t, p: uint32x4_t) -> uint32x4_t { - // We want this to compile to: - // add t.4s, a.4s, b.4s - // sub u.4s, t.4s, P.4s - // umin res.4s, t.4s, u.4s - // throughput: .75 cyc/vec (5.33 els/cyc) - // latency: 6 cyc - - // See field/src/packed/x86_64_avx.rs for a proof of correctness of this algorithm. - + // Uses saturating add to detect "a + b >= P" in one comparison: + // sat = min(a+b, 2^32-1). If a+b >= 2^32, sat = 2^32-1 >= P. If a+b < 2^32, sat = a+b. + // Either way, sat >= P iff a+b >= P. + // + // add t.4s, a.4s, b.4s // wrapping add + // sub u.4s, t.4s, P.4s // wrapping sub P + // uqadd sat.4s, a.4s, b.4s // saturating add + // cmhs mask.4s, sat.4s, P.4s // sat >= P ? + // bsl mask.4s, u.4s, t.4s // select + // throughput: 1.25 cyc/vec (3.2 els/cyc) + // latency: 8 cyc unsafe { - // Safety: If this code got compiled then NEON intrinsics are available. let t = aarch64::vaddq_u32(a, b); let u = aarch64::vsubq_u32(t, p); - aarch64::vminq_u32(t, u) + let sat = aarch64::vqaddq_u32(a, b); // saturating: min(a+b, 2^32-1) + let mask = aarch64::vcgeq_u32(sat, p); // sat >= P iff a+b >= P + aarch64::vbslq_u32(mask, u, t) } } /// Subtract the packed vectors `a` and `b` modulo `p`. /// -/// This allows us to subtract 4 elements at once. -/// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. The result will be in `[0, P)`. +/// Works for any P < 2^32, including P > 2^31. #[inline] #[must_use] pub fn uint32x4_mod_sub(a: uint32x4_t, b: uint32x4_t, p: uint32x4_t) -> uint32x4_t { - // We want this to compile to: - // sub t.4s, a.4s, b.4s - // add u.4s, t.4s, P.4s - // umin res.4s, t.4s, u.4s - // throughput: .75 cyc/vec (5.33 els/cyc) - // latency: 6 cyc - - // See field/src/packed/x86_64_avx.rs for a proof of correctness of this algorithm. - + // Algorithm: t = a - b (wrapping). If a < b (borrow), result = t + P; otherwise result = t. + // + // sub t.4s, a.4s, b.4s + // cmhi borrow.4s, b.4s, a.4s // b > a means borrow + // and corr.4s, borrow.4s, P.4s + // add res.4s, t.4s, corr.4s + // throughput: 1 cyc/vec (4 els/cyc) + // latency: 8 cyc unsafe { - // Safety: If this code got compiled then NEON intrinsics are available. let t = aarch64::vsubq_u32(a, b); - let u = aarch64::vaddq_u32(t, p); - aarch64::vminq_u32(t, u) + let borrow = aarch64::vcgtq_u32(b, a); // b > a means borrow + let corr = aarch64::vandq_u32(borrow, p); + aarch64::vaddq_u32(t, corr) } } /// Add two arrays of integers modulo `P` using packings. /// -/// Assumes that `P` is less than `2^31` and `a + b <= 2P` for all array pairs `a, b`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod P`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. -/// -/// Scalar add is assumed to be a function which implements `a + b % P` with the -/// same specifications as above. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. Works for P > 2^31. /// /// TODO: Add support for extensions of degree 2,3,6,7. #[inline(always)] @@ -152,14 +137,7 @@ pub fn packed_mod_add( /// Subtract two arrays of integers modulo `P` using packings. /// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. -/// -/// Scalar sub is assumed to be a function which implements `a - b % P` with the -/// same specifications as above. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. Works for P > 2^31. /// /// TODO: Add support for extensions of degree 2,3,6,7. #[inline(always)] diff --git a/crates/backend/field/src/packed/x86_64_avx.rs b/crates/backend/field/src/packed/x86_64_avx.rs index 66a9ab241..d16698510 100644 --- a/crates/backend/field/src/packed/x86_64_avx.rs +++ b/crates/backend/field/src/packed/x86_64_avx.rs @@ -7,189 +7,130 @@ use core::arch::x86_64::__m512i; use core::arch::x86_64::{self, __m128i, __m256i}; use core::mem::transmute; -// Goal: Compute r = lhs + rhs mod P for lhs, rhs <= P < 2^31 -// Output should mostly lie in [0, P) but is allowed to equal P if lhs = rhs = P. +// Modular addition/subtraction for P < 2^32. // -// Let t := lhs + rhs. Clearly t \in [0, 2P] -// Define u := (t - P) mod 2^32 and r := min(t, u) (Note that it is crucial this is an unsigned min) -// We argue by cases. -// - If t is in [0, P), then due to wraparound, u is in [2^32 - P, 2^32 - 1). As -// 2^32 - P > P - 1, we conclude that r = t lies in the correct range. -// - If t is in [P, 2 P], then u is in [0, P] and r = u lies in the correct range. -// As both t and u are both equal to lhs + rhs mod P, we conclude that -// r = (lhs + rhs) mod P and lies in the correct range. +// For add: t = a + b (wrapping u32). If overflow occurred or t >= P, subtract P. +// Overflow detection: t < a (unsigned) means carry. // -// An identical idea works for subtraction. -// Set t := lhs - rhs, u := t + P and output r := min(t, u). +// For sub: t = a - b (wrapping u32). If borrow (a < b), add P. +// These algorithms work for any P < 2^32. -/// Add the packed vectors `a` and `b` modulo `p`. +/// Add the packed vectors `a` and `b` modulo `p` (SSE, 4 elements). /// -/// This allows us to add 4 elements at once. -/// -/// Assumes that `p` is less than `2^31` and `a + b <= 2P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod p`. -/// It will be equal to `P` if and only if `a + b = 2P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] fn mm128_mod_add(a: __m128i, b: __m128i, p: __m128i) -> __m128i { - // We want this to compile to: - // paddd t, lhs, rhs - // psubd u, t, P - // pminud res, t, u - // throughput: 1 cyc/vec (8 els/cyc) - // latency: 3 cyc - unsafe { let t = x86_64::_mm_add_epi32(a, b); let u = x86_64::_mm_sub_epi32(t, p); - x86_64::_mm_min_epu32(t, u) + // Detect carry: flip sign bits for unsigned comparison + let flip = x86_64::_mm_set1_epi32(i32::MIN); + let a_f = x86_64::_mm_xor_si128(a, flip); + let t_f = x86_64::_mm_xor_si128(t, flip); + let overflow = x86_64::_mm_cmpgt_epi32(a_f, t_f); // a > t unsigned → overflow + let t_f2 = t_f; // reuse + let p_m1_f = x86_64::_mm_xor_si128(x86_64::_mm_sub_epi32(p, x86_64::_mm_set1_epi32(1)), flip); + let geq_p = x86_64::_mm_cmpgt_epi32(t_f2, p_m1_f); // t > p-1 unsigned → t >= p + let mask = x86_64::_mm_or_si128(overflow, geq_p); + // blend: (mask & u) | (~mask & t) + let sel_u = x86_64::_mm_and_si128(mask, u); + let sel_t = x86_64::_mm_andnot_si128(mask, t); + x86_64::_mm_or_si128(sel_u, sel_t) } } -/// Subtract the packed vectors `a` and `b` modulo `p`. +/// Subtract the packed vectors `a` and `b` modulo `p` (SSE, 4 elements). /// -/// This allows us to subtract 4 elements at once. -/// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] fn mm128_mod_sub(a: __m128i, b: __m128i, p: __m128i) -> __m128i { - // We want this to compile to: - // psubd t, lhs, rhs - // paddd u, t, P - // pminud res, t, u - // throughput: 1 cyc/vec (8 els/cyc) - // latency: 3 cyc - unsafe { let t = x86_64::_mm_sub_epi32(a, b); - let u = x86_64::_mm_add_epi32(t, p); - x86_64::_mm_min_epu32(t, u) + // Detect borrow: b > a unsigned + let flip = x86_64::_mm_set1_epi32(i32::MIN); + let a_f = x86_64::_mm_xor_si128(a, flip); + let b_f = x86_64::_mm_xor_si128(b, flip); + let borrow = x86_64::_mm_cmpgt_epi32(b_f, a_f); + let corr = x86_64::_mm_and_si128(borrow, p); + x86_64::_mm_add_epi32(t, corr) } } -/// Add the packed vectors `a` and `b` modulo `p`. +/// Add the packed vectors `lhs` and `rhs` modulo `p` (AVX2, 8 elements). /// -/// This allows us to add 8 elements at once. -/// -/// Assumes that `p` is less than `2^31` and `a + b <= 2P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod p`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] pub fn mm256_mod_add(lhs: __m256i, rhs: __m256i, p: __m256i) -> __m256i { - // We want this to compile to: - // vpaddd t, lhs, rhs - // vpsubd u, t, P - // vpminud res, t, u - // throughput: 1 cyc/vec (8 els/cyc) - // latency: 3 cyc - unsafe { let t = x86_64::_mm256_add_epi32(lhs, rhs); let u = x86_64::_mm256_sub_epi32(t, p); - x86_64::_mm256_min_epu32(t, u) + let flip = x86_64::_mm256_set1_epi32(i32::MIN); + let lhs_f = x86_64::_mm256_xor_si256(lhs, flip); + let t_f = x86_64::_mm256_xor_si256(t, flip); + let overflow = x86_64::_mm256_cmpgt_epi32(lhs_f, t_f); + let p_m1_f = x86_64::_mm256_xor_si256(x86_64::_mm256_sub_epi32(p, x86_64::_mm256_set1_epi32(1)), flip); + let geq_p = x86_64::_mm256_cmpgt_epi32(t_f, p_m1_f); + let mask = x86_64::_mm256_or_si256(overflow, geq_p); + x86_64::_mm256_blendv_epi8(t, u, mask) } } -/// Subtract the packed vectors `a` and `b` modulo `p`. -/// -/// This allows us to subtract 8 elements at once. +/// Subtract the packed vectors `lhs` and `rhs` modulo `p` (AVX2, 8 elements). /// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] pub fn mm256_mod_sub(lhs: __m256i, rhs: __m256i, p: __m256i) -> __m256i { - // We want this to compile to: - // vpsubd t, lhs, rhs - // vpaddd u, t, P - // vpminud res, t, u - // throughput: 1 cyc/vec (8 els/cyc) - // latency: 3 cyc - unsafe { let t = x86_64::_mm256_sub_epi32(lhs, rhs); - let u = x86_64::_mm256_add_epi32(t, p); - x86_64::_mm256_min_epu32(t, u) + let flip = x86_64::_mm256_set1_epi32(i32::MIN); + let lhs_f = x86_64::_mm256_xor_si256(lhs, flip); + let rhs_f = x86_64::_mm256_xor_si256(rhs, flip); + let borrow = x86_64::_mm256_cmpgt_epi32(rhs_f, lhs_f); + let corr = x86_64::_mm256_and_si256(borrow, p); + x86_64::_mm256_add_epi32(t, corr) } } #[cfg(target_feature = "avx512f")] -/// Add the packed vectors `a` and `b` modulo `p`. -/// -/// This allows us to add 16 elements at once. +/// Add the packed vectors `lhs` and `rhs` modulo `p` (AVX-512, 16 elements). /// -/// Assumes that `p` is less than `2^31` and `a + b <= 2P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod p`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] pub fn mm512_mod_add(lhs: __m512i, rhs: __m512i, p: __m512i) -> __m512i { - // We want this to compile to: - // vpaddd t, lhs, rhs - // vpsubd u, t, P - // vpminud res, t, u - // throughput: 1.5 cyc/vec (10.67 els/cyc) - // latency: 3 cyc - unsafe { let t = x86_64::_mm512_add_epi32(lhs, rhs); let u = x86_64::_mm512_sub_epi32(t, p); - x86_64::_mm512_min_epu32(t, u) + // AVX-512 has native unsigned comparison + let overflow = x86_64::_mm512_cmplt_epu32_mask(t, lhs); // t < lhs → overflow + let geq_p = x86_64::_mm512_cmpge_epu32_mask(t, p); // t >= P + let mask = overflow | geq_p; + x86_64::_mm512_mask_mov_epi32(t, mask, u) } } #[cfg(target_feature = "avx512f")] -/// Subtract the packed vectors `a` and `b` modulo `p`. -/// -/// This allows us to subtract 16 elements at once. +/// Subtract the packed vectors `lhs` and `rhs` modulo `p` (AVX-512, 16 elements). /// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` in `[0, P)` where `P < 2^32`. Works for P > 2^31. #[inline(always)] #[must_use] pub fn mm512_mod_sub(lhs: __m512i, rhs: __m512i, p: __m512i) -> __m512i { - // We want this to compile to: - // vpsubd t, lhs, rhs - // vpaddd u, t, P - // vpminud res, t, u - // throughput: 1.5 cyc/vec (10.67 els/cyc) - // latency: 3 cyc - unsafe { - // Safety: If this code got compiled then AVX-512F intrinsics are available. let t = x86_64::_mm512_sub_epi32(lhs, rhs); - let u = x86_64::_mm512_add_epi32(t, p); - x86_64::_mm512_min_epu32(t, u) + let borrow = x86_64::_mm512_cmplt_epu32_mask(lhs, rhs); // lhs < rhs → borrow + let p_masked = x86_64::_mm512_maskz_mov_epi32(borrow, p); + x86_64::_mm512_add_epi32(t, p_masked) } } /// Add two arrays of integers modulo `P` using packings. /// -/// Assumes that `P` is less than `2^31` and `a + b <= 2P` for all array pairs `a, b`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod P`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. -/// -/// Scalar add is assumed to be a function which implements `a + b % P` with the -/// same specifications as above. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. Works for P > 2^31. /// /// TODO: Add support for extensions of degree 2,3,6,7. #[inline(always)] @@ -248,14 +189,7 @@ pub fn packed_mod_add( /// Subtract two arrays of integers modulo `P` using packings. /// -/// Assumes that `p` is less than `2^31` and `|a - b| <= P`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod p`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. -/// -/// Scalar sub is assumed to be a function which implements `a - b % P` with the -/// same specifications as above. +/// Assumes `a, b` are in `[0, P)` where `P < 2^32`. Works for P > 2^31. /// /// TODO: Add support for extensions of degree 2,3,6,7. #[inline(always)] diff --git a/crates/backend/koala-bear/src/aarch64_neon/packing.rs b/crates/backend/koala-bear/src/aarch64_neon/packing.rs index 1dfc6bb6b..1769aa369 100644 --- a/crates/backend/koala-bear/src/aarch64_neon/packing.rs +++ b/crates/backend/koala-bear/src/aarch64_neon/packing.rs @@ -10,9 +10,9 @@ use crate::KoalaBearParameters; const WIDTH: usize = 4; impl MontyParametersNeon for KoalaBearParameters { - const PACKED_P: uint32x4_t = unsafe { transmute::<[u32; WIDTH], _>([0x7f000001; WIDTH]) }; - // This MU is the same 0x88000001 as elsewhere, just interpreted as an `i32`. - const PACKED_MU: int32x4_t = unsafe { transmute::<[i32; WIDTH], _>([-0x7effffff; WIDTH]) }; + const PACKED_P: uint32x4_t = unsafe { transmute::<[u32; WIDTH], _>([0xfa000001; WIDTH]) }; + // MU = 0x06000001, which fits in positive i32. + const PACKED_MU: int32x4_t = unsafe { transmute::<[i32; WIDTH], _>([0x06000001_i32; WIDTH]) }; } pub type PackedKoalaBearNeon = PackedMontyField31Neon; diff --git a/crates/backend/koala-bear/src/koala_bear.rs b/crates/backend/koala-bear/src/koala_bear.rs index 116c3d098..f10df49f3 100644 --- a/crates/backend/koala-bear/src/koala_bear.rs +++ b/crates/backend/koala-bear/src/koala_bear.rs @@ -5,24 +5,23 @@ use crate::monty_31::{ TwoAdicData, }; use field::PrimeCharacteristicRing; -use field::exponentiation::exp_1420470955; +use field::exponentiation::exp_2796202667; -/// The prime field `2^31 - 2^24 + 1`, a.k.a. the Koala Bear field. +/// The prime field `125 * 2^25 + 1 = 4194304001`, a.k.a. the Koala Bear field. pub type KoalaBear = MontyField31; #[derive(Copy, Clone, Default, Debug, Eq, Hash, PartialEq)] pub struct KoalaBearParameters; impl MontyParameters for KoalaBearParameters { - /// The KoalaBear prime: 2^31 - 2^24 + 1 - /// This is a 31-bit prime with the highest possible two adicity if we additionally demand that - /// the cube map (x -> x^3) is an automorphism of the multiplicative group. - /// It's not unique, as there is one other option with equal 2 adicity: 2^30 + 2^27 + 2^24 + 1. - /// There is also one 29-bit prime with higher two adicity which might be appropriate for some applications: 2^29 - 2^26 + 1. - const PRIME: u32 = 0x7f000001; + /// The KoalaBear prime: 125 * 2^25 + 1 = 4194304001 + /// This is a 32-bit prime with two-adicity 25 and the cube map (x -> x^3) is an + /// automorphism of the multiplicative group (since gcd(3, p-1) = 1). + /// Note: the sum of 2 elements does NOT fit in a u32, requiring u64 intermediates for addition. + const PRIME: u32 = 0xfa000001; const MONTY_BITS: u32 = 32; - const MONTY_MU: u32 = 0x81000001; + const MONTY_MU: u32 = 0x06000001; } impl PackedMontyParameters for KoalaBearParameters {} @@ -34,32 +33,32 @@ impl FieldParameters for KoalaBearParameters { } impl RelativelyPrimePower<3> for KoalaBearParameters { - /// In the field `KoalaBear`, `a^{1/3}` is equal to a^{1420470955}. + /// In the field `KoalaBear`, `a^{1/3}` is equal to a^{2796202667}. /// - /// This follows from the calculation `3 * 1420470955 = 2*(2^31 - 2^24) + 1 = 1 mod (p - 1)`. + /// This follows from the calculation `3 * 2796202667 = 8388608001 = 2*(125 * 2^25) + 1 = 1 mod (p - 1)`. fn exp_root_d(val: R) -> R { - exp_1420470955(val) + exp_2796202667(val) } } impl TwoAdicData for KoalaBearParameters { - const TWO_ADICITY: usize = 24; + const TWO_ADICITY: usize = 25; type ArrayLike = &'static [KoalaBear]; const TWO_ADIC_GENERATORS: Self::ArrayLike = &KoalaBear::new_array([ - 0x1, 0x7f000000, 0x7e010002, 0x6832fe4a, 0x8dbd69c, 0xa28f031, 0x5c4a5b99, 0x29b75a80, 0x17668b8a, 0x27ad539b, - 0x334d48c7, 0x7744959c, 0x768fc6fa, 0x303964b2, 0x3e687d4d, 0x45a60e61, 0x6e2f4d7a, 0x163bd499, 0x6c4a8a45, - 0x143ef899, 0x514ddcad, 0x484ef19b, 0x205d63c3, 0x68e7dd49, 0x6ac49f88, + 0x1, 0xfa000000, 0x304096c9, 0x894b5390, 0x6b52061e, 0xad3c2449, 0x15fe844d, 0x78c80fc6, 0x6f53c3e8, + 0xbde222a7, 0xb8d15cfe, 0xeda3085d, 0x796cdd9b, 0xdb8107f4, 0x5e491875, 0xcf40ad0, 0x2526aeba, 0x1df6fb4c, + 0x2f221af1, 0x40593728, 0xcd1100d7, 0x64a4ed0b, 0x9782cd0e, 0xaf03bc88, 0x99d352d0, 0x4633584e, ]); - const ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x1, 0x6832fe4a, 0x7e010002, 0x174e3650]); - const INV_ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x1, 0x67b1c9b1, 0xfeffff, 0x16cd01b7]); + const ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x1, 0x894b5390, 0x304096c9, 0x2e9b3a27]); + const INV_ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x1, 0xcb64c5da, 0xc9bf6938, 0x70b4ac71]); const ROOTS_16: Self::ArrayLike = &KoalaBear::new_array([ - 0x1, 0x8dbd69c, 0x6832fe4a, 0x27ae21e2, 0x7e010002, 0x3a89a025, 0x174e3650, 0x27dfce22, + 0x1, 0x6b52061e, 0x894b5390, 0x39f910ef, 0x304096c9, 0x33c5a441, 0x2e9b3a27, 0x9d09df4b, ]); const INV_ROOTS_16: Self::ArrayLike = &KoalaBear::new_array([ - 0x1, 0x572031df, 0x67b1c9b1, 0x44765fdc, 0xfeffff, 0x5751de1f, 0x16cd01b7, 0x76242965, + 0x1, 0x5cf620b6, 0xcb64c5da, 0xc63a5bc0, 0xc9bf6938, 0xc006ef12, 0x70b4ac71, 0x8eadf9e3, ]); } diff --git a/crates/backend/koala-bear/src/monty_31/aarch64_neon/packing.rs b/crates/backend/koala-bear/src/monty_31/aarch64_neon/packing.rs index 2c2078d9b..2be8595dc 100644 --- a/crates/backend/koala-bear/src/monty_31/aarch64_neon/packing.rs +++ b/crates/backend/koala-bear/src/monty_31/aarch64_neon/packing.rs @@ -1,9 +1,7 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). use alloc::vec::Vec; -use core::arch::aarch64::{self, int32x4_t, uint32x4_t}; -use core::arch::asm; -use core::hint::unreachable_unchecked; +use core::arch::aarch64::{self, int32x4_t, uint32x4_t, uint64x2_t}; use core::iter::{Product, Sum}; use core::mem::transmute; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -73,21 +71,6 @@ impl PackedMontyField31Neon { } } - /// Get an arch-specific vector representing the packed values. - #[inline] - #[must_use] - pub(crate) fn to_signed_vector(self) -> int32x4_t { - unsafe { - // Safety: `MontyField31` is `repr(transparent)` so it can be transmuted to `u32` furthermore - // the u32 is guaranteed to be less than `2^31` so it can be safely reinterpreted as an `i32`. It - // follows that `[MontyField31; WIDTH]` can be transmuted to `[i32; WIDTH]`, which can be - // transmuted to `int32x4_t`, since arrays are guaranteed to be contiguous in memory. - // Finally `PackedMontyField31Neon` is `repr(transparent)` so it can be transmuted to - // `[MontyField31; WIDTH]`. - transmute(self) - } - } - /// Make a packed field vector from an arch-specific vector. /// /// SAFETY: The caller must ensure that each element of `vector` represents a valid `MontyField31`. @@ -165,11 +148,11 @@ impl Mul for PackedMontyField31Neon { type Output = Self; #[inline] fn mul(self, rhs: Self) -> Self { - let lhs = self.to_signed_vector(); - let rhs = rhs.to_signed_vector(); - let res = mul::(lhs, rhs); + let lhs = self.to_vector(); + let rhs = rhs.to_vector(); + let res = mul_unsigned::(lhs, rhs); unsafe { - // Safety: `mul` returns values in canonical form when given values in canonical form. + // Safety: `mul_unsigned` returns values in canonical form when given values in canonical form. Self::from_vector(res) } } @@ -206,10 +189,10 @@ impl PrimeCharacteristicRing for PackedMontyField31Neon #[inline] fn cube(&self) -> Self { - let val = self.to_signed_vector(); - let res = cube::(val); + let val = self.to_vector(); + let res = cube_unsigned::(val); unsafe { - // Safety: `cube` returns values in canonical form when given values in canonical form. + // Safety: `cube_unsigned` returns values in canonical form when given values in canonical form. Self::from_vector(res) } } @@ -236,19 +219,17 @@ impl PrimeCharacteristicRing for PackedMontyField31Neon 3 => self.cube(), 4 => self.square().square(), 5 => { - let val = self.to_signed_vector(); + let val = self.to_vector(); unsafe { - // Safety: `exp_5` returns values in canonical form when given values in canonical form. - let res = exp_5::(val); + let res = exp_5_unsigned::(val); Self::from_vector(res) } } 6 => self.square().cube(), 7 => { - let val = self.to_signed_vector(); + let val = self.to_vector(); unsafe { - // Safety: `exp_7` returns values in canonical form when given values in canonical form. - let res = exp_7::(val); + let res = exp_7_unsigned::(val); Self::from_vector(res) } } @@ -287,32 +268,6 @@ impl, const D: u64> PermutationMon } } -/// No-op. Prevents the compiler from deducing the value of the vector. -/// -/// Similar to `core::hint::black_box`, it can be used to stop the compiler applying undesirable -/// "optimizations". Unlike the built-in `black_box`, it does not force the value to be written to -/// and then read from the stack. -#[inline] -#[must_use] -fn confuse_compiler(x: uint32x4_t) -> uint32x4_t { - let y; - unsafe { - asm!( - "/*{0:v}*/", - inlateout(vreg) x => y, - options(nomem, nostack, preserves_flags, pure), - ); - // Below tells the compiler the semantics of this so it can still do constant folding, etc. - // You may ask, doesn't it defeat the point of the inline asm block to tell the compiler - // what it does? The answer is that we still inhibit the transform we want to avoid, so - // apparently not. Idk, LLVM works in mysterious ways. - if transmute::(x) != transmute::(y) { - unreachable_unchecked(); - } - } - y -} - // MONTGOMERY MULTIPLICATION // This implementation is based on [1] but with changes. The reduction is as follows: // @@ -341,223 +296,93 @@ fn confuse_compiler(x: uint32x4_t) -> uint32x4_t { // [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press, // 2010, algorithm 2.7. -#[inline] -#[must_use] -fn mulby_mu(val: int32x4_t) -> int32x4_t { - // We want this to compile to: - // mul res.4s, val.4s, MU.4s - // throughput: .25 cyc/vec (16 els/cyc) - // latency: 3 cyc - - unsafe { aarch64::vmulq_s32(val, MPNeon::PACKED_MU) } -} - -#[inline] -#[must_use] -fn get_c_hi(lhs: int32x4_t, rhs: int32x4_t) -> int32x4_t { - // We want this to compile to: - // sqdmulh c_hi.4s, lhs.4s, rhs.4s - // throughput: .25 cyc/vec (16 els/cyc) - // latency: 3 cyc - - unsafe { - // Get bits 31, ..., 62 of C. Note that `sqdmulh` saturates when the product doesn't fit in - // an `i63`, but this cannot happen here due to our bounds on `lhs` and `rhs`. - aarch64::vqdmulhq_s32(lhs, rhs) - } -} - -#[inline] -#[must_use] -fn get_qp_hi(lhs: int32x4_t, mu_rhs: int32x4_t) -> int32x4_t { - // We want this to compile to: - // mul q.4s, lhs.4s, mu_rhs.4s - // sqdmulh qp_hi.4s, q.4s, P.4s - // throughput: .5 cyc/vec (8 els/cyc) - // latency: 6 cyc - - unsafe { - // Form `Q`. - let q = aarch64::vmulq_s32(lhs, mu_rhs); - - // Gets bits 31, ..., 62 of Q P. Again, saturation is not an issue because `P` is not - // -2**31. - aarch64::vqdmulhq_s32(q, aarch64::vreinterpretq_s32_u32(MPNeon::PACKED_P)) - } -} - -/// Multiply MontyField31 field elements. +/// Montgomery reduction of a 64-bit product to canonical form [0, P). /// -/// # Safety -/// Inputs must be signed 32-bit integers in the range [-P, P]. -/// Outputs will be a unsigned 32-bit integers in canonical form [0, ..., P). -#[inline] -#[must_use] -fn mul(lhs: int32x4_t, rhs: int32x4_t) -> uint32x4_t { - // We want this to compile to: - // sqdmulh c_hi.4s, lhs.4s, rhs.4s - // mul mu_rhs.4s, rhs.4s, MU.4s - // mul q.4s, lhs.4s, mu_rhs.4s - // sqdmulh qp_hi.4s, q.4s, P.4s - // shsub res.4s, c_hi.4s, qp_hi.4s - // cmgt underflow.4s, qp_hi.4s, c_hi.4s - // mls res.4s, underflow.4s, P.4s - // throughput: 1.75 cyc/vec (2.29 els/cyc) - // latency: (lhs->) 11 cyc, (rhs->) 14 cyc - - unsafe { - let mu_rhs = mulby_mu::(rhs); - let d = mul_with_precomp::(lhs, rhs, mu_rhs); - - // Safe as mul_with_precomp:: returns integers in [0, P) - aarch64::vreinterpretq_u32_s32(d) - } -} - -/// Multiply MontyField31 field elements using precomputation. -/// -/// Allows us to reuse `mu_rhs`. +/// Given C (64-bit unsigned per lane, split into low and high halves), +/// computes D = (C - Q*P) / 2^32 mod P where Q = C*MU mod 2^32. /// /// # Safety -/// Both `lhs` and `rhs` must be signed 32-bit integers in the range [-P, P]. -/// `mu_rhs` must be equal to `MPNeon::PACKED_MU * rhs mod 2^32` -/// -/// Output will be signed 32-bit integers either in (-P, P) if CANONICAL is set to false -/// or in [0, P) if CANONICAL is set to true. +/// C must be < 2^32 * P per lane (guaranteed for a single product of values in [0, P)). #[inline] #[must_use] -fn mul_with_precomp( - lhs: int32x4_t, - rhs: int32x4_t, - mu_rhs: int32x4_t, -) -> int32x4_t { - // If CANONICAL: - // We want this to compile to: - // sqdmulh c_hi.4s, lhs.4s, rhs.4s - // mul q.4s, lhs.4s, mu_rhs.4s - // sqdmulh qp_hi.4s, q.4s, P.4s - // shsub res.4s, c_hi.4s, qp_hi.4s - // cmgt underflow.4s, qp_hi.4s, c_hi.4s - // mls res.4s, underflow.4s, P.4s - // - // throughput: 1.5 cyc/vec (2.66 els/cyc) - // latency: 11 cyc +unsafe fn monty_reduce_neon(c_l: uint64x2_t, c_h: uint64x2_t) -> uint32x4_t { + // Montgomery reduction: D = (C - Q*P) / 2^32, then canonicalize D ∈ (-P, P) → [0, P). // - // If !CANONICAL: - // We want this to compile to: - // sqdmulh c_hi.4s, lhs.4s, rhs.4s - // mul q.4s, lhs.4s, mu_rhs.4s - // sqdmulh qp_hi.4s, q.4s, P.4s - // shsub res.4s, c_hi.4s, qp_hi.4s + // Key trick: since C_lo ≡ (qP)_lo (mod 2^32) by construction, the 64-bit subtraction + // d = C - qP has zero low 32 bits and the borrow propagation only affects the high word. + // So: d_hi = c_hi - qp_hi (u32 wrapping), and borrow ↔ d_hi > c_hi (unsigned). // - // throughput: 1 cyc/vec (4 els/cyc) - // latency: 8 cyc + // vuzp1 c_lo, c_l, c_h // extract low 32 bits + // vuzp2 c_hi, c_l, c_h // extract high 32 bits + // vmul q, c_lo, MU // q = c_lo * MU mod 2^32 + // vmlsl d_l, c_l, q_lo, P_lo // d_l = c_l - q_lo*P_lo (64-bit) + // vmlsl2 d_h, c_h, q, P // d_h = c_h - q_hi*P_hi (64-bit) + // vuzp2 d_hi, d_l, d_h // extract D_u32 + // cmhi borrow, d_hi, c_hi // borrow: d_hi > c_hi (unsigned 32-bit) + // and corr, borrow, P + // add result, d_hi, corr // + // 9 instructions, throughput ~2.25 cyc/vec. unsafe { - let c_hi = get_c_hi(lhs, rhs); - let qp_hi = get_qp_hi::(lhs, mu_rhs); - let d = aarch64::vhsubq_s32(c_hi, qp_hi); - - // This branch will be removed by the compiler. - if CANONICAL { - // We reduce d to canonical form. d is negative iff `c_hi > qp_hi`, so if that's the - // case then we add P. Note that if `c_hi > qp_hi` then `underflow` is -1, so we must - // _subtract_ `underflow` * P. - let underflow = aarch64::vcltq_s32(c_hi, qp_hi); - - // As underflow and MPNeon::PACKED_P are unsigned we use the unsigned version of multiply - // and subtract. Note that on bits, the signed and unsigned versions are literally identical. - let reduced = aarch64::vmlsq_u32( - aarch64::vreinterpretq_u32_s32(d), - confuse_compiler(underflow), - MPNeon::PACKED_P, - ); - - // We convert back to int32x4_t to match the function output. - aarch64::vreinterpretq_s32_u32(reduced) - } else { - d - } - } -} + let c_lo = aarch64::vuzp1q_u32(aarch64::vreinterpretq_u32_u64(c_l), aarch64::vreinterpretq_u32_u64(c_h)); + let c_hi = aarch64::vuzp2q_u32(aarch64::vreinterpretq_u32_u64(c_l), aarch64::vreinterpretq_u32_u64(c_h)); -/// Take cube of MontyField31 field elements. -/// -/// # Safety -/// Inputs must be signed 32-bit integers in the range [-P, P]. -/// Outputs will be a unsigned 32-bit integers in canonical form [0, ..., P). -#[inline] -#[must_use] -fn cube(val: int32x4_t) -> uint32x4_t { - // throughput: 2.75 cyc/vec (1.45 els/cyc) - // latency: 22 cyc + let mu = aarch64::vreinterpretq_u32_s32(MPNeon::PACKED_MU); + let q = aarch64::vmulq_u32(c_lo, mu); - unsafe { - let mu_val = mulby_mu::(val); + let d_l = aarch64::vmlsl_u32(c_l, aarch64::vget_low_u32(q), aarch64::vget_low_u32(MPNeon::PACKED_P)); + let d_h = aarch64::vmlsl_high_u32(c_h, q, MPNeon::PACKED_P); - let val_2 = mul_with_precomp::(val, val, mu_val); - let val_3 = mul_with_precomp::(val_2, val, mu_val); + let d_hi = aarch64::vuzp2q_u32(aarch64::vreinterpretq_u32_u64(d_l), aarch64::vreinterpretq_u32_u64(d_h)); - // Safe as mul_with_precomp:: returns integers in [0, P) - aarch64::vreinterpretq_u32_s32(val_3) + // Borrow ↔ d_hi > c_hi (unsigned 32-bit): the low 32 bits cancel in C - qP, + // so the u64 borrow equals the u32 high-word borrow. + let borrow = aarch64::vcgtq_u32(d_hi, c_hi); + let corr = aarch64::vandq_u32(borrow, MPNeon::PACKED_P); + aarch64::vaddq_u32(d_hi, corr) } } -/// Take the fifth power of the MontyField31 field elements. +/// Multiply MontyField31 field elements (unsigned, works for P > 2^31). /// -/// # Safety -/// Inputs must be signed 32-bit integers in the range [-P, P]. -/// Outputs will be a unsigned 32-bit integers in canonical form [0, ..., P). +/// Inputs must be unsigned 32-bit integers in [0, P). +/// Outputs will be unsigned 32-bit integers in canonical form [0, P). #[inline] #[must_use] -fn exp_5(val: int32x4_t) -> uint32x4_t { - // throughput: 4 cyc/vec (1 els/cyc) - // latency: 30 cyc - +fn mul_unsigned(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t { unsafe { - let mu_val = mulby_mu::(val); - - let val_2 = mul_with_precomp::(val, val, mu_val); - - // mu_val_2 and val_3 can be computed in parallel. - let mu_val_2 = mulby_mu::(val_2); - let val_3 = mul_with_precomp::(val_2, val, mu_val); - - let val_5 = mul_with_precomp::(val_3, val_2, mu_val_2); - - // Safe as mul_with_precomp:: returns integers in [0, P) - aarch64::vreinterpretq_u32_s32(val_5) + // Widening multiply: C = lhs * rhs (64-bit per lane) + let c_l = aarch64::vmull_u32(aarch64::vget_low_u32(lhs), aarch64::vget_low_u32(rhs)); + let c_h = aarch64::vmull_high_u32(lhs, rhs); + monty_reduce_neon::(c_l, c_h) } } -/// Take the seventh power of the MontyField31 field elements. -/// -/// # Safety -/// Inputs must be signed 32-bit integers in the range [-P, P]. -/// Outputs will be a unsigned 32-bit integers in canonical form [0, ..., P). +/// Take cube of MontyField31 field elements (unsigned path). #[inline] #[must_use] -fn exp_7(val: int32x4_t) -> uint32x4_t { - // throughput: 5.25 cyc/vec (0.76 els/cyc) - // latency: 33 cyc - - unsafe { - let mu_val = mulby_mu::(val); - - let val_2 = mul_with_precomp::(val, val, mu_val); - - // mu_val_2, val_4 and val_3, mu_val_3 can be computed in parallel. - let mu_val_2 = mulby_mu::(val_2); - let val_3 = mul_with_precomp::(val_2, val, mu_val); - - let mu_val_3 = mulby_mu::(val_3); - let val_4 = mul_with_precomp::(val_2, val_2, mu_val_2); +fn cube_unsigned(val: uint32x4_t) -> uint32x4_t { + let val_2 = mul_unsigned::(val, val); + mul_unsigned::(val_2, val) +} - let val_7 = mul_with_precomp::(val_4, val_3, mu_val_3); +/// Take the fifth power (unsigned path). +#[inline] +#[must_use] +fn exp_5_unsigned(val: uint32x4_t) -> uint32x4_t { + let val_2 = mul_unsigned::(val, val); + let val_3 = mul_unsigned::(val_2, val); + mul_unsigned::(val_3, val_2) +} - // Safe as mul_with_precomp:: returns integers in [0, P) - aarch64::vreinterpretq_u32_s32(val_7) - } +/// Take the seventh power (unsigned path). +#[inline] +#[must_use] +fn exp_7_unsigned(val: uint32x4_t) -> uint32x4_t { + let val_2 = mul_unsigned::(val, val); + let val_3 = mul_unsigned::(val_2, val); + let val_4 = mul_unsigned::(val_2, val_2); + mul_unsigned::(val_4, val_3) } /// Negate a vector of Monty31 field elements in canonical form. @@ -601,21 +426,14 @@ unsafe impl PackedField for PackedMontyField31Neon { general_dot_product::<_, _, _, N>(coeffs, vecs) } - /// Fused `(self - rhs) * scalar` that skips the intermediate modular reduction on the - /// subtraction. This is valid because `vsubq_u32(x, y)` with `x, y ∈ [0, P)` produces a - /// value that, reinterpreted as `i32`, lies in `(-P, P)` — exactly the signed input range - /// that Montgomery multiplication accepts. + /// Fused `(self - rhs) * scalar`. #[inline] fn fused_sub_mul(self, rhs: Self, scalar: Self::Scalar) -> Self { - unsafe { - // Unreduced subtraction: result in (-P, P) when reinterpreted as signed. - let diff = aarch64::vreinterpretq_s32_u32(aarch64::vsubq_u32(self.to_vector(), rhs.to_vector())); - let scalar_vec: Self = scalar.into(); - let scalar_s = scalar_vec.to_signed_vector(); - let res = mul::(diff, scalar_s); - // Safety: `mul` returns values in canonical form [0, P). - Self::from_vector(res) - } + // With P > 2^31, we can't use the signed multiplication shortcut. + // Fall back to sub + mul. + let diff = self - rhs; + let scalar_packed: Self = scalar.into(); + diff * scalar_packed } } @@ -640,53 +458,60 @@ where RHS: IntoVec

, { unsafe { - // Accumulate the full 64-bit sum C = l0*r0 + l1*r1. - - // Low half (Lanes 0 & 1) - let mut sum_l = aarch64::vmull_u32( + // Accumulate C = l0*r0 + l1*r1 in u64 (may overflow for P > 2^31). + // + // For P > 2^31: each product < P^2 ≈ 2^63.9, so 2 products can exceed 2^64. + // We detect the u64 overflow and correct the Montgomery result afterwards. + // Overflow correction: true_sum = u64_sum + 2^64, so D_true = D_naive + 2^32. + // In the field: D_true ≡ D_naive + (2^32 mod P) = D_naive + (2^32 - P). + // Since D_naive ∈ [0, P) and (2^32 - P) < P, the sum is in [0, 2P). + // One conditional subtraction of P yields [0, P). + + // Low half: accumulate with overflow detection + let prod0_l = aarch64::vmull_u32( aarch64::vget_low_u32(lhs[0].into_vec()), aarch64::vget_low_u32(rhs[0].into_vec()), ); - sum_l = aarch64::vmlal_u32( - sum_l, + let sum_l = aarch64::vmlal_u32( + prod0_l, aarch64::vget_low_u32(lhs[1].into_vec()), aarch64::vget_low_u32(rhs[1].into_vec()), ); + let over_l = aarch64::vcltq_u64(sum_l, prod0_l); // overflow: sum < prev - // High half (Lanes 2 & 3) - let mut sum_h = aarch64::vmull_high_u32(lhs[0].into_vec(), rhs[0].into_vec()); - sum_h = aarch64::vmlal_high_u32(sum_h, lhs[1].into_vec(), rhs[1].into_vec()); + // High half: same + let prod0_h = aarch64::vmull_high_u32(lhs[0].into_vec(), rhs[0].into_vec()); + let sum_h = aarch64::vmlal_high_u32(prod0_h, lhs[1].into_vec(), rhs[1].into_vec()); + let over_h = aarch64::vcltq_u64(sum_h, prod0_h); - // Split C into 32-bit low halves per lane: c_lo = C mod 2^{32} + // Montgomery reduction using 32-bit high-word borrow trick let c_lo = aarch64::vuzp1q_u32( aarch64::vreinterpretq_u32_u64(sum_l), aarch64::vreinterpretq_u32_u64(sum_h), ); - - // q ≡ c_lo ⋅ μ (mod 2^{32}), with μ = −P^{-1} (mod 2^{32}). + let c_hi = aarch64::vuzp2q_u32( + aarch64::vreinterpretq_u32_u64(sum_l), + aarch64::vreinterpretq_u32_u64(sum_h), + ); let q = aarch64::vmulq_u32(c_lo, aarch64::vreinterpretq_u32_s32(P::PACKED_MU)); - - // Compute d = (C - q⋅P) / B using multiply-subtract-long instructions. - // - // This combines the multiplication q⋅P and subtraction C - q⋅P in one step. let d_l = aarch64::vmlsl_u32(sum_l, aarch64::vget_low_u32(q), aarch64::vget_low_u32(P::PACKED_P)); let d_h = aarch64::vmlsl_high_u32(sum_h, q, P::PACKED_P); + let d_hi = aarch64::vuzp2q_u32(aarch64::vreinterpretq_u32_u64(d_l), aarch64::vreinterpretq_u32_u64(d_h)); + let borrow = aarch64::vcgtq_u32(d_hi, c_hi); + let mut d = aarch64::vaddq_u32(d_hi, aarch64::vandq_u32(borrow, P::PACKED_P)); + + // Overflow correction: add (2^32 - P) where u64 overflow occurred + let over = aarch64::vuzp2q_u32( + aarch64::vreinterpretq_u32_u64(over_l), + aarch64::vreinterpretq_u32_u64(over_h), + ); + let neg_p = aarch64::vdupq_n_u32(0u32.wrapping_sub(P::PRIME)); // 2^32 - P + d = aarch64::vaddq_u32(d, aarch64::vandq_u32(over, neg_p)); - // Extract the high 32 bits (the division by B = 2^32) from d_l and d_h. - let d = aarch64::vuzp2q_u32(aarch64::vreinterpretq_u32_u64(d_l), aarch64::vreinterpretq_u32_u64(d_h)); - - // Canonicalize d from (-P, P) to [0, P) branchlessly. - // - // The `vmlsq_u32` instruction computes `a - (b * c)`. - // - If `d` is negative (interpreted as unsigned, it's >= 2^31), the mask is `-1` (all 1s), - // so we compute `d - (-1 * P) = d + P`. - // - If `d` is non-negative, the mask is `0`, so we compute `d - (0 * P) = d`. - // - // Check if d >= 2^31 (i.e., negative when interpreted as signed). - let underflow = aarch64::vcgeq_u32(d, aarch64::vdupq_n_u32(1u32 << 31)); - let canonical_res = aarch64::vmlsq_u32(d, underflow, P::PACKED_P); + // Final reduction from [0, 2P) → [0, P) + let geq_p = aarch64::vcgeq_u32(d, P::PACKED_P); + let canonical_res = aarch64::vsubq_u32(d, aarch64::vandq_u32(geq_p, P::PACKED_P)); - // Safety: The result is now in canonical form [0, P). PackedMontyField31Neon::from_vector(canonical_res) } } @@ -701,279 +526,28 @@ where { assert_eq!(lhs.len(), N); assert_eq!(rhs.len(), N); + // For P > 2^31, we accumulate at most 2 products per Montgomery reduction (via dot_product_2 + // with u64 overflow correction), then sum results with field additions. match N { 0 => PackedMontyField31Neon::

::ZERO, 1 => lhs[0].into() * rhs[0].into(), 2 => unsafe { dot_product_2(&[lhs[0], lhs[1]], &[rhs[0], rhs[1]]) }, - 3 => { - let lhs_packed = [ - lhs[0].into(), - lhs[1].into(), - lhs[2].into(), - PackedMontyField31Neon::

::ZERO, - ]; - let rhs_packed = [ - rhs[0].into(), - rhs[1].into(), - rhs[2].into(), - PackedMontyField31Neon::

::ZERO, - ]; - unsafe { dot_product_4(&lhs_packed, &rhs_packed) } - } - 4 => unsafe { dot_product_4(&[lhs[0], lhs[1], lhs[2], lhs[3]], &[rhs[0], rhs[1], rhs[2], rhs[3]]) }, - 5 => unsafe { - dot_product_5( - &[lhs[0], lhs[1], lhs[2], lhs[3], lhs[4]], - &[rhs[0], rhs[1], rhs[2], rhs[3], rhs[4]], - ) - }, - 64 => { - let sum_4s: [PackedMontyField31Neon

; 16] = core::array::from_fn(|i| { - let start = i * 4; - unsafe { - dot_product_4( - &[lhs[start], lhs[start + 1], lhs[start + 2], lhs[start + 3]], - &[rhs[start], rhs[start + 1], rhs[start + 2], rhs[start + 3]], - ) - } - }); - PackedMontyField31Neon::

::sum_array::<16>(&sum_4s) - } _ => { - // Initialize accumulator with the first chunk of 4. - let mut acc = - unsafe { dot_product_4(&[lhs[0], lhs[1], lhs[2], lhs[3]], &[rhs[0], rhs[1], rhs[2], rhs[3]]) }; - - // Loop over the rest of the full chunks of 4. - for i in (4..N).step_by(4) { - if i + 3 < N { - acc += unsafe { - dot_product_4( - &[lhs[i], lhs[i + 1], lhs[i + 2], lhs[i + 3]], - &[rhs[i], rhs[i + 1], rhs[i + 2], rhs[i + 3]], - ) - }; - } + // Process pairs using dot_product_2 (amortizes 1 monty reduction over 2 products) + let mut acc: PackedMontyField31Neon

= unsafe { dot_product_2(&[lhs[0], lhs[1]], &[rhs[0], rhs[1]]) }; + let mut i = 2; + while i + 1 < N { + acc += unsafe { dot_product_2(&[lhs[i], lhs[i + 1]], &[rhs[i], rhs[i + 1]]) }; + i += 2; } - - // Handle the remainder recursively by creating new arrays and calling self. - match N % 4 { - 0 => acc, - 1 => { - let rem_start = N - 1; - let lhs_rem: [_; 1] = core::array::from_fn(|i| lhs[rem_start + i]); - let rhs_rem: [_; 1] = core::array::from_fn(|i| rhs[rem_start + i]); - acc + general_dot_product::<_, _, _, 1>(&lhs_rem, &rhs_rem) - } - 2 => { - let rem_start = N - 2; - let lhs_rem: [_; 2] = core::array::from_fn(|i| lhs[rem_start + i]); - let rhs_rem: [_; 2] = core::array::from_fn(|i| rhs[rem_start + i]); - acc + general_dot_product::<_, _, _, 2>(&lhs_rem, &rhs_rem) - } - 3 => { - let rem_start = N - 3; - let lhs_rem: [_; 3] = core::array::from_fn(|i| lhs[rem_start + i]); - let rhs_rem: [_; 3] = core::array::from_fn(|i| rhs[rem_start + i]); - acc + general_dot_product::<_, _, _, 3>(&lhs_rem, &rhs_rem) - } - _ => unreachable!(), + if i < N { + acc += lhs[i].into() * rhs[i].into(); } + acc } } } -/// Compute the elementary function `l0*r0 + l1*r1 + l2*r2 + l3*r3` given eight inputs -/// in canonical form. -/// -/// If the inputs are not in canonical form, the result is undefined. -#[inline] -unsafe fn dot_product_4(lhs: &[LHS; 4], rhs: &[RHS; 4]) -> PackedMontyField31Neon

-where - P: FieldParameters + MontyParametersNeon, - LHS: IntoVec

, - RHS: IntoVec

, -{ - unsafe { - // Accumulate the full 64-bit sum C = Σ lhs_i ⋅ rhs_i. - - // Low half (Lanes 0 & 1) - let mut sum_l = aarch64::vmull_u32( - aarch64::vget_low_u32(lhs[0].into_vec()), - aarch64::vget_low_u32(rhs[0].into_vec()), - ); - sum_l = aarch64::vmlal_u32( - sum_l, - aarch64::vget_low_u32(lhs[1].into_vec()), - aarch64::vget_low_u32(rhs[1].into_vec()), - ); - sum_l = aarch64::vmlal_u32( - sum_l, - aarch64::vget_low_u32(lhs[2].into_vec()), - aarch64::vget_low_u32(rhs[2].into_vec()), - ); - sum_l = aarch64::vmlal_u32( - sum_l, - aarch64::vget_low_u32(lhs[3].into_vec()), - aarch64::vget_low_u32(rhs[3].into_vec()), - ); - - // High half (Lanes 2 & 3) - let mut sum_h = aarch64::vmull_high_u32(lhs[0].into_vec(), rhs[0].into_vec()); - sum_h = aarch64::vmlal_high_u32(sum_h, lhs[1].into_vec(), rhs[1].into_vec()); - sum_h = aarch64::vmlal_high_u32(sum_h, lhs[2].into_vec(), rhs[2].into_vec()); - sum_h = aarch64::vmlal_high_u32(sum_h, lhs[3].into_vec(), rhs[3].into_vec()); - - // Split C into 32-bit halves per lane: - // - c_lo = C mod 2^{32}, - // - c_hi = C >> 32. - let c_lo = aarch64::vuzp1q_u32( - aarch64::vreinterpretq_u32_u64(sum_l), - aarch64::vreinterpretq_u32_u64(sum_h), - ); - let c_hi = aarch64::vuzp2q_u32( - aarch64::vreinterpretq_u32_u64(sum_l), - aarch64::vreinterpretq_u32_u64(sum_h), - ); - - // Since C < 4P^2 and P < 2^{31}, we have c_hi < 2P. - // We want to compute: c_hi' ∈ [0,P) satisfying c_hi' = c_hi mod P. - let c_hi_sub = aarch64::vsubq_u32(c_hi, P::PACKED_P); - let c_hi_prime = aarch64::vminq_u32(c_hi, c_hi_sub); - - // q ≡ c_lo ⋅ μ (mod 2^{32}), with μ = −P^{-1} (mod 2^{32}). - let q = aarch64::vmulq_u32(c_lo, aarch64::vreinterpretq_u32_s32(P::PACKED_MU)); - - // Compute (q⋅P)_hi = high 32 bits of q⋅P per lane (exact unsigned widening multiply). - let qp_l = aarch64::vmull_u32(aarch64::vget_low_u32(q), aarch64::vget_low_u32(P::PACKED_P)); - let qp_h = aarch64::vmull_high_u32(q, P::PACKED_P); - let qp_hi = aarch64::vuzp2q_u32( - aarch64::vreinterpretq_u32_u64(qp_l), - aarch64::vreinterpretq_u32_u64(qp_h), - ); - - let d = aarch64::vsubq_u32(c_hi_prime, qp_hi); - - // Canonicalize d from (-P, P) to [0, P) branchlessly. - // - // The `vmlsq_u32` instruction computes `a - (b * c)`. - // - If `d` is negative, the mask is `-1` (all 1s), so we compute `d - (-1 * P) = d + P`. - // - If `d` is non-negative, the mask is `0`, so we compute `d - (0 * P) = d`. - let underflow = aarch64::vcltq_u32(c_hi_prime, qp_hi); - let canonical_res = aarch64::vmlsq_u32(d, underflow, P::PACKED_P); - - // Safety: The result is now in canonical form [0, P). - PackedMontyField31Neon::from_vector(canonical_res) - } -} - -/// Compute the elementary function `l0*r0 + l1*r1 + l2*r2 + l3*r3 + l4*r4` given ten inputs -/// in canonical form. -/// -/// If the inputs are not in canonical form, the result is undefined. -#[inline] -unsafe fn dot_product_5(lhs: &[LHS; 5], rhs: &[RHS; 5]) -> PackedMontyField31Neon

-where - P: FieldParameters + MontyParametersNeon, - LHS: IntoVec

, - RHS: IntoVec

, -{ - unsafe { - // Materialize all vectors once. - let lhs0 = lhs[0].into_vec(); - let rhs0 = rhs[0].into_vec(); - let lhs1 = lhs[1].into_vec(); - let rhs1 = rhs[1].into_vec(); - let lhs2 = lhs[2].into_vec(); - let rhs2 = rhs[2].into_vec(); - let lhs3 = lhs[3].into_vec(); - let rhs3 = rhs[3].into_vec(); - let lhs4 = lhs[4].into_vec(); - let rhs4 = rhs[4].into_vec(); - - // Group A: accumulate terms 0-2 in wide form. Safe: 3*(P-1)^2 < 2^64. - - // Low half (Lanes 0 & 1) - let mut sum_al = aarch64::vmull_u32(aarch64::vget_low_u32(lhs0), aarch64::vget_low_u32(rhs0)); - sum_al = aarch64::vmlal_u32(sum_al, aarch64::vget_low_u32(lhs1), aarch64::vget_low_u32(rhs1)); - sum_al = aarch64::vmlal_u32(sum_al, aarch64::vget_low_u32(lhs2), aarch64::vget_low_u32(rhs2)); - - // High half (Lanes 2 & 3) - let mut sum_ah = aarch64::vmull_high_u32(lhs0, rhs0); - sum_ah = aarch64::vmlal_high_u32(sum_ah, lhs1, rhs1); - sum_ah = aarch64::vmlal_high_u32(sum_ah, lhs2, rhs2); - - // Group B: accumulate terms 3-4 in wide form. Safe: 2*(P-1)^2 < 2^64. - - // Low half (Lanes 0 & 1) - let mut sum_bl = aarch64::vmull_u32(aarch64::vget_low_u32(lhs3), aarch64::vget_low_u32(rhs3)); - sum_bl = aarch64::vmlal_u32(sum_bl, aarch64::vget_low_u32(lhs4), aarch64::vget_low_u32(rhs4)); - - // High half (Lanes 2 & 3) - let mut sum_bh = aarch64::vmull_high_u32(lhs3, rhs3); - sum_bh = aarch64::vmlal_high_u32(sum_bh, lhs4, rhs4); - - // Split each group into 32-bit c_lo, c_hi. - let c_lo_a = aarch64::vuzp1q_u32( - aarch64::vreinterpretq_u32_u64(sum_al), - aarch64::vreinterpretq_u32_u64(sum_ah), - ); - let c_hi_a = aarch64::vuzp2q_u32( - aarch64::vreinterpretq_u32_u64(sum_al), - aarch64::vreinterpretq_u32_u64(sum_ah), - ); - let c_lo_b = aarch64::vuzp1q_u32( - aarch64::vreinterpretq_u32_u64(sum_bl), - aarch64::vreinterpretq_u32_u64(sum_bh), - ); - let c_hi_b = aarch64::vuzp2q_u32( - aarch64::vreinterpretq_u32_u64(sum_bl), - aarch64::vreinterpretq_u32_u64(sum_bh), - ); - - // Reduce group A's c_hi from [0, 2P) to [0, P). Group B's c_hi < P needs no reduction. - let c_hi_a_sub = aarch64::vsubq_u32(c_hi_a, P::PACKED_P); - let c_hi_a_red = aarch64::vminq_u32(c_hi_a, c_hi_a_sub); - - // Merge the two groups with carry propagation. - // - // c_lo = c_lo_a + c_lo_b (wrapping u32 add). - let c_lo = aarch64::vaddq_u32(c_lo_a, c_lo_b); - // carry = -1 (all 1s) if c_lo wrapped, 0 otherwise. - let carry = aarch64::vcltq_u32(c_lo, c_lo_a); - - // c_hi_sum ∈ [0, 2P-2]. - let c_hi_sum = aarch64::vaddq_u32(c_hi_a_red, c_hi_b); - // Subtracting -1 adds 1; subtracting 0 is a no-op. c_hi ∈ [0, 2P-1]. - let c_hi = aarch64::vsubq_u32(c_hi_sum, carry); - // Conditional subtract by P → c_hi_prime ∈ [0, P-1]. - let c_hi_sub = aarch64::vsubq_u32(c_hi, P::PACKED_P); - let c_hi_prime = aarch64::vminq_u32(c_hi, c_hi_sub); - - // Montgomery reduction (identical to dot_product_4). - // - // q ≡ c_lo ⋅ μ (mod 2^{32}). - let q = aarch64::vmulq_u32(c_lo, aarch64::vreinterpretq_u32_s32(P::PACKED_MU)); - - // (q⋅P)_hi = high 32 bits of q⋅P. - let qp_l = aarch64::vmull_u32(aarch64::vget_low_u32(q), aarch64::vget_low_u32(P::PACKED_P)); - let qp_h = aarch64::vmull_high_u32(q, P::PACKED_P); - let qp_hi = aarch64::vuzp2q_u32( - aarch64::vreinterpretq_u32_u64(qp_l), - aarch64::vreinterpretq_u32_u64(qp_h), - ); - - let d = aarch64::vsubq_u32(c_hi_prime, qp_hi); - - // Canonicalize d from (-P, P) to [0, P) branchlessly. - let underflow = aarch64::vcltq_u32(c_hi_prime, qp_hi); - let canonical_res = aarch64::vmlsq_u32(d, underflow, P::PACKED_P); - - // Safety: The result is now in canonical form [0, P). - PackedMontyField31Neon::from_vector(canonical_res) - } -} - /// Multiplication by a base field element in a binomial extension field. #[inline] pub fn base_mul_packed( @@ -1023,14 +597,14 @@ pub fn base_mul_packed( /// Outputs will be unsigned 32-bit integers in canonical form `[0, P)`. #[inline(always)] #[must_use] -pub(crate) fn exp_small(val: int32x4_t) -> uint32x4_t +pub(crate) fn exp_small(val: uint32x4_t) -> uint32x4_t where PMP: PackedMontyParameters + FieldParameters, { match D { - 3 => cube::(val), - 5 => exp_5::(val), - 7 => exp_7::(val), + 3 => cube_unsigned::(val), + 5 => exp_5_unsigned::(val), + 7 => exp_7_unsigned::(val), _ => panic!("No exp function for given D"), } } diff --git a/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs b/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs index 61333abca..abe9972ec 100644 --- a/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs +++ b/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs @@ -2,14 +2,12 @@ //! NEON helpers shared by Poseidon1 permutations. -use core::arch::aarch64::{self, int32x4_t, uint32x4_t}; +use core::arch::aarch64::{self, uint32x4_t}; use core::mem::transmute; use super::exp_small; -use crate::{FieldParameters, MontyParameters, PackedMontyField31Neon, PackedMontyParameters, RelativelyPrimePower}; - -// Convenience alias to match the naming used for the AVX2/AVX512 helpers. -pub(crate) use convert_to_vec_neg_form_neon as convert_to_vec_neg_form; +use crate::{FieldParameters, PackedMontyField31Neon, PackedMontyParameters, RelativelyPrimePower}; +use field::uint32x4_mod_add; /// A specialized representation of the Poseidon state for a width of 16. /// @@ -35,40 +33,23 @@ impl InternalLayer16 { } } -/// Converts a scalar constant into a packed NEON vector in "negative form" (`c - P`). +/// Converts a scalar constant into a packed NEON vector (canonical unsigned form). #[inline(always)] -pub(crate) fn convert_to_vec_neg_form_neon(input: i32) -> int32x4_t { - unsafe { - let input_sub_p = input - (MP::PRIME as i32); - aarch64::vdupq_n_s32(input_sub_p) - } +pub(crate) fn convert_to_vec_neon(input: u32) -> uint32x4_t { + unsafe { aarch64::vdupq_n_u32(input) } } -/// Performs the fused AddRoundConstant and S-Box operation `x -> (x + c)^D`. +/// Performs the AddRoundConstant and S-Box operation `x -> (x + c)^D`. /// /// `val` must contain elements in canonical form `[0, P)`. -/// `rc` must contain round constants in negative form `[-P, 0)`. -pub(crate) fn add_rc_and_sbox(val: &mut PackedMontyField31Neon, rc: int32x4_t) +/// `rc` must contain round constants in canonical form `[0, P)`. +pub(crate) fn add_rc_and_sbox(val: &mut PackedMontyField31Neon, rc: uint32x4_t) where PMP: PackedMontyParameters + FieldParameters + RelativelyPrimePower, { unsafe { - let vec_val_s = val.to_signed_vector(); - let val_plus_rc = aarch64::vaddq_s32(vec_val_s, rc); + let val_plus_rc = uint32x4_mod_add(val.to_vector(), rc, PMP::PACKED_P); let output = exp_small::(val_plus_rc); *val = PackedMontyField31Neon::::from_vector(output); } } - -/// Applies the S-Box `x -> x^D` to a packed vector. Output is in canonical form. -#[inline(always)] -pub(crate) fn sbox(val: PackedMontyField31Neon) -> PackedMontyField31Neon -where - PMP: PackedMontyParameters + FieldParameters + RelativelyPrimePower, -{ - unsafe { - let signed = val.to_signed_vector(); - let out = exp_small::(signed); - PackedMontyField31Neon::::from_vector(out) - } -} diff --git a/crates/backend/koala-bear/src/monty_31/data_traits.rs b/crates/backend/koala-bear/src/monty_31/data_traits.rs index 168d2488b..448c6609d 100644 --- a/crates/backend/koala-bear/src/monty_31/data_traits.rs +++ b/crates/backend/koala-bear/src/monty_31/data_traits.rs @@ -10,7 +10,7 @@ use crate::MontyField31; /// MontyParameters contains the prime P along with constants needed to convert elements into and out of MONTY form. /// The MONTY constant is assumed to be a power of 2. pub trait MontyParameters: Copy + Clone + Default + Debug + Eq + PartialEq + Sync + Send + Hash + 'static { - // A 31-bit prime. + // A prime that fits in a u32. May be larger than 2^31. const PRIME: u32; // The log_2 of our MONTY constant. diff --git a/crates/backend/koala-bear/src/monty_31/monty_31.rs b/crates/backend/koala-bear/src/monty_31/monty_31.rs index 3fbb7fb40..a651a2fb1 100644 --- a/crates/backend/koala-bear/src/monty_31/monty_31.rs +++ b/crates/backend/koala-bear/src/monty_31/monty_31.rs @@ -24,8 +24,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use utils::{flatten_to_base, gcd_inversion_prime_field_32}; use crate::monty_31::utils::{ - from_monty, halve_u32, large_monty_reduce, monty_add, monty_reduce, monty_reduce_u128, monty_sub, to_monty, - to_monty_64, to_monty_64_signed, to_monty_signed, + from_monty, halve_u32, monty_add, monty_reduce, monty_sub, to_monty, to_monty_64, to_monty_64_signed, + to_monty_signed, }; use crate::{FieldParameters, MontyParameters, RelativelyPrimePower, TwoAdicData}; @@ -140,10 +140,9 @@ impl Distribution> for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> MontyField31 { loop { - let next_u31 = rng.next_u32() >> 1; - let is_canonical = next_u31 < FP::PRIME; - if is_canonical { - return MontyField31::new_monty(next_u31); + let next = rng.next_u32(); + if next < FP::PRIME { + return MontyField31::new_monty(next); } } } @@ -247,118 +246,25 @@ impl PrimeCharacteristicRing for MontyField31 { #[inline] fn dot_product(lhs: &[Self; N], rhs: &[Self; N]) -> Self { assert!(N as u64 <= (1 << 34)); - // This code relies on assumptions about the relative size of the - // prime and the monty parameter. If these are changes this needs to be checked. debug_assert!(FP::MONTY_BITS == 32); - debug_assert!((FP::PRIME as u64) < (1 << 31)); match N { 0 => Self::ZERO, 1 => lhs[0] * rhs[0], - 2 => { - // As all values are < P < 2^31, the products are < P^2 < 2^31P. - // Hence, summing two together we stay below MONTY*P which means - // monty_reduce will produce a valid result. - let u64_prod_sum = - (lhs[0].value as u64) * (rhs[0].value as u64) + (lhs[1].value as u64) * (rhs[1].value as u64); - Self::new_monty(monty_reduce::(u64_prod_sum)) - } - 3 => { - // As all values are < P < 2^31, the products are < P^2 < 2^31P. - // Hence, summing three together will be less than 2 * MONTY * P - let u64_prod_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64); - Self::new_monty(large_monty_reduce::(u64_prod_sum)) - } - 4 => { - // As all values are < P < 2^31, the products are < P^2 < 2^31P. - // Hence, summing four together will be less than 2 * MONTY * P. - let u64_prod_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64) - + (lhs[3].value as u64) * (rhs[3].value as u64); - Self::new_monty(large_monty_reduce::(u64_prod_sum)) - } - 5 => { - let head_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64) - + (lhs[3].value as u64) * (rhs[3].value as u64); - let tail_sum = (lhs[4].value as u64) * (rhs[4].value as u64); - // head_sum < 4*P^2, tail_sum < P^2. - let head_sum_corr = head_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - // head_sum.min(head_sum_corr) is guaranteed to be < 2*P^2. - // Hence sum < 4P^2 < 2 * MONTY * P - let sum = head_sum.min(head_sum_corr) + tail_sum; - Self::new_monty(large_monty_reduce::(sum)) - } - 6 => { - let head_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64) - + (lhs[3].value as u64) * (rhs[3].value as u64); - let tail_sum = - (lhs[4].value as u64) * (rhs[4].value as u64) + (lhs[5].value as u64) * (rhs[5].value as u64); - // head_sum < 4*P^2, tail_sum < 2*P^2. - let head_sum_corr = head_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - // head_sum.min(head_sum_corr) is guaranteed to be < 2*P^2. - // Hence sum < 4P^2 < 2 * MONTY * P - let sum = head_sum.min(head_sum_corr) + tail_sum; - Self::new_monty(large_monty_reduce::(sum)) - } - 7 => { - let head_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64) - + (lhs[3].value as u64) * (rhs[3].value as u64); - let tail_sum = (lhs[4].value as u64) * (rhs[4].value as u64) - + lhs[5].value as u64 * (rhs[5].value as u64) - + lhs[6].value as u64 * (rhs[6].value as u64); - // head_sum, tail_sum are guaranteed to be < 4*P^2. - let head_sum_corr = head_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - let tail_sum_corr = tail_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - // head_sum.min(head_sum_corr), tail_sum.min(tail_sum_corr) is guaranteed to be < 2*P^2. - // Hence sum < 4P^2 < 2 * MONTY * P - let sum = head_sum.min(head_sum_corr) + tail_sum.min(tail_sum_corr); - Self::new_monty(large_monty_reduce::(sum)) - } - 8 => { - let head_sum = (lhs[0].value as u64) * (rhs[0].value as u64) - + (lhs[1].value as u64) * (rhs[1].value as u64) - + (lhs[2].value as u64) * (rhs[2].value as u64) - + (lhs[3].value as u64) * (rhs[3].value as u64); - let tail_sum = (lhs[4].value as u64) * (rhs[4].value as u64) - + lhs[5].value as u64 * (rhs[5].value as u64) - + lhs[6].value as u64 * (rhs[6].value as u64) - + lhs[7].value as u64 * (rhs[7].value as u64); - // head_sum, tail_sum are guaranteed to be < 4*P^2. - let head_sum_corr = head_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - let tail_sum_corr = tail_sum.wrapping_sub((FP::PRIME as u64) << FP::MONTY_BITS); - // head_sum.min(head_sum_corr), tail_sum.min(tail_sum_corr) is guaranteed to be < 2*P^2. - // Hence sum < 4P^2 < 2 * MONTY * P - let sum = head_sum.min(head_sum_corr) + tail_sum.min(tail_sum_corr); - Self::new_monty(large_monty_reduce::(sum)) - } _ => { - // For large enough N, we accumulate into a u128. This helps the compiler as it lets - // it do a lot of computation in parallel as it knows that summing u128's is associative. - let acc_u128 = lhs - .chunks(4) - .zip(rhs.chunks(4)) - .map(|(l, r)| { - // As all values are < P < 2^31, the products are < P^2 < 2^31P. - // Hence, summing four together will not overflow a u64 but will be - // larger than 2^32P. - let u64_prod_sum = l - .iter() - .zip(r) - .map(|(l, r)| (l.value as u64) * (r.value as u64)) - .sum::(); - u64_prod_sum as u128 - }) + // Each product is < P^2 < 2^64. For P > 2^31, even 2 products can overflow u64, + // so we accumulate into u128 for all N >= 2. Each product < P^2 < 2^64, and + // summing N products gives < N * P^2 < 2^34 * 2^64 = 2^98 > 2^96. + // We batch into chunks of 4 to keep individual u64 sums within range: + // a single product fits in u64, so we widen to u128 per chunk. + let acc_u128: u128 = lhs + .iter() + .zip(rhs.iter()) + .map(|(l, r)| (l.value as u64 as u128) * (r.value as u64 as u128)) .sum(); - // As N <= 2^34 by the earlier assertion, acc_u128 <= 2^34 * P^2 < 2^34 * 2^62 < 2^96. - Self::new_monty(monty_reduce_u128::(acc_u128)) + // As N <= 2^34, acc_u128 <= 2^34 * P^2 < 2^34 * 2^64 = 2^98. + // monty_reduce_u128 requires input < 2^96, so reduce first. + let acc_reduced = (acc_u128 % FP::PRIME as u128) as u64; + Self::new_monty(monty_reduce::(acc_reduced)) } } } @@ -397,8 +303,8 @@ impl Field for MontyField31 { return None; } - // The number of bits of FP::PRIME. By the very name of MontyField31 this should always be 31. - const NUM_PRIME_BITS: u32 = 31; + // The number of bits of FP::PRIME. + const NUM_PRIME_BITS: u32 = 32; // Get the inverse using a gcd algorithm. // We use `val` to denote the input to `gcd_inversion_prime_field_32` and `R = 2^{MONTY_BITS}` @@ -695,7 +601,7 @@ impl Sum for MontyField31 { // This is faster than iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO) for iterators of length > 2. // There might be a faster reduction method possible for lengths <= 16 which avoids %. - // This sum will not overflow so long as iter.len() < 2^33. + // This sum will not overflow so long as iter.len() < 2^32 (each value < P < 2^32). let sum = iter.map(|x| x.value as u64).sum::(); Self::new_monty((sum % FP::PRIME as u64) as u32) } diff --git a/crates/backend/koala-bear/src/monty_31/utils.rs b/crates/backend/koala-bear/src/monty_31/utils.rs index 2ea56c658..60e8ba6f3 100644 --- a/crates/backend/koala-bear/src/monty_31/utils.rs +++ b/crates/backend/koala-bear/src/monty_31/utils.rs @@ -55,29 +55,28 @@ pub(crate) const fn from_monty(x: u32) -> u32 { /// Add two integers modulo `P = MP::PRIME`. /// -/// Assumes that `P` is less than `2^31` and `a + b <= 2P` for all array pairs `a, b`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a + b) mod P`. -/// It will be equal to `P` if and only if `a + b = 2P` so provided `a + b < 2P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` are in `[0, P)`. The result will be in `[0, P)`. +/// Uses a u64 intermediate since `a + b` may exceed u32 when `P > 2^31`. #[inline] #[must_use] pub const fn monty_add(lhs: u32, rhs: u32) -> u32 { - let mut sum = lhs + rhs; - let (corr_sum, over) = sum.overflowing_sub(MP::PRIME); - if !over { - sum = corr_sum; + // Branchless: compute both sum and sum-P, select based on comparison. + let sum = lhs as u64 + rhs as u64; + let reduced = sum.wrapping_sub(MP::PRIME as u64); + // If sum >= P, reduced fits in u32 and is the correct answer. + // If sum < P, reduced has bit 32+ set (negative in u64). + // Select: if bit 32 of reduced is clear, use reduced; else use sum. + if reduced < (1u64 << 32) { + reduced as u32 + } else { + sum as u32 } - sum } /// Subtract two integers modulo `P = MP::PRIME`. /// -/// Assumes that `P` is less than `2^31` and `|a - b| <= P` for all array pairs `a, b`. -/// If the inputs are not in this range, the result may be incorrect. -/// The result will be in the range `[0, P]` and equal to `(a - b) mod P`. -/// It will be equal to `P` if and only if `a - b = P` so provided `a - b < P` -/// the result is guaranteed to be less than `P`. +/// Assumes `a, b` are in `[0, P)`. The result will be in `[0, P)`. +/// Works for any P < 2^32: when `lhs < rhs`, `wrapping_add(P)` corrects via mod-2^32 arithmetic. #[inline] #[must_use] pub fn monty_sub(lhs: u32, rhs: u32) -> u32 { @@ -87,7 +86,7 @@ pub fn monty_sub(lhs: u32, rhs: u32) -> u32 { diff } -/// Given an element `x` from a 31 bit field `F` compute `x/2`. +/// Given an element `x` from a field `F` with `P < 2^32`, compute `x/2`. /// The input must be in `[0, P)`. /// The output will also be in `[0, P)`. #[inline] @@ -131,8 +130,14 @@ pub(crate) const fn monty_reduce(x: u64) -> u32 { /// The output will be in [0, P). /// /// This is slower than `monty_reduce` but has a larger input range. +/// +/// Note: for `P > 2^31` the input range `[0, 2 * MONTY * P) = [0, 2^33 * P)` +/// no longer fits in a u64, so this helper is currently unused. Kept for +/// reference in case the dot_product fast paths are re-introduced under a +/// future representation (e.g. u128 accumulators). #[inline] #[must_use] +#[allow(dead_code)] pub(crate) const fn large_monty_reduce(x: u64) -> u32 { // t = x * MONTY_MU mod MONTY let t = x.wrapping_mul(MP::MONTY_MU as u64) & (MP::MONTY_MASK as u64); @@ -196,6 +201,7 @@ pub(crate) const fn large_monty_reduce(x: u64) -> u32 { /// - One conditional subtraction for the final modular addition. /// /// No 128-bit division or modulo is ever performed. +#[allow(dead_code)] pub(crate) const fn monty_reduce_u128(x: u128) -> u32 { // Split the 128-bit input into its two limbs. // @@ -211,12 +217,12 @@ pub(crate) const fn monty_reduce_u128(x: u128) -> u32 { // the Montgomery reduction helper. // // Range analysis: - // R*P = 2^32 * P < 2^63 (P is a 31-bit prime) - // 2*R*P < 2^64 + // R*P = 2^32 * P < 2^64 (P is a 32-bit prime) + // 2*R*P < 2^65 // lo <= 2^64 - 1 (arbitrary u64) // // So the low limb can exceed the accepted range by at most one copy of 2*R*P. - // Subtracting it once is always enough because 2^64 < 4*R*P for any 31-bit prime. + // Subtracting it once is always enough because 2^64 < 4*R*P for any prime P > 2^30. // // Correctness: 2*R*P is a multiple of P, so this subtraction // does not change the residue modulo P. diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/packing.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/packing.rs index 7eff533e0..f5a4cd086 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/packing.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/packing.rs @@ -2,7 +2,6 @@ use alloc::vec::Vec; use core::arch::x86_64::{self, __m256i}; -use core::array; use core::iter::{Product, Sum}; use core::mem::transmute; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -20,7 +19,7 @@ use rand::Rng; use rand::distr::{Distribution, StandardUniform}; use utils::reconstitute_from_base; -use crate::{FieldParameters, MontyField31, PackedMontyParameters, RelativelyPrimePower, halve_avx2, signed_add_avx2}; +use crate::{FieldParameters, MontyField31, PackedMontyParameters, RelativelyPrimePower, halve_avx2}; const WIDTH: usize = 8; @@ -129,11 +128,9 @@ impl Mul for PackedMontyField31AVX2 { fn mul(self, rhs: Self) -> Self { let lhs = self.to_vector(); let rhs = rhs.to_vector(); - let t = mul::(lhs, rhs); - let res = red_signed_to_canonical::(t); + let res = mul::(lhs, rhs); unsafe { - // Safety: `mul` returns values in signed form when given values in canonical form. - // Then `red_signed_to_canonical` reduces values from signed form to canonical form. + // Safety: `mul` returns values in canonical form when given values in canonical form. Self::from_vector(res) } } @@ -170,23 +167,12 @@ impl PrimeCharacteristicRing for PackedMontyField31AVX2 #[inline] fn cube(&self) -> Self { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_3::); - Self::from_vector(res) - } - } - - #[inline] - fn xor(&self, rhs: &Self) -> Self { - let lhs = self.to_vector(); - let rhs = rhs.to_vector(); - let res = xor::(lhs, rhs); - unsafe { - // Safety: `xor` returns values in canonical form when given values in canonical form. - Self::from_vector(res) - } + // The optimized `apply_func_to_even_odd` + `packed_exp_3` path keeps an + // intermediate "signed in (-P, P)" value packed in i32 lanes. That + // representation is unambiguous only when `P < 2^31`. The current 32-bit + // prime (`P = 0xfa000001`) violates that, so route through the canonical + // unsigned Mul, mirroring AVX512. + *self * self.square() } #[inline] @@ -199,38 +185,16 @@ impl PrimeCharacteristicRing for PackedMontyField31AVX2 Self::from_vector(res) } } + // `xor` is left to the default trait implementation. The optimised + // packed routine assumes `2*lhs` and `2^31 - rhs` fit in a u32 without + // wrap-around, which only holds for `P < 2^31`. With the current 32-bit + // prime the wrap can push `lhs * rhs` past `2^32 * P` and break the + // Montgomery reduction. - #[inline(always)] - fn exp_const_u64(&self) -> Self { - // We provide specialised code for the powers 3, 5, 7 as these turn up regularly. - // The other powers could be specialised similarly but we ignore this for now. - // These ideas could also be used to speed up the more generic exp_u64. - match POWER { - 0 => Self::ONE, - 1 => *self, - 2 => self.square(), - 3 => self.cube(), - 4 => self.square().square(), - 5 => { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_5::); - Self::from_vector(res) - } - } - 6 => self.square().cube(), - 7 => { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_7::); - Self::from_vector(res) - } - } - _ => self.exp_u64(POWER), - } - } + // The `packed_exp_*` helpers (`apply_func_to_even_odd` + signed Montgomery) + // assume `P < 2^31`. With the current 32-bit prime they would overflow, so we + // fall back to the default `exp_const_u64` implementation, which composes + // canonical `square` / `cube` calls. Mirrors AVX512. #[inline(always)] fn dot_product(u: &[Self; N], v: &[Self; N]) -> Self { @@ -280,38 +244,39 @@ impl Algebra> for PackedMontyField31AVX2 2^31`) we cannot represent the partial reduction +/// "signed in (-P, P)" inside an `i32`, so we canonicalize directly: detect the +/// underflow `top32(input) < top32(q_p)` (unsigned compare) and conditionally add `P`. #[inline] #[must_use] -fn partial_monty_red_unsigned_to_signed(input: __m256i) -> __m256i { +fn partial_monty_red_unsigned_to_canonical(input: __m256i) -> __m256i { unsafe { let q = x86_64::_mm256_mul_epu32(input, MPAVX2::PACKED_MU); let q_p = x86_64::_mm256_mul_epu32(q, MPAVX2::PACKED_P); - // By construction, the bottom 32 bits of input and q_p are equal. - // Thus _mm256_sub_epi32 and _mm256_sub_epi64 should act identically. - // However for some reason, the compiler gets confused if we use _mm256_sub_epi64 - // and outputs a load of nonsense, see: https://godbolt.org/z/3W8M7Tv84. - x86_64::_mm256_sub_epi32(input, q_p) - } -} + // `_mm256_sub_epi32` is fine here: the low 32 bits of `input` and `q_p` cancel by + // construction, so the result of this 32-bit subtraction at the upper-32-bit + // positions equals `top32(input) - top32(q_p)` (mod 2^32). + let raw = x86_64::_mm256_sub_epi32(input, q_p); -/// Perform a partial Montgomery reduction on each 64 bit element. -/// Input must lie in {-2^{31}P, ..., 2^31P}. -/// The output will lie in {-P, ..., P} and be stored in the upper 32 bits. -#[inline] -#[must_use] -fn partial_monty_red_signed_to_signed(input: __m256i) -> __m256i { - unsafe { - let q = x86_64::_mm256_mul_epi32(input, MPAVX2::PACKED_MU); - let q_p = x86_64::_mm256_mul_epi32(q, MPAVX2::PACKED_P); + // Detect underflow: `top32(input) < top32(q_p)` as unsigned. + // AVX2 has no native unsigned compare; emulate via `xor i32::MIN` + signed cmpgt. + let flip = x86_64::_mm256_set1_epi32(i32::MIN); + let input_f = x86_64::_mm256_xor_si256(input, flip); + let q_p_f = x86_64::_mm256_xor_si256(q_p, flip); + let underflow = x86_64::_mm256_cmpgt_epi32(q_p_f, input_f); + let corr = x86_64::_mm256_and_si256(underflow, MPAVX2::PACKED_P); - // Unlike the previous case the compiler output is essentially identical - // between _mm256_sub_epi32 and _mm256_sub_epi64. We use _mm256_sub_epi32 - // again just for consistency. - x86_64::_mm256_sub_epi32(input, q_p) + // The mask was computed at the upper-32-bit positions of each 64-bit lane (where + // the actual subtraction result lives); the lower-32-bit values of `corr` happen + // to have the same mask but the lower-32-bit positions of `raw` are discarded by + // `blend_evn_odd`/`movehdup` callers anyway. + x86_64::_mm256_add_epi32(raw, corr) } } @@ -336,55 +301,16 @@ fn blend_evn_odd(evn: __m256i, odd: __m256i) -> __m256i { } } -/// Given a vector of signed field elements, return a vector of elements in canonical form. -/// -/// Inputs must be signed 32-bit integers lying in (-P, ..., P). If they do not lie in -/// this range, the output is undefined. -#[inline(always)] -#[must_use] -fn red_signed_to_canonical(input: __m256i) -> __m256i { - unsafe { - // We want this to compile to: - // vpaddd corr, input, P - // vpminud res, input, corr - // throughput: 0.67 cyc/vec (12 els/cyc) - // latency: 2 cyc - - // We want to return input mod P where input lies in (-2^31 <) -P + 1, ..., P - 1 (< 2^31). - // It suffices to return input if input >= 0 and input + P otherwise. - // - // Let corr := (input + P) mod 2^32 and res := unsigned_min(input, corr). - // If input is in 0, ..., P - 1, then corr is in P, ..., 2 P - 1 and res = input. - // Otherwise, input is in -P + 1, ..., -1; corr is in 1, ..., P - 1 (< P) and res = corr. - // Hence, res is input if input < P and input + P otherwise, as desired. - let corr = x86_64::_mm256_add_epi32(input, MPAVX2::PACKED_P); - x86_64::_mm256_min_epu32(input, corr) - } -} - /// Multiply the MontyField31 field elements in the even index entries. /// lhs[2i], rhs[2i] must be unsigned 32-bit integers such that /// lhs[2i] * rhs[2i] lies in {0, ..., 2^32P}. -/// The output will lie in {-P, ..., P} and be stored in output[2i + 1]. +/// The output is canonical ([0, P)) and stored in output[2i + 1]. #[inline] #[must_use] fn monty_mul(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { let prod = x86_64::_mm256_mul_epu32(lhs, rhs); - partial_monty_red_unsigned_to_signed::(prod) - } -} - -/// Multiply the MontyField31 field elements in the even index entries. -/// lhs[2i], rhs[2i] must be signed 32-bit integers such that -/// lhs[2i] * rhs[2i] lies in {-2^31P, ..., 2^31P}. -/// The output will lie in {-P, ..., P} stored in output[2i + 1]. -#[inline] -#[must_use] -fn monty_mul_signed(lhs: __m256i, rhs: __m256i) -> __m256i { - unsafe { - let prod = x86_64::_mm256_mul_epi32(lhs, rhs); - partial_monty_red_signed_to_signed::(prod) + partial_monty_red_unsigned_to_canonical::(prod) } } @@ -396,7 +322,7 @@ fn movehdup_epi32(x: __m256i) -> __m256i { unsafe { x86_64::_mm256_castps_si256(x86_64::_mm256_movehdup_ps(x86_64::_mm256_castsi256_ps(x))) } } -/// Multiply unsigned vectors of field elements returning a vector of signed integers lying in (-P, P). +/// Multiply unsigned vectors of field elements returning a vector of canonical results in [0, P). /// /// Inputs are allowed to not be in canonical form however they must obey the bound `lhs*rhs < 2^32P`. If this bound /// is broken, the output is undefined. @@ -468,195 +394,13 @@ impl IntoM256 for MontyField31 { } } -/// Compute the elementary function `l0*r0 + l1*r1` given four inputs -/// in canonical form. -/// -/// If the inputs are not in canonical form, the result is undefined. -#[inline] -#[must_use] -fn dot_product_2, RHS: IntoM256>( - lhs: [LHS; 2], - rhs: [RHS; 2], -) -> __m256i { - // The following analysis treats all input arrays as being arrays of PackedMontyField31AVX2. - // If one of the arrays contains MontyField31, we get to avoid the initial vmovshdup. - // - // We improve the throughput by combining the monty reductions together. As all inputs are - // `< P < 2^{31}`, `l0*r0 + l1*r1 < 2P^2 < 2^{32}P` so the montgomery reduction - // algorithm can be applied to the sum of the products instead of to each product individually. - // - // We want this to compile to: - // vmovshdup lhs_odd0, lhs0 - // vmovshdup rhs_odd0, rhs0 - // vmovshdup lhs_odd1, lhs1 - // vmovshdup rhs_odd1, rhs1 - // vpmuludq prod_evn0, lhs0, rhs0 - // vpmuludq prod_odd0, lhs_odd0, rhs_odd0 - // vpmuludq prod_evn1, lhs1, rhs1 - // vpmuludq prod_odd1, lhs_odd1, rhs_odd1 - // vpaddq prod_evn, prod_evn0, prod_evn1 - // vpaddq prod_odd, prod_odd0, prod_odd1 - // vpmuludq q_evn, prod_evn, MU - // vpmuludq q_odd, prod_odd, MU - // vpmuludq q_P_evn, q_evn, P - // vpmuludq q_P_odd, q_odd, P - // vpsubq d_evn, prod_evn, q_P_evn - // vpsubq d_odd, prod_odd, q_P_odd - // vmovshdup d_evn_hi, d_evn - // vpblendd t, d_evn_hi, d_odd, aah - // vpaddd u, t, P - // vpminud res, t, u - // throughput: 6.67 cyc/vec (1.20 els/cyc) - // latency: 21 cyc - unsafe { - let lhs_evn0 = lhs[0].as_m256i(); - let lhs_odd0 = lhs[0].as_shifted_m256i(); - let lhs_evn1 = lhs[1].as_m256i(); - let lhs_odd1 = lhs[1].as_shifted_m256i(); - - let rhs_evn0 = rhs[0].as_m256i(); - let rhs_odd0 = rhs[0].as_shifted_m256i(); - let rhs_evn1 = rhs[1].as_m256i(); - let rhs_odd1 = rhs[1].as_shifted_m256i(); - - let mul_evn0 = x86_64::_mm256_mul_epu32(lhs_evn0, rhs_evn0); - let mul_evn1 = x86_64::_mm256_mul_epu32(lhs_evn1, rhs_evn1); - let mul_odd0 = x86_64::_mm256_mul_epu32(lhs_odd0, rhs_odd0); - let mul_odd1 = x86_64::_mm256_mul_epu32(lhs_odd1, rhs_odd1); - - let dot_evn = x86_64::_mm256_add_epi64(mul_evn0, mul_evn1); - let dot_odd = x86_64::_mm256_add_epi64(mul_odd0, mul_odd1); - - let red_evn = partial_monty_red_unsigned_to_signed::(dot_evn); - let red_odd = partial_monty_red_unsigned_to_signed::(dot_odd); - - let t = blend_evn_odd(red_evn, red_odd); - red_signed_to_canonical::(t) - } -} - -/// Compute the elementary function `l0*r0 + l1*r1 + l2*r2 + l3*r3` given eight inputs -/// in canonical form. -/// -/// If the inputs are not in canonical form, the result is undefined. -#[inline] -#[must_use] -#[allow(private_bounds)] -pub fn dot_product_4, RHS: IntoM256>( - lhs: [LHS; 4], - rhs: [RHS; 4], -) -> __m256i { - // The following analysis treats all input arrays as being arrays of PackedMontyField31AVX2. - // If one of the arrays contains MontyField31, we get to avoid the initial vmovshdup. - // - // Similarly to dot_product_2, we improve throughput by combining monty reductions however in this case - // we will need to slightly adjust the reduction algorithm. - // - // As all inputs are `< P < 2^{31}`, the sum satisfies: `C = l0*r0 + l1*r1 + l2*r2 + l3*r3 < 4P^2 < 2*2^{32}P`. - // Start by computing Q := μ C mod B as usual. - // We can't proceed as normal however as 2*2^{32}P > C - QP > -2^{32}P which doesn't fit into an i64. - // Instead we do a reduction on C, defining C' = if C < 2^{32}P: {C} else {C - 2^{32}P} - // From here we proceed with the standard montgomery reduction with C replaced by C'. It works identically - // with the Q we already computed as C = C' mod B. - // - // We want this to compile to: - // vmovshdup lhs_odd0, lhs0 - // vmovshdup rhs_odd0, rhs0 - // vmovshdup lhs_odd1, lhs1 - // vmovshdup rhs_odd1, rhs1 - // vmovshdup lhs_odd2, lhs2 - // vmovshdup rhs_odd2, rhs2 - // vmovshdup lhs_odd3, lhs3 - // vmovshdup rhs_odd3, rhs3 - // vpmuludq prod_evn0, lhs0, rhs0 - // vpmuludq prod_odd0, lhs_odd0, rhs_odd0 - // vpmuludq prod_evn1, lhs1, rhs1 - // vpmuludq prod_odd1, lhs_odd1, rhs_odd1 - // vpmuludq prod_evn2, lhs2, rhs2 - // vpmuludq prod_odd2, lhs_odd2, rhs_odd2 - // vpmuludq prod_evn3, lhs3, rhs3 - // vpmuludq prod_odd3, lhs_odd3, rhs_odd3 - // vpaddq prod_evn01, prod_evn0, prod_evn1 - // vpaddq prod_odd01, prod_odd0, prod_odd1 - // vpaddq prod_evn23, prod_evn2, prod_evn3 - // vpaddq prod_odd23, prod_odd2, prod_odd3 - // vpaddq dot_evn, prod_evn01, prod_evn23 - // vpaddq dot_odd, prod_odd01, prod_odd23 - // vmovshdup dot_evn_hi, dot_evn - // vpblendd dot, dot_evn_hi, dot_odd, aah - // vpmuludq q_evn, dot_evn, MU - // vpmuludq q_odd, dot_odd, MU - // vpmuludq q_P_evn, q_evn, P - // vpmuludq q_P_odd, q_odd, P - // vmovshdup q_P_evn_hi, q_P_evn - // vpblendd q_P, q_P_evn_hi, q_P_odd, aah - // vpsubq dot_sub, dot, P - // vpminud dot_prime, dot, dot_sub - // vpsubq t, dot_prime, q_P - // vpaddd u, t, P - // vpminud res, t, u - // throughput: 11.67 cyc/vec (0.69 els/cyc) - // latency: 22 cyc - unsafe { - let lhs_evn0 = lhs[0].as_m256i(); - let lhs_odd0 = lhs[0].as_shifted_m256i(); - let lhs_evn1 = lhs[1].as_m256i(); - let lhs_odd1 = lhs[1].as_shifted_m256i(); - let lhs_evn2 = lhs[2].as_m256i(); - let lhs_odd2 = lhs[2].as_shifted_m256i(); - let lhs_evn3 = lhs[3].as_m256i(); - let lhs_odd3 = lhs[3].as_shifted_m256i(); - - let rhs_evn0 = rhs[0].as_m256i(); - let rhs_odd0 = rhs[0].as_shifted_m256i(); - let rhs_evn1 = rhs[1].as_m256i(); - let rhs_odd1 = rhs[1].as_shifted_m256i(); - let rhs_evn2 = rhs[2].as_m256i(); - let rhs_odd2 = rhs[2].as_shifted_m256i(); - let rhs_evn3 = rhs[3].as_m256i(); - let rhs_odd3 = rhs[3].as_shifted_m256i(); - - let mul_evn0 = x86_64::_mm256_mul_epu32(lhs_evn0, rhs_evn0); - let mul_evn1 = x86_64::_mm256_mul_epu32(lhs_evn1, rhs_evn1); - let mul_evn2 = x86_64::_mm256_mul_epu32(lhs_evn2, rhs_evn2); - let mul_evn3 = x86_64::_mm256_mul_epu32(lhs_evn3, rhs_evn3); - let mul_odd0 = x86_64::_mm256_mul_epu32(lhs_odd0, rhs_odd0); - let mul_odd1 = x86_64::_mm256_mul_epu32(lhs_odd1, rhs_odd1); - let mul_odd2 = x86_64::_mm256_mul_epu32(lhs_odd2, rhs_odd2); - let mul_odd3 = x86_64::_mm256_mul_epu32(lhs_odd3, rhs_odd3); - - let dot_evn01 = x86_64::_mm256_add_epi64(mul_evn0, mul_evn1); - let dot_odd01 = x86_64::_mm256_add_epi64(mul_odd0, mul_odd1); - let dot_evn23 = x86_64::_mm256_add_epi64(mul_evn2, mul_evn3); - let dot_odd23 = x86_64::_mm256_add_epi64(mul_odd2, mul_odd3); - - let dot_evn = x86_64::_mm256_add_epi64(dot_evn01, dot_evn23); - let dot_odd = x86_64::_mm256_add_epi64(dot_odd01, dot_odd23); - - // We only care about the top 32 bits of dot_evn/odd. - // They currently lie in [0, 2P] so we reduce them to [0, P) - let dot = blend_evn_odd(dot_evn, dot_odd); - let dot_sub = x86_64::_mm256_sub_epi32(dot, PMP::PACKED_P); - let dot_prime = x86_64::_mm256_min_epu32(dot, dot_sub); - - let q_evn = x86_64::_mm256_mul_epu32(dot_evn, PMP::PACKED_MU); - let q_p_evn = x86_64::_mm256_mul_epu32(q_evn, PMP::PACKED_P); - let q_odd = x86_64::_mm256_mul_epu32(dot_odd, PMP::PACKED_MU); - let q_p_odd = x86_64::_mm256_mul_epu32(q_odd, PMP::PACKED_P); - - // Similarly we only need to care about the top 32 bits of q_p_odd/evn - let q_p = blend_evn_odd(q_p_evn, q_p_odd); - - let t = x86_64::_mm256_sub_epi32(dot_prime, q_p); - red_signed_to_canonical::(t) - } -} +// `dot_product_2` and `dot_product_4` were specialised batched-Montgomery dot products that +// accumulated two or four 64-bit products into a single u64 before reduction. That bound only +// holds for `P < 2^31`. With the current 32-bit prime (`0xfa000001`) two products already +// overflow u64, so the helpers have been removed. `general_dot_product` accumulates each +// product through the canonical-returning `Mul` instead. /// A general fast dot product implementation. -/// -/// Maximises the number of calls to `dot_product_4` for dot products involving vectors of length -/// more than 4. The length 64 occurs commonly enough it's useful to have a custom implementation -/// which lets it use a slightly better summation algorithm with lower latency. #[inline(always)] fn general_dot_product, RHS: IntoM256, const N: usize>( lhs: &[LHS], @@ -664,140 +408,14 @@ fn general_dot_product, RHS: IntoM256 ) -> PackedMontyField31AVX2 { assert_eq!(lhs.len(), N); assert_eq!(rhs.len(), N); - match N { - 0 => PackedMontyField31AVX2::::ZERO, - 1 => (lhs[0]).into() * (rhs[0]).into(), - 2 => { - let res = dot_product_2([lhs[0], lhs[1]], [rhs[0], rhs[1]]); - unsafe { - // Safety: `dot_product_2` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX2::::from_vector(res) - } - } - 3 => { - let lhs2 = lhs[2]; - let rhs2 = rhs[2]; - let res = dot_product_2([lhs[0], lhs[1]], [rhs[0], rhs[1]]); - unsafe { - // Safety: `dot_product_2` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX2::::from_vector(res) + (lhs2.into() * rhs2.into()) - } - } - 4 => { - let res = dot_product_4([lhs[0], lhs[1], lhs[2], lhs[3]], [rhs[0], rhs[1], rhs[2], rhs[3]]); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX2::::from_vector(res) - } - } - 64 => { - let sum_4s: [PackedMontyField31AVX2; 16] = array::from_fn(|i| { - let res = dot_product_4( - [lhs[4 * i], lhs[4 * i + 1], lhs[4 * i + 2], lhs[4 * i + 3]], - [rhs[4 * i], rhs[4 * i + 1], rhs[4 * i + 2], rhs[4 * i + 3]], - ); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX2::::from_vector(res) - } - }); - PackedMontyField31AVX2::::sum_array::<16>(&sum_4s) - } - _ => { - let mut acc = { - let res = dot_product_4([lhs[0], lhs[1], lhs[2], lhs[3]], [rhs[0], rhs[1], rhs[2], rhs[3]]); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX2::::from_vector(res) - } - }; - for i in (4..(N - 3)).step_by(4) { - let res = dot_product_4( - [lhs[i], lhs[i + 1], lhs[i + 2], lhs[i + 3]], - [rhs[i], rhs[i + 1], rhs[i + 2], rhs[i + 3]], - ); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - acc += PackedMontyField31AVX2::::from_vector(res) - } - } - match N & 3 { - 0 => acc, - 1 => acc + general_dot_product::<_, _, _, 1>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - 2 => acc + general_dot_product::<_, _, _, 2>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - 3 => acc + general_dot_product::<_, _, _, 3>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - _ => unreachable!(), - } - } + if N == 0 { + return PackedMontyField31AVX2::::ZERO; } -} - -/// Square the MontyField31 field elements in the even index entries. -/// Inputs must be signed 32-bit integers. -/// Outputs will be a signed integer in (-P, ..., P) copied into both the even and odd indices. -#[inline] -#[must_use] -fn shifted_square(input: __m256i) -> __m256i { - // Note that we do not need a restriction on the size of input[i]^2 as - // 2^30 < P and |i32| <= 2^31 and so => input[i]^2 <= 2^62 < 2^32P. - unsafe { - let square = x86_64::_mm256_mul_epi32(input, input); - let square_red = partial_monty_red_unsigned_to_signed::(square); - movehdup_epi32(square_red) - } -} - -/// Compute the elementary arithmetic generalization of `xor`, namely `xor(l, r) = l + r - 2lr` of -/// vectors in canonical form. -/// -/// Inputs are assumed to be in canonical form, if the inputs are not in canonical form, the result is undefined. -#[inline] -#[must_use] -fn xor(lhs: __m256i, rhs: __m256i) -> __m256i { - // Refactor the expression as r + 2l(1/2 - r). As MONTY_CONSTANT = 2^32, the internal - // representation 1/2 is 2^31 mod P so the product in the above expression is represented - // as 2l(2^31 - r). As 0 < 2l, 2^31 - r < 2^32 and 2l(2^31 - r) < 2^32P, we can compute - // the factors as 32 bit integers and then multiply and monty reduce as usual. - // - // We want this to compile to: - // vpaddd lhs_double, lhs, lhs - // vpsubd sub_rhs, rhs, (1 << 31) - // vmovshdup lhs_odd, lhs_double - // vmovshdup rhs_odd, sub_rhs - // vpmuludq prod_evn, lhs_double, sub_rhs - // vpmuludq prod_odd, lhs_odd, rhs_odd - // vpmuludq q_evn, prod_evn, MU - // vpmuludq q_odd, prod_odd, MU - // vpmuludq q_P_evn, q_evn, P - // vpmuludq q_P_odd, q_odd, P - // vpsubq d_evn, prod_evn, q_P_evn - // vpsubq d_odd, prod_odd, q_P_odd - // vmovshdup d_evn_hi, d_evn - // vpblendd t, d_evn_hi, d_odd, aah - // vpsignd pos_neg_P, P, t - // vpaddd sum, rhs, t - // vpsubd sum_corr, sum, pos_neg_P - // vpminud res, sum, sum_corr - // throughput: 6 cyc/vec (1.33 els/cyc) - // latency: 22 cyc - unsafe { - // 0 <= 2*lhs < 2P - let double_lhs = x86_64::_mm256_add_epi32(lhs, lhs); - - // Note that 2^31 is represented as an i32 as (-2^31). - // Compiler should realise this is a constant. - let half = x86_64::_mm256_set1_epi32(-1 << 31); - - // 0 < 2^31 - rhs < 2^31 - let half_sub_rhs = x86_64::_mm256_sub_epi32(half, rhs); - - // 2*lhs (2^31 - rhs) < 2P 2^31 < 2^32P so we can use the multiplication function. - let mul_res = mul::(double_lhs, half_sub_rhs); - - // As -P < mul_res < P and 0 <= rhs < P, we can use signed add - // which saves an instruction over reducing mul_res and adding in the usual way. - signed_add_avx2::(rhs, mul_res) + let mut acc: PackedMontyField31AVX2 = lhs[0].into() * rhs[0].into(); + for i in 1..N { + acc += lhs[i].into() * rhs[i].into(); } + acc } /// Compute the elementary arithmetic generalization of `andnot`, namely `andn(l, r) = (1 - l)r` of @@ -830,102 +448,55 @@ fn andn(lhs: __m256i, rhs: __m256i) -> __m256i { // throughput: 5 cyc/vec (1.6 els/cyc) // latency: 20 cyc unsafe { - // We use 2^32 - P instead of 2^32 to avoid having to worry about 0's in lhs. - - // Compiler should realise that this is a constant. - let neg_p = x86_64::_mm256_sub_epi32(x86_64::_mm256_setzero_si256(), MPAVX2::PACKED_P); - let neg_lhs = x86_64::_mm256_sub_epi32(neg_p, lhs); - - // 2*lhs (2^31 - rhs) < 2P 2^31 < 2^32P so we can use the multiplication function. - let mul_res = mul::(neg_lhs, rhs); + // M(1) = 2^32 mod P = 2^32 - P (since P < 2^32). We need + // `neg_lhs = M(1) - lhs (mod P)`, i.e. a canonical value in [0, P). + // + // For P < 2^31 the simple `(2^32 - P) - lhs` computed mod 2^32 is + // already in [0, M(1)] ⊂ [0, P) because lhs < P < M(1). For P > 2^31, + // M(1) = 2^32 - P < P, so `lhs` may exceed M(1) and the subtraction + // wraps mod 2^32, producing a value in [M(1)+1, 2^32) that is no + // longer congruent to `M(1) - lhs` mod P (it differs by `2^32 mod P`). + // + // Use `mm256_mod_sub` to do the subtraction correctly mod P for any + // `P < 2^32`. + let one_monty = x86_64::_mm256_sub_epi32(x86_64::_mm256_setzero_si256(), MPAVX2::PACKED_P); + let neg_lhs = mm256_mod_sub(one_monty, lhs, MPAVX2::PACKED_P); - // As -P < mul_res < P we just need to reduce elements to canonical form. - red_signed_to_canonical::(mul_res) + // `neg_lhs` is canonical in [0, P), so `neg_lhs * rhs < 2^32 * P` and + // we can apply the standard Montgomery reduction. `mul` returns a + // canonical result. + mul::(neg_lhs, rhs) } } -/// Cube the MontyField31 field elements in the even index entries. -/// Inputs must be signed 32-bit integers in [-P, ..., P]. -/// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices. -#[inline] -#[must_use] -pub(crate) fn packed_exp_3(input: __m256i) -> __m256i { - let square = shifted_square::(input); - monty_mul_signed::(square, input) -} - -/// Take the fifth power of the MontyField31 field elements in the even index entries. -/// Inputs must be signed 32-bit integers in [-P, ..., P]. -/// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices. -#[inline] -#[must_use] -pub(crate) fn packed_exp_5(input: __m256i) -> __m256i { - let square = shifted_square::(input); - let quad = shifted_square::(square); - monty_mul_signed::(quad, input) -} - -/// Take the seventh power of the MontyField31 field elements in the even index entries. -/// Inputs must lie in [-P, ..., P]. -/// Outputs will also lie in (-P, ..., P) stored in the odd indices. -#[inline] -#[must_use] -pub(crate) fn packed_exp_7(input: __m256i) -> __m256i { - let square = shifted_square::(input); - let cube = monty_mul_signed::(square, input); - let cube_shifted = movehdup_epi32(cube); - let quad = shifted_square::(square); - - monty_mul_signed::(quad, cube_shifted) -} - -/// Apply func to the even and odd indices of the input vector. -/// func should only depend in the 32 bit entries in the even indices. -/// The output of func must lie in (-P, ..., P) and be stored in the odd indices. -/// The even indices of the output of func will not be read. -/// The input should conform to the requirements of `func`. -#[inline] -#[must_use] -pub(crate) unsafe fn apply_func_to_even_odd( - input: __m256i, - func: fn(__m256i) -> __m256i, -) -> __m256i { - let input_evn = input; - let input_odd = movehdup_epi32(input); - - let d_evn = func(input_evn); - let d_odd = func(input_odd); - - let t = blend_evn_odd(d_evn, d_odd); - red_signed_to_canonical::(t) -} +// `packed_exp_*`, `apply_func_to_even_odd`, `monty_mul_signed`, +// `partial_monty_red_signed_to_signed`, `shifted_square`, and +// `red_signed_to_canonical` were used to implement the optimised cube / +// exp_5 / exp_7 paths. They all relied on a "signed in `(-P, P)` stored as +// i32" intermediate representation, which is unambiguous only when +// `P < 2^31`. With the current 32-bit prime they were producing wrong +// results, so the optimised paths have been removed and `cube` / +// `exp_const_u64` now go through the canonical-returning `Mul` directly. /// Negate a vector of MontyField31 field elements in canonical form. /// If the inputs are not in canonical form, the result is undefined. #[inline] #[must_use] fn neg(val: __m256i) -> __m256i { - // We want this to compile to: - // vpsubd t, P, val - // vpsignd res, t, val - // throughput: .67 cyc/vec (12 els/cyc) - // latency: 2 cyc - - // The vpsignd instruction is poorly named, because it doesn't _return_ or _copy_ the sign of - // anything, but _multiplies_ x by the sign of y (treating both as signed integers). In other - // words, - // { x if y >s 0, - // vpsignd(x, y) := { 0 if y = 0, - // { -x mod 2^32 if y s 0. If val = 0, then res = vpsignd(t, 0) = 0, as desired. Otherwise, - // res = vpsignd(t, val) = t passes t through. + // We want to return `(P - val) mod P`, i.e. `0` if `val == 0` else `P - val`. + // + // The previous implementation used `vpsignd(P - val, val)`, which exploits the + // fact that `val` is non-negative as a signed i32 when `P < 2^31`. With the + // current 32-bit prime (`P > 2^31`) `val` can have its sign bit set and + // `vpsignd` would return `-(P - val)` instead of `(P - val)`. So we use a + // mask-based approach that works for any `P < 2^32`. unsafe { // Safety: If this code got compiled then AVX2 intrinsics are available. let t = x86_64::_mm256_sub_epi32(MPAVX2::PACKED_P, val); - x86_64::_mm256_sign_epi32(t, val) + let zero = x86_64::_mm256_setzero_si256(); + let is_zero = x86_64::_mm256_cmpeq_epi32(val, zero); + // `andnot(is_zero, t)` = `(!is_zero) & t` = `t` when `val != 0`, else `0`. + x86_64::_mm256_andnot_si256(is_zero, t) } } diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/poseidon_helpers.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/poseidon_helpers.rs index 396a4635a..0e1fd16b3 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/poseidon_helpers.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/poseidon_helpers.rs @@ -1,84 +1,9 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). //! AVX2 helpers shared by Poseidon1 permutations. - -use core::arch::x86_64::{self, __m256i}; -use core::mem::transmute; - -use super::{apply_func_to_even_odd, packed_exp_3, packed_exp_5, packed_exp_7}; -use crate::{MontyParameters, PackedMontyField31AVX2, PackedMontyParameters}; - -/// A specialized representation of the Poseidon state for a width of 16. -/// -/// Splits the state into `s0` (undergoes S-box) and `s_hi` (undergoes only linear transforms), -/// enabling instruction-level parallelism between the two independent data paths. -#[derive(Clone, Copy)] -#[repr(C)] -pub struct InternalLayer16 { - pub(crate) s0: PackedMontyField31AVX2, - pub(crate) s_hi: [__m256i; 15], -} - -impl InternalLayer16 { - #[inline] - pub(crate) unsafe fn to_packed_field_array(self) -> [PackedMontyField31AVX2; 16] { - unsafe { transmute(self) } - } - - #[inline] - #[must_use] - pub(crate) fn from_packed_field_array(vector: [PackedMontyField31AVX2; 16]) -> Self { - unsafe { transmute(vector) } - } -} - -/// Use hard coded methods to compute `x -> x^D` for the even index entries and small `D`. -/// Inputs should be signed 32-bit integers in `[-P, ..., P]`. -/// Outputs will also be signed integers in `(-P, ..., P)` stored in the odd indices. -#[inline(always)] -#[must_use] -pub(crate) fn exp_small(val: __m256i) -> __m256i { - match D { - 3 => packed_exp_3::(val), - 5 => packed_exp_5::(val), - 7 => packed_exp_7::(val), - _ => panic!("No exp function for given D"), - } -} - -/// Converts a scalar constant into a packed AVX2 vector in "negative form" (`c - P`). -#[inline(always)] -pub(crate) fn convert_to_vec_neg_form(input: i32) -> __m256i { - let input_sub_p = input - (MP::PRIME as i32); - unsafe { x86_64::_mm256_set1_epi32(input_sub_p) } -} - -/// Performs the fused AddRoundConstant and S-Box operation `x -> (x + c)^D`. -/// -/// `val` must contain elements in canonical form `[0, P)`. -/// `rc` must contain round constants in negative form `[-P, 0)`. -#[inline(always)] -pub(crate) fn add_rc_and_sbox( - val: &mut PackedMontyField31AVX2, - rc: __m256i, -) { - unsafe { - let vec_val = val.to_vector(); - let val_plus_rc = x86_64::_mm256_add_epi32(vec_val, rc); - let output = apply_func_to_even_odd::(val_plus_rc, exp_small::); - *val = PackedMontyField31AVX2::::from_vector(output); - } -} - -/// Applies the S-Box `x -> x^D` to a packed vector. Output is in canonical form. -#[inline(always)] -#[must_use] -pub(crate) fn sbox( - val: PackedMontyField31AVX2, -) -> PackedMontyField31AVX2 { - unsafe { - let vec = val.to_vector(); - let out = apply_func_to_even_odd::(vec, exp_small::); - PackedMontyField31AVX2::::from_vector(out) - } -} +//! +//! The optimised batched-S-box helpers (`exp_small`, `add_rc_and_sbox`, +//! `sbox`, `InternalLayer16`) relied on a "signed in (-P, P)" intermediate +//! stored as i32, which is unambiguous only for `P < 2^31`. They have +//! been removed for the current 32-bit prime; AVX2 Poseidon falls through +//! the generic `permute_generic` path which uses canonical Mul. diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/utils.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/utils.rs index 8c6f01f09..456816ad1 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx2/utils.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx2/utils.rs @@ -3,7 +3,7 @@ use core::arch::x86_64::{self, __m256i}; use core::mem::transmute; -use crate::{MontyParameters, MontyParametersAVX2, TwoAdicData}; +use crate::{FieldParameters, MontyParameters, MontyParametersAVX2, TwoAdicData}; // Godbolt file showing that these all compile to the expected instructions. (Potentially plus a few memory ops): // https://godbolt.org/z/9P71nYrqh @@ -12,28 +12,20 @@ use crate::{MontyParameters, MontyParametersAVX2, TwoAdicData}; /// /// If the inputs are not in canonical form, the result is undefined. #[inline(always)] -pub(crate) fn halve_avx2(input: __m256i) -> __m256i { +pub(crate) fn halve_avx2(input: __m256i) -> __m256i { /* - We want this to compile to: - vpand least_bit, val, ONE - vpsrld t, val, 1 - vpsignd maybe_half, HALF, least_bit - vpaddd res, t, maybe_half - throughput: 1.33 cyc/vec - latency: 3 cyc - Given an element val in [0, P), we want to compute val/2 mod P. If val is even: val/2 mod P = val/2 = val >> 1. If val is odd: val/2 mod P = (val + P)/2 = (val >> 1) + (P + 1)/2 */ unsafe { - // Safety: If this code got compiled then AVX2 intrinsics are available. const ONE: __m256i = unsafe { transmute([1u32; 8]) }; - let half = x86_64::_mm256_set1_epi32((MP::PRIME as i32 + 1) / 2); // Compiler realises this is constant. + // HALF_P_PLUS_1 = (P + 1) / 2, computed correctly at u32 level. + // For P = 0xFA000001: HALF_P_PLUS_1 = 0x7D000001 which fits in positive i32. + let half = x86_64::_mm256_set1_epi32(FP::HALF_P_PLUS_1 as i32); - let least_bit = x86_64::_mm256_and_si256(input, ONE); // Determine the parity of val. + let least_bit = x86_64::_mm256_and_si256(input, ONE); let t = x86_64::_mm256_srli_epi32::<1>(input); - // This does nothing when least_bit = 1 and sets the corresponding entry to 0 when least_bit = 0 let maybe_half = x86_64::_mm256_sign_epi32(half, least_bit); x86_64::_mm256_add_epi32(t, maybe_half) } @@ -45,45 +37,46 @@ pub(crate) fn halve_avx2(input: __m256i) -> __m256i { /// /// This function is not symmetric in the inputs. The caller must ensure that inputs /// conform to the expected representation. Each element of lhs must lie in [0, P) and -/// each element of rhs in (-P, P). +/// each element of rhs in (-P, P) (as signed i32 for P < 2^31, or as u32 with wrapping for P >= 2^31). +/// +/// For P > 2^31, rhs in (-P, P) means: the mathematical value D satisfies -P < D < P. +/// If D >= 0, it's stored as D (u32). If D < 0, it's stored as 2^32 + D (u32 wrapping). +/// +/// The output is in [0, P) canonical form. #[inline(always)] pub(crate) unsafe fn signed_add_avx2(lhs: __m256i, rhs: __m256i) -> __m256i { - /* - We want this to compile to: - vpsignd pos_neg_P, P, rhs - vpaddd sum, lhs, rhs - vpsubd sum_corr, sum, pos_neg_P - vpminud res, sum, sum_corr - throughput: 1.33 cyc/vec - latency: 3 cyc - - While this is more expensive than an add, it is cheaper than reducing the rhs to a canonical value and then adding. - - We give a short proof that the output is correct: - - Let t = lhs + rhs mod 2^32, we want to return t mod P while correcting for any possible wraparound. - We make use of the fact wrapping addition acts identically on signed and unsigned inputs. - - If rhs is positive, lhs + rhs < 2P < 2^32 and so we interpret t as a unsigned 32 bit integer. - In this case, t mod P = min_{u32}(t, t - P) where min_{u32} takes the min treating both inputs as unsigned 32 bit integers. - This works as if t >= P then t - P < t and if t < P then, due to wraparound, t - P outputs t - P + 2^32 > t. - If rhs is negative, -2^31 < -P < lhs + rhs < P < 2^31 so we interpret t as a signed 32 bit integer. - In this case t mod P = min_{u32}(t, t + P) - This works as if t > 0 then t < t + P and if t < 0 then due to wraparound when we interpret t as an unsigned integer it becomes - 2^32 + t > t + P. - if rhs = 0 then we can just return t = lhs as it is already in the desired range. - */ + // For P > 2^31 we cannot use the min trick. Instead: + // 1. Canonicalize rhs from (-P, P) to [0, P) by detecting negative values + // and adding P. For P > 2^31, "negative" means the u32 value is in [2^32-P+1, 2^32-1]. + // We can't distinguish from the u32 alone, but we know |D| < P so D < 0 iff D_u32 >= P. + // 2. Then do a standard canonical add. + // + // rhs in (-P, P): if D >= 0, rhs_u32 = D ∈ [0, P-1]. + // if D < 0, rhs_u32 = 2^32 + D ∈ [2^32-P+1, 2^32-1]. + // Since P < 2^32, the negative range starts at 2^32 - P + 1 > P, so rhs_u32 >= P iff D < 0. unsafe { - // If rhs > 0 set the value to P, if rhs < 0 set it to -P and if rhs = 0 set it to 0. - let pos_neg_p = x86_64::_mm256_sign_epi32(MPAVX2::PACKED_P, rhs); - - // Compute t = lhs + rhs - let sum = x86_64::_mm256_add_epi32(lhs, rhs); - - // sum_corr = (t - P) if rhs > 0, t + P if rhs < 0 and t if rhs = 0 as desired. - let sum_corr = x86_64::_mm256_sub_epi32(sum, pos_neg_p); - - x86_64::_mm256_min_epu32(sum, sum_corr) + // Detect negative: rhs >= P (unsigned) + let flip = x86_64::_mm256_set1_epi32(i32::MIN); + let rhs_f = x86_64::_mm256_xor_si256(rhs, flip); + let p_m1_f = x86_64::_mm256_xor_si256( + x86_64::_mm256_sub_epi32(MPAVX2::PACKED_P, x86_64::_mm256_set1_epi32(1)), + flip, + ); + let is_neg = x86_64::_mm256_cmpgt_epi32(rhs_f, p_m1_f); // rhs > P-1 unsigned → negative + let corr = x86_64::_mm256_and_si256(is_neg, MPAVX2::PACKED_P); + let rhs_canon = x86_64::_mm256_add_epi32(rhs, corr); // canonicalize: add P if negative + + // Now do standard modular add of two canonical values + let t = x86_64::_mm256_add_epi32(lhs, rhs_canon); + let u = x86_64::_mm256_sub_epi32(t, MPAVX2::PACKED_P); + // Detect overflow or t >= P + let lhs_f = x86_64::_mm256_xor_si256(lhs, flip); + let t_f = x86_64::_mm256_xor_si256(t, flip); + let overflow = x86_64::_mm256_cmpgt_epi32(lhs_f, t_f); + let t_f2 = x86_64::_mm256_xor_si256(t, flip); + let geq_p = x86_64::_mm256_cmpgt_epi32(t_f2, p_m1_f); + let mask = x86_64::_mm256_or_si256(overflow, geq_p); + x86_64::_mm256_blendv_epi8(t, u, mask) } } diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/mod.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/mod.rs index 993e232b9..0279432f7 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/mod.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/mod.rs @@ -1,9 +1,8 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). mod packing; -pub(crate) mod poseidon_helpers; +mod poseidon_helpers; mod utils; pub use packing::*; -pub(crate) use poseidon_helpers::*; pub use utils::*; diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/packing.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/packing.rs index fff1e6ef1..bede3a659 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/packing.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/packing.rs @@ -7,7 +7,6 @@ use alloc::vec::Vec; use core::arch::asm; use core::arch::x86_64::{self, __m256i, __m512i, __mmask8, __mmask16}; -use core::array; use core::hint::unreachable_unchecked; use core::iter::{Product, Sum}; use core::mem::transmute; @@ -199,24 +198,18 @@ impl PrimeCharacteristicRing for PackedMontyField31AVX512 Self { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_3::); - Self::from_vector(res) - } + // The optimized `packed_exp_3` path assumes inputs fit in [-P, P] as + // signed 32-bit integers (i.e. P < 2^31). The current 32-bit prime + // (P = 0xfa000001 ≈ 2^32) violates that assumption, so fall back to + // unsigned Montgomery multiplication. + *self * self.square() } - #[inline] - fn xor(&self, rhs: &Self) -> Self { - let lhs = self.to_vector(); - let rhs = rhs.to_vector(); - let res = xor::(lhs, rhs); - unsafe { - // Safety: `xor` returns values in canonical form when given values in canonical form. - Self::from_vector(res) - } - } + // `xor` is left to the default trait implementation. The optimised + // packed routine assumes `2*lhs` and `2^31 - rhs` fit in a u32 without + // wrap-around, which only holds for `P < 2^31`. With the current 32-bit + // prime the wrap can push `lhs * rhs` past `2^32 * P` and break the + // Montgomery reduction. #[inline] fn andn(&self, rhs: &Self) -> Self { @@ -229,37 +222,10 @@ impl PrimeCharacteristicRing for PackedMontyField31AVX512(&self) -> Self { - // We provide specialised code for the powers 3, 5, 7 as these turn up regularly. - // The other powers could be specialised similarly but we ignore this for now. - // These ideas could also be used to speed up the more generic exp_u64. - match POWER { - 0 => Self::ONE, - 1 => *self, - 2 => self.square(), - 3 => self.cube(), - 4 => self.square().square(), - 5 => { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_5::); - Self::from_vector(res) - } - } - 6 => self.square().cube(), - 7 => { - let val = self.to_vector(); - unsafe { - // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. - let res = apply_func_to_even_odd::(val, packed_exp_7::); - Self::from_vector(res) - } - } - _ => self.exp_u64(POWER), - } - } + // The `packed_exp_*` helpers assume P < 2^31. With the current 32-bit + // prime they would overflow, so we fall back to the default + // `exp_const_u64` implementation, which composes correct `square` / + // `cube` calls. #[inline(always)] fn dot_product(u: &[Self; N], v: &[Self; N]) -> Self { @@ -364,41 +330,11 @@ fn confuse_compiler_256(x: __m256i) -> __m256i { // [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press, // 2010, algorithm 2.7. -/// Perform a partial Montgomery reduction on each 64 bit element. -/// Input must lie in {0, ..., 2^32P}. -/// The output will lie in {-P, ..., P} and be stored in the upper 32 bits. -#[inline] -#[must_use] -fn partial_monty_red_unsigned_to_signed(input: __m512i) -> __m512i { - unsafe { - // We throw a confuse compiler here to prevent the compiler from - // using vpmullq instead of vpmuludq in the computations for q_p. - // vpmullq has both higher latency and lower throughput. - let q = confuse_compiler(x86_64::_mm512_mul_epu32(input, MPAVX512::PACKED_MU)); - let q_p = x86_64::_mm512_mul_epu32(q, MPAVX512::PACKED_P); - - // This could equivalently be _mm512_sub_epi64 - x86_64::_mm512_sub_epi32(input, q_p) - } -} - -/// Perform a partial Montgomery reduction on each 64 bit element. -/// Input must lie in {-2^{31}P, ..., 2^31P}. -/// The output will lie in {-P, ..., P} and be stored in the upper 32 bits. -#[inline] -#[must_use] -fn partial_monty_red_signed_to_signed(input: __m512i) -> __m512i { - unsafe { - // We throw a confuse compiler here to prevent the compiler from - // using vpmullq instead of vpmuludq in the computations for q_p. - // vpmullq has both higher latency and lower throughput. - let q = confuse_compiler(x86_64::_mm512_mul_epi32(input, MPAVX512::PACKED_MU)); - let q_p = x86_64::_mm512_mul_epi32(q, MPAVX512::PACKED_P); - - // This could equivalently be _mm512_sub_epi64 - x86_64::_mm512_sub_epi32(input, q_p) - } -} +// `partial_monty_red_unsigned_to_signed` / `partial_monty_red_signed_to_signed` +// (and their consumers `shifted_square`, `packed_exp_*`, `apply_func_to_even_odd`) +// produced an intermediate "signed in (-P, P)" value stored as i32, which is +// unambiguous only when `P < 2^31`. For the current 32-bit prime they have +// been removed; `cube` / `exp_const_u64` go through canonical Mul instead. /// Viewing the input as a vector of 16 `u32`s, copy the odd elements into the even elements below /// them. In other words, for all `0 <= i < 8`, set the even elements according to @@ -625,59 +561,6 @@ fn mul_256(lhs: __m256i, rhs: i32) -> __m256i { } /// Compute the elementary arithmetic generalization of `xor`, namely `xor(l, r) = l + r - 2lr` of -/// vectors in canonical form. -/// -/// Inputs are assumed to be in canonical form, if the inputs are not in canonical form, the result is undefined. -#[inline] -#[must_use] -fn xor(lhs: __m512i, rhs: __m512i) -> __m512i { - // Refactor the expression as r + 2l(1/2 - r). As MONTY_CONSTANT = 2^32, the internal - // representation 1/2 is 2^31 mod P so the product in the above expression is represented - // as 2l(2^31 - r). As 0 < 2l, 2^31 - r < 2^32 and 2l(2^31 - r) < 2^32P, we can compute - // the factors as 32 bit integers and then multiply and monty reduce as usual. - // - // We want this to compile to: - // vpaddd lhs_double, lhs, lhs - // vpsubd sub_rhs, (1 << 31), rhs - // vmovshdup lhs_odd, lhs_double - // vmovshdup rhs_odd, sub_rhs - // vpmuludq prod_evn, lhs_double, sub_rhs - // vpmuludq prod_hi, lhs_odd, rhs_odd - // vpmuludq q_evn, prod_evn, MU - // vpmuludq q_odd, prod_hi, MU - // vmovshdup prod_hi{EVENS}, prod_evn - // vpmuludq q_p_evn, q_evn, P - // vpmuludq q_p_hi, q_odd, P - // vmovshdup q_p_hi{EVENS}, q_p_evn - // vpcmpltud underflow, prod_hi, q_p_hi - // vpsubd res, prod_hi, q_p_hi - // vpaddd res{underflow}, res, P - // vpaddd sum, rhs, t - // vpsubd sum_corr, sum, pos_neg_P - // vpminud res, sum, sum_corr - // throughput: 9 cyc/vec (1.77 els/cyc) - // latency: 25 cyc - unsafe { - // 0 <= 2*lhs < 2P - let double_lhs = x86_64::_mm512_add_epi32(lhs, lhs); - - // Note that 2^31 is represented as an i32 as (-2^31). - // Compiler should realise this is a constant. - let half = x86_64::_mm512_set1_epi32(-1 << 31); - - // 0 < 2^31 - rhs < 2^31 - let half_sub_rhs = x86_64::_mm512_sub_epi32(half, rhs); - - // 2*lhs (2^31 - rhs) < 2P 2^31 < 2^32P so we can use the multiplication function. - let mul_res = mul::(double_lhs, half_sub_rhs); - - // Unfortunately, AVX512 has no equivalent of vpsignd so we can't do the same - // signed_add trick as in the AVX2 case. Instead we get a reduced value from mul - // and add on rhs in the standard way. - mm512_mod_add(rhs, mul_res, MPAVX512::PACKED_P) - } -} - /// Compute the elementary arithmetic generalization of `andnot`, namely `andn(l, r) = (1 - l)r` of /// vectors in canonical form. /// @@ -707,128 +590,25 @@ fn andn(lhs: __m512i, rhs: __m512i) -> __m512i // throughput: 7 cyc/vec (2.3 els/cyc) // latency: 22 cyc unsafe { - // We use 2^32 - P instead of 2^32 to avoid having to worry about 0's in lhs. - - // Compiler should realise that this is a constant. - let neg_p = x86_64::_mm512_sub_epi32(x86_64::_mm512_setzero_epi32(), MPAVX512::PACKED_P); - let neg_lhs = x86_64::_mm512_sub_epi32(neg_p, lhs); - - // 2*lhs (2^31 - rhs) < 2P 2^31 < 2^32P so we can use the multiplication function. + // M(1) = 2^32 mod P = 2^32 - P (since P < 2^32). We need + // `neg_lhs = M(1) - lhs (mod P)`, i.e. a canonical value in [0, P). + // + // For P < 2^31 the simple `(2^32 - P) - lhs` computed mod 2^32 is + // already in [0, M(1)] ⊂ [0, P) because lhs < P < M(1). For P > 2^31, + // M(1) = 2^32 - P < P, so `lhs` may exceed M(1) and the subtraction + // wraps mod 2^32, producing a value in [M(1)+1, 2^32) that is no + // longer congruent to `M(1) - lhs` mod P (it differs by `2^32 mod P`). + // + // Use `mm512_mod_sub` to do the subtraction correctly mod P for any + // `P < 2^32`. + let one_monty = x86_64::_mm512_sub_epi32(x86_64::_mm512_setzero_epi32(), MPAVX512::PACKED_P); + let neg_lhs = mm512_mod_sub(one_monty, lhs, MPAVX512::PACKED_P); + + // `neg_lhs` is canonical in [0, P), so `neg_lhs * rhs < 2^32 * P`. mul::(neg_lhs, rhs) } } -/// Square the MontyField31 elements in the even index entries. -/// Inputs must be signed 32-bit integers in [-P, ..., P]. -/// Outputs will be a signed integer in (-P, ..., P) copied into both the even and odd indices. -#[inline] -#[must_use] -fn shifted_square(input: __m512i) -> __m512i { - // Note that we do not need a restriction on the size of input[i]^2 as - // 2^30 < P and |i32| <= 2^31 and so => input[i]^2 <= 2^62 < 2^32P. - unsafe { - let square = x86_64::_mm512_mul_epi32(input, input); - let square_red = partial_monty_red_unsigned_to_signed::(square); - movehdup_epi32(square_red) - } -} - -/// Cube the MontyField31 elements in the even index entries. -/// Inputs must be signed 32-bit integers in [-P, ..., P]. -/// Outputs will be signed integers in (-P^2, ..., P^2). -#[inline] -#[must_use] -pub(crate) fn packed_exp_3(input: __m512i) -> __m512i { - unsafe { - let square = shifted_square::(input); - x86_64::_mm512_mul_epi32(square, input) - } -} - -/// Take the fifth power of the MontyField31 elements in the even index entries. -/// Inputs must be signed 32-bit integers in [-P, ..., P]. -/// Outputs will be signed integers in (-P^2, ..., P^2). -#[inline] -#[must_use] -pub(crate) fn packed_exp_5(input: __m512i) -> __m512i { - unsafe { - let square = shifted_square::(input); - let quad = shifted_square::(square); - x86_64::_mm512_mul_epi32(quad, input) - } -} - -/// Take the seventh power of the MontyField31 elements in the even index entries. -/// Inputs must lie in [-P, ..., P]. -/// Outputs will be signed integers in (-P^2, ..., P^2). -#[inline] -#[must_use] -pub(crate) fn packed_exp_7(input: __m512i) -> __m512i { - unsafe { - let square = shifted_square::(input); - let cube_raw = x86_64::_mm512_mul_epi32(square, input); - let cube_red = partial_monty_red_signed_to_signed::(cube_raw); - let cube = movehdup_epi32(cube_red); - let quad = shifted_square::(square); - x86_64::_mm512_mul_epi32(quad, cube) - } -} - -/// Apply func to the even and odd indices of the input vector. -/// -/// func should only depend in the 32 bit entries in the even indices. -/// The input should conform to the requirements of `func`. -/// The output of func must lie in (-P^2, ..., P^2) after which -/// apply_func_to_even_odd will reduce the outputs to lie in [0, P) -/// and recombine the odd and even parts. -#[inline] -#[must_use] -pub(crate) unsafe fn apply_func_to_even_odd( - input: __m512i, - func: fn(__m512i) -> __m512i, -) -> __m512i { - unsafe { - let input_evn = input; - let input_odd = movehdup_epi32(input); - - let output_even = func(input_evn); - let output_odd = func(input_odd); - - // We need to recombine these even and odd parts and, at the same time reduce back to - // an output in [0, P). - - // We throw a confuse compiler here to prevent the compiler from - // using vpmullq instead of vpmuludq in the computations for q_p. - // vpmullq has both higher latency and lower throughput. - let q_evn = confuse_compiler(x86_64::_mm512_mul_epi32(output_even, MPAVX512::PACKED_MU)); - let q_odd = confuse_compiler(x86_64::_mm512_mul_epi32(output_odd, MPAVX512::PACKED_MU)); - - // Get all the high halves as one vector: this is `(lhs * rhs) >> 32`. - // NB: `vpermt2d` may feel like a more intuitive choice here, but it has much higher - // latency. - let output_hi = mask_movehdup_epi32(output_odd, EVENS, output_even); - - // Normally we'd want to mask to perform % 2**32, but the instruction below only reads the - // low 32 bits anyway. - let q_p_evn = x86_64::_mm512_mul_epi32(q_evn, MPAVX512::PACKED_P); - let q_p_odd = x86_64::_mm512_mul_epi32(q_odd, MPAVX512::PACKED_P); - - // We can ignore all the low halves of `q_p` as they cancel out. Get all the high halves as - // one vector. - let q_p_hi = mask_movehdup_epi32(q_p_odd, EVENS, q_p_evn); - - // Subtraction `output_hi - q_p_hi` modulo `P`. - // NB: Normally we'd `vpaddd P` and take the `vpminud`, but `vpminud` runs on port 0, which - // is already under a lot of pressure performing multiplications. To relieve this pressure, - // we check for underflow to generate a mask, and then conditionally add `P`. The underflow - // check runs on port 5, increasing our throughput, although it does cost us an additional - // cycle of latency. - let underflow = x86_64::_mm512_cmplt_epi32_mask(output_hi, q_p_hi); - let t = x86_64::_mm512_sub_epi32(output_hi, q_p_hi); - x86_64::_mm512_mask_add_epi32(t, underflow, t, MPAVX512::PACKED_P) - } -} - /// Negate a vector of MontyField31 elements in canonical form. /// If the inputs are not in canonical form, the result is undefined. #[inline] @@ -982,145 +762,19 @@ pub fn dot_product_2, RHS: IntoM5 } } -/// Compute the elementary function `l0*r0 + l1*r1 + l2*r2 + l3*r3` given eight inputs -/// in canonical form. -/// -/// If the inputs are not in canonical form, the result is undefined. -#[inline] -#[must_use] -fn dot_product_4, RHS: IntoM512>( - lhs: [LHS; 4], - rhs: [RHS; 4], -) -> __m512i { - // The following analysis treats all input arrays as being arrays of PackedMontyField31AVX512. - // If one of the arrays contains MontyField31, we get to avoid the initial vmovshdup. - // - // Similarly to dot_product_2, we improve throughput by combining monty reductions however in this case - // we will need to slightly adjust the reduction algorithm. - // - // As all inputs are `< P < 2^{31}`, the sum satisfies: `C = l0*r0 + l1*r1 + l2*r2 + l3*r3 < 4P^2 < 2*2^{32}P`. - // Start by computing Q := μ C mod B as usual. - // We can't proceed as normal however as 2*2^{32}P > C - QP > -2^{32}P which doesn't fit into an i64. - // Instead we do a reduction on C, defining C' = if C < 2^{32}P: {C} else {C - 2^{32}P} - // From here we proceed with the standard montgomery reduction with C replaced by C'. It works identically - // with the Q we already computed as C = C' mod B. - // - // We want this to compile to: - // vmovshdup lhs_odd0, lhs0 - // vmovshdup rhs_odd0, rhs0 - // vmovshdup lhs_odd1, lhs1 - // vmovshdup rhs_odd1, rhs1 - // vmovshdup lhs_odd2, lhs2 - // vmovshdup rhs_odd2, rhs2 - // vmovshdup lhs_odd3, lhs3 - // vmovshdup rhs_odd3, rhs3 - // vpmuludq prod_evn0, lhs0, rhs0 - // vpmuludq prod_odd0, lhs_odd0, rhs_odd0 - // vpmuludq prod_evn1, lhs1, rhs1 - // vpmuludq prod_odd1, lhs_odd1, rhs_odd1 - // vpmuludq prod_evn2, lhs2, rhs2 - // vpmuludq prod_odd2, lhs_odd2, rhs_odd2 - // vpmuludq prod_evn3, lhs3, rhs3 - // vpmuludq prod_odd3, lhs_odd3, rhs_odd3 - // vpaddq dot_evn01, prod_evn0, prod_evn1 - // vpaddq dot_odd01, prod_odd0, prod_odd1 - // vpaddq dot_evn23, prod_evn2, prod_evn3 - // vpaddq dot_odd23, prod_odd2, prod_odd3 - // vpaddq dot_evn, dot_evn01, dot_evn23 - // vpaddq dot, dot_odd01, dot_odd23 - // vpmuludq q_evn, dot_evn, MU - // vpmuludq q_odd, dot, MU - // vpmuludq q_P_evn, q_evn, P - // vpmuludq q_P, q_odd, P - // vmovshdup dot{EVENS} dot_evn - // vpcmpleud over_P, P, dot - // vpsubd dot{underflow}, dot, P - // vmovshdup q_P{EVENS} q_P_evn - // vpcmpltud underflow, dot, q_P - // vpsubd res, dot, q_P - // vpaddd res{underflow}, res, P - // throughput: 16.5 cyc/vec (0.97 els/cyc) - // latency: 23 cyc - unsafe { - let lhs_evn0 = lhs[0].as_m512i(); - let lhs_odd0 = lhs[0].as_shifted_m512i(); - let lhs_evn1 = lhs[1].as_m512i(); - let lhs_odd1 = lhs[1].as_shifted_m512i(); - let lhs_evn2 = lhs[2].as_m512i(); - let lhs_odd2 = lhs[2].as_shifted_m512i(); - let lhs_evn3 = lhs[3].as_m512i(); - let lhs_odd3 = lhs[3].as_shifted_m512i(); - - let rhs_evn0 = rhs[0].as_m512i(); - let rhs_odd0 = rhs[0].as_shifted_m512i(); - let rhs_evn1 = rhs[1].as_m512i(); - let rhs_odd1 = rhs[1].as_shifted_m512i(); - let rhs_evn2 = rhs[2].as_m512i(); - let rhs_odd2 = rhs[2].as_shifted_m512i(); - let rhs_evn3 = rhs[3].as_m512i(); - let rhs_odd3 = rhs[3].as_shifted_m512i(); - - let mul_evn0 = x86_64::_mm512_mul_epu32(lhs_evn0, rhs_evn0); - let mul_evn1 = x86_64::_mm512_mul_epu32(lhs_evn1, rhs_evn1); - let mul_evn2 = x86_64::_mm512_mul_epu32(lhs_evn2, rhs_evn2); - let mul_evn3 = x86_64::_mm512_mul_epu32(lhs_evn3, rhs_evn3); - let mul_odd0 = x86_64::_mm512_mul_epu32(lhs_odd0, rhs_odd0); - let mul_odd1 = x86_64::_mm512_mul_epu32(lhs_odd1, rhs_odd1); - let mul_odd2 = x86_64::_mm512_mul_epu32(lhs_odd2, rhs_odd2); - let mul_odd3 = x86_64::_mm512_mul_epu32(lhs_odd3, rhs_odd3); - - let dot_evn01 = x86_64::_mm512_add_epi64(mul_evn0, mul_evn1); - let dot_odd01 = x86_64::_mm512_add_epi64(mul_odd0, mul_odd1); - let dot_evn23 = x86_64::_mm512_add_epi64(mul_evn2, mul_evn3); - let dot_odd23 = x86_64::_mm512_add_epi64(mul_odd2, mul_odd3); - - let dot_evn = x86_64::_mm512_add_epi64(dot_evn01, dot_evn23); - let dot_odd = x86_64::_mm512_add_epi64(dot_odd01, dot_odd23); - - // We throw a confuse compiler here to prevent the compiler from - // using vpmullq instead of vpmuludq in the computations for q_p. - // vpmullq has both higher latency and lower throughput. - let q_evn = confuse_compiler(x86_64::_mm512_mul_epu32(dot_evn, PMP::PACKED_MU)); - let q_odd = confuse_compiler(x86_64::_mm512_mul_epu32(dot_odd, PMP::PACKED_MU)); - - // Get all the high halves as one vector: this is `dot(lhs, rhs) >> 32`. - // NB: `vpermt2d` may feel like a more intuitive choice here, but it has much higher - // latency. - let dot = mask_movehdup_epi32(dot_odd, EVENS, dot_evn); - - // The elements in dot lie in [0, 2P) so we need to reduce them to [0, P) - // NB: Normally we'd `vpsubq P` and take the `vpminud`, but `vpminud` runs on port 0, which - // is already under a lot of pressure performing multiplications. To relieve this pressure, - // we check for underflow to generate a mask, and then conditionally add `P`. - let over_p = x86_64::_mm512_cmple_epu32_mask(PMP::PACKED_P, dot); - let dot_corr = x86_64::_mm512_mask_sub_epi32(dot, over_p, dot, PMP::PACKED_P); - - // Normally we'd want to mask to perform % 2**32, but the instruction below only reads the - // low 32 bits anyway. - let q_p_evn = x86_64::_mm512_mul_epu32(q_evn, PMP::PACKED_P); - let q_p_odd = x86_64::_mm512_mul_epu32(q_odd, PMP::PACKED_P); - - // We can ignore all the low halves of `q_p` as they cancel out. Get all the high halves as - // one vector. - let q_p = mask_movehdup_epi32(q_p_odd, EVENS, q_p_evn); - - // Subtraction `prod_hi - q_p_hi` modulo `P`. - // NB: Normally we'd `vpaddd P` and take the `vpminud`, but `vpminud` runs on port 0, which - // is already under a lot of pressure performing multiplications. To relieve this pressure, - // we check for underflow to generate a mask, and then conditionally add `P`. The underflow - // check runs on port 5, increasing our throughput, although it does cost us an additional - // cycle of latency. - let underflow = x86_64::_mm512_cmplt_epu32_mask(dot_corr, q_p); - let t = x86_64::_mm512_sub_epi32(dot_corr, q_p); - x86_64::_mm512_mask_add_epi32(t, underflow, t, PMP::PACKED_P) - } -} +// `dot_product_4` was a specialised batched-Montgomery dot product that +// accumulated four 64-bit products into a single u64 before reduction. That +// bound only holds for `P < 2^31`; for the current 32-bit prime even two +// products overflow u64. `general_dot_product` falls through to scalar +// accumulation via the canonical Mul instead. /// A general fast dot product implementation. /// -/// Maximises the number of calls to `dot_product_4` for dot products involving vectors of length -/// more than 4. The length 64 occurs commonly enough it's useful to have a custom implementation -/// which lets it use a slightly better summation algorithm with lower latency. +/// `dot_product_2` / `dot_product_4` would normally amortise one Montgomery +/// reduction over multiple products, but they accumulate two/four 64-bit +/// products into a single u64 — sound only for `P < 2^31`. For the current +/// 32-bit prime (`0xfa000001`) even two products can overflow u64, so we +/// reduce each product individually and accumulate the canonical results. #[inline(always)] fn general_dot_product, RHS: IntoM512, const N: usize>( lhs: &[LHS], @@ -1128,72 +782,14 @@ fn general_dot_product, RHS: IntoM512 ) -> PackedMontyField31AVX512 { assert_eq!(lhs.len(), N); assert_eq!(rhs.len(), N); - match N { - 0 => PackedMontyField31AVX512::::ZERO, - 1 => (lhs[0]).into() * (rhs[0]).into(), - 2 => { - let res = dot_product_2([lhs[0], lhs[1]], [rhs[0], rhs[1]]); - unsafe { - // Safety: `dot_product_2` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX512::::from_vector(res) - } - } - 3 => { - let lhs2 = lhs[2]; - let rhs2 = rhs[2]; - let res = dot_product_2([lhs[0], lhs[1]], [rhs[0], rhs[1]]); - unsafe { - // Safety: `dot_product_2` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX512::::from_vector(res) + (lhs2.into() * rhs2.into()) - } - } - 4 => { - let res = dot_product_4([lhs[0], lhs[1], lhs[2], lhs[3]], [rhs[0], rhs[1], rhs[2], rhs[3]]); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX512::::from_vector(res) - } - } - 64 => { - let sum_4s: [PackedMontyField31AVX512; 16] = array::from_fn(|i| { - let res = dot_product_4( - [lhs[4 * i], lhs[4 * i + 1], lhs[4 * i + 2], lhs[4 * i + 3]], - [rhs[4 * i], rhs[4 * i + 1], rhs[4 * i + 2], rhs[4 * i + 3]], - ); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX512::::from_vector(res) - } - }); - PackedMontyField31AVX512::::sum_array::<16>(&sum_4s) - } - _ => { - let mut acc = { - let res = dot_product_4([lhs[0], lhs[1], lhs[2], lhs[3]], [rhs[0], rhs[1], rhs[2], rhs[3]]); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - PackedMontyField31AVX512::::from_vector(res) - } - }; - for i in (4..(N - 3)).step_by(4) { - let res = dot_product_4( - [lhs[i], lhs[i + 1], lhs[i + 2], lhs[i + 3]], - [rhs[i], rhs[i + 1], rhs[i + 2], rhs[i + 3]], - ); - unsafe { - // Safety: `dot_product_4` returns values in canonical form when given values in canonical form. - acc += PackedMontyField31AVX512::::from_vector(res) - } - } - match N & 3 { - 0 => acc, - 1 => acc + general_dot_product::<_, _, _, 1>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - 2 => acc + general_dot_product::<_, _, _, 2>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - 3 => acc + general_dot_product::<_, _, _, 3>(&lhs[(4 * (N / 4))..], &rhs[(4 * (N / 4))..]), - _ => unreachable!(), - } - } + if N == 0 { + return PackedMontyField31AVX512::::ZERO; + } + let mut acc: PackedMontyField31AVX512 = lhs[0].into() * rhs[0].into(); + for i in 1..N { + acc += lhs[i].into() * rhs[i].into(); } + acc } impl_packed_value!( diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/poseidon_helpers.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/poseidon_helpers.rs index c7d396487..f73b9cddb 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/poseidon_helpers.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/poseidon_helpers.rs @@ -1,83 +1,9 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). //! AVX512 helpers shared by Poseidon1 permutations. - -use core::arch::x86_64::{self, __m512i}; -use core::mem::transmute; - -use super::{apply_func_to_even_odd, packed_exp_3, packed_exp_5, packed_exp_7}; -use crate::{MontyParameters, PackedMontyField31AVX512, PackedMontyParameters}; - -/// A specialized representation of the Poseidon state for a width of 16. -/// -/// Splits the state into `s0` (undergoes S-box) and `s_hi` (undergoes only linear transforms), -/// enabling instruction-level parallelism between the two independent data paths. -#[derive(Clone, Copy)] -#[repr(C)] -pub struct InternalLayer16 { - pub(crate) s0: PackedMontyField31AVX512, - pub(crate) s_hi: [__m512i; 15], -} - -impl InternalLayer16 { - #[inline] - pub(crate) unsafe fn to_packed_field_array(self) -> [PackedMontyField31AVX512; 16] { - unsafe { transmute(self) } - } - - #[inline] - #[must_use] - pub(crate) fn from_packed_field_array(vector: [PackedMontyField31AVX512; 16]) -> Self { - unsafe { transmute(vector) } - } -} - -/// Use hard coded methods to compute `x -> x^D` for the even index entries and small `D`. -/// Inputs should be signed 32-bit integers in `[-P, ..., P]`. -/// Outputs will also be signed integers in `(-P, ..., P)` stored in the odd indices. -#[inline(always)] -#[must_use] -pub(crate) fn exp_small(val: __m512i) -> __m512i { - match D { - 3 => packed_exp_3::(val), - 5 => packed_exp_5::(val), - 7 => packed_exp_7::(val), - _ => panic!("No exp function for given D"), - } -} - -/// Converts a scalar constant into a packed AVX512 vector in "negative form" (`c - P`). -#[inline(always)] -pub(crate) fn convert_to_vec_neg_form(input: i32) -> __m512i { - let input_sub_p = input - (MP::PRIME as i32); - unsafe { x86_64::_mm512_set1_epi32(input_sub_p) } -} - -/// Performs the fused AddRoundConstant and S-Box operation `x -> (x + c)^D`. -/// -/// `val` must contain elements in canonical form `[0, P)`. -/// `rc` must contain round constants in negative form `[-P, 0)`. -#[inline(always)] -pub(crate) fn add_rc_and_sbox( - val: &mut PackedMontyField31AVX512, - rc: __m512i, -) { - unsafe { - let vec_val = val.to_vector(); - let val_plus_rc = x86_64::_mm512_add_epi32(vec_val, rc); - let output = apply_func_to_even_odd::(val_plus_rc, exp_small::); - *val = PackedMontyField31AVX512::::from_vector(output); - } -} - -/// Applies the S-Box `x -> x^D` to a packed vector. Output is in canonical form. -#[inline(always)] -pub(crate) fn sbox( - val: PackedMontyField31AVX512, -) -> PackedMontyField31AVX512 { - unsafe { - let vec = val.to_vector(); - let out = apply_func_to_even_odd::(vec, exp_small::); - PackedMontyField31AVX512::::from_vector(out) - } -} +//! +//! The optimised batched-S-box helpers (`exp_small`, `add_rc_and_sbox`, +//! `sbox`, `InternalLayer16`) relied on a "signed in (-P, P)" intermediate +//! stored as i32, which is unambiguous only for `P < 2^31`. They have +//! been removed for the current 32-bit prime; AVX512 Poseidon falls +//! through the generic `permute_generic` path which uses canonical Mul. diff --git a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/utils.rs b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/utils.rs index 7557af334..9078a175f 100644 --- a/crates/backend/koala-bear/src/monty_31/x86_64_avx512/utils.rs +++ b/crates/backend/koala-bear/src/monty_31/x86_64_avx512/utils.rs @@ -3,7 +3,7 @@ use core::arch::x86_64::{self, __m512i}; use core::mem::transmute; -use crate::{MontyParameters, PackedMontyParameters, TwoAdicData}; +use crate::{FieldParameters, PackedMontyParameters, TwoAdicData}; // Godbolt file showing that these all compile to the expected instructions. (Potentially plus a few memory ops): // https://godbolt.org/z/dvW7r1zjj @@ -12,27 +12,19 @@ use crate::{MontyParameters, PackedMontyParameters, TwoAdicData}; /// /// If the inputs are not in canonical form, the result is undefined. #[inline(always)] -pub(crate) fn halve_avx512(input: __m512i) -> __m512i { +pub(crate) fn halve_avx512(input: __m512i) -> __m512i { /* - We want this to compile to: - vptestmd least_bit, val, ONE - vpsrld res, val, 1 - vpaddd res{least_bit}, res, maybe_half - throughput: 2 cyc/vec - latency: 4 cyc - Given an element val in [0, P), we want to compute val/2 mod P. If val is even: val/2 mod P = val/2 = val >> 1. If val is odd: val/2 mod P = (val + P)/2 = (val >> 1) + (P + 1)/2 */ unsafe { - // Safety: If this code got compiled then AVX512 intrinsics are available. const ONE: __m512i = unsafe { transmute([1u32; 16]) }; - let half = x86_64::_mm512_set1_epi32((MP::PRIME as i32 + 1) / 2); // Compiler realises this is constant. + // HALF_P_PLUS_1 = (P + 1) / 2, computed correctly at u32 level. + let half = x86_64::_mm512_set1_epi32(FP::HALF_P_PLUS_1 as i32); - let least_bit = x86_64::_mm512_test_epi32_mask(input, ONE); // Determine the parity of val. + let least_bit = x86_64::_mm512_test_epi32_mask(input, ONE); let t = x86_64::_mm512_srli_epi32::<1>(input); - // This does nothing when least_bit = 1 and sets the corresponding entry to 0 when least_bit = 0 x86_64::_mm512_mask_add_epi32(t, least_bit, t, half) } } diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs index 41be54ff5..f1eac4d9e 100644 --- a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs @@ -25,50 +25,26 @@ const MDS_CIRC_COL: [KoalaBear; 16] = KoalaBear::new_array([1, 3, 13, 22, 67, 2, // Forward twiddles for 16-point FFT: W_k = omega^k // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W1: KoalaBear = KoalaBear::new(0x08dbd69c); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W2: KoalaBear = KoalaBear::new(0x6832fe4a); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W3: KoalaBear = KoalaBear::new(0x27ae21e2); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W4: KoalaBear = KoalaBear::new(0x7e010002); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W5: KoalaBear = KoalaBear::new(0x3a89a025); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W6: KoalaBear = KoalaBear::new(0x174e3650); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -const W7: KoalaBear = KoalaBear::new(0x27dfce22); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W1: KoalaBear = KoalaBear::new(0x6b52061e); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W2: KoalaBear = KoalaBear::new(0x894b5390); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W3: KoalaBear = KoalaBear::new(0x39f910ef); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W4: KoalaBear = KoalaBear::new(0x304096c9); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W5: KoalaBear = KoalaBear::new(0x33c5a441); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W6: KoalaBear = KoalaBear::new(0x2e9b3a27); +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +const W7: KoalaBear = KoalaBear::new(0x9d09df4b); // ========================================================================= // 16-point FFT / IFFT (radix-2, fully unrolled, in-place) // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { let (a, b) = (v[lo], v[hi]); @@ -76,10 +52,7 @@ fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { v[hi] = a - b; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { let a = v[lo]; @@ -88,10 +61,7 @@ fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBea v[hi] = a - tb; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { let (a, b) = (v[lo], v[hi]); @@ -99,10 +69,7 @@ fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: Koal v[hi] = (b - a) * t; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 8); @@ -139,10 +106,7 @@ fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 14, 15); } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] fn dit_fft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 1); @@ -327,7 +291,7 @@ pub fn mds_circ_16>(stat } // ========================================================================= -// Sparse matrix decomposition helpers (for SIMD partial rounds) +// Sparse matrix decomposition helpers (for NEON partial rounds) // ========================================================================= /// Dense NxN matrix multiplication: C = A * B. @@ -543,50 +507,21 @@ struct Precomputed { /// Length = POSEIDON1_PARTIAL_ROUNDS - 1. sparse_round_constants: Vec, - // --- SIMD pre-packed constants (NEON / AVX2 / AVX512) --- - #[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") - ))] - simd: SimdPrecomputed, + // --- NEON pre-packed constants --- + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + neon: NeonPrecomputed, } -/// Arch-specific raw vector type used for negative-form round constants. #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] -type Rc = core::arch::aarch64::int32x4_t; -#[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] -type Rc = core::arch::x86_64::__m256i; -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -type Rc = core::arch::x86_64::__m512i; - -/// Common field parameters type used across all architectures. -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -type FP = crate::KoalaBearParameters; - -/// Arch-specific packed KoalaBear type. -#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] -type PackedKB = crate::PackedKoalaBearNeon; -#[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] -type PackedKB = crate::PackedKoalaBearAVX2; -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -type PackedKB = crate::PackedKoalaBearAVX512; - -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -struct SimdPrecomputed { - /// Initial full round constants in negative form (only first 3 rounds; +struct NeonPrecomputed { + /// Initial full round constants in negative NEON form (only first 3 rounds; /// the 4th is fused with the partial round entry). - packed_initial_rc: [[Rc; 16]; POSEIDON1_HALF_FULL_ROUNDS - 1], - /// Terminal full round constants in negative form. - packed_terminal_rc: [[Rc; 16]; POSEIDON1_HALF_FULL_ROUNDS], - /// Pre-packed sparse first rows. + packed_initial_rc: [[core::arch::aarch64::uint32x4_t; 16]; POSEIDON1_HALF_FULL_ROUNDS - 1], + /// Terminal full round constants in canonical form. + packed_terminal_rc: [[core::arch::aarch64::uint32x4_t; 16]; POSEIDON1_HALF_FULL_ROUNDS], + /// Pre-packed sparse first rows as PackedKoalaBearNeon. packed_sparse_first_row: [[PackedKB; 16]; POSEIDON1_PARTIAL_ROUNDS], - /// Pre-packed v vectors. + /// Pre-packed v vectors as PackedKoalaBearNeon. packed_sparse_v: [[PackedKB; 16]; POSEIDON1_PARTIAL_ROUNDS], /// Pre-packed scalar round constants for partial rounds 0..RP-2. packed_round_constants: [PackedKB; POSEIDON1_PARTIAL_ROUNDS - 1], @@ -595,22 +530,24 @@ struct SimdPrecomputed { packed_fused_mi_mds: [[PackedKB; 16]; 16], /// Fused bias: m_i * first_round_constants. packed_fused_bias: [PackedKB; 16], - /// Last initial round constant in negative form (for fused add_rc_and_sbox). - packed_last_initial_rc: [Rc; 16], + /// Last initial round constant in canonical form (for add_rc_and_sbox). + packed_last_initial_rc: [core::arch::aarch64::uint32x4_t; 16], /// Pre-packed eigenvalues * INV16 for FFT MDS (absorbs /16 normalization). packed_lambda_over_16: [PackedKB; 16], } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] -impl std::fmt::Debug for SimdPrecomputed { +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +impl std::fmt::Debug for NeonPrecomputed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SimdPrecomputed").finish_non_exhaustive() + f.debug_struct("NeonPrecomputed").finish_non_exhaustive() } } +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +type FP = crate::KoalaBearParameters; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +type PackedKB = crate::PackedKoalaBearNeon; + static PRECOMPUTED: OnceLock = OnceLock::new(); fn precomputed() -> &'static Precomputed { @@ -634,29 +571,27 @@ fn precomputed() -> &'static Precomputed { .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] })) .collect(); - // --- SIMD pre-packed constants (same layout for NEON / AVX2 / AVX512) --- - #[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") - ))] - let simd = { - use crate::convert_to_vec_neg_form; + // --- NEON pre-packed constants --- + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + let neon = { + use crate::PackedMontyField31Neon; + use crate::convert_to_vec_neon; - let pack = |c: KoalaBear| PackedKB::from(c); - let neg_form = |c: KoalaBear| convert_to_vec_neg_form::(c.value as i32); + let pack = |c: KoalaBear| PackedMontyField31Neon::::from(c); + let canon_form = |c: KoalaBear| convert_to_vec_neon(c.value); // Initial full round constants (only first 3; 4th is fused). let init_rc = poseidon1_initial_constants(); - let packed_initial_rc: [[Rc; 16]; POSEIDON1_HALF_FULL_ROUNDS - 1] = - core::array::from_fn(|r| init_rc[r].map(neg_form)); + let packed_initial_rc: [[core::arch::aarch64::uint32x4_t; 16]; POSEIDON1_HALF_FULL_ROUNDS - 1] = + core::array::from_fn(|r| init_rc[r].map(canon_form)); - // Last initial round constant (for fused add_rc_and_sbox before partial rounds). - let packed_last_initial_rc = init_rc[POSEIDON1_HALF_FULL_ROUNDS - 1].map(neg_form); + // Last initial round constant (for add_rc_and_sbox before partial rounds). + let packed_last_initial_rc = init_rc[POSEIDON1_HALF_FULL_ROUNDS - 1].map(canon_form); // Terminal full round constants. let term_rc = poseidon1_final_constants(); - let packed_terminal_rc: [[Rc; 16]; POSEIDON1_HALF_FULL_ROUNDS] = - core::array::from_fn(|r| term_rc[r].map(neg_form)); + let packed_terminal_rc: [[core::arch::aarch64::uint32x4_t; 16]; POSEIDON1_HALF_FULL_ROUNDS] = + core::array::from_fn(|r| term_rc[r].map(canon_form)); // Pre-packed sparse constants (fixed-size arrays). let packed_sparse_first_row: [[PackedKB; 16]; POSEIDON1_PARTIAL_ROUNDS] = @@ -677,10 +612,10 @@ fn precomputed() -> &'static Precomputed { // Pre-packed eigenvalues * INV16 (absorbs /16 into eigenvalues). let mut lambda_br = MDS_CIRC_COL; dif_ifft_16_mut(&mut lambda_br); - let inv16 = KoalaBear::new(1997537281); // 16^{-1} mod p + let inv16 = KoalaBear::new(3932160001); // 16^{-1} mod p let packed_lambda_over_16: [PackedKB; 16] = core::array::from_fn(|i| pack(lambda_br[i] * inv16)); - SimdPrecomputed { + NeonPrecomputed { packed_initial_rc, packed_terminal_rc, packed_sparse_first_row, @@ -699,11 +634,8 @@ fn precomputed() -> &'static Precomputed { sparse_first_row, sparse_v, sparse_round_constants: scalar_round_constants, - #[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") - ))] - simd, + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + neon, } }) } @@ -715,118 +647,118 @@ fn precomputed() -> &'static Precomputed { const POSEIDON1_RC: [[KoalaBear; 16]; POSEIDON1_N_ROUNDS] = KoalaBear::new_2d_array([ // Initial full rounds (4) [ - 0x7ee56a48, 0x11367045, 0x12e41941, 0x7ebbc12b, 0x1970b7d5, 0x662b60e8, 0x3e4990c6, 0x679f91f5, 0x350813bb, - 0x00874ad4, 0x28a0081a, 0x18fa5872, 0x5f25b071, 0x5e5d5998, 0x5e6fd3e7, 0x5b2e2660, + 0x16563297, 0x3e669663, 0xad043724, 0x90ce7116, 0xa078ea37, 0x0626b030, 0xec75cf01, 0x3aab2bf5, 0x591c1f83, + 0x3c9a00ec, 0x73c17410, 0x0b7b4103, 0x00f0d14d, 0x59c6c3d4, 0x569c4787, 0x29f72a6a, ], [ - 0x6f1837bf, 0x3fe6182b, 0x1edd7ac5, 0x57470d00, 0x43d486d5, 0x1982c70f, 0x0ea53af9, 0x61d6165b, 0x51639c00, - 0x2dec352c, 0x2950e531, 0x2d2cb947, 0x08256cef, 0x1a0109f6, 0x1f51faf3, 0x5cef1c62, + 0xe7e6145a, 0xa62dfa56, 0xa72459c4, 0xc12e3d21, 0xcfbaf427, 0x79471bf4, 0x35ea3fd2, 0x0e871da2, 0xb8142f08, + 0x3623564d, 0x607e0747, 0x6c2c3f6a, 0xbaa4cd3d, 0x069d6b42, 0xb98024b9, 0x68831d77, ], [ - 0x3d65e50e, 0x33d91626, 0x133d5a1e, 0x0ff49b0d, 0x38900cd1, 0x2c22cc3f, 0x28852bb2, 0x06c65a02, 0x7b2cf7bc, - 0x68016e1a, 0x15e16bc0, 0x5248149a, 0x6dd212a0, 0x18d6830a, 0x5001be82, 0x64dac34e, + 0xbc4cdd4a, 0x872186a1, 0x90c3888e, 0x7e783120, 0xb3575851, 0x334e7976, 0xbaa135c3, 0x3eb628cc, 0x58712a8d, + 0xd8c18178, 0x602d3f41, 0xcb1b656a, 0x5bd99496, 0xf1f622ca, 0xc363a444, 0x97c4b605, ], [ - 0x5902b287, 0x426583a0, 0x0c921632, 0x3fe028a5, 0x245f8e49, 0x43bb297e, 0x7873dbd9, 0x3cc987df, 0x286bb4ce, - 0x640a8dcd, 0x512a8e36, 0x03a4cf55, 0x481837a2, 0x03d6da84, 0x73726ac7, 0x760e7fdf, + 0x5a07e287, 0xc80962f3, 0x2e37f2d9, 0x26d67c3e, 0xcf96a0a0, 0x1fcde991, 0x641f2cd8, 0x1ef5c127, 0xaddf98c2, + 0xa5c96cb9, 0xca07faa8, 0x73d4e273, 0x8385bfa3, 0x15b22541, 0xcbfde37c, 0x32475586, ], // Partial rounds (20) [ - 0x54dfeb5d, 0x7d40afd6, 0x722cb316, 0x106a4573, 0x45a7ccdb, 0x44061375, 0x154077a5, 0x45744faa, 0x4eb5e5ee, - 0x3794e83f, 0x47c7093c, 0x5694903c, 0x69cb6299, 0x373df84c, 0x46a0df58, 0x46b8758a, + 0xe49f77a8, 0xc751c2d1, 0x76eb1c5a, 0x8b8672cd, 0x1cb44174, 0xf6fb3e5e, 0xdaabb8d4, 0x18e7bd3d, 0x2a560e97, + 0xe04179a3, 0xe21d103d, 0x58d0e0ab, 0x303d1e0c, 0x3f250c3c, 0xc86d7c1e, 0x3ece5680, ], [ - 0x3241ebcb, 0x0b09d233, 0x1af42357, 0x1e66cec2, 0x43e7dc24, 0x259a5d61, 0x27e85a3b, 0x1b9133fa, 0x343e5628, - 0x485cd4c2, 0x16e269f5, 0x165b60c6, 0x25f683d9, 0x124f81f9, 0x174331f9, 0x77344dc5, + 0x16b1f692, 0xf16b95b7, 0x26701f28, 0x85cf7a29, 0xe40a08a9, 0x9636d8d1, 0xeb72850a, 0x48143205, 0x73afec23, + 0x7fd31200, 0xd1fc60c6, 0xce74b7f9, 0xe0acb139, 0x59eeab1f, 0x18afb097, 0x42d20438, ], [ - 0x5a821dba, 0x5fc4177f, 0x54153bf5, 0x5e3f1194, 0x3bdbf191, 0x088c84a3, 0x68256c9b, 0x3c90bbc6, 0x6846166a, - 0x03f4238d, 0x463335fb, 0x5e3d3551, 0x6e59ae6f, 0x32d06cc0, 0x596293f3, 0x6c87edb2, + 0xcaf14d2f, 0xf26c3fe2, 0x2adc5737, 0xe1e64e39, 0x735ae3e8, 0xbb1e61db, 0xb7c7035c, 0xae582686, 0xebfaf2bd, + 0x0e0610a8, 0x8cd868de, 0xc68befce, 0x6cc077e3, 0x73b6bf4f, 0x8e8404bb, 0x056c9356, ], [ - 0x08fc60b5, 0x34bcca80, 0x24f007f3, 0x62731c6f, 0x1e1db6c6, 0x0ca409bb, 0x585c1e78, 0x56e94edc, 0x16d22734, - 0x18e11467, 0x7b2c3730, 0x770075e4, 0x35d1b18c, 0x22be3db5, 0x4fb1fbb7, 0x477cb3ed, + 0xe2458a7a, 0x895a2ec9, 0xe3667424, 0x68d98c49, 0x044ed7f5, 0x8139cd81, 0x32fe0a12, 0xa9f206ee, 0x7ec87d9b, + 0xa4854877, 0xdd8cec50, 0x4c5d8b26, 0xe00c0ef4, 0x5da03424, 0xceb8fe3d, 0x60d7a6e2, ], [ - 0x7d5311c6, 0x5b62ae7d, 0x559c5fa8, 0x77f15048, 0x3211570b, 0x490fef6a, 0x77ec311f, 0x2247171b, 0x4e0ac711, - 0x2edf69c9, 0x3b5a8850, 0x65809421, 0x5619b4aa, 0x362019a7, 0x6bf9d4ed, 0x5b413dff, + 0x45439320, 0x99febd74, 0xf5913c65, 0x4e96d0aa, 0xc7653152, 0x90cd4ba6, 0x8889cbcd, 0x6638fae2, 0xdcfc4c5e, + 0x8af12cc1, 0x59307544, 0xe1e6bbb3, 0xf9bfa656, 0x72ae4f2c, 0xd5c9598d, 0x3ad1558f, ], [ - 0x617e181e, 0x5e7ab57b, 0x33ad7833, 0x3466c7ca, 0x6488dff4, 0x71f068f4, 0x056e891f, 0x04f1eccc, 0x663257d5, - 0x671e31b9, 0x5871987c, 0x280c109e, 0x2a227761, 0x350a25e9, 0x5b91b1c4, 0x7a073546, + 0x2a3ca52b, 0x99b09e2f, 0x2f7eecd5, 0x520ae1bb, 0x64587b54, 0xf8562fb2, 0xd7770959, 0x60406484, 0x530479ec, + 0x12a21d02, 0xac8bad07, 0xf67994c0, 0x0cc0472e, 0x18c6d644, 0x1b664e25, 0xe6e8b908, ], [ - 0x01826270, 0x53a67720, 0x0ed4b074, 0x34cf0c4e, 0x6e751e88, 0x29bd5f59, 0x49ec32df, 0x7693452b, 0x3cf09e58, - 0x6ba0e2bf, 0x7ab93acf, 0x3ce597df, 0x536e3d42, 0x147a808d, 0x5e32eb56, 0x5a203323, + 0xb2983e38, 0x47455654, 0x17526a2f, 0x6a2d789f, 0xe74a8306, 0xe0fb4aad, 0x9cd3cd7b, 0x614971b7, 0xd97aac83, + 0x3e042505, 0x3125d6a4, 0x562bac89, 0x95736c1d, 0xd58f393b, 0x58cd5712, 0x90841a10, ], [ - 0x50965766, 0x6d44b7c5, 0x6698636a, 0x57b84f9f, 0x554b61b9, 0x6da0ab28, 0x1585b6ac, 0x6705a2b4, 0x152872f6, - 0x0f4409fd, 0x23a9dd60, 0x6f2b18d4, 0x65ac9fd4, 0x2f0efbea, 0x591e67fd, 0x217ca19b, + 0x31eadbdb, 0x2ea63f0b, 0xb7911a51, 0x39001e47, 0x89687d2d, 0x77f8f4db, 0x4077716d, 0xe74357fd, 0x02f591df, + 0x1b9d1ab6, 0xc10be6ba, 0x9d5ad139, 0x3af0f7c7, 0x5a63730e, 0x606a08dd, 0x8e896d67, ], [ - 0x469c90ca, 0x03d60ef5, 0x4ea7857e, 0x07c86a4f, 0x288ed461, 0x2fe51b22, 0x7e293614, 0x2c4beb85, 0x5b0b7d11, - 0x1e17dff6, 0x089beae1, 0x0a5acf1a, 0x2fc33d8f, 0x60422dc6, 0x6e1dc939, 0x635351b9, + 0x3cafe446, 0x39b28d39, 0xcd38a868, 0x56c4c0ed, 0x692a0c89, 0x662f9fed, 0x8a312370, 0x65998ae3, 0x5ab80205, + 0xc2b3941d, 0xcf73f6ba, 0x105cd3df, 0x5cd190f4, 0x6f9d0294, 0x7c96a0cd, 0x41c10f0d, ], [ - 0x55522fc0, 0x3eb94ef7, 0x2a24a65c, 0x2e139c76, 0x51391144, 0x78cc0742, 0x579538f9, 0x44de9aae, 0x3c2f1e2e, - 0x195747be, 0x2496339c, 0x650b2e39, 0x52899665, 0x6cb35558, 0x0f461c1c, 0x70f6b270, + 0x672380d2, 0x30215c93, 0xce6489d2, 0xd9759689, 0x92102223, 0x79ff0e1b, 0x8180a72a, 0x1d1d6066, 0x1257b35f, + 0xb130f5f7, 0x14b2bfe1, 0xc1eb9e61, 0x9a032f74, 0x3922de50, 0x22ac5b30, 0x52152927, ], [ - 0x3faaa36f, 0x62e3348a, 0x672167cb, 0x394c880b, 0x2a46ba82, 0x63ffb74a, 0x1cf875d6, 0x53d12772, 0x036a4552, - 0x3bdd9f2b, 0x02f72c24, 0x02b6006c, 0x077fe158, 0x1f9d6ea4, 0x20904d6f, 0x5d6534fa, + 0xd03de930, 0xd30282b3, 0x96637fce, 0xf5b04144, 0x5c659f82, 0x3f257b92, 0xf471e073, 0x64d033e2, 0x489b2abb, + 0x25fcf9fd, 0x7e528a59, 0x042f3e11, 0x55a5b0eb, 0x6ceb509d, 0x331c2be5, 0xc59c27ab, ], [ - 0x066d8974, 0x6198f1f4, 0x26301ab4, 0x41f274c2, 0x00eac15c, 0x28b54b47, 0x2339739d, 0x48c6281c, 0x4ed935fc, - 0x3f9187fa, 0x4a1930a6, 0x3ad4d736, 0x0f3f1889, 0x635a388f, 0x2862c145, 0x277ed1e8, + 0x4fe6f17b, 0xd38fd166, 0x807c01b6, 0x81a24eb2, 0xbdc42437, 0x9f81c56c, 0xafeedd34, 0x3bf25434, 0x2bacd3b8, + 0x53b1e339, 0x39d7e7d9, 0xd79c8cd0, 0x889aa21b, 0x6de4d734, 0x2a3486a9, 0xc6bda81d, ], [ - 0x4db23cad, 0x1f1b11f5, 0x1f3dba2b, 0x1c26eb4e, 0x0f7f5546, 0x6cd024b0, 0x67c47902, 0x793b8900, 0x0e8a283c, - 0x4590b7ea, 0x6f567a2b, 0x5dc97300, 0x15247bc6, 0x50567fcb, 0x133eff84, 0x547dc2ef, + 0x1b7b6b5d, 0x0b3a9554, 0xda9f90a6, 0xbf99a500, 0xa6fdac71, 0xc647ce05, 0x08c13cba, 0x1afb4dba, 0x924e49e0, + 0xd945c1b8, 0xb27617db, 0x8925b96f, 0x11276cf2, 0x2a04b3ea, 0xe740b35c, 0x8e599926, ], [ - 0x34eb3dbb, 0x12402317, 0x66c6ae49, 0x174338b6, 0x24251008, 0x1b514927, 0x062d98d6, 0x7af30bbc, 0x26af15e8, - 0x70d907a3, 0x5dfc5cac, 0x731f27ec, 0x53aa7d3f, 0x63ab0ec6, 0x216053f4, 0x18796b39, + 0xf2172870, 0x4637c7d7, 0x1f3f2c3b, 0x523dc658, 0x8b9be7a3, 0x5c3edbc1, 0x710f3163, 0x817dcef3, 0x9bb5931d, + 0xc9fedd06, 0x5452d0ff, 0x3c3fb0bb, 0x46c83153, 0xd782c351, 0x5d16354f, 0x486081bc, ], [ - 0x19156afd, 0x5eea6973, 0x6704c6a9, 0x0dce002b, 0x331169c0, 0x714d7178, 0x3ddaffaf, 0x7e464957, 0x20ca59ea, - 0x679820c9, 0x42ef21a1, 0x798ea089, 0x14a74fa3, 0x0c06cf18, 0x6a4c8d52, 0x620f6d81, + 0x7ef8fd4c, 0x5bd0b96b, 0x507531e5, 0x921f01bf, 0x57aae86e, 0x50f668df, 0x9cd98fa7, 0xb8ef69b5, 0x33f83af2, + 0xef0e4b26, 0x3f502d46, 0x723a0147, 0x7b2df793, 0x9e0dab32, 0x2108bd31, 0x0e64b870, ], [ - 0x2220901a, 0x5277bb90, 0x230bf95e, 0x0ad8847a, 0x5e96e8b6, 0x77b4056e, 0x70a50d2c, 0x5f0eed59, 0x3646c4df, - 0x10eb9a87, 0x21eed6b7, 0x534add36, 0x6e3e7421, 0x2b25810e, 0x1d8f707b, 0x45318a1a, + 0x2005322b, 0x34b37a14, 0x326a4764, 0xc23709a9, 0xc2877e3f, 0x98d3bc14, 0x071198ef, 0x9dd541db, 0xa47318c5, + 0xca4336f1, 0xde35cddb, 0x94dd1390, 0xa74cfcc5, 0x71396cbd, 0x08456022, 0x040cbdb3, ], [ - 0x677f8ff2, 0x0258c9e0, 0x4cd02a00, 0x2e24ff15, 0x634a715d, 0x4ac01e59, 0x601511e1, 0x26e9c01a, 0x4c165c6e, - 0x57cd1140, 0x3ac6543b, 0x6787d847, 0x037dfbf9, 0x6dd9d079, 0x4d24b281, 0x2a6f407d, + 0x837945ce, 0x9e49261c, 0x28632ae3, 0xe2ebb5e1, 0x13035665, 0x059623df, 0x97dc3043, 0xa04168fc, 0x2a936478, + 0x0047f358, 0xf04d99cc, 0x7ea282bc, 0xdb61c7a1, 0x36213a96, 0x967c85a3, 0x7f9822e0, ], [ - 0x0131df8e, 0x4b8a7896, 0x23700858, 0x2cf5e534, 0x12aafc3f, 0x54568d03, 0x1a250735, 0x5331686d, 0x4ce76d91, - 0x799c1a8c, 0x2b7a8ac9, 0x60aee672, 0x74f7421c, 0x3c42146d, 0x26d369c5, 0x4ae54a12, + 0xb6954f09, 0x07292e31, 0x02091a8d, 0xa304c184, 0x70ea38a0, 0xd3053ca7, 0x00b561ee, 0xc70e3fcb, 0xc82103f8, + 0xdd6355cf, 0x5f0b2b85, 0x6194184e, 0x64fb4fdf, 0x2aaf8ca7, 0x40c2422b, 0x176a2fc8, ], [ - 0x7eea16d1, 0x5ce3eae8, 0x69f28994, 0x262b8642, 0x610d4cc4, 0x5e1af21c, 0x1a8526d0, 0x316b127b, 0x3576fe5d, - 0x02d968a0, 0x4ba00f51, 0x40bed993, 0x377fb907, 0x7859216e, 0x1931d9d1, 0x53b0934e, + 0x7cc97de7, 0x63be86dc, 0x08f0ca90, 0x3071a41e, 0xd56e3a1f, 0xf220dce4, 0x5424c61e, 0xc14b44d7, 0xe4f646df, + 0x6d7be7ad, 0x4b29772e, 0x07ba3bce, 0x397a901c, 0xd710cf8c, 0x018d1e0b, 0x6829ef3d, ], [ - 0x71914ff7, 0x4eabae6c, 0x7196468e, 0x164b3cc2, 0x58cb66c0, 0x4c147307, 0x6b3afccd, 0x4236518b, 0x4ad85605, - 0x291382e1, 0x1e89b6cf, 0x5e16c3a8, 0x2e675921, 0x24300954, 0x05e555c3, 0x78880a24, + 0x9ba21d4c, 0xed64b8db, 0x4aaec707, 0x6d6ae164, 0x3c0746a4, 0xc15cdc64, 0x36e921d7, 0x30c898cc, 0x7c981c6e, + 0x871c3b04, 0x7050a49b, 0x819149a2, 0x08bc501d, 0xc26ceeae, 0x3d78eddc, 0xf2884eca, ], // Terminal full rounds (4) [ - 0x763a3125, 0x4f53b240, 0x18b7fa43, 0x2bbe8a73, 0x1c9a12f2, 0x3f6fd40d, 0x0e1d4ec4, 0x1361c64d, 0x09a8f470, - 0x03d23a40, 0x109ad290, 0x28c2fb88, 0x3b6498f2, 0x74d8be57, 0x6a4277d2, 0x18c2b3d4, + 0x4602cc03, 0xa906d37f, 0x4f1b5c39, 0xc46d832b, 0x189335a1, 0xaa11ab5e, 0xec647d5a, 0x1cae1926, 0x9e51dd38, + 0xbf44201e, 0x371adb90, 0x7a544001, 0x58d3f484, 0x195ec3a6, 0x45776d19, 0x09e98d4a, ], [ - 0x6252c30c, 0x07cc2560, 0x209fe15b, 0x52a55fac, 0x4df19eb7, 0x02521116, 0x5e414ff1, 0x3cd9a1f4, 0x005aad15, - 0x27a53f00, 0x72bbe9cb, 0x71d8bd7d, 0x4194b79a, 0x48e87a72, 0x3341553c, 0x63d34faa, + 0x29f2e1d8, 0x2d7f058c, 0xf25f4a33, 0x4352dfef, 0xa74c0aef, 0x52ba24ca, 0x677b305b, 0xf2941d7d, 0xda68d6e0, + 0x32502a90, 0x0fedb550, 0xf5b7cb9b, 0xcad9d395, 0x793f2d86, 0xa49167fa, 0x8a86b259, ], [ - 0x132a01e3, 0x3833e2d9, 0x49726e04, 0x054957f8, 0x7b71bce4, 0x73eec57d, 0x556e5533, 0x1fa93fde, 0x346a8ca8, - 0x1162dfde, 0x5c30d028, 0x094a4294, 0x3052dcda, 0x37988498, 0x51f06b97, 0x65848779, + 0xabb033c5, 0xe1562215, 0x64e88ed0, 0xb9194068, 0xaf17ebfb, 0xee8377ad, 0xcc7cefea, 0x2522c0b2, 0xa507ae8e, + 0x6eeb4ced, 0x7980c048, 0x25a6f40d, 0xdd443b41, 0x8412e868, 0xbd05f0f4, 0x8c804a4e, ], [ - 0x7599b0d4, 0x436fdabc, 0x66c5b77d, 0x40c86a9e, 0x27e7055b, 0x6d0dd9d8, 0x7e5598b5, 0x1a4d04f3, 0x5e3b2bc7, - 0x533b5b2f, 0x3e33a125, 0x664d71ce, 0x382e6c2a, 0x24c4eb6e, 0x13f246f7, 0x07e2d7ef, + 0xbaad5dad, 0x2bdbe1f0, 0x8dfe8a3f, 0xa5b6f683, 0x0de5ca68, 0x5af48e3d, 0x5d895c2f, 0xf656679d, 0xa3d98a66, + 0xb5e70bc2, 0x678a0ba2, 0x05441e51, 0x5785e092, 0x59734838, 0x4118c3c6, 0xe2e29ed7, ], ]); @@ -938,27 +870,26 @@ impl Poseidon1KoalaBear16 { mds_circ_16(state); } - /// SIMD fast path (NEON / AVX2 / AVX512) using: + /// NEON-specific fast path using: /// - Fused AddRC+S-box (`add_rc_and_sbox`) for full rounds /// - `InternalLayer16` split for ILP between S-box and dot product in partial rounds /// - Pre-packed sparse matrix constants - #[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") - ))] + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] #[inline(always)] - fn permute_simd(&self, state: &mut [PackedKB; 16]) { - use crate::{InternalLayer16, add_rc_and_sbox, sbox}; + fn permute_neon(&self, state: &mut [PackedKB; 16]) { + use crate::PackedMontyField31Neon; + use crate::exp_small; + use crate::{InternalLayer16, add_rc_and_sbox}; use core::mem::transmute; - let simd = &self.pre.simd; - let lambda16 = &simd.packed_lambda_over_16; + let neon = &self.pre.neon; + let lambda16 = &neon.packed_lambda_over_16; /// FFT MDS: state = C * state. /// Uses lambda/16 eigenvalues so no separate /16 step needed. /// C * x = DIT_FFT((lambda/16) ⊙ DIF_IFFT(x)) #[inline(always)] - fn mds_fft(state: &mut [PackedKB; 16], lambda16: &[PackedKB; 16]) { + fn mds_fft_neon(state: &mut [PackedKB; 16], lambda16: &[PackedKB; 16]) { dif_ifft_16_mut(state); for i in 0..16 { state[i] *= lambda16[i]; @@ -967,11 +898,11 @@ impl Poseidon1KoalaBear16 { } // --- Initial full rounds (first 3 of 4) --- - for round_constants in &simd.packed_initial_rc { + for round_constants in &neon.packed_initial_rc { for (s, &rc) in state.iter_mut().zip(round_constants.iter()) { add_rc_and_sbox::(s, rc); } - mds_fft(state, lambda16); + mds_fft_neon(state, lambda16); } // --- Last initial full round: AddRC + S-box, then fused (m_i * MDS) --- @@ -979,12 +910,13 @@ impl Poseidon1KoalaBear16 { // = (m_i * MDS) * state + m_i * first_rc // Saves one full FFT MDS call. { - for (s, &rc) in state.iter_mut().zip(simd.packed_last_initial_rc.iter()) { + for (s, &rc) in state.iter_mut().zip(neon.packed_last_initial_rc.iter()) { add_rc_and_sbox::(s, rc); } let input = *state; for (i, state_i) in state.iter_mut().enumerate() { - *state_i = PackedKB::dot_product(&input, &simd.packed_fused_mi_mds[i]) + simd.packed_fused_bias[i]; + *state_i = PackedMontyField31Neon::::dot_product(&input, &neon.packed_fused_mi_mds[i]) + + neon.packed_fused_bias[i]; } } @@ -994,25 +926,29 @@ impl Poseidon1KoalaBear16 { for r in 0..POSEIDON1_PARTIAL_ROUNDS { // PATH A (high latency): S-box on s0 only. - split.s0 = sbox::(split.s0); + unsafe { + let s0_vec = split.s0.to_vector(); + let s0_sboxed = exp_small::(s0_vec); + split.s0 = PackedMontyField31Neon::from_vector(s0_sboxed); + } // Add scalar round constant (except last round). if r < POSEIDON1_PARTIAL_ROUNDS - 1 { - split.s0 += simd.packed_round_constants[r]; + split.s0 += neon.packed_round_constants[r]; } // PATH B (can overlap with S-box): partial dot product on s_hi. let s_hi: &[PackedKB; 15] = unsafe { transmute(&split.s_hi) }; - let first_row = &simd.packed_sparse_first_row[r]; + let first_row = &neon.packed_sparse_first_row[r]; let first_row_hi: &[PackedKB; 15] = first_row[1..].try_into().unwrap(); - let partial_dot = PackedKB::dot_product(s_hi, first_row_hi); + let partial_dot = PackedMontyField31Neon::::dot_product(s_hi, first_row_hi); // SERIAL: complete s0 = first_row[0] * s0 + partial_dot. let s0_val = split.s0; split.s0 = s0_val * first_row[0] + partial_dot; // Rank-1 update: s_hi[j] += s0_old * v[j]. - let v = &simd.packed_sparse_v[r]; + let v = &neon.packed_sparse_v[r]; let s_hi_mut: &mut [PackedKB; 15] = unsafe { transmute(&mut split.s_hi) }; for j in 0..15 { s_hi_mut[j] += s0_val * v[j]; @@ -1023,11 +959,11 @@ impl Poseidon1KoalaBear16 { } // --- Terminal full rounds --- - for round_constants in &simd.packed_terminal_rc { + for round_constants in &neon.packed_terminal_rc { for (s, &rc) in state.iter_mut().zip(round_constants.iter()) { add_rc_and_sbox::(s, rc); } - mds_fft(state, lambda16); + mds_fft_neon(state, lambda16); } } @@ -1038,7 +974,7 @@ impl Poseidon1KoalaBear16 { state: &mut [R; 16], ) { let initial = *state; - // Use permute_mut so the SIMD fast path is dispatched when applicable. + // Use permute_mut for NEON dispatch. Permutation::permute_mut(self, state); for (s, init) in state.iter_mut().zip(initial) { *s += init; @@ -1050,16 +986,14 @@ impl + InjectiveMonomial<3> + Send + Sync + 'static> Permu for Poseidon1KoalaBear16 { fn permute_mut(&self, input: &mut [R; 16]) { - // On targets with a SIMD fast path, dispatch to it when R is the arch-specific packed type. - #[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") - ))] + // On aarch64+neon, dispatch to the NEON fast path when R is PackedKoalaBearNeon. + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { if std::any::TypeId::of::() == std::any::TypeId::of::() { - // SAFETY: TypeId confirms R == PackedKB; PackedKB is repr(transparent). - let simd_state: &mut [PackedKB; 16] = unsafe { &mut *(input as *mut [R; 16] as *mut [PackedKB; 16]) }; - self.permute_simd(simd_state); + // SAFETY: We have just confirmed via TypeId that R == PackedKB. + // Both types have the same size and alignment (PackedKB is repr(transparent)). + let neon_state: &mut [PackedKB; 16] = unsafe { &mut *(input as *mut [R; 16] as *mut [PackedKB; 16]) }; + self.permute_neon(neon_state); return; } } @@ -1101,8 +1035,8 @@ mod tests { assert_eq!( vals, vec![ - 610090613, 935319874, 1893335292, 796792199, 356405232, 552237741, 55134556, 1215104204, 1823723405, - 1133298033, 1780633798, 1453946561, 710069176, 1128629550, 1917333254, 1175481618, + 2472545174, 2494465264, 1378828411, 2159817276, 990840178, 4077691891, 367747210, 1296698476, + 2559737505, 2863680013, 1095349934, 2118207550, 3744526966, 2462370130, 4189625406, 3376325618, ] ); } diff --git a/crates/backend/koala-bear/src/quintic_extension/extension.rs b/crates/backend/koala-bear/src/quintic_extension/extension.rs index d600f9275..b06fa588c 100644 --- a/crates/backend/koala-bear/src/quintic_extension/extension.rs +++ b/crates/backend/koala-bear/src/quintic_extension/extension.rs @@ -23,7 +23,7 @@ use super::packed_extension::PackedQuinticExtensionField; use crate::QuinticExtendable; /// Quintic Extension Field (degree 5), specifically designed for Koala-Bear -/// Irreducible polynomial: X^5 + X^2 - 1 +/// Irreducible polynomial: X^5 + 2 #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, PartialOrd, Ord)] #[repr(transparent)] // Needed to make various casts safe. #[must_use] @@ -526,27 +526,34 @@ impl TwoAdicField for QuinticExtensionField } } -/// Quintic extension field multiplication in F[X]/(X^5 + X^2 - 1). -#[inline(always)] -pub fn quintic_mul>( +/// Quintic extension field multiplication in F[X]/(X^5 + 2). +/// +/// X^5 = -2, so higher-degree terms are reduced by replacing X^5 with -2. +#[inline] +pub fn quintic_mul + Neg>( a: &[T; 5], b: &[T; 5], dot_product: impl Fn(&[T; 5], &[T; 5]) -> T, ) -> [T; 5] { - let b_0_m3 = b[0] - b[3]; - let b_1_m4 = b[1] - b[4]; - let b_4_m2 = b[4] - b[2]; + // Precompute -2*b[i] = -(b[i] + b[i]) + let n2b1 = -(b[1] + b[1]); + let n2b2 = -(b[2] + b[2]); + let n2b3 = -(b[3] + b[3]); + let n2b4 = -(b[4] + b[4]); [ - dot_product(a, &[b[0], b[4], b[3], b[2], b_1_m4]), - dot_product(a, &[b[1], b[0], b[4], b[3], b[2]]), - dot_product(a, &[b[2], b_1_m4, b_0_m3, b_4_m2, b[3] - b_1_m4]), - dot_product(a, &[b[3], b[2], b_1_m4, b_0_m3, b_4_m2]), - dot_product(a, &[b[4], b[3], b[2], b_1_m4, b_0_m3]), + dot_product(a, &[b[0], n2b4, n2b3, n2b2, n2b1]), + dot_product(a, &[b[1], b[0], n2b4, n2b3, n2b2]), + dot_product(a, &[b[2], b[1], b[0], n2b4, n2b3]), + dot_product(a, &[b[3], b[2], b[1], b[0], n2b4]), + dot_product(a, &[b[4], b[3], b[2], b[1], b[0]]), ] } -#[inline(always)] +/// Squaring in F[X]/(X^5 + 2). +/// +/// X^5 = -2, so c[i+5] terms get multiplied by -2. +#[inline] pub(crate) fn quintic_square(a: &[R; 5], res: &mut [R; 5]) where F: Field, @@ -565,20 +572,25 @@ where let a3_square = a[3].square(); let a4_square = a[4].square(); - // Constant term = a0^2 + 2*a1*a4 + 2*a2*a3 - a4^2 - res[0] = R::dot_product(&[a[0], two_a1], &[a[0], a[4]]) + two_a2_a3 - a4_square; + // For X^5 + 2: c[5] = 2*a1*a4 + 2*a2*a3, reduced by -2 + // c[6] = 2*a2*a4 + a3^2, reduced by -2 + // c[7] = 2*a3*a4, reduced by -2 + // c[8] = a4^2, reduced by -2 + + // Constant term = a0^2 - 2*(2*a1*a4 + 2*a2*a3) = a0^2 - 4*a1*a4 - 4*a2*a3 + res[0] = a[0].square() - (two_a1_a4 + two_a2_a3).double(); - // Linear term = 2*a0*a1 + a3^2 + 2*a2*a4 - res[1] = two_a0 * a[1] + a3_square + two_a2_a4; + // Linear term = 2*a0*a1 - 2*(2*a2*a4 + a3^2) = 2*a0*a1 - 4*a2*a4 - 2*a3^2 + res[1] = two_a0 * a[1] - two_a2_a4.double() - a3_square.double(); - // Square term = a1^2 + 2*a0*a2 - 2*a1*a4 - 2*a2*a3 + 2*a3*a4 + a4^2 - res[2] = a[1].square() + two_a0 * a[2] - two_a1_a4 - two_a2_a3 + two_a3_a4 + a4_square; + // Quadratic term = 2*a0*a2 + a1^2 - 2*(2*a3*a4) = 2*a0*a2 + a1^2 - 4*a3*a4 + res[2] = two_a0 * a[2] + a[1].square() - two_a3_a4.double(); - // Cubic term = 2*a0*a3 + 2*a1*a2 - a3^2 - 2*a2*a4 + a4^2 - res[3] = R::dot_product(&[two_a0, two_a1], &[a[3], a[2]]) - a3_square - two_a2_a4 + a4_square; + // Cubic term = 2*a0*a3 + 2*a1*a2 - 2*a4^2 + res[3] = R::dot_product(&[two_a0, two_a1], &[a[3], a[2]]) - a4_square.double(); - // Quartic term = a2^2 + 2*a0*a4 + 2*a1*a3 - 2*a3*a4 - res[4] = R::dot_product(&[two_a0, two_a1], &[a[4], a[3]]) + a[2].square() - two_a3_a4; + // Quartic term = 2*a0*a4 + 2*a1*a3 + a2^2 + res[4] = R::dot_product(&[two_a0, two_a1], &[a[4], a[3]]) + a[2].square(); } #[inline] @@ -588,16 +600,17 @@ fn quintic_inv(a: &QuinticExtensionField) -> QuinticExt let a_exp_q_plus_q_sq = (*a * a_exp_q).frobenius(); let prod_conj = a_exp_q_plus_q_sq * a_exp_q_plus_q_sq.repeated_frobenius(2); - // norm = a * prod_conj is in the base field, so only compute that - // coefficient rather than the full product. + // norm = a * prod_conj is in the base field, so only compute the + // constant coefficient (for X^5 + 2: c[0] = a0*b0 - 2*(a1*b4 + a2*b3 + a3*b2 + a4*b1)). + let two = F::TWO; let norm = F::dot_product::<5>( &a.value, &[ prod_conj.value[0], - prod_conj.value[4], - prod_conj.value[3], - prod_conj.value[2], - prod_conj.value[1] - prod_conj.value[4], + -(two * prod_conj.value[4]), + -(two * prod_conj.value[3]), + -(two * prod_conj.value[2]), + -(two * prod_conj.value[1]), ], ); diff --git a/crates/backend/koala-bear/src/quintic_extension/mod.rs b/crates/backend/koala-bear/src/quintic_extension/mod.rs index 6ccdca4f7..785da450f 100644 --- a/crates/backend/koala-bear/src/quintic_extension/mod.rs +++ b/crates/backend/koala-bear/src/quintic_extension/mod.rs @@ -16,34 +16,37 @@ pub type QuinticExtensionFieldKB = QuinticExtensionField; pub type PackedQuinticExtensionFieldKB = PackedQuinticExtensionField::Packing>; impl QuinticExtendable for KoalaBear { + /// Frobenius matrix for X^5 + 2 over F_p where p = 4194304001. + /// Since X^5 + 2 is a binomial extension and p ≡ 1 (mod 5), the Frobenius is diagonal: + /// X^{ip} = w^i * X^i where w = (-2)^((p-1)/5). const FROBENIUS_MATRIX: [[Self; 5]; 4] = [ [ - Self::new(1576402667), - Self::new(1173144480), - Self::new(1567662457), - Self::new(1206866823), - Self::new(2428146), + Self::new(0), + Self::new(561393150), + Self::new(0), + Self::new(0), + Self::new(0), ], [ - Self::new(1680345488), - Self::new(1381986), - Self::new(615237464), - Self::new(1380104858), - Self::new(295431824), + Self::new(0), + Self::new(0), + Self::new(1307621960), + Self::new(0), + Self::new(0), ], [ - Self::new(441230756), - Self::new(323126830), - Self::new(704986542), - Self::new(1445620072), - Self::new(503505220), + Self::new(0), + Self::new(0), + Self::new(0), + Self::new(1448665303), + Self::new(0), ], [ - Self::new(1364444097), - Self::new(1144738982), - Self::new(2008416047), - Self::new(143367062), - Self::new(1027410849), + Self::new(0), + Self::new(0), + Self::new(0), + Self::new(0), + Self::new(876623587), ], ]; @@ -92,7 +95,7 @@ impl QuinticExtendableAlgebra for KoalaBear { } } -/// Trait for fields that support binomial extension of the form: `F[X]/(X^5 + X^2 - 1)` +/// Trait for fields that support quintic extension of the form: `F[X]/(X^5 + 2)` pub trait QuinticExtendable: Field + QuinticExtendableAlgebra { const FROBENIUS_MATRIX: [[Self; 5]; 4]; diff --git a/crates/backend/koala-bear/src/quintic_extension/packing.rs b/crates/backend/koala-bear/src/quintic_extension/packing.rs index 9a7e6ce23..00884aa3c 100644 --- a/crates/backend/koala-bear/src/quintic_extension/packing.rs +++ b/crates/backend/koala-bear/src/quintic_extension/packing.rs @@ -8,187 +8,66 @@ use crate::KoalaBear; all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "x86_64", target_feature = "avx2",) )))] -#[inline(always)] +#[inline] pub(crate) fn quintic_mul_packed(a: &[KoalaBear; 5], b: &[KoalaBear; 5], res: &mut [KoalaBear; 5]) { use field::PrimeCharacteristicRing; *res = super::extension::quintic_mul(a, b, KoalaBear::dot_product::<5>); } #[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] -/// Multiplication in a quintic binomial extension field. -#[inline(always)] +/// Multiplication in a quintic binomial extension field F[X]/(X^5 + 2). +/// +/// The packed (vectorized) AVX dot-product helpers used by previous +/// implementations assume P < 2^31. With the current 32-bit prime +/// (P = 0xfa000001 ≈ 2^32), summing two 32-bit-prime products in a u64 +/// can overflow, so we route this multiplication through the scalar +/// path which uses u128 accumulation. +#[inline] pub(crate) fn quintic_mul_packed(a: &[KoalaBear; 5], b: &[KoalaBear; 5], res: &mut [KoalaBear; 5]) { - // TODO: This could likely be optimised further with more effort. - // in particular it would benefit from a custom AVX2 implementation. - - use crate::PackedMontyField31AVX2; use field::PrimeCharacteristicRing; - - // Constant term = a0*b0 + a1*b4 + a2*b3 + a3*b2 + a4*b1 - a4*b4 - // Linear term = a0*b1 + a1*b0 + a2*b4 + a3*b3 + a4*b2 - // Square term = a0*b2 + a1*b1 - a1*b4 + a2*b0 - a2*b3 + a3*b4 - a3*b2 + a4*b3 - a4*b1 + a4*b4 - // Cubic term = a0*b3 + a1*b2 + a2*b1 - a2*b4 + a3*b0 - a3*b3 + a4*b4 - a4*b2 - // Quartic term = a0*b4 + a1*b3 + a2*b2 + a3*b1 - a3*b4 + a4*b0 - a4*b3 - - let zero = KoalaBear::ZERO; - let b0_minus_b3 = b[0] - b[3]; - let b1_minus_b4 = b[1] - b[4]; - let b4_minus_b2 = b[4] - b[2]; - let b3_plus_b4_minus_b_1 = b[3] - b1_minus_b4; - - let lhs = [ - PackedMontyField31AVX2([a[0], a[0], a[0], a[0], a[0], a[4], a[4], a[4]]), - PackedMontyField31AVX2([a[1], a[1], a[1], a[1], a[1], zero, zero, zero]), - PackedMontyField31AVX2([a[2], a[2], a[2], a[2], a[2], zero, zero, zero]), - PackedMontyField31AVX2([a[3], a[3], a[3], a[3], a[3], zero, zero, zero]), - ]; - let rhs = [ - PackedMontyField31AVX2([b[0], b[1], b[2], b[3], b[4], b1_minus_b4, b[2], b3_plus_b4_minus_b_1]), - PackedMontyField31AVX2([b[4], b[0], b1_minus_b4, b[2], b[3], zero, zero, zero]), - PackedMontyField31AVX2([b[3], b[4], b0_minus_b3, b1_minus_b4, b[2], zero, zero, zero]), - PackedMontyField31AVX2([b[2], b[3], b4_minus_b2, b0_minus_b3, b1_minus_b4, zero, zero, zero]), - ]; - - let dot_res = unsafe { PackedMontyField31AVX2::from_vector(crate::dot_product_4(lhs, rhs)) }; - - // We managed to compute 3 of the extra terms in the last 3 places of the dot product. - // This leaves us with 2 terms remaining we need to compute manually. - let extra1 = b4_minus_b2 * a[4]; - let extra2 = b0_minus_b3 * a[4]; - - let extra_addition = PackedMontyField31AVX2([ - dot_res.0[5], - dot_res.0[6], - dot_res.0[7], - extra1, - extra2, - zero, - zero, - zero, - ]); - let total = dot_res + extra_addition; - - res.copy_from_slice(&total.0[..5]); + *res = super::extension::quintic_mul(a, b, KoalaBear::dot_product::<5>); } #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -/// Multiplication in a quintic binomial extension field. -#[inline(always)] +/// Multiplication in a quintic binomial extension field F[X]/(X^5 + 2). +/// +/// The packed (vectorized) AVX dot-product helpers used by previous +/// implementations assume P < 2^31. With the current 32-bit prime +/// (P = 0xfa000001 ≈ 2^32), summing two 32-bit-prime products in a u64 +/// can overflow, so we route this multiplication through the scalar +/// path which uses u128 accumulation. +#[inline] pub(crate) fn quintic_mul_packed(a: &[KoalaBear; 5], b: &[KoalaBear; 5], res: &mut [KoalaBear; 5]) { - use crate::{PackedMontyField31AVX512, dot_product_2}; use field::PrimeCharacteristicRing; - - // TODO: It's plausible that this could be improved by folding the computation of packed_b into - // the custom AVX512 implementation. Moreover, AVX512 is really a bit to large so we are wasting a lot - // of space. A custom implementation which mixes AVX512 and AVX2 code might well be able to - // improve one that is here. - let zero = KoalaBear::ZERO; - let b0_minus_b3 = b[0] - b[3]; - let b1_minus_b4 = b[1] - b[4]; - let b4_minus_b2 = b[4] - b[2]; - let b3_plus_b4_minus_b_1 = b[3] - b1_minus_b4; - - // Constant term = a0*b0 + a1*b4 + a2*b3 + a3*b2 + a4*b1 - a4*b4 - // Linear term = a0*b1 + a1*b0 + a2*b4 + a3*b3 + a4*b2 - // Square term = a0*b2 + a1*b1 - a1*b4 + a2*b0 - a2*b3 + a3*b4 - a3*b2 + a4*b3 - a4*b1 + a4*b4 - // Cubic term = a0*b3 + a1*b2 + a2*b1 - a2*b4 + a3*b0 - a3*b3 + a4*b4 - a4*b2 - // Quartic term = a0*b4 + a1*b3 + a2*b2 + a3*b1 - a3*b4 + a4*b0 - a4*b3 - - // Each packed vector can do 8 multiplications at once. As we have - // 25 multiplications to do we will need to use at least 3 packed vectors - // but we might as well use 4 so we can make use of dot_product_2. - // TODO: This can probably be improved by using a custom function. - let lhs = [ - PackedMontyField31AVX512([ - a[0], a[2], a[0], a[2], a[0], a[2], a[0], a[2], a[0], a[2], a[4], a[4], a[4], a[4], a[4], zero, - ]), - PackedMontyField31AVX512([ - a[1], a[3], a[1], a[3], a[1], a[3], a[1], a[3], a[1], a[3], zero, zero, zero, zero, zero, zero, - ]), - ]; - let rhs = [ - PackedMontyField31AVX512([ - b[0], - b[3], - b[1], - b[4], - b[2], - b0_minus_b3, - b[3], - b1_minus_b4, - b[4], - b[2], - b1_minus_b4, - b[2], - b3_plus_b4_minus_b_1, - b4_minus_b2, - b0_minus_b3, - zero, - ]), - PackedMontyField31AVX512([ - b[4], - b[2], - b[0], - b[3], - b1_minus_b4, - b4_minus_b2, - b[2], - b0_minus_b3, - b[3], - b1_minus_b4, - zero, - zero, - zero, - zero, - zero, - zero, - ]), - ]; - - let dot = unsafe { PackedMontyField31AVX512::from_vector(dot_product_2(lhs, rhs)).0 }; - - let sumand1 = PackedMontyField31AVX512::from_monty_array([dot[0], dot[2], dot[4], dot[6], dot[8]]); - let sumand2 = PackedMontyField31AVX512::from_monty_array([dot[1], dot[3], dot[5], dot[7], dot[9]]); - let sumand3 = PackedMontyField31AVX512::from_monty_array([dot[10], dot[11], dot[12], dot[13], dot[14]]); - let sum = sumand1 + sumand2 + sumand3; - - res.copy_from_slice(&sum.0[..5]); + *res = super::extension::quintic_mul(a, b, KoalaBear::dot_product::<5>); } #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] -/// Multiplication in a quintic binomial extension field. -#[inline(always)] +/// Multiplication in quintic extension field F[X]/(X^5 + 2). +#[inline] pub(crate) fn quintic_mul_packed(a: &[KoalaBear; 5], b: &[KoalaBear; 5], res: &mut [KoalaBear; 5]) { - // TODO: This could be optimised further with a custom NEON implementation. - use crate::PackedMontyField31Neon; use field::PrimeCharacteristicRing; - let b0_minus_b3 = b[0] - b[3]; - let b1_minus_b4 = b[1] - b[4]; - let b4_minus_b2 = b[4] - b[2]; - let b3_plus_b4_minus_b_1 = b[3] - b1_minus_b4; + // For X^5 + 2: X^5 = -2 + let neg_2b1 = -(b[1] + b[1]); + let neg_2b2 = -(b[2] + b[2]); + let neg_2b3 = -(b[3] + b[3]); + let neg_2b4 = -(b[4] + b[4]); - // Constant term = a0*b0 + a1*b4 + a2*b3 + a3*b2 + a4*b1 - a4*b4 - // Linear term = a0*b1 + a1*b0 + a2*b4 + a3*b3 + a4*b2 - // Square term = a0*b2 + a1*b1 - a1*b4 + a2*b0 - a2*b3 + a3*b4 - a3*b2 + a4*b3 - a4*b1 + a4*b4 - // Cubic term = a0*b3 + a1*b2 + a2*b1 - a2*b4 + a3*b0 - a3*b3 + a4*b4 - a4*b2 - // Quartic term = a0*b4 + a1*b3 + a2*b2 + a3*b1 - a3*b4 + a4*b0 - a4*b3 let lhs: [PackedMontyField31Neon; 5] = [a[0].into(), a[1].into(), a[2].into(), a[3].into(), a[4].into()]; let rhs = [ PackedMontyField31Neon([b[0], b[1], b[2], b[3]]), - PackedMontyField31Neon([b[4], b[0], b1_minus_b4, b[2]]), - PackedMontyField31Neon([b[3], b[4], b0_minus_b3, b1_minus_b4]), - PackedMontyField31Neon([b[2], b[3], b4_minus_b2, b0_minus_b3]), - PackedMontyField31Neon([b1_minus_b4, b[2], b3_plus_b4_minus_b_1, b4_minus_b2]), + PackedMontyField31Neon([neg_2b4, b[0], b[1], b[2]]), + PackedMontyField31Neon([neg_2b3, neg_2b4, b[0], b[1]]), + PackedMontyField31Neon([neg_2b2, neg_2b3, neg_2b4, b[0]]), + PackedMontyField31Neon([neg_2b1, neg_2b2, neg_2b3, neg_2b4]), ]; let dot = PackedMontyField31Neon::dot_product(&lhs, &rhs).0; - res[..4].copy_from_slice(&dot); - res[4] = KoalaBear::dot_product::<5>( - &[a[0], a[1], a[2], a[3], a[4]], - &[b[4], b[3], b[2], b1_minus_b4, b0_minus_b3], - ); + + // result[4] = dot(a, [b4, b3, b2, b1, b0]) (no -2 terms) + res[4] = KoalaBear::dot_product::<5>(&[a[0], a[1], a[2], a[3], a[4]], &[b[4], b[3], b[2], b[1], b[0]]); } diff --git a/crates/backend/koala-bear/src/x86_64_avx2/packing.rs b/crates/backend/koala-bear/src/x86_64_avx2/packing.rs index 396a99d77..df762de23 100644 --- a/crates/backend/koala-bear/src/x86_64_avx2/packing.rs +++ b/crates/backend/koala-bear/src/x86_64_avx2/packing.rs @@ -11,6 +11,6 @@ pub type PackedKoalaBearAVX2 = PackedMontyField31AVX2; const WIDTH: usize = 8; impl MontyParametersAVX2 for KoalaBearParameters { - const PACKED_P: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0x7f000001; WIDTH]) }; - const PACKED_MU: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0x81000001; WIDTH]) }; + const PACKED_P: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0xfa000001; WIDTH]) }; + const PACKED_MU: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0x06000001; WIDTH]) }; } diff --git a/crates/backend/koala-bear/src/x86_64_avx512/packing.rs b/crates/backend/koala-bear/src/x86_64_avx512/packing.rs index 192d70ad0..3f59aabd2 100644 --- a/crates/backend/koala-bear/src/x86_64_avx512/packing.rs +++ b/crates/backend/koala-bear/src/x86_64_avx512/packing.rs @@ -11,6 +11,6 @@ pub type PackedKoalaBearAVX512 = PackedMontyField31AVX512; const WIDTH: usize = 16; impl MontyParametersAVX512 for KoalaBearParameters { - const PACKED_P: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x7f000001; WIDTH]) }; - const PACKED_MU: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x81000001; WIDTH]) }; + const PACKED_P: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0xfa000001; WIDTH]) }; + const PACKED_MU: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x06000001; WIDTH]) }; } diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py index 6ea149c2a..6a71c0b70 100644 --- a/crates/lean_compiler/tests/test_data/program_15.py +++ b/crates/lean_compiler/tests/test_data/program_15.py @@ -10,7 +10,7 @@ def main(): i, j, k = func_1(x, y) assert i == 2 assert j == 3 - assert k == 2130706432 + assert k == 4194304000 g = Array(8) h = Array(8) diff --git a/crates/lean_compiler/tests/test_data/program_30.py b/crates/lean_compiler/tests/test_data/program_30.py index 0348faa65..819af3b25 100644 --- a/crates/lean_compiler/tests/test_data/program_30.py +++ b/crates/lean_compiler/tests/test_data/program_30.py @@ -9,7 +9,7 @@ def main(): for i in unroll(0, 2): res = f1(ARR[i]) buff[i + 1] = res - assert buff[2] == 1390320454 + assert buff[2] == 1365390346 return diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 7d8b11eef..043f24aa7 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -11,7 +11,10 @@ use std::sync::OnceLock; use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements}; use tracing::instrument; use utils::Counter; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, XMSS_DIGEST_LEN}; +use xmss::{ + ENCODING_NUM_FINAL_ZEROS, LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, + XMSS_DIGEST_LEN, +}; use crate::{MERKLE_LEVELS_PER_CHUNK_FOR_SLOT, N_MERKLE_CHUNKS_FOR_SLOT, NUM_REPEATED_ONES, ZERO_VEC_LEN}; @@ -349,6 +352,10 @@ fn build_replacements( replacements.insert("V_PLACEHOLDER".to_string(), V.to_string()); replacements.insert("W_PLACEHOLDER".to_string(), W.to_string()); replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); + replacements.insert( + "ENCODING_NUM_FINAL_ZEROS_PLACEHOLDER".to_string(), + ENCODING_NUM_FINAL_ZEROS.to_string(), + ); replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MESSAGE_LEN_FE.to_string()); replacements.insert("RANDOMNESS_LEN_PLACEHOLDER".to_string(), RANDOMNESS_LEN_FE.to_string()); diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 3999f4548..98eb5d8e8 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -1,10 +1,10 @@ from snark_lib import * from hashing import * -F_BITS = 31 # koala-bear = 31 bits +F_BITS = 32 # koala-bear (32-bit prime experiment) = 32 bits -TWO_ADICITY = 24 -ROOT = 1791270792 # of order 2^TWO_ADICITY +TWO_ADICITY = 25 +ROOT = 1177770062 # of order 2^TWO_ADICITY (= 0x4633584e) @inline @@ -447,17 +447,14 @@ def sum_2_ef_fractions(a_num, a_den, b_num, b_den): return sum_num, common_den -# p = 2^31 - 2^24 + 1 -# in binary: p = 1111111000000000000000000000001 -# p - 1 = 1111111000000000000000000000000 -# p - 2 = 1111110111111111111111111111111 -# p - 3 = 1111110111111111111111111111110 -# ... +# p = 0xfa000001 = 250 * 2^24 + 1 +# in binary: p = 11111010_00000000_00000000_00000001 +# p - 1 = 11111010_00000000_00000000_00000000 # Any field element (< p) is either: -# - 1111111 | 00...00 -# - not(1111111) | xx...xx +# - 11111010 | 00...00 +# - not(11111010) | xx...xx def checked_decompose_bits(a): - # return a pointer to the 31 bits of a + # return a pointer to the 32 bits of a # .. and the first 24 partial sums of these bits bits = Array(F_BITS) hint_decompose_bits(a, bits, F_BITS, LITTLE_ENDIAN) @@ -468,13 +465,13 @@ def checked_decompose_bits(a): partial_sums_24[0] = bits[0] for i in unroll(1, 24): partial_sums_24[i] = partial_sums_24[i - 1] + bits[i] * 2**i - sum_7: Mut = bits[24] - for i in unroll(1, 7): - sum_7 += bits[24 + i] * 2**i - if sum_7 == 127: + sum_8: Mut = bits[24] + for i in unroll(1, 8): + sum_8 += bits[24 + i] * 2**i + if sum_8 == 250: assert partial_sums_24[23] == 0 - assert a == partial_sums_24[23] + sum_7 * 2**24 + assert a == partial_sums_24[23] + sum_8 * 2**24 return bits, partial_sums_24 @@ -514,12 +511,12 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): for i in unroll(1, 6): partial_sum += nibbles[i] * 16**i - # p = 2^31 - 2^24 + 1, so 2^24 * 127 = p - 1 ≡ -1 (mod p), hence inv(2^24) = -127. - # Deduce top7 from the identity partial_sum + top7 * 2^24 == a: - # top7 = (a - partial_sum) * inv(2^24) = (partial_sum - a) * 127 - top7 = (partial_sum - a) * 127 - assert top7 < 2**7 - if top7 == 2**7 - 1: + # p = 250 * 2^24 + 1, so 2^24 * 250 ≡ -1 (mod p), hence inv(2^24) = -250. + # Deduce top from the identity partial_sum + top * 2^24 == a: + # top = (a - partial_sum) * inv(2^24) = (partial_sum - a) * 250 + top = (partial_sum - a) * 250 + assert top < 251 + if top == 250: assert partial_sum == 0 leaf_data = Array(num_chunks * DIGEST_LEN) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 39090e6c6..5897ad9c7 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -5,6 +5,7 @@ W = W_PLACEHOLDER CHAIN_LENGTH = 2**W TARGET_SUM = TARGET_SUM_PLACEHOLDER +ENCODING_NUM_FINAL_ZEROS = ENCODING_NUM_FINAL_ZEROS_PLACEHOLDER LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER RANDOMNESS_LEN = RANDOMNESS_LEN_PLACEHOLDER @@ -15,7 +16,7 @@ WOTS_SIG_SIZE = RANDOMNESS_LEN + V * XMSS_DIGEST_LEN # wots_public_key pair stride: each pair occupies 10 cells `[leading_0 | tip_a(4) | tip_b(4) | trailing_0]`. In order to be able to use copy_5 on both sides. WOTS_PK_PAIR_STRIDE = 2 + 2 * XMSS_DIGEST_LEN -NUM_ENCODING_FE = div_ceil(V, (24 / W)) +NUM_ENCODING_FE = div_ceil(V * W + ENCODING_NUM_FINAL_ZEROS, 24) MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER N_MERKLE_CHUNKS = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK INNER_PUB_MEM_SIZE = 2**INNER_PUBLIC_MEMORY_LOG_SIZE # = DIGEST_LEN @@ -71,12 +72,21 @@ def xmss_verify(pub_key, message, merkle_chunks): for j in unroll(1, 24 / (2 * W)): partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH**2) ** j - # p = 2^31 - 2^24 + 1 = 127.2^24 + 1, so inv(2^24) = -127 (mod p). + # p = 250 * 2^24 + 1, so inv(2^24) = -250 (mod p). # Deduce remaining_i from partial_sum + remaining_i * 2^24 == encoding_fe[i]: - # remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 127 - remaining_i = (partial_sum - encoding_fe[i]) * 127 - assert remaining_i < 127 # ensures uniformity + prevent overflow - + # remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 250 + remaining_i = (partial_sum - encoding_fe[i]) * 250 + assert remaining_i < 250 # ensures uniformity + prevent overflow + + # encoding[V/2] is the 2*W-bit chunk holding bits V*W..V*W + 2*W - 1 of the + # encoding digest, just past the V*W bits used for chain indices. Force its + # ENCODING_NUM_FINAL_ZEROS lowest bits to zero, lifting the encoding constraint + # from V*W to V*W + ENCODING_NUM_FINAL_ZEROS bits. + match_range( + encoding[V / 2], + range(0, CHAIN_LENGTH**2), + lambda n: assert_encoding_gap_zero(n), + ) debug_assert(V % 2 == 0) wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) @@ -179,6 +189,13 @@ def chain_hash_pair( return +@inline +def assert_encoding_gap_zero(n): + if n % (2**ENCODING_NUM_FINAL_ZEROS) != 0: + assert False + return + + @inline def wots_pk_hash(wots_public_key, public_param): N_CHUNKS = V / 2 diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 3a41f2164..4d165b231 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -22,6 +22,9 @@ pub const W: usize = 3; pub const CHAIN_LENGTH: usize = 1 << W; pub const NUM_CHAIN_HASHES: usize = 110; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; +pub const ENCODING_NUM_FINAL_ZEROS: usize = 2; +const _: () = assert!(V * W + ENCODING_NUM_FINAL_ZEROS <= 8 * 24); // fits in the 8×24-bit encoding digest +const _: () = assert!(ENCODING_NUM_FINAL_ZEROS <= 2 * W); // fits in a single 2W-bit pair chunk (zkDSL assumption) pub const RANDOMNESS_LEN_FE: usize = 6; pub const MESSAGE_LEN_FE: usize = 8; pub const PUBLIC_PARAM_LEN_FE: usize = 4; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index bce2e6a4c..acd887244 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -169,10 +169,16 @@ pub fn wots_encode( // ensures uniformity of encoding return None; } - let all_indices: Vec<_> = compressed + let all_bits: Vec = compressed .iter() .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) - .collect::>() + .collect(); + + if all_bits[V * W..][..ENCODING_NUM_FINAL_ZEROS].iter().any(|&b| b) { + return None; + } + + let all_indices: Vec = all_bits .chunks_exact(W) .take(V) .map(|chunk| { diff --git a/crates/xmss/xmss.md b/crates/xmss/xmss.md index 1a03e7d4c..0884627eb 100644 --- a/crates/xmss/xmss.md +++ b/crates/xmss/xmss.md @@ -2,7 +2,7 @@ ## Field -KoalaBear (p = 2^31 - 2^24 + 1). +p = 125.2^25 + 1 ## Hash function @@ -21,6 +21,7 @@ KoalaBear (p = 2^31 - 2^24 + 1). - `v = 42`: number of hash chains - `w = 3`, `chain_length = 2^w = 8` - `target_sum = 184`: a WOTS encoding `(e_0, ..., e_{v-1})` is valid iff each `e_i < chain_length` and `sum(e_i) = target_sum`. The signer grinds `randomness` until the encoding is valid (avoids checksum chains). +- `encoding_num_final_zeros = 2`: the `encoding_num_final_zeros` bits of the encoding digest immediately after the `v · w = 126` chain-encoding bits must be zero (for 128-bits security) ## XMSS @@ -30,7 +31,7 @@ KoalaBear (p = 2^31 - 2^24 + 1). Inputs: public key `(merkle_root, pp)`, message `msg`, slot `s`, signature `(randomness, chain_tips, merkle_proof)`. -1. **Encode**: compute the 8-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 0000)`. Reject if any limb of `D` equals `-1` (for a uniform sampling). For each limb, take its canonical representative's low 24 bits in little-endian order, concatenate to get 192 bits, then take the first `v · w = 126` bits split into `v = 42` little-endian chunks of `w = 3` bits → encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. +1. **Encode**: compute the 8-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 0000)`. Reject if any limb of `D` equals `-1` (for a uniform sampling). For each limb, take its canonical representative's low 24 bits in little-endian order, concatenate to get 192 bits. Reject if any of the `encoding_num_final_zeros` bits at positions `v · w .. v · w + encoding_num_final_zeros` is non-zero. Take the first `v · w = 126` bits split into `v = 42` little-endian chunks of `w = 3` bits → encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. 2. **Recover WOTS public key**: for each `i`, walk chain `i` from `chain_tips[i]` for `chain_length - 1 - e_i` steps, where each step is `H(tweak_chain(i, step, s) | 00 | previous_value | pp | 0000)` truncated to `n`. 3. **Hash WOTS public key**: T-sponge with replacement over the `v` recovered chain ends, with IV `[tweak_wots_pk(s) | 00 | pp]`, ingesting two chain end digests at a time. Output is the Merkle leaf. 4. **Walk Merkle path**: for `level = 0..log_lifetime`, combine the current node with `merkle_proof[level]` (left/right determined by bit `level` of `s`) via `H(tweak_merkle(level+1, parent_index) | 00 | pp | left | right)` truncated to `n`. @@ -39,7 +40,7 @@ Inputs: public key `(merkle_root, pp)`, message `msg`, slot `s`, signature `(ran ## Security -target = 123,9 ≈ 124 bits of classical security in the ROM, and ≈ 62 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. +target = 127.84 ≈ 128 bits of classical security in the ROM, and ≈ 64 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. ## Signature size