diff --git a/crates/fhe/src/bfv/keys/key_switching_key.rs b/crates/fhe/src/bfv/keys/key_switching_key.rs index c4ef1ab..a80f647 100644 --- a/crates/fhe/src/bfv/keys/key_switching_key.rs +++ b/crates/fhe/src/bfv/keys/key_switching_key.rs @@ -10,7 +10,7 @@ use fhe_math::{ rq::{Ntt, NttShoup, Poly, PowerBasis}, }; use fhe_traits::{DeserializeWithContext, Serialize}; -use itertools::{Itertools, izip}; +use itertools::izip; use num_bigint::BigUint; use rand::{CryptoRng, Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; @@ -305,25 +305,29 @@ impl KeySwitchingKey { .ilog2() as usize; let mut coefficients = p.coefficients().to_slice().unwrap().to_vec(); - let mut c2i = vec![]; let mask = (1u64 << self.log_base) - 1; - (0..log_modulus.div_ceil(self.log_base)).for_each(|_| { - c2i.push(coefficients.iter().map(|c| c & mask).collect_vec()); - coefficients.iter_mut().for_each(|c| *c >>= self.log_base); - }); let mut c0 = Poly::::zero(&self.ctx_ksk); let mut c1 = Poly::::zero(&self.ctx_ksk); - for (c2_i_coefficients, c0_i, c1_i) in izip!(c2i.iter(), self.c0.iter(), self.c1.iter()) { + + let decomposition_count = log_modulus.div_ceil(self.log_base); + let mut c2_i_buffer = vec![0u64; coefficients.len()]; + + for (c0_i, c1_i) in izip!(self.c0.iter(), self.c1.iter()).take(decomposition_count) { + for (dest, src) in izip!(c2_i_buffer.iter_mut(), coefficients.iter()) { + *dest = src & mask; + } let mut c2_i = unsafe { Poly::::create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time( - c2_i_coefficients.as_slice(), + &c2_i_buffer, &self.ctx_ksk, ) }; c0 += &(&c2_i * c0_i); c2_i *= c1_i; c1 += &c2_i; + + coefficients.iter_mut().for_each(|c| *c >>= self.log_base); } Ok((c0, c1)) } @@ -576,4 +580,5 @@ mod tests { } Ok(()) } + }