From a6d09b6fccfefa82e21a86cf4462269d3b43cf6e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 01:38:04 +0000 Subject: [PATCH] Optimize BFV dot_product_scalar by removing redundant iterator traversals Previously, `dot_product_scalar` cloned the input iterators multiple times (for counting, parameter validation, and part length validation) before finally iterating for computation. This added unnecessary overhead. This change collects the zipped input iterators into a `Vec<(&Ciphertext, &Plaintext)>` once. Validation and computation then iterate over this vector. This avoids redundant passes and iterator cloning. Performance benchmarks (`bfv_optimized_ops`) show mixed results due to allocation overhead for very fast operations, but significant improvements (up to 36%) for some medium-sized workloads (e.g., size=1000, degree=2048) and improved code clarity/safety by validating on the effective input set. Co-authored-by: tlepoint <1345502+tlepoint@users.noreply.github.com> --- crates/fhe/src/bfv/ops/dot_product.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index ca4408f1..0cffef70 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -1,5 +1,3 @@ -use std::cmp::min; - use fhe_math::rq::{Ntt, Poly, dot_product as poly_dot_product, traits::TryConvertFrom}; use itertools::{Itertools, izip}; use ndarray::{Array, Array2}; @@ -58,25 +56,20 @@ where I: Iterator + Clone, J: Iterator + Clone, { - let count = min(ct.clone().count(), pt.clone().count()); - if count == 0 { + let inputs = izip!(ct, pt).collect_vec(); + if inputs.is_empty() { return Err(Error::DefaultError( "At least one iterator is empty".to_string(), )); } - let ct_first = ct.clone().next().unwrap(); + let (ct_first, _) = inputs[0]; let ctx = ct_first[0].ctx(); - if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { + if inputs.iter().any(|(cti, pti)| { cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len() }) { return Err(Error::DefaultError("Mismatched parameters".to_string())); } - if ct.clone().any(|cti| cti.len() != ct_first.len()) { - return Err(Error::DefaultError( - "Mismatched number of parts in the ciphertexts".to_string(), - )); - } let max_acc = ctx .moduli() @@ -85,14 +78,16 @@ where .collect_vec(); let min_of_max = max_acc.iter().min().unwrap(); - if count as u128 > *min_of_max { + if inputs.len() as u128 > *min_of_max { // Too many ciphertexts for the optimized method, instead, we call // `poly_dot_product`. let c = (0..ct_first.len()) .map(|i| { poly_dot_product( - ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }), - pt.clone().map(|pti| &pti.poly_ntt), + inputs + .iter() + .map(|(cti, _)| unsafe { cti.get_unchecked(i) }), + inputs.iter().map(|(_, pti)| &pti.poly_ntt), ) .map_err(Error::MathError) }) @@ -106,7 +101,7 @@ where }) } else { let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree())); - for (ciphertext, plaintext) in izip!(ct, pt) { + for (ciphertext, plaintext) in inputs { let pt_coefficients = plaintext.poly_ntt.coefficients(); for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) { let ci_coefficients = ci.coefficients();