From c47dfd3c03bbada6eddb977d39b1a8038eaf8622 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 23 Dec 2024 11:09:28 -0700 Subject: [PATCH 01/10] WIP: initial start --- src/lib.rs | 1 + src/lwe/mod.rs | 167 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 src/lwe/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 7acc5a38..61ef6dc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,7 @@ pub mod encryption; pub mod hashes; pub mod hmac; pub mod kzg; +pub mod lwe; pub mod multi_var_poly; pub mod polynomial; pub mod sumcheck; diff --git a/src/lwe/mod.rs b/src/lwe/mod.rs new file mode 100644 index 00000000..58ac9573 --- /dev/null +++ b/src/lwe/mod.rs @@ -0,0 +1,167 @@ +use rand::Rng; + +use crate::algebra::field::{extension::GaloisField, prime::PrimeField, Field, FiniteField}; + +// Small parameters for testing +// const N: usize = 4; // Dimension (use larger like 256 in practice) +// const Q: u32 = 97; // Modulus (prime, use larger in practice) +const K: usize = 1; // Parameter for binomial sampling + +#[derive(Debug, Clone)] +pub struct PublicKey { + a: [[PrimeField; N]; N], // Matrix A in Z_q^{n×n} + b: [PrimeField; N], // Vector b = As + e +} + +#[derive(Debug, Clone)] +pub struct SecretKey { + s: [PrimeField; N], // Secret vector in Z_q^n +} + +#[derive(Debug, Clone)] +pub struct Ciphertext { + u: Vec>, // First component + v: PrimeField, // Second component +} + +fn sample_binomial() -> PrimeField { + let mut rng = rand::thread_rng(); + let mut sum = PrimeField::::ZERO; + + for _ in 0..K { + // Add 1 or 0 for first term + if rng.gen::() { + sum = sum + PrimeField::::ONE; + } + // Subtract 1 or 0 for second term + if rng.gen::() { + sum = sum - PrimeField::::ONE; + } + } + sum +} + +pub fn keygen() -> (PublicKey, SecretKey) { + let mut rng = rand::thread_rng(); + + // Generate random matrix A + let mut a = [[PrimeField::::ZERO; N]; N]; + for i in 0..N { + for j in 0..N { + a[i][j] = PrimeField::from(rng.gen_range(0..Q)); + } + } + + // Sample secret vector s with small coefficients + let mut s = [PrimeField::::ZERO; N]; + for i in 0..N { + s[i] = sample_binomial(); + } + + // Generate error vector e + let mut e = [PrimeField::::ZERO; N]; + for i in 0..N { + e[i] = sample_binomial(); + } + + // Compute b = As + e + let mut b = [PrimeField::::ZERO; N]; + for i in 0..N { + let mut sum = PrimeField::::ZERO; + for j in 0..N { + sum += a[i][j] * s[j]; + } + sum += e[i]; + b[i] = sum; + } + + (PublicKey { a, b }, SecretKey { s }) +} + +// TODO: Should impl the `SymmetricEncryption` trait +pub fn encrypt(pk: &PublicKey, message: bool) -> Ciphertext { + let mut rng = rand::thread_rng(); + + // Sample random vector r + let mut r = vec![0; N]; + for i in 0..N { + r[i] = rng.gen_range(0..2) as u32; // Binary vector + } + + // Compute u = A^T r + let mut u = vec![0; N]; + for i in 0..N { + let mut sum = 0i32; + for j in 0..N { + sum += (pk.a[j][i] * r[j]) as i32; + } + u[i] = mod_q(sum); + } + + // Compute v = b^T r + ⌊q/2⌋m + let mut v = 0i32; + for i in 0..N { + v += (pk.b[i] * r[i]) as i32; + } + if message { + v += (Q / 2) as i32; + } + + Ciphertext { u, v: mod_q(v) } +} + +pub fn decrypt(sk: &SecretKey, ct: &Ciphertext) -> bool { + // Compute v - s^T u + let mut sum = ct.v as i32; + for i in 0..N { + sum -= (sk.s[i] * ct.u[i]) as i32; + } + sum = mod_q(sum) as i32; + + // Check if closer to 0 or ⌊q/2⌋ + let q_half = (Q / 2) as i32; + let mut dist_to_zero = sum.min(Q as i32 - sum); + let mut dist_to_q_half = ((sum - q_half).abs()).min((sum - q_half + Q as i32).abs()); + + dist_to_q_half < dist_to_zero +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encryption_decryption() { + // Test multiple random instances + for _ in 0..100 { + let (pk, sk) = keygen(); + + // Test encryption and decryption of 0 + let ct_zero = encrypt(&pk, false); + assert_eq!(decrypt(&sk, &ct_zero), false); + + // Test encryption and decryption of 1 + let ct_one = encrypt(&pk, true); + assert_eq!(decrypt(&sk, &ct_one), true); + } + } + + #[test] + fn test_with_fixed_randomness() { + // This test would need a deterministic RNG to be meaningful + // In practice, you'd want to test with known test vectors + } + + #[test] + fn test_binomial_distribution() { + let mut counts = vec![0; 3]; // -1, 0, 1 for k=1 + for _ in 0..1000 { + let sample = sample_binomial(); + counts[(sample + 1) as usize] += 1; + } + // Check rough distribution - should be approximately 1/4, 1/2, 1/4 + for count in counts.iter() { + assert!(*count > 150); // Should be roughly 250 but allow some variance + } + } +} From e1122e0fd1c838e66f44ac08b8a2db4ead94cc52 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 23 Dec 2024 12:19:11 -0700 Subject: [PATCH 02/10] feat: basic LWE --- src/algebra/field/prime/mod.rs | 2 +- src/algebra/ring.rs | 30 +++++ src/encryption/asymmetric/lwe/mod.rs | 186 +++++++++++++++++++++++++++ src/encryption/asymmetric/mod.rs | 14 ++ src/lib.rs | 1 - src/lwe/mod.rs | 167 ------------------------ 6 files changed, 231 insertions(+), 169 deletions(-) create mode 100644 src/algebra/ring.rs create mode 100644 src/encryption/asymmetric/lwe/mod.rs delete mode 100644 src/lwe/mod.rs diff --git a/src/algebra/field/prime/mod.rs b/src/algebra/field/prime/mod.rs index 8794f80f..131e76fa 100644 --- a/src/algebra/field/prime/mod.rs +++ b/src/algebra/field/prime/mod.rs @@ -36,7 +36,7 @@ pub type AESField = PrimeField<2>; /// The [`PrimeField`] struct represents elements of a field with prime order. The field is defined /// by a prime number `P`, and the elements are integers modulo `P`. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, PartialOrd)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] pub struct PrimeField { pub(crate) value: usize, } diff --git a/src/algebra/ring.rs b/src/algebra/ring.rs new file mode 100644 index 00000000..e2760ad8 --- /dev/null +++ b/src/algebra/ring.rs @@ -0,0 +1,30 @@ +use std::{ + cmp::{Eq, PartialEq}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +use super::Finite; + +pub trait Ring: + std::fmt::Debug + + Default + + Sized + + Copy + + Clone + + PartialEq + + Eq + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Neg + + 'static { + const ZERO: Self; + const ONE: Self; +} + +// pub trait FiniteRing: Finite + Ring { +// const PRIMITIVE_ELEMENT: Self; +// } diff --git a/src/encryption/asymmetric/lwe/mod.rs b/src/encryption/asymmetric/lwe/mod.rs new file mode 100644 index 00000000..fc783c98 --- /dev/null +++ b/src/encryption/asymmetric/lwe/mod.rs @@ -0,0 +1,186 @@ +//! This module implements the Learning With Errors (LWE) public-key encryption scheme. +//! +//! LWE is a lattice-based cryptosystem whose security is based on the hardness of the +//! Learning With Errors problem, first introduced by Regev in 2005. The scheme encrypts +//! single bits and provides security against quantum computers. +//! +//! The implementation uses three type parameters: +//! - Q: The modulus (must be prime) +//! - N: The dimension of the lattice +//! - K: The parameter for the binomial distribution used for error sampling + +use rand::Rng; + +use super::AsymmetricEncryption; +use crate::algebra::field::{prime::PrimeField, Field}; + +/// An implementation of Learning With Errors (LWE) encryption +pub struct LWE { + private_key: PrivateKey, + public_key: PublicKey, +} + +/// The public key consisting of a random matrix A and vector b = As + e +#[derive(Debug, Clone)] +pub struct PublicKey { + /// Random matrix in Z_q^{n×n} + a: [[PrimeField; N]; N], + /// Vector b = As + e where s is secret and e is error + b: [PrimeField; N], +} + +/// The private key consisting of the secret vector s +#[derive(Debug, Clone)] +pub struct PrivateKey { + /// Secret vector with small coefficients + s: [PrimeField; N], +} + +/// A ciphertext consisting of vector u and scalar v +#[derive(Debug, Clone)] +pub struct Ciphertext { + /// First component u = A^T r + u: [PrimeField; N], + /// Second component v = b^T r + ⌊q/2⌋m + v: PrimeField, +} + +/// Sample from a centered binomial distribution with parameter K +/// +/// Returns a value in the range [-K, K] following a discrete approximation +/// of a Gaussian distribution. +pub fn sample_binomial() -> PrimeField { + let mut rng = rand::thread_rng(); + let mut sum = PrimeField::::ZERO; + + for _ in 0..K { + if rng.gen::() { + sum += PrimeField::::ONE; + } + if rng.gen::() { + sum -= PrimeField::::ONE; + } + } + sum +} + +impl LWE { + pub fn new() -> LWE { + let mut rng = rand::thread_rng(); + + // Generate random matrix A + let mut a = [[PrimeField::::ZERO; N]; N]; + for i in 0..N { + for j in 0..N { + a[i][j] = PrimeField::from(rng.gen_range(0..Q)); + } + } + + // Sample secret vector s with small coefficients + let mut s = [PrimeField::::ZERO; N]; + for i in 0..N { + s[i] = sample_binomial::(); + } + + // Generate error vector e + let mut e = [PrimeField::::ZERO; N]; + for i in 0..N { + e[i] = sample_binomial::(); + } + + // Compute b = As + e + let mut b = [PrimeField::::ZERO; N]; + for i in 0..N { + let mut sum = PrimeField::::ZERO; + for j in 0..N { + sum += a[i][j] * s[j]; + } + sum += e[i]; + b[i] = sum; + } + + Self { public_key: PublicKey { a, b }, private_key: PrivateKey { s } } + } +} + +impl AsymmetricEncryption for LWE { + type Ciphertext = Ciphertext; + type Plaintext = bool; + type PrivateKey = PrivateKey; + type PublicKey = PublicKey; + + fn encrypt(&self, plaintext: &Self::Plaintext) -> Self::Ciphertext { + let mut rng = rand::thread_rng(); + + // Sample random vector r (binary) + let mut r = [PrimeField::::ZERO; N]; + for i in 0..N { + r[i] = PrimeField::from(rng.gen_range(0..2)); + } + + // Compute u = A^T r + let mut u = [PrimeField::::ZERO; N]; + for i in 0..N { + for j in 0..N { + u[i] += self.public_key.a[j][i] * r[j]; + } + } + + // Compute v = b^T r + ⌊q/2⌋m + let mut v = PrimeField::::ZERO; + for i in 0..N { + v += self.public_key.b[i] * r[i]; + } + + if *plaintext { + v += PrimeField::from(Q / 2); + } + + Ciphertext { u, v } + } + + fn decrypt(&self, ct: &Self::Ciphertext) -> Self::Plaintext { + // Compute v - s^T u + let mut result = ct.v; + for i in 0..N { + result -= self.private_key.s[i] * ct.u[i]; + } + + // Get q/2 as a field element + let q_half = PrimeField::::from(Q / 2); + + // For distance to zero, we need min(x, q-x) + let dist_to_zero = result.min(-result); + + // For distance to q/2, we need min(|x - q/2|, |q/2 - x|) + let dist_to_q_half = if result >= q_half { + // If result ≥ q/2, distance is min(result - q/2, q - result + q/2) + (result - q_half).min(-result + q_half) + } else { + // If result < q/2, distance is min(q/2 - result, result + q/2) + (q_half - result).min(result + q_half) + }; + + dist_to_q_half < dist_to_zero + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encryption_decryption() { + for _ in 0..100 { + let lwe = LWE::<97, 4, 2>::new(); + + // Test encryption and decryption of 0 + let ct_zero = lwe.encrypt(&false); + assert_eq!(lwe.decrypt(&ct_zero), false, "Failed decrypting 0"); + + // Test encryption and decryption of 1 + let ct_one = lwe.encrypt(&true); + assert_eq!(lwe.decrypt(&ct_one), true, "Failed decrypting 1"); + } + } +} diff --git a/src/encryption/asymmetric/mod.rs b/src/encryption/asymmetric/mod.rs index 62939b9a..1b9e310f 100644 --- a/src/encryption/asymmetric/mod.rs +++ b/src/encryption/asymmetric/mod.rs @@ -1,2 +1,16 @@ //! Contains implementation of asymmetric cryptographic primitives like RSA encryption. +pub mod lwe; pub mod rsa; + +pub trait AsymmetricEncryption { + type PublicKey; + type PrivateKey; + type Plaintext; + type Ciphertext; + + /// Encrypts plaintext using key and returns ciphertext + fn encrypt(&self, plaintext: &Self::Plaintext) -> Self::Ciphertext; + + /// Decrypts ciphertext using key and returns plaintext + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> Self::Plaintext; +} diff --git a/src/lib.rs b/src/lib.rs index 61ef6dc2..7acc5a38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,6 @@ pub mod encryption; pub mod hashes; pub mod hmac; pub mod kzg; -pub mod lwe; pub mod multi_var_poly; pub mod polynomial; pub mod sumcheck; diff --git a/src/lwe/mod.rs b/src/lwe/mod.rs deleted file mode 100644 index 58ac9573..00000000 --- a/src/lwe/mod.rs +++ /dev/null @@ -1,167 +0,0 @@ -use rand::Rng; - -use crate::algebra::field::{extension::GaloisField, prime::PrimeField, Field, FiniteField}; - -// Small parameters for testing -// const N: usize = 4; // Dimension (use larger like 256 in practice) -// const Q: u32 = 97; // Modulus (prime, use larger in practice) -const K: usize = 1; // Parameter for binomial sampling - -#[derive(Debug, Clone)] -pub struct PublicKey { - a: [[PrimeField; N]; N], // Matrix A in Z_q^{n×n} - b: [PrimeField; N], // Vector b = As + e -} - -#[derive(Debug, Clone)] -pub struct SecretKey { - s: [PrimeField; N], // Secret vector in Z_q^n -} - -#[derive(Debug, Clone)] -pub struct Ciphertext { - u: Vec>, // First component - v: PrimeField, // Second component -} - -fn sample_binomial() -> PrimeField { - let mut rng = rand::thread_rng(); - let mut sum = PrimeField::::ZERO; - - for _ in 0..K { - // Add 1 or 0 for first term - if rng.gen::() { - sum = sum + PrimeField::::ONE; - } - // Subtract 1 or 0 for second term - if rng.gen::() { - sum = sum - PrimeField::::ONE; - } - } - sum -} - -pub fn keygen() -> (PublicKey, SecretKey) { - let mut rng = rand::thread_rng(); - - // Generate random matrix A - let mut a = [[PrimeField::::ZERO; N]; N]; - for i in 0..N { - for j in 0..N { - a[i][j] = PrimeField::from(rng.gen_range(0..Q)); - } - } - - // Sample secret vector s with small coefficients - let mut s = [PrimeField::::ZERO; N]; - for i in 0..N { - s[i] = sample_binomial(); - } - - // Generate error vector e - let mut e = [PrimeField::::ZERO; N]; - for i in 0..N { - e[i] = sample_binomial(); - } - - // Compute b = As + e - let mut b = [PrimeField::::ZERO; N]; - for i in 0..N { - let mut sum = PrimeField::::ZERO; - for j in 0..N { - sum += a[i][j] * s[j]; - } - sum += e[i]; - b[i] = sum; - } - - (PublicKey { a, b }, SecretKey { s }) -} - -// TODO: Should impl the `SymmetricEncryption` trait -pub fn encrypt(pk: &PublicKey, message: bool) -> Ciphertext { - let mut rng = rand::thread_rng(); - - // Sample random vector r - let mut r = vec![0; N]; - for i in 0..N { - r[i] = rng.gen_range(0..2) as u32; // Binary vector - } - - // Compute u = A^T r - let mut u = vec![0; N]; - for i in 0..N { - let mut sum = 0i32; - for j in 0..N { - sum += (pk.a[j][i] * r[j]) as i32; - } - u[i] = mod_q(sum); - } - - // Compute v = b^T r + ⌊q/2⌋m - let mut v = 0i32; - for i in 0..N { - v += (pk.b[i] * r[i]) as i32; - } - if message { - v += (Q / 2) as i32; - } - - Ciphertext { u, v: mod_q(v) } -} - -pub fn decrypt(sk: &SecretKey, ct: &Ciphertext) -> bool { - // Compute v - s^T u - let mut sum = ct.v as i32; - for i in 0..N { - sum -= (sk.s[i] * ct.u[i]) as i32; - } - sum = mod_q(sum) as i32; - - // Check if closer to 0 or ⌊q/2⌋ - let q_half = (Q / 2) as i32; - let mut dist_to_zero = sum.min(Q as i32 - sum); - let mut dist_to_q_half = ((sum - q_half).abs()).min((sum - q_half + Q as i32).abs()); - - dist_to_q_half < dist_to_zero -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_encryption_decryption() { - // Test multiple random instances - for _ in 0..100 { - let (pk, sk) = keygen(); - - // Test encryption and decryption of 0 - let ct_zero = encrypt(&pk, false); - assert_eq!(decrypt(&sk, &ct_zero), false); - - // Test encryption and decryption of 1 - let ct_one = encrypt(&pk, true); - assert_eq!(decrypt(&sk, &ct_one), true); - } - } - - #[test] - fn test_with_fixed_randomness() { - // This test would need a deterministic RNG to be meaningful - // In practice, you'd want to test with known test vectors - } - - #[test] - fn test_binomial_distribution() { - let mut counts = vec![0; 3]; // -1, 0, 1 for k=1 - for _ in 0..1000 { - let sample = sample_binomial(); - counts[(sample + 1) as usize] += 1; - } - // Check rough distribution - should be approximately 1/4, 1/2, 1/4 - for count in counts.iter() { - assert!(*count > 150); // Should be roughly 250 but allow some variance - } - } -} From ee086f2062a00cc5554e676457fbf0f65ae204a0 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Mon, 20 Jan 2025 17:19:26 +0530 Subject: [PATCH 03/10] init kem --- src/kem/README.md | 23 +++++++++++ src/kem/kyber/README.md | 84 +++++++++++++++++++++++++++++++++++++++++ src/kem/kyber/mod.rs | 77 +++++++++++++++++++++++++++++++++++++ src/kem/mod.rs | 1 + src/lib.rs | 2 + 5 files changed, 187 insertions(+) create mode 100644 src/kem/README.md create mode 100644 src/kem/kyber/README.md create mode 100644 src/kem/kyber/mod.rs create mode 100644 src/kem/mod.rs diff --git a/src/kem/README.md b/src/kem/README.md new file mode 100644 index 00000000..e104ac3e --- /dev/null +++ b/src/kem/README.md @@ -0,0 +1,23 @@ +# Key Encpasulation Mechanism + +## ToC +- What is KEM? +- How is it useful? +- Security + +KEM is a cryptographic algorithm, that is used to establish a shared key between two parties on a untrusted channel. + +KEM consists of three algorithms: +- $\text{KeyGen}$: Generates encapsulation and decapsulation keys +- $\text{Encaps}$: Encapsulates key and ciphertext that is sent to other part to generate shared key +- $\text{Decaps}$: Generates shared key from decapsulation key + +```mermaid +flowchart TB +a[Alice]-->k[KeyGen] +k--"Decapsulation Key"-->d[Decaps] +k--"Encapsulation Key"-->e[Encaps] +e--"ciphertext"-->d +d-->al[Alice's shared key] +e-->bo[Bob's shared key] +``` \ No newline at end of file diff --git a/src/kem/kyber/README.md b/src/kem/kyber/README.md new file mode 100644 index 00000000..f5899c71 --- /dev/null +++ b/src/kem/kyber/README.md @@ -0,0 +1,84 @@ +# Module Lattice Based KEM (FIPS 203) + +## ToC +- [ ] Intrduction +- [ ] Preliminaries +- [ ] Parameters +- [ ] Encryption +- [ ] KEM +- [ ] Optimisation +- [ ] Performance +- [ ] Security + + +## Introduction + +> [!NOTE] +> Implementation and terminilogy follows FIPS 203 that formalises PQ-KEM standard for the internet. + +ML-KEM is based on Module Learning with Errors problems introduced by [Reg05][Reg05]. + +## Preliminaries + +## Implementation Details +- Polynomial Rings + - representation + - Arithmetic with Polynomials +- Matrices and Vectors + - representation + - Arithmetic +- cryptographic functions + - SHA3-256, SHA3-512 + - SHAKE128, SHAKE256 + - PRF: $\textsf{PRF}_{\eta}(s,b):=\text{SHAKE256}(s\|b,8\cdot 64 \cdot \eta)$ + - Hash function: + - $H: H(s) := \text{SHA3-256}(s)$, where $s\in \mathbb{B}^*$ + - $J: J(s) := \text{SHAKE256}(s,8\cdot 32)$, where $s\in \mathbb{B}^*$ + - $G:\ G(c \in \mathbb{B}^*) := \text{SHA3-512}(c) \in \mathbb{B}^{32}\times\mathbb{B}^{32}$ + - XOF: SHAKE128 +- General Functions: + - BitsToBytes, BytesToBits + - Compression, Decompression + - ByteEncode, ByteDecode +- Sampling algorithms: + - SampleNTT +- NTT, RevNTT +- MultiplyNTT +- BaseCaseMultiply + +## NTT: Number-Theroretic Transform + +- Special case of FFT performed over $\mathbb{Z}_q^*$. +- Most computationally intensive operation in encryption scheme is the multiplication of polynomials in $\mathcal{R}_{q,f}$ that is $\mathcal{O}(n^2)$, where $n$ is the degree of the polynomial. NTT allows to perform this in $\mathcal{O}(n\log n)$. +- Polynomial: $\mathbb{Z}_q[X]/(X^d+\alpha)$, where $d=2^k$ and $-\alpha$ is perfect square of $d\in\mathbb{Z}_q$ + - $X^d+\alpha\equiv(X^{d/2}-r)(X^{d/2}+r)\mod q$ +- Using, Chinese Remainder Theorem, $ab\in\mathbb{Z}_q/(X^d+\alpha)$ can be written as $ab \mod (X^{d/2}+r),\ ab \mod (X^{d/2}-r)$, and can be converted back to $\mod (X^d+\alpha)$. +- Doing this recursively, asymptotic complexity of the algorithm turns out to be $d\log d$ + +## Parameters +- $q=3329=2^8+1$ +- $f=X^{256}+1$ +- $k,\eta_1,\eta_2,d_u,d_v$ vary according to parameter sets. + +| | k | $\eta_1$ | $\eta_2$ | $d_u$ | $d_v$ | decryption error | pk size | ciphertext size | +| ---------- | --- | -------- | -------- | ----- | ----- | ---------------- | ------- | --------------- | +| Kyber-512 | 3 | 3 | 2 | 10 | 4 | $2^{−139}$ | 800B | 768B | +| Kyber-768 | 4 | 2 | 2 | 10 | 4 | $2^{−164}$ | 1184B | 1088B | +| Kyber-1024 | 5 | 2 | 2 | 11 | 5 | $2^{−174}$ | 1568B | 1568B | + +## CPA-secure Encryption Scheme +1. $\text{KeyGen}$ +2. + +## TODO + +- [ ] FIPS 202: SHA3-{256,512}, SHAKE{128,256} +- [ ] Rings, Polynomial Rings + - [ ] arithmetic for rings + - [ ] + +## References +- [Module-Lattice-Based Key-Encapsulation Mechanism Standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf) +- [RustCrypto/KEMs](https://github.com/RustCrypto/KEMs): const functions inspiration + +[Reg05]: \ No newline at end of file diff --git a/src/kem/kyber/mod.rs b/src/kem/kyber/mod.rs new file mode 100644 index 00000000..0b393d6e --- /dev/null +++ b/src/kem/kyber/mod.rs @@ -0,0 +1,77 @@ +use rand::Rng; + +use crate::{ + algebra::field::prime::PrimeField, + polynomial::{Basis, Monomial, Polynomial}, +}; + +mod auxiliary; +mod compress; +mod encode; +mod kpke; +mod ntt; +mod sampling; +// #[cfg(test)] mod tests; + +pub const MLKEM_Q: usize = 3329; +pub const MLKEM_N: usize = 256; + +pub struct KPke { + pub eta1: usize, + pub eta2: usize, + pub du: usize, + pub dv: usize, +} + +pub type MlKemField = PrimeField; + +pub struct Ntt; +impl Basis for Ntt { + type Data = (); +} + +pub type PolyRing = Polynomial; +pub type NttPolyRing = Polynomial; + +pub struct PolyVec { + // TODO: might need to be written as [Polynomial; K] + pub vec: [Polynomial; K], +} + +impl PolyVec { + pub fn new(vec: [Polynomial; K]) -> Self { Self { vec } } +} + +pub struct MatrixPolyVec { + pub vec: [PolyVec; K], +} + +// impl MatrixPolyVec { +// pub fn new(vec: [PolyVec; K]) -> Self { Self { vec } } +// } + +pub fn sample_ntt() -> Vec { + let mut rng = rand::thread_rng(); + (0..256).map(|_| rng.gen_range(0..256)).collect() +} + +impl KPke { + pub fn new(eta1: usize, eta2: usize, du: usize, dv: usize) -> Self { Self { eta1, eta2, du, dv } } + + pub fn pke_keygen(&self, d: [u8; 32]) { + // let (rho, sigma) = hash::g(&[d, [k]].concat()); + // let n = 0; + // let mut a_ntt = Vec::with_capacity(capacity) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn poly_ring() { + let coeffs = [MlKemField::new(1); 256]; + let poly: PolyRing<256> = PolyRing::new(coeffs); + } +} diff --git a/src/kem/mod.rs b/src/kem/mod.rs new file mode 100644 index 00000000..19c6bc24 --- /dev/null +++ b/src/kem/mod.rs @@ -0,0 +1 @@ +pub mod kyber; diff --git a/src/lib.rs b/src/lib.rs index 7acc5a38..2c561426 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ //! operations such as point addition, scalar multiplication, and pairing operations. //! - Compiler: Simple DSL to write circuits which can be compiled to polynomials used in PLONK. +#![feature(test)] #![allow(incomplete_features)] #![feature(effects)] #![feature(const_trait_impl)] @@ -33,6 +34,7 @@ pub mod dsa; pub mod encryption; pub mod hashes; pub mod hmac; +pub mod kem; pub mod kzg; pub mod multi_var_poly; pub mod polynomial; From 84130ca9be421b94c1a2657a52f8a39753cd2889 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Mon, 20 Jan 2025 17:19:40 +0530 Subject: [PATCH 04/10] add compress --- src/kem/kyber/compress.rs | 68 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 src/kem/kyber/compress.rs diff --git a/src/kem/kyber/compress.rs b/src/kem/kyber/compress.rs new file mode 100644 index 00000000..771740cc --- /dev/null +++ b/src/kem/kyber/compress.rs @@ -0,0 +1,68 @@ +use super::MlKemField; +use crate::{ + algebra::Finite, + polynomial::{Monomial, Polynomial}, +}; + +/// Compresses a number x to a number in the range [0, 2^d) using the formula round((2^d / q) * x) +/// mod 2^d. +/// round(a / b) = floor((a + b/2) / b) +pub fn compress_fieldelement(x: &MlKemField) -> MlKemField { + // TODO: Implement using barrett reduction + let q_half = (MlKemField::ORDER + 1) >> 1; + MlKemField::new((((x.value << D) + q_half) / MlKemField::ORDER) % (1 << D)) +} + +/// Decompresses a number y to a number in the range [0, q) using the formula round((q / 2^d)) * y. +pub fn decompress_fieldelement(y: &MlKemField) -> MlKemField { + let d_pow_half = 1 << (D - 1); + let quotient = MlKemField::ORDER * y.value + d_pow_half; + MlKemField::new(quotient >> D) +} + +pub fn poly_compress( + poly: &Polynomial, +) -> [MlKemField; D] { + // TODO: remove unwrap + poly + .coefficients + .iter() + .map(compress_fieldelement::<8>) + .collect::>() + .try_into() + .unwrap() +} + +pub fn poly_decompress( + poly: &[MlKemField; D], +) -> Polynomial { + let mut coefficients = [MlKemField::default(); D]; + for (i, x) in poly.iter().enumerate() { + coefficients[i] = decompress_fieldelement::<8>(x); + } + Polynomial::::new(coefficients) +} + +// pub fn polyvec_compress( +// poly_vec: &PolyVec, +// ) -> [[MlKemField; D]; K] { +// let mut res = [[MlKemField::default(); D]; K]; + +// for (i, poly) in poly_vec.vec.iter().enumerate() { +// res[i] = poly_compress(poly); +// } + +// res +// } + +// pub fn polyvec_decompress( + +// ) + +#[test] +fn test_compress_decompress() { + let x = MlKemField::new(10); + let z = decompress_fieldelement::<8>(&x); + let y = compress_fieldelement::<8>(&z); + assert_eq!(x, y); +} From 4e86ea8f92bdb3d18a1f9ebad8bad8b1ec826b69 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Mon, 20 Jan 2025 17:20:02 +0530 Subject: [PATCH 05/10] add auxiliary and compression functions --- src/kem/kyber/auxiliary.rs | 45 ++++++++++++++++ src/kem/kyber/encode.rs | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 src/kem/kyber/auxiliary.rs create mode 100644 src/kem/kyber/encode.rs diff --git a/src/kem/kyber/auxiliary.rs b/src/kem/kyber/auxiliary.rs new file mode 100644 index 00000000..aa3086d4 --- /dev/null +++ b/src/kem/kyber/auxiliary.rs @@ -0,0 +1,45 @@ +use sha3::{ + digest::{ExtendableOutput, Update, XofReader}, + Digest, Shake128, Shake256, +}; + +pub fn prf(s: [u8; 32], b: u8) -> [u8; 64 * eta] { + // concat s and b + let mut hasher = Shake256::default(); + hasher.update(&s); + hasher.update(&[b]); + let mut res = [0u8; 64 * eta]; + XofReader::read(&mut hasher.finalize_xof(), &mut res); + res +} + +pub fn h(s: &[u8]) -> [u8; 32] { sha3::Sha3_256::digest(s).into() } + +pub fn j(s: &[u8]) -> [u8; 32] { + let mut hasher = Shake256::default(); + hasher.update(s); + let mut reader = hasher.finalize_xof(); + let mut res = [0u8; 32]; + XofReader::read(&mut reader, &mut res); + res +} + +pub fn g(c: &[u8]) -> ([u8; 32], [u8; 32]) { + let res = sha3::Sha3_512::digest(c); + (res[..32].try_into().unwrap(), res[32..].try_into().unwrap()) +} + +pub struct Xof(Shake128); + +impl Xof { + pub fn init() -> Self { Self(Shake128::default()) } + + pub fn absorb(mut self, input: &[u8]) -> impl XofReader { + self.0.update(input); + self.0.finalize_xof() + } + + pub fn squeeze(reader: &mut impl XofReader, output: &mut [u8]) { + XofReader::read(reader, output); + } +} diff --git a/src/kem/kyber/encode.rs b/src/kem/kyber/encode.rs new file mode 100644 index 00000000..114b9de6 --- /dev/null +++ b/src/kem/kyber/encode.rs @@ -0,0 +1,107 @@ +use super::MlKemField; +use crate::algebra::field::Field; + +/// Encodes a field element into a byte array where each field element is represented by D bits. +/// Converts the field element into a binary representation and then packs the bits into bytes. +fn byte_encode(f: [MlKemField; 256]) -> [u8; 32 * D] +where [(); 256 * D]: { + let mut encoded_bits = [0u8; 256 * D]; + + for (i, x) in f.iter().enumerate() { + let mut val = x.value; + for j in 0..D { + encoded_bits[i * D + j] = (val & 1) as u8; + val >>= 1; + } + } + + let mut encoded_bytes = [0u8; 32 * D]; + for (i, chunk) in encoded_bits.chunks(8).enumerate() { + encoded_bytes[i] = chunk.iter().enumerate().fold(0, |acc, (j, &b)| acc | (b << j)); + } + + encoded_bytes +} + +/// Encodes a field element into a byte array where each field element is represented by D bits. +/// Converts the field element into a binary representation and then packs the bits into bytes. +fn byte_encode_optimized(f: [MlKemField; 256]) -> [u8; 32 * D] +where [(); 256 * D]: { + let mut encoded_bytes = [0u8; 32 * D]; + + // Process 8 field elements at a time to fill each D bytes + for chunk_idx in 0..32 { + let chunk_start = chunk_idx * 8; + let elements = &f[chunk_start..chunk_start + 8]; + + // For each bit position within the D bits + for bit_pos in 0..D { + let byte_idx = chunk_idx * D + bit_pos; + + // Collect bits from 8 elements into a single byte + encoded_bytes[byte_idx] = elements + .iter() + .enumerate() + .fold(0, |acc, (j, x)| acc | (((x.value >> bit_pos) & 1) as u8) << j); + } + } + + encoded_bytes +} + +/// Decodes a byte array into a field element where each field element is represented by D bits. +/// Unpacks the bytes into bits and then converts the bits into a field element. +fn byte_decode(encoded_bytes: [u8; 32 * D]) -> [MlKemField; 256] +where [(); 256 * D]: { + let mut encoded_bits = [0u8; 256 * D]; + for (i, &byte) in encoded_bytes.iter().enumerate() { + for j in 0..8 { + encoded_bits[i * 8 + j] = (byte >> j) & 1; + } + } + + let mut f = [MlKemField::ZERO; 256]; + for (i, chunk) in encoded_bits.chunks(D).enumerate() { + let mut val = 0; + for (j, &bit) in chunk.iter().enumerate() { + val |= (bit as usize) << j; + } + f[i].value = val; + } + + f +} + +#[cfg(test)] +mod tests { + use super::*; + extern crate test; + + fn generate_test_data() -> [MlKemField; 256] { + let mut data = [MlKemField { value: 0 }; 256]; + for i in 0..256 { + data[i].value = i as usize; + } + data + } + + #[test] + fn test_byte_encode() { + let f = generate_test_data(); + let encoded = byte_encode::<8>(f); + let encoded_optimized = byte_encode_optimized::<8>(f); + assert_eq!(encoded, encoded_optimized); + } + + #[bench] + fn bench_byte_encode(b: &mut test::Bencher) { + let f = generate_test_data(); + b.iter(|| byte_encode::<8>(f)); + } + + #[bench] + fn bench_byte_encode_optimized(b: &mut test::Bencher) { + let f = generate_test_data(); + b.iter(|| byte_encode_optimized::<8>(f)); + } +} From 1bd7fcc4aef99b0627073e9854462309e4275948 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Mon, 20 Jan 2025 17:20:12 +0530 Subject: [PATCH 06/10] add sampling and ntt --- src/kem/kyber/ntt.rs | 112 ++++++++++++++++++++++++++++++++++++++ src/kem/kyber/sampling.rs | 48 ++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 src/kem/kyber/ntt.rs create mode 100644 src/kem/kyber/sampling.rs diff --git a/src/kem/kyber/ntt.rs b/src/kem/kyber/ntt.rs new file mode 100644 index 00000000..b4111923 --- /dev/null +++ b/src/kem/kyber/ntt.rs @@ -0,0 +1,112 @@ +use std::ops::Mul; + +use super::{MlKemField, Ntt}; +use crate::{ + algebra::{field::Field, Finite}, + polynomial::{Monomial, Polynomial}, +}; + +const ZETA: usize = 17; +const fn bitrev_7(x: usize) -> usize { + (x >> 6) & 1 + | ((x >> 5) & 1) << 1 + | ((x >> 4) & 1) << 2 + | ((x >> 3) & 1) << 3 + | ((x >> 2) & 1) << 4 + | ((x >> 1) & 1) << 5 + | (x & 1) << 6 +} + +const ZETA_POW: [MlKemField; 128] = { + let mut i = 0; + let mut curr = 1; + let mut zeta_pow = [MlKemField::ZERO; 128]; + while i < 128 { + zeta_pow[i] = MlKemField::new(curr); + curr = curr * ZETA % MlKemField::ORDER; + i += 1; + } + + let mut zeta_pow_rev = [MlKemField::ZERO; 128]; + let mut i = 0; + while i < 128 { + zeta_pow_rev[i] = zeta_pow[bitrev_7(i)]; + i += 1; + } + + zeta_pow_rev +}; + +const GAMMA: [MlKemField; 128] = { + let mut gamma = [MlKemField::ZERO; 128]; + let mut i = 0; + while i < 128 { + gamma[i] = + MlKemField { value: (ZETA_POW[i].value * ZETA_POW[i].value * ZETA) % MlKemField::ORDER }; + i += 1; + } + + gamma +}; + +impl Polynomial { + pub fn ntt(self) -> Polynomial { + let mut f_ntt_coeffs = [MlKemField::ZERO; D]; + let mut i = 1; + let mut len = 128; + while len >= 2 { + for start in (0..256).step_by(len * 2) { + let zeta = ZETA_POW[i]; + i += 1; + for j in start..start + len { + let t = zeta * self.coefficients[j + len]; + f_ntt_coeffs[j + len] = self.coefficients[j] - t; + f_ntt_coeffs[j] = self.coefficients[j] + t; + } + } + len >>= 1; + } + + Polynomial:: { coefficients: f_ntt_coeffs, basis: Ntt } + } +} + +impl Mul> for Polynomial { + type Output = Self; + + fn mul(self, rhs: Polynomial) -> Self::Output { + let mut res_coeffs = [MlKemField::ZERO; 256]; + + for i in 0..128 { + let (a0, a1) = (self.coefficients[2 * i], self.coefficients[2 * i + 1]); + let (b0, b1) = (rhs.coefficients[2 * i], rhs.coefficients[2 * i + 1]); + let (c0, c1) = (a0 * b0 + GAMMA[i] * a1 * b1, a0 * b1 + a1 * b0); + res_coeffs[2 * i] = c0; + res_coeffs[2 * i + 1] = c1; + } + + Polynomial:: { coefficients: res_coeffs, basis: Ntt } + } +} + +impl Polynomial { + pub fn ntt_inv(self) -> Polynomial { + let mut f_coeffs = [MlKemField::ZERO; D]; + let mut i = 127; + let mut len = 2; + while len <= 128 { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW[i]; + i -= 1; + for j in start..start + len { + let t = self.coefficients[j]; + f_coeffs[j] = t + self.coefficients[j + len]; + f_coeffs[j + len] = zeta * (self.coefficients[j + len] - t); + } + } + len <<= 1; + } + + Polynomial:: { coefficients: f_coeffs, basis: Monomial } + } +} diff --git a/src/kem/kyber/sampling.rs b/src/kem/kyber/sampling.rs new file mode 100644 index 00000000..7b54e620 --- /dev/null +++ b/src/kem/kyber/sampling.rs @@ -0,0 +1,48 @@ +use sha3::digest::XofReader; + +use super::{auxiliary::Xof, MlKemField}; +use crate::algebra::{field::Field, Finite}; + +pub fn sample_ntt(input: &[u8]) -> [MlKemField; 256] { + assert!(input.len() == 34); + let mut ntt = [MlKemField::ZERO; 256]; + + let mut xof = Xof::init().absorb(input); + let mut j = 0; + while j < 256 { + let mut buf = [0u8; 3]; + Xof::squeeze(&mut xof, &mut buf); + + let d_1 = buf[0] as usize + ((buf[1] as usize & 0xf) << 8); + let d_2 = (buf[1] >> 4) as usize + ((buf[2] as usize) << 4); + + if d_1 < MlKemField::ORDER { + ntt[j] = MlKemField::new(d_1); + j += 1; + } + if d_2 < MlKemField::ORDER && j < 256 { + ntt[j] = MlKemField::new(d_2); + j += 1; + } + } + ntt +} + +pub fn sample_poly_cbd(seed: [u8; 64 * ETA]) -> [MlKemField; 256] +where [(); 64 * ETA * 8]: { + let mut bit_encode = [0u8; 64 * ETA * 8]; + for i in 0..64 * ETA { + for j in 0..8 { + bit_encode[i * 8 + j] = (seed[i] >> j) & 1; + } + } + + let mut res = [MlKemField::ZERO; 256]; + for i in 0..256 { + let x = (0..ETA).fold(0, |acc, j| acc + bit_encode[2 * i * ETA + j]); + let y = (0..ETA).fold(0, |acc, j| acc + bit_encode[(2 * i + 1) * ETA + j]); + res[i] = MlKemField::new((x - y) as usize); + } + + res +} From 27c91296a758477117e8ce13aef0c17558b2924f Mon Sep 17 00:00:00 2001 From: lonerapier Date: Sun, 9 Feb 2025 15:56:06 +0530 Subject: [PATCH 07/10] feat: add auxialiary, compress and encode functions --- src/kem/kyber/auxiliary.rs | 11 +++++-- src/kem/kyber/compress.rs | 60 +++++++++++++++++++++--------------- src/kem/kyber/encode.rs | 63 ++++++++++++++++++++++++++++---------- 3 files changed, 90 insertions(+), 44 deletions(-) diff --git a/src/kem/kyber/auxiliary.rs b/src/kem/kyber/auxiliary.rs index aa3086d4..1cbd84b2 100644 --- a/src/kem/kyber/auxiliary.rs +++ b/src/kem/kyber/auxiliary.rs @@ -1,14 +1,19 @@ +//! Contains auxiliary cryptographic functions for Kyber KEM. +//! - [`prf`] - Pseudorandom function +//! - [`h`], [`g`] - Hash function +//! - [`Xof`] - Extendable output function use sha3::{ digest::{ExtendableOutput, Update, XofReader}, Digest, Shake128, Shake256, }; -pub fn prf(s: [u8; 32], b: u8) -> [u8; 64 * eta] { - // concat s and b +pub fn prf(s: &[u8], b: u8) -> [u8; 64 * ETA] { + assert!(s.len() == 32); + let mut hasher = Shake256::default(); hasher.update(&s); hasher.update(&[b]); - let mut res = [0u8; 64 * eta]; + let mut res = [0u8; 64 * ETA]; XofReader::read(&mut hasher.finalize_xof(), &mut res); res } diff --git a/src/kem/kyber/compress.rs b/src/kem/kyber/compress.rs index 771740cc..10f78b6d 100644 --- a/src/kem/kyber/compress.rs +++ b/src/kem/kyber/compress.rs @@ -1,4 +1,4 @@ -use super::MlKemField; +use super::{MlKemField, PolyVec}; use crate::{ algebra::Finite, polynomial::{Monomial, Polynomial}, @@ -7,57 +7,69 @@ use crate::{ /// Compresses a number x to a number in the range [0, 2^d) using the formula round((2^d / q) * x) /// mod 2^d. /// round(a / b) = floor((a + b/2) / b) -pub fn compress_fieldelement(x: &MlKemField) -> MlKemField { +pub fn compress_fieldelement(x: &MlKemField) -> MlKemField { // TODO: Implement using barrett reduction let q_half = (MlKemField::ORDER + 1) >> 1; - MlKemField::new((((x.value << D) + q_half) / MlKemField::ORDER) % (1 << D)) + MlKemField::new((((x.value << d) + q_half) / MlKemField::ORDER) % (1 << d)) } /// Decompresses a number y to a number in the range [0, q) using the formula round((q / 2^d)) * y. -pub fn decompress_fieldelement(y: &MlKemField) -> MlKemField { - let d_pow_half = 1 << (D - 1); +pub fn decompress_fieldelement(y: &MlKemField) -> MlKemField { + let d_pow_half = 1 << (d - 1); let quotient = MlKemField::ORDER * y.value + d_pow_half; - MlKemField::new(quotient >> D) + MlKemField::new(quotient >> d) } -pub fn poly_compress( +pub fn poly_compress( poly: &Polynomial, -) -> [MlKemField; D] { +) -> Polynomial { // TODO: remove unwrap - poly + let coeffs = poly .coefficients .iter() - .map(compress_fieldelement::<8>) + .map(compress_fieldelement::) .collect::>() .try_into() - .unwrap() + .unwrap(); + + Polynomial::::new(coeffs) } -pub fn poly_decompress( +pub fn poly_decompress( poly: &[MlKemField; D], ) -> Polynomial { let mut coefficients = [MlKemField::default(); D]; for (i, x) in poly.iter().enumerate() { - coefficients[i] = decompress_fieldelement::<8>(x); + coefficients[i] = decompress_fieldelement::(x); } Polynomial::::new(coefficients) } -// pub fn polyvec_compress( -// poly_vec: &PolyVec, -// ) -> [[MlKemField; D]; K] { -// let mut res = [[MlKemField::default(); D]; K]; +pub fn polyvec_compress( + poly_vec: &PolyVec, +) -> PolyVec { + let mut res = Vec::with_capacity(K); -// for (i, poly) in poly_vec.vec.iter().enumerate() { -// res[i] = poly_compress(poly); -// } + for poly in poly_vec.vec.iter() { + res.push(poly_compress::(poly)); + } -// res -// } + let res = res.try_into().unwrap(); + PolyVec::new(res) +} -// pub fn polyvec_decompress( +pub fn polyvec_decompress( + poly_vec: &PolyVec, +) -> PolyVec { + let mut res = Vec::with_capacity(K); -// ) + for poly in poly_vec.vec.iter() { + res.push(poly_decompress::(&poly.coefficients)); + } + + let res = res.try_into().unwrap(); + PolyVec::new(res) +} #[test] fn test_compress_decompress() { diff --git a/src/kem/kyber/encode.rs b/src/kem/kyber/encode.rs index 114b9de6..d0202c87 100644 --- a/src/kem/kyber/encode.rs +++ b/src/kem/kyber/encode.rs @@ -1,21 +1,23 @@ -use super::MlKemField; -use crate::algebra::field::Field; +use super::{MlKemField, PolyVec}; +use crate::{ + algebra::field::Field, + polynomial::{Basis, Polynomial}, +}; -/// Encodes a field element into a byte array where each field element is represented by D bits. +/// Encodes a field element into a byte array where each field element is represented by d bits. /// Converts the field element into a binary representation and then packs the bits into bytes. -fn byte_encode(f: [MlKemField; 256]) -> [u8; 32 * D] -where [(); 256 * D]: { - let mut encoded_bits = [0u8; 256 * D]; +pub fn byte_encode(f: &[MlKemField; D]) -> Vec { + let mut encoded_bits = Vec::with_capacity(D * d); for (i, x) in f.iter().enumerate() { let mut val = x.value; - for j in 0..D { - encoded_bits[i * D + j] = (val & 1) as u8; + for j in 0..d { + encoded_bits[i * d + j] = (val & 1) as u8; val >>= 1; } } - let mut encoded_bytes = [0u8; 32 * D]; + let mut encoded_bytes = Vec::with_capacity(D / 8 * d); for (i, chunk) in encoded_bits.chunks(8).enumerate() { encoded_bytes[i] = chunk.iter().enumerate().fold(0, |acc, (j, &b)| acc | (b << j)); } @@ -23,6 +25,18 @@ where [(); 256 * D]: { encoded_bytes } +pub fn byte_encode_polyvec( + f: PolyVec, +) -> Vec { + let mut encoded_bytes = Vec::with_capacity(D / 8 * d * K); + for (i, poly) in f.vec.iter().enumerate() { + let encoded = byte_encode::(&poly.coefficients); + encoded_bytes[i * D / 8 * d..(i + 1) * D / 8 * d].copy_from_slice(&encoded); + } + + encoded_bytes +} + /// Encodes a field element into a byte array where each field element is represented by D bits. /// Converts the field element into a binary representation and then packs the bits into bytes. fn byte_encode_optimized(f: [MlKemField; 256]) -> [u8; 32 * D] @@ -51,20 +65,20 @@ where [(); 256 * D]: { /// Decodes a byte array into a field element where each field element is represented by D bits. /// Unpacks the bytes into bits and then converts the bits into a field element. -fn byte_decode(encoded_bytes: [u8; 32 * D]) -> [MlKemField; 256] -where [(); 256 * D]: { - let mut encoded_bits = [0u8; 256 * D]; +pub fn byte_decode(encoded_bytes: &[u8]) -> [MlKemField; D] { + let mut encoded_bits = Vec::with_capacity(256 * d); for (i, &byte) in encoded_bytes.iter().enumerate() { for j in 0..8 { - encoded_bits[i * 8 + j] = (byte >> j) & 1; + encoded_bits.push((byte >> j) & 1); } } - let mut f = [MlKemField::ZERO; 256]; + let mask: usize = (1 << d) - 1; + let mut f = [MlKemField::ZERO; D]; for (i, chunk) in encoded_bits.chunks(D).enumerate() { let mut val = 0; for (j, &bit) in chunk.iter().enumerate() { - val |= (bit as usize) << j; + val |= ((bit as usize) << j) & mask; } f[i].value = val; } @@ -72,6 +86,21 @@ where [(); 256 * D]: { f } +pub fn byte_decode_polyvec( + encoded_bytes: &[u8], + basis: B, +) -> PolyVec { + let mut f = Vec::with_capacity(K); + + for bytes in encoded_bytes.chunks(32 * d) { + let coeffs = byte_decode::(bytes.try_into().unwrap()); + f.push(Polynomial { coefficients: coeffs, basis: basis.clone() }) + } + + let f = f.try_into().unwrap(); + PolyVec::new(f) +} + #[cfg(test)] mod tests { use super::*; @@ -88,7 +117,7 @@ mod tests { #[test] fn test_byte_encode() { let f = generate_test_data(); - let encoded = byte_encode::<8>(f); + let encoded = byte_encode::<8, 256>(&f); let encoded_optimized = byte_encode_optimized::<8>(f); assert_eq!(encoded, encoded_optimized); } @@ -96,7 +125,7 @@ mod tests { #[bench] fn bench_byte_encode(b: &mut test::Bencher) { let f = generate_test_data(); - b.iter(|| byte_encode::<8>(f)); + b.iter(|| byte_encode::<8, 256>(&f)); } #[bench] From b44b89ccde48eafa90a874af31f649ae7054f044 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Sun, 9 Feb 2025 15:56:36 +0530 Subject: [PATCH 08/10] feat: add sampling and ntt algorithms --- src/kem/kyber/ntt.rs | 31 ++++++++++++++++++++++++------- src/kem/kyber/sampling.rs | 13 ++++++++----- src/polynomial/arithmetic.rs | 21 +++++++++++++++++++++ src/polynomial/mod.rs | 4 ++-- 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/kem/kyber/ntt.rs b/src/kem/kyber/ntt.rs index b4111923..f0e00d3e 100644 --- a/src/kem/kyber/ntt.rs +++ b/src/kem/kyber/ntt.rs @@ -1,6 +1,6 @@ use std::ops::Mul; -use super::{MlKemField, Ntt}; +use super::{MlKemField, Ntt, PolyVec}; use crate::{ algebra::{field::Field, Finite}, polynomial::{Monomial, Polynomial}, @@ -71,13 +71,30 @@ impl Polynomial { } } -impl Mul> for Polynomial { - type Output = Self; +impl PolyVec { + pub fn ntt(self) -> PolyVec { + let ntt_vec = self.vec.iter().map(|poly| poly.ntt()).collect::>().try_into().unwrap(); - fn mul(self, rhs: Polynomial) -> Self::Output { - let mut res_coeffs = [MlKemField::ZERO; 256]; + PolyVec:: { vec: ntt_vec } + } +} + +impl PolyVec { + pub fn ntt_inv(self) -> PolyVec { + let ntt_inv_vec = + self.vec.iter().map(|poly| poly.ntt_inv()).collect::>().try_into().unwrap(); + + PolyVec { vec: ntt_inv_vec } + } +} + +impl Mul<&Polynomial> for &Polynomial { + type Output = Polynomial; + + fn mul(self, rhs: &Polynomial) -> Self::Output { + let mut res_coeffs = [MlKemField::ZERO; D]; - for i in 0..128 { + for i in 0..D >> 1 { let (a0, a1) = (self.coefficients[2 * i], self.coefficients[2 * i + 1]); let (b0, b1) = (rhs.coefficients[2 * i], rhs.coefficients[2 * i + 1]); let (c0, c1) = (a0 * b0 + GAMMA[i] * a1 * b1, a0 * b1 + a1 * b0); @@ -85,7 +102,7 @@ impl Mul> for Polynomial res_coeffs[2 * i + 1] = c1; } - Polynomial:: { coefficients: res_coeffs, basis: Ntt } + Polynomial:: { coefficients: res_coeffs, basis: Ntt } } } diff --git a/src/kem/kyber/sampling.rs b/src/kem/kyber/sampling.rs index 7b54e620..24450db1 100644 --- a/src/kem/kyber/sampling.rs +++ b/src/kem/kyber/sampling.rs @@ -1,13 +1,16 @@ -use sha3::digest::XofReader; - use super::{auxiliary::Xof, MlKemField}; use crate::algebra::{field::Field, Finite}; -pub fn sample_ntt(input: &[u8]) -> [MlKemField; 256] { - assert!(input.len() == 34); +pub fn sample_ntt(rho: &[u8], j: u8, i: u8) -> [MlKemField; 256] { + assert!(rho.len() == 32); + let mut input = [0u8; 34]; + input[..32].copy_from_slice(rho); + input[32] = j; + input[33] = i; + let mut ntt = [MlKemField::ZERO; 256]; - let mut xof = Xof::init().absorb(input); + let mut xof = Xof::init().absorb(&input); let mut j = 0; while j < 256 { let mut buf = [0u8; 3]; diff --git a/src/polynomial/arithmetic.rs b/src/polynomial/arithmetic.rs index fa3c99b7..6371e808 100644 --- a/src/polynomial/arithmetic.rs +++ b/src/polynomial/arithmetic.rs @@ -34,6 +34,27 @@ impl Add Add<&Polynomial> + for &Polynomial +{ + type Output = Polynomial; + + /// Implements addition of two polynomials by adding their coefficients. + /// Note: degree of first operand > deg of second operand. + fn add(self, rhs: &Polynomial) -> Self::Output { + let coefficients = self + .coefficients + .iter() + .zip(rhs.coefficients.iter().chain(std::iter::repeat(&F::ZERO))) + .map(|(&a, &b)| a + b) + .take(D) + .collect::>() + .try_into() + .unwrap_or_else(|v: Vec| panic!("Expected a Vec of length {} but it was {}", D, v.len())); + Self::Output { coefficients, basis: self.basis } + } +} + impl AddAssign> for Polynomial { diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 378ccfc1..fbb2892c 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -17,7 +17,7 @@ //! - Includes Discrete Fourier Transform (DFT) for polynomials in the [`Monomial`] basis to convert //! into the [`Lagrange`] basis via evaluation at the roots of unity. -use std::array; +use std::{array, fmt::Debug}; use super::*; use crate::algebra::field::FiniteField; @@ -45,7 +45,7 @@ pub struct Polynomial { /// [`Basis`] trait is used to specify the basis of the polynomial. /// The basis can be [`Monomial`] or [`Lagrange`]. This is a type-state pattern for [`Polynomial`]. -pub trait Basis { +pub trait Basis: Debug + Clone { /// The associated data type for the basis. type Data; } From 9e01b74b71cff4175c2d7341da9db36d30c44289 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Sun, 9 Feb 2025 15:57:23 +0530 Subject: [PATCH 09/10] feat: complete mlkem (testing contd..) --- src/kem/kyber/README.md | 6 +- src/kem/kyber/algebra.rs | 0 src/kem/kyber/kpke.rs | 10 + src/kem/kyber/mod.rs | 436 ++++++++++++++++++++++++++++++++++++++- src/kem/kyber/tests.rs | 0 5 files changed, 440 insertions(+), 12 deletions(-) create mode 100644 src/kem/kyber/algebra.rs create mode 100644 src/kem/kyber/kpke.rs create mode 100644 src/kem/kyber/tests.rs diff --git a/src/kem/kyber/README.md b/src/kem/kyber/README.md index f5899c71..b5735b9c 100644 --- a/src/kem/kyber/README.md +++ b/src/kem/kyber/README.md @@ -19,6 +19,10 @@ ML-KEM is based on Module Learning with Errors problems introduced by [Reg05][Reg05]. ## Preliminaries +- [ ] Why Polynomial Rings? +- [ ] Cyclotomic Polynomial Rings +- [ ] NTT +- [ ] Lattices ## Implementation Details - Polynomial Rings @@ -81,4 +85,4 @@ ML-KEM is based on Module Learning with Errors problems introduced by [Reg05][Re - [Module-Lattice-Based Key-Encapsulation Mechanism Standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf) - [RustCrypto/KEMs](https://github.com/RustCrypto/KEMs): const functions inspiration -[Reg05]: \ No newline at end of file +[Reg05]: diff --git a/src/kem/kyber/algebra.rs b/src/kem/kyber/algebra.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/kem/kyber/kpke.rs b/src/kem/kyber/kpke.rs new file mode 100644 index 00000000..728340c0 --- /dev/null +++ b/src/kem/kyber/kpke.rs @@ -0,0 +1,10 @@ +// use super::{auxiliary::g, MatrixPolyVec}; + +// pub struct KpkeEncryptionKey([u8; 384*K+32]); +// pub struct KpkeDecryptionKey([u8; 384*K]); +// pub fn kpke_keygen(d: [u8; 32]) -> (KpkeEncryptionKey, KpkeDecryptionKey) { +// let (rho, sigma) = g(d, K); + +// let mut a_ntt = MatrixPolyVec:: + +// } diff --git a/src/kem/kyber/mod.rs b/src/kem/kyber/mod.rs index 0b393d6e..45b95732 100644 --- a/src/kem/kyber/mod.rs +++ b/src/kem/kyber/mod.rs @@ -1,3 +1,8 @@ +use std::ops::{Add, AddAssign, Mul}; + +use auxiliary::{g, h}; +use compress::{poly_compress, poly_decompress, polyvec_compress, polyvec_decompress}; +use encode::{byte_decode, byte_decode_polyvec, byte_encode, byte_encode_polyvec}; use rand::Rng; use crate::{ @@ -8,7 +13,7 @@ use crate::{ mod auxiliary; mod compress; mod encode; -mod kpke; +// mod kpke; mod ntt; mod sampling; // #[cfg(test)] mod tests; @@ -25,6 +30,7 @@ pub struct KPke { pub type MlKemField = PrimeField; +#[derive(Debug, Clone, Copy)] pub struct Ntt; impl Basis for Ntt { type Data = (); @@ -33,6 +39,35 @@ impl Basis for Ntt { pub type PolyRing = Polynomial; pub type NttPolyRing = Polynomial; +impl Add for &NttPolyRing { + type Output = NttPolyRing; + + fn add(self, rhs: Self) -> Self::Output { + let coeffs = self + .coefficients + .iter() + .zip(rhs.coefficients.iter()) + .map(|(&a, &b)| a + b) + .collect::>() + .try_into() + .unwrap(); + NttPolyRing::new(coeffs) + } +} + +impl AddAssign for NttPolyRing { + fn add_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coefficients[i] += rhs.coefficients[i]; + } + } +} + +impl NttPolyRing { + pub fn new(coeffs: [MlKemField; D]) -> Self { Self { coefficients: coeffs, basis: Ntt } } +} + +#[derive(Debug, Clone, Copy)] pub struct PolyVec { // TODO: might need to be written as [Polynomial; K] pub vec: [Polynomial; K], @@ -42,26 +77,405 @@ impl PolyVec { pub fn new(vec: [Polynomial; K]) -> Self { Self { vec } } } +impl PolyVec { + pub fn dot_product(&self, rhs: &Self) -> NttPolyRing { + let mut result = NttPolyRing::new([MlKemField::ZERO; D]); + for i in 0..K { + result += &self.vec[i] * &rhs.vec[i]; + } + result + } +} + +impl Add for PolyVec { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let vec = (0..K) + .map(|i| &self.vec[i] + &rhs.vec[i]) + .collect::>>() + .try_into() + .unwrap(); + Self::new(vec) + } +} + +impl Add for PolyVec { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let vec = (0..K) + .map(|i| self.vec[i] + rhs.vec[i]) + .collect::>>() + .try_into() + .unwrap(); + Self::new(vec) + } +} + +// impl Iterator for PolyVec { +// type Item = Polynomial; + +// fn next(&mut self) -> Option { self.vec.iter().next().copied() } +// } + +/// A matrix of polynomial vectors of size K \times K. pub struct MatrixPolyVec { pub vec: [PolyVec; K], } -// impl MatrixPolyVec { -// pub fn new(vec: [PolyVec; K]) -> Self { Self { vec } } -// } +impl MatrixPolyVec { + /// Creates a new matrix of polynomial vectors. + pub fn new(vec: [PolyVec; K]) -> Self { Self { vec } } +} + +impl MatrixPolyVec { + pub fn transpose(&self) -> Self { + let poly_vec = (0..K) + .map(|col_idx| { + let vec = (0..K) + .map(|row_idx| self.vec[row_idx].vec[col_idx].clone()) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); -pub fn sample_ntt() -> Vec { - let mut rng = rand::thread_rng(); - (0..256).map(|_| rng.gen_range(0..256)).collect() + Self::new(poly_vec) + } } +use crate::algebra::field::Field; + +impl Mul<&PolyVec> for MatrixPolyVec { + type Output = PolyVec; + + fn mul(self, rhs: &PolyVec) -> Self::Output { + let res = (0..K) + .map(|i| { + let mut sum = NttPolyRing::new([MlKemField::ZERO; D]); + for j in 0..K { + sum += &self.vec[i].vec[j] * &rhs.vec[j]; + } + sum + }) + .collect::>>() + .try_into() + .unwrap(); + + PolyVec::new(res) + } +} + +pub struct KPkeEncryptionKey([u8; 384 * K + 32]) +where [(); 384 * K + 32]:; +pub struct KPkeDecryptionKey([u8; 384 * K]) +where [(); 384 * K]:; + impl KPke { pub fn new(eta1: usize, eta2: usize, du: usize, dv: usize) -> Self { Self { eta1, eta2, du, dv } } - pub fn pke_keygen(&self, d: [u8; 32]) { - // let (rho, sigma) = hash::g(&[d, [k]].concat()); - // let n = 0; - // let mut a_ntt = Vec::with_capacity(capacity) + pub fn pke_keygen( + &self, + d: [u8; 32], + ) -> (KPkeEncryptionKey, KPkeDecryptionKey) + where + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 384 * K + 32]:, + [(); 384 * K]:, + { + let mut input = d.to_vec(); + input.push(K as u8); + let (rho, sigma) = auxiliary::g(&input); + + let mut n = 0; + + let a_hat_vec: [PolyVec; K] = (0..K) + .map(|i| { + let vec: [NttPolyRing; K] = (0..K) + .map(|j| { + let coeffs = sampling::sample_ntt(&rho, j as u8, i as u8); + NttPolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); + let a_hat: MatrixPolyVec = MatrixPolyVec::new(a_hat_vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&sigma, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let s: PolyVec = PolyVec::new(vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&sigma, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let e: PolyVec = PolyVec::new(vec); + + let s_hat = s.ntt(); + let e_hat = e.ntt(); + + let t_hat = a_hat * &s_hat + e_hat; + + let ek = [byte_encode_polyvec::(t_hat), rho.to_vec()] + .concat() + .try_into() + .unwrap(); + let dk = byte_encode_polyvec::(s_hat).try_into().unwrap(); + + (KPkeEncryptionKey(ek), KPkeDecryptionKey(dk)) + } + + pub fn kpke_encrypt( + &self, + encryption_key: KPkeEncryptionKey, + m: [u8; 32], + r: [u8; 32], + ) -> Vec + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let mut n = 0; + let (t_hat, rho) = encryption_key.0.split_at(384 * K); + let t_hat = byte_decode_polyvec::(t_hat, Ntt); + let rho: [u8; 32] = rho.try_into().unwrap(); + + let a_hat_vec: [PolyVec; K] = (0..K) + .map(|i| { + let vec: [NttPolyRing; K] = (0..K) + .map(|j| { + let coeffs = sampling::sample_ntt(&rho, j as u8, i as u8); + NttPolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); + let a_hat: MatrixPolyVec = MatrixPolyVec::new(a_hat_vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&r, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let y: PolyVec = PolyVec::new(vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&r, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let e1: PolyVec = PolyVec::new(vec); + + let e2 = PolyRing::new(sampling::sample_poly_cbd::(auxiliary::prf::(&r, n))); + + let y_hat = y.ntt(); + + let u = (a_hat.transpose() * &y_hat).ntt_inv() + e1; + + let mu = poly_decompress::(&byte_decode::<1, MLKEM_N>(&m)); + + let v = t_hat.dot_product(&y_hat).ntt_inv() + e2 + mu; + // let c1 = poly_compress::(&u); + let c1_compressed = polyvec_compress::(&u); + let c1 = byte_encode_polyvec::(c1_compressed); + let c2_compressed = poly_compress::(&v); + let c2 = byte_encode::(&c2_compressed.coefficients); + [c1, c2].concat() + } + + pub fn kpke_decrypt( + &self, + decryption_key: KPkeDecryptionKey, + c: Vec, + ) -> [u8; 32] + where + [(); 384 * K]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let (c1, c2) = c.split_at(384 * K); + let c1 = byte_decode_polyvec::(c1, Monomial); + let u_prime = polyvec_decompress::(&c1); + let v_prime = poly_decompress::(&byte_decode::(c2)); + + let s_ntt = byte_decode_polyvec::(&decryption_key.0, Ntt); + let w = v_prime - s_ntt.dot_product(&u_prime.ntt()).ntt_inv(); + let m = byte_encode::<1, MLKEM_N>(&poly_compress::(&w).coefficients); + m.try_into().unwrap() + } +} + +pub struct MlKemEncapsKey([u8; 384 * K + 32]) +where [(); 384 * K + 32]:; +pub struct MlKemDecapsKey +where [(); 384 * K + 32]: { + dk_pke: KPkeDecryptionKey, + ek_pke: KPkeEncryptionKey, + h: [u8; 32], + z: [u8; 32], +} + +pub struct MlKem { + kpke: KPke, +} + +impl MlKem { + fn keygen_internal( + &self, + d: [u8; 32], + z: [u8; 32], + ) -> (MlKemEncapsKey, MlKemDecapsKey) + where + [(); 384 * K + 32]:, + [(); 768 * K + 96]:, + [(); 64 * eta1 * 8]:, + { + let (ek_pke, dk_pke) = self.kpke.pke_keygen::(d); + let ek: [u8; 384 * K + 32] = ek_pke.0.clone(); + let h = h(&ek); + + (MlKemEncapsKey(ek), MlKemDecapsKey { dk_pke, ek_pke, h, z }) + } + + fn encaps_internal( + &self, + ek: MlKemEncapsKey, + m: [u8; 32], + ) -> ([u8; 32], Vec) + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let (k, r) = g([m, h(&ek.0)].concat().as_ref()); + let c = self.kpke.kpke_encrypt::(KPkeEncryptionKey(ek.0), m, r); + (k, c) + } + + fn decaps_internal( + &self, + dk: MlKemDecapsKey, + c: Vec, + ) -> [u8; 32] + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let m_prime = self.kpke.kpke_decrypt::(dk.dk_pke, c.clone()); + let (mut k_prime, r_prime) = g([m_prime, dk.h].concat().as_ref()); + + let mut pseudo_input = dk.z.to_vec(); + pseudo_input.extend(c.clone()); + let k_bar = auxiliary::j(&pseudo_input); + let c_prime = self.kpke.kpke_encrypt::(dk.ek_pke, m_prime, r_prime); + if c != c_prime { + k_prime = k_bar; + } + k_prime + } + + pub fn keygen( + &self, + ) -> Result<(MlKemEncapsKey, MlKemDecapsKey), String> + where + [(); 384 * K + 32]:, + [(); 768 * K + 96]:, + [(); 64 * eta1 * 8]:, { + // TODO: check if rng is good to use + let mut rng = rand::thread_rng(); + let d: [u8; 32] = rng.gen::<[u8; 32]>(); + let z = rand::thread_rng().gen::<[u8; 32]>(); + if d.is_empty() || z.is_empty() { + return Err("Error: keygen failed".to_string()); + } + Ok(self.keygen_internal::(d, z)) + } + + pub fn encaps( + &self, + ek: MlKemEncapsKey, + m: [u8; 32], + ) -> Result<([u8; 32], Vec), String> + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + // TODO: run encaps key check + if m.is_empty() { + return Err("Error: encaps failed".to_string()); + } + Ok(self.encaps_internal::(ek, m)) + } + + pub fn decaps( + &self, + dk: MlKemDecapsKey, + c: Vec, + ) -> Result<[u8; 32], String> + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + if c.len() != 32 * (du * K + dv) { + return Err("Error: invalid ciphertext".to_string()); + } + + Ok(self.decaps_internal::(dk, c)) } } diff --git a/src/kem/kyber/tests.rs b/src/kem/kyber/tests.rs new file mode 100644 index 00000000..e69de29b From 512af6bc9a08617bddcc73c6b2ba4fee44e9d5d3 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Sun, 9 Feb 2025 17:29:34 +0530 Subject: [PATCH 10/10] use ronkathon sha3 primitive --- src/hashes/sha3.rs | 26 ++++++++++++------------ src/kem/kyber/auxiliary.rs | 41 +++++++++++++++++++------------------- src/kem/kyber/encode.rs | 4 ++-- src/kem/kyber/mod.rs | 2 +- src/kem/kyber/sampling.rs | 5 +++-- src/lib.rs | 1 - 6 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/hashes/sha3.rs b/src/hashes/sha3.rs index 30d5899f..4d366fa3 100644 --- a/src/hashes/sha3.rs +++ b/src/hashes/sha3.rs @@ -41,6 +41,19 @@ const RHO: [[u32; 5]; 5] = 27, 20, 39, 8, 14, ]]; +/// Type alias for SHA3-224. +pub type Sha3_224 = Sha3<28>; +/// Type alias for SHA3-256. +pub type Sha3_256 = Sha3<32>; +/// Type alias for SHA3-384. +pub type Sha3_384 = Sha3<48>; +/// Type alias for SHA3-512. +pub type Sha3_512 = Sha3<64>; +/// Type alias for SHAKE128. +pub type Shake128 = Shake<128>; +/// Type alias for SHAKE256. +pub type Shake256 = Shake<256>; + #[derive(Clone, Debug)] struct KeccakState { lanes: [[u64; 5]; 5], @@ -277,19 +290,6 @@ impl Shake { } } -/// Type alias for SHA3-224. -pub type Sha3_224 = Sha3<28>; -/// Type alias for SHA3-256. -pub type Sha3_256 = Sha3<32>; -/// Type alias for SHA3-384. -pub type Sha3_384 = Sha3<48>; -/// Type alias for SHA3-512. -pub type Sha3_512 = Sha3<64>; -/// Type alias for SHAKE128. -pub type Shake128 = Shake<128>; -/// Type alias for SHAKE256. -pub type Shake256 = Shake<256>; - #[cfg(test)] mod tests { use hex_literal::hex; diff --git a/src/kem/kyber/auxiliary.rs b/src/kem/kyber/auxiliary.rs index 1cbd84b2..1a60fc82 100644 --- a/src/kem/kyber/auxiliary.rs +++ b/src/kem/kyber/auxiliary.rs @@ -2,49 +2,48 @@ //! - [`prf`] - Pseudorandom function //! - [`h`], [`g`] - Hash function //! - [`Xof`] - Extendable output function -use sha3::{ - digest::{ExtendableOutput, Update, XofReader}, - Digest, Shake128, Shake256, -}; + +use crate::hashes::sha3::{Sha3_256, Sha3_512, Shake128, Shake256}; pub fn prf(s: &[u8], b: u8) -> [u8; 64 * ETA] { assert!(s.len() == 32); - let mut hasher = Shake256::default(); - hasher.update(&s); + let mut hasher = Shake256::new(); + hasher.update(s); hasher.update(&[b]); let mut res = [0u8; 64 * ETA]; - XofReader::read(&mut hasher.finalize_xof(), &mut res); + hasher.squeeze(&mut res); res } -pub fn h(s: &[u8]) -> [u8; 32] { sha3::Sha3_256::digest(s).into() } +pub fn h(s: &[u8]) -> [u8; 32] { + let mut hasher = Sha3_256::new(); + hasher.update(s); + hasher.finalize() +} pub fn j(s: &[u8]) -> [u8; 32] { - let mut hasher = Shake256::default(); + let mut hasher = Shake256::new(); hasher.update(s); - let mut reader = hasher.finalize_xof(); let mut res = [0u8; 32]; - XofReader::read(&mut reader, &mut res); + hasher.squeeze(&mut res); res } pub fn g(c: &[u8]) -> ([u8; 32], [u8; 32]) { - let res = sha3::Sha3_512::digest(c); - (res[..32].try_into().unwrap(), res[32..].try_into().unwrap()) + let mut hasher = Sha3_512::new(); + hasher.update(c); + let res = hasher.finalize(); + let (h0, h1) = res.split_at(32); + (h0.try_into().unwrap(), h1.try_into().unwrap()) } pub struct Xof(Shake128); impl Xof { - pub fn init() -> Self { Self(Shake128::default()) } + pub fn init() -> Self { Self(Shake128::new()) } - pub fn absorb(mut self, input: &[u8]) -> impl XofReader { - self.0.update(input); - self.0.finalize_xof() - } + pub fn absorb(&mut self, input: &[u8]) { self.0.update(input); } - pub fn squeeze(reader: &mut impl XofReader, output: &mut [u8]) { - XofReader::read(reader, output); - } + pub fn squeeze(&mut self, output: &mut [u8]) { self.0.squeeze(output); } } diff --git a/src/kem/kyber/encode.rs b/src/kem/kyber/encode.rs index d0202c87..212d9122 100644 --- a/src/kem/kyber/encode.rs +++ b/src/kem/kyber/encode.rs @@ -93,7 +93,7 @@ pub fn byte_decode_polyvec(bytes.try_into().unwrap()); + let coeffs = byte_decode::(bytes); f.push(Polynomial { coefficients: coeffs, basis: basis.clone() }) } @@ -109,7 +109,7 @@ mod tests { fn generate_test_data() -> [MlKemField; 256] { let mut data = [MlKemField { value: 0 }; 256]; for i in 0..256 { - data[i].value = i as usize; + data[i].value = i; } data } diff --git a/src/kem/kyber/mod.rs b/src/kem/kyber/mod.rs index 45b95732..749213ca 100644 --- a/src/kem/kyber/mod.rs +++ b/src/kem/kyber/mod.rs @@ -375,7 +375,7 @@ impl MlKem { [(); 64 * eta1 * 8]:, { let (ek_pke, dk_pke) = self.kpke.pke_keygen::(d); - let ek: [u8; 384 * K + 32] = ek_pke.0.clone(); + let ek: [u8; 384 * K + 32] = ek_pke.0; let h = h(&ek); (MlKemEncapsKey(ek), MlKemDecapsKey { dk_pke, ek_pke, h, z }) diff --git a/src/kem/kyber/sampling.rs b/src/kem/kyber/sampling.rs index 24450db1..eb48398f 100644 --- a/src/kem/kyber/sampling.rs +++ b/src/kem/kyber/sampling.rs @@ -10,11 +10,12 @@ pub fn sample_ntt(rho: &[u8], j: u8, i: u8) -> [MlKemField; 256] { let mut ntt = [MlKemField::ZERO; 256]; - let mut xof = Xof::init().absorb(&input); + let mut xof = Xof::init(); + xof.absorb(&input); let mut j = 0; while j < 256 { let mut buf = [0u8; 3]; - Xof::squeeze(&mut xof, &mut buf); + xof.squeeze(&mut buf); let d_1 = buf[0] as usize + ((buf[1] as usize & 0xf) << 8); let d_2 = (buf[1] >> 4) as usize + ((buf[2] as usize) << 4); diff --git a/src/lib.rs b/src/lib.rs index 2c561426..527e0c0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,6 @@ #![feature(const_option)] #![feature(generic_const_exprs)] #![feature(specialization)] -#![feature(test)] #![warn(missing_docs)] pub mod algebra;