diff --git a/src/androidTest/java/com/afkanerd/.DS_Store b/src/androidTest/java/com/afkanerd/.DS_Store new file mode 100644 index 0000000..e4cabc0 Binary files /dev/null and b/src/androidTest/java/com/afkanerd/.DS_Store differ diff --git a/src/androidTest/java/com/afkanerd/smswithoutborders/.DS_Store b/src/androidTest/java/com/afkanerd/smswithoutborders/.DS_Store new file mode 100644 index 0000000..68565d0 Binary files /dev/null and b/src/androidTest/java/com/afkanerd/smswithoutborders/.DS_Store differ diff --git a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityRSATest.java b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityRSATest.java index ec9ebd2..98487f6 100644 --- a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityRSATest.java +++ b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityRSATest.java @@ -43,12 +43,12 @@ public void testCanStoreAndEncrypt() throws NoSuchAlgorithmException, NoSuchProv // .build()); // // KeyPair keyPair = kpg.generateKeyPair(); - PublicKey publicKey = SecurityRSA.generateKeyPair(keystoreAlias, 2048); - KeyPair keyPair = KeystoreHelpers.getKeyPairFromKeystore(keystoreAlias); - - SecretKey secretKey = SecurityAES.generateSecretKey(256); - byte[] cipherText = SecurityRSA.encrypt(keyPair.getPublic(), secretKey.getEncoded()); - byte[] plainText = SecurityRSA.decrypt(keyPair.getPrivate(), cipherText); - assertArrayEquals(secretKey.getEncoded(), plainText); +// PublicKey publicKey = SecurityRSA.generateKeyPair(keystoreAlias, 2048); +// KeyPair keyPair = KeystoreHelpers.getKeyPairFromKeystore(keystoreAlias); +// +// SecretKey secretKey = SecurityAES.generateSecretKey(256); +// byte[] cipherText = SecurityRSA.encrypt(keyPair.getPublic(), secretKey.getEncoded()); +// byte[] plainText = SecurityRSA.decrypt(keyPair.getPrivate(), cipherText); +// assertArrayEquals(secretKey.getEncoded(), plainText); } } diff --git a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityX25519Test.kt b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityX25519Test.kt index 52a1f89..b8fa7c2 100644 --- a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityX25519Test.kt +++ b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityX25519Test.kt @@ -1,12 +1,50 @@ package com.afkanerd.smswithoutborders.libsignal_doubleratchet +import android.security.keystore.KeyGenParameterSpec +import android.security.keystore.KeyProperties import androidx.test.filters.SmallTest -import junit.framework.TestCase.assertEquals import org.junit.Assert.assertArrayEquals import org.junit.Test +import java.security.KeyPairGenerator +import java.security.KeyStore +import java.security.Signature @SmallTest class SecurityX25519Test { + @Test + fun keystoreEd25519() { + val keystoreAlias = "keystoreAlias" + val kpg: KeyPairGenerator = KeyPairGenerator.getInstance( + KeyProperties.KEY_ALGORITHM_EC, + "AndroidKeyStore" + ) + val parameterSpec: KeyGenParameterSpec = KeyGenParameterSpec.Builder( + keystoreAlias, + KeyProperties.PURPOSE_SIGN or KeyProperties.PURPOSE_VERIFY + ).run { + setDigests(KeyProperties.DIGEST_SHA256, KeyProperties.DIGEST_SHA512) + build() + } + + kpg.initialize(parameterSpec) + val kp = kpg.generateKeyPair() + + val ks: KeyStore = KeyStore.getInstance("AndroidKeyStore").apply { + load(null) + } + val entry: KeyStore.Entry = ks.getEntry(keystoreAlias, null) + if (entry !is KeyStore.PrivateKeyEntry) { + throw Exception("No instance of keystore") + } + + val data = "Hello world".encodeToByteArray() + val signature: ByteArray = Signature.getInstance("SHA256withECDSA").run { + initSign(entry.privateKey) + update(data) + sign() + } + + } @Test fun sharedSecret() { diff --git a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsTest.kt b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsTest.kt index 63a8fc3..63c31af 100644 --- a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsTest.kt +++ b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsTest.kt @@ -5,6 +5,7 @@ import androidx.core.util.component1 import androidx.core.util.component2 import androidx.test.filters.SmallTest import androidx.test.platform.app.InstrumentationRegistry +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.CryptoHelpers import com.afkanerd.smswithoutborders.libsignal_doubleratchet.SecurityCurve25519 import org.junit.Assert.assertArrayEquals import org.junit.Test @@ -14,6 +15,108 @@ import java.security.SecureRandom class RatchetsTest { var context: Context = InstrumentationRegistry.getInstrumentation().targetContext + @Test + fun completeRatchetHETest() { + val aliceEphemeralKeyPair = SecurityCurve25519() + val aliceEphemeralHeaderKeyPair = SecurityCurve25519() + val aliceEphemeralNextHeaderKeyPair = SecurityCurve25519() + + val bobStaticKeyPair = SecurityCurve25519() + val bobEphemeralKeyPair = SecurityCurve25519() + val bobEphemeralHeaderKeyPair = SecurityCurve25519() + val bobEphemeralNextHeaderKeyPair = SecurityCurve25519() + + val aliceNonce = CryptoHelpers.generateRandomBytes(16) + val bobNonce = CryptoHelpers.generateRandomBytes(16) + + val (aliceSk, aliceSkH, aliceSkNh) = SecurityCurve25519(aliceEphemeralKeyPair.privateKey) + .agreeWithAuthAndNonce( + authenticationPublicKey = bobStaticKeyPair.generateKey(), + authenticationPrivateKey = null, + headerPrivateKey = aliceEphemeralHeaderKeyPair.privateKey, + nextHeaderPrivateKey = aliceEphemeralNextHeaderKeyPair.privateKey, + publicKey = bobEphemeralKeyPair.generateKey(), + headerPublicKey = bobEphemeralHeaderKeyPair.generateKey(), + nextHeaderPublicKey = bobEphemeralNextHeaderKeyPair.generateKey(), + salt = "RelaySMS v1".encodeToByteArray(), + nonce1 = aliceNonce, + nonce2 = bobNonce, + info = "RelaySMS C2S DR v1".encodeToByteArray() + ) + + val (bobSk, bobSkH, bobSkNh) = SecurityCurve25519(bobEphemeralKeyPair.privateKey) + .agreeWithAuthAndNonce( + authenticationPublicKey = null, + authenticationPrivateKey = bobStaticKeyPair.privateKey, + headerPrivateKey = bobEphemeralHeaderKeyPair.privateKey, + nextHeaderPrivateKey = bobEphemeralNextHeaderKeyPair.privateKey, + publicKey = aliceEphemeralKeyPair.generateKey(), + headerPublicKey = aliceEphemeralHeaderKeyPair.generateKey(), + nextHeaderPublicKey = aliceEphemeralNextHeaderKeyPair.generateKey(), + salt = "RelaySMS v1".encodeToByteArray(), + nonce1 = aliceNonce, + nonce2 = bobNonce, + info = "RelaySMS C2S DR v1".encodeToByteArray() + ) + + assertArrayEquals(aliceSk, bobSk) + assertArrayEquals(aliceSkH, bobSkH) + assertArrayEquals(aliceSkNh, bobSkNh) + + val aliceState = States() + RatchetsHE.ratchetInitAlice( + state = aliceState, + SK = aliceSk, + bobDhPublicKey = bobEphemeralKeyPair.generateKey(), + sharedHka = aliceSkH, + sharedNhkb = aliceSkNh + ) + + val bobState = States() + RatchetsHE.ratchetInitBob( + state = bobState, + SK = bobSk, + bobDhPublicKeypair = bobEphemeralKeyPair.getKeypair(), + sharedHka = bobSkH, + sharedNhkb = bobSkNh + ) + + val originalText = SecureRandom.getSeed(32); + val (encHeader, aliceCipherText) = RatchetsHE.ratchetEncrypt( + aliceState, + originalText, + bobStaticKeyPair.generateKey() + ) + + var encHeader1: ByteArray? = null + var aliceCipherText1: ByteArray? = null + for(i in 1..10) { + val (encHeader2, aliceCipherText2) = RatchetsHE.ratchetEncrypt( + aliceState, + originalText, + bobStaticKeyPair.generateKey() + ) + encHeader1 = encHeader2 + aliceCipherText1 = aliceCipherText2 + } + + val bobPlainText = RatchetsHE.ratchetDecrypt( + state = bobState, + encHeader = encHeader, + cipherText = aliceCipherText, + AD = bobStaticKeyPair.generateKey() + ) + + val bobPlainText1 = RatchetsHE.ratchetDecrypt( + state = bobState, + encHeader = encHeader1!!, + cipherText = aliceCipherText1!!, + AD = bobStaticKeyPair.generateKey() + ) + + assertArrayEquals(originalText, bobPlainText) + assertArrayEquals(originalText, bobPlainText1) + } @Test fun completeRatchetTest() { @@ -48,7 +151,7 @@ class RatchetsTest { val bobPlainText1 = Ratchets.ratchetDecrypt(bobState, header1, aliceCipherText1, bob.generateKey()) - println(bobState.serializedStates) + println(bobState.serialize()) assertArrayEquals(originalText, bobPlainText) assertArrayEquals(originalText, bobPlainText1) diff --git a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/StateTest.kt b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/StateTest.kt index 42e3d84..42c254e 100644 --- a/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/StateTest.kt +++ b/src/androidTest/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/StateTest.kt @@ -2,6 +2,7 @@ package com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal import androidx.test.filters.SmallTest import junit.framework.TestCase.assertEquals +import kotlinx.serialization.json.Json import org.junit.Test import java.security.SecureRandom @@ -12,11 +13,8 @@ class StateTest { val state = States() state.DHs = android.util.Pair(SecureRandom.getSeed(32), SecureRandom.getSeed(32)) - val serializedStates = state.serializedStates - println("Encoded values: $serializedStates") - val state1 = States(serializedStates) - println(state1.serializedStates) - - assertEquals(state, state1) + val serializedStates = Json.encodeToString(state) + val deserializedStates = Json.decodeFromString(serializedStates) + assertEquals(state, deserializedStates) } } \ No newline at end of file diff --git a/src/main/java/com/afkanerd/.DS_Store b/src/main/java/com/afkanerd/.DS_Store new file mode 100644 index 0000000..fce4432 Binary files /dev/null and b/src/main/java/com/afkanerd/.DS_Store differ diff --git a/src/main/java/com/afkanerd/smswithoutborders/.DS_Store b/src/main/java/com/afkanerd/smswithoutborders/.DS_Store new file mode 100644 index 0000000..68565d0 Binary files /dev/null and b/src/main/java/com/afkanerd/smswithoutborders/.DS_Store differ diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/.DS_Store b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/.DS_Store new file mode 100644 index 0000000..d661d7f Binary files /dev/null and b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/.DS_Store differ diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/CryptoHelpers.java b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/CryptoHelpers.java index df6ed86..e84b2af 100644 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/CryptoHelpers.java +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/CryptoHelpers.java @@ -1,7 +1,5 @@ package com.afkanerd.smswithoutborders.libsignal_doubleratchet; -import android.util.Base64; - import com.google.common.primitives.Bytes; import java.security.GeneralSecurityException; @@ -30,7 +28,7 @@ public static byte[] getCipherMacParameters(String ALGO, byte[] mk) throws Gener public static Mac buildVerificationHash(byte[] authKey, byte[] AD, byte[] cipherText) throws GeneralSecurityException { Mac mac = CryptoHelpers.HMAC256(authKey); - byte[] updatedParams = Bytes.concat(AD, cipherText); + byte[] updatedParams = (AD == null) ? cipherText : Bytes.concat(AD, cipherText); mac.update(updatedParams); return mac; } @@ -85,9 +83,7 @@ public static Mac HMAC256(byte[] data) throws GeneralSecurityException { public static byte[] generateRandomBytes(int length) { SecureRandom random = new SecureRandom(); - byte[] bytes = new - - byte[length]; + byte[] bytes = new byte[length]; random.nextBytes(bytes); return bytes; } diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/EncryptionController.kt b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/EncryptionController.kt index 548a0c3..f73bd34 100644 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/EncryptionController.kt +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/EncryptionController.kt @@ -203,7 +203,7 @@ object EncryptionController { android.util.Pair(keypair.second, keypair.first) ) } - else state = States(String(currentState)) + else state = States.deserialize(String(currentState)) val keypair = context.getKeypairValues(address) var decryptedText: String? @@ -215,7 +215,7 @@ object EncryptionController { keypair.first )) context.saveBinaryDataEncrypted(keystore, - state.serializedStates.encodeToByteArray()) + state.serialize().encodeToByteArray()) } catch(e: Exception) { throw e } @@ -253,7 +253,7 @@ object EncryptionController { val sk = context.calculateSharedSecret(address, publicKeyBytes) Ratchets.ratchetInitAlice(state, sk, publicKeyBytes) } - else state = States(String(currentState)) + else state = States.deserialize(String(currentState)) val ratchetOutput = Ratchets.ratchetEncrypt(state, text.encodeToByteArray(), publicKeyBytes) @@ -264,7 +264,7 @@ object EncryptionController { ratchetOutput.second ) context.saveBinaryDataEncrypted(keystore, - state.serializedStates.encodeToByteArray()) + state.serialize().encodeToByteArray()) Base64.encodeToString(message, Base64.DEFAULT) } catch(e: Exception) { throw e diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityCurve25519.kt b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityCurve25519.kt index 9382dbb..a4fc59c 100644 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityCurve25519.kt +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/SecurityCurve25519.kt @@ -7,10 +7,109 @@ class SecurityCurve25519(val privateKey: ByteArray = Curve25519.generateRandomKe return Curve25519.publicKey(this.privateKey) } - fun calculateSharedSecret(publicKey: ByteArray): ByteArray { + private fun agreeWithAuthAndNonceImpl( + authenticationPublicKey: ByteArray?, + authenticationPrivateKey: ByteArray?, + publicKey: ByteArray, + salt: ByteArray, + info: ByteArray, + handshakeSalt: ByteArray, + privateKey: ByteArray? = null, + ): ByteArray { + val privateKey = privateKey ?: this.privateKey + val dh1 = if(authenticationPrivateKey == null) + Curve25519.sharedSecret(privateKey, authenticationPublicKey) + else + Curve25519.sharedSecret(authenticationPrivateKey, publicKey) + val dh2 = Curve25519.sharedSecret(privateKey, publicKey) + var ck = CryptoHelpers.HKDF( + "HMACSHA256", + handshakeSalt, + salt, + info, + 32, + 1 + )[0] + ck = CryptoHelpers.HKDF( + "HMACSHA256", + dh1, + ck, + info, + 32, + 1 + )[0] + return CryptoHelpers.HKDF( + "HMACSHA256", + dh2, + ck, + info, + 32, + 1 + )[0] + } + + fun agreeWithAuthAndNonce( + authenticationPublicKey: ByteArray?, + authenticationPrivateKey: ByteArray?, + headerPrivateKey: ByteArray, + nextHeaderPrivateKey: ByteArray, + publicKey: ByteArray, + headerPublicKey: ByteArray, + nextHeaderPublicKey: ByteArray, + salt: ByteArray, + nonce1: ByteArray, + nonce2: ByteArray, + info: ByteArray, + ): Triple { + val handshakeSalt = nonce1 + nonce2 + val headerInfo = "RelaySMS C2S DRHE v1".encodeToByteArray() + + val rootKey = agreeWithAuthAndNonceImpl( + authenticationPublicKey = authenticationPublicKey, + authenticationPrivateKey = authenticationPrivateKey, + publicKey = publicKey, + salt = salt, + info = info, + handshakeSalt = handshakeSalt, + ) + + val headerKey = agreeWithAuthAndNonceImpl( + authenticationPublicKey = authenticationPublicKey, + authenticationPrivateKey = authenticationPrivateKey, + publicKey = headerPublicKey, + salt = salt, + info = headerInfo, + handshakeSalt = handshakeSalt, + privateKey = headerPrivateKey + ) + + val nextHeaderKey = agreeWithAuthAndNonceImpl( + authenticationPublicKey = authenticationPublicKey, + authenticationPrivateKey = authenticationPrivateKey, + publicKey = nextHeaderPublicKey, + salt = salt, + info = headerInfo, + handshakeSalt = handshakeSalt, + privateKey = nextHeaderPrivateKey + ) + + return Triple(rootKey, headerKey, nextHeaderKey) + } + + fun calculateSharedSecret( + publicKey: ByteArray, + salt: ByteArray? = null, + info: ByteArray? = "x25591_key_exchange".encodeToByteArray(), + ): ByteArray { val sharedKey = Curve25519.sharedSecret(this.privateKey, publicKey) - return CryptoHelpers.HKDF("HMACSHA256", sharedKey, null, - "x25591_key_exchange".encodeToByteArray(), 32, 1)[0] + return CryptoHelpers.HKDF( + "HMACSHA256", + sharedKey, + salt, + info, + 32, + 1 + )[0] } fun getKeypair(): android.util.Pair { diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/extensions/context.kt b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/extensions/context.kt index 7b2c70c..fbc7020 100644 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/extensions/context.kt +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/extensions/context.kt @@ -11,20 +11,14 @@ import androidx.datastore.preferences.preferencesDataStore import com.afkanerd.smswithoutborders.libsignal_doubleratchet.SecurityAES import com.afkanerd.smswithoutborders.libsignal_doubleratchet.SecurityRSA import com.google.gson.Gson -import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.map import java.io.IOException -import java.security.KeyFactory import java.security.KeyPair import java.security.KeyStore import java.security.KeyStoreException import java.security.NoSuchAlgorithmException import java.security.UnrecoverableEntryException import java.security.cert.CertificateException -import java.security.spec.PKCS8EncodedKeySpec -import java.security.spec.X509EncodedKeySpec -import javax.crypto.SecretKey import javax.crypto.spec.SecretKeySpec val Context.dataStore: DataStore by preferencesDataStore(name = "secure_comms") @@ -134,8 +128,7 @@ suspend fun Context.saveBinaryDataEncrypted( @Throws suspend fun Context.getEncryptedBinaryData(keystoreAlias: String): ByteArray? { val keyValue = stringPreferencesKey(keystoreAlias) - val data = dataStore.data.first()[keyValue] - if(data == null) return null + val data = dataStore.data.first()[keyValue] ?: return null val savedBinaryData = Gson().fromJson(data, SavedBinaryData::class.java) diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/Protocols.java b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/Protocols.java index 976ae94..3776adc 100644 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/Protocols.java +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/Protocols.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.security.GeneralSecurityException; +import java.util.ArrayList; import javax.crypto.Mac; @@ -29,14 +30,36 @@ */ public class Protocols { final static int HKDF_LEN = 32; - final static int HKDF_NUM_KEYS = 2; final static String ALGO = "HMACSHA512"; + final static byte[] KDF_RK_HE_INFO = "RelaySMS C2S DR Ratchet v1".getBytes(); public static Pair GENERATE_DH() { SecurityCurve25519 securityCurve25519 = new SecurityCurve25519(); return new Pair<>(securityCurve25519.getPrivateKey(), securityCurve25519.generateKey()); } + /** + * + * @param dhPair This private key (keypair required in Android if supported) + * @param peerPublicKey + * @return + * @throws GeneralSecurityException + * @throws IOException + * @throws InterruptedException + */ + public static byte[] DH_HE( + Pair dhPair, + byte[] peerPublicKey, + byte[] info + ) { + SecurityCurve25519 securityCurve25519 = new SecurityCurve25519(dhPair.first); + return securityCurve25519.calculateSharedSecret( + peerPublicKey, + null, + info + ); + } + /** * * @param dhPair This private key (keypair required in Android if supported) @@ -48,12 +71,26 @@ public static Pair GENERATE_DH() { */ public static byte[] DH(Pair dhPair, byte[] peerPublicKey) { SecurityCurve25519 securityCurve25519 = new SecurityCurve25519(dhPair.first); - return securityCurve25519.calculateSharedSecret(peerPublicKey); + return securityCurve25519.calculateSharedSecret( + peerPublicKey, + null, + "x25591_key_exchange".getBytes() + ); + } + + public static byte[][] KDF_RK_HE( + byte[] rk, + byte[] dhOut + ) throws GeneralSecurityException { + int numKeys = 3; + byte[] info = "SMSWithoutBorders DRHE v2".getBytes(); + return CryptoHelpers.HKDF(ALGO, dhOut, rk, info, HKDF_LEN, numKeys); } public static Pair KDF_RK(byte[] rk, byte[] dhOut) throws GeneralSecurityException { + int numKeys = 2; byte[] info = "KDF_RK".getBytes(); - byte[][] hkdfOutput = CryptoHelpers.HKDF(ALGO, dhOut, rk, info, HKDF_LEN, HKDF_NUM_KEYS); + byte[][] hkdfOutput = CryptoHelpers.HKDF(ALGO, dhOut, rk, info, HKDF_LEN, numKeys); return new Pair<>(hkdfOutput[0], hkdfOutput[1]); } @@ -65,6 +102,24 @@ public static Pair KDF_CK(byte[] ck) throws GeneralSecurityExcep return new Pair<>(_ck, mk); } + public static byte[] HENCRYPT( + byte[] mk, + byte[] plainText + ) throws Throwable { + byte[] hkdfOutput = getCipherMacParameters(ALGO, mk); + byte[] key = new byte[32]; + byte[] authenticationKey = new byte[32]; + byte[] iv = new byte[16]; + + System.arraycopy(hkdfOutput, 0, key, 0, 32); + System.arraycopy(hkdfOutput, 32, authenticationKey, 0, 32); + System.arraycopy(hkdfOutput, 64, iv, 0, 16); + + byte[] cipherText = SecurityAES.encryptAES256CBC(plainText, key, iv); + byte[] mac = buildVerificationHash(authenticationKey, null, cipherText).doFinal(); + return Bytes.concat(cipherText, mac); + } + public static byte[] ENCRYPT(byte[] mk, byte[] plainText, byte[] associated_data) throws Throwable { byte[] hkdfOutput = getCipherMacParameters(ALGO, mk); byte[] key = new byte[32]; @@ -80,6 +135,20 @@ public static byte[] ENCRYPT(byte[] mk, byte[] plainText, byte[] associated_data return Bytes.concat(cipherText, mac); } + public static byte[] HDECRYPT( + byte[] mk, + byte[] cipherText + ) throws Throwable { + cipherText = verifyCipherText(ALGO, mk, cipherText, null); + + byte[] hkdfOutput = getCipherMacParameters(ALGO, mk); + byte[] key = new byte[32]; + byte[] iv = new byte[16]; + System.arraycopy(hkdfOutput, 0, key, 0, 32); + System.arraycopy(hkdfOutput, 64, iv, 0, 16); + + return SecurityAES.decryptAES256CBC(cipherText, key, iv); + } public static byte[] DECRYPT(byte[] mk, byte[] cipherText, byte[] associated_data) throws Throwable { cipherText = verifyCipherText(ALGO, mk, cipherText, associated_data); @@ -92,6 +161,10 @@ public static byte[] DECRYPT(byte[] mk, byte[] cipherText, byte[] associated_dat return SecurityAES.decryptAES256CBC(cipherText, key, iv); } + public static byte[] CONCAT_HE(byte[] AD, byte[] headers) throws IOException { + return Bytes.concat(AD, headers); + } + public static byte[] CONCAT(byte[] AD, Headers headers) throws IOException { return Bytes.concat(AD, headers.getSerialized()); } diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsHE.kt b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsHE.kt new file mode 100644 index 0000000..8d93d95 --- /dev/null +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/RatchetsHE.kt @@ -0,0 +1,213 @@ +package com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal + +import android.util.Pair +import androidx.core.util.component1 +import androidx.core.util.component2 +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.CONCAT_HE +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.DECRYPT +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.ENCRYPT +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.GENERATE_DH +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.HDECRYPT +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.HENCRYPT +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.KDF_CK +import com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal.Protocols.KDF_RK_HE + +object RatchetsHE { + + const val MAX_SKIP: Int = 100 + + fun ratchetInitAlice( + state: States, + SK: ByteArray, + bobDhPublicKey: ByteArray, + sharedHka: ByteArray, + sharedNhkb: ByteArray, + ) { + state.DHRs = GENERATE_DH() + state.DHRr = bobDhPublicKey + + val kdfRkHEOutputs = KDF_RK_HE(SK, + Protocols.DH_HE( + state.DHRs, + state.DHRr, + Protocols.KDF_RK_HE_INFO + ) + ) + state.RK = kdfRkHEOutputs[0] + state.CKs = kdfRkHEOutputs[1] + state.NHKs = kdfRkHEOutputs[2] + + state.CKr = null + state.Ns = 0 + state.Nr = 0 + state.PN = 0 + state.MKSKIPPED = mutableMapOf() + state.HKs = sharedHka + state.HKr = null + state.NHKr = sharedNhkb + } + + fun ratchetInitBob( + state: States, + SK: ByteArray, + bobDhPublicKeypair: Pair, + sharedHka: ByteArray, + sharedNhkb: ByteArray, + ) { + state.DHRs = bobDhPublicKeypair + state.DHRr = null + state.RK = SK + state.CKs = null + state.CKr = null + state.Ns = 0 + state.Nr = 0 + state.PN = 0 + state.MKSKIPPED = mutableMapOf() + state.HKs = null + state.NHKs = sharedNhkb + state.HKr = null + state.NHKr = sharedHka + } + + fun ratchetEncrypt( + state: States, + plaintext: ByteArray, + AD: ByteArray, + ) : Pair { + val kdfCk = KDF_CK(state.CKs) + state.CKs = kdfCk.first + val mk = kdfCk.second + val header = Headers(state.DHRs, state.PN, state.Ns) + val encHeader = HENCRYPT(state.HKs, header.serialized) + state.Ns += 1 + return Pair(encHeader, + ENCRYPT(mk, plaintext, CONCAT_HE(AD, encHeader))) + } + + fun ratchetDecrypt( + state: States, + encHeader: ByteArray, + cipherText: ByteArray, + AD: ByteArray, + ): ByteArray { + val plaintext = trySkippedMessageKeys(state, encHeader, cipherText, AD) + if(plaintext != null) + return plaintext + + val (header, dhRatchet) = decryptHeader(state, encHeader) + if(dhRatchet) { + skipMessageKeys(state, header.PN) + DHRatchetHE(state, header) + } + + skipMessageKeys(state, header.N) + val kdfCk = KDF_CK(state.CKr) + state.CKr = kdfCk.first + val mk = kdfCk.second + state.Nr += 1 + return DECRYPT(mk, cipherText, CONCAT_HE(AD, encHeader)) + } + + private fun skipMessageKeys( + state: States, + until: Int, + ) { + if(state.Nr + MAX_SKIP < until) + throw Exception("MAX_SKIP Exceeded") + + state.CKr?.let{ + while(state.Nr < until) { + val kdfCk = KDF_CK(state.CKr) + state.CKr = kdfCk.first + val mk = kdfCk.second + state.MKSKIPPED[Pair(state.HKr, state.Nr)] = mk + state.Nr += 1 + } + } + } + + private fun trySkippedMessageKeys( + state: States, + encHeader: ByteArray, + ciphertext: ByteArray, + AD: ByteArray + ) : ByteArray? { + state.MKSKIPPED.forEach { + val hk = it.key.first + val n = it.key.second + val mk = it.value + + val header = HDECRYPT(hk, encHeader).run { + Headers.deSerializeHeader(this) + } + if(header != null && header.N == n) { + state.MKSKIPPED.remove(it.key) + return DECRYPT(mk, ciphertext, CONCAT_HE(AD, encHeader)) + } + } + + return null + } + + private fun decryptHeader( + state: States, + encHeader: ByteArray + ) : Pair { + var header: Headers? = null + try { + header = HDECRYPT(state.HKr, encHeader).run { + Headers.deSerializeHeader(this) + } + } catch(e: Exception) { + e.printStackTrace() + } + + header?.let { + return Pair(header, false) + } + + header = HDECRYPT(state.NHKr, encHeader).run { + Headers.deSerializeHeader(this) + } + header?.let { + return Pair(header, true) + } + throw Exception("Generic error decrypting header...") + } + + private fun DHRatchetHE( + state: States, + header: Headers + ) { + state.PN = state.Ns + state.Ns = 0 + state.Nr = 0 + state.HKs = state.NHKs + state.HKr = state.NHKr + state.DHRr = header.dh + + var kdfRkHEOutputs = KDF_RK_HE(state.RK, + Protocols.DH_HE( + state.DHRs, + state.DHRr, + Protocols.KDF_RK_HE_INFO + ) + ) + state.RK = kdfRkHEOutputs[0] + state.CKr = kdfRkHEOutputs[1] + state.NHKr = kdfRkHEOutputs[2] + + state.DHRs = GENERATE_DH() + + kdfRkHEOutputs = KDF_RK_HE(state.RK, + Protocols.DH_HE( + state.DHRs, + state.DHRr, + Protocols.KDF_RK_HE_INFO + ) + ) + state.RK = kdfRkHEOutputs[0] + state.CKs = kdfRkHEOutputs[1] + state.NHKs = kdfRkHEOutputs[2] + } +} \ No newline at end of file diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.java b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.java deleted file mode 100644 index 1625b57..0000000 --- a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.java +++ /dev/null @@ -1,165 +0,0 @@ -package com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal; - -import android.util.Log; -import android.util.Pair; -import android.util.Base64; - -import androidx.annotation.Nullable; - -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonPrimitive; -import com.google.gson.JsonSerializationContext; - -import com.google.gson.JsonSerializer; - -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; - -import java.lang.reflect.Type; -import java.security.KeyPair; -import java.security.NoSuchAlgorithmException; -import java.security.PublicKey; -import java.security.spec.InvalidKeySpecException; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -public class States { - public Pair DHs; - public byte[] DHr; - public byte[] RK; - public byte[] CKs; - public byte[] CKr; - - public int Ns = 0; - public int Nr = 0; - public int PN = 0; - - public Map, byte[]> MKSKIPPED = new HashMap<>(); - - public States(String states) throws JSONException { - if(states == null) - return; - - JSONObject jsonObject = new JSONObject(states); - if(jsonObject.has("DHs")) { - String[] encodedValues = jsonObject.getString("DHs").split(" "); - this.DHs = new Pair<>(android.util.Base64.decode(encodedValues[0], Base64.NO_WRAP), - android.util.Base64.decode(encodedValues[1], Base64.NO_WRAP)); - } - if(jsonObject.has("DHr")) - this.DHr = Base64.decode(jsonObject.getString("DHr"), Base64.NO_WRAP); - - if(jsonObject.has("RK")) - this.RK = Base64.decode(jsonObject.getString("RK"), Base64.NO_WRAP); - if(jsonObject.has("CKs")) - this.CKs = Base64.decode(jsonObject.get("CKs").toString(), Base64.NO_WRAP); - if(jsonObject.has("CKr")) - this.CKr = Base64.decode(jsonObject.getString("CKr"), Base64.NO_WRAP); - this.Ns = jsonObject.getInt("Ns"); - this.Nr = jsonObject.getInt("Nr"); - this.PN = jsonObject.getInt("PN"); - - JSONArray mkskipped = jsonObject.getJSONArray("MKSKIPPED"); - for(int i=0;i(pubkey, pair.getInt(StatesMKSKIPPED.N)), - Base64.decode(pair.getString(StatesMKSKIPPED.MK), Base64.NO_WRAP)); - } - } - - public static byte[] getADForHeaders(States states, Headers headers) { - for(Map.Entry, byte[]> entry : states.MKSKIPPED.entrySet()) { - if(entry.getKey().second == (headers.PN + headers.N)) - return entry.getKey().first; - } - - return null; - } - - public States() { - } - - public String getSerializedStates() { - GsonBuilder gsonBuilder = new GsonBuilder(); - gsonBuilder.registerTypeAdapter(KeyPair.class, new StatesKeyPairSerializer()); - gsonBuilder.registerTypeAdapter(PublicKey.class, new StatesPublicKeySerializer()); - gsonBuilder.registerTypeAdapter(byte[].class, new StatesBytesSerializer()); - gsonBuilder.registerTypeAdapter(Pair.class, new PairStatesBytesSerializer()); - gsonBuilder.registerTypeAdapter(Map.class, new StatesMKSKIPPED()); - gsonBuilder.setPrettyPrinting() - .disableHtmlEscaping(); - - Gson gson = gsonBuilder.create(); - return gson.toJson(this); - } - - @Override - public boolean equals(@Nullable Object obj) { - if(obj instanceof States state) { - return state.getSerializedStates().equals(this.getSerializedStates()); - } - return false; - } - - public static class StatesKeyPairSerializer implements JsonSerializer { - @Override - public JsonElement serialize(KeyPair src, Type typeOfSrc, JsonSerializationContext context) { - return new JsonPrimitive( - Base64.encodeToString(src.getPublic().getEncoded(), Base64.NO_WRAP)); - } - } - - public static class StatesPublicKeySerializer implements JsonSerializer { - @Override - public JsonElement serialize(PublicKey src, Type typeOfSrc, JsonSerializationContext context) { - return new JsonPrimitive(Base64.encodeToString(src.getEncoded(), Base64.NO_WRAP)); - } - } - - public static class PairStatesBytesSerializer implements JsonSerializer> { - @Override - public JsonElement serialize(Pair src, Type typeOfSrc, JsonSerializationContext context) { - return new JsonPrimitive( Base64.encodeToString(src.first, Base64.NO_WRAP) + " " + - Base64.encodeToString(src.second, Base64.NO_WRAP)); - } - } - - public static class StatesBytesSerializer implements JsonSerializer { - @Override - public JsonElement serialize(byte[] src, Type typeOfSrc, JsonSerializationContext context) { - return new JsonPrimitive( Base64.encodeToString(src, Base64.NO_WRAP)); - } - } - - - public static class StatesMKSKIPPED implements JsonSerializer, byte[]>> { - public final static String PUBLIC_KEY = "PUBLIC_KEY"; - public final static String N = "N"; - public final static String MK = "MK"; - - @Override - public JsonElement serialize(Map, byte[]> src, Type typeOfSrc, JsonSerializationContext context) { - JsonArray jsonArray = new JsonArray(); - for(Map.Entry, byte[]> entry: src.entrySet()) { - String publicKey = Base64.encodeToString(entry.getKey().first, Base64.NO_WRAP); - Integer n = entry.getKey().second; - - JsonObject jsonObject1 = new JsonObject(); - jsonObject1.addProperty(PUBLIC_KEY, publicKey); - jsonObject1.addProperty(N, n); - jsonObject1.addProperty(MK, Base64.encodeToString(entry.getValue(), Base64.NO_WRAP)); - - jsonArray.add(jsonObject1); - } - return jsonArray; - } - } - -} diff --git a/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.kt b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.kt new file mode 100644 index 0000000..79817ed --- /dev/null +++ b/src/main/java/com/afkanerd/smswithoutborders/libsignal_doubleratchet/libsignal/States.kt @@ -0,0 +1,61 @@ +package com.afkanerd.smswithoutborders.libsignal_doubleratchet.libsignal + +import android.util.Pair +import kotlinx.serialization.json.Json + +data class States( + @JvmField + var DHs: Pair? = null, + + @JvmField + var DHr: ByteArray? = null, + + @JvmField + var RK: ByteArray? = null, + + @JvmField + var CKs: ByteArray? = null, + + @JvmField + var CKr: ByteArray? = null, + + @JvmField + var Ns: Int = 0, + + @JvmField + var Nr: Int = 0, + + @JvmField + var PN: Int = 0, + + @JvmField + var DHRs: Pair? = null, + + @JvmField + var DHRr: ByteArray? = null, + + @JvmField + var HKs: ByteArray? = null, + + @JvmField + var HKr: ByteArray? = null, + + @JvmField + var NHKs: ByteArray? = null, + + @JvmField + var NHKr: ByteArray? = null, + + @JvmField + var MKSKIPPED: MutableMap, ByteArray> = mutableMapOf() +) { + fun serialize(): String { + return Json.encodeToString(this) + } + + companion object { + fun deserialize(input: String): States { + return Json.decodeFromString(input) + } + } +} \ No newline at end of file