diff --git a/include/sframe/map.h b/include/sframe/map.h index 58d16aa..b6f8617 100644 --- a/include/sframe/map.h +++ b/include/sframe/map.h @@ -58,6 +58,11 @@ class map : private vector>, N> return pos->value().second; } + void erase(const K& key) + { + erase_if_key([key](const auto& other) { return other == key; }); + } + template void erase_if_key(F&& f) { diff --git a/include/sframe/sframe.h b/include/sframe/sframe.h index 65b34ba..397d03e 100644 --- a/include/sframe/sframe.h +++ b/include/sframe/sframe.h @@ -118,6 +118,7 @@ class Context virtual ~Context(); Result add_key(KeyID kid, KeyUsage usage, input_bytes key); + void remove_key(KeyID kid); Result protect(KeyID key_id, output_bytes ciphertext, @@ -134,6 +135,8 @@ class Context CipherSuite suite; map keys; + Result require_key(KeyID key_id) const; + Result protect_inner(const Header& header, output_bytes ciphertext, input_bytes plaintext, @@ -160,6 +163,7 @@ class MLSContext : protected Context Result add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret, size_t sender_bits); + void remove_epoch(EpochID epoch_id); void purge_before(EpochID keeper); Result protect(EpochID epoch_id, diff --git a/src/sframe.cpp b/src/sframe.cpp index 529c81d..b389cfc 100644 --- a/src/sframe.cpp +++ b/src/sframe.cpp @@ -81,6 +81,12 @@ Context::Context(CipherSuite suite_in) Context::~Context() = default; +void +Context::remove_key(KeyID key_id) +{ + keys.erase(key_id); +} + Result Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key) { @@ -118,12 +124,23 @@ form_aad(const Header& header, input_bytes metadata) return aad; } +Result +Context::require_key(KeyID key_id) const +{ + if (!keys.contains(key_id)) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Unknown key ID"); + } + return Result::ok(); +} + Result Context::protect(KeyID key_id, output_bytes ciphertext, input_bytes plaintext, input_bytes metadata) { + SFRAME_VOID_OR_RETURN(require_key(key_id)); auto& key_record = keys.at(key_id); const auto counter = key_record.counter; key_record.counter += 1; @@ -166,6 +183,7 @@ Context::protect_inner(const Header& header, "Ciphertext too small for cipher overhead"); } + SFRAME_VOID_OR_RETURN(require_key(header.key_id)); const auto& key_and_salt = keys.at(header.key_id); SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata)); @@ -190,6 +208,7 @@ Context::unprotect_inner(const Header& header, "Plaintext too small for decrypted value"); } + SFRAME_VOID_OR_RETURN(require_key(header.key_id)); const auto& key_and_salt = keys.at(header.key_id); SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata)); @@ -326,6 +345,17 @@ MLSContext::EpochKeys::base_key(CipherSuite ciphersuite, ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size); } +void +MLSContext::remove_epoch(EpochID epoch_id) +{ + purge_epoch(epoch_id); + + const auto idx = epoch_id & epoch_mask; + if (idx < epoch_cache.size()) { + epoch_cache[idx].reset(); + } +} + void MLSContext::purge_epoch(EpochID epoch_id) { diff --git a/test/sframe.cpp b/test/sframe.cpp index d665026..9f8adf8 100644 --- a/test/sframe.cpp +++ b/test/sframe.cpp @@ -231,3 +231,116 @@ TEST_CASE("MLS Failure after Purge") const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata).unwrap(); CHECK(plaintext == to_bytes(dec_ab_2)); } + +TEST_CASE("SFrame Context Remove Key") +{ + const auto suite = CipherSuite::AES_GCM_128_SHA256; + const auto kid = KeyID(0x07); + const auto key = from_hex("000102030405060708090a0b0c0d0e0f"); + const auto plaintext = from_hex("00010203"); + const auto metadata = bytes{}; + + auto pt_out = bytes(plaintext.size()); + auto ct_out = bytes(plaintext.size() + Context::max_overhead); + + auto sender = Context(suite); + auto receiver = Context(suite); + sender.add_key(kid, KeyUsage::protect, key).unwrap(); + receiver.add_key(kid, KeyUsage::unprotect, key).unwrap(); + + // Protect and unprotect succeed before removal + auto encrypted = + to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap()); + auto decrypted = + to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap()); + CHECK(decrypted == plaintext); + + // Remove sender key and verify protect fails + sender.remove_key(kid); + CHECK(sender.protect(kid, ct_out, plaintext, metadata).error().type() == + SFrameErrorType::invalid_parameter_error); + + // Remove receiver key and verify unprotect fails + receiver.remove_key(kid); + CHECK(receiver.unprotect(pt_out, encrypted, metadata).error().type() == + SFrameErrorType::invalid_parameter_error); + + // Re-add keys and verify round-trip works again + sender.add_key(kid, KeyUsage::protect, key).unwrap(); + receiver.add_key(kid, KeyUsage::unprotect, key).unwrap(); + + encrypted = + to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap()); + decrypted = + to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap()); + CHECK(decrypted == plaintext); +} + +TEST_CASE("SFrame Context Remove Key - Nonexistent Key") +{ + const auto suite = CipherSuite::AES_GCM_128_SHA256; + + auto ctx = Context(suite); + + // Removing a key that was never added should not throw + CHECK_NOTHROW(ctx.remove_key(KeyID(0x99))); +} + +TEST_CASE("MLS Remove Epoch") +{ + const auto suite = CipherSuite::AES_GCM_128_SHA256; + const auto epoch_bits = 2; + const auto metadata = from_hex("00010203"); + const auto plaintext = from_hex("04050607"); + const auto sender_id = MLSContext::SenderID(0xA0A0A0A0); + const auto sframe_epoch_secret_1 = bytes(32, 1); + const auto sframe_epoch_secret_2 = bytes(32, 2); + + auto pt_out = bytes(plaintext.size()); + auto ct_out = bytes(plaintext.size() + Context::max_overhead); + + auto member_a = MLSContext(suite, epoch_bits); + auto member_b = MLSContext(suite, epoch_bits); + + // Install epoch 1 and verify round-trip + const auto epoch_id_1 = MLSContext::EpochID(1); + member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1); + member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1); + + auto enc = + member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata) + .unwrap(); + auto enc_data = to_bytes(enc); + auto dec = to_bytes(member_b.unprotect(pt_out, enc_data, metadata).unwrap()); + CHECK(plaintext == dec); + + // Install epoch 2 + const auto epoch_id_2 = MLSContext::EpochID(2); + member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2); + member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2); + + // Remove only epoch 1 (not purge_before) and verify it fails + member_a.remove_epoch(epoch_id_1); + member_b.remove_epoch(epoch_id_1); + + CHECK(member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata) + .error() + .type() == SFrameErrorType::invalid_parameter_error); + CHECK(member_b.unprotect(pt_out, enc_data, metadata).error().type() == + SFrameErrorType::invalid_parameter_error); + + // Epoch 2 should still work + enc = member_a.protect(epoch_id_2, sender_id, ct_out, plaintext, metadata) + .unwrap(); + dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap()); + CHECK(plaintext == dec); + + // Re-add epoch 1 with the same secret and verify it works again + member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1); + member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1); + + enc = member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata) + .unwrap(); + dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap()); + CHECK(plaintext == dec); +}