diff --git a/src/algebra/field/prime/mod.rs b/src/algebra/field/prime/mod.rs index 8794f80f..131e76fa 100644 --- a/src/algebra/field/prime/mod.rs +++ b/src/algebra/field/prime/mod.rs @@ -36,7 +36,7 @@ pub type AESField = PrimeField<2>; /// The [`PrimeField`] struct represents elements of a field with prime order. The field is defined /// by a prime number `P`, and the elements are integers modulo `P`. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, PartialOrd)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] pub struct PrimeField { pub(crate) value: usize, } diff --git a/src/algebra/ring.rs b/src/algebra/ring.rs new file mode 100644 index 00000000..e2760ad8 --- /dev/null +++ b/src/algebra/ring.rs @@ -0,0 +1,30 @@ +use std::{ + cmp::{Eq, PartialEq}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +use super::Finite; + +pub trait Ring: + std::fmt::Debug + + Default + + Sized + + Copy + + Clone + + PartialEq + + Eq + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Neg + + 'static { + const ZERO: Self; + const ONE: Self; +} + +// pub trait FiniteRing: Finite + Ring { +// const PRIMITIVE_ELEMENT: Self; +// } diff --git a/src/encryption/asymmetric/lwe/mod.rs b/src/encryption/asymmetric/lwe/mod.rs new file mode 100644 index 00000000..fc783c98 --- /dev/null +++ b/src/encryption/asymmetric/lwe/mod.rs @@ -0,0 +1,186 @@ +//! This module implements the Learning With Errors (LWE) public-key encryption scheme. +//! +//! LWE is a lattice-based cryptosystem whose security is based on the hardness of the +//! Learning With Errors problem, first introduced by Regev in 2005. The scheme encrypts +//! single bits and provides security against quantum computers. +//! +//! The implementation uses three type parameters: +//! - Q: The modulus (must be prime) +//! - N: The dimension of the lattice +//! - K: The parameter for the binomial distribution used for error sampling + +use rand::Rng; + +use super::AsymmetricEncryption; +use crate::algebra::field::{prime::PrimeField, Field}; + +/// An implementation of Learning With Errors (LWE) encryption +pub struct LWE { + private_key: PrivateKey, + public_key: PublicKey, +} + +/// The public key consisting of a random matrix A and vector b = As + e +#[derive(Debug, Clone)] +pub struct PublicKey { + /// Random matrix in Z_q^{n×n} + a: [[PrimeField; N]; N], + /// Vector b = As + e where s is secret and e is error + b: [PrimeField; N], +} + +/// The private key consisting of the secret vector s +#[derive(Debug, Clone)] +pub struct PrivateKey { + /// Secret vector with small coefficients + s: [PrimeField; N], +} + +/// A ciphertext consisting of vector u and scalar v +#[derive(Debug, Clone)] +pub struct Ciphertext { + /// First component u = A^T r + u: [PrimeField; N], + /// Second component v = b^T r + ⌊q/2⌋m + v: PrimeField, +} + +/// Sample from a centered binomial distribution with parameter K +/// +/// Returns a value in the range [-K, K] following a discrete approximation +/// of a Gaussian distribution. +pub fn sample_binomial() -> PrimeField { + let mut rng = rand::thread_rng(); + let mut sum = PrimeField::::ZERO; + + for _ in 0..K { + if rng.gen::() { + sum += PrimeField::::ONE; + } + if rng.gen::() { + sum -= PrimeField::::ONE; + } + } + sum +} + +impl LWE { + pub fn new() -> LWE { + let mut rng = rand::thread_rng(); + + // Generate random matrix A + let mut a = [[PrimeField::::ZERO; N]; N]; + for i in 0..N { + for j in 0..N { + a[i][j] = PrimeField::from(rng.gen_range(0..Q)); + } + } + + // Sample secret vector s with small coefficients + let mut s = [PrimeField::::ZERO; N]; + for i in 0..N { + s[i] = sample_binomial::(); + } + + // Generate error vector e + let mut e = [PrimeField::::ZERO; N]; + for i in 0..N { + e[i] = sample_binomial::(); + } + + // Compute b = As + e + let mut b = [PrimeField::::ZERO; N]; + for i in 0..N { + let mut sum = PrimeField::::ZERO; + for j in 0..N { + sum += a[i][j] * s[j]; + } + sum += e[i]; + b[i] = sum; + } + + Self { public_key: PublicKey { a, b }, private_key: PrivateKey { s } } + } +} + +impl AsymmetricEncryption for LWE { + type Ciphertext = Ciphertext; + type Plaintext = bool; + type PrivateKey = PrivateKey; + type PublicKey = PublicKey; + + fn encrypt(&self, plaintext: &Self::Plaintext) -> Self::Ciphertext { + let mut rng = rand::thread_rng(); + + // Sample random vector r (binary) + let mut r = [PrimeField::::ZERO; N]; + for i in 0..N { + r[i] = PrimeField::from(rng.gen_range(0..2)); + } + + // Compute u = A^T r + let mut u = [PrimeField::::ZERO; N]; + for i in 0..N { + for j in 0..N { + u[i] += self.public_key.a[j][i] * r[j]; + } + } + + // Compute v = b^T r + ⌊q/2⌋m + let mut v = PrimeField::::ZERO; + for i in 0..N { + v += self.public_key.b[i] * r[i]; + } + + if *plaintext { + v += PrimeField::from(Q / 2); + } + + Ciphertext { u, v } + } + + fn decrypt(&self, ct: &Self::Ciphertext) -> Self::Plaintext { + // Compute v - s^T u + let mut result = ct.v; + for i in 0..N { + result -= self.private_key.s[i] * ct.u[i]; + } + + // Get q/2 as a field element + let q_half = PrimeField::::from(Q / 2); + + // For distance to zero, we need min(x, q-x) + let dist_to_zero = result.min(-result); + + // For distance to q/2, we need min(|x - q/2|, |q/2 - x|) + let dist_to_q_half = if result >= q_half { + // If result ≥ q/2, distance is min(result - q/2, q - result + q/2) + (result - q_half).min(-result + q_half) + } else { + // If result < q/2, distance is min(q/2 - result, result + q/2) + (q_half - result).min(result + q_half) + }; + + dist_to_q_half < dist_to_zero + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encryption_decryption() { + for _ in 0..100 { + let lwe = LWE::<97, 4, 2>::new(); + + // Test encryption and decryption of 0 + let ct_zero = lwe.encrypt(&false); + assert_eq!(lwe.decrypt(&ct_zero), false, "Failed decrypting 0"); + + // Test encryption and decryption of 1 + let ct_one = lwe.encrypt(&true); + assert_eq!(lwe.decrypt(&ct_one), true, "Failed decrypting 1"); + } + } +} diff --git a/src/encryption/asymmetric/mod.rs b/src/encryption/asymmetric/mod.rs index 62939b9a..1b9e310f 100644 --- a/src/encryption/asymmetric/mod.rs +++ b/src/encryption/asymmetric/mod.rs @@ -1,2 +1,16 @@ //! Contains implementation of asymmetric cryptographic primitives like RSA encryption. +pub mod lwe; pub mod rsa; + +pub trait AsymmetricEncryption { + type PublicKey; + type PrivateKey; + type Plaintext; + type Ciphertext; + + /// Encrypts plaintext using key and returns ciphertext + fn encrypt(&self, plaintext: &Self::Plaintext) -> Self::Ciphertext; + + /// Decrypts ciphertext using key and returns plaintext + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> Self::Plaintext; +} diff --git a/src/hashes/sha3.rs b/src/hashes/sha3.rs index 30d5899f..4d366fa3 100644 --- a/src/hashes/sha3.rs +++ b/src/hashes/sha3.rs @@ -41,6 +41,19 @@ const RHO: [[u32; 5]; 5] = 27, 20, 39, 8, 14, ]]; +/// Type alias for SHA3-224. +pub type Sha3_224 = Sha3<28>; +/// Type alias for SHA3-256. +pub type Sha3_256 = Sha3<32>; +/// Type alias for SHA3-384. +pub type Sha3_384 = Sha3<48>; +/// Type alias for SHA3-512. +pub type Sha3_512 = Sha3<64>; +/// Type alias for SHAKE128. +pub type Shake128 = Shake<128>; +/// Type alias for SHAKE256. +pub type Shake256 = Shake<256>; + #[derive(Clone, Debug)] struct KeccakState { lanes: [[u64; 5]; 5], @@ -277,19 +290,6 @@ impl Shake { } } -/// Type alias for SHA3-224. -pub type Sha3_224 = Sha3<28>; -/// Type alias for SHA3-256. -pub type Sha3_256 = Sha3<32>; -/// Type alias for SHA3-384. -pub type Sha3_384 = Sha3<48>; -/// Type alias for SHA3-512. -pub type Sha3_512 = Sha3<64>; -/// Type alias for SHAKE128. -pub type Shake128 = Shake<128>; -/// Type alias for SHAKE256. -pub type Shake256 = Shake<256>; - #[cfg(test)] mod tests { use hex_literal::hex; diff --git a/src/kem/README.md b/src/kem/README.md new file mode 100644 index 00000000..e104ac3e --- /dev/null +++ b/src/kem/README.md @@ -0,0 +1,23 @@ +# Key Encpasulation Mechanism + +## ToC +- What is KEM? +- How is it useful? +- Security + +KEM is a cryptographic algorithm, that is used to establish a shared key between two parties on a untrusted channel. + +KEM consists of three algorithms: +- $\text{KeyGen}$: Generates encapsulation and decapsulation keys +- $\text{Encaps}$: Encapsulates key and ciphertext that is sent to other part to generate shared key +- $\text{Decaps}$: Generates shared key from decapsulation key + +```mermaid +flowchart TB +a[Alice]-->k[KeyGen] +k--"Decapsulation Key"-->d[Decaps] +k--"Encapsulation Key"-->e[Encaps] +e--"ciphertext"-->d +d-->al[Alice's shared key] +e-->bo[Bob's shared key] +``` \ No newline at end of file diff --git a/src/kem/kyber/README.md b/src/kem/kyber/README.md new file mode 100644 index 00000000..b5735b9c --- /dev/null +++ b/src/kem/kyber/README.md @@ -0,0 +1,88 @@ +# Module Lattice Based KEM (FIPS 203) + +## ToC +- [ ] Intrduction +- [ ] Preliminaries +- [ ] Parameters +- [ ] Encryption +- [ ] KEM +- [ ] Optimisation +- [ ] Performance +- [ ] Security + + +## Introduction + +> [!NOTE] +> Implementation and terminilogy follows FIPS 203 that formalises PQ-KEM standard for the internet. + +ML-KEM is based on Module Learning with Errors problems introduced by [Reg05][Reg05]. + +## Preliminaries +- [ ] Why Polynomial Rings? +- [ ] Cyclotomic Polynomial Rings +- [ ] NTT +- [ ] Lattices + +## Implementation Details +- Polynomial Rings + - representation + - Arithmetic with Polynomials +- Matrices and Vectors + - representation + - Arithmetic +- cryptographic functions + - SHA3-256, SHA3-512 + - SHAKE128, SHAKE256 + - PRF: $\textsf{PRF}_{\eta}(s,b):=\text{SHAKE256}(s\|b,8\cdot 64 \cdot \eta)$ + - Hash function: + - $H: H(s) := \text{SHA3-256}(s)$, where $s\in \mathbb{B}^*$ + - $J: J(s) := \text{SHAKE256}(s,8\cdot 32)$, where $s\in \mathbb{B}^*$ + - $G:\ G(c \in \mathbb{B}^*) := \text{SHA3-512}(c) \in \mathbb{B}^{32}\times\mathbb{B}^{32}$ + - XOF: SHAKE128 +- General Functions: + - BitsToBytes, BytesToBits + - Compression, Decompression + - ByteEncode, ByteDecode +- Sampling algorithms: + - SampleNTT +- NTT, RevNTT +- MultiplyNTT +- BaseCaseMultiply + +## NTT: Number-Theroretic Transform + +- Special case of FFT performed over $\mathbb{Z}_q^*$. +- Most computationally intensive operation in encryption scheme is the multiplication of polynomials in $\mathcal{R}_{q,f}$ that is $\mathcal{O}(n^2)$, where $n$ is the degree of the polynomial. NTT allows to perform this in $\mathcal{O}(n\log n)$. +- Polynomial: $\mathbb{Z}_q[X]/(X^d+\alpha)$, where $d=2^k$ and $-\alpha$ is perfect square of $d\in\mathbb{Z}_q$ + - $X^d+\alpha\equiv(X^{d/2}-r)(X^{d/2}+r)\mod q$ +- Using, Chinese Remainder Theorem, $ab\in\mathbb{Z}_q/(X^d+\alpha)$ can be written as $ab \mod (X^{d/2}+r),\ ab \mod (X^{d/2}-r)$, and can be converted back to $\mod (X^d+\alpha)$. +- Doing this recursively, asymptotic complexity of the algorithm turns out to be $d\log d$ + +## Parameters +- $q=3329=2^8+1$ +- $f=X^{256}+1$ +- $k,\eta_1,\eta_2,d_u,d_v$ vary according to parameter sets. + +| | k | $\eta_1$ | $\eta_2$ | $d_u$ | $d_v$ | decryption error | pk size | ciphertext size | +| ---------- | --- | -------- | -------- | ----- | ----- | ---------------- | ------- | --------------- | +| Kyber-512 | 3 | 3 | 2 | 10 | 4 | $2^{−139}$ | 800B | 768B | +| Kyber-768 | 4 | 2 | 2 | 10 | 4 | $2^{−164}$ | 1184B | 1088B | +| Kyber-1024 | 5 | 2 | 2 | 11 | 5 | $2^{−174}$ | 1568B | 1568B | + +## CPA-secure Encryption Scheme +1. $\text{KeyGen}$ +2. + +## TODO + +- [ ] FIPS 202: SHA3-{256,512}, SHAKE{128,256} +- [ ] Rings, Polynomial Rings + - [ ] arithmetic for rings + - [ ] + +## References +- [Module-Lattice-Based Key-Encapsulation Mechanism Standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf) +- [RustCrypto/KEMs](https://github.com/RustCrypto/KEMs): const functions inspiration + +[Reg05]: diff --git a/src/kem/kyber/algebra.rs b/src/kem/kyber/algebra.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/kem/kyber/auxiliary.rs b/src/kem/kyber/auxiliary.rs new file mode 100644 index 00000000..1a60fc82 --- /dev/null +++ b/src/kem/kyber/auxiliary.rs @@ -0,0 +1,49 @@ +//! Contains auxiliary cryptographic functions for Kyber KEM. +//! - [`prf`] - Pseudorandom function +//! - [`h`], [`g`] - Hash function +//! - [`Xof`] - Extendable output function + +use crate::hashes::sha3::{Sha3_256, Sha3_512, Shake128, Shake256}; + +pub fn prf(s: &[u8], b: u8) -> [u8; 64 * ETA] { + assert!(s.len() == 32); + + let mut hasher = Shake256::new(); + hasher.update(s); + hasher.update(&[b]); + let mut res = [0u8; 64 * ETA]; + hasher.squeeze(&mut res); + res +} + +pub fn h(s: &[u8]) -> [u8; 32] { + let mut hasher = Sha3_256::new(); + hasher.update(s); + hasher.finalize() +} + +pub fn j(s: &[u8]) -> [u8; 32] { + let mut hasher = Shake256::new(); + hasher.update(s); + let mut res = [0u8; 32]; + hasher.squeeze(&mut res); + res +} + +pub fn g(c: &[u8]) -> ([u8; 32], [u8; 32]) { + let mut hasher = Sha3_512::new(); + hasher.update(c); + let res = hasher.finalize(); + let (h0, h1) = res.split_at(32); + (h0.try_into().unwrap(), h1.try_into().unwrap()) +} + +pub struct Xof(Shake128); + +impl Xof { + pub fn init() -> Self { Self(Shake128::new()) } + + pub fn absorb(&mut self, input: &[u8]) { self.0.update(input); } + + pub fn squeeze(&mut self, output: &mut [u8]) { self.0.squeeze(output); } +} diff --git a/src/kem/kyber/compress.rs b/src/kem/kyber/compress.rs new file mode 100644 index 00000000..10f78b6d --- /dev/null +++ b/src/kem/kyber/compress.rs @@ -0,0 +1,80 @@ +use super::{MlKemField, PolyVec}; +use crate::{ + algebra::Finite, + polynomial::{Monomial, Polynomial}, +}; + +/// Compresses a number x to a number in the range [0, 2^d) using the formula round((2^d / q) * x) +/// mod 2^d. +/// round(a / b) = floor((a + b/2) / b) +pub fn compress_fieldelement(x: &MlKemField) -> MlKemField { + // TODO: Implement using barrett reduction + let q_half = (MlKemField::ORDER + 1) >> 1; + MlKemField::new((((x.value << d) + q_half) / MlKemField::ORDER) % (1 << d)) +} + +/// Decompresses a number y to a number in the range [0, q) using the formula round((q / 2^d)) * y. +pub fn decompress_fieldelement(y: &MlKemField) -> MlKemField { + let d_pow_half = 1 << (d - 1); + let quotient = MlKemField::ORDER * y.value + d_pow_half; + MlKemField::new(quotient >> d) +} + +pub fn poly_compress( + poly: &Polynomial, +) -> Polynomial { + // TODO: remove unwrap + let coeffs = poly + .coefficients + .iter() + .map(compress_fieldelement::) + .collect::>() + .try_into() + .unwrap(); + + Polynomial::::new(coeffs) +} + +pub fn poly_decompress( + poly: &[MlKemField; D], +) -> Polynomial { + let mut coefficients = [MlKemField::default(); D]; + for (i, x) in poly.iter().enumerate() { + coefficients[i] = decompress_fieldelement::(x); + } + Polynomial::::new(coefficients) +} + +pub fn polyvec_compress( + poly_vec: &PolyVec, +) -> PolyVec { + let mut res = Vec::with_capacity(K); + + for poly in poly_vec.vec.iter() { + res.push(poly_compress::(poly)); + } + + let res = res.try_into().unwrap(); + PolyVec::new(res) +} + +pub fn polyvec_decompress( + poly_vec: &PolyVec, +) -> PolyVec { + let mut res = Vec::with_capacity(K); + + for poly in poly_vec.vec.iter() { + res.push(poly_decompress::(&poly.coefficients)); + } + + let res = res.try_into().unwrap(); + PolyVec::new(res) +} + +#[test] +fn test_compress_decompress() { + let x = MlKemField::new(10); + let z = decompress_fieldelement::<8>(&x); + let y = compress_fieldelement::<8>(&z); + assert_eq!(x, y); +} diff --git a/src/kem/kyber/encode.rs b/src/kem/kyber/encode.rs new file mode 100644 index 00000000..212d9122 --- /dev/null +++ b/src/kem/kyber/encode.rs @@ -0,0 +1,136 @@ +use super::{MlKemField, PolyVec}; +use crate::{ + algebra::field::Field, + polynomial::{Basis, Polynomial}, +}; + +/// Encodes a field element into a byte array where each field element is represented by d bits. +/// Converts the field element into a binary representation and then packs the bits into bytes. +pub fn byte_encode(f: &[MlKemField; D]) -> Vec { + let mut encoded_bits = Vec::with_capacity(D * d); + + for (i, x) in f.iter().enumerate() { + let mut val = x.value; + for j in 0..d { + encoded_bits[i * d + j] = (val & 1) as u8; + val >>= 1; + } + } + + let mut encoded_bytes = Vec::with_capacity(D / 8 * d); + for (i, chunk) in encoded_bits.chunks(8).enumerate() { + encoded_bytes[i] = chunk.iter().enumerate().fold(0, |acc, (j, &b)| acc | (b << j)); + } + + encoded_bytes +} + +pub fn byte_encode_polyvec( + f: PolyVec, +) -> Vec { + let mut encoded_bytes = Vec::with_capacity(D / 8 * d * K); + for (i, poly) in f.vec.iter().enumerate() { + let encoded = byte_encode::(&poly.coefficients); + encoded_bytes[i * D / 8 * d..(i + 1) * D / 8 * d].copy_from_slice(&encoded); + } + + encoded_bytes +} + +/// Encodes a field element into a byte array where each field element is represented by D bits. +/// Converts the field element into a binary representation and then packs the bits into bytes. +fn byte_encode_optimized(f: [MlKemField; 256]) -> [u8; 32 * D] +where [(); 256 * D]: { + let mut encoded_bytes = [0u8; 32 * D]; + + // Process 8 field elements at a time to fill each D bytes + for chunk_idx in 0..32 { + let chunk_start = chunk_idx * 8; + let elements = &f[chunk_start..chunk_start + 8]; + + // For each bit position within the D bits + for bit_pos in 0..D { + let byte_idx = chunk_idx * D + bit_pos; + + // Collect bits from 8 elements into a single byte + encoded_bytes[byte_idx] = elements + .iter() + .enumerate() + .fold(0, |acc, (j, x)| acc | (((x.value >> bit_pos) & 1) as u8) << j); + } + } + + encoded_bytes +} + +/// Decodes a byte array into a field element where each field element is represented by D bits. +/// Unpacks the bytes into bits and then converts the bits into a field element. +pub fn byte_decode(encoded_bytes: &[u8]) -> [MlKemField; D] { + let mut encoded_bits = Vec::with_capacity(256 * d); + for (i, &byte) in encoded_bytes.iter().enumerate() { + for j in 0..8 { + encoded_bits.push((byte >> j) & 1); + } + } + + let mask: usize = (1 << d) - 1; + let mut f = [MlKemField::ZERO; D]; + for (i, chunk) in encoded_bits.chunks(D).enumerate() { + let mut val = 0; + for (j, &bit) in chunk.iter().enumerate() { + val |= ((bit as usize) << j) & mask; + } + f[i].value = val; + } + + f +} + +pub fn byte_decode_polyvec( + encoded_bytes: &[u8], + basis: B, +) -> PolyVec { + let mut f = Vec::with_capacity(K); + + for bytes in encoded_bytes.chunks(32 * d) { + let coeffs = byte_decode::(bytes); + f.push(Polynomial { coefficients: coeffs, basis: basis.clone() }) + } + + let f = f.try_into().unwrap(); + PolyVec::new(f) +} + +#[cfg(test)] +mod tests { + use super::*; + extern crate test; + + fn generate_test_data() -> [MlKemField; 256] { + let mut data = [MlKemField { value: 0 }; 256]; + for i in 0..256 { + data[i].value = i; + } + data + } + + #[test] + fn test_byte_encode() { + let f = generate_test_data(); + let encoded = byte_encode::<8, 256>(&f); + let encoded_optimized = byte_encode_optimized::<8>(f); + assert_eq!(encoded, encoded_optimized); + } + + #[bench] + fn bench_byte_encode(b: &mut test::Bencher) { + let f = generate_test_data(); + b.iter(|| byte_encode::<8, 256>(&f)); + } + + #[bench] + fn bench_byte_encode_optimized(b: &mut test::Bencher) { + let f = generate_test_data(); + b.iter(|| byte_encode_optimized::<8>(f)); + } +} diff --git a/src/kem/kyber/kpke.rs b/src/kem/kyber/kpke.rs new file mode 100644 index 00000000..728340c0 --- /dev/null +++ b/src/kem/kyber/kpke.rs @@ -0,0 +1,10 @@ +// use super::{auxiliary::g, MatrixPolyVec}; + +// pub struct KpkeEncryptionKey([u8; 384*K+32]); +// pub struct KpkeDecryptionKey([u8; 384*K]); +// pub fn kpke_keygen(d: [u8; 32]) -> (KpkeEncryptionKey, KpkeDecryptionKey) { +// let (rho, sigma) = g(d, K); + +// let mut a_ntt = MatrixPolyVec:: + +// } diff --git a/src/kem/kyber/mod.rs b/src/kem/kyber/mod.rs new file mode 100644 index 00000000..749213ca --- /dev/null +++ b/src/kem/kyber/mod.rs @@ -0,0 +1,491 @@ +use std::ops::{Add, AddAssign, Mul}; + +use auxiliary::{g, h}; +use compress::{poly_compress, poly_decompress, polyvec_compress, polyvec_decompress}; +use encode::{byte_decode, byte_decode_polyvec, byte_encode, byte_encode_polyvec}; +use rand::Rng; + +use crate::{ + algebra::field::prime::PrimeField, + polynomial::{Basis, Monomial, Polynomial}, +}; + +mod auxiliary; +mod compress; +mod encode; +// mod kpke; +mod ntt; +mod sampling; +// #[cfg(test)] mod tests; + +pub const MLKEM_Q: usize = 3329; +pub const MLKEM_N: usize = 256; + +pub struct KPke { + pub eta1: usize, + pub eta2: usize, + pub du: usize, + pub dv: usize, +} + +pub type MlKemField = PrimeField; + +#[derive(Debug, Clone, Copy)] +pub struct Ntt; +impl Basis for Ntt { + type Data = (); +} + +pub type PolyRing = Polynomial; +pub type NttPolyRing = Polynomial; + +impl Add for &NttPolyRing { + type Output = NttPolyRing; + + fn add(self, rhs: Self) -> Self::Output { + let coeffs = self + .coefficients + .iter() + .zip(rhs.coefficients.iter()) + .map(|(&a, &b)| a + b) + .collect::>() + .try_into() + .unwrap(); + NttPolyRing::new(coeffs) + } +} + +impl AddAssign for NttPolyRing { + fn add_assign(&mut self, rhs: Self) { + for i in 0..D { + self.coefficients[i] += rhs.coefficients[i]; + } + } +} + +impl NttPolyRing { + pub fn new(coeffs: [MlKemField; D]) -> Self { Self { coefficients: coeffs, basis: Ntt } } +} + +#[derive(Debug, Clone, Copy)] +pub struct PolyVec { + // TODO: might need to be written as [Polynomial; K] + pub vec: [Polynomial; K], +} + +impl PolyVec { + pub fn new(vec: [Polynomial; K]) -> Self { Self { vec } } +} + +impl PolyVec { + pub fn dot_product(&self, rhs: &Self) -> NttPolyRing { + let mut result = NttPolyRing::new([MlKemField::ZERO; D]); + for i in 0..K { + result += &self.vec[i] * &rhs.vec[i]; + } + result + } +} + +impl Add for PolyVec { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let vec = (0..K) + .map(|i| &self.vec[i] + &rhs.vec[i]) + .collect::>>() + .try_into() + .unwrap(); + Self::new(vec) + } +} + +impl Add for PolyVec { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let vec = (0..K) + .map(|i| self.vec[i] + rhs.vec[i]) + .collect::>>() + .try_into() + .unwrap(); + Self::new(vec) + } +} + +// impl Iterator for PolyVec { +// type Item = Polynomial; + +// fn next(&mut self) -> Option { self.vec.iter().next().copied() } +// } + +/// A matrix of polynomial vectors of size K \times K. +pub struct MatrixPolyVec { + pub vec: [PolyVec; K], +} + +impl MatrixPolyVec { + /// Creates a new matrix of polynomial vectors. + pub fn new(vec: [PolyVec; K]) -> Self { Self { vec } } +} + +impl MatrixPolyVec { + pub fn transpose(&self) -> Self { + let poly_vec = (0..K) + .map(|col_idx| { + let vec = (0..K) + .map(|row_idx| self.vec[row_idx].vec[col_idx].clone()) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); + + Self::new(poly_vec) + } +} + +use crate::algebra::field::Field; + +impl Mul<&PolyVec> for MatrixPolyVec { + type Output = PolyVec; + + fn mul(self, rhs: &PolyVec) -> Self::Output { + let res = (0..K) + .map(|i| { + let mut sum = NttPolyRing::new([MlKemField::ZERO; D]); + for j in 0..K { + sum += &self.vec[i].vec[j] * &rhs.vec[j]; + } + sum + }) + .collect::>>() + .try_into() + .unwrap(); + + PolyVec::new(res) + } +} + +pub struct KPkeEncryptionKey([u8; 384 * K + 32]) +where [(); 384 * K + 32]:; +pub struct KPkeDecryptionKey([u8; 384 * K]) +where [(); 384 * K]:; + +impl KPke { + pub fn new(eta1: usize, eta2: usize, du: usize, dv: usize) -> Self { Self { eta1, eta2, du, dv } } + + pub fn pke_keygen( + &self, + d: [u8; 32], + ) -> (KPkeEncryptionKey, KPkeDecryptionKey) + where + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 384 * K + 32]:, + [(); 384 * K]:, + { + let mut input = d.to_vec(); + input.push(K as u8); + let (rho, sigma) = auxiliary::g(&input); + + let mut n = 0; + + let a_hat_vec: [PolyVec; K] = (0..K) + .map(|i| { + let vec: [NttPolyRing; K] = (0..K) + .map(|j| { + let coeffs = sampling::sample_ntt(&rho, j as u8, i as u8); + NttPolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); + let a_hat: MatrixPolyVec = MatrixPolyVec::new(a_hat_vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&sigma, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let s: PolyVec = PolyVec::new(vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&sigma, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let e: PolyVec = PolyVec::new(vec); + + let s_hat = s.ntt(); + let e_hat = e.ntt(); + + let t_hat = a_hat * &s_hat + e_hat; + + let ek = [byte_encode_polyvec::(t_hat), rho.to_vec()] + .concat() + .try_into() + .unwrap(); + let dk = byte_encode_polyvec::(s_hat).try_into().unwrap(); + + (KPkeEncryptionKey(ek), KPkeDecryptionKey(dk)) + } + + pub fn kpke_encrypt( + &self, + encryption_key: KPkeEncryptionKey, + m: [u8; 32], + r: [u8; 32], + ) -> Vec + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let mut n = 0; + let (t_hat, rho) = encryption_key.0.split_at(384 * K); + let t_hat = byte_decode_polyvec::(t_hat, Ntt); + let rho: [u8; 32] = rho.try_into().unwrap(); + + let a_hat_vec: [PolyVec; K] = (0..K) + .map(|i| { + let vec: [NttPolyRing; K] = (0..K) + .map(|j| { + let coeffs = sampling::sample_ntt(&rho, j as u8, i as u8); + NttPolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + PolyVec::new(vec) + }) + .collect::>>() + .try_into() + .unwrap(); + let a_hat: MatrixPolyVec = MatrixPolyVec::new(a_hat_vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&r, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let y: PolyVec = PolyVec::new(vec); + + let vec: [PolyRing; K] = (0..K) + .map(|_| { + let seed = auxiliary::prf::(&r, n); + n += 1; + let coeffs = sampling::sample_poly_cbd::(seed); + PolyRing::new(coeffs) + }) + .collect::>>() + .try_into() + .unwrap(); + let e1: PolyVec = PolyVec::new(vec); + + let e2 = PolyRing::new(sampling::sample_poly_cbd::(auxiliary::prf::(&r, n))); + + let y_hat = y.ntt(); + + let u = (a_hat.transpose() * &y_hat).ntt_inv() + e1; + + let mu = poly_decompress::(&byte_decode::<1, MLKEM_N>(&m)); + + let v = t_hat.dot_product(&y_hat).ntt_inv() + e2 + mu; + // let c1 = poly_compress::(&u); + let c1_compressed = polyvec_compress::(&u); + let c1 = byte_encode_polyvec::(c1_compressed); + let c2_compressed = poly_compress::(&v); + let c2 = byte_encode::(&c2_compressed.coefficients); + [c1, c2].concat() + } + + pub fn kpke_decrypt( + &self, + decryption_key: KPkeDecryptionKey, + c: Vec, + ) -> [u8; 32] + where + [(); 384 * K]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let (c1, c2) = c.split_at(384 * K); + let c1 = byte_decode_polyvec::(c1, Monomial); + let u_prime = polyvec_decompress::(&c1); + let v_prime = poly_decompress::(&byte_decode::(c2)); + + let s_ntt = byte_decode_polyvec::(&decryption_key.0, Ntt); + let w = v_prime - s_ntt.dot_product(&u_prime.ntt()).ntt_inv(); + let m = byte_encode::<1, MLKEM_N>(&poly_compress::(&w).coefficients); + m.try_into().unwrap() + } +} + +pub struct MlKemEncapsKey([u8; 384 * K + 32]) +where [(); 384 * K + 32]:; +pub struct MlKemDecapsKey +where [(); 384 * K + 32]: { + dk_pke: KPkeDecryptionKey, + ek_pke: KPkeEncryptionKey, + h: [u8; 32], + z: [u8; 32], +} + +pub struct MlKem { + kpke: KPke, +} + +impl MlKem { + fn keygen_internal( + &self, + d: [u8; 32], + z: [u8; 32], + ) -> (MlKemEncapsKey, MlKemDecapsKey) + where + [(); 384 * K + 32]:, + [(); 768 * K + 96]:, + [(); 64 * eta1 * 8]:, + { + let (ek_pke, dk_pke) = self.kpke.pke_keygen::(d); + let ek: [u8; 384 * K + 32] = ek_pke.0; + let h = h(&ek); + + (MlKemEncapsKey(ek), MlKemDecapsKey { dk_pke, ek_pke, h, z }) + } + + fn encaps_internal( + &self, + ek: MlKemEncapsKey, + m: [u8; 32], + ) -> ([u8; 32], Vec) + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let (k, r) = g([m, h(&ek.0)].concat().as_ref()); + let c = self.kpke.kpke_encrypt::(KPkeEncryptionKey(ek.0), m, r); + (k, c) + } + + fn decaps_internal( + &self, + dk: MlKemDecapsKey, + c: Vec, + ) -> [u8; 32] + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + let m_prime = self.kpke.kpke_decrypt::(dk.dk_pke, c.clone()); + let (mut k_prime, r_prime) = g([m_prime, dk.h].concat().as_ref()); + + let mut pseudo_input = dk.z.to_vec(); + pseudo_input.extend(c.clone()); + let k_bar = auxiliary::j(&pseudo_input); + let c_prime = self.kpke.kpke_encrypt::(dk.ek_pke, m_prime, r_prime); + if c != c_prime { + k_prime = k_bar; + } + k_prime + } + + pub fn keygen( + &self, + ) -> Result<(MlKemEncapsKey, MlKemDecapsKey), String> + where + [(); 384 * K + 32]:, + [(); 768 * K + 96]:, + [(); 64 * eta1 * 8]:, { + // TODO: check if rng is good to use + let mut rng = rand::thread_rng(); + let d: [u8; 32] = rng.gen::<[u8; 32]>(); + let z = rand::thread_rng().gen::<[u8; 32]>(); + if d.is_empty() || z.is_empty() { + return Err("Error: keygen failed".to_string()); + } + Ok(self.keygen_internal::(d, z)) + } + + pub fn encaps( + &self, + ek: MlKemEncapsKey, + m: [u8; 32], + ) -> Result<([u8; 32], Vec), String> + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + // TODO: run encaps key check + if m.is_empty() { + return Err("Error: encaps failed".to_string()); + } + Ok(self.encaps_internal::(ek, m)) + } + + pub fn decaps( + &self, + dk: MlKemDecapsKey, + c: Vec, + ) -> Result<[u8; 32], String> + where + [(); 384 * K + 32]:, + [(); 64 * eta1]:, + [(); 64 * eta1 * 8]:, + [(); 64 * eta2]:, + [(); 64 * eta2 * 8]:, + { + if c.len() != 32 * (du * K + dv) { + return Err("Error: invalid ciphertext".to_string()); + } + + Ok(self.decaps_internal::(dk, c)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn poly_ring() { + let coeffs = [MlKemField::new(1); 256]; + let poly: PolyRing<256> = PolyRing::new(coeffs); + } +} diff --git a/src/kem/kyber/ntt.rs b/src/kem/kyber/ntt.rs new file mode 100644 index 00000000..f0e00d3e --- /dev/null +++ b/src/kem/kyber/ntt.rs @@ -0,0 +1,129 @@ +use std::ops::Mul; + +use super::{MlKemField, Ntt, PolyVec}; +use crate::{ + algebra::{field::Field, Finite}, + polynomial::{Monomial, Polynomial}, +}; + +const ZETA: usize = 17; +const fn bitrev_7(x: usize) -> usize { + (x >> 6) & 1 + | ((x >> 5) & 1) << 1 + | ((x >> 4) & 1) << 2 + | ((x >> 3) & 1) << 3 + | ((x >> 2) & 1) << 4 + | ((x >> 1) & 1) << 5 + | (x & 1) << 6 +} + +const ZETA_POW: [MlKemField; 128] = { + let mut i = 0; + let mut curr = 1; + let mut zeta_pow = [MlKemField::ZERO; 128]; + while i < 128 { + zeta_pow[i] = MlKemField::new(curr); + curr = curr * ZETA % MlKemField::ORDER; + i += 1; + } + + let mut zeta_pow_rev = [MlKemField::ZERO; 128]; + let mut i = 0; + while i < 128 { + zeta_pow_rev[i] = zeta_pow[bitrev_7(i)]; + i += 1; + } + + zeta_pow_rev +}; + +const GAMMA: [MlKemField; 128] = { + let mut gamma = [MlKemField::ZERO; 128]; + let mut i = 0; + while i < 128 { + gamma[i] = + MlKemField { value: (ZETA_POW[i].value * ZETA_POW[i].value * ZETA) % MlKemField::ORDER }; + i += 1; + } + + gamma +}; + +impl Polynomial { + pub fn ntt(self) -> Polynomial { + let mut f_ntt_coeffs = [MlKemField::ZERO; D]; + let mut i = 1; + let mut len = 128; + while len >= 2 { + for start in (0..256).step_by(len * 2) { + let zeta = ZETA_POW[i]; + i += 1; + for j in start..start + len { + let t = zeta * self.coefficients[j + len]; + f_ntt_coeffs[j + len] = self.coefficients[j] - t; + f_ntt_coeffs[j] = self.coefficients[j] + t; + } + } + len >>= 1; + } + + Polynomial:: { coefficients: f_ntt_coeffs, basis: Ntt } + } +} + +impl PolyVec { + pub fn ntt(self) -> PolyVec { + let ntt_vec = self.vec.iter().map(|poly| poly.ntt()).collect::>().try_into().unwrap(); + + PolyVec:: { vec: ntt_vec } + } +} + +impl PolyVec { + pub fn ntt_inv(self) -> PolyVec { + let ntt_inv_vec = + self.vec.iter().map(|poly| poly.ntt_inv()).collect::>().try_into().unwrap(); + + PolyVec { vec: ntt_inv_vec } + } +} + +impl Mul<&Polynomial> for &Polynomial { + type Output = Polynomial; + + fn mul(self, rhs: &Polynomial) -> Self::Output { + let mut res_coeffs = [MlKemField::ZERO; D]; + + for i in 0..D >> 1 { + let (a0, a1) = (self.coefficients[2 * i], self.coefficients[2 * i + 1]); + let (b0, b1) = (rhs.coefficients[2 * i], rhs.coefficients[2 * i + 1]); + let (c0, c1) = (a0 * b0 + GAMMA[i] * a1 * b1, a0 * b1 + a1 * b0); + res_coeffs[2 * i] = c0; + res_coeffs[2 * i + 1] = c1; + } + + Polynomial:: { coefficients: res_coeffs, basis: Ntt } + } +} + +impl Polynomial { + pub fn ntt_inv(self) -> Polynomial { + let mut f_coeffs = [MlKemField::ZERO; D]; + let mut i = 127; + let mut len = 2; + while len <= 128 { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW[i]; + i -= 1; + for j in start..start + len { + let t = self.coefficients[j]; + f_coeffs[j] = t + self.coefficients[j + len]; + f_coeffs[j + len] = zeta * (self.coefficients[j + len] - t); + } + } + len <<= 1; + } + + Polynomial:: { coefficients: f_coeffs, basis: Monomial } + } +} diff --git a/src/kem/kyber/sampling.rs b/src/kem/kyber/sampling.rs new file mode 100644 index 00000000..eb48398f --- /dev/null +++ b/src/kem/kyber/sampling.rs @@ -0,0 +1,52 @@ +use super::{auxiliary::Xof, MlKemField}; +use crate::algebra::{field::Field, Finite}; + +pub fn sample_ntt(rho: &[u8], j: u8, i: u8) -> [MlKemField; 256] { + assert!(rho.len() == 32); + let mut input = [0u8; 34]; + input[..32].copy_from_slice(rho); + input[32] = j; + input[33] = i; + + let mut ntt = [MlKemField::ZERO; 256]; + + let mut xof = Xof::init(); + xof.absorb(&input); + let mut j = 0; + while j < 256 { + let mut buf = [0u8; 3]; + xof.squeeze(&mut buf); + + let d_1 = buf[0] as usize + ((buf[1] as usize & 0xf) << 8); + let d_2 = (buf[1] >> 4) as usize + ((buf[2] as usize) << 4); + + if d_1 < MlKemField::ORDER { + ntt[j] = MlKemField::new(d_1); + j += 1; + } + if d_2 < MlKemField::ORDER && j < 256 { + ntt[j] = MlKemField::new(d_2); + j += 1; + } + } + ntt +} + +pub fn sample_poly_cbd(seed: [u8; 64 * ETA]) -> [MlKemField; 256] +where [(); 64 * ETA * 8]: { + let mut bit_encode = [0u8; 64 * ETA * 8]; + for i in 0..64 * ETA { + for j in 0..8 { + bit_encode[i * 8 + j] = (seed[i] >> j) & 1; + } + } + + let mut res = [MlKemField::ZERO; 256]; + for i in 0..256 { + let x = (0..ETA).fold(0, |acc, j| acc + bit_encode[2 * i * ETA + j]); + let y = (0..ETA).fold(0, |acc, j| acc + bit_encode[(2 * i + 1) * ETA + j]); + res[i] = MlKemField::new((x - y) as usize); + } + + res +} diff --git a/src/kem/kyber/tests.rs b/src/kem/kyber/tests.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/kem/mod.rs b/src/kem/mod.rs new file mode 100644 index 00000000..19c6bc24 --- /dev/null +++ b/src/kem/mod.rs @@ -0,0 +1 @@ +pub mod kyber; diff --git a/src/lib.rs b/src/lib.rs index 7acc5a38..527e0c0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ //! operations such as point addition, scalar multiplication, and pairing operations. //! - Compiler: Simple DSL to write circuits which can be compiled to polynomials used in PLONK. +#![feature(test)] #![allow(incomplete_features)] #![feature(effects)] #![feature(const_trait_impl)] @@ -21,7 +22,6 @@ #![feature(const_option)] #![feature(generic_const_exprs)] #![feature(specialization)] -#![feature(test)] #![warn(missing_docs)] pub mod algebra; @@ -33,6 +33,7 @@ pub mod dsa; pub mod encryption; pub mod hashes; pub mod hmac; +pub mod kem; pub mod kzg; pub mod multi_var_poly; pub mod polynomial; diff --git a/src/polynomial/arithmetic.rs b/src/polynomial/arithmetic.rs index fa3c99b7..6371e808 100644 --- a/src/polynomial/arithmetic.rs +++ b/src/polynomial/arithmetic.rs @@ -34,6 +34,27 @@ impl Add Add<&Polynomial> + for &Polynomial +{ + type Output = Polynomial; + + /// Implements addition of two polynomials by adding their coefficients. + /// Note: degree of first operand > deg of second operand. + fn add(self, rhs: &Polynomial) -> Self::Output { + let coefficients = self + .coefficients + .iter() + .zip(rhs.coefficients.iter().chain(std::iter::repeat(&F::ZERO))) + .map(|(&a, &b)| a + b) + .take(D) + .collect::>() + .try_into() + .unwrap_or_else(|v: Vec| panic!("Expected a Vec of length {} but it was {}", D, v.len())); + Self::Output { coefficients, basis: self.basis } + } +} + impl AddAssign> for Polynomial { diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 378ccfc1..fbb2892c 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -17,7 +17,7 @@ //! - Includes Discrete Fourier Transform (DFT) for polynomials in the [`Monomial`] basis to convert //! into the [`Lagrange`] basis via evaluation at the roots of unity. -use std::array; +use std::{array, fmt::Debug}; use super::*; use crate::algebra::field::FiniteField; @@ -45,7 +45,7 @@ pub struct Polynomial { /// [`Basis`] trait is used to specify the basis of the polynomial. /// The basis can be [`Monomial`] or [`Lagrange`]. This is a type-state pattern for [`Polynomial`]. -pub trait Basis { +pub trait Basis: Debug + Clone { /// The associated data type for the basis. type Data; }