diff --git a/src/llmq/signing_shares.cpp b/src/llmq/signing_shares.cpp index f337d63ff19e..4849e2940ee9 100644 --- a/src/llmq/signing_shares.cpp +++ b/src/llmq/signing_shares.cpp @@ -22,10 +22,21 @@ #include +#include #include namespace llmq { +namespace { +constexpr size_t MAX_SESSIONS_PER_PEER_FACTOR{4}; +constexpr size_t MIN_SESSIONS_PER_PEER{100}; + +size_t GetMaxSessionsForPeer(const Consensus::LLMQParams& params) +{ + return std::max(size_t(params.size) * MAX_SESSIONS_PER_PEER_FACTOR, MIN_SESSIONS_PER_PEER); +} +} // namespace + void CSigShare::UpdateKey() { key.first = this->buildSignHash().Get(); @@ -130,9 +141,32 @@ CSigSharesNodeState::Session& CSigSharesNodeState::GetOrCreateSessionFromAnn(con if (s.announced.inv.empty()) { InitSession(s, signHash, ann); } + s.receivedAnnouncement = true; return s; } +bool CSigSharesNodeState::CanCreateSessionFromAnn(const llmq::CSigSesAnn& ann, size_t maxSessions) const +{ + return sessions.count(ann.buildSignHash().Get()) != 0 || GetAnnouncementSessionCount(ann.getLlmqType()) < maxSessions; +} + +size_t CSigSharesNodeState::GetSessionCount() const +{ + return sessions.size(); +} + +size_t CSigSharesNodeState::GetSessionCount(Consensus::LLMQType llmqType) const +{ + return std::ranges::count_if(sessions, [&](const auto& kv) { return kv.second.llmqType == llmqType; }); +} + +size_t CSigSharesNodeState::GetAnnouncementSessionCount(Consensus::LLMQType llmqType) const +{ + return std::ranges::count_if(sessions, [&](const auto& kv) { + return kv.second.receivedAnnouncement && kv.second.llmqType == llmqType; + }); +} + CSigSharesNodeState::Session* CSigSharesNodeState::GetSessionBySignHash(const uint256& signHash) { auto it = sessions.find(signHash); @@ -206,7 +240,8 @@ void CSigSharesManager::UnregisterRecoveryInterface() bool CSigSharesManager::ProcessMessageSigSesAnn(const CNode& pfrom, const CSigSesAnn& ann) { auto llmqType = ann.getLlmqType(); - if (!Params().GetLLMQ(llmqType).has_value()) { + const auto& llmq_params_opt = Params().GetLLMQ(llmqType); + if (!llmq_params_opt.has_value()) { return false; } if (ann.getSessionId() == UNINITIALIZED_SESSION_ID || ann.getQuorumHash().IsNull() || ann.getId().IsNull() || ann.getMsgHash().IsNull()) { @@ -225,7 +260,15 @@ bool CSigSharesManager::ProcessMessageSigSesAnn(const CNode& pfrom, const CSigSe LOCK(cs); auto& nodeState = nodeStates[pfrom.GetId()]; + const size_t maxSessions = GetMaxSessionsForPeer(*llmq_params_opt); + if (!nodeState.CanCreateSessionFromAnn(ann, maxSessions)) { + LogPrint(BCLog::LLMQ_SIGS, "CSigSharesManager::%s -- too many sessions. cnt=%d, max=%d, llmqType=%d, node=%d\n", + __func__, nodeState.GetAnnouncementSessionCount(llmqType), maxSessions, static_cast(llmqType), pfrom.GetId()); + return true; + } + const uint256 signHash = ann.buildSignHash().Get(); auto& session = nodeState.GetOrCreateSessionFromAnn(ann); + timeSeenForSessions.try_emplace(signHash, GetTime().count()); nodeState.sessionByRecvId.erase(session.recvSessionId); nodeState.sessionByRecvId.erase(ann.getSessionId()); session.recvSessionId = ann.getSessionId(); @@ -1247,6 +1290,11 @@ void CSigSharesManager::Cleanup() doneSessions.emplace(sigShare.GetSignHash()); } }); + for (const auto& [signHash, _] : timeSeenForSessions) { + if (doneSessions.count(signHash) == 0 && sigman.HasRecoveredSigForSession(signHash)) { + doneSessions.emplace(signHash); + } + } for (const auto& signHash : doneSessions) { RemoveSigSharesForSession(signHash); } diff --git a/src/llmq/signing_shares.h b/src/llmq/signing_shares.h index 82c1e88de585..da2f39043c1b 100644 --- a/src/llmq/signing_shares.h +++ b/src/llmq/signing_shares.h @@ -326,8 +326,9 @@ class CSigSharesNodeState CSigSharesInv announced; CSigSharesInv requested; CSigSharesInv knows; + + bool receivedAnnouncement{false}; }; - // TODO limit number of sessions per node Uint256HashMap sessions; std::unordered_map sessionByRecvId; @@ -339,6 +340,10 @@ class CSigSharesNodeState Session& GetOrCreateSessionFromShare(const CSigShare& sigShare); Session& GetOrCreateSessionFromAnn(const CSigSesAnn& ann); + [[nodiscard]] bool CanCreateSessionFromAnn(const CSigSesAnn& ann, size_t maxSessions) const; + [[nodiscard]] size_t GetSessionCount() const; + [[nodiscard]] size_t GetSessionCount(Consensus::LLMQType llmqType) const; + [[nodiscard]] size_t GetAnnouncementSessionCount(Consensus::LLMQType llmqType) const; Session* GetSessionBySignHash(const uint256& signHash); Session* GetSessionByRecvId(uint32_t sessionId); bool GetSessionInfoByRecvId(uint32_t sessionId, SessionInfo& retInfo); diff --git a/src/test/llmq_utils_tests.cpp b/src/test/llmq_utils_tests.cpp index da67f4a5189a..ea51601a8b5a 100644 --- a/src/test/llmq_utils_tests.cpp +++ b/src/test/llmq_utils_tests.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,85 @@ BOOST_FIXTURE_TEST_SUITE(llmq_utils_tests, BasicTestingSetup) BOOST_AUTO_TEST_CASE(trivially_passes) { BOOST_CHECK(true); } +static CSigSesAnn MakeSigSesAnn(uint32_t session_id, uint32_t nonce, Consensus::LLMQType llmq_type = Consensus::LLMQType::LLMQ_50_60) +{ + return CSigSesAnn{session_id, llmq_type, GetTestQuorumHash(1), GetTestQuorumHash(2), GetTestQuorumHash(nonce)}; +} + +static CSigShare MakeSigShare(uint32_t nonce, Consensus::LLMQType llmq_type = Consensus::LLMQType::LLMQ_50_60) +{ + CSigShare sig_share{llmq_type, GetTestQuorumHash(1), GetTestQuorumHash(2), GetTestQuorumHash(nonce), 1, CBLSLazySignature{}}; + sig_share.UpdateKey(); + return sig_share; +} + +BOOST_AUTO_TEST_CASE(sig_ses_ann_respects_session_limit_but_allows_refresh) +{ + CSigSharesNodeState node_state; + + const CSigSesAnn ann1{MakeSigSesAnn(1, 1)}; + const CSigSesAnn ann2{MakeSigSesAnn(2, 2)}; + const CSigSesAnn ann3{MakeSigSesAnn(3, 3)}; + constexpr size_t max_sessions{2}; + + BOOST_CHECK(node_state.CanCreateSessionFromAnn(ann1, max_sessions)); + node_state.GetOrCreateSessionFromAnn(ann1); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(), 1U); + BOOST_CHECK_EQUAL(node_state.GetAnnouncementSessionCount(Consensus::LLMQType::LLMQ_50_60), 1U); + + BOOST_CHECK(node_state.CanCreateSessionFromAnn(ann2, max_sessions)); + node_state.GetOrCreateSessionFromAnn(ann2); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(), max_sessions); + BOOST_CHECK_EQUAL(node_state.GetAnnouncementSessionCount(Consensus::LLMQType::LLMQ_50_60), max_sessions); + + BOOST_CHECK(!node_state.CanCreateSessionFromAnn(ann3, max_sessions)); + + const CSigSesAnn ann1_refresh{4, Consensus::LLMQType::LLMQ_50_60, ann1.getQuorumHash(), ann1.getId(), ann1.getMsgHash()}; + BOOST_CHECK(node_state.CanCreateSessionFromAnn(ann1_refresh, max_sessions)); + node_state.GetOrCreateSessionFromAnn(ann1_refresh); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(), max_sessions); + BOOST_CHECK_EQUAL(node_state.GetAnnouncementSessionCount(Consensus::LLMQType::LLMQ_50_60), max_sessions); +} + +BOOST_AUTO_TEST_CASE(sig_ses_ann_limit_ignores_send_only_sessions) +{ + CSigSharesNodeState node_state; + + constexpr size_t max_sessions{1}; + const CSigShare sig_share{MakeSigShare(1)}; + const CSigSesAnn ann{MakeSigSesAnn(1, 2)}; + + node_state.GetOrCreateSessionFromShare(sig_share); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(Consensus::LLMQType::LLMQ_50_60), 1U); + BOOST_CHECK_EQUAL(node_state.GetAnnouncementSessionCount(Consensus::LLMQType::LLMQ_50_60), 0U); + + BOOST_CHECK(node_state.CanCreateSessionFromAnn(ann, max_sessions)); + node_state.GetOrCreateSessionFromAnn(ann); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(Consensus::LLMQType::LLMQ_50_60), 2U); + BOOST_CHECK_EQUAL(node_state.GetAnnouncementSessionCount(Consensus::LLMQType::LLMQ_50_60), 1U); +} + +BOOST_AUTO_TEST_CASE(sig_ses_ann_limit_is_per_llmq_type) +{ + CSigSharesNodeState node_state; + + constexpr size_t max_sessions{1}; + const CSigSesAnn ann1{MakeSigSesAnn(1, 1)}; + const CSigSesAnn ann2{MakeSigSesAnn(2, 2)}; + const CSigSesAnn other_type_ann{MakeSigSesAnn(3, 3, Consensus::LLMQType::LLMQ_400_60)}; + + BOOST_CHECK(node_state.CanCreateSessionFromAnn(ann1, max_sessions)); + node_state.GetOrCreateSessionFromAnn(ann1); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(), 1U); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(Consensus::LLMQType::LLMQ_50_60), 1U); + + BOOST_CHECK(!node_state.CanCreateSessionFromAnn(ann2, max_sessions)); + BOOST_CHECK(node_state.CanCreateSessionFromAnn(other_type_ann, max_sessions)); + node_state.GetOrCreateSessionFromAnn(other_type_ann); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(), 2U); + BOOST_CHECK_EQUAL(node_state.GetSessionCount(Consensus::LLMQType::LLMQ_400_60), 1U); +} + BOOST_AUTO_TEST_CASE(deterministic_outbound_connection_test) { // Test deterministic behavior