diff --git a/lib/cpp/README.md b/lib/cpp/README.md index 74983aeb1f..a9959af8c2 100644 --- a/lib/cpp/README.md +++ b/lib/cpp/README.md @@ -135,6 +135,11 @@ TSimpleServer, TThreadedServer, and TThreadPoolServer. There are two main classes TSSLSocketFactory and TSSLSocket. Instances of TSSLSocket are always created from TSSLSocketFactory. +The default TSSLSocketFactory context uses OpenSSL's version-flexible TLS +method and sets TLS 1.2 as the minimum negotiated protocol version. Applications +that need a different protocol range can provide a custom SSLContext factory and +adjust the OpenSSL context options before creating sockets. + ## How to use SSL APIs See the TestClient.cpp and TestServer.cpp files for examples. @@ -319,4 +324,3 @@ assertion or a core instead of undefined behavior. The lifetime of a TSSLSocket up too early. If the static boolean is set to disable openssl initialization and cleanup and leave it up to the consuming application, this requirement is not needed. (THRIFT-4164) - diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp index f97962ea5b..ac88acd926 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp @@ -204,11 +204,12 @@ SSLContext::SSLContext(const SSLProtocol& protocol) { } SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY); - // Disable horribly insecure SSLv2 and SSLv3 protocols but allow a handshake - // with older clients so they get a graceful denial. + // Keep version-flexible negotiation for current protocol versions while setting + // the default protocol floor at TLSv1.2. if (protocol == SSLTLS) { - SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv2); - SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3); // THRIFT-3164 + SSL_CTX_set_options(ctx_, + SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 + | SSL_OP_NO_TLSv1_1); } } @@ -884,6 +885,33 @@ bool TSSLSocketFactory::manualOpenSSLInitialization_ = false; bool TSSLSocketFactory::didWeInitializeOpenSSL_ = false; TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol) : server_(false) { + initializeOpenSSLState(); + try { + ctx_ = std::make_shared(protocol); + } catch (...) { + cleanupOpenSSLState(); + throw; + } +} + +TSSLSocketFactory::TSSLSocketFactory(const SSLContextFactory& contextFactory) : server_(false) { + if (!contextFactory) { + throw TSSLException("SSLContextFactory must not be empty"); + } + initializeOpenSSLState(); + try { + std::shared_ptr ctx = contextFactory(); + if (ctx == nullptr) { + throw TSSLException("SSLContextFactory must not return null"); + } + ctx_ = ctx; + } catch (...) { + cleanupOpenSSLState(); + throw; + } +} + +void TSSLSocketFactory::initializeOpenSSLState() { Guard guard(mutex_); if (count_ == 0) { if (!manualOpenSSLInitialization_) { @@ -893,13 +921,18 @@ TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol) : server_(false) { randomize(); } count_++; - ctx_ = std::make_shared(protocol); } TSSLSocketFactory::~TSSLSocketFactory() { + cleanupOpenSSLState(); +} + +void TSSLSocketFactory::cleanupOpenSSLState() { Guard guard(mutex_); ctx_.reset(); - count_--; + if (count_ > 0) { + count_--; + } if (count_ == 0 && didWeInitializeOpenSSL_) { cleanupOpenSSL(); didWeInitializeOpenSSL_ = false; diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h index 45502bd5ec..e6b8992451 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.h +++ b/lib/cpp/src/thrift/transport/TSSLSocket.h @@ -23,6 +23,7 @@ // Put this first to avoid WIN32 build failure #include +#include #include #include #include @@ -33,9 +34,10 @@ namespace transport { class AccessManager; class SSLContext; +typedef std::function()> SSLContextFactory; enum SSLProtocol { - SSLTLS = 0, // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later. + SSLTLS = 0, // Supports version-flexible TLS negotiation with TLSv1_2 as the default floor. //SSLv2 = 1, // HORRIBLY INSECURE! SSLv3 = 2, // Supports SSLv3 only - also horribly insecure! TLSv1_0 = 3, // Supports TLSv1_0 or later. @@ -210,6 +212,12 @@ class TSSLSocketFactory { * @param protocol The SSL/TLS protocol to use. */ TSSLSocketFactory(SSLProtocol protocol = SSLTLS); + /** + * Constructor + * + * @param contextFactory Function invoked during construction to return a custom OpenSSL context. + */ + TSSLSocketFactory(const SSLContextFactory& contextFactory); virtual ~TSSLSocketFactory(); /** * Create an instance of TSSLSocket with a fresh new socket. @@ -328,6 +336,8 @@ class TSSLSocketFactory { static uint64_t count_; static bool manualOpenSSLInitialization_; static bool didWeInitializeOpenSSL_; // in that case we also perform de-init + void initializeOpenSSLState(); + void cleanupOpenSSLState(); void setup(std::shared_ptr ssl); static int passwordCallback(char* password, int size, int, void* data); }; diff --git a/lib/cpp/test/SecurityFromBufferTest.cpp b/lib/cpp/test/SecurityFromBufferTest.cpp index 32f2378d38..08f76b3f25 100644 --- a/lib/cpp/test/SecurityFromBufferTest.cpp +++ b/lib/cpp/test/SecurityFromBufferTest.cpp @@ -214,11 +214,11 @@ BOOST_AUTO_TEST_CASE(ssl_security_matrix) { { // server = SSLTLS SSLv2 SSLv3 TLSv1_0 TLSv1_1 TLSv1_2 // client - /* SSLTLS */ { true, false, false, true, true, true }, + /* SSLTLS */ { true, false, false, false, false, true }, /* SSLv2 */ { false, false, false, false, false, false }, /* SSLv3 */ { false, false, true, false, false, false }, - /* TLSv1_0 */ { true, false, false, true, false, false }, - /* TLSv1_1 */ { true, false, false, false, true, false }, + /* TLSv1_0 */ { false, false, false, true, false, false }, + /* TLSv1_1 */ { false, false, false, false, true, false }, /* TLSv1_2 */ { true, false, false, false, false, true } }; diff --git a/lib/cpp/test/SecurityTest.cpp b/lib/cpp/test/SecurityTest.cpp index cc71f04793..d64b0da844 100644 --- a/lib/cpp/test/SecurityTest.cpp +++ b/lib/cpp/test/SecurityTest.cpp @@ -34,6 +34,8 @@ #endif using apache::thrift::transport::TSSLServerSocket; +using apache::thrift::transport::SSLContextFactory; +using apache::thrift::transport::TSSLException; using apache::thrift::transport::TServerTransport; using apache::thrift::transport::TSSLSocket; using apache::thrift::transport::TSSLSocketFactory; @@ -226,6 +228,78 @@ struct SecurityFixture BOOST_FIXTURE_TEST_SUITE(BOOST_TEST_MODULE, SecurityFixture) +BOOST_AUTO_TEST_CASE(default_ssl_context_options) +{ + apache::thrift::transport::SSLContext context; + const auto options = SSL_CTX_get_options(context.get()); + + if (SSL_OP_NO_SSLv2 != 0) { + BOOST_CHECK((options & SSL_OP_NO_SSLv2) != 0); + } + if (SSL_OP_NO_SSLv3 != 0) { + BOOST_CHECK((options & SSL_OP_NO_SSLv3) != 0); + } + if (SSL_OP_NO_TLSv1 != 0) { + BOOST_CHECK((options & SSL_OP_NO_TLSv1) != 0); + } + if (SSL_OP_NO_TLSv1_1 != 0) { + BOOST_CHECK((options & SSL_OP_NO_TLSv1_1) != 0); + } +} + +BOOST_AUTO_TEST_CASE(custom_ssl_context_options) +{ + class CustomSSLContext : public apache::thrift::transport::SSLContext + { + public: + CustomSSLContext() : SSLContext() + { + SSL_CTX_clear_options(get(), SSL_OP_NO_TLSv1_1); + } + }; + + std::shared_ptr context; + TSSLSocketFactory factory([&context]() { + context = std::make_shared(); + return context; + }); + const auto options = SSL_CTX_get_options(context->get()); + + if (SSL_OP_NO_TLSv1 != 0) { + BOOST_CHECK((options & SSL_OP_NO_TLSv1) != 0); + } + if (SSL_OP_NO_TLSv1_1 != 0) { + BOOST_CHECK((options & SSL_OP_NO_TLSv1_1) == 0); + } + context.reset(); +} + +BOOST_AUTO_TEST_CASE(custom_ssl_context_factory_validation) +{ + try + { + SSLContextFactory contextFactory; + TSSLSocketFactory factory(contextFactory); + BOOST_FAIL("Expected empty SSLContextFactory to throw"); + } + catch (const TSSLException& ex) + { + BOOST_CHECK_EQUAL("SSLContextFactory must not be empty", std::string(ex.what())); + } + + try + { + TSSLSocketFactory factory([]() { + return std::shared_ptr(); + }); + BOOST_FAIL("Expected null SSLContextFactory result to throw"); + } + catch (const TSSLException& ex) + { + BOOST_CHECK_EQUAL("SSLContextFactory must not return null", std::string(ex.what())); + } +} + BOOST_AUTO_TEST_CASE(ssl_security_matrix) { try @@ -236,11 +310,11 @@ BOOST_AUTO_TEST_CASE(ssl_security_matrix) { // server = SSLTLS SSLv2 SSLv3 TLSv1_0 TLSv1_1 TLSv1_2 // client - /* SSLTLS */ { true, false, false, true, true, true }, + /* SSLTLS */ { true, false, false, false, false, true }, /* SSLv2 */ { false, false, false, false, false, false }, /* SSLv3 */ { false, false, true, false, false, false }, - /* TLSv1_0 */ { true, false, false, true, false, false }, - /* TLSv1_1 */ { true, false, false, false, true, false }, + /* TLSv1_0 */ { false, false, false, true, false, false }, + /* TLSv1_1 */ { false, false, false, false, true, false }, /* TLSv1_2 */ { true, false, false, false, false, true } };