diff --git a/include/sframe/result.h b/include/sframe/result.h index 5e9d775..5a896c3 100644 --- a/include/sframe/result.h +++ b/include/sframe/result.h @@ -48,9 +48,10 @@ class SFrameError const char* message_ = nullptr; }; -// Helper to convert SFrameError to appropriate exception type +#ifdef __cpp_exceptions void -throw_on_error(const SFrameError& error); +throw_sframe_error(const SFrameError& error); +#endif template class Result @@ -96,6 +97,17 @@ class Result bool is_err() const { return std::holds_alternative(data_); } +#ifdef __cpp_exceptions + T unwrap() + { + if (std::holds_alternative(data_)) { + throw_sframe_error(std::get(data_)); + } + + return std::move(std::get(data_)); + } +#endif + private: std::variant data_; }; @@ -135,24 +147,21 @@ class Result bool is_err() const { return error_.has_value(); } +#ifdef __cpp_exceptions + void unwrap() + { + if (error_.has_value()) { + throw_sframe_error(error_.value()); + } + } +#endif + private: std::optional error_; }; } // namespace SFRAME_NAMESPACE -// Unwrap a Result, throwing the corresponding exception on error. -// Use in functions that have NOT yet been migrated away from exceptions. -// Usage: const auto val = SFRAME_VALUE_OR_THROW(some_result_expr); -#define SFRAME_VALUE_OR_THROW(expr) \ - ([&]() { \ - auto _result = (expr); \ - if (_result.is_err()) { \ - SFRAME_NAMESPACE::throw_on_error(_result.error()); \ - } \ - return _result.value(); \ - }()) - // Unwrap a Result into `var`, propagating the error by early return. // Use in functions that already return Result. // Usage: SFRAME_VALUE_OR_RETURN(val, some_result_expr); diff --git a/include/sframe/sframe.h b/include/sframe/sframe.h index 6a6d52e..65b34ba 100644 --- a/include/sframe/sframe.h +++ b/include/sframe/sframe.h @@ -8,6 +8,10 @@ #include #include +#ifdef __cpp_exceptions +#include +#endif + #include // These constants define the size of certain internal data structures if @@ -28,6 +32,7 @@ namespace SFRAME_NAMESPACE { +#ifdef __cpp_exceptions struct crypto_error : std::runtime_error { crypto_error(); @@ -60,6 +65,7 @@ struct invalid_key_usage_error : std::runtime_error using parent = std::runtime_error; using parent::parent; }; +#endif enum class CipherSuite : uint16_t { @@ -111,15 +117,15 @@ class Context Context(CipherSuite suite); virtual ~Context(); - void add_key(KeyID kid, KeyUsage usage, input_bytes key); + Result add_key(KeyID kid, KeyUsage usage, input_bytes key); - output_bytes protect(KeyID key_id, - output_bytes ciphertext, - input_bytes plaintext, - input_bytes metadata); - output_bytes unprotect(output_bytes plaintext, - input_bytes ciphertext, - input_bytes metadata); + Result protect(KeyID key_id, + output_bytes ciphertext, + input_bytes plaintext, + input_bytes metadata); + Result unprotect(output_bytes plaintext, + input_bytes ciphertext, + input_bytes metadata); static constexpr size_t max_overhead = 17 + 16; static constexpr size_t max_metadata_size = 512; @@ -150,29 +156,30 @@ class MLSContext : protected Context MLSContext(CipherSuite suite_in, size_t epoch_bits_in); - void add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret); - void add_epoch(EpochID epoch_id, - input_bytes sframe_epoch_secret, - size_t sender_bits); + Result add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret); + Result add_epoch(EpochID epoch_id, + input_bytes sframe_epoch_secret, + size_t sender_bits); void purge_before(EpochID keeper); - output_bytes protect(EpochID epoch_id, - SenderID sender_id, - output_bytes ciphertext, - input_bytes plaintext, - input_bytes metadata); - output_bytes protect(EpochID epoch_id, - SenderID sender_id, - ContextID context_id, - output_bytes ciphertext, - input_bytes plaintext, - input_bytes metadata); - - output_bytes unprotect(output_bytes plaintext, - input_bytes ciphertext, - input_bytes metadata); + Result protect(EpochID epoch_id, + SenderID sender_id, + output_bytes ciphertext, + input_bytes plaintext, + input_bytes metadata); + Result protect(EpochID epoch_id, + SenderID sender_id, + ContextID context_id, + output_bytes ciphertext, + input_bytes plaintext, + input_bytes metadata); + + Result unprotect(output_bytes plaintext, + input_bytes ciphertext, + input_bytes metadata); private: + // NOLINTBEGIN(clang-analyzer-core.uninitialized.Assign) struct EpochKeys { static constexpr size_t max_secret_size = 64; @@ -184,20 +191,22 @@ class MLSContext : protected Context uint64_t max_sender_id; uint64_t max_context_id; - EpochKeys(EpochID full_epoch_in, - input_bytes sframe_epoch_secret_in, - size_t epoch_bits, - size_t sender_bits_in); + EpochKeys() = default; + static Result create(EpochID full_epoch_in, + input_bytes sframe_epoch_secret_in, + size_t epoch_bits, + size_t sender_bits_in); Result> base_key(CipherSuite suite, SenderID sender_id) const; }; + // NOLINTEND(clang-analyzer-core.uninitialized.Assign) void purge_epoch(EpochID epoch_id); - KeyID form_key_id(EpochID epoch_id, - SenderID sender_id, - ContextID context_id) const; - void ensure_key(KeyID key_id, KeyUsage usage); + Result form_key_id(EpochID epoch_id, + SenderID sender_id, + ContextID context_id) const; + Result ensure_key(KeyID key_id, KeyUsage usage); const size_t epoch_bits; const size_t epoch_mask; diff --git a/src/crypto_boringssl.cpp b/src/crypto_boringssl.cpp index 5d1d3d9..96dd5ba 100644 --- a/src/crypto_boringssl.cpp +++ b/src/crypto_boringssl.cpp @@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE { /// Convert between native identifiers / errors and OpenSSL ones /// +#ifdef __cpp_exceptions crypto_error::crypto_error() : std::runtime_error(ERR_error_string(ERR_get_error(), nullptr)) { } +#endif static Result openssl_digest_type(CipherSuite suite) diff --git a/src/crypto_openssl11.cpp b/src/crypto_openssl11.cpp index abea601..9b33785 100644 --- a/src/crypto_openssl11.cpp +++ b/src/crypto_openssl11.cpp @@ -23,10 +23,12 @@ using scoped_hmac_ctx = std::unique_ptr; /// Convert between native identifiers / errors and OpenSSL ones /// +#ifdef __cpp_exceptions crypto_error::crypto_error() : std::runtime_error(ERR_error_string(ERR_get_error(), nullptr)) { } +#endif static Result openssl_digest_type(CipherSuite suite) diff --git a/src/crypto_openssl3.cpp b/src/crypto_openssl3.cpp index 85bf3ea..36fb125 100644 --- a/src/crypto_openssl3.cpp +++ b/src/crypto_openssl3.cpp @@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE { /// Convert between native identifiers / errors and OpenSSL ones /// +#ifdef __cpp_exceptions crypto_error::crypto_error() : std::runtime_error(ERR_error_string(ERR_get_error(), nullptr)) { } +#endif static Result openssl_cipher(CipherSuite suite) diff --git a/src/result.cpp b/src/result.cpp index 038f82b..6fae585 100644 --- a/src/result.cpp +++ b/src/result.cpp @@ -1,12 +1,25 @@ -#include #include namespace SFRAME_NAMESPACE { +#ifdef __cpp_exceptions +unsupported_ciphersuite_error::unsupported_ciphersuite_error() + : std::runtime_error("Unsupported ciphersuite") +{ +} + +authentication_error::authentication_error() + : std::runtime_error("AEAD authentication failure") +{ +} + void -throw_on_error(const SFrameError& error) +throw_sframe_error(const SFrameError& error) { switch (error.type()) { + case SFrameErrorType::internal_error: + throw std::runtime_error(error.message() ? error.message() + : "SFrame internal error"); case SFrameErrorType::buffer_too_small_error: throw buffer_too_small_error(error.message()); case SFrameErrorType::invalid_parameter_error: @@ -19,9 +32,8 @@ throw_on_error(const SFrameError& error) throw authentication_error(); case SFrameErrorType::invalid_key_usage_error: throw invalid_key_usage_error(error.message()); - default: - throw std::runtime_error(error.message()); } } +#endif } // namespace SFRAME_NAMESPACE diff --git a/src/sframe.cpp b/src/sframe.cpp index e1f7e2b..529c81d 100644 --- a/src/sframe.cpp +++ b/src/sframe.cpp @@ -5,20 +5,6 @@ namespace SFRAME_NAMESPACE { -/// -/// Errors -/// - -unsupported_ciphersuite_error::unsupported_ciphersuite_error() - : std::runtime_error("Unsupported ciphersuite") -{ -} - -authentication_error::authentication_error() - : std::runtime_error("AEAD authentication failure") -{ -} - /// /// KeyRecord /// @@ -95,12 +81,13 @@ Context::Context(CipherSuite suite_in) Context::~Context() = default; -void +Result Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key) { - keys.emplace(key_id, - SFRAME_VALUE_OR_THROW( - KeyRecord::from_base_key(suite, key_id, usage, base_key))); + SFRAME_VALUE_OR_RETURN( + record, KeyRecord::from_base_key(suite, key_id, usage, base_key)); + keys.emplace(key_id, record); + return Result::ok(); } static owned_bytes @@ -131,7 +118,7 @@ form_aad(const Header& header, input_bytes metadata) return aad; } -output_bytes +Result Context::protect(KeyID key_id, output_bytes ciphertext, input_bytes plaintext, @@ -144,25 +131,27 @@ Context::protect(KeyID key_id, const auto header = Header{ key_id, counter }; const auto header_data = header.encoded(); if (ciphertext.size() < header_data.size()) { - throw buffer_too_small_error("Ciphertext too small for SFrame header"); + return SFrameError(SFrameErrorType::buffer_too_small_error, + "Ciphertext too small for SFrame header"); } std::copy(header_data.begin(), header_data.end(), ciphertext.begin()); auto inner_ciphertext = ciphertext.subspan(header_data.size()); - auto final_ciphertext = SFRAME_VALUE_OR_THROW( + SFRAME_VALUE_OR_RETURN( + final_ciphertext, Context::protect_inner(header, inner_ciphertext, plaintext, metadata)); return ciphertext.first(header_data.size() + final_ciphertext.size()); } -output_bytes +Result Context::unprotect(output_bytes plaintext, input_bytes ciphertext, input_bytes metadata) { - const auto header = SFRAME_VALUE_OR_THROW(Header::parse(ciphertext)); + SFRAME_VALUE_OR_RETURN(header, Header::parse(ciphertext)); const auto inner_ciphertext = ciphertext.subspan(header.size()); - return SFRAME_VALUE_OR_THROW( - Context::unprotect_inner(header, plaintext, inner_ciphertext, metadata)); + return Context::unprotect_inner( + header, plaintext, inner_ciphertext, metadata); } Result @@ -220,13 +209,13 @@ MLSContext::MLSContext(CipherSuite suite_in, size_t epoch_bits_in) epoch_cache.resize(1 << epoch_bits_in); } -void +Result MLSContext::add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret) { - add_epoch(epoch_id, sframe_epoch_secret, 0); + return add_epoch(epoch_id, sframe_epoch_secret, 0); } -void +Result MLSContext::add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret, size_t sender_bits) @@ -238,7 +227,11 @@ MLSContext::add_epoch(EpochID epoch_id, purge_epoch(epoch->full_epoch); } - epoch.emplace(epoch_id, sframe_epoch_secret, epoch_bits, sender_bits); + SFRAME_VALUE_OR_RETURN( + new_epoch, + EpochKeys::create(epoch_id, sframe_epoch_secret, epoch_bits, sender_bits)); + epoch.emplace(std::move(new_epoch)); + return Result::ok(); } void @@ -252,7 +245,7 @@ MLSContext::purge_before(EpochID keeper) } } -output_bytes +Result MLSContext::protect(EpochID epoch_id, SenderID sender_id, output_bytes ciphertext, @@ -262,7 +255,7 @@ MLSContext::protect(EpochID epoch_id, return protect(epoch_id, sender_id, 0, ciphertext, plaintext, metadata); } -output_bytes +Result MLSContext::protect(EpochID epoch_id, SenderID sender_id, ContextID context_id, @@ -270,49 +263,55 @@ MLSContext::protect(EpochID epoch_id, input_bytes plaintext, input_bytes metadata) { - auto key_id = form_key_id(epoch_id, sender_id, context_id); - ensure_key(key_id, KeyUsage::protect); + SFRAME_VALUE_OR_RETURN(key_id, form_key_id(epoch_id, sender_id, context_id)); + SFRAME_VOID_OR_RETURN(ensure_key(key_id, KeyUsage::protect)); return Context::protect(key_id, ciphertext, plaintext, metadata); } -output_bytes +Result MLSContext::unprotect(output_bytes plaintext, input_bytes ciphertext, input_bytes metadata) { - const auto header = SFRAME_VALUE_OR_THROW(Header::parse(ciphertext)); + SFRAME_VALUE_OR_RETURN(header, Header::parse(ciphertext)); const auto inner_ciphertext = ciphertext.subspan(header.size()); - ensure_key(header.key_id, KeyUsage::unprotect); - return SFRAME_VALUE_OR_THROW( - Context::unprotect_inner(header, plaintext, inner_ciphertext, metadata)); + SFRAME_VOID_OR_RETURN(ensure_key(header.key_id, KeyUsage::unprotect)); + return Context::unprotect_inner( + header, plaintext, inner_ciphertext, metadata); } -MLSContext::EpochKeys::EpochKeys(MLSContext::EpochID full_epoch_in, - input_bytes sframe_epoch_secret_in, - size_t epoch_bits, - size_t sender_bits_in) - : full_epoch(full_epoch_in) - , sframe_epoch_secret(sframe_epoch_secret_in) - , sender_bits(sender_bits_in) +Result +MLSContext::EpochKeys::create(MLSContext::EpochID full_epoch_in, + input_bytes sframe_epoch_secret_in, + size_t epoch_bits, + size_t sender_bits_in) { static constexpr uint64_t one = 1; static constexpr size_t key_id_bits = 64; - if (sender_bits > key_id_bits - epoch_bits) { - throw invalid_parameter_error("Sender ID field too large"); + EpochKeys epoch_keys; + epoch_keys.full_epoch = full_epoch_in; + epoch_keys.sframe_epoch_secret = sframe_epoch_secret_in; + epoch_keys.sender_bits = sender_bits_in; + + if (epoch_keys.sender_bits > key_id_bits - epoch_bits) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Sender ID field too large"); } // XXX(RLB) We use 0 as a signifier that the sender takes the rest of the key // ID, and context IDs are not allowed. This would be more explicit if we // used std::optional, but would require more modern C++. - if (sender_bits == 0) { - sender_bits = key_id_bits - epoch_bits; + if (epoch_keys.sender_bits == 0) { + epoch_keys.sender_bits = key_id_bits - epoch_bits; } - context_bits = key_id_bits - sender_bits - epoch_bits; - max_sender_id = (one << sender_bits) - 1; - max_context_id = (one << context_bits) - 1; + epoch_keys.context_bits = key_id_bits - epoch_keys.sender_bits - epoch_bits; + epoch_keys.max_sender_id = (one << epoch_keys.sender_bits) - 1; + epoch_keys.max_context_id = (one << epoch_keys.context_bits) - 1; + + return epoch_keys; } Result> @@ -336,7 +335,7 @@ MLSContext::purge_epoch(EpochID epoch_id) [&](const auto& epoch) { return (epoch & epoch_bits) == drop_bits; }); } -KeyID +Result MLSContext::form_key_id(EpochID epoch_id, SenderID sender_id, ContextID context_id) const @@ -344,15 +343,18 @@ MLSContext::form_key_id(EpochID epoch_id, auto epoch_index = epoch_id & epoch_mask; auto& epoch = epoch_cache[epoch_index]; if (!epoch) { - throw invalid_parameter_error("Unknown epoch"); + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Unknown epoch"); } if (sender_id > epoch->max_sender_id) { - throw invalid_parameter_error("Sender ID overflow"); + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Sender ID overflow"); } if (context_id > epoch->max_context_id) { - throw invalid_parameter_error("Context ID overflow"); + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Context ID overflow"); } auto sender_part = uint64_t(sender_id) << epoch_bits; @@ -364,25 +366,25 @@ MLSContext::form_key_id(EpochID epoch_id, return KeyID(context_part | sender_part | epoch_index); } -void +Result MLSContext::ensure_key(KeyID key_id, KeyUsage usage) { // If the required key already exists, we are done const auto epoch_index = key_id & epoch_mask; auto& epoch = epoch_cache[epoch_index]; if (!epoch) { - throw invalid_parameter_error("Unknown epoch"); + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Unknown epoch"); } if (keys.contains(key_id)) { - return; + return Result::ok(); } // Otherwise, derive a key and implant it const auto sender_id = key_id >> epoch_bits; - Context::add_key( - key_id, usage, SFRAME_VALUE_OR_THROW(epoch->base_key(suite, sender_id))); - return; + SFRAME_VALUE_OR_RETURN(base, epoch->base_key(suite, sender_id)); + return Context::add_key(key_id, usage, base); } } // namespace SFRAME_NAMESPACE diff --git a/test/common.h b/test/common.h index ad40edf..ccf2e67 100644 --- a/test/common.h +++ b/test/common.h @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/test/header.cpp b/test/header.cpp index 4557106..531b551 100644 --- a/test/header.cpp +++ b/test/header.cpp @@ -38,7 +38,7 @@ TEST_CASE("Header Known-Answer") for (const auto& tc : cases) { // Decode - const auto decoded = SFRAME_VALUE_OR_THROW(Header::parse(tc.encoding)); + const auto decoded = Header::parse(tc.encoding).unwrap(); REQUIRE(decoded.key_id == tc.key_id); REQUIRE(decoded.counter == tc.counter); REQUIRE(decoded.size() == tc.encoding.size()); diff --git a/test/sframe.cpp b/test/sframe.cpp index 0200634..d665026 100644 --- a/test/sframe.cpp +++ b/test/sframe.cpp @@ -39,14 +39,15 @@ TEST_CASE("SFrame Round-Trip") auto& key = pair.second; auto send = Context(suite); - send.add_key(kid, KeyUsage::protect, key); + send.add_key(kid, KeyUsage::protect, key).unwrap(); auto recv = Context(suite); - recv.add_key(kid, KeyUsage::unprotect, key); + recv.add_key(kid, KeyUsage::unprotect, key).unwrap(); for (int i = 0; i < rounds; i++) { - auto encrypted = to_bytes(send.protect(kid, ct_out, plaintext, {})); - auto decrypted = to_bytes(recv.unprotect(pt_out, encrypted, {})); + auto encrypted = + to_bytes(send.protect(kid, ct_out, plaintext, {}).unwrap()); + auto decrypted = to_bytes(recv.unprotect(pt_out, encrypted, {}).unwrap()); CHECK(decrypted == plaintext); } } @@ -81,18 +82,22 @@ TEST_CASE("MLS Round-Trip") for (MLSContext::EpochID epoch_id = 0; epoch_id < test_epochs; epoch_id++) { const auto sframe_epoch_secret = bytes(8, uint8_t(epoch_id)); - member_a.add_epoch(epoch_id, sframe_epoch_secret); - member_b.add_epoch(epoch_id, sframe_epoch_secret); + member_a.add_epoch(epoch_id, sframe_epoch_secret).unwrap(); + member_b.add_epoch(epoch_id, sframe_epoch_secret).unwrap(); for (int i = 0; i < epoch_rounds; i++) { auto encrypted_ab = - member_a.protect(epoch_id, sender_id_a, ct_out, plaintext, metadata); - auto decrypted_ab = member_b.unprotect(pt_out, encrypted_ab, metadata); + member_a.protect(epoch_id, sender_id_a, ct_out, plaintext, metadata) + .unwrap(); + auto decrypted_ab = + member_b.unprotect(pt_out, encrypted_ab, metadata).unwrap(); CHECK(plaintext == to_bytes(decrypted_ab)); auto encrypted_ba = - member_b.protect(epoch_id, sender_id_b, ct_out, plaintext, metadata); - auto decrypted_ba = member_a.unprotect(pt_out, encrypted_ba, metadata); + member_b.protect(epoch_id, sender_id_b, ct_out, plaintext, metadata) + .unwrap(); + auto decrypted_ba = + member_a.unprotect(pt_out, encrypted_ba, metadata).unwrap(); CHECK(plaintext == to_bytes(decrypted_ba)); } } @@ -132,31 +137,46 @@ TEST_CASE("MLS Round-Trip with context") for (MLSContext::EpochID epoch_id = 0; epoch_id < test_epochs; epoch_id++) { const auto sframe_epoch_secret = bytes(8, uint8_t(epoch_id)); - member_a_0.add_epoch(epoch_id, sframe_epoch_secret, sender_id_bits); - member_a_1.add_epoch(epoch_id, sframe_epoch_secret, sender_id_bits); - member_b.add_epoch(epoch_id, sframe_epoch_secret); + member_a_0.add_epoch(epoch_id, sframe_epoch_secret, sender_id_bits) + .unwrap(); + member_a_1.add_epoch(epoch_id, sframe_epoch_secret, sender_id_bits) + .unwrap(); + member_b.add_epoch(epoch_id, sframe_epoch_secret).unwrap(); for (int i = 0; i < epoch_rounds; i++) { - auto encrypted_ab_0 = member_a_0.protect( - epoch_id, sender_id_a, context_id_0, ct_out_0, plaintext, metadata); - auto decrypted_ab_0 = - to_bytes(member_b.unprotect(pt_out, encrypted_ab_0, metadata)); + auto encrypted_ab_0 = member_a_0 + .protect(epoch_id, + sender_id_a, + context_id_0, + ct_out_0, + plaintext, + metadata) + .unwrap(); + auto decrypted_ab_0 = to_bytes( + member_b.unprotect(pt_out, encrypted_ab_0, metadata).unwrap()); CHECK(plaintext == decrypted_ab_0); - auto encrypted_ab_1 = member_a_1.protect( - epoch_id, sender_id_a, context_id_1, ct_out_1, plaintext, metadata); - auto decrypted_ab_1 = - to_bytes(member_b.unprotect(pt_out, encrypted_ab_1, metadata)); + auto encrypted_ab_1 = member_a_1 + .protect(epoch_id, + sender_id_a, + context_id_1, + ct_out_1, + plaintext, + metadata) + .unwrap(); + auto decrypted_ab_1 = to_bytes( + member_b.unprotect(pt_out, encrypted_ab_1, metadata).unwrap()); CHECK(plaintext == decrypted_ab_1); CHECK(to_bytes(encrypted_ab_0) != to_bytes(encrypted_ab_1)); - auto encrypted_ba = member_b.protect( - epoch_id, sender_id_b, ct_out_0, plaintext, metadata); - auto decrypted_ba_0 = - to_bytes(member_a_0.unprotect(pt_out, encrypted_ba, metadata)); - auto decrypted_ba_1 = - to_bytes(member_a_1.unprotect(pt_out, encrypted_ba, metadata)); + auto encrypted_ba = + member_b.protect(epoch_id, sender_id_b, ct_out_0, plaintext, metadata) + .unwrap(); + auto decrypted_ba_0 = to_bytes( + member_a_0.unprotect(pt_out, encrypted_ba, metadata).unwrap()); + auto decrypted_ba_1 = to_bytes( + member_a_1.unprotect(pt_out, encrypted_ba, metadata).unwrap()); CHECK(plaintext == decrypted_ba_0); CHECK(plaintext == decrypted_ba_1); } @@ -182,30 +202,32 @@ TEST_CASE("MLS Failure after Purge") // Install epoch 1 and create a cipihertext const auto epoch_id_1 = MLSContext::EpochID(1); - member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1); - member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1); + member_a.add_epoch(epoch_id_1, sframe_epoch_secret_1).unwrap(); + member_b.add_epoch(epoch_id_1, sframe_epoch_secret_1).unwrap(); const auto enc_ab_1 = - member_a.protect(epoch_id_1, sender_id_a, ct_out, plaintext, metadata); + member_a.protect(epoch_id_1, sender_id_a, ct_out, plaintext, metadata) + .unwrap(); const auto enc_ab_1_data = to_bytes(enc_ab_1); // Install epoch 2 const auto epoch_id_2 = MLSContext::EpochID(2); - member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2); - member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2); + member_a.add_epoch(epoch_id_2, sframe_epoch_secret_2).unwrap(); + member_b.add_epoch(epoch_id_2, sframe_epoch_secret_2).unwrap(); // Purge epoch 1 and verify failure member_a.purge_before(epoch_id_2); member_b.purge_before(epoch_id_2); - CHECK_THROWS_AS( - member_a.protect(epoch_id_1, sender_id_a, ct_out, plaintext, metadata), - invalid_parameter_error); - CHECK_THROWS_AS(member_b.unprotect(pt_out, enc_ab_1_data, metadata), - invalid_parameter_error); + CHECK(member_a.protect(epoch_id_1, sender_id_a, ct_out, plaintext, metadata) + .error() + .type() == SFrameErrorType::invalid_parameter_error); + CHECK(member_b.unprotect(pt_out, enc_ab_1_data, metadata).error().type() == + SFrameErrorType::invalid_parameter_error); const auto enc_ab_2 = - member_a.protect(epoch_id_2, sender_id_a, ct_out, plaintext, metadata); - const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata); + member_a.protect(epoch_id_2, sender_id_a, ct_out, plaintext, metadata) + .unwrap(); + const auto dec_ab_2 = member_b.unprotect(pt_out, enc_ab_2, metadata).unwrap(); CHECK(plaintext == to_bytes(dec_ab_2)); } diff --git a/test/vectors.cpp b/test/vectors.cpp index e1a6a30..28cbbe7 100644 --- a/test/vectors.cpp +++ b/test/vectors.cpp @@ -66,7 +66,7 @@ struct HeaderTestVector void verify() const { // Decode - const auto decoded = SFRAME_VALUE_OR_THROW(Header::parse(encoded)); + const auto decoded = Header::parse(encoded).unwrap(); REQUIRE(decoded.key_id == kid); REQUIRE(decoded.counter == ctr); REQUIRE(decoded.size() == encoded.data.size()); @@ -103,14 +103,14 @@ struct AesCtrHmacTestVector { // Seal auto ciphertext = bytes(ct.data.size()); - const auto ct_out = SFRAME_VALUE_OR_THROW( - seal(cipher_suite, key, nonce, ciphertext, aad, pt)); + const auto ct_out = + seal(cipher_suite, key, nonce, ciphertext, aad, pt).unwrap(); REQUIRE(ct_out == ct); // Open auto plaintext = bytes(pt.data.size()); const auto pt_out = - SFRAME_VALUE_OR_THROW(open(cipher_suite, key, nonce, plaintext, aad, ct)); + open(cipher_suite, key, nonce, plaintext, aad, ct).unwrap(); REQUIRE(pt_out == pt); } }; @@ -148,16 +148,16 @@ struct SFrameTestVector { // Protect auto send_ctx = Context(cipher_suite); - send_ctx.add_key(kid, KeyUsage::protect, base_key); + send_ctx.add_key(kid, KeyUsage::protect, base_key).unwrap(); auto ct_data = owned_bytes<128>(); auto next_ctr = uint64_t(0); while (next_ctr < ctr) { - send_ctx.protect(kid, ct_data, pt, metadata); + send_ctx.protect(kid, ct_data, pt, metadata).unwrap(); next_ctr += 1; } - const auto ct_out = send_ctx.protect(kid, ct_data, pt, metadata); + const auto ct_out = send_ctx.protect(kid, ct_data, pt, metadata).unwrap(); const auto act_ct_hex = to_hex(ct_out); const auto exp_ct_hex = to_hex(ct); @@ -167,10 +167,10 @@ struct SFrameTestVector // Unprotect auto recv_ctx = Context(cipher_suite); - recv_ctx.add_key(kid, KeyUsage::unprotect, base_key); + recv_ctx.add_key(kid, KeyUsage::unprotect, base_key).unwrap(); auto pt_data = owned_bytes<128>(); - auto pt_out = recv_ctx.unprotect(pt_data, ct, metadata); + auto pt_out = recv_ctx.unprotect(pt_data, ct, metadata).unwrap(); REQUIRE(pt_out == pt); } };