Skip to content

Commit be07f2b

Browse files
authored
Merge pull request #85 from k-wasniowski/sframe-key-api-removal
sframe key api removal
2 parents 2049052 + e9b6131 commit be07f2b

4 files changed

Lines changed: 152 additions & 0 deletions

File tree

include/sframe/map.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class map : private vector<std::optional<std::pair<K, V>>, N>
5858
return pos->value().second;
5959
}
6060

61+
void erase(const K& key)
62+
{
63+
erase_if_key([key](const auto& other) { return other == key; });
64+
}
65+
6166
template<typename F>
6267
void erase_if_key(F&& f)
6368
{

include/sframe/sframe.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class Context
118118
virtual ~Context();
119119

120120
Result<void> add_key(KeyID kid, KeyUsage usage, input_bytes key);
121+
void remove_key(KeyID kid);
121122

122123
Result<output_bytes> protect(KeyID key_id,
123124
output_bytes ciphertext,
@@ -134,6 +135,8 @@ class Context
134135
CipherSuite suite;
135136
map<KeyID, KeyRecord, SFRAME_MAX_KEYS> keys;
136137

138+
Result<void> require_key(KeyID key_id) const;
139+
137140
Result<output_bytes> protect_inner(const Header& header,
138141
output_bytes ciphertext,
139142
input_bytes plaintext,
@@ -160,6 +163,7 @@ class MLSContext : protected Context
160163
Result<void> add_epoch(EpochID epoch_id,
161164
input_bytes sframe_epoch_secret,
162165
size_t sender_bits);
166+
void remove_epoch(EpochID epoch_id);
163167
void purge_before(EpochID keeper);
164168

165169
Result<output_bytes> protect(EpochID epoch_id,

src/sframe.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ Context::Context(CipherSuite suite_in)
8181

8282
Context::~Context() = default;
8383

84+
void
85+
Context::remove_key(KeyID key_id)
86+
{
87+
keys.erase(key_id);
88+
}
89+
8490
Result<void>
8591
Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
8692
{
@@ -118,12 +124,23 @@ form_aad(const Header& header, input_bytes metadata)
118124
return aad;
119125
}
120126

127+
Result<void>
128+
Context::require_key(KeyID key_id) const
129+
{
130+
if (!keys.contains(key_id)) {
131+
return SFrameError(SFrameErrorType::invalid_parameter_error,
132+
"Unknown key ID");
133+
}
134+
return Result<void>::ok();
135+
}
136+
121137
Result<output_bytes>
122138
Context::protect(KeyID key_id,
123139
output_bytes ciphertext,
124140
input_bytes plaintext,
125141
input_bytes metadata)
126142
{
143+
SFRAME_VOID_OR_RETURN(require_key(key_id));
127144
auto& key_record = keys.at(key_id);
128145
const auto counter = key_record.counter;
129146
key_record.counter += 1;
@@ -166,6 +183,7 @@ Context::protect_inner(const Header& header,
166183
"Ciphertext too small for cipher overhead");
167184
}
168185

186+
SFRAME_VOID_OR_RETURN(require_key(header.key_id));
169187
const auto& key_and_salt = keys.at(header.key_id);
170188

171189
SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
@@ -190,6 +208,7 @@ Context::unprotect_inner(const Header& header,
190208
"Plaintext too small for decrypted value");
191209
}
192210

211+
SFRAME_VOID_OR_RETURN(require_key(header.key_id));
193212
const auto& key_and_salt = keys.at(header.key_id);
194213

195214
SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata));
@@ -326,6 +345,17 @@ MLSContext::EpochKeys::base_key(CipherSuite ciphersuite,
326345
ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size);
327346
}
328347

