diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 3da4ff9a5..5d9973c7b 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -226,6 +226,8 @@ async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool: True if the message was successfully sent and acknowledged """ + stream = None + try: result = False # Open a stream to the peer @@ -298,7 +300,8 @@ async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool: logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") finally: - await stream.close() + if stream is not None: + await stream.close() return result async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]: diff --git a/newsfragments/1335.bugfix.rst b/newsfragments/1335.bugfix.rst new file mode 100644 index 000000000..f6f0fdd11 --- /dev/null +++ b/newsfragments/1335.bugfix.rst @@ -0,0 +1 @@ +Fixed Kademlia DHT provider advertisement so a failure to open the ADD_PROVIDER stream no longer raises a secondary ``UnboundLocalError``; the underlying connection error is logged instead. diff --git a/tests/core/kad_dht/test_unit_provider_store.py b/tests/core/kad_dht/test_unit_provider_store.py index 560c56e5e..abe34a84d 100644 --- a/tests/core/kad_dht/test_unit_provider_store.py +++ b/tests/core/kad_dht/test_unit_provider_store.py @@ -564,6 +564,23 @@ async def test_provide_skip_local_peer(self): # Should only call _send_add_provider once (skip local peer) assert mock_send.call_count == 1 + @pytest.mark.trio + async def test_send_add_provider_new_stream_failure(self): + """Test _send_add_provider when new_stream fails before stream is opened.""" + mock_host = Mock() + local_peer_id = ID.from_base58("QmTest123") + remote_peer_id = ID.from_base58("QmPeer1") + + mock_host.get_id.return_value = local_peer_id + mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + mock_host.new_stream = AsyncMock(side_effect=Exception("Stream failed")) + + store = ProviderStore(host=mock_host) + result = await store._send_add_provider(remote_peer_id, b"test_key") + + assert result is False + mock_host.new_stream.assert_awaited_once() + @pytest.mark.trio async def test_find_providers_no_host(self): """Test find_providers() returns empty list when no host."""