From d6c7d30a165704d30cc20cfe0fc7588d6687e64b Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 7 Feb 2026 05:47:19 +0100 Subject: [PATCH] perf: clean up mul_redc inner loop --- src/algorithms/mul_redc.rs | 46 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/algorithms/mul_redc.rs b/src/algorithms/mul_redc.rs index 05ba4a67..559ce1c6 100644 --- a/src/algorithms/mul_redc.rs +++ b/src/algorithms/mul_redc.rs @@ -21,36 +21,35 @@ pub fn mul_redc(a: [u64; N], b: [u64; N], modulus: [u64; N], inv // See let mut result = [0; N]; let mut carry = false; + let has_top_carry = modulus[N - 1] >= 0x7fff_ffff_ffff_ffff; for b in b { - let mut m = 0; - let mut carry_1 = 0; - let mut carry_2 = 0; - for i in 0..N { - // Add limb product + let mut carry_1; + let mut carry_2; + + // Add limb product and compute reduction factor. + let (value, next_carry) = carrying_mul_add(a[0], b, result[0], 0); + carry_1 = next_carry; + let m = value.wrapping_mul(inv); + let (value, next_carry) = carrying_mul_add(modulus[0], m, value, 0); + carry_2 = next_carry; + debug_assert_eq!(value, 0); + + for i in 1..N { + // Add limb product. let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1); carry_1 = next_carry; - if i == 0 { - // Compute reduction factor - m = value.wrapping_mul(inv); - } - - // Add m * modulus to acc to clear next_result[0] + // Add m * modulus to acc and shift result. let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2); carry_2 = next_carry; - // Shift result - if i > 0 { - result[i - 1] = value; - } else { - debug_assert_eq!(value, 0); - } + result[i - 1] = value; } - // Add carries + // Add carries. let (value, next_carry) = carrying_add(carry_1, carry_2, carry); result[N - 1] = value; - if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff { + if has_top_carry { carry = next_carry; } else { debug_assert!(!next_carry); @@ -74,8 +73,9 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> let mut result = [0; N]; let mut carry_outer = 0; + let has_top_carry = modulus[N - 1] >= 0x3fff_ffff_ffff_ffff; for i in 0..N { - // Add limb product + // Add limb product. let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0); let mut carry_hi = false; result[i] = value; @@ -87,7 +87,7 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> carry_hi = next_carry_hi; } - // Add m times modulus to result and shift one limb + // Add m times modulus to result and shift one limb. let m = result[0].wrapping_mul(inv); let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0); debug_assert_eq!(value, 0); @@ -97,8 +97,8 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> carry = next_carry; } - // Add carries - if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff { + // Add carries. + if has_top_carry { let wide = (carry_outer as u128) .wrapping_add(carry_lo as u128) .wrapping_add((carry_hi as u128) << 64)