diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index ca4408f1..0cffef70 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -1,5 +1,3 @@ -use std::cmp::min; - use fhe_math::rq::{Ntt, Poly, dot_product as poly_dot_product, traits::TryConvertFrom}; use itertools::{Itertools, izip}; use ndarray::{Array, Array2}; @@ -58,25 +56,20 @@ where I: Iterator + Clone, J: Iterator + Clone, { - let count = min(ct.clone().count(), pt.clone().count()); - if count == 0 { + let inputs = izip!(ct, pt).collect_vec(); + if inputs.is_empty() { return Err(Error::DefaultError( "At least one iterator is empty".to_string(), )); } - let ct_first = ct.clone().next().unwrap(); + let (ct_first, _) = inputs[0]; let ctx = ct_first[0].ctx(); - if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { + if inputs.iter().any(|(cti, pti)| { cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len() }) { return Err(Error::DefaultError("Mismatched parameters".to_string())); } - if ct.clone().any(|cti| cti.len() != ct_first.len()) { - return Err(Error::DefaultError( - "Mismatched number of parts in the ciphertexts".to_string(), - )); - } let max_acc = ctx .moduli() @@ -85,14 +78,16 @@ where .collect_vec(); let min_of_max = max_acc.iter().min().unwrap(); - if count as u128 > *min_of_max { + if inputs.len() as u128 > *min_of_max { // Too many ciphertexts for the optimized method, instead, we call // `poly_dot_product`. let c = (0..ct_first.len()) .map(|i| { poly_dot_product( - ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }), - pt.clone().map(|pti| &pti.poly_ntt), + inputs + .iter() + .map(|(cti, _)| unsafe { cti.get_unchecked(i) }), + inputs.iter().map(|(_, pti)| &pti.poly_ntt), ) .map_err(Error::MathError) }) @@ -106,7 +101,7 @@ where }) } else { let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree())); - for (ciphertext, plaintext) in izip!(ct, pt) { + for (ciphertext, plaintext) in inputs { let pt_coefficients = plaintext.poly_ntt.coefficients(); for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) { let ci_coefficients = ci.coefficients();