Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions include/sframe/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class SFrameError
const char* message_ = nullptr;
};

// Helper to convert SFrameError to appropriate exception type
#ifdef __cpp_exceptions
void
throw_on_error(const SFrameError& error);
throw_sframe_error(const SFrameError& error);
#endif

template<typename T>
class Result
Expand Down Expand Up @@ -96,6 +97,17 @@ class Result

bool is_err() const { return std::holds_alternative<SFrameError>(data_); }

#ifdef __cpp_exceptions
T unwrap()
{
if (std::holds_alternative<SFrameError>(data_)) {
throw_sframe_error(std::get<SFrameError>(data_));
}

return std::move(std::get<T>(data_));
}
#endif

private:
std::variant<T, SFrameError> data_;
};
Expand Down Expand Up @@ -135,24 +147,21 @@ class Result<void>

bool is_err() const { return error_.has_value(); }

#ifdef __cpp_exceptions
void unwrap()
{
if (error_.has_value()) {
throw_sframe_error(error_.value());
}
}
#endif

private:
std::optional<SFrameError> error_;
};

} // namespace SFRAME_NAMESPACE

// Unwrap a Result<T>, throwing the corresponding exception on error.
// Use in functions that have NOT yet been migrated away from exceptions.
// Usage: const auto val = SFRAME_VALUE_OR_THROW(some_result_expr);
#define SFRAME_VALUE_OR_THROW(expr) \
([&]() { \
auto _result = (expr); \
if (_result.is_err()) { \
SFRAME_NAMESPACE::throw_on_error(_result.error()); \
} \
return _result.value(); \
}())

// Unwrap a Result<T> into `var`, propagating the error by early return.
// Use in functions that already return Result<U>.
// Usage: SFRAME_VALUE_OR_RETURN(val, some_result_expr);
Expand Down
79 changes: 44 additions & 35 deletions include/sframe/sframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <sframe/result.h>
#include <sframe/vector.h>

#ifdef __cpp_exceptions
#include <stdexcept>
#endif

#include <namespace.h>

// These constants define the size of certain internal data structures if
Expand All @@ -28,6 +32,7 @@

namespace SFRAME_NAMESPACE {

#ifdef __cpp_exceptions
struct crypto_error : std::runtime_error
{
crypto_error();
Expand Down Expand Up @@ -60,6 +65,7 @@ struct invalid_key_usage_error : std::runtime_error
using parent = std::runtime_error;
using parent::parent;
};
#endif

enum class CipherSuite : uint16_t
{
Expand Down Expand Up @@ -111,15 +117,15 @@ class Context
Context(CipherSuite suite);
virtual ~Context();

void add_key(KeyID kid, KeyUsage usage, input_bytes key);
Result<void> add_key(KeyID kid, KeyUsage usage, input_bytes key);

output_bytes protect(KeyID key_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);
output_bytes unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata);
Result<output_bytes> protect(KeyID key_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);
Result<output_bytes> unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata);

static constexpr size_t max_overhead = 17 + 16;
static constexpr size_t max_metadata_size = 512;
Expand Down Expand Up @@ -150,29 +156,30 @@ class MLSContext : protected Context

MLSContext(CipherSuite suite_in, size_t epoch_bits_in);

void add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret);
void add_epoch(EpochID epoch_id,
input_bytes sframe_epoch_secret,
size_t sender_bits);
Result<void> add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret);
Result<void> add_epoch(EpochID epoch_id,
input_bytes sframe_epoch_secret,
size_t sender_bits);
void purge_before(EpochID keeper);

output_bytes protect(EpochID epoch_id,
SenderID sender_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);
output_bytes protect(EpochID epoch_id,
SenderID sender_id,
ContextID context_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);

output_bytes unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata);
Result<output_bytes> protect(EpochID epoch_id,
SenderID sender_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);
Result<output_bytes> protect(EpochID epoch_id,
SenderID sender_id,
ContextID context_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata);

Result<output_bytes> unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata);

private:
// NOLINTBEGIN(clang-analyzer-core.uninitialized.Assign)
struct EpochKeys
{
static constexpr size_t max_secret_size = 64;
Expand All @@ -184,20 +191,22 @@ class MLSContext : protected Context
uint64_t max_sender_id;
uint64_t max_context_id;

EpochKeys(EpochID full_epoch_in,
input_bytes sframe_epoch_secret_in,
size_t epoch_bits,
size_t sender_bits_in);
EpochKeys() = default;
static Result<EpochKeys> create(EpochID full_epoch_in,
input_bytes sframe_epoch_secret_in,
size_t epoch_bits,
size_t sender_bits_in);
Result<owned_bytes<max_secret_size>> base_key(CipherSuite suite,
SenderID sender_id) const;
};
// NOLINTEND(clang-analyzer-core.uninitialized.Assign)

void purge_epoch(EpochID epoch_id);

KeyID form_key_id(EpochID epoch_id,
SenderID sender_id,
ContextID context_id) const;
void ensure_key(KeyID key_id, KeyUsage usage);
Result<KeyID> form_key_id(EpochID epoch_id,
SenderID sender_id,
ContextID context_id) const;
Result<void> ensure_key(KeyID key_id, KeyUsage usage);

const size_t epoch_bits;
const size_t epoch_mask;
Expand Down
2 changes: 2 additions & 0 deletions src/crypto_boringssl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE {
/// Convert between native identifiers / errors and OpenSSL ones
///

#ifdef __cpp_exceptions
crypto_error::crypto_error()
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
{
}
#endif

static Result<const EVP_MD*>
openssl_digest_type(CipherSuite suite)
Expand Down
2 changes: 2 additions & 0 deletions src/crypto_openssl11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ using scoped_hmac_ctx = std::unique_ptr<HMAC_CTX, decltype(&HMAC_CTX_free)>;
/// Convert between native identifiers / errors and OpenSSL ones
///

#ifdef __cpp_exceptions
crypto_error::crypto_error()
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
{
}
#endif

static Result<const EVP_MD*>
openssl_digest_type(CipherSuite suite)
Expand Down
2 changes: 2 additions & 0 deletions src/crypto_openssl3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ namespace SFRAME_NAMESPACE {
/// Convert between native identifiers / errors and OpenSSL ones
///

#ifdef __cpp_exceptions
crypto_error::crypto_error()
: std::runtime_error(ERR_error_string(ERR_get_error(), nullptr))
{
}
#endif

static Result<const EVP_CIPHER*>
openssl_cipher(CipherSuite suite)
Expand Down
20 changes: 16 additions & 4 deletions src/result.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
#include <sframe/result.h>
#include <sframe/sframe.h>

namespace SFRAME_NAMESPACE {

#ifdef __cpp_exceptions
unsupported_ciphersuite_error::unsupported_ciphersuite_error()
: std::runtime_error("Unsupported ciphersuite")
{
}

authentication_error::authentication_error()
: std::runtime_error("AEAD authentication failure")
{
}

void
throw_on_error(const SFrameError& error)
throw_sframe_error(const SFrameError& error)
{
switch (error.type()) {
case SFrameErrorType::internal_error:
throw std::runtime_error(error.message() ? error.message()
: "SFrame internal error");
case SFrameErrorType::buffer_too_small_error:
throw buffer_too_small_error(error.message());
case SFrameErrorType::invalid_parameter_error:
Expand All @@ -19,9 +32,8 @@ throw_on_error(const SFrameError& error)
throw authentication_error();
case SFrameErrorType::invalid_key_usage_error:
throw invalid_key_usage_error(error.message());
default:
throw std::runtime_error(error.message());
}
}
#endif

} // namespace SFRAME_NAMESPACE
Loading
Loading