Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions src/algorithms/mul_redc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,35 @@ pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv
// See <https://tches.iacr.org/index.php/TCHES/article/view/10972>
let mut result = [0; N];
let mut carry = false;
let has_top_carry = modulus[N - 1] >= 0x7fff_ffff_ffff_ffff;
for b in b {
let mut m = 0;
let mut carry_1 = 0;
let mut carry_2 = 0;
for i in 0..N {
// Add limb product
let mut carry_1;
let mut carry_2;

// Add limb product and compute reduction factor.
let (value, next_carry) = carrying_mul_add(a[0], b, result[0], 0);
carry_1 = next_carry;
let m = value.wrapping_mul(inv);
let (value, next_carry) = carrying_mul_add(modulus[0], m, value, 0);
carry_2 = next_carry;
debug_assert_eq!(value, 0);

for i in 1..N {
// Add limb product.
let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
carry_1 = next_carry;

if i == 0 {
// Compute reduction factor
m = value.wrapping_mul(inv);
}

// Add m * modulus to acc to clear next_result[0]
// Add m * modulus to acc and shift result.
let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
carry_2 = next_carry;

// Shift result
if i > 0 {
result[i - 1] = value;
} else {
debug_assert_eq!(value, 0);
}
result[i - 1] = value;
}

// Add carries
// Add carries.
let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
result[N - 1] = value;
if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
if has_top_carry {
carry = next_carry;
} else {
debug_assert!(!next_carry);
Expand All @@ -74,8 +73,9 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->

let mut result = [0; N];
let mut carry_outer = 0;
let has_top_carry = modulus[N - 1] >= 0x3fff_ffff_ffff_ffff;
for i in 0..N {
// Add limb product
// Add limb product.
let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
let mut carry_hi = false;
result[i] = value;
Expand All @@ -87,7 +87,7 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->
carry_hi = next_carry_hi;
}

// Add m times modulus to result and shift one limb
// Add m times modulus to result and shift one limb.
let m = result[0].wrapping_mul(inv);
let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
debug_assert_eq!(value, 0);
Expand All @@ -97,8 +97,8 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->
carry = next_carry;
}

// Add carries
if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
// Add carries.
if has_top_carry {
let wide = (carry_outer as u128)
.wrapping_add(carry_lo as u128)
.wrapping_add((carry_hi as u128) << 64)
Expand Down
Loading