@@ -54,6 +54,38 @@ absl::Status LazyRnsPolynomial<ModularInt>::CheckFusedMulAddInPlaceParameters(
5454 return absl::OkStatus ();
5555}
5656
57+ template <typename ModularInt>
58+ absl::Status
59+ LazyRnsPolynomial<ModularInt>::CheckFusedMulSumAddInPlaceParameters(
60+ const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
61+ const RnsPolynomial<ModularInt>& c,
62+ absl::Span<const PrimeModulus<ModularInt>* const > moduli) {
63+ if (!a.IsNttForm () || !b.IsNttForm () || !c.IsNttForm ()) {
64+ return absl::InvalidArgumentError (
65+ " Polynomials `a`, `b`, and `c` must be in NTT form." );
66+ }
67+ int num_moduli = moduli.size ();
68+ if (a.NumModuli () != num_moduli || b.NumModuli () != num_moduli ||
69+ c.NumModuli () != num_moduli || coeff_vectors_.size () != num_moduli) {
70+ return absl::InvalidArgumentError (
71+ " Polynomials `a`, `b`, `c`, and this must all be defined wrt `moduli`" );
72+ }
73+ for (int i = 0 ; i < num_moduli; ++i) {
74+ if (moduli[i]->ModParams ()->log_modulus + 3 >= sizeof (Integer) * 8 ) {
75+ return absl::InvalidArgumentError (
76+ " Modulus is too large to perform fused multiply-sum-add." );
77+ }
78+ }
79+ int num_coeffs = coeff_vectors_[0 ].size ();
80+ if (a.NumCoeffs () != num_coeffs || b.NumCoeffs () != num_coeffs ||
81+ c.NumCoeffs () != num_coeffs) {
82+ return absl::InvalidArgumentError (
83+ " Polynomials `a`, `b`, and `c` must have the same number of "
84+ " coefficients as this lazy polynomial." );
85+ }
86+ return absl::OkStatus ();
87+ }
88+
5789template <typename ModularInt>
5890absl::Status LazyRnsPolynomial<ModularInt>::FusedMulAddInPlace(
5991 const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
@@ -118,6 +150,157 @@ absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulAddInPlace(
118150 return absl::OkStatus ();
119151}
120152
153+ template <typename ModularInt>
154+ absl::Status LazyRnsPolynomial<ModularInt>::FusedMulSumAddInPlace(
155+ const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
156+ const RnsPolynomial<ModularInt>& c,
157+ absl::Span<const PrimeModulus<ModularInt>* const > moduli) {
158+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
159+ if (current_level_ == maximum_level_) {
160+ Refresh (moduli);
161+ }
162+
163+ int num_moduli = moduli.size ();
164+ int num_coeffs = coeff_vectors_[0 ].size ();
165+ const auto & a_coeff_vectors = a.Coeffs ();
166+ const auto & b_coeff_vectors = b.Coeffs ();
167+ const auto & c_coeff_vectors = c.Coeffs ();
168+ for (int i = 0 ; i < num_moduli; ++i) {
169+ for (int j = 0 ; j < num_coeffs; ++j) {
170+ coeff_vectors_[i][j] +=
171+ static_cast <BigInt>(
172+ a_coeff_vectors[i][j].GetMontgomeryRepresentation () +
173+ b_coeff_vectors[i][j].GetMontgomeryRepresentation ()) *
174+ c_coeff_vectors[i][j].GetMontgomeryRepresentation ();
175+ }
176+ }
177+ current_level_++;
178+ return absl::OkStatus ();
179+ }
180+
181+ template <>
182+ absl::Status LazyRnsPolynomial<ModularInt32>::FusedMulSumAddInPlace(
183+ const RnsPolynomial<ModularInt32>& a, const RnsPolynomial<ModularInt32>& b,
184+ const RnsPolynomial<ModularInt32>& c,
185+ absl::Span<const PrimeModulus<ModularInt32>* const > moduli) {
186+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
187+ if (current_level_ == maximum_level_) {
188+ Refresh (moduli);
189+ }
190+
191+ int num_moduli = moduli.size ();
192+ const auto & a_coeff_vectors = a.Coeffs ();
193+ const auto & b_coeff_vectors = b.Coeffs ();
194+ const auto & c_coeff_vectors = c.Coeffs ();
195+ for (int i = 0 ; i < num_moduli; ++i) {
196+ internal::BatchFusedMulSumAddMontgomeryRep<Uint32>(
197+ a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i],
198+ coeff_vectors_[i]);
199+ }
200+ current_level_++;
201+ return absl::OkStatus ();
202+ }
203+
204+ template <>
205+ absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulSumAddInPlace(
206+ const RnsPolynomial<ModularInt64>& a, const RnsPolynomial<ModularInt64>& b,
207+ const RnsPolynomial<ModularInt64>& c,
208+ absl::Span<const PrimeModulus<ModularInt64>* const > moduli) {
209+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
210+ if (current_level_ == maximum_level_) {
211+ Refresh (moduli);
212+ }
213+
214+ int num_moduli = moduli.size ();
215+ const auto & a_coeff_vectors = a.Coeffs ();
216+ const auto & b_coeff_vectors = b.Coeffs ();
217+ const auto & c_coeff_vectors = c.Coeffs ();
218+ for (int i = 0 ; i < num_moduli; ++i) {
219+ internal::BatchFusedMulSumAddMontgomeryRep<Uint64>(
220+ a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i],
221+ coeff_vectors_[i]);
222+ }
223+ current_level_++;
224+ return absl::OkStatus ();
225+ }
226+
227+ template <typename ModularInt>
228+ absl::Status LazyRnsPolynomial<ModularInt>::FusedMulDifferenceAddInPlace(
229+ const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
230+ const RnsPolynomial<ModularInt>& c,
231+ absl::Span<const PrimeModulus<ModularInt>* const > moduli) {
232+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
233+ if (current_level_ == maximum_level_) {
234+ Refresh (moduli);
235+ }
236+
237+ int num_moduli = moduli.size ();
238+ int num_coeffs = coeff_vectors_[0 ].size ();
239+ const auto & a_coeff_vectors = a.Coeffs ();
240+ const auto & b_coeff_vectors = b.Coeffs ();
241+ const auto & c_coeff_vectors = c.Coeffs ();
242+ for (int i = 0 ; i < num_moduli; ++i) {
243+ const auto qi = moduli[i]->Modulus ();
244+ for (int j = 0 ; j < num_coeffs; ++j) {
245+ coeff_vectors_[i][j] +=
246+ static_cast <BigInt>(
247+ a_coeff_vectors[i][j].GetMontgomeryRepresentation () + qi -
248+ b_coeff_vectors[i][j].GetMontgomeryRepresentation ()) *
249+ c_coeff_vectors[i][j].GetMontgomeryRepresentation ();
250+ }
251+ }
252+ current_level_++;
253+ return absl::OkStatus ();
254+ }
255+
256+ template <>
257+ absl::Status LazyRnsPolynomial<ModularInt32>::FusedMulDifferenceAddInPlace(
258+ const RnsPolynomial<ModularInt32>& a, const RnsPolynomial<ModularInt32>& b,
259+ const RnsPolynomial<ModularInt32>& c,
260+ absl::Span<const PrimeModulus<ModularInt32>* const > moduli) {
261+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
262+ if (current_level_ == maximum_level_) {
263+ Refresh (moduli);
264+ }
265+
266+ int num_moduli = moduli.size ();
267+ const auto & a_coeff_vectors = a.Coeffs ();
268+ const auto & b_coeff_vectors = b.Coeffs ();
269+ const auto & c_coeff_vectors = c.Coeffs ();
270+ for (int i = 0 ; i < num_moduli; ++i) {
271+ const auto qi = moduli[i]->Modulus ();
272+ internal::BatchFusedMulDifferenceAddMontgomeryRep<Uint32>(
273+ a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i], qi,
274+ coeff_vectors_[i]);
275+ }
276+ current_level_++;
277+ return absl::OkStatus ();
278+ }
279+
280+ template <>
281+ absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulDifferenceAddInPlace(
282+ const RnsPolynomial<ModularInt64>& a, const RnsPolynomial<ModularInt64>& b,
283+ const RnsPolynomial<ModularInt64>& c,
284+ absl::Span<const PrimeModulus<ModularInt64>* const > moduli) {
285+ RLWE_RETURN_IF_ERROR (CheckFusedMulSumAddInPlaceParameters (a, b, c, moduli));
286+ if (current_level_ == maximum_level_) {
287+ Refresh (moduli);
288+ }
289+
290+ int num_moduli = moduli.size ();
291+ const auto & a_coeff_vectors = a.Coeffs ();
292+ const auto & b_coeff_vectors = b.Coeffs ();
293+ const auto & c_coeff_vectors = c.Coeffs ();
294+ for (int i = 0 ; i < num_moduli; ++i) {
295+ const auto qi = moduli[i]->Modulus ();
296+ internal::BatchFusedMulDifferenceAddMontgomeryRep<Uint64>(
297+ a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i], qi,
298+ coeff_vectors_[i]);
299+ }
300+ current_level_++;
301+ return absl::OkStatus ();
302+ }
303+
121304template class LazyRnsPolynomial <MontgomeryInt<Uint16>>;
122305template class LazyRnsPolynomial <MontgomeryInt<Uint32>>;
123306template class LazyRnsPolynomial <MontgomeryInt<Uint64>>;
0 commit comments