348+
void
349+
MLSContext::remove_epoch(EpochID epoch_id)
350+
{
351+
purge_epoch(epoch_id);
352+
353+
const auto idx = epoch_id & epoch_mask;
354+
if (idx < epoch_cache.size()) {
355+
epoch_cache[idx].reset();
356+
}
357+
}
358+
329359
void
330360
MLSContext::purge_epoch(EpochID epoch_id)
331361
{

test/sframe.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,116 @@ TEST_CASE("MLS Failure after Purge")
231231
const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata).unwrap();
232232
CHECK(plaintext == to_bytes(dec_ab_2));
233233
}
234+
235+
TEST_CASE("SFrame Context Remove Key")
236+
{
237+
const auto suite = CipherSuite::AES_GCM_128_SHA256;
238+
const auto kid = KeyID(0x07);
239+
const auto key = from_hex("000102030405060708090a0b0c0d0e0f");
240+
const auto plaintext = from_hex("00010203");
241+
const auto metadata = bytes{};
242+
243+
auto pt_out = bytes(plaintext.size());
244+
auto ct_out = bytes(plaintext.size() + Context::max_overhead);
245+
246+
auto sender = Context(suite);
247+
auto receiver = Context(suite);
248+
sender.add_key(kid, KeyUsage::protect, key).unwrap();
249+
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();
250+
251+
// Protect and unprotect succeed before removal
252+
auto encrypted =
253+
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
254+
auto decrypted =
255+
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
256+
CHECK(decrypted == plaintext);
257+
258+
// Remove sender key and verify protect fails
259+
sender.remove_key(kid);
260+
CHECK(sender.protect(kid, ct_out, plaintext, metadata).error().type() ==
261+
SFrameErrorType::invalid_parameter_error);
262+
263+
// Remove receiver key and verify unprotect fails
264+
receiver.remove_key(kid);
265+
CHECK(receiver.unprotect(pt_out, encrypted, metadata).error().type() ==
266+
SFrameErrorType::invalid_parameter_error);
267+
268+
// Re-add keys and verify round-trip works again
269+
sender.add_key(kid, KeyUsage::protect, key).unwrap();
270+
receiver.add_key(kid, KeyUsage::unprotect, key).unwrap();
271+
272+
encrypted =
273+
to_bytes(sender.protect(kid, ct_out, plaintext, metadata).unwrap());
274+
decrypted =
275+
to_bytes(receiver.unprotect(pt_out, encrypted, metadata).unwrap());
276+
CHECK(decrypted == plaintext);
277+
}
278+
279+
TEST_CASE("SFrame Context Remove Key - Nonexistent Key")
280+
{
281+
const auto suite = CipherSuite::AES_GCM_128_SHA256;
282+
283+
auto ctx = Context(suite);
284+
285+
// Removing a key that was never added should not throw
286+
CHECK_NOTHROW(ctx.remove_key(KeyID(0x99)));
287+
}
288+
289+
TEST_CASE("MLS Remove Epoch")
290+
{
291+
const auto suite = CipherSuite::AES_GCM_128_SHA256;
292+
const auto epoch_bits = 2;
293+
const auto metadata = from_hex("00010203");
294+
const auto plaintext = from_hex("04050607");
295+
const auto sender_id = MLSContext::SenderID(0xA0A0A0A0);
296+
const auto sframe_epoch_secret_1 = bytes(32, 1);
297+
const auto sframe_epoch_secret_2 = bytes(32, 2);
298+
299+
auto pt_out = bytes(plaintext.size());
300+
auto ct_out = bytes(plaintext.size() + Context::max_overhead);
301+
302+
auto member_a = MLSContext(suite, epoch_bits);
303+
auto member_b = MLSContext(suite, epoch_bits);
304+
305+
// Install epoch 1 and verify round-trip
306+
const auto epoch_id_1 = MLSContext::EpochID(1);
307+
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
308+
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);
309+
310+
auto enc =
311+
member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
312+
.unwrap();
313+
auto enc_data = to_bytes(enc);
314+
auto dec = to_bytes(member_b.unprotect(pt_out, enc_data, metadata).unwrap());
315+
CHECK(plaintext == dec);
316+
317+
// Install epoch 2
318+
const auto epoch_id_2 = MLSContext::EpochID(2);
319+
member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2);
320+
member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2);
321+
322+
// Remove only epoch 1 (not purge_before) and verify it fails
323+
member_a.remove_epoch(epoch_id_1);
324+
member_b.remove_epoch(epoch_id_1);
325+
326+
CHECK(member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
327+
.error()
328+
.type() == SFrameErrorType::invalid_parameter_error);
329+
CHECK(member_b.unprotect(pt_out, enc_data, metadata).error().type() ==
330+
SFrameErrorType::invalid_parameter_error);
331+
332+
// Epoch 2 should still work
333+
enc = member_a.protect(epoch_id_2, sender_id, ct_out, plaintext, metadata)
334+
.unwrap();
335+
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
336+
CHECK(plaintext == dec);
337+
338+
// Re-add epoch 1 with the same secret and verify it works again
339+
member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1);
340+
member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1);
341+
342+
enc = member_a.protect(epoch_id_1, sender_id, ct_out, plaintext, metadata)
343+
.unwrap();
344+
dec = to_bytes(member_b.unprotect(pt_out, enc, metadata).unwrap());
345+
CHECK(plaintext == dec);
346+
}

0 commit comments

Comments
 (0)