Skip to content

Commit 51133d2

Browse files
b5licopybara-github
authored andcommitted
Compute (a + b) * c lazily with SIMD optimizations
PiperOrigin-RevId: 910478896
1 parent 383e1aa commit 51133d2

6 files changed

Lines changed: 891 additions & 50 deletions

File tree

shell_encryption/rns/lazy_rns_polynomial.cc

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,38 @@ absl::Status LazyRnsPolynomial<ModularInt>::CheckFusedMulAddInPlaceParameters(
5454
return absl::OkStatus();
5555
}
5656

57+
template <typename ModularInt>
58+
absl::Status
59+
LazyRnsPolynomial<ModularInt>::CheckFusedMulSumAddInPlaceParameters(
60+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
61+
const RnsPolynomial<ModularInt>& c,
62+
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
63+
if (!a.IsNttForm() || !b.IsNttForm() || !c.IsNttForm()) {
64+
return absl::InvalidArgumentError(
65+
"Polynomials `a`, `b`, and `c` must be in NTT form.");
66+
}
67+
int num_moduli = moduli.size();
68+
if (a.NumModuli() != num_moduli || b.NumModuli() != num_moduli ||
69+
c.NumModuli() != num_moduli || coeff_vectors_.size() != num_moduli) {
70+
return absl::InvalidArgumentError(
71+
"Polynomials `a`, `b`, `c`, and this must all be defined wrt `moduli`");
72+
}
73+
for (int i = 0; i < num_moduli; ++i) {
74+
if (moduli[i]->ModParams()->log_modulus + 3 >= sizeof(Integer) * 8) {
75+
return absl::InvalidArgumentError(
76+
"Modulus is too large to perform fused multiply-sum-add.");
77+
}
78+
}
79+
int num_coeffs = coeff_vectors_[0].size();
80+
if (a.NumCoeffs() != num_coeffs || b.NumCoeffs() != num_coeffs ||
81+
c.NumCoeffs() != num_coeffs) {
82+
return absl::InvalidArgumentError(
83+
"Polynomials `a`, `b`, and `c` must have the same number of "
84+
"coefficients as this lazy polynomial.");
85+
}
86+
return absl::OkStatus();
87+
}
88+
5789
template <typename ModularInt>
5890
absl::Status LazyRnsPolynomial<ModularInt>::FusedMulAddInPlace(
5991
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
@@ -118,6 +150,157 @@ absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulAddInPlace(
118150
return absl::OkStatus();
119151
}
120152

153+
template <typename ModularInt>
154+
absl::Status LazyRnsPolynomial<ModularInt>::FusedMulSumAddInPlace(
155+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
156+
const RnsPolynomial<ModularInt>& c,
157+
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
158+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
159+
if (current_level_ == maximum_level_) {
160+
Refresh(moduli);
161+
}
162+
163+
int num_moduli = moduli.size();
164+
int num_coeffs = coeff_vectors_[0].size();
165+
const auto& a_coeff_vectors = a.Coeffs();
166+
const auto& b_coeff_vectors = b.Coeffs();
167+
const auto& c_coeff_vectors = c.Coeffs();
168+
for (int i = 0; i < num_moduli; ++i) {
169+
for (int j = 0; j < num_coeffs; ++j) {
170+
coeff_vectors_[i][j] +=
171+
static_cast<BigInt>(
172+
a_coeff_vectors[i][j].GetMontgomeryRepresentation() +
173+
b_coeff_vectors[i][j].GetMontgomeryRepresentation()) *
174+
c_coeff_vectors[i][j].GetMontgomeryRepresentation();
175+
}
176+
}
177+
current_level_++;
178+
return absl::OkStatus();
179+
}
180+
181+
template <>
182+
absl::Status LazyRnsPolynomial<ModularInt32>::FusedMulSumAddInPlace(
183+
const RnsPolynomial<ModularInt32>& a, const RnsPolynomial<ModularInt32>& b,
184+
const RnsPolynomial<ModularInt32>& c,
185+
absl::Span<const PrimeModulus<ModularInt32>* const> moduli) {
186+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
187+
if (current_level_ == maximum_level_) {
188+
Refresh(moduli);
189+
}
190+
191+
int num_moduli = moduli.size();
192+
const auto& a_coeff_vectors = a.Coeffs();
193+
const auto& b_coeff_vectors = b.Coeffs();
194+
const auto& c_coeff_vectors = c.Coeffs();
195+
for (int i = 0; i < num_moduli; ++i) {
196+
internal::BatchFusedMulSumAddMontgomeryRep<Uint32>(
197+
a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i],
198+
coeff_vectors_[i]);
199+
}
200+
current_level_++;
201+
return absl::OkStatus();
202+
}
203+
204+
template <>
205+
absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulSumAddInPlace(
206+
const RnsPolynomial<ModularInt64>& a, const RnsPolynomial<ModularInt64>& b,
207+
const RnsPolynomial<ModularInt64>& c,
208+
absl::Span<const PrimeModulus<ModularInt64>* const> moduli) {
209+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
210+
if (current_level_ == maximum_level_) {
211+
Refresh(moduli);
212+
}
213+
214+
int num_moduli = moduli.size();
215+
const auto& a_coeff_vectors = a.Coeffs();
216+
const auto& b_coeff_vectors = b.Coeffs();
217+
const auto& c_coeff_vectors = c.Coeffs();
218+
for (int i = 0; i < num_moduli; ++i) {
219+
internal::BatchFusedMulSumAddMontgomeryRep<Uint64>(
220+
a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i],
221+
coeff_vectors_[i]);
222+
}
223+
current_level_++;
224+
return absl::OkStatus();
225+
}
226+
227+
template <typename ModularInt>
228+
absl::Status LazyRnsPolynomial<ModularInt>::FusedMulDifferenceAddInPlace(
229+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
230+
const RnsPolynomial<ModularInt>& c,
231+
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
232+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
233+
if (current_level_ == maximum_level_) {
234+
Refresh(moduli);
235+
}
236+
237+
int num_moduli = moduli.size();
238+
int num_coeffs = coeff_vectors_[0].size();
239+
const auto& a_coeff_vectors = a.Coeffs();
240+
const auto& b_coeff_vectors = b.Coeffs();
241+
const auto& c_coeff_vectors = c.Coeffs();
242+
for (int i = 0; i < num_moduli; ++i) {
243+
const auto qi = moduli[i]->Modulus();
244+
for (int j = 0; j < num_coeffs; ++j) {
245+
coeff_vectors_[i][j] +=
246+
static_cast<BigInt>(
247+
a_coeff_vectors[i][j].GetMontgomeryRepresentation() + qi -
248+
b_coeff_vectors[i][j].GetMontgomeryRepresentation()) *
249+
c_coeff_vectors[i][j].GetMontgomeryRepresentation();
250+
}
251+
}
252+
current_level_++;
253+
return absl::OkStatus();
254+
}
255+
256+
template <>
257+
absl::Status LazyRnsPolynomial<ModularInt32>::FusedMulDifferenceAddInPlace(
258+
const RnsPolynomial<ModularInt32>& a, const RnsPolynomial<ModularInt32>& b,
259+
const RnsPolynomial<ModularInt32>& c,
260+
absl::Span<const PrimeModulus<ModularInt32>* const> moduli) {
261+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
262+
if (current_level_ == maximum_level_) {
263+
Refresh(moduli);
264+
}
265+
266+
int num_moduli = moduli.size();
267+
const auto& a_coeff_vectors = a.Coeffs();
268+
const auto& b_coeff_vectors = b.Coeffs();
269+
const auto& c_coeff_vectors = c.Coeffs();
270+
for (int i = 0; i < num_moduli; ++i) {
271+
const auto qi = moduli[i]->Modulus();
272+
internal::BatchFusedMulDifferenceAddMontgomeryRep<Uint32>(
273+
a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i], qi,
274+
coeff_vectors_[i]);
275+
}
276+
current_level_++;
277+
return absl::OkStatus();
278+
}
279+
280+
template <>
281+
absl::Status LazyRnsPolynomial<ModularInt64>::FusedMulDifferenceAddInPlace(
282+
const RnsPolynomial<ModularInt64>& a, const RnsPolynomial<ModularInt64>& b,
283+
const RnsPolynomial<ModularInt64>& c,
284+
absl::Span<const PrimeModulus<ModularInt64>* const> moduli) {
285+
RLWE_RETURN_IF_ERROR(CheckFusedMulSumAddInPlaceParameters(a, b, c, moduli));
286+
if (current_level_ == maximum_level_) {
287+
Refresh(moduli);
288+
}
289+
290+
int num_moduli = moduli.size();
291+
const auto& a_coeff_vectors = a.Coeffs();
292+
const auto& b_coeff_vectors = b.Coeffs();
293+
const auto& c_coeff_vectors = c.Coeffs();
294+
for (int i = 0; i < num_moduli; ++i) {
295+
const auto qi = moduli[i]->Modulus();
296+
internal::BatchFusedMulDifferenceAddMontgomeryRep<Uint64>(
297+
a_coeff_vectors[i], b_coeff_vectors[i], c_coeff_vectors[i], qi,
298+
coeff_vectors_[i]);
299+
}
300+
current_level_++;
301+
return absl::OkStatus();
302+
}
303+
121304
template class LazyRnsPolynomial<MontgomeryInt<Uint16>>;
122305
template class LazyRnsPolynomial<MontgomeryInt<Uint32>>;
123306
template class LazyRnsPolynomial<MontgomeryInt<Uint64>>;

shell_encryption/rns/lazy_rns_polynomial.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ class LazyRnsPolynomial {
133133
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
134134
absl::Span<const PrimeModulus<ModularInt>* const> moduli);
135135

136+
// Adds (a + b) * c (mod moduli) to this polynomial.
137+
absl::Status FusedMulSumAddInPlace(
138+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
139+
const RnsPolynomial<ModularInt>& c,
140+
absl::Span<const PrimeModulus<ModularInt>* const> moduli);
141+
142+
// Adds (a - b) * c (mod moduli) to this polynomial.
143+
absl::Status FusedMulDifferenceAddInPlace(
144+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
145+
const RnsPolynomial<ModularInt>& c,
146+
absl::Span<const PrimeModulus<ModularInt>* const> moduli);
147+
136148
private:
137149
explicit LazyRnsPolynomial(
138150
std::vector<hwy::AlignedVector<BigInt>> coeff_vectors,
@@ -183,6 +195,11 @@ class LazyRnsPolynomial {
183195
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
184196
absl::Span<const PrimeModulus<ModularInt>* const> moduli);
185197

198+
absl::Status CheckFusedMulSumAddInPlaceParameters(
199+
const RnsPolynomial<ModularInt>& a, const RnsPolynomial<ModularInt>& b,
200+
const RnsPolynomial<ModularInt>& c,
201+
absl::Span<const PrimeModulus<ModularInt>* const> moduli);
202+
186203
// Coefficients of the polynomial modulo prime moduli.
187204
// Each vector corresponds to a prime modulus in moduli_ in the same order.
188205
std::vector<hwy::AlignedVector<BigInt>> coeff_vectors_;

0 commit comments

Comments
 (0)