diff --git a/CMakeLists.txt b/CMakeLists.txt index 5759660..ad45086 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ find_package(blake3 CONFIG REQUIRED) set(EXTENSION_SOURCES src/crypto_extension.cpp src/crypto_hash.cpp + src/crypto_enc.cpp src/query_farm_telemetry.cpp ) diff --git a/src/crypto_enc.cpp b/src/crypto_enc.cpp new file mode 100644 index 0000000..ccf5b5c --- /dev/null +++ b/src/crypto_enc.cpp @@ -0,0 +1,341 @@ + + +#include "crypto_extension.hpp" +#include "crypto_enc.hpp" +#include "duckdb.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar_function.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace duckdb +{ + + constexpr int MODE_ENCRYPT = 1; + constexpr int MODE_DECRYPT = 0; + constexpr size_t GCM_TAG_SIZE = 16; + + void throwOpensslError(const std::string &prefix) + { + unsigned long err = ERR_get_error(); + char err_buf[256]; + ERR_error_string_n(err, err_buf, sizeof(err_buf)); + std::string error_msg = prefix + ": " + err_buf; + throw InternalException(error_msg); + } + + std::string generate_random_string(size_t len) + { + std::string out(len, '\0'); // allocate len bytes + if (RAND_bytes(reinterpret_cast(&out[0]), len) != 1) + { + throwOpensslError("RAND_bytes failed"); + } + return out; + } + + class EvpCipherContext + { + public: + EvpCipherContext(const EVP_CIPHER *cipher_) : cipher(cipher_), ctx(EVP_CIPHER_CTX_new(), + [](EVP_CIPHER_CTX *p) + { + if (p) + { + EVP_CIPHER_CTX_free(p); + } + }) + { + } + + bool NeedsTag() + { + return (EVP_CIPHER_get_flags(cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) != 0; + } + + int IvLength() const + { + return EVP_CIPHER_iv_length(cipher); + } + + void SetIv(const unsigned char *iv_) + { + iv = iv_; + } + + void SetKey(const unsigned char *key_) + { + key = key_; + } + + void SetTag(unsigned char *tag_) + { + in_tag = tag_; + } + + std::string GetTag() + { + + std::string tag; + tag.resize(GCM_TAG_SIZE); + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, GCM_TAG_SIZE, reinterpret_cast(&tag[0]))) + { + throwOpensslError("Failed to get tag"); + } + return tag; + } + + void Init(int mode) + { + + if (!EVP_CipherInit_ex(ctx.get(), cipher, nullptr, nullptr, nullptr, mode)) + { + throwOpensslError("Failed to initialize cipher"); + } + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_IVLEN, IvLength(), nullptr)) + { + throwOpensslError("Failed to set iv length"); + } + + if (!EVP_CipherInit_ex(ctx.get(), cipher, nullptr, key, iv, mode)) + { + throwOpensslError("Failed to initialize cipher"); + } + + if (NeedsTag() && mode == MODE_DECRYPT) + { + if (in_tag == nullptr) + { + throw InvalidInputException("Decryption mode for AEAD cipher requires a tag to be set"); + } + // Decryption mode, set expected tag + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, in_tag)) + { + throwOpensslError("Failed to set expected tag"); + } + } + } + + std::string Update(const std::string &data) + { + return Update(reinterpret_cast(data.data()), data.size()); + } + + std::string Update(const unsigned char *data, size_t size) + { + int outlen = 0; + std::string out; + out.resize(size + EVP_CIPHER_block_size(cipher)); + if (EVP_CipherUpdate(ctx.get(), reinterpret_cast(&out[0]), &outlen, + data, size) != 1) + { + throwOpensslError("Failed to update cipher"); + } + out.resize(outlen); + return out; + } + + std::string Finalize() + { + + std::string result; + result.resize(EVP_CIPHER_block_size(cipher)); + int result_len = 0; + if (EVP_CipherFinal_ex(ctx.get(), reinterpret_cast(&result[0]), &result_len) != 1) + { + throwOpensslError("Failed to finalize cipher"); + } + + result.resize(result_len); + return result; + } + + EvpCipherContext(const EvpCipherContext &) = delete; + EvpCipherContext &operator=(const EvpCipherContext &) = delete; + + // Enable move operations (needed because copy is deleted) + EvpCipherContext(EvpCipherContext &&) = default; + EvpCipherContext &operator=(EvpCipherContext &&) = default; + + private: + // We don't own these + const EVP_CIPHER *cipher = nullptr; + const unsigned char *iv = nullptr; + const unsigned char *key = nullptr; + unsigned char *in_tag = nullptr; + + std::unique_ptr ctx; + }; + + // This is dramatically simplefied wrapper for an encrypted text blob + // It expects to be used with the Plaintext class + // It will do AES AEAD + class CipherText + { + + public: + /// @brief + /// @param cipher + /// @param ciphertext - this is the raw ciphertext bytes, it will be iv + encrypted bytes + tag + CipherText(const EVP_CIPHER *cipher, std::string ciphertext_) : ciphertext(ciphertext_), ctx(cipher) + { + } + + static CipherText Encrypt(const EVP_CIPHER *cipher, const unsigned char *key, const std::string &plaintext) + { + EvpCipherContext ctx(cipher); + + std::string iv = generate_random_string(ctx.IvLength()); + + ctx.SetKey(key); + ctx.SetIv(reinterpret_cast(iv.data())); + ctx.Init(MODE_ENCRYPT); // 1 for encrypt + + std::string encrypted = ctx.Update(plaintext); + encrypted += ctx.Finalize(); + + std::string ciphertext = iv + encrypted; + if (ctx.NeedsTag()) + { + std::string tag = ctx.GetTag(); + ciphertext += tag; + } + + return CipherText(cipher, ciphertext); + } + + const char *GetTag() const + { + // Tag is last 16 bytes + return ciphertext.data() + ciphertext.size() - 16; + } + + const char *GetIv() const + { + // IV is first iv_length bytes + return ciphertext.data(); + } + + const char *GetEncryptedData() const + { + // Encrypted data is between iv and tag + return ciphertext.data() + ctx.IvLength(); + } + + // std::string Update(const unsigned char *data, size_t size) + + std::string Decrypt(const unsigned char *key) + { + ctx.SetKey(key); + ctx.SetIv(reinterpret_cast(GetIv())); + ctx.SetTag(reinterpret_cast(const_cast(GetTag()))); + ctx.Init(MODE_DECRYPT); + std::string decrypted = ctx.Update(reinterpret_cast(GetEncryptedData()), + ciphertext.size() - ctx.IvLength() - 16); + decrypted += ctx.Finalize(); + return decrypted; + } + + std::string &GetValue() + { + return ciphertext; + } + + CipherText(const CipherText &) = delete; + CipherText &operator=(const CipherText &) = delete; + + // Enable move operations (needed because copy is deleted) + CipherText(CipherText &&) = default; + CipherText &operator=(CipherText &&) = default; + + std::string ciphertext; + EvpCipherContext ctx; + std::string tag = ""; + }; + + // // TODO: handle key and iv size checks + // // TODO: handle padding if necessary + // // TODO: handle output buffer size + // // TODO: handle different modes of operation + // evp_ctx.Init(cipher, key, iv); + // evp_ctx.Update(data, data_len); + // evp_ctx.Finalize(); + // } + + inline void CryptoScalarFun(DataChunk &args, ExpressionState &state, Vector &result, int mode) + { + // This is called with three arguments: + // 1. The cipher name + // 2. The key + // 3. The value + // + + auto &hash_function_name_vector = args.data[0]; + auto &key_vector = args.data[1]; + auto &value_vector = args.data[2]; + + TernaryExecutor::Execute( + hash_function_name_vector, key_vector, value_vector, result, args.size(), + [&](string_t cipher_name, string_t key, string_t value) + { + string algorithm(cipher_name.GetData(), cipher_name.GetSize()); + string key_str(key.GetData(), key.GetSize()); + string value_str(value.GetData(), value.GetSize()); + + // TODO: handle iv properly + // TODO: only allow the algorithm to be set once instead of per row + std::string algo_lower = StringUtil::Lower(algorithm); + const EVP_CIPHER *cipher = EVP_get_cipherbyname(algo_lower.c_str()); + + if (!cipher) + { + throw InvalidInputException("Invalid ciphername '" + algorithm + "'"); + } + + if (mode == 0) + { // decrypt + CipherText ct(cipher, value_str); + std::string decrypted = ct.Decrypt(reinterpret_cast(key_str.data())); + return StringVector::AddStringOrBlob(result, string_t(decrypted.data(), decrypted.size())); + } + else if (mode == 1) + { // encrypt + CipherText ct = CipherText::Encrypt(cipher, reinterpret_cast(key_str.data()), value_str); + return StringVector::AddStringOrBlob(result, string_t(ct.GetValue().data(), ct.GetValue().size())); + } + else + { + throw InternalException("Invalid mode for CryptoScalarFun"); + } + }); + } + + inline void CryptoScalarEncryptFun(DataChunk &args, ExpressionState &state, Vector &result) + { + CryptoScalarFun(args, state, result, MODE_ENCRYPT); + } + + inline void CryptoScalarDecryptFun(DataChunk &args, ExpressionState &state, Vector &result) + { + CryptoScalarFun(args, state, result, MODE_DECRYPT); + } + + void LoadCipherInternal(ExtensionLoader &loader) + { + // crypto_hash accepts VARCHAR for algorithm name and ANY type for the data to hash + auto crypto_encrypt_scalar_function = ScalarFunction("crypto_encrypt", {LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::ANY}, LogicalType::BLOB, CryptoScalarEncryptFun); + loader.RegisterFunction(crypto_encrypt_scalar_function); + + auto crypto_decrypt_scalar_function = ScalarFunction("crypto_decrypt", {LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::ANY}, LogicalType::BLOB, CryptoScalarDecryptFun); + loader.RegisterFunction(crypto_decrypt_scalar_function); + } + +}; \ No newline at end of file diff --git a/src/crypto_extension.cpp b/src/crypto_extension.cpp index 554611f..8f6dcdf 100644 --- a/src/crypto_extension.cpp +++ b/src/crypto_extension.cpp @@ -2,6 +2,7 @@ #include "crypto_extension.hpp" #include "crypto_hash.hpp" +#include "crypto_enc.hpp" #include "duckdb.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" @@ -22,10 +23,10 @@ namespace duckdb namespace { // RAII wrapper for EVP_MD_CTX - class EVPContext + class EVPMDContext { public: - EVPContext() + EVPMDContext() { ctx = EVP_MD_CTX_new(); if (ctx == nullptr) @@ -34,7 +35,7 @@ namespace duckdb } } - ~EVPContext() + ~EVPMDContext() { if (ctx) { @@ -42,8 +43,8 @@ namespace duckdb } } - EVPContext(const EVPContext &) = delete; - EVPContext &operator=(const EVPContext &) = delete; + EVPMDContext(const EVPMDContext &) = delete; + EVPMDContext &operator=(const EVPMDContext &) = delete; void Init(const EVP_MD *md) { @@ -102,7 +103,7 @@ namespace duckdb } // Helper to hash a single list element with EVP - void HashListElementEVP(EVPContext &evp_ctx, const LogicalType &child_type, + void HashListElementEVP(EVPMDContext &evp_ctx, const LogicalType &child_type, UnifiedVectorFormat &child_format, idx_t child_idx) { if (child_type == LogicalType::VARCHAR || child_type == LogicalType::BLOB) @@ -213,7 +214,7 @@ namespace duckdb } else { - EVPContext evp_ctx; + EVPMDContext evp_ctx; evp_ctx.Init(md); for (idx_t list_idx = 0; list_idx < list_entry.length; list_idx++) @@ -660,6 +661,8 @@ namespace duckdb RegisterHashAggType(agg_set, LogicalType::FLOAT); RegisterHashAggType(agg_set, LogicalType::DOUBLE); + loader.RegisterFunction(agg_set); + // crypto_hash_agg: Aggregate function for hashing multiple rows CreateAggregateFunctionInfo crypto_hash_agg_info(agg_set); crypto_hash_agg_info.descriptions.push_back({ @@ -673,13 +676,13 @@ namespace duckdb {"cryptography", "hash", "aggregate"} // categories }); loader.RegisterFunction(crypto_hash_agg_info); - + LoadCipherInternal(loader); QueryFarmSendTelemetry(loader, "crypto", "2025120201"); } void CryptoExtension::Load(ExtensionLoader &loader) { - LoadInternal(loader); + LoadInternal(loader); } std::string CryptoExtension::Name() { diff --git a/src/include/crypto_enc.hpp b/src/include/crypto_enc.hpp new file mode 100644 index 0000000..34bcd92 --- /dev/null +++ b/src/include/crypto_enc.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "duckdb.hpp" +#include "duckdb/common/exception.hpp" +#include +#include +#include + +namespace duckdb +{ + void LoadCipherInternal(ExtensionLoader &loader); + void throwOpensslError(const std::string &prefix); +}; diff --git a/test/sql/crypto_enc.test b/test/sql/crypto_enc.test new file mode 100644 index 0000000..fed8900 --- /dev/null +++ b/test/sql/crypto_enc.test @@ -0,0 +1,44 @@ +# name: test/sql/crypto_enc.test +# description: test crypto extension encryption functions +# group: [crypto] + +# Before we load the extension, this will fail +statement error +SELECT crypto_encrypt('aes-256-cbc', '01234567012345670123456701234567', 'blah blah blah'); +---- +Catalog Error: Scalar Function with name crypto_encrypt does not exist! + +# Require statement will ensure this test is run with this extension loaded +require crypto + +# Confirm the extension works - returns BLOBs that can be converted to hex +query I +SELECT to_hex(crypto_encrypt('aes-256-cbc', '01234567012345670123456701234567', 'blah blah blah')) != to_hex('blah blah blah'); +---- +true + +# Confirm the extension works - returns BLOBs that can be converted to hex +query I +SELECT crypto_decrypt('aes-256-gcm', '01234567012345670123456701234567', crypto_encrypt('aes-256-gcm', '01234567012345670123456701234567', 'blah blah blah')); +---- +blah blah blah + +# Confirm the extension works - returns BLOBs that can be converted to hex +query I +SELECT to_hex(crypto_encrypt('aes-256-gcm', '01234567012345670123456701234567', 'blah blah blah')) != to_hex('blah blah blah'); +---- +true + +# Error on invalid cipher +statement error +SELECT crypto_encrypt('aes-2x56-cbcs', '01234567012345670123456701234567', 'blah blah blah'); +---- +Invalid Input Error: Invalid ciphername 'aes-2x56-cbcs' + + +# Confirm the extension works - returns BLOBs that can be converted to hex +# query I +# SELECT crypto_decrypt('aes-256-gcm', '01234567012345670123456701234567', crypto_encrypt('aes-256-gcm', '01234567012345670123456701234A67', 'blah blah blah')); +# ---- +# blah blah blah +