Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/sframe/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class map : private vector<std::optional<std::pair<K, V>>, N>
return pos->value().second;
}

void erase(const K& key)
{
erase_if_key([key](const auto& other) { return other == key; });
}

template<typename F>
void erase_if_key(F&& f)
{
Expand Down
4 changes: 4 additions & 0 deletions include/sframe/sframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class Context
virtual ~Context();

Result<void> add_key(KeyID kid, KeyUsage usage, input_bytes key);
void remove_key(KeyID kid);

Result<output_bytes> protect(KeyID key_id,
output_bytes ciphertext,
Expand All @@ -134,6 +135,8 @@ class Context
CipherSuite suite;
map<KeyID, KeyRecord, SFRAME_MAX_KEYS> keys;

Result<void> require_key(KeyID key_id) const;

Result<output_bytes> protect_inner(const Header& header,
output_bytes ciphertext,
input_bytes plaintext,
Expand All @@ -160,6 +163,7 @@ class MLSContext : protected Context
Result<void> add_epoch(EpochID epoch_id,
input_bytes sframe_epoch_secret,
size_t sender_bits);
void remove_epoch(EpochID epoch_id);
void purge_before(EpochID keeper);

Result<output_bytes> protect(EpochID epoch_id,
Expand Down
30 changes: 30 additions & 0 deletions src/sframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ Context::Context(CipherSuite suite_in)

Context::~Context() = default;

void
Context::remove_key(KeyID key_id)
{
keys.erase(key_id);
}

Result<void>
Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
{
Expand Down Expand Up @@ -118,12 +124,23 @@ form_aad(const Header& header, input_bytes metadata)
return aad;
}

Result<void>
Context::require_key(KeyID key_id) const
{
if (!keys.contains(key_id)) {
return SFrameError(SFrameErrorType::invalid_parameter_error,
"Unknown key ID");
}
return Result<void>::ok();
}

Result<output_bytes>
Context::protect(KeyID key_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
SFRAME_VOID_OR_RETURN(require_key(key_id));
auto& key_record = keys.at(key_id);
const auto counter = key_record.counter;
key_record.counter += 1;
Expand Down Expand Up @@ -166,6 +183,7 @@ Context::protect_inner(const Header& header,
"Ciphertext too small for cipher overhead");
}

SFRAME_VOID_OR_RETURN(require_key(header.key_id));
const auto& key_and_salt = keys.at(header.key_id);

SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
Expand All @@ -190,6 +208,7 @@ Context::unprotect_inner(const Header& header,
"Plaintext too small for decrypted value");
}

SFRAME_VOID_OR_RETURN(require_key(header.key_id));
const auto& key_and_salt = keys.at(header.key_id);

SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
Expand Down Expand Up @@ -326,6 +345,17 @@ MLSContext::EpochKeys::base_key(CipherSuite ciphersuite,
ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size);
}

void
MLSContext::remove_epoch(EpochID epoch_id)
{
purge_epoch(epoch_id);

const auto idx = epoch_id & epoch_mask;
if (idx < epoch_cache.size()) {
epoch_cache[idx].reset();
}
}

void
MLSContext::purge_epoch(EpochID epoch_id)
{
Expand Down
113 changes: 113 additions & 0 deletions test/sframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,116 @@ TEST_CASE("MLS Failure after Purge")
const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata).unwrap();
CHECK(plaintext == to_bytes(dec_ab_2));
}

TEST_CASE("SFrame Context Remove Key")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;
const auto kid = KeyID(0x07);
const auto key = from_hex("000102030405060708090a0b0c0d0e0f");
const auto plaintext = from_hex("00010203");
const auto metadata = bytes{};

auto pt_out = bytes(plaintext.size());
auto ct_out = bytes(plaintext.size() + Context::max_overhead);

auto sender = Context(suite);
auto receiver = Context(suite);
sender.add_key(kid, KeyUsage::protect, key).unwrap();
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();

// Protect and unprotect succeed before removal
auto encrypted =
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
auto decrypted =
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
CHECK(decrypted == plaintext);

// Remove sender key and verify protect fails
sender.remove_key(kid);
CHECK(sender.protect(kid, ct_out, plaintext, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Remove receiver key and verify unprotect fails
receiver.remove_key(kid);
CHECK(receiver.unprotect(pt_out, encrypted, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Re-add keys and verify round-trip works again
sender.add_key(kid, KeyUsage::protect, key).unwrap();
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();

encrypted =
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
decrypted =
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
CHECK(decrypted == plaintext);
}

TEST_CASE("SFrame Context Remove Key - Nonexistent Key")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;

auto ctx = Context(suite);

// Removing a key that was never added should not throw
CHECK_NOTHROW(ctx.remove_key(KeyID(0x99)));
}

TEST_CASE("MLS Remove Epoch")
{
const auto suite = CipherSuite::AES_GCM_128_SHA256;
const auto epoch_bits = 2;
const auto metadata = from_hex("00010203");
const auto plaintext = from_hex("04050607");
const auto sender_id = MLSContext::SenderID(0xA0A0A0A0);
const auto sframe_epoch_secret_1 = bytes(32, 1);
const auto sframe_epoch_secret_2 = bytes(32, 2);

auto pt_out = bytes(plaintext.size());
auto ct_out = bytes(plaintext.size() + Context::max_overhead);

auto member_a = MLSContext(suite, epoch_bits);
auto member_b = MLSContext(suite, epoch_bits);

// Install epoch 1 and verify round-trip
const auto epoch_id_1 = MLSContext::EpochID(1);
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);

auto enc =
member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.unwrap();
auto enc_data = to_bytes(enc);
auto dec = to_bytes(member_b.unprotect(pt_out, enc_data, metadata).unwrap());
CHECK(plaintext == dec);

// Install epoch 2
const auto epoch_id_2 = MLSContext::EpochID(2);
member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2);
member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2);

// Remove only epoch 1 (not purge_before) and verify it fails
member_a.remove_epoch(epoch_id_1);
member_b.remove_epoch(epoch_id_1);

CHECK(member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.error()
.type() == SFrameErrorType::invalid_parameter_error);
CHECK(member_b.unprotect(pt_out, enc_data, metadata).error().type() ==
SFrameErrorType::invalid_parameter_error);

// Epoch 2 should still work
enc = member_a.protect(epoch_id_2, sender_id, ct_out, plaintext, metadata)
.unwrap();
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
CHECK(plaintext == dec);

// Re-add epoch 1 with the same secret and verify it works again
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);

enc = member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
.unwrap();
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
CHECK(plaintext == dec);
}
Loading