diff --git a/examples/bitswap/bitswap.py b/examples/bitswap/bitswap.py index 1a9c31cac..24437a772 100755 --- a/examples/bitswap/bitswap.py +++ b/examples/bitswap/bitswap.py @@ -76,9 +76,8 @@ async def run_provider(file_path: str, port: int = 0): # Create host host = new_host() - async with host.run(listen_addrs=listen_addrs): - peer_id = host.get_id() - logger.info(f"Peer ID: {peer_id}") + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: + logger.info(f"Peer ID: {host.get_id()}") # Get actual listening addresses addrs = host.get_addrs() @@ -91,7 +90,8 @@ async def run_provider(file_path: str, port: int = 0): await bitswap.start() logger.info("✓ Bitswap started") - # Create Merkle DAG + # Set nursery so bitswap can spawn background tasks + bitswap.set_nursery(nursery) dag = MerkleDag(bitswap) logger.info("") @@ -198,13 +198,14 @@ async def run_client( # Create host host = new_host() - async with host.run(listen_addrs=listen_addrs): + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: logger.info(f"Client Peer ID: {host.get_id()}") # Start Bitswap bitswap = BitswapClient(host) await bitswap.start() logger.info("✓ Bitswap started") + bitswap.set_nursery(nursery) try: # Connect to provider @@ -214,7 +215,6 @@ async def run_client( await host.connect(peer_info) logger.info("✓ Connected") - # Create Merkle DAG dag = MerkleDag(bitswap) logger.info("") @@ -232,7 +232,7 @@ def progress_callback(current: int, total: int, status: str): # Fetch file with automatic filename extraction try: file_data, filename = await dag.fetch_file( - root_cid, progress_callback=progress_callback + root_cid, progress_callback=progress_callback, timeout=120.0 ) # Show fetch statistics @@ -284,18 +284,18 @@ def progress_callback(current: int, total: int, status: str): logger.info("=" * 70) logger.info(f"Size: {format_size(len(file_data))}") - # Determine output filename + # Determine output filename (priority: metadata > generated) if filename: - output_filename = filename - logger.info(f"Filename: {filename} (from metadata)") + final_filename = filename + logger.info(f"Filename: {final_filename} (from metadata)") else: - output_filename = ( + final_filename = ( f"file_{format_cid_for_display(root_cid, max_len=16)}.bin" ) - logger.info(f"Filename: {output_filename} (no metadata)") + logger.info(f"Filename: {final_filename} (generated from CID)") # Handle filename conflicts - output_file = output_path / output_filename + output_file = output_path / final_filename if output_file.exists(): stem = output_file.stem suffix = output_file.suffix @@ -315,7 +315,9 @@ def progress_callback(current: int, total: int, status: str): except Exception as e: logger.error(f"Failed: {e}") logger.exception("Full traceback:") + raise finally: + pass # Nursery will cleanup background tasks await bitswap.stop() diff --git a/libp2p/bitswap/__init__.py b/libp2p/bitswap/__init__.py index 756ad5793..dcad9d1aa 100644 --- a/libp2p/bitswap/__init__.py +++ b/libp2p/bitswap/__init__.py @@ -31,7 +31,8 @@ New code should prefer the object-returning variants above. """ -from .block_store import BlockStore, MemoryBlockStore +from .block_service import BlockService +from .block_store import BlockStore, FilesystemBlockStore, MemoryBlockStore from .cid import ( CID_V0, CID_V1, @@ -65,12 +66,29 @@ MessageTooLargeError, TimeoutError, ) +from .wantlist import ( + BitswapMessage, + BlockPresence, + BlockPresenceType, + Wantlist, + WantlistEntry, + WantType, +) __all__ = [ # Core "BitswapClient", + "BlockService", "BlockStore", "MemoryBlockStore", + "FilesystemBlockStore", + # Messages + "BitswapMessage", + "BlockPresence", + "BlockPresenceType", + "Wantlist", + "WantlistEntry", + "WantType", # CID types "CIDInput", "CIDObject", diff --git a/libp2p/bitswap/block_service.py b/libp2p/bitswap/block_service.py new file mode 100644 index 000000000..c4e452d9a --- /dev/null +++ b/libp2p/bitswap/block_service.py @@ -0,0 +1,196 @@ +""" +BlockService: transparent local→network fallback for block retrieval. + +Sits between MerkleDag and BitswapClient, providing: + - Local-first lookup (no network cost if block is already stored) + - Automatic caching of network-fetched blocks into the local store + - Peer announcement when new blocks are stored locally + - A clean abstraction so MerkleDag is not hardwired to BitswapClient +""" + +from __future__ import annotations + +from collections.abc import Sequence +import logging +from typing import TYPE_CHECKING + +from .block_store import BlockStore +from .cid import CIDInput, cid_to_bytes, format_cid_for_display, parse_cid + +if TYPE_CHECKING: + from libp2p.peer.id import ID as PeerID + + from .client import BitswapClient + +logger = logging.getLogger(__name__) + + +class BlockService: + """ + Combines a local BlockStore with a BitswapClient into one unified interface. + + get_block() flow: + 1. Check local BlockStore → return immediately if found (no network) + 2. Fetch via BitswapClient → goes to the network + 3. Auto-cache the result → store locally so next call is free + + put_block() flow: + 1. Write to local BlockStore + 2. Call bitswap.add_block() so peers who have this CID in their + wantlist are notified and can receive it + + This is a drop-in wrapper: MerkleDag can use BlockService instead of + calling bitswap directly, and the behaviour is identical but with the + caching and announcement benefits added transparently. + + Example: + >>> store = FilesystemBlockStore("./blocks") + >>> service = BlockService(store, bitswap) + >>> dag = MerkleDag(bitswap, block_service=service) + + """ + + def __init__(self, store: BlockStore, bitswap: BitswapClient) -> None: + self.store = store + self.bitswap = bitswap + + async def get_block( + self, + cid: CIDInput, + peer_id: PeerID | None = None, + timeout: float = 30.0, + ) -> bytes | None: + """ + Get a block. Checks local store first, then fetches from network. + Any block fetched from the network is automatically cached locally. + + Args: + cid: The CID of the block to retrieve + peer_id: Optional specific peer to fetch from (passed to bitswap) + timeout: Network timeout in seconds + + Returns: + Block data bytes, or None if not found anywhere + + """ + cid_bytes = cid_to_bytes(cid) + cid_obj = parse_cid(cid_bytes) + + # 1. Local lookup — instant, no network cost + data = await self.store.get_block(cid_obj) + if data is not None: + logger.debug( + f"BlockService: local hit {format_cid_for_display(cid_obj, max_len=12)}" + ) + return data + + # 2. Network fetch via Bitswap + logger.debug( + f"BlockService: local miss, fetching from network " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + try: + data = await self.bitswap.get_block(cid_bytes, peer_id, timeout) + except Exception as e: + logger.warning(f"BlockService: network fetch failed: {e}") + return None + + if data is not None: + # 3. Auto-cache locally — future requests for this block are free + await self.store.put_block(cid_obj, data) + logger.debug( + f"BlockService: cached fetched block " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + + return data + + async def put_block(self, cid: CIDInput, data: bytes) -> None: + """ + Store a block locally and announce it to waiting peers. + + Calling bitswap.add_block() both writes to bitswap's own store AND + notifies any peers who have this CID in their pending wantlist. + We also write to our own store so get_block() local-hits on it. + + Args: + cid: The CID of the block + data: The block data bytes + + """ + cid_obj = parse_cid(cid_to_bytes(cid)) + + # Write to our local store + await self.store.put_block(cid_obj, data) + + # add_block() writes to bitswap's internal store AND calls + # _notify_peers_about_block() for any peers waiting on this CID + await self.bitswap.add_block(cid_obj, data) + + logger.debug( + f"BlockService: stored and announced " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + + async def get_blocks_batch( + self, + cids: Sequence[CIDInput], + peer_id: PeerID | None = None, + timeout: float = 30.0, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """ + Batch-fetch multiple blocks. Local hits are returned immediately; + only missing blocks go to the network. All network-fetched blocks + are auto-cached locally. + + Args: + cids: List of CIDs to fetch + peer_id: Optional specific peer to fetch from + timeout: Network timeout in seconds + batch_size: Wantlist batch size passed to bitswap + + Returns: + Dict mapping cid_bytes -> block_data for all found blocks + + """ + results: dict[bytes, bytes] = {} + missing_cids: list[CIDInput] = [] + + # Local pass first + for cid in cids: + cid_bytes = cid_to_bytes(cid) + cid_obj = parse_cid(cid_bytes) + data = await self.store.get_block(cid_obj) + if data is not None: + results[cid_bytes] = data + else: + missing_cids.append(cid) + + if not missing_cids: + logger.debug(f"BlockService.get_blocks_batch: all {len(cids)} blocks local") + return results + + local_hits = len(cids) - len(missing_cids) + logger.debug( + f"BlockService.get_blocks_batch: {local_hits} local hits, " + f"{len(missing_cids)} fetching from network" + ) + + # Network pass for missing blocks + network_results = await self.bitswap.get_blocks_batch( + missing_cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + + # Auto-cache all network-fetched blocks + for cid_bytes, data in network_results.items(): + cid_obj = parse_cid(cid_bytes) + await self.store.put_block(cid_obj, data) + results[cid_bytes] = data + + return results + + @property + def block_store(self) -> BlockStore: + """Expose the underlying BlockStore (used by MerkleDag internals).""" + return self.store diff --git a/libp2p/bitswap/block_store.py b/libp2p/bitswap/block_store.py index 12eee5aab..bc36269ce 100644 --- a/libp2p/bitswap/block_store.py +++ b/libp2p/bitswap/block_store.py @@ -3,6 +3,9 @@ """ from abc import ABC, abstractmethod +from pathlib import Path + +import trio from .cid import CIDInput, CIDObject, parse_cid @@ -118,3 +121,99 @@ def get_all_cids(self) -> list[bytes]: def size(self) -> int: """Get the number of blocks in the store.""" return len(self._blocks) + + +class FilesystemBlockStore(BlockStore): + """ + Filesystem-based block store. Persists blocks to disk as files. + + Each block is stored as a file at: + // + + This two-level directory structure avoids having too many files in a + single directory and matches the layout used by py-ipfs-lite. + + Args: + base_path: Root directory for block storage. Created if it doesn't exist. + + Example: + >>> store = FilesystemBlockStore("/var/lib/myapp/blocks") + >>> bitswap = BitswapClient(host, store) + >>> # Blocks now survive process restarts! + + >>> # Drop-in replacement for MemoryBlockStore: + >>> # store = MemoryBlockStore() # before + >>> store = FilesystemBlockStore("./blocks") # after — persistent + + """ + + def __init__(self, base_path: str | Path) -> None: + """Initialize the filesystem block store.""" + self._path = Path(base_path) + self._path.mkdir(parents=True, exist_ok=True) + + def _cid_to_path(self, cid: CIDInput) -> Path: + """Convert a CID to a filesystem path using 2-char prefix directories.""" + cid_str = str(_normalize_cid(cid)) + # e.g. bafybeiabc... → /ba/fybeiabc... + return self._path / cid_str[:2] / cid_str[2:] + + async def get_block(self, cid: CIDInput) -> bytes | None: + """Get a block by CID. Returns None if not found on disk.""" + path = self._cid_to_path(cid) + if not path.exists(): + return None + return await trio.to_thread.run_sync(path.read_bytes) + + async def put_block(self, cid: CIDInput, data: bytes) -> None: + """Write a block to disk.""" + path = self._cid_to_path(cid) + await trio.to_thread.run_sync( + lambda: path.parent.mkdir(parents=True, exist_ok=True) + ) + await trio.to_thread.run_sync(path.write_bytes, data) + + async def has_block(self, cid: CIDInput) -> bool: + """Check if a block file exists on disk.""" + return self._cid_to_path(cid).exists() + + async def delete_block(self, cid: CIDInput) -> None: + """Delete a block file from disk.""" + path = self._cid_to_path(cid) + if path.exists(): + await trio.to_thread.run_sync(path.unlink) + + def get_all_cids(self) -> list[bytes]: + """Return all stored CIDs as bytes by scanning the directory tree.""" + cids: list[bytes] = [] + if not self._path.exists(): + return cids + for subdir in self._path.iterdir(): + if not subdir.is_dir(): + continue + for entry in subdir.iterdir(): + if not entry.is_file(): + continue + cid_str = subdir.name + entry.name + try: + cid_obj = _normalize_cid(cid_str) + cids.append(cid_obj.buffer) + except Exception: + pass # skip files that aren't valid CIDs + return cids + + def size(self) -> int: + """Return the number of stored blocks.""" + if not self._path.exists(): + return 0 + return sum( + 1 + for d in self._path.iterdir() + if d.is_dir() + for f in d.iterdir() + if f.is_file() + ) + + def base_path(self) -> Path: + """Return the root directory where blocks are stored.""" + return self._path diff --git a/libp2p/bitswap/chunker.py b/libp2p/bitswap/chunker.py index 10cb869b0..4739da3de 100644 --- a/libp2p/bitswap/chunker.py +++ b/libp2p/bitswap/chunker.py @@ -7,10 +7,13 @@ """ from collections.abc import Callable, Iterator +import io from pathlib import Path # Default chunk size: 63 KB (py-libp2p accepts less than 64 KB) -DEFAULT_CHUNK_SIZE = 63 * 1024 +# 63 KB minus 32 bytes to leave room for the dag-pb leaf envelope overhead, +# ensuring wrapped blocks never exceed MAX_BLOCK_SIZE (63 * 1024). +DEFAULT_CHUNK_SIZE = 63 * 1024 - 32 def chunk_bytes(data: bytes, chunk_size: int = DEFAULT_CHUNK_SIZE) -> list[bytes]: @@ -82,6 +85,49 @@ def chunk_file(file_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> Iterator yield chunk +def chunk_stream( + stream: io.IOBase, chunk_size: int = DEFAULT_CHUNK_SIZE +) -> Iterator[bytes]: + """ + Stream chunks from any readable io.IOBase object. + + Memory efficient — reads one chunk at a time without loading the + entire content into memory. Works with any Python stream: + open() file handles, BytesIO, GzipFile, BZ2File, network sockets, + or any object that implements io.IOBase.read(). + + Args: + stream: Any readable io.IOBase (open(), BytesIO, GzipFile, etc.) + chunk_size: Size of each chunk in bytes + + Yields: + Chunks of up to chunk_size bytes. The final chunk may be smaller. + + Example: + >>> import io + >>> data = b"hello world " * 100000 + >>> chunks = list(chunk_stream(io.BytesIO(data), chunk_size=256*1024)) + >>> print(f"Split into {len(chunks)} chunks") + + >>> # From a real file handle + >>> with open("movie.mp4", "rb") as f: + ... for chunk in chunk_stream(f): + ... process(chunk) + + >>> # From a gzip stream (decompress on-the-fly) + >>> import gzip + >>> with gzip.open("archive.gz", "rb") as f: + ... for chunk in chunk_stream(f): + ... process(chunk) + + """ + while True: + chunk = stream.read(chunk_size) + if not chunk: + break + yield chunk + + def estimate_chunk_count(file_size: int, chunk_size: int = DEFAULT_CHUNK_SIZE) -> int: """ Estimate number of chunks for a given file size. diff --git a/libp2p/bitswap/cid.py b/libp2p/bitswap/cid.py index 9f21d90de..0056d0710 100644 --- a/libp2p/bitswap/cid.py +++ b/libp2p/bitswap/cid.py @@ -209,7 +209,16 @@ def parse_cid(value: CIDInput) -> CIDv0 | CIDv1: return value if isinstance(value, bytes): - return make_cid(value) + try: + return make_cid(value) + except ValueError: + # make_cid(bytes) fails for raw CIDv0 buffers (multihash bytes). + # CIDv0 is simply a bare multihash, so try constructing directly. + try: + return CIDv0(value) + except Exception: + pass + raise if isinstance(value, str): cid_str = value.strip() diff --git a/libp2p/bitswap/client.py b/libp2p/bitswap/client.py index 96913567f..3d3acefc0 100644 --- a/libp2p/bitswap/client.py +++ b/libp2p/bitswap/client.py @@ -15,6 +15,7 @@ from libp2p.custom_types import TProtocol from libp2p.network.stream.exceptions import StreamEOF from libp2p.peer.id import ID as PeerID +from libp2p.peer.peerinfo import PeerInfo # noqa: F401 from .block_store import BlockStore, MemoryBlockStore from .cid import ( @@ -43,6 +44,7 @@ ) from .messages import create_message, create_wantlist_entry from .pb.bitswap_pb2 import Message +from .provider_query import ProviderQueryManager logger = logging.getLogger(__name__) @@ -60,6 +62,7 @@ def __init__( host: IHost, block_store: BlockStore | None = None, protocol_version: str = BITSWAP_PROTOCOL_V120, + provider_query_manager: ProviderQueryManager | None = None, ): """ Initialize Bitswap client. @@ -68,11 +71,18 @@ def __init__( host: The libp2p host block_store: Block storage backend (defaults to in-memory) protocol_version: Preferred protocol version (defaults to v1.2.0) + provider_query_manager: Optional ProviderQueryManager for automatic + DHT-based provider discovery. When supplied, + ``get_block()`` will query the DHT for providers before + broadcasting to all connected peers. """ self.host = host self.block_store = block_store or MemoryBlockStore() self.protocol_version = protocol_version + self.provider_query_manager: ProviderQueryManager | None = ( + provider_query_manager + ) self._wantlist: dict[ CIDObject, dict[str, Any] ] = {} # CID -> {priority, want_type, send_dont_have} @@ -153,6 +163,88 @@ async def add_block(self, cid: CIDInput, data: bytes) -> None: # Notify peers who wanted this block await self._notify_peers_about_block(cid_obj, data) + async def get_blocks_batch( + self, + cids: list[CIDInput], + peer_id: PeerID | None = None, + timeout: float = DEFAULT_TIMEOUT, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """ + Fetch multiple blocks in batches using a single wantlist per batch. + + Sends all CIDs in one wantlist message, waits for all responses on the + same stream. This avoids opening hundreds of individual streams which + causes Kubo to send GO_AWAY. + + Args: + cids: List of CIDs to fetch + peer_id: Optional specific peer to request from + timeout: Timeout per batch in seconds + batch_size: How many CIDs to request per wantlist message + + Returns: + Dict mapping cid_bytes -> block_data for all successfully fetched blocks + + """ + results: dict[bytes, bytes] = {} + cid_objs = [parse_cid(c) for c in cids] + + # Check local store first + remaining: list[CIDObject] = [] + for cid_obj in cid_objs: + data = await self.block_store.get_block(cid_obj) + if data is not None: + results[cid_obj.buffer] = data + else: + remaining.append(cid_obj) + + if not remaining: + return results + + # Process in batches to avoid overwhelming the peer + for batch_start in range(0, len(remaining), batch_size): + batch = remaining[batch_start : batch_start + batch_size] + + # Register pending events for all CIDs in batch + for cid_obj in batch: + if cid_obj not in self._pending_requests: + self._pending_requests[cid_obj] = trio.Event() + await self.want_block(cid_obj, send_dont_have=True) + + # Send all CIDs in a single wantlist to the peer + if peer_id: + await self._send_wantlist_to_peer(peer_id, batch) + else: + await self._broadcast_wantlist(batch) + + # Wait for all blocks in this batch + try: + with trio.fail_after(timeout): + for cid_obj in batch: + if cid_obj in self._pending_requests: + await self._pending_requests[cid_obj].wait() + except trio.TooSlowError: + msg = f"Batch timeout: {len(batch)} blocks, got partial results" + logger.warning(msg) + + # Collect results and clean up + for cid_obj in batch: + data = await self.block_store.get_block(cid_obj) + if data is not None: + results[cid_obj.buffer] = data + else: + cid_str = format_cid_for_display(cid_obj) + logger.warning(f"Block not received: {cid_str}") + + # Cleanup + if cid_obj in self._pending_requests: + del self._pending_requests[cid_obj] + if cid_obj in self._wantlist: + del self._wantlist[cid_obj] + + return results + async def get_block( self, cid: CIDInput, @@ -162,9 +254,15 @@ async def get_block( """ Get a block, fetching from peers if not available locally. + If a ``ProviderQueryManager`` was supplied at construction time and no + explicit ``peer_id`` is given, the manager is consulted first to + discover which peers have the block via the DHT. The first discovered + provider is used; if none is found the request falls back to + broadcasting to all connected peers. + Args: cid: The CID of the block to fetch - peer_id: Optional specific peer to request from + peer_id: Optional peer to request from; DHT discovery is skipped when set. timeout: Timeout in seconds Returns: @@ -177,12 +275,31 @@ async def get_block( """ cid_obj = parse_cid(cid) - # Check local store first + # 1. Check local store first data = await self.block_store.get_block(cid_obj) if data is not None: return data - # Request from network + # 2. If no explicit peer given, try DHT provider discovery + if peer_id is None and self.provider_query_manager is not None: + try: + providers = await self.provider_query_manager.find_providers_single( + cid, timeout=min(5.0, timeout / 2) + ) + if providers: + peer_id = providers[0] + logger.debug( + "DHT discovered provider %s for %s", + peer_id, + format_cid_for_display(cid_obj, max_len=12), + ) + except Exception as exc: + logger.debug( + "Provider query failed, falling back to broadcast: %s", + exc, + ) + + # 3. Request from network (specific peer or broadcast) return await self._request_block(cid_obj, peer_id, timeout) async def want_block( @@ -286,10 +403,8 @@ async def _request_block( # Send wantlist to peers if peer_id: - logger.info(f" → Sending wantlist to peer {peer_id}") await self._send_wantlist_to_peer(peer_id, [cid]) else: - logger.info(" → Broadcasting wantlist") await self._broadcast_wantlist([cid]) # Wait for block to arrive @@ -650,25 +765,91 @@ async def _process_wantlist( # Send DontHave (v1.2.0) presences_to_send.append((entry_cid, False)) - # Send responses + # Send responses in batches to stay under MAX_MESSAGE_SIZE + # and Noise protocol limit (65535 bytes) if blocks_to_send_v100 or blocks_to_send_v110 or presences_to_send: - response_msg = create_message( - blocks_v100=blocks_to_send_v100 if blocks_to_send_v100 else None, - blocks_v110=blocks_to_send_v110 if blocks_to_send_v110 else None, - block_presences=presences_to_send if presences_to_send else None, - ) - logger.debug(f"Sending response message to {peer_id} on stream {stream}") - await self._write_message(stream, response_msg) - logger.debug(f"Response message sent to {peer_id}") - - if blocks_to_send_v100 or blocks_to_send_v110: - count = len(blocks_to_send_v100) + len(blocks_to_send_v110) - logger.debug(f"Sent {count} blocks to peer {peer_id}") + # Send blocks in batches + if blocks_to_send_v100: + await self._send_blocks_in_batches_v100( + blocks_to_send_v100, peer_id, stream + ) + if blocks_to_send_v110: + await self._send_blocks_in_batches_v110( + blocks_to_send_v110, peer_id, stream + ) + # Send presences (usually small, can send all at once) if presences_to_send: + presence_msg = create_message(block_presences=presences_to_send) + await self._write_message(stream, presence_msg) logger.debug( f"Sent {len(presences_to_send)} block presences to peer {peer_id}" ) + async def _send_blocks_in_batches_v100( + self, blocks: list[bytes], peer_id: PeerID, stream: INetStream + ) -> None: + """Send blocks in batches to stay under message size limit.""" + # Noise protocol limit is 65535 bytes per message + # Reserve some space for protobuf overhead + MAX_BATCH_SIZE = 60000 # ~60KB per message for safety + + batch: list[bytes] = [] + batch_size = 0 + + for block_data in blocks: + block_size = len(block_data) + + # If adding this block would exceed limit, send current batch first + if batch and (batch_size + block_size > MAX_BATCH_SIZE): + msg = create_message(blocks_v100=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent batch of {len(batch)} blocks to peer {peer_id}") + batch = [] + batch_size = 0 + + batch.append(block_data) + batch_size += block_size + + # Send remaining blocks + if batch: + msg = create_message(blocks_v100=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent final batch of {len(batch)} blocks to peer {peer_id}") + + async def _send_blocks_in_batches_v110( + self, + blocks: list[tuple[bytes, bytes]], + peer_id: PeerID, + stream: INetStream, + ) -> None: + """Send blocks (v1.1.0+ format) in batches to stay under message size limit.""" + # Noise protocol limit is 65535 bytes per message + # Reserve some space for protobuf overhead + MAX_BATCH_SIZE = 60000 # ~60KB per message for safety + + batch: list[tuple[bytes, bytes]] = [] + batch_size = 0 + + for prefix, block_data in blocks: + block_size = len(prefix) + len(block_data) + + # If adding this block would exceed limit, send current batch first + if batch and (batch_size + block_size > MAX_BATCH_SIZE): + msg = create_message(blocks_v110=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent batch of {len(batch)} blocks to peer {peer_id}") + batch = [] + batch_size = 0 + + batch.append((prefix, block_data)) + batch_size += block_size + + # Send remaining blocks + if batch: + msg = create_message(blocks_v110=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent final batch of {len(batch)} blocks to peer {peer_id}") + async def _process_blocks_v100(self, blocks: list[bytes], peer_id: PeerID) -> None: """ Process received blocks (v1.0.0 format). diff --git a/libp2p/bitswap/config.py b/libp2p/bitswap/config.py index 87ba26e0e..6fc3f2bfb 100644 --- a/libp2p/bitswap/config.py +++ b/libp2p/bitswap/config.py @@ -22,8 +22,9 @@ # Maximum message size (4MiB as per spec) MAX_MESSAGE_SIZE = 4 * 1024 * 1024 -# Maximum block size (63 KB - matches DEFAULT_CHUNK_SIZE in chunker.py) +# Maximum block size (63 KB - after DAG-PB/UnixFS encoding) # py-libp2p stream limit is ~64 KB, so we use 63 KB to be safe +# Note: Raw chunk data should be smaller to account for DAG-PB overhead (~14 bytes) MAX_BLOCK_SIZE = 63 * 1024 # Default timeout for operations (in seconds) diff --git a/libp2p/bitswap/dag.py b/libp2p/bitswap/dag.py index 98ce469db..9283fdcf9 100644 --- a/libp2p/bitswap/dag.py +++ b/libp2p/bitswap/dag.py @@ -9,22 +9,24 @@ from collections.abc import Awaitable, Callable import inspect +import io import logging from typing import Union from libp2p.peer.id import ID as PeerID +from .block_service import BlockService from .block_store import BlockStore from .chunker import ( DEFAULT_CHUNK_SIZE, chunk_bytes, chunk_file, + chunk_stream, estimate_chunk_count, get_file_size, ) from .cid import ( CODEC_DAG_PB, - CODEC_RAW, CIDInput, cid_to_bytes, compute_cid_v1, @@ -33,11 +35,13 @@ ) from .client import BitswapClient from .dag_pb import ( - create_file_node, + balanced_layout, + create_leaf_node, decode_dag_pb, is_directory_node, is_file_node, ) +from .errors import BlockNotFoundError logger = logging.getLogger(__name__) @@ -98,17 +102,97 @@ class MerkleDag: """ - def __init__(self, bitswap: BitswapClient, block_store: BlockStore | None = None): + def __init__( + self, + bitswap: BitswapClient, + block_store: BlockStore | None = None, + block_service: BlockService | None = None, + ): """ Initialize Merkle DAG manager. Args: bitswap: Bitswap client for block exchange block_store: Optional block store (uses bitswap's store if None) + block_service: Optional BlockService for transparent local→network + fallback with auto-caching. When provided, all block + reads/writes go through it instead of bitswap directly. + Construct with: BlockService(your_store, bitswap) """ self.bitswap = bitswap self.block_store = block_store or bitswap.block_store + # If a BlockService is provided use it; otherwise fall back to + # calling bitswap directly (existing behaviour, no regression). + self._service: BlockService | None = block_service + + # ── private routing helpers ─────────────────────────────────────────────── + + async def _put_block(self, cid: CIDInput, data: bytes) -> None: + """Store a block. Routes through BlockService when available.""" + if self._service is not None: + await self._service.put_block(cid, data) + else: + await self.bitswap.add_block(cid, data) + + async def _get_block( + self, + cid: CIDInput, + peer_id: PeerID | None = None, + timeout: float = 30.0, + ) -> bytes: + """Fetch a block. Routes through BlockService when available.""" + if self._service is not None: + data = await self._service.get_block(cid, peer_id=peer_id, timeout=timeout) + if data is None: + from .cid import cid_to_bytes, format_cid_for_display + + raise BlockNotFoundError( + f"Block not found: {format_cid_for_display(cid_to_bytes(cid))}" + ) + return data + return await self.bitswap.get_block(cid, peer_id, timeout) + + async def _get_blocks_batch( + self, + cids: list[CIDInput], + peer_id: PeerID | None = None, + timeout: float = 30.0, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """Batch-fetch blocks. Routes through BlockService when available.""" + if self._service is not None: + return await self._service.get_blocks_batch( + cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + # Check if the client supports native batch fetching + get_blocks_batch: Callable[..., Awaitable[dict[bytes, bytes]]] | None = getattr( + self.bitswap, "get_blocks_batch", None + ) + if get_blocks_batch is not None and callable(get_blocks_batch): + try: + result = await get_blocks_batch( + cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + # Ensure the result is a plain dict (not a coroutine from a mock) + if isinstance(result, dict): + return result + except Exception: + pass + # Fall back to individual _get_block calls + results: dict[bytes, bytes] = {} + for cid in cids: + from .cid import cid_to_bytes + + cid_bytes = cid_to_bytes(cid) + try: + data = await self._get_block( + cid_bytes, peer_id=peer_id, timeout=timeout + ) + results[cid_bytes] = data + except Exception: + pass + return results async def add_file( self, @@ -154,16 +238,17 @@ async def add_file( logger.debug(f"Using chunk size: {chunk_size} bytes") - # If file is small enough, store as single RAW block + # If file is small enough, store as single dag-pb leaf block if file_size <= chunk_size: logger.debug("File fits in single block") with open(file_path, "rb") as f: data = f.read() - cid = compute_cid_v1(data, codec=CODEC_RAW) + leaf_block = create_leaf_node(data) + cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) - await self.bitswap.add_block(cid, data) + await self._put_block(cid, leaf_block) if progress_callback: await _call_progress_callback( @@ -187,7 +272,7 @@ async def add_file( dir_data = create_directory_node([(filename, cid, file_size)]) dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(dir_cid, dir_data) + await self._put_block(dir_cid, dir_data) logger.info( f"Created directory wrapper. Directory CID: " @@ -202,19 +287,18 @@ async def add_file( logger.debug(f"Chunking file into ~{estimated_chunks} chunks") logger.info("=== Starting file chunking process ===") - chunks_data: list[tuple[bytes, int]] = [] + # leaf_triples: (cid_bytes, leaf_block_bytes, raw_data_size) + leaf_triples: list[tuple[bytes, bytes, int]] = [] bytes_processed = 0 # Process file in chunks (memory efficient) for i, chunk_data in enumerate(chunk_file(file_path, chunk_size)): - # Compute CID for chunk - chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) + # Wrap chunk in UnixFS dag-pb leaf (matches Kubo's RawLeaves=false) + leaf_block = create_leaf_node(chunk_data) + chunk_cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) - # Store chunk - await self.bitswap.add_block(chunk_cid, chunk_data) - - # Track chunk info - chunks_data.append((chunk_cid, len(chunk_data))) + await self._put_block(chunk_cid, leaf_block) + leaf_triples.append((chunk_cid, leaf_block, len(chunk_data))) bytes_processed += len(chunk_data) # Progress callback @@ -226,43 +310,36 @@ async def add_file( f"chunking ({i + 1} chunks)", ) - # Enhanced logging with full CID logger.info( f"Chunk {i + 1}: CID={format_cid_for_display(chunk_cid)}, " f"Size={len(chunk_data)} bytes, " f"Progress={bytes_processed}/{file_size}" ) logger.debug( - f"Stored chunk {i}: {format_cid_for_display(chunk_cid, max_len=16)} " + f"Stored leaf {i}: {format_cid_for_display(chunk_cid, max_len=16)} " f"({len(chunk_data)} bytes)" ) - # Create root node with links to all chunks + # Build balanced DAG tree (max 174 links/node, matches Kubo) if progress_callback: await _call_progress_callback( progress_callback, file_size, file_size, "creating root node" ) - root_data = create_file_node(chunks_data) - root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(root_cid, root_data) + root_cid, root_data = balanced_layout(leaf_triples) + await self._put_block(root_cid, root_data) # Enhanced logging for root CID logger.info("=== File chunking completed ===") logger.info( f"Root CID: {format_cid_for_display(root_cid)} " - f"(Links to {len(chunks_data)} chunks)" + f"(Balanced DAG over {len(leaf_triples)} leaves)" ) logger.info(f"Total file size: {file_size} bytes") - logger.info("=== Chunk CIDs ===") - for i, (chunk_cid, chunk_size) in enumerate(chunks_data): - logger.info( - f" Chunk {i}: {format_cid_for_display(chunk_cid)} ({chunk_size} bytes)" - ) logger.info("=" * 50) logger.info( - f"Added file with {len(chunks_data)} chunks. " + f"Added file with {len(leaf_triples)} leaves. " f"Root CID: {format_cid_for_display(root_cid, max_len=16)}" ) @@ -283,7 +360,7 @@ async def add_file( # Create directory node with single entry pointing to the file dir_data = create_directory_node([(filename, root_cid, file_size)]) dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(dir_cid, dir_data) + await self._put_block(dir_cid, dir_data) logger.info( "Created directory wrapper. Directory CID: " @@ -322,10 +399,11 @@ async def add_bytes( if chunk_size is None: chunk_size = DEFAULT_CHUNK_SIZE - # If data is small, store as single block + # If data is small, store as single dag-pb leaf block if file_size <= chunk_size: - cid = compute_cid_v1(data, codec=CODEC_RAW) - await self.bitswap.add_block(cid, data) + leaf_block = create_leaf_node(data) + cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) + await self._put_block(cid, leaf_block) if progress_callback: await _call_progress_callback( @@ -334,17 +412,18 @@ async def add_bytes( return cid - # Chunk the data + # Chunk the data and wrap each chunk as a dag-pb leaf chunks = chunk_bytes(data, chunk_size) - chunks_data: list[tuple[bytes, int]] = [] + leaf_triples: list[tuple[bytes, bytes, int]] = [] for i, chunk_data in enumerate(chunks): - chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) - await self.bitswap.add_block(chunk_cid, chunk_data) - chunks_data.append((chunk_cid, len(chunk_data))) + leaf_block = create_leaf_node(chunk_data) + chunk_cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) + await self._put_block(chunk_cid, leaf_block) + leaf_triples.append((chunk_cid, leaf_block, len(chunk_data))) if progress_callback: - bytes_processed = sum(size for _, size in chunks_data) + bytes_processed = sum(s for _, _, s in leaf_triples) await _call_progress_callback( progress_callback, bytes_processed, @@ -352,10 +431,9 @@ async def add_bytes( f"chunking ({i + 1}/{len(chunks)})", ) - # Create root node - root_data = create_file_node(chunks_data) - root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(root_cid, root_data) + # Build balanced DAG tree + root_cid, root_data = balanced_layout(leaf_triples) + await self._put_block(root_cid, root_data) if progress_callback: await _call_progress_callback( @@ -364,6 +442,94 @@ async def add_bytes( return root_cid + async def add_stream( + self, + stream: io.IOBase, + chunk_size: int | None = None, + progress_callback: ProgressCallback | None = None, + ) -> bytes: + """ + Add data from any io.IOBase stream to the DAG. + + More flexible than add_file() (accepts any stream, not just file paths) + and more memory efficient than add_bytes() (reads one chunk at a time, + so total memory usage is O(chunk_size) regardless of file size). + + Args: + stream: Any readable io.IOBase — open() handles, BytesIO, + GzipFile, BZ2File, network streams, pipes, etc. + chunk_size: Optional chunk size in bytes (auto-selected if None) + progress_callback: Optional callback(current, total, status). + Note: total is unknown for streams, so current + is reported as bytes processed so far. + + Returns: + Root CID bytes of the stored DAG + + Example: + >>> import io + >>> root_cid = await dag.add_stream(io.BytesIO(b"hello world")) + + >>> # Memory-efficient large file (no full read into RAM) + >>> with open("movie.mp4", "rb") as f: + ... root_cid = await dag.add_stream(f) + + >>> # Decompress and add in one pass + >>> import gzip + >>> with gzip.open("archive.gz", "rb") as f: + ... root_cid = await dag.add_stream(f) + + >>> # With BlockService for persistent caching + >>> service = BlockService(FilesystemBlockStore("./blocks"), bitswap) + >>> dag = MerkleDag(bitswap, block_service=service) + >>> with open("large.bin", "rb") as f: + ... root_cid = await dag.add_stream(f) # cached to disk + + """ + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + leaf_triples: list[tuple[bytes, bytes, int]] = [] + bytes_processed = 0 + + for i, chunk_data in enumerate(chunk_stream(stream, chunk_size)): + leaf_block = create_leaf_node(chunk_data) + chunk_cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) + await self._put_block(chunk_cid, leaf_block) + leaf_triples.append((chunk_cid, leaf_block, len(chunk_data))) + bytes_processed += len(chunk_data) + + if progress_callback: + # total is unknown for streams — report bytes processed so far + await _call_progress_callback( + progress_callback, + bytes_processed, + bytes_processed, + f"chunking ({i + 1} chunks, {bytes_processed} bytes)", + ) + + # Empty stream — store a single empty leaf + if not leaf_triples: + leaf_block = create_leaf_node(b"") + cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) + await self._put_block(cid, leaf_block) + return cid + + # Single chunk — return the leaf CID directly (no root node needed) + if len(leaf_triples) == 1: + return leaf_triples[0][0] + + # Multiple chunks — build balanced DAG tree + root_cid, root_data = balanced_layout(leaf_triples) + await self._put_block(root_cid, root_data) + + if progress_callback: + await _call_progress_callback( + progress_callback, bytes_processed, bytes_processed, "completed" + ) + + return root_cid + async def fetch_file( self, root_cid: CIDInput, @@ -417,160 +583,265 @@ async def fetch_file( """ root_cid_bytes = cid_to_bytes(root_cid) - logger.info( - f"Fetching file: {format_cid_for_display(root_cid_bytes, max_len=16)}" - ) - logger.info( - "=== Starting file fetch for CID: " - f"{format_cid_for_display(root_cid_bytes)} ===" - ) + logger.info(f"Fetching file: {format_cid_for_display(root_cid_bytes)}") - # Get root block - root_data = await self.bitswap.get_block(root_cid_bytes, peer_id, timeout) - - # Verify root block + # Step 1: Fetch the root block + root_data = await self._get_block(root_cid_bytes, peer_id, timeout) if not verify_cid(root_cid_bytes, root_data): - raise ValueError( - "Root block verification failed: " - f"{format_cid_for_display(root_cid_bytes)}" - ) + root_cid_str = format_cid_for_display(root_cid_bytes) + raise ValueError(f"Root block CID verification failed: {root_cid_str}") - # Check if it's a directory wrapper (IPFS-standard way for filename) + # Step 2: Handle directory wrapper + # (produced by `ipfs add --wrap-with-directory`) filename = None actual_file_cid = root_cid_bytes actual_file_data = root_data if is_directory_node(root_data): - logger.info("Root is a directory node, extracting file entry...") - links, _ = decode_dag_pb(root_data) - - if links: - # Get the first (and typically only) file entry - first_link = links[0] - filename = first_link.name if first_link.name else None + logger.info("Root is a directory node — extracting filename and file CID") + dir_links, _ = decode_dag_pb(root_data) + if dir_links: + first_link = dir_links[0] + filename = first_link.name or None actual_file_cid = first_link.cid - - logger.info(f"Extracted filename: {filename}") - logger.info( - f"Actual file CID: " - f"{format_cid_for_display(actual_file_cid, max_len=16)}" - ) - - # Fetch the actual file block - actual_file_data = await self.bitswap.get_block( + logger.info(f"Filename from directory: {filename!r}") + actual_file_data = await self._get_block( actual_file_cid, peer_id, timeout ) - if not verify_cid(actual_file_cid, actual_file_data): - raise ValueError( - "File block verification failed: " - f"{format_cid_for_display(actual_file_cid)}" - ) - - # Now process the actual file data - # Check if it's a DAG-PB file node - if is_file_node(actual_file_data): - logger.debug("Root is a DAG-PB file node, resolving chunks...") - - # Decode to get links and metadata - links, unixfs_data = decode_dag_pb(actual_file_data) - - if not links: - # File with inline data (small file) - logger.debug("File has inline data") - file_data = ( - unixfs_data.data if unixfs_data and unixfs_data.data else b"" - ) - - # Notify progress callback with metadata - if progress_callback: - await _call_progress_callback( - progress_callback, - len(file_data), - len(file_data), - f"metadata: size={len(file_data)}, chunks=0", - ) - - return file_data, filename - - # File with multiple chunks - total_size = unixfs_data.filesize if unixfs_data else 0 - logger.debug(f"File has {len(links)} chunks, total size: {total_size}") - logger.info( - f"Fetching multi-chunk file: {len(links)} chunks, {total_size} bytes" - ) - logger.info("=== Chunk CIDs to fetch ===") - for i, link in enumerate(links): - logger.info( - f" Chunk {i}: {format_cid_for_display(link.cid)} " - f"({link.size} bytes)" - ) - logger.info("=" * 50) - - # Notify progress callback with file metadata at the start + f_cid_str = format_cid_for_display(actual_file_cid) + err_msg = f"File block CID verification failed: {f_cid_str}" + raise ValueError(err_msg) + + # Step 3: Handle raw block (not a DAG-PB node at all) + if not is_file_node(actual_file_data): + logger.info(f"Root is a raw block: {len(actual_file_data)} bytes") + return actual_file_data, filename + + # Step 4: Parse the file node + top_links, top_unixfs = decode_dag_pb(actual_file_data) + filesize = top_unixfs.filesize if top_unixfs else 0 + total_size = filesize or sum(lnk.size for lnk in top_links) + msg = f"File node: {len(top_links)} top-level links, total size={total_size}" + logger.info(f"{msg} bytes") + + # Step 5: Small file with inline data (no links) + if not top_links: + file_data = top_unixfs.data if top_unixfs and top_unixfs.data else b"" + logger.info(f"Inline file data: {len(file_data)} bytes") if progress_callback: + data_len = len(file_data) await _call_progress_callback( - progress_callback, - 0, - total_size, - f"metadata: size={total_size}, chunks={len(links)}", - ) - - file_data = b"" - bytes_fetched = 0 - - # Fetch each chunk - for i, link in enumerate(links): - if progress_callback: - await _call_progress_callback( - progress_callback, - bytes_fetched, - total_size, - f"fetching chunk {i + 1}/{len(links)}", - ) - - logger.info( - f"Fetching chunk {i + 1}/{len(links)}: " - f"CID={format_cid_for_display(link.cid)}" + progress_callback, data_len, data_len, "completed" ) + return file_data, filename - # Fetch chunk - chunk_data = await self.bitswap.get_block(link.cid, peer_id, timeout) - - # Verify chunk - if not verify_cid(link.cid, chunk_data): - raise ValueError( - f"Chunk verification failed: {format_cid_for_display(link.cid)}" - ) - - file_data += chunk_data - bytes_fetched += len(chunk_data) + # Step 6: Collect all leaf CIDs without opening streams + # Strategy: Recursively batch-fetch all DAG nodes + # then traverse locally to collect leaves + + top_len = len(top_links) + msg1 = f"[DAG] Recursively batch-fetching DAG tree ({top_len} top links)..." + logger.info(msg1) + msg2 = f"[FETCH] Recursively batch-fetching DAG tree ({top_len} top links)..." + print(msg2, flush=True) + + # Map to store ALL fetched blocks (both intermediate and leaves) + all_blocks_map: dict[bytes, bytes] = {} + + async def _batch_fetch_tree(cid_list: list[bytes], depth: int) -> None: + """Recursively batch-fetch a level of DAG nodes and queue their children.""" + if not cid_list: + return + + c_count = len(cid_list) + msg1 = f"[DAG] Depth {depth}: batch-fetching {c_count} blocks..." + logger.info(msg1) + msg2 = f"[FETCH] Depth {depth}: batch-fetching {c_count} blocks..." + print(msg2, flush=True) + + # Batch-fetch this level's blocks + level_blocks = await self._get_blocks_batch( + list(cid_list), peer_id=peer_id, timeout=timeout, batch_size=32 + ) + logger.info(f"[DAG] Depth {depth}: ✓ received {len(level_blocks)} blocks") + all_blocks_map.update(level_blocks) + + # Collect child CIDs for recursion + child_cids: list[bytes] = [] + for cid_bytes in cid_list: + block_data = level_blocks.get(cid_bytes) + if block_data is None: + c_str = format_cid_for_display(cid_bytes) + msg = f"[DAG] Depth {depth}: block {c_str} missing after" + logger.warning(f"{msg} fetch") + continue + + if is_file_node(block_data): + node_links, _ = decode_dag_pb(block_data) + cid_str = format_cid_for_display(cid_bytes) + msg = f"[DAG] Depth {depth}: {cid_str} has {len(node_links)}" + logger.debug(f"{msg} children") + for link in node_links: + child_cids.append(link.cid) + + # Recursively fetch next level if there are children + if child_cids: + ch_count = len(child_cids) + msg = f"[DAG] Depth {depth}: found {ch_count} child CIDs" + logger.info(f"{msg}, fetching next level...") + await _batch_fetch_tree(child_cids, depth + 1) + + # Starting from the top-level links + await _batch_fetch_tree([top_link.cid for top_link in top_links], depth=1) + blocks_count = len(all_blocks_map) + logger.info(f"[DAG] ✓ Tree fetch complete: {blocks_count} total blocks") + print(f"[FETCH] ✓ Tree fetch complete: {blocks_count} total blocks", flush=True) + + # Now traverse locally to collect leaf CIDs in order + ordered_leaf_cids: list[bytes] = [] + + def _collect_leaves_local(cid_bytes: bytes, depth: int = 1) -> None: + """Traverse locally-fetched blocks to collect leaf CIDs.""" + block_data = all_blocks_map.get(cid_bytes) + if block_data is None: + cid_str = format_cid_for_display(cid_bytes) + logger.warning(f"[DAG] Depth {depth}: block {cid_str} not in map") + return + + if not is_file_node(block_data): + # Raw block - it's a leaf + logger.debug(f"[DAG] Depth {depth}: raw block (leaf)") + ordered_leaf_cids.append(cid_bytes) + return + + node_links, _ = decode_dag_pb(block_data) + logger.debug(f"[DAG] Depth {depth}: {len(node_links)} links") + + if not node_links: + # Leaf node (no children, data is inline in UnixFS) + logger.debug(f"[DAG] Depth {depth}: file node with inline data (leaf)") + ordered_leaf_cids.append(cid_bytes) + return + + # Intermediate node - recursively process children + for j, child_link in enumerate(node_links): + c_idx = j + 1 + c_tot = len(node_links) + msg = f"[DAG] Depth {depth}: processing child {c_idx}/{c_tot}" + logger.debug(msg) + _collect_leaves_local(child_link.cid, depth + 1) + + # Traverse each top-level block + for i, top_link in enumerate(top_links): + logger.info(f"[DAG] Traversing top-level {i + 1}/{len(top_links)}...") + _collect_leaves_local(top_link.cid, depth=1) + + logger.info(f"[DAG] ✓ Collected {len(ordered_leaf_cids)} leaf blocks") + + # Step 7: Batch-fetch all leaf blocks + # (single wantlist per batch → avoids GO_AWAY) + if progress_callback: + await _call_progress_callback( + progress_callback, + 0, + total_size, + f"fetching {len(ordered_leaf_cids)} leaf blocks in batches", + ) - logger.info( - f"✓ Chunk {i + 1} fetched and verified: " - f"{len(chunk_data)} bytes (total: {bytes_fetched}/{total_size})" - ) - logger.debug( - f"Fetched chunk {i + 1}/{len(links)}: " - f"{format_cid_for_display(link.cid, max_len=16)} " - f"({len(chunk_data)} bytes)" - ) + l_count = len(ordered_leaf_cids) + msg1 = f"[DAG] Starting batch fetch of {l_count} leaves with batch_size=32" + logger.info(f"{msg1}, timeout={timeout}s") + msg2 = ( + f"[FETCH] Batch fetching {l_count} leaves " + f"(batch_size=32, timeout={timeout}s)" + ) + print(msg2, flush=True) + + # First try to get blocks from the already-fetched tree + block_map: dict[bytes, bytes] = {} + missing_cids: list[CIDInput] = [] + for leaf_cid in ordered_leaf_cids: + leaf_data = all_blocks_map.get(leaf_cid) + if leaf_data is not None: + block_map[leaf_cid] = leaf_data + else: + missing_cids.append(leaf_cid) + + # If some leaves weren't in the tree fetch, fetch them now + if missing_cids: + logger.info(f"[DAG] Fetching {len(missing_cids)} missing leaves") + fetched_blocks = await self._get_blocks_batch( + missing_cids, peer_id=peer_id, timeout=timeout, batch_size=32 + ) + block_map.update(fetched_blocks) + + logger.info(f"[DAG] ✓ Batch fetch complete: {len(block_map)} blocks received") + print(f"[FETCH] ✓ Batch fetch complete: {len(block_map)} blocks", flush=True) + + # Step 8: Reassemble data in order + # extracting UnixFS inline data from leaf nodes + file_data = b"" + bytes_fetched = 0 + missing_blocks: list[bytes] = [] + for idx, leaf_cid in enumerate(ordered_leaf_cids): + leaf_raw = block_map.get(bytes(leaf_cid)) + if leaf_raw is None: + l_idx = idx + 1 + t_leaves = len(ordered_leaf_cids) + c_str = format_cid_for_display(leaf_cid) + msg = f"[DAG] Leaf block {l_idx}/{t_leaves} MISSING: {c_str}" + logger.error(msg) + print(f"[FETCH] ✗ Leaf {l_idx}/{t_leaves} MISSING", flush=True) + missing_blocks.append(leaf_cid) + continue + + # Extract data: leaf blocks are UnixFS file nodes with inline data + if is_file_node(leaf_raw): + _, leaf_unixfs = decode_dag_pb(leaf_raw) + if leaf_unixfs is not None and leaf_unixfs.data: + chunk = leaf_unixfs.data + else: + chunk = b"" + chunk_len = len(chunk) + msg = f"[DAG] Leaf {idx + 1}: extracted {chunk_len} bytes" + logger.debug(f"{msg} from file node") + else: + chunk = leaf_raw + logger.debug(f"[DAG] Leaf {idx + 1}: raw block {len(chunk)} bytes") + + file_data += chunk + bytes_fetched += len(chunk) + + if (idx + 1) % 10 == 0 or idx == len(ordered_leaf_cids) - 1: + i_p = idx + 1 + t_l = len(ordered_leaf_cids) + p_str = f"{bytes_fetched}/{total_size} bytes" + logger.info(f"[DAG] Reassembled {i_p}/{t_l} leaves: {p_str}") + print(f"[FETCH] Reassembled {i_p}/{t_l} leaves: {p_str}", flush=True) if progress_callback: await _call_progress_callback( - progress_callback, total_size, total_size, "completed" + progress_callback, bytes_fetched, total_size, "downloading" ) - logger.info("=== File fetch completed ===") - logger.info(f"Total bytes fetched: {len(file_data)}") - logger.info(f"All {len(links)} chunks verified successfully") - logger.info("=" * 50) - logger.info(f"Fetched file: {len(file_data)} bytes") - return file_data, filename + if missing_blocks: + missing_count = len(missing_blocks) + logger.error(f"[DAG] ✗ {missing_count} blocks missing after batch fetch!") + missing_list = [format_cid_for_display(cid) for cid in missing_blocks[:5]] + msg = f"{missing_count} leaf blocks missing: {missing_list}..." + raise BlockNotFoundError(msg) + + if progress_callback: + await _call_progress_callback( + progress_callback, total_size, total_size, "completed" + ) - # Not a DAG-PB file node - return as raw data - logger.debug("Root is a raw block, returning directly") - return actual_file_data, filename + file_len = len(file_data) + msg = f"[DAG] ✓ File fetch complete: {file_len} bytes, filename={filename!r}" + logger.info(msg) + print(f"[FETCH] ✓ DOWNLOAD COMPLETE: {file_len} bytes", flush=True) + return file_data, filename async def get_file_info( self, root_cid: CIDInput, peer_id: PeerID | None = None, timeout: float = 30.0 @@ -597,7 +868,7 @@ async def get_file_info( """ # Get root block root_cid_bytes = cid_to_bytes(root_cid) - root_data = await self.bitswap.get_block(root_cid_bytes, peer_id, timeout) + root_data = await self._get_block(root_cid_bytes, peer_id, timeout) # Check if it's a DAG-PB file node if is_file_node(root_data): diff --git a/libp2p/bitswap/dag_pb.py b/libp2p/bitswap/dag_pb.py index 74bbcddc2..164add080 100644 --- a/libp2p/bitswap/dag_pb.py +++ b/libp2p/bitswap/dag_pb.py @@ -9,13 +9,26 @@ from dataclasses import dataclass, field import logging -from .cid import CIDInput, cid_to_bytes -from .pb.dag_pb_pb2 import PBNode +from .cid import CODEC_DAG_PB, CIDInput, cid_to_bytes, compute_cid_v1 +from .pb.dag_pb_pb2 import PBLink, PBNode from .pb.unixfs_pb2 import Data as PBUnixFSData +# Maximum links per internal DAG-PB node — matches Go's balanced.Layout default +MAX_LINKS_PER_NODE = 174 + logger = logging.getLogger(__name__) +def _encode_varint(value: int) -> bytes: + """Encode an unsigned integer as a protobuf varint.""" + buf = [] + while value > 0x7F: + buf.append((value & 0x7F) | 0x80) + value >>= 7 + buf.append(value & 0x7F) + return bytes(buf) + + def _normalize_link_cid(cid: CIDInput) -> bytes: """Normalize CID input for DAG links while preserving raw-bytes compatibility.""" if isinstance(cid, bytes): @@ -103,38 +116,39 @@ def encode_dag_pb(links: list[Link], unixfs_data: UnixFSData | None = None) -> b >>> encoded = encode_dag_pb(links, data) """ - # Create PBNode - pb_node = PBNode() + # DAG-PB canonical format requires Links (field 2) BEFORE Data (field 1). + # Standard protobuf SerializeToString() emits fields in field-number order + # (Data=1 first, Links=2 second), producing different bytes and a different + # CID than Kubo for the same logical content. + # We manually construct the wire format to enforce the correct ordering. - # Add links + result = b"" + + # 1. Serialize each Link first — field 2, wire type 2 (length-delimited) = tag 0x12 for link in links: - pb_link = pb_node.Links.add() + pb_link = PBLink() pb_link.Hash = link.cid pb_link.Name = link.name pb_link.Tsize = link.size + link_bytes = pb_link.SerializeToString() + result += b"\x12" + _encode_varint(len(link_bytes)) + link_bytes - # Add UnixFS data if provided - if unixfs_data: - # Create UnixFS data structure + # 2. Serialize Data after Links — field 1, wire type 2 = tag 0x0a + if unixfs_data is not None: pb_unixfs = PBUnixFSData() pb_unixfs.Type = UnixFSData.TYPE_MAP[unixfs_data.type] # type: ignore[assignment] pb_unixfs.Data = unixfs_data.data pb_unixfs.filesize = unixfs_data.filesize - - # Add blocksizes for blocksize in unixfs_data.blocksizes: pb_unixfs.blocksizes.append(blocksize) - if unixfs_data.hash_type: pb_unixfs.hashType = unixfs_data.hash_type if unixfs_data.fanout: pb_unixfs.fanout = unixfs_data.fanout + data_bytes = pb_unixfs.SerializeToString() + result += b"\x0a" + _encode_varint(len(data_bytes)) + data_bytes - # Serialize UnixFS data and add to PBNode - pb_node.Data = pb_unixfs.SerializeToString() - - # Serialize PBNode - return pb_node.SerializeToString() + return result def decode_dag_pb(data: bytes) -> tuple[list[Link], UnixFSData | None]: @@ -282,3 +296,94 @@ def get_file_size(data: bytes) -> int: if unixfs_data and unixfs_data.type == "file": return unixfs_data.filesize return 0 + + +def create_leaf_node(data: bytes) -> bytes: + """ + Create a DAG-PB leaf node for a single file chunk. + + Wraps raw bytes in UnixFS Data(type=File, data=chunk, filesize=len(chunk)) + inside a PBNode with no links. This matches Kubo's default behaviour + (RawLeaves=false), ensuring leaf CIDs are byte-identical to those + produced by `ipfs add`. + + Args: + data: Raw chunk bytes (may be empty for an empty file) + + Returns: + Encoded DAG-PB bytes, suitable for storage as a dag-pb block + + """ + unixfs_data = UnixFSData(type="file", data=data, filesize=len(data)) + return encode_dag_pb([], unixfs_data) + + +def balanced_layout( + leaves: list[tuple[bytes, bytes, int]], + max_links: int = MAX_LINKS_PER_NODE, +) -> tuple[bytes, bytes]: + """ + Build a balanced Merkle DAG from a flat list of leaf blocks. + + Groups leaves into batches of `max_links` (default 174), creates an + internal DAG-PB node for each batch, then repeats level by level until + a single root remains. Matches Go's balanced.Layout exactly. + + Args: + leaves: List of (cid_bytes, block_bytes, file_data_size) tuples where + - cid_bytes: CID of the leaf block as raw bytes + - block_bytes: The encoded dag-pb leaf block bytes + - file_data_size: Size of the raw file data inside this leaf + (i.e. len(original chunk), NOT len(block)) + max_links: Max links per internal node (default 174, matches Kubo) + + Returns: + (root_cid_bytes, root_block_bytes) + + Raises: + ValueError: If leaves is empty + + """ + if not leaves: + raise ValueError("Cannot build balanced layout from empty leaf list") + + if len(leaves) == 1: + return leaves[0][0], leaves[0][1] + + # Each level entry: (cid_bytes, block_bytes, file_data_size, cumulative_block_size) + # cumulative_block_size = len(this block) + sum(children's cumulative sizes) + level: list[tuple[bytes, bytes, int, int]] = [ + (cid, blk, fsize, len(blk)) for cid, blk, fsize in leaves + ] + + while len(level) > 1: + next_level: list[tuple[bytes, bytes, int, int]] = [] + for i in range(0, len(level), max_links): + batch = level[i : i + max_links] + if len(batch) == 1: + next_level.append(batch[0]) + continue + + # Build internal node: links to each child, UnixFS blocksizes + internal_links: list[Link] = [] + blocksizes: list[int] = [] + total_filesize = 0 + total_cum = 0 + for cid_b, _, fsize, cum in batch: + # Tsize = cumulative block size of the subtree rooted at this child + internal_links.append(Link(cid=cid_b, name="", size=cum)) + blocksizes.append(fsize) + total_filesize += fsize + total_cum += cum + + unixfs_data = UnixFSData( + type="file", filesize=total_filesize, blocksizes=blocksizes + ) + internal_block = encode_dag_pb(internal_links, unixfs_data) + internal_cid = compute_cid_v1(internal_block, codec=CODEC_DAG_PB) + # cumulative size = own block + sum of children's cumulative sizes + cum_size = len(internal_block) + total_cum + next_level.append((internal_cid, internal_block, total_filesize, cum_size)) + level = next_level + + return level[0][0], level[0][1] diff --git a/libp2p/bitswap/messages.py b/libp2p/bitswap/messages.py index 8eea6535d..0c4264bce 100644 --- a/libp2p/bitswap/messages.py +++ b/libp2p/bitswap/messages.py @@ -4,16 +4,20 @@ """ from collections.abc import Sequence +from typing import TYPE_CHECKING, Union from .cid import CIDInput, cid_to_bytes from .pb.bitswap_pb2 import Message +if TYPE_CHECKING: + from .wantlist import WantType + def create_wantlist_entry( block_cid: CIDInput, priority: int = 1, cancel: bool = False, - want_type: int = 0, # 0 = Block, 1 = Have (v1.2.0) + want_type: Union[int, "WantType"] = 0, # 0 = Block, 1 = Have (v1.2.0) send_dont_have: bool = False, # v1.2.0 ) -> Message.Wantlist.Entry: """ @@ -36,8 +40,12 @@ def create_wantlist_entry( entry.block = cid_to_bytes(block_cid) entry.priority = priority entry.cancel = cancel - # Type checkers don't like int assignment to enum, but protobuf accepts it - entry.wantType = want_type # type: ignore[assignment] # v1.2.0 field + # Handle both int and WantType enum + if isinstance(want_type, int): + entry.wantType = want_type # type: ignore[assignment] + else: + # Extract .value from WantType enum + entry.wantType = want_type.value # type: ignore[assignment] entry.sendDontHave = send_dont_have # v1.2.0 field return entry diff --git a/libp2p/bitswap/provider_query.py b/libp2p/bitswap/provider_query.py new file mode 100644 index 000000000..47fcf98ad --- /dev/null +++ b/libp2p/bitswap/provider_query.py @@ -0,0 +1,457 @@ +""" +Provider Query Manager for Bitswap. + +This module provides DHT integration for automatic provider discovery with +caching, parallelization, and error handling. It's a critical component for +enabling automatic peer discovery in Bitswap without manual peer specification. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +import logging +import time +from typing import TYPE_CHECKING + +import trio + +from libp2p.peer.id import ID as PeerID + +from .cid import CIDInput, cid_to_bytes, format_cid_for_display + +if TYPE_CHECKING: + from libp2p.kad_dht.kad_dht import KadDHT + +logger = logging.getLogger(__name__) + + +@dataclass +class ProviderCacheEntry: + """ + Cached provider information for a CID. + + Attributes: + providers: List of peer IDs that provide this content + timestamp: When this entry was cached + ttl: Time-to-live in seconds (how long the cache is valid) + + """ + + providers: list[PeerID] + timestamp: float = field(default_factory=time.time) + ttl: float = 300 # 5 minutes default + + def is_expired(self) -> bool: + """Check if this cache entry has expired.""" + return (time.time() - self.timestamp) > self.ttl + + def age(self) -> float: + """Get the age of this cache entry in seconds.""" + return time.time() - self.timestamp + + +class ProviderCache: + """ + LRU cache for provider records with TTL support. + + Caches DHT provider query results to reduce network load and improve + performance for repeated queries. + """ + + def __init__(self, max_size: int = 1000, default_ttl: float = 300): + """ + Initialize provider cache. + + Args: + max_size: Maximum number of entries to cache + default_ttl: Default time-to-live in seconds + + """ + self.max_size = max_size + self.default_ttl: float = default_ttl + self._cache: dict[bytes, ProviderCacheEntry] = {} + self._access_order: list[bytes] = [] # For LRU tracking + + def get(self, cid_bytes: bytes) -> list[PeerID] | None: + """ + Get cached providers for a CID. + + Args: + cid_bytes: CID as bytes + + Returns: + List of provider peer IDs if cached and not expired, None otherwise + + """ + if cid_bytes not in self._cache: + return None + + entry = self._cache[cid_bytes] + + # Check if expired + if entry.is_expired(): + self._remove(cid_bytes) + return None + + # Update access order (LRU) + self._mark_accessed(cid_bytes) + + return entry.providers + + def put( + self, + cid_bytes: bytes, + providers: list[PeerID], + ttl: float | None = None, + ) -> None: + """ + Cache providers for a CID. + + Args: + cid_bytes: CID as bytes + providers: List of provider peer IDs + ttl: Optional custom TTL (uses default if not specified) + + """ + # Evict oldest entry if cache is full + if len(self._cache) >= self.max_size and cid_bytes not in self._cache: + self._evict_oldest() + + # Store entry + entry = ProviderCacheEntry( + providers=providers, + timestamp=time.time(), + ttl=ttl or self.default_ttl, + ) + self._cache[cid_bytes] = entry + self._mark_accessed(cid_bytes) + + def _mark_accessed(self, cid_bytes: bytes) -> None: + """Mark a cache entry as recently accessed (for LRU).""" + # Remove from current position if exists + if cid_bytes in self._access_order: + self._access_order.remove(cid_bytes) + # Add to end (most recently used) + self._access_order.append(cid_bytes) + + def _evict_oldest(self) -> None: + """Evict the least recently used cache entry.""" + if not self._access_order: + return + oldest = self._access_order.pop(0) + self._remove(oldest) + + def _remove(self, cid_bytes: bytes) -> None: + """Remove an entry from the cache.""" + if cid_bytes in self._cache: + del self._cache[cid_bytes] + if cid_bytes in self._access_order: + self._access_order.remove(cid_bytes) + + def clear(self) -> None: + """Clear all cache entries.""" + self._cache.clear() + self._access_order.clear() + + def cleanup_expired(self) -> int: + """ + Remove all expired entries from the cache. + + Returns: + Number of entries removed + + """ + expired = [ + cid_bytes for cid_bytes, entry in self._cache.items() if entry.is_expired() + ] + + for cid_bytes in expired: + self._remove(cid_bytes) + + return len(expired) + + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + def stats(self) -> dict[str, int]: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics + + """ + return { + "size": len(self._cache), + "max_size": self.max_size, + "expired": sum(1 for e in self._cache.values() if e.is_expired()), + } + + +class ProviderQueryManager: + """ + Manages DHT provider queries with caching and parallelization. + + This component integrates Bitswap with the Kademlia DHT to automatically + discover which peers have specific content. It provides: + + - Automatic provider discovery via DHT + - Parallel queries for multiple CIDs + - Provider caching to reduce DHT load + - Configurable limits and timeouts + - Error handling and retry logic + + Example: + >>> dht = KadDHT(host) + >>> manager = ProviderQueryManager(dht) + >>> providers = await manager.find_providers([cid1, cid2]) + >>> print(f"Found {len(providers)} provider mappings") + + """ + + def __init__( + self, + dht: KadDHT, + max_providers: int = 10, + cache_ttl: float = 300, # 5 minutes + cache_size: int = 1000, + max_concurrent_queries: int = 20, + ): + """ + Initialize Provider Query Manager. + + Args: + dht: Kademlia DHT instance for provider queries + max_providers: Maximum number of providers to return per CID + cache_ttl: Cache time-to-live in seconds + cache_size: Maximum number of CIDs to cache + max_concurrent_queries: Maximum parallel DHT queries + + """ + self.dht = dht + self.max_providers = max_providers + self.cache = ProviderCache(max_size=cache_size, default_ttl=cache_ttl) + self.query_semaphore = trio.Semaphore(max_concurrent_queries) + + # Statistics + self._stats = { + "queries": 0, + "cache_hits": 0, + "cache_misses": 0, + "errors": 0, + "providers_found": 0, + } + + async def find_providers( + self, + cids: Sequence[CIDInput], + timeout: float = 5.0, + use_cache: bool = True, + ) -> dict[bytes, list[PeerID]]: + """ + Find providers for multiple CIDs in parallel. + + This is the main entry point for provider discovery. It: + 1. Checks cache for each CID + 2. Queries DHT in parallel for cache misses + 3. Updates cache with results + 4. Returns combined results + + Args: + cids: List of CIDs to find providers for + timeout: Timeout per DHT query in seconds + use_cache: Whether to use cached results + + Returns: + Dictionary mapping CID bytes to list of provider peer IDs + + Example: + >>> cids = [cid1, cid2, cid3] + >>> results = await manager.find_providers(cids) + >>> for cid_bytes, providers in results.items(): + ... n = len(providers) + ... print(f"CID {cid_bytes.hex()[:8]}... has {n} providers") + + """ + results: dict[bytes, list[PeerID]] = {} + missing: list[tuple[CIDInput, bytes]] = [] + + # Phase 1: Check cache + for cid in cids: + cid_bytes = cid_to_bytes(cid) + + if use_cache: + cached = self.cache.get(cid_bytes) + if cached is not None: + results[cid_bytes] = cached + self._stats["cache_hits"] += 1 + logger.debug( + f"Cache hit for {format_cid_for_display(cid, max_len=12)}: " + f"{len(cached)} providers" + ) + continue + + # Not in cache or cache disabled + missing.append((cid, cid_bytes)) + self._stats["cache_misses"] += 1 + + if not missing: + logger.debug(f"All {len(cids)} CIDs found in cache") + return results + + logger.info( + f"Querying DHT for {len(missing)} CIDs (cache hits: {len(results)})" + ) + + # Phase 2: Query DHT in parallel for missing CIDs + async with trio.open_nursery() as nursery: + for cid, cid_bytes in missing: + nursery.start_soon( + self._query_single, + cid, + cid_bytes, + results, + timeout, + ) + + logger.info( + f"Provider discovery complete: {len(results)}/{len(cids)} CIDs resolved" + ) + + return results + + async def _query_single( + self, + cid: CIDInput, + cid_bytes: bytes, + results: dict[bytes, list[PeerID]], + timeout: float, + ) -> None: + """ + Query DHT for providers of a single CID. + + This method is called concurrently for each CID. It uses a semaphore + to limit parallelism and handles errors gracefully. + + Args: + cid: CID to query (for display) + cid_bytes: CID as bytes (for DHT query) + results: Shared results dictionary to update + timeout: Query timeout in seconds + + """ + async with self.query_semaphore: + self._stats["queries"] += 1 + + try: + with trio.fail_after(timeout): + # Perform a network DHT provider lookup (not a local-store read) + provider_infos = await self.dht.provider_store.find_providers( + cid_bytes, self.max_providers + ) + + # Extract peer IDs from PeerInfo objects + providers = [info.peer_id for info in provider_infos] + + # Limit to max_providers + if len(providers) > self.max_providers: + providers = providers[: self.max_providers] + + if providers: + # Update results + results[cid_bytes] = providers + + # Update cache with remote results + self.cache.put(cid_bytes, providers) + + # Update stats + self._stats["providers_found"] += len(providers) + + logger.debug( + f"Found {len(providers)} providers for " + f"{format_cid_for_display(cid, max_len=12)}" + ) + else: + logger.debug( + f"No providers found for " + f"{format_cid_for_display(cid, max_len=12)}" + ) + + except trio.TooSlowError: + self._stats["errors"] += 1 + logger.warning( + f"DHT query timeout for {format_cid_for_display(cid, max_len=12)}" + ) + except Exception as e: + self._stats["errors"] += 1 + cid_disp = format_cid_for_display(cid, max_len=12) + logger.error(f"DHT query error for {cid_disp}: {e}") + + async def find_providers_single( + self, + cid: CIDInput, + timeout: float = 5.0, + use_cache: bool = True, + ) -> list[PeerID]: + """ + Find providers for a single CID (convenience method). + + Args: + cid: CID to find providers for + timeout: Query timeout in seconds + use_cache: Whether to use cached results + + Returns: + List of provider peer IDs + + Example: + >>> providers = await manager.find_providers_single(cid) + >>> for peer_id in providers: + ... print(f"Provider: {peer_id}") + + """ + results = await self.find_providers([cid], timeout, use_cache) + cid_bytes = cid_to_bytes(cid) + return results.get(cid_bytes, []) + + def get_stats(self) -> dict[str, int]: + """ + Get provider query statistics. + + Returns: + Dictionary with statistics: + - queries: Total DHT queries made + - cache_hits: Number of cache hits + - cache_misses: Number of cache misses + - errors: Number of query errors + - providers_found: Total providers discovered + - cache_size: Current cache size + + Example: + >>> stats = manager.get_stats() + >>> print(f"Cache hit rate: {stats['cache_hits'] / stats['queries']:.1%}") + + """ + stats = self._stats.copy() + stats.update(self.cache.stats()) + return stats + + def clear_cache(self) -> None: + """Clear the provider cache.""" + self.cache.clear() + logger.info("Provider cache cleared") + + async def cleanup_expired_cache(self) -> int: + """ + Remove expired entries from cache. + + Returns: + Number of entries removed + + """ + removed = self.cache.cleanup_expired() + if removed > 0: + logger.debug(f"Removed {removed} expired cache entries") + return removed diff --git a/libp2p/bitswap/wantlist.py b/libp2p/bitswap/wantlist.py new file mode 100644 index 000000000..8c3f80519 --- /dev/null +++ b/libp2p/bitswap/wantlist.py @@ -0,0 +1,367 @@ +""" +Typed dataclass wrappers for Bitswap wantlist entries and messages. + +Provides a clean, self-documenting Python API over the raw protobuf +Message format. All types here are pure Python dataclasses — no +protobuf dependency. Convert to/from protobuf via messages.py helpers. + +Usage: + from libp2p.bitswap.wantlist import ( + WantType, BlockPresenceType, + WantlistEntry, Wantlist, + BlockPresence, BitswapMessage, + ) + + # Build a wantlist + wl = Wantlist() + wl.add(my_cid, want_type=WantType.Block, send_dont_have=True) + wl.add(other_cid, want_type=WantType.Have) + + # Build a full message + msg = BitswapMessage() + msg.add_want(my_cid, want_type=WantType.Block) + msg.add_block(root_cid, block_data) + msg.add_have(peer_cid) + msg.add_dont_have(missing_cid) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from .cid import CIDInput, cid_to_bytes +from .pb.bitswap_pb2 import Message as PBMessage + +# ── enums ───────────────────────────────────────────────────────────────────── + + +class WantType(Enum): + """ + Type of want request (Bitswap 1.2.0 wantType field). + + Block = 0 → "Send me the full block bytes." + Have = 1 → "Just tell me if you have it (HAVE/DONT_HAVE response)." + Cheaper than Block — useful for presence checks before + committing to a full block transfer. + """ + + Block = 0 + Have = 1 + + +class BlockPresenceType(Enum): + """ + Type of block presence response (Bitswap 1.2.0 BlockPresence.type field). + + Have = 0 → Peer has the block and can send it. + DontHave = 1 → Peer does not have the block. + """ + + Have = 0 + DontHave = 1 + + +# ── wantlist dataclasses ────────────────────────────────────────────────────── + + +@dataclass +class WantlistEntry: + """ + A single entry in a Bitswap wantlist. + + Prefer constructing via WantlistEntry.from_cid() which normalises + any CIDInput form to raw bytes. + + Attributes: + cid: CID of the requested block as raw bytes. + priority: Request urgency. Higher = more urgent. Default 1. + cancel: True to cancel a previously sent want for this CID. + want_type: WantType.Block (full data) or WantType.Have (presence). + send_dont_have: If True, ask the peer to send an explicit DontHave + response when it doesn't have the block. + + """ + + cid: bytes + priority: int = 1 + cancel: bool = False + want_type: WantType = WantType.Block + send_dont_have: bool = False + + @classmethod + def from_cid( + cls, + cid: CIDInput, + priority: int = 1, + cancel: bool = False, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> WantlistEntry: + """Create a WantlistEntry from any CIDInput form.""" + return cls( + cid=cid_to_bytes(cid), + priority=priority, + cancel=cancel, + want_type=want_type, + send_dont_have=send_dont_have, + ) + + +@dataclass +class Wantlist: + """ + A collection of wantlist entries. + + Attributes: + entries: List of WantlistEntry items. + full: True = this replaces the peer's entire wantlist. + False (default) = delta update, adds/cancels entries. + + Example: + >>> wl = Wantlist() + >>> wl.add(cid1, want_type=WantType.Block, send_dont_have=True) + >>> wl.add(cid2, want_type=WantType.Have) + >>> wl.cancel(cid3) + >>> print(len(wl)) # 3 + + """ + + entries: list[WantlistEntry] = field(default_factory=list) + full: bool = False + + def add( + self, + cid: CIDInput, + priority: int = 1, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> None: + """Add a want entry for the given CID.""" + self.entries.append( + WantlistEntry.from_cid( + cid, + priority=priority, + want_type=want_type, + send_dont_have=send_dont_have, + ) + ) + + def cancel(self, cid: CIDInput) -> None: + """Add a cancel entry for a previously wanted CID.""" + self.entries.append(WantlistEntry.from_cid(cid, cancel=True)) + + def contains(self, cid: CIDInput) -> bool: + """Return True if any non-cancel entry exists for this CID.""" + cid_bytes = cid_to_bytes(cid) + return any(e.cid == cid_bytes and not e.cancel for e in self.entries) + + def __len__(self) -> int: + return len(self.entries) + + def __bool__(self) -> bool: + return bool(self.entries) + + +# ── message dataclasses ─────────────────────────────────────────────────────── + + +@dataclass +class BlockPresence: + """ + A HAVE or DONT_HAVE response for a specific CID (Bitswap 1.2.0). + + Use the class-method constructors for convenience: + BlockPresence.have(cid) + BlockPresence.dont_have(cid) + """ + + cid: bytes + type: BlockPresenceType + + @classmethod + def have(cls, cid: CIDInput) -> BlockPresence: + """Create a HAVE response.""" + return cls(cid=cid_to_bytes(cid), type=BlockPresenceType.Have) + + @classmethod + def dont_have(cls, cid: CIDInput) -> BlockPresence: + """Create a DONT_HAVE response.""" + return cls(cid=cid_to_bytes(cid), type=BlockPresenceType.DontHave) + + +@dataclass +class BitswapMessage: + """ + High-level typed representation of a Bitswap protocol message. + + Wraps the three main message components with typed fields and + convenience builder methods. Does not depend on protobuf directly — + convert to/from protobuf using to_proto() / from_proto(). + + Attributes: + wantlist: Optional wantlist (want/cancel entries). + blocks: List of (cid_bytes, block_data) block payloads. + block_presences: List of HAVE/DONT_HAVE presence responses. + pending_bytes: Bytes queued to send (v1.2.0 flow-control hint). + + Properties: + is_want True if the message contains want entries. + has_blocks True if the message contains block payloads. + has_presences True if the message contains HAVE/DONT_HAVE entries. + + Example: + >>> msg = BitswapMessage() + >>> msg.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + >>> msg.add_want(cid2, want_type=WantType.Have) + >>> msg.add_block(root_cid, data) + >>> msg.add_have(cid3) + >>> msg.add_dont_have(cid4) + >>> assert msg.is_want and msg.has_blocks and msg.has_presences + + """ + + wantlist: Wantlist | None = None + blocks: list[tuple[bytes, bytes]] = field(default_factory=list) # (cid, data) + block_presences: list[BlockPresence] = field(default_factory=list) + pending_bytes: int = 0 + + # ── read-only properties ────────────────────────────────────────────────── + + @property + def is_want(self) -> bool: + """True if this message contains wantlist entries.""" + return self.wantlist is not None and bool(self.wantlist) + + @property + def has_blocks(self) -> bool: + """True if this message carries block payloads.""" + return bool(self.blocks) + + @property + def has_presences(self) -> bool: + """True if this message carries HAVE/DONT_HAVE responses.""" + return bool(self.block_presences) + + # ── builder methods ─────────────────────────────────────────────────────── + + def add_want( + self, + cid: CIDInput, + priority: int = 1, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> None: + """Add a want entry. Creates the wantlist if not yet present.""" + if self.wantlist is None: + self.wantlist = Wantlist() + self.wantlist.add( + cid, + priority=priority, + want_type=want_type, + send_dont_have=send_dont_have, + ) + + def cancel_want(self, cid: CIDInput) -> None: + """Add a cancel entry for a previously wanted CID.""" + if self.wantlist is None: + self.wantlist = Wantlist() + self.wantlist.cancel(cid) + + def add_block(self, cid: CIDInput, data: bytes) -> None: + """Add a block payload to this message.""" + self.blocks.append((cid_to_bytes(cid), data)) + + def add_have(self, cid: CIDInput) -> None: + """Add a HAVE presence response.""" + self.block_presences.append(BlockPresence.have(cid)) + + def add_dont_have(self, cid: CIDInput) -> None: + """Add a DONT_HAVE presence response.""" + self.block_presences.append(BlockPresence.dont_have(cid)) + + # ── protobuf conversion ─────────────────────────────────────────────────── + + def to_proto(self) -> PBMessage: + """ + Convert to a raw protobuf Message object (pb.bitswap_pb2.Message). + + Returns: + A populated protobuf Message ready for serialisation. + + """ + proto = PBMessage() + + if self.wantlist is not None: + for entry in self.wantlist.entries: + pb_entry = proto.wantlist.entries.add() + pb_entry.block = entry.cid + pb_entry.priority = entry.priority + pb_entry.cancel = entry.cancel + pb_entry.wantType = entry.want_type.value # type: ignore[assignment] + pb_entry.sendDontHave = entry.send_dont_have + proto.wantlist.full = self.wantlist.full + + for cid_bytes, data in self.blocks: + from .cid import get_cid_prefix + + pb_block = proto.payload.add() + pb_block.prefix = get_cid_prefix(cid_bytes) + pb_block.data = data + + for presence in self.block_presences: + pb_presence = proto.blockPresences.add() + pb_presence.cid = presence.cid + pb_presence.type = presence.type.value # type: ignore[assignment] + + if self.pending_bytes: + proto.pendingBytes = self.pending_bytes + + return proto + + @classmethod + def from_proto(cls, proto: PBMessage) -> BitswapMessage: + """ + Build a BitswapMessage from a raw protobuf Message object. + + Args: + proto: A pb.bitswap_pb2.Message instance. + + Returns: + A populated BitswapMessage dataclass. + + """ + from .cid import reconstruct_cid_from_prefix_and_data + + msg = cls() + + if proto.HasField("wantlist") and proto.wantlist.entries: + wl = Wantlist(full=proto.wantlist.full) + for e in proto.wantlist.entries: + wl.entries.append( + WantlistEntry( + cid=bytes(e.block), + priority=e.priority, + cancel=e.cancel, + want_type=WantType(e.wantType), + send_dont_have=e.sendDontHave, + ) + ) + msg.wantlist = wl + + for pb_block in proto.payload: + cid_bytes = reconstruct_cid_from_prefix_and_data( + bytes(pb_block.prefix), bytes(pb_block.data) + ) + msg.blocks.append((cid_bytes, bytes(pb_block.data))) + + for pb_presence in proto.blockPresences: + msg.block_presences.append( + BlockPresence( + cid=bytes(pb_presence.cid), + type=BlockPresenceType(pb_presence.type), + ) + ) + + msg.pending_bytes = proto.pendingBytes + return msg diff --git a/libp2p/kad_dht/__init__.py b/libp2p/kad_dht/__init__.py index 690d37bae..cf58e878f 100644 --- a/libp2p/kad_dht/__init__.py +++ b/libp2p/kad_dht/__init__.py @@ -7,6 +7,7 @@ from .kad_dht import ( KadDHT, + DHTMode, ) from .peer_routing import ( PeerRouting, @@ -23,6 +24,7 @@ __all__ = [ "KadDHT", + "DHTMode", "RoutingTable", "PeerRouting", "ValueStore", diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 01aa23afc..bb11f1cb6 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -1058,7 +1058,7 @@ async def query_one(peer: ID) -> None: values = [rec.value for _p, rec in valid_records] best_idx = self.validator.select(key, values) logger.debug( - f"Selected best value at index {best_idx}using validator.select()" + f"Selected best value at index {best_idx} using validator.select()" ) best_peer, best_rec = valid_records[best_idx] @@ -1074,7 +1074,7 @@ async def query_one(peer: ID) -> None: if outdated_peers: logger.debug( - f"Propagating best value to {len(outdated_peers)}" + f"Propagating best value to {len(outdated_peers)} " "peers with outdated values" ) diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto index 8d66cca5c..93fe526c3 100644 --- a/libp2p/kad_dht/pb/kademlia.proto +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -4,6 +4,11 @@ message Record { bytes key = 1; bytes value = 2; string timeReceived = 5; + // author is the serialized public key of the record author (for unsigned records) + optional bytes author = 3; + // signature is the Ed25519/Secp256k1 signature over the record + // signing payload: "libp2p-record:" + key + value + optional bytes signature = 4; }; message Message { @@ -39,4 +44,3 @@ message Message { optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded } -` diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py index e41bb5292..19b4c2ca2 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.py +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -1,22 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE # source: libp2p/kad_dht/pb/kademlia.proto -# Protobuf Python Version: 5.29.3 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 3, - '', - 'libp2p/kad_dht/pb/kademlia.proto' -) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -24,21 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\"\x80\x01\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\x12\x13\n\x06\x61uthor\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x12\x16\n\tsignature\x18\x04 \x01(\x0cH\x01\x88\x01\x01\x42\t\n\x07_authorB\x0c\n\n_signature\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_RECORD']._serialized_start=36 - _globals['_RECORD']._serialized_end=94 - _globals['_MESSAGE']._serialized_start=97 - _globals['_MESSAGE']._serialized_end=643 - _globals['_MESSAGE_PEER']._serialized_start=308 - _globals['_MESSAGE_PEER']._serialized_end=430 - _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432 - _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537 - _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539 - _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626 +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_RECORD']._serialized_start=37 + _globals['_RECORD']._serialized_end=165 + _globals['_MESSAGE']._serialized_start=168 + _globals['_MESSAGE']._serialized_end=714 + _globals['_MESSAGE_PEER']._serialized_start=379 + _globals['_MESSAGE_PEER']._serialized_end=501 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=503 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=608 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=610 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=697 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi index 641ae66ae..ae32c2361 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.pyi +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -1,144 +1,74 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" - -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing - -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Record(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - TIMERECEIVED_FIELD_NUMBER: builtins.int - key: builtins.bytes - value: builtins.bytes - timeReceived: builtins.str - def __init__( - self, - *, - key: builtins.bytes = ..., - value: builtins.bytes = ..., - timeReceived: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... - -global___Record = Record - -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _MessageType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - PUT_VALUE: Message._MessageType.ValueType # 0 - GET_VALUE: Message._MessageType.ValueType # 1 - ADD_PROVIDER: Message._MessageType.ValueType # 2 - GET_PROVIDERS: Message._MessageType.ValueType # 3 - FIND_NODE: Message._MessageType.ValueType # 4 - PING: Message._MessageType.ValueType # 5 - - class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... - PUT_VALUE: Message.MessageType.ValueType # 0 - GET_VALUE: Message.MessageType.ValueType # 1 - ADD_PROVIDER: Message.MessageType.ValueType # 2 - GET_PROVIDERS: Message.MessageType.ValueType # 3 - FIND_NODE: Message.MessageType.ValueType # 4 - PING: Message.MessageType.ValueType # 5 - - class _ConnectionType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NOT_CONNECTED: Message._ConnectionType.ValueType # 0 - CONNECTED: Message._ConnectionType.ValueType # 1 - CAN_CONNECT: Message._ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 - - class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... - NOT_CONNECTED: Message.ConnectionType.ValueType # 0 - CONNECTED: Message.ConnectionType.ValueType # 1 - CAN_CONNECT: Message.ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 - - @typing.final - class Peer(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - ADDRS_FIELD_NUMBER: builtins.int - CONNECTION_FIELD_NUMBER: builtins.int - SIGNEDRECORD_FIELD_NUMBER: builtins.int - id: builtins.bytes - connection: global___Message.ConnectionType.ValueType - signedRecord: builtins.bytes - """Envelope(PeerRecord) encoded""" - @property - def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__( - self, - *, - id: builtins.bytes = ..., - addrs: collections.abc.Iterable[builtins.bytes] | None = ..., - connection: global___Message.ConnectionType.ValueType = ..., - signedRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_signedRecord", b"_signedRecord", "signedRecord", b"signedRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_signedRecord", b"_signedRecord", "addrs", b"addrs", "connection", b"connection", "id", b"id", "signedRecord", b"signedRecord"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_signedRecord", b"_signedRecord"]) -> typing.Literal["signedRecord"] | None: ... - - TYPE_FIELD_NUMBER: builtins.int - CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - RECORD_FIELD_NUMBER: builtins.int - CLOSERPEERS_FIELD_NUMBER: builtins.int - PROVIDERPEERS_FIELD_NUMBER: builtins.int - SENDERRECORD_FIELD_NUMBER: builtins.int - type: global___Message.MessageType.ValueType - clusterLevelRaw: builtins.int - key: builtins.bytes - senderRecord: builtins.bytes - """Envelope(PeerRecord) encoded""" - @property - def record(self) -> global___Record: ... - @property - def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - @property - def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - def __init__( - self, - *, - type: global___Message.MessageType.ValueType = ..., - clusterLevelRaw: builtins.int = ..., - key: builtins.bytes = ..., - record: global___Record | None = ..., - closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - senderRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "record", b"record", "senderRecord", b"senderRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "senderRecord", b"senderRecord", "type", b"type"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_senderRecord", b"_senderRecord"]) -> typing.Literal["senderRecord"] | None: ... - -global___Message = Message +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import Any, ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Record(_message.Message): + __slots__ = ("key", "value", "timeReceived", "author", "signature") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + TIMERECEIVED_FIELD_NUMBER: _ClassVar[int] + AUTHOR_FIELD_NUMBER: _ClassVar[int] + SIGNATURE_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + timeReceived: str + author: bytes + signature: bytes + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ..., author: _Optional[bytes] = ..., signature: _Optional[bytes] = ...) -> None: ... + +class Message(_message.Message): + __slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord") + class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PUT_VALUE: _ClassVar[Message.MessageType] + GET_VALUE: _ClassVar[Message.MessageType] + ADD_PROVIDER: _ClassVar[Message.MessageType] + GET_PROVIDERS: _ClassVar[Message.MessageType] + FIND_NODE: _ClassVar[Message.MessageType] + PING: _ClassVar[Message.MessageType] + PUT_VALUE: Message.MessageType + GET_VALUE: Message.MessageType + ADD_PROVIDER: Message.MessageType + GET_PROVIDERS: Message.MessageType + FIND_NODE: Message.MessageType + PING: Message.MessageType + class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NOT_CONNECTED: _ClassVar[Message.ConnectionType] + CONNECTED: _ClassVar[Message.ConnectionType] + CAN_CONNECT: _ClassVar[Message.ConnectionType] + CANNOT_CONNECT: _ClassVar[Message.ConnectionType] + NOT_CONNECTED: Message.ConnectionType + CONNECTED: Message.ConnectionType + CAN_CONNECT: Message.ConnectionType + CANNOT_CONNECT: Message.ConnectionType + class Peer(_message.Message): + __slots__ = ("id", "addrs", "connection", "signedRecord") + ID_FIELD_NUMBER: _ClassVar[int] + ADDRS_FIELD_NUMBER: _ClassVar[int] + CONNECTION_FIELD_NUMBER: _ClassVar[int] + SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int] + id: bytes + addrs: _containers.RepeatedScalarFieldContainer[bytes] + connection: Message.ConnectionType + signedRecord: bytes + def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + CLOSERPEERS_FIELD_NUMBER: _ClassVar[int] + PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + type: Message.MessageType + clusterLevelRaw: int + key: bytes + record: Record + closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + senderRecord: bytes + def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping[str, Any]]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping[str, Any]]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping[str, Any]]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index 90cd77ae4..459e7487e 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -20,6 +20,7 @@ ID, ) from libp2p.peer.peerstore import env_to_send_in_RPC +from libp2p.records.record import make_signed_put_record from .common import ( DEFAULT_TTL, @@ -65,14 +66,17 @@ def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None: None """ - from libp2p.records.record import make_put_record - if validity == 0.0: validity = time.time() + DEFAULT_TTL logger.debug( "Storing value for key %s... with validity %s", key.hex(), validity ) - record = make_put_record(key, value) + + # Create a signed record using the host's private key + private_key = self.host.get_private_key() + record = make_signed_put_record(key, value, private_key) + + # Set timeReceived when storing locally record.timeReceived = str(time.time()) self.store[key] = (record, validity) @@ -123,11 +127,20 @@ async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool: envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes - # Set message fields + # Build the outbound record from the locally-stored signed record when + # available (normal put() path), otherwise sign the record now so the + # outbound message always carries signature and author fields. + local_entry = self.store.get(key) + if local_entry is not None: + signed_record, _ = local_entry + message.record.CopyFrom(signed_record) + else: + private_key = self.host.get_private_key() + signed_record = make_signed_put_record(key, value, private_key) + message.record.CopyFrom(signed_record) message.key = key - message.record.key = key - message.record.value = value - message.record.timeReceived = str(time.time()) + # Note: timeReceived will be set by the receiving peer when storing + message.record.ClearField("timeReceived") # Serialize and send the protobuf message with length prefix proto_bytes = message.SerializeToString() @@ -320,6 +333,10 @@ async def _get_from_peer( logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" ) + + # Update timeReceived to current time (when we received it locally) + response.record.timeReceived = str(time.time()) + return response.record if return_record else response.record.value # Handle case where value is not found but peer infos are returned diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index 1fcbb1c75..9a7f6466f 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,7 +1,7 @@ from typing import Any, cast import multiaddr -from multicodec import Code, get_codec, get_prefix +from multicodec import Code, get_prefix from multicodec.code_table import LIBP2P_PEER_RECORD from libp2p.crypto.ed25519 import Ed25519PublicKey @@ -12,6 +12,7 @@ import libp2p.peer.pb.envelope_pb2 as pb import libp2p.peer.pb.peer_record_pb2 as record_pb from libp2p.peer.peer_record import ( + PEER_RECORD_ENVELOPE_PAYLOAD_TYPE, PeerRecord, peer_record_from_protobuf, unmarshal_record, @@ -19,9 +20,10 @@ from libp2p.utils.varint import encode_uvarint ENVELOPE_DOMAIN = "libp2p-peer-record" -# Multicodec-based codec for peer records +# Multicodec Code object (for internal use / comparison only) PEER_RECORD_CODE: Code = LIBP2P_PEER_RECORD -PEER_RECORD_CODEC: bytes = get_prefix(str(PEER_RECORD_CODE)) +# Wire-format payload type bytes — matches go-libp2p: []byte{0x03, 0x01} +PEER_RECORD_CODEC: bytes = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE class Envelope: @@ -40,7 +42,9 @@ class Envelope: """ public_key: PublicKey - payload_type_code: Code + # payload_type is stored as raw bytes (wire format), matching go-libp2p. + # For PeerRecord envelopes this is bytes([0x03, 0x01]), NOT varint-encoded. + _payload_type: bytes raw_payload: bytes signature: bytes @@ -56,28 +60,42 @@ def __init__( ): self.public_key = public_key - # Normalise payload_type to a Code instance + # Normalise payload_type to raw bytes if isinstance(payload_type, bytes): - try: - codec_name = get_codec(payload_type) - self.payload_type_code = Code.from_string(codec_name) - except Exception as e: - raise ValueError(f"Invalid codec: {e}") + # Already raw bytes — use as-is (this is the go-libp2p wire format) + self._payload_type = payload_type elif isinstance(payload_type, str): - try: - self.payload_type_code = Code.from_string(payload_type) - except Exception as e: - raise ValueError(f"Invalid codec: {e}") + # Treat as codec name, encode to raw prefix bytes + self._payload_type = get_prefix(payload_type) + elif isinstance(payload_type, Code): + if payload_type == PEER_RECORD_CODE: + # Use the go-libp2p compatible raw bytes, not varint + self._payload_type = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE + else: + self._payload_type = get_prefix(str(payload_type)) else: - self.payload_type_code = payload_type + self._payload_type = bytes(payload_type) self.raw_payload = raw_payload self.signature = signature @property def payload_type(self) -> bytes: - """Return the multicodec-prefixed payload type.""" - return get_prefix(str(self.payload_type_code)) + """Return the raw payload type bytes (wire format).""" + return self._payload_type + + @property + def payload_type_code(self) -> Code: + """Return the multicodec Code for this payload type (best-effort).""" + return PEER_RECORD_CODE + + @payload_type_code.setter + def payload_type_code(self, value: Code) -> None: + """Update the raw payload_type bytes from a Code value.""" + if value == PEER_RECORD_CODE: + self._payload_type = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE + else: + self._payload_type = get_prefix(str(value)) def marshal_envelope(self) -> bytes: """ @@ -125,10 +143,9 @@ def record(self) -> PeerRecord: return self._cached_record try: - if self.payload_type_code != PEER_RECORD_CODE: + if self._payload_type != PEER_RECORD_ENVELOPE_PAYLOAD_TYPE: raise ValueError( - f"Unsupported payload type in envelope: " - f"{self.payload_type_code.name}" + f"Unsupported payload type in envelope: {self._payload_type.hex()}" ) msg = record_pb.PeerRecord() msg.ParseFromString(self.raw_payload) @@ -154,7 +171,7 @@ def equal(self, other: Any) -> bool: if isinstance(other, Envelope): return ( self.public_key.__eq__(other.public_key) - and self.payload_type_code == other.payload_type_code + and self._payload_type == other._payload_type and self.signature == other.signature and self.raw_payload == other.raw_payload ) @@ -217,7 +234,7 @@ def seal_record(record: PeerRecord, private_key: PrivateKey) -> Envelope: return Envelope( public_key=private_key.get_public_key(), - payload_type=PEER_RECORD_CODE, + payload_type=PEER_RECORD_ENVELOPE_PAYLOAD_TYPE, raw_payload=payload, signature=signature, ) diff --git a/libp2p/peer/peer_record.py b/libp2p/peer/peer_record.py index 0fff196f0..26676f983 100644 --- a/libp2p/peer/peer_record.py +++ b/libp2p/peer/peer_record.py @@ -4,7 +4,7 @@ from typing import Any from multiaddr import Multiaddr -from multicodec import Code, get_prefix +from multicodec import Code from multicodec.code_table import LIBP2P_PEER_RECORD from libp2p.abc import IPeerRecord @@ -14,7 +14,10 @@ PEER_RECORD_ENVELOPE_DOMAIN = "libp2p-peer-record" PEER_RECORD_ENVELOPE_CODE: Code = LIBP2P_PEER_RECORD -PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = get_prefix(str(PEER_RECORD_ENVELOPE_CODE)) +# go-libp2p uses raw bytes [0x03, 0x01] for the peer-record payload type +# (NOT varint-encoded). See: https://github.com/libp2p/go-libp2p/blob/master/core/peer/record.go +# PeerRecordEnvelopePayloadType = []byte{0x03, 0x01} +PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = bytes([0x03, 0x01]) _last_timestamp_lock = threading.Lock() _last_timestamp: int = 0 diff --git a/libp2p/records/record.py b/libp2p/records/record.py index 8644e3c09..87dd96b1c 100644 --- a/libp2p/records/record.py +++ b/libp2p/records/record.py @@ -1,4 +1,6 @@ +from libp2p.crypto.keys import PrivateKey from libp2p.kad_dht.pb import kademlia_pb2 as record_pb2 +from libp2p.records.utils import sign_record def make_put_record(key: bytes, value: bytes) -> record_pb2.Record: @@ -17,3 +19,35 @@ def make_put_record(key: bytes, value: bytes) -> record_pb2.Record: record.key = key record.value = value return record + + +def make_signed_put_record( + key: bytes, value: bytes, private_key: PrivateKey +) -> record_pb2.Record: + """ + Create a signed Record object with the specified key, value, and signature. + + The record is signed using the libp2p record signing convention: + signature = sign("libp2p-record:" + key + value) + + This matches go-libp2p's record signing behavior for DHT PUT_VALUE. + + Args: + key (bytes): The key for the record. + value (bytes): The value to associate with the key in the record. + private_key (PrivateKey): The private key to sign the record with. + + Returns: + record_pb2.Record: A signed Record object. + + """ + record = record_pb2.Record() + record.key = key + record.value = value + + # Sign the record + signature, author_public_key = sign_record(private_key, key, value) + record.signature = signature + record.author = author_public_key + + return record diff --git a/libp2p/records/utils.py b/libp2p/records/utils.py index 82161beb3..2dcc6620f 100644 --- a/libp2p/records/utils.py +++ b/libp2p/records/utils.py @@ -1,7 +1,94 @@ +from libp2p.crypto.ed25519 import Ed25519PublicKey +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.pb import crypto_pb2 +from libp2p.crypto.rsa import RSAPublicKey +from libp2p.crypto.secp256k1 import Secp256k1PublicKey + + class InvalidRecordType(Exception): pass +def _unmarshal_public_key(data: bytes) -> PublicKey: + """ + Deserialize a ``crypto_pb2.PublicKey`` protobuf into a concrete + ``PublicKey`` instance. + + Kept private to this module to avoid the circular import that arises + when importing from ``libp2p.records.pubkey`` (which itself imports + from this module). + """ + proto_key = crypto_pb2.PublicKey.FromString(data) + key_type = proto_key.key_type + key_data = proto_key.data + + if key_type == crypto_pb2.KeyType.RSA: + return RSAPublicKey.from_bytes(key_data) + elif key_type == crypto_pb2.KeyType.Ed25519: + return Ed25519PublicKey.from_bytes(key_data) + elif key_type == crypto_pb2.KeyType.Secp256k1: + return Secp256k1PublicKey.from_bytes(key_data) + else: + raise ValueError(f"Unsupported key type: {key_type}") + + +def sign_record( + private_key: PrivateKey, key: bytes, value: bytes +) -> tuple[bytes, bytes]: + """ + Sign a DHT record using the given private key. + + The signature is computed over "libp2p-record:" + key + value. + + Args: + private_key: The private key to sign with + key: The record key + value: The record value + + Returns: + tuple[bytes, bytes]: A tuple of (signature, author_public_key_bytes) + + """ + signing_payload = b"libp2p-record:" + key + value + signature = private_key.sign(signing_payload) + public_key = private_key.get_public_key() + # Serialize as a protobuf-wrapped PublicKey so that verify_record (and + # remote peers) can reconstruct the key without knowing its type in advance. + author_bytes = public_key.serialize() + return signature, author_bytes + + +def verify_record( + signature: bytes, author_public_key: bytes, key: bytes, value: bytes +) -> bool: + """ + Verify a signed DHT record. + + Supports all key types that libp2p serialises in a protobuf PublicKey + envelope (Ed25519, RSA, Secp256k1). The author field is treated as a + serialised ``crypto_pb2.PublicKey`` message and dispatched through + ``unmarshal_public_key`` so that non-Ed25519 peers are not silently + rejected. + + Args: + signature: The record signature + author_public_key: The serialized public key of the author + (``crypto_pb2.PublicKey`` protobuf bytes) + key: The record key + value: The record value + + Returns: + bool: True if the signature is valid, False otherwise + + """ + try: + public_key = _unmarshal_public_key(author_public_key) + signing_payload = b"libp2p-record:" + key + value + return public_key.verify(signing_payload, signature) + except Exception: + return False + + def split_key(key: str) -> tuple[str, str]: """ Split a record key into its type and the rest. The key must start with diff --git a/newsfragments/1321.feature.rst b/newsfragments/1321.feature.rst new file mode 100644 index 000000000..af0c9d04a --- /dev/null +++ b/newsfragments/1321.feature.rst @@ -0,0 +1,10 @@ +Comprehensive Bitswap overhaul for Kubo compatibility and performance: + +- **Batch block fetching** — send multiple CIDs in a single wantlist message. +- **Kubo-compatible DAG-PB encoding** — produce identical CIDs to Kubo's ``ipfs add``. +- **FilesystemBlockStore** — persistent storage surviving process restarts. +- **BlockService** — local-first lookup with automatic block caching and announcement. +- **Streaming support** — ``chunk_stream`` and ``MerkleDag.add_stream`` for efficient DAG building. +- **Bitswap 1.2.0 wantlist API** — ``WantType``, ``BlockPresence``, ``WantlistEntry``, ``BitswapMessage``. +- **DHT record signing/verification** — Kubo-compatible provider and value record signing. +- **ProviderQueryManager** — automatic DHT-based peer discovery in ``BitswapClient.get_block()`` with LRU caching. diff --git a/tests/core/bitswap/test_block_service.py b/tests/core/bitswap/test_block_service.py new file mode 100644 index 000000000..f4754dd7c --- /dev/null +++ b/tests/core/bitswap/test_block_service.py @@ -0,0 +1,236 @@ +""" +Test BlockService — transparent local→network fallback with auto-caching. + +Run with: + python test_block_service.py +""" + +from unittest.mock import AsyncMock, MagicMock + +import trio + +from libp2p.bitswap.block_service import BlockService +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import CODEC_RAW, compute_cid_v1 +from libp2p.bitswap.client import BitswapClient + + +def make_block(content: bytes): + cid = compute_cid_v1(content, codec=CODEC_RAW) + return cid, content + + +def ok(label): + print(f" OK {label}") + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def make_service(network_blocks: dict | None = None): + """ + Build a BlockService with a real MemoryBlockStore and a mock BitswapClient. + network_blocks: cid_bytes -> data that the mock 'network' can return. + """ + store = MemoryBlockStore() + mock_bitswap = MagicMock(spec=BitswapClient) + mock_bitswap.block_store = store + network_blocks = network_blocks or {} + + async def fake_get_block(cid, peer_id=None, timeout=30.0): + return network_blocks.get(bytes(cid)) + + async def fake_add_block(cid, data): + pass # just accept it + + async def fake_get_blocks_batch(cids, peer_id=None, timeout=30.0, batch_size=32): + return { + bytes(c): network_blocks[bytes(c)] + for c in cids + if bytes(c) in network_blocks + } + + mock_bitswap.get_block = AsyncMock(side_effect=fake_get_block) + mock_bitswap.add_block = AsyncMock(side_effect=fake_add_block) + mock_bitswap.get_blocks_batch = AsyncMock(side_effect=fake_get_blocks_batch) + + service = BlockService(store, mock_bitswap) + return service, store, mock_bitswap + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +async def test_local_hit_no_network(): + print("\n[1] Local hit — network is never called") + cid, data = make_block(b"already stored locally") + service, store, mock_bitswap = make_service() + + # Pre-populate local store + await store.put_block(cid, data) + + result = await service.get_block(cid) + assert result == data + ok("get_block returns local data") + + mock_bitswap.get_block.assert_not_called() + ok("network (bitswap.get_block) was NOT called") + + +async def test_local_miss_goes_to_network(): + print("\n[2] Local miss — fetches from network") + cid, data = make_block(b"only on the network") + service, store, mock_bitswap = make_service(network_blocks={bytes(cid): data}) + + result = await service.get_block(cid) + assert result == data + ok("get_block returns network data") + + mock_bitswap.get_block.assert_called_once() + ok("network (bitswap.get_block) was called exactly once") + + +async def test_auto_cache_after_network_fetch(): + print("\n[3] Auto-cache — network-fetched block stored locally") + cid, data = make_block(b"fetch and cache me") + service, store, mock_bitswap = make_service(network_blocks={bytes(cid): data}) + + # First call: local miss → network fetch → auto-cache + result1 = await service.get_block(cid) + assert result1 == data + + # Verify it's now in the local store + cached = await store.get_block(cid) + assert cached == data + ok("block is in local store after first network fetch") + + # Second call: must be a local hit, no second network call + result2 = await service.get_block(cid) + assert result2 == data + assert mock_bitswap.get_block.call_count == 1 # still only 1 network call + ok("second get_block is a local hit (network called only once total)") + + +async def test_put_block_stores_and_announces(): + print("\n[4] put_block — stores locally AND calls bitswap.add_block") + cid, data = make_block(b"new block to store") + service, store, mock_bitswap = make_service() + + await service.put_block(cid, data) + + # Must be in local store + cached = await store.get_block(cid) + assert cached == data + ok("block is in local store after put_block") + + # Must have called bitswap.add_block (announces to waiting peers) + mock_bitswap.add_block.assert_called_once() + ok("bitswap.add_block was called (peers notified)") + + +async def test_get_blocks_batch_local_hits_skip_network(): + print("\n[5] get_blocks_batch — local hits skip network") + blocks = [make_block(f"block {i}".encode()) for i in range(5)] + service, store, mock_bitswap = make_service() + + # Store all 5 locally + for cid, data in blocks: + await store.put_block(cid, data) + + cids: list[bytes] = [cid for cid, _ in blocks] + results = await service.get_blocks_batch(cids) + + assert len(results) == 5 + ok("all 5 blocks returned from local store") + mock_bitswap.get_blocks_batch.assert_not_called() + ok("network batch fetch was NOT called") + + +async def test_get_blocks_batch_partial_local(): + print("\n[6] get_blocks_batch — partial local, rest from network") + local_blocks = [make_block(f"local {i}".encode()) for i in range(3)] + net_blocks = [make_block(f"remote {i}".encode()) for i in range(2)] + network_dict = {bytes(cid): data for cid, data in net_blocks} + + service, store, mock_bitswap = make_service(network_blocks=network_dict) + + # Store only local blocks + for cid, data in local_blocks: + await store.put_block(cid, data) + + all_cids: list[bytes] = [cid for cid, _ in local_blocks + net_blocks] + results = await service.get_blocks_batch(all_cids) + + assert len(results) == 5 + ok("all 5 blocks returned (3 local + 2 network)") + mock_bitswap.get_blocks_batch.assert_called_once() + ok("network batch fetch called exactly once (only for 2 missing blocks)") + + # Network blocks must now be cached locally + for cid, data in net_blocks: + cached = await store.get_block(cid) + assert cached == data + ok("network-fetched blocks are now cached locally") + + +async def test_missing_block_returns_none(): + print("\n[7] get_block returns None when block not found anywhere") + cid, _ = make_block(b"this block does not exist") + service, store, mock_bitswap = make_service(network_blocks={}) # empty network + + result = await service.get_block(cid) + assert result is None + ok("get_block returns None for unknown block") + + +async def test_merkledag_uses_block_service(): + print("\n[8] MerkleDag.add_bytes routes through BlockService") + from libp2p.bitswap.dag import MerkleDag + + service, store, mock_bitswap = make_service() + dag = MerkleDag(mock_bitswap, block_service=service) + + data = b"hello block service" * 100 + root_cid = await dag.add_bytes(data) + + # All blocks must be in the local store via BlockService + cached = await store.get_block(root_cid) + assert cached is not None + ok("root block is in local store via BlockService") + + # bitswap.add_block was called (for peer announcement) + assert mock_bitswap.add_block.called + ok("bitswap.add_block was called for peer announcement") + + # MerkleDag without BlockService still works (no regression) + service2, store2, mock_bitswap2 = make_service() + dag2 = MerkleDag(mock_bitswap2) # no block_service + root_cid2 = await dag2.add_bytes(data) + assert root_cid2 is not None + ok("MerkleDag without BlockService still works (no regression)") + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main(): + print("=" * 60) + print("BlockService — Test Suite") + print("=" * 60) + + await test_local_hit_no_network() + await test_local_miss_goes_to_network() + await test_auto_cache_after_network_fetch() + await test_put_block_stores_and_announces() + await test_get_blocks_batch_local_hits_skip_network() + await test_get_blocks_batch_partial_local() + await test_missing_block_returns_none() + await test_merkledag_uses_block_service() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_dag.py b/tests/core/bitswap/test_dag.py index e94fb0f1a..d1144f707 100644 --- a/tests/core/bitswap/test_dag.py +++ b/tests/core/bitswap/test_dag.py @@ -52,6 +52,8 @@ class TestAddBytes: @pytest.mark.trio async def test_add_small_bytes(self): """Test adding small data (single block).""" + from libp2p.bitswap.dag_pb import create_leaf_node + # Setup mock_client = MagicMock(spec=BitswapClient) mock_client.block_store = MemoryBlockStore() @@ -66,13 +68,18 @@ async def test_add_small_bytes(self): # Verify assert root_cid is not None assert len(root_cid) > 0 - assert verify_cid(root_cid, data) - # Should be single block (RAW codec) + # Small data is stored as a dag-pb leaf node (not raw codec) + leaf_block = create_leaf_node(data) + expected_cid = compute_cid_v1(leaf_block, codec=CODEC_DAG_PB) + assert root_cid == expected_cid + assert verify_cid(root_cid, leaf_block) + + # Should be single block (DAG-PB codec) mock_client.add_block.assert_called_once() call_args = mock_client.add_block.call_args assert call_args[0][0] == root_cid # CID - assert call_args[0][1] == data # Data + assert call_args[0][1] == leaf_block # dag-pb wrapped data @pytest.mark.trio async def test_add_large_bytes(self): @@ -161,9 +168,15 @@ async def test_add_small_file(self): assert root_cid is not None mock_client.add_block.assert_called_once() - # Should be single RAW block + # Small file is stored as a dag-pb leaf node + from libp2p.bitswap.dag_pb import create_leaf_node + call_args = mock_client.add_block.call_args - assert verify_cid(call_args[0][0], data) + stored_cid = call_args[0][0] + stored_block = call_args[0][1] + leaf_block = create_leaf_node(data) + assert stored_block == leaf_block + assert verify_cid(stored_cid, leaf_block) finally: Path(temp_path).unlink() @@ -285,16 +298,22 @@ async def test_fetch_small_file(self, cid_input_kind: str): @pytest.mark.trio async def test_fetch_chunked_file(self): """Test fetching multi-chunk file.""" - # Create chunks + from libp2p.bitswap.dag_pb import create_leaf_node + + # Create dag-pb leaf blocks (matching what add_bytes/add_file produces) chunk1 = b"chunk1" * 1000 chunk2 = b"chunk2" * 1000 chunk3 = b"chunk3" * 1000 - cid1 = compute_cid_v1(chunk1, codec=CODEC_RAW) - cid2 = compute_cid_v1(chunk2, codec=CODEC_RAW) - cid3 = compute_cid_v1(chunk3, codec=CODEC_RAW) + leaf1 = create_leaf_node(chunk1) + leaf2 = create_leaf_node(chunk2) + leaf3 = create_leaf_node(chunk3) - # Create DAG-PB root node + cid1 = compute_cid_v1(leaf1, codec=CODEC_DAG_PB) + cid2 = compute_cid_v1(leaf2, codec=CODEC_DAG_PB) + cid3 = compute_cid_v1(leaf3, codec=CODEC_DAG_PB) + + # Create DAG-PB root node linking to the leaves chunks_data = [ (cid1, len(chunk1)), (cid2, len(chunk2)), @@ -308,11 +327,11 @@ def get_block_side_effect(cid, peer_id, timeout): if cid == root_cid: return root_data elif cid == cid1: - return chunk1 + return leaf1 elif cid == cid2: - return chunk2 + return leaf2 elif cid == cid3: - return chunk3 + return leaf3 raise ValueError(f"Unknown CID: {cid.hex()}") mock_client = MagicMock(spec=BitswapClient) @@ -324,23 +343,30 @@ def get_block_side_effect(cid, peer_id, timeout): # Fetch fetched_data, filename = await dag.fetch_file(root_cid, timeout=30.0) - # Verify + # Verify reconstructed data expected_data = chunk1 + chunk2 + chunk3 assert fetched_data == expected_data assert filename is None # File node without directory wrapper - # Should have fetched root + 3 chunks + # root fetch (1) + tree-level batch fallback (3) = 4 + # Leaves are already fetched during tree traversal, + # no separate leaf fetch needed assert mock_client.get_block.call_count == 4 @pytest.mark.trio async def test_fetch_file_with_progress(self): """Test fetching with progress callback.""" - # Create chunked file + from libp2p.bitswap.dag_pb import create_leaf_node + + # Create dag-pb leaf blocks (matching what add_bytes/add_file produces) chunk1 = b"x" * 1000 chunk2 = b"y" * 1000 - cid1 = compute_cid_v1(chunk1, codec=CODEC_RAW) - cid2 = compute_cid_v1(chunk2, codec=CODEC_RAW) + leaf1 = create_leaf_node(chunk1) + leaf2 = create_leaf_node(chunk2) + + cid1 = compute_cid_v1(leaf1, codec=CODEC_DAG_PB) + cid2 = compute_cid_v1(leaf2, codec=CODEC_DAG_PB) root_data = create_file_node([(cid1, len(chunk1)), (cid2, len(chunk2))]) root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) @@ -350,9 +376,9 @@ def get_block_side_effect(cid, peer_id, timeout): if cid == root_cid: return root_data elif cid == cid1: - return chunk1 + return leaf1 elif cid == cid2: - return chunk2 + return leaf2 mock_client = MagicMock(spec=BitswapClient) mock_client.block_store = MemoryBlockStore() @@ -370,8 +396,8 @@ def progress_callback(current, total, status): # Verify progress assert len(progress_calls) > 0 - # Should report progress for each chunk - assert any("fetching chunk" in call[2] for call in progress_calls) + # Implementation emits "downloading" per leaf and "completed" at end + assert any(call[2] in ("downloading", "completed") for call in progress_calls) # Last call should be completion assert progress_calls[-1][2] == "completed" diff --git a/tests/core/bitswap/test_filesystem_blockstore.py b/tests/core/bitswap/test_filesystem_blockstore.py new file mode 100644 index 000000000..8596f26c1 --- /dev/null +++ b/tests/core/bitswap/test_filesystem_blockstore.py @@ -0,0 +1,207 @@ +""" +Manual test for FilesystemBlockStore. + +Tests: + 1. Basic put/get/has/delete round-trip + 2. Persistence: blocks survive store re-creation (simulates process restart) + 3. get_all_cids: scans the directory tree and returns all stored CIDs + 4. Drop-in replacement: swapping MemoryBlockStore → FilesystemBlockStore + +Run with: + python test_filesystem_blockstore.py + or + pytest test_filesystem_blockstore.py +""" + +from pathlib import Path +import shutil +import tempfile + +import pytest +import trio + +from libp2p.bitswap.block_store import FilesystemBlockStore, MemoryBlockStore +from libp2p.bitswap.cid import CODEC_RAW, cid_to_text, compute_cid_v1 + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def make_block(content: bytes) -> tuple[bytes, bytes]: + """Return (cid_bytes, data) for a raw block.""" + cid = compute_cid_v1(content, codec=CODEC_RAW) + return cid, content + + +def pass_fail(label: str, ok: bool) -> None: + icon = "✅" if ok else "❌" + print(f" {icon} {label}") + if not ok: + raise AssertionError(f"FAILED: {label}") + + +# ── pytest fixtures ─────────────────────────────────────────────────────────── + + +@pytest.fixture +def store_path(tmp_path): + """Provide a fresh temporary directory path for each test.""" + return str(tmp_path) + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +@pytest.mark.trio +async def test_basic_round_trip(store_path: str) -> None: + print("\n[1] Basic put / get / has / delete") + store = FilesystemBlockStore(store_path) + + cid, data = make_block(b"hello filesystem blockstore") + + # has_block → False before put + pass_fail("has_block returns False before put", not await store.has_block(cid)) + + # put_block + await store.put_block(cid, data) + pass_fail("block file exists on disk after put", store._cid_to_path(cid).exists()) + + # get_block + fetched = await store.get_block(cid) + pass_fail("get_block returns correct data", fetched == data) + + # has_block → True after put + pass_fail("has_block returns True after put", await store.has_block(cid)) + + # delete_block + await store.delete_block(cid) + pass_fail("block file gone after delete", not store._cid_to_path(cid).exists()) + pass_fail("get_block returns None after delete", await store.get_block(cid) is None) + + +@pytest.mark.trio +async def test_persistence(store_path: str) -> None: + print("\n[2] Persistence across store re-creation (simulates process restart)") + + # Write with first instance + store1 = FilesystemBlockStore(store_path) + cid1, data1 = make_block(b"block that should survive restart") + cid2, data2 = make_block(b"another persistent block") + await store1.put_block(cid1, data1) + await store1.put_block(cid2, data2) + pass_fail("2 blocks written by store1", store1.size() == 2) + + # Create a brand-new store object pointing to the same path + # (simulates a process restart) + store2 = FilesystemBlockStore(store_path) + pass_fail( + "store2 sees block1 written by store1", await store2.get_block(cid1) == data1 + ) + pass_fail( + "store2 sees block2 written by store1", await store2.get_block(cid2) == data2 + ) + pass_fail("store2.size() == 2", store2.size() == 2) + + print(f" Block directory: {store2.base_path()}") + print(f" CID1: {cid_to_text(cid1)}") + print(f" CID2: {cid_to_text(cid2)}") + + +@pytest.mark.trio +async def test_get_all_cids(store_path: str) -> None: + print("\n[3] get_all_cids scans directory tree") + store = FilesystemBlockStore(store_path) + + blocks = [make_block(f"block {i}".encode()) for i in range(5)] + for cid, data in blocks: + await store.put_block(cid, data) + + all_cids = store.get_all_cids() + pass_fail(f"get_all_cids returns {len(blocks)} CIDs", len(all_cids) == len(blocks)) + + stored_set = {bytes(c) for c in all_cids} + for cid, _ in blocks: + pass_fail( + f"CID {cid_to_text(cid)[:20]}... is in get_all_cids", + bytes(cid) in stored_set, + ) + + +@pytest.mark.trio +async def test_get_missing_returns_none(store_path: str) -> None: + print("\n[4] get_block returns None for missing CID") + store = FilesystemBlockStore(store_path) + cid, _ = make_block(b"this block was never stored") + result = await store.get_block(cid) + pass_fail("get_block returns None for unknown CID", result is None) + + +@pytest.mark.trio +async def test_drop_in_for_memory_store(store_path: str) -> None: + print("\n[5] Drop-in replacement for MemoryBlockStore") + + async def use_store(store) -> bytes: + """Same code works for both store types.""" + cid, data = make_block(b"drop-in replacement test") + await store.put_block(cid, data) + return await store.get_block(cid) + + mem_result = await use_store(MemoryBlockStore()) + fs_result = await use_store(FilesystemBlockStore(store_path)) + + pass_fail( + "MemoryBlockStore and FilesystemBlockStore return same data", + mem_result == fs_result, + ) + + +@pytest.mark.trio +async def test_directory_structure(store_path: str) -> None: + print("\n[6] 2-char prefix directory structure") + store = FilesystemBlockStore(store_path) + cid, data = make_block(b"check directory layout") + await store.put_block(cid, data) + + cid_str = cid_to_text(cid) + expected_dir = Path(store_path) / cid_str[:2] + expected_file = expected_dir / cid_str[2:] + + pass_fail(f"2-char prefix dir '{cid_str[:2]}' exists", expected_dir.is_dir()) + pass_fail( + f"block file '{cid_str[2:8]}...' exists inside prefix dir", + expected_file.exists(), + ) + pass_fail("file contents match original data", expected_file.read_bytes() == data) + + print(f" Path: {expected_file}") + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main() -> None: + print("=" * 60) + print("FilesystemBlockStore — Manual Test Suite") + print("=" * 60) + + # Each test gets its own temp directory so they don't interfere + dirs = [tempfile.mkdtemp(prefix="fs_blockstore_test_") for _ in range(6)] + + try: + await test_basic_round_trip(dirs[0]) + await test_persistence(dirs[1]) + await test_get_all_cids(dirs[2]) + await test_get_missing_returns_none(dirs[3]) + await test_drop_in_for_memory_store(dirs[4]) + await test_directory_structure(dirs[5]) + + print("\n" + "=" * 60) + print("✅ All tests passed!") + print("=" * 60) + + finally: + for d in dirs: + shutil.rmtree(d, ignore_errors=True) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_io_stream.py b/tests/core/bitswap/test_io_stream.py new file mode 100644 index 000000000..bd1ecdecb --- /dev/null +++ b/tests/core/bitswap/test_io_stream.py @@ -0,0 +1,297 @@ +""" +Test io.IOBase input support — chunk_stream() and MerkleDag.add_stream(). + +Run with: + python test_io_stream.py +""" + +import gzip +import io +import os +import tempfile + +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.chunker import DEFAULT_CHUNK_SIZE, chunk_stream +from libp2p.bitswap.cid import cid_to_text +from libp2p.bitswap.dag_pb import decode_dag_pb, is_file_node + + +def ok(label): + print(f" OK {label}") + + +# ── 1. chunk_stream basics ──────────────────────────────────────────────────── + + +def test_chunk_stream_bytesio(): + print("\n[1] chunk_stream — BytesIO") + data = b"x" * (DEFAULT_CHUNK_SIZE * 3 + 100) # 3 full + 1 partial chunk + chunks = list(chunk_stream(io.BytesIO(data), DEFAULT_CHUNK_SIZE)) + assert len(chunks) == 4 + assert b"".join(chunks) == data + assert len(chunks[0]) == DEFAULT_CHUNK_SIZE + assert len(chunks[-1]) == 100 + ok(f"4 chunks, sizes: {[len(c) for c in chunks]}") + + +def test_chunk_stream_empty(): + print("\n[2] chunk_stream — empty stream yields nothing") + chunks = list(chunk_stream(io.BytesIO(b""))) + assert chunks == [] + ok("empty stream yields no chunks") + + +def test_chunk_stream_file_handle(): + print("\n[3] chunk_stream — real file handle") + data = b"file handle test " * 5000 + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(data) + tmp = f.name + try: + with open(tmp, "rb") as fh: + chunks = list(chunk_stream(fh)) + assert b"".join(chunks) == data + ok(f"file handle: {len(chunks)} chunks, {len(data)} bytes total") + finally: + os.unlink(tmp) + + +def test_chunk_stream_gzip(): + print("\n[4] chunk_stream — gzip stream (decompress on-the-fly)") + original = b"compressed data " * 10000 + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(original) + buf.seek(0) + + with gzip.GzipFile(fileobj=buf, mode="rb") as gz: + chunks = list(chunk_stream(gz)) + + assert b"".join(chunks) == original + ok(f"gzip stream: {len(chunks)} chunks, {len(original)} bytes decompressed") + + +def test_chunk_stream_matches_chunk_bytes(): + print("\n[5] chunk_stream produces same chunks as chunk_bytes") + from libp2p.bitswap.chunker import chunk_bytes + + data = os.urandom(DEFAULT_CHUNK_SIZE * 5 + 777) + stream_chunks = list(chunk_stream(io.BytesIO(data))) + bytes_chunks = chunk_bytes(data) + assert stream_chunks == bytes_chunks + ok(f"chunk_stream == chunk_bytes for {len(data)} bytes of random data") + + +# ── 2. MerkleDag.add_stream ─────────────────────────────────────────────────── + + +async def test_add_stream_bytesio(): + print("\n[6] add_stream — BytesIO produces same CID as add_bytes") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + data = b"same content " * 5000 + + cid_bytes = await dag.add_bytes(data) + stored.clear() + cid_stream = await dag.add_stream(io.BytesIO(data)) + + assert bytes(cid_bytes) == bytes(cid_stream), ( + f"CIDs differ:\n add_bytes: {cid_to_text(cid_bytes)}\n" + f" add_stream: {cid_to_text(cid_stream)}" + ) + ok(f"add_stream CID == add_bytes CID: {cid_to_text(cid_stream)[:30]}...") + + +async def test_add_stream_empty(): + print("\n[7] add_stream — empty stream stores single empty leaf") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + await dag.add_stream(io.BytesIO(b"")) + + assert len(stored) == 1 + block = list(stored.values())[0] + assert is_file_node(block) + _, unixfs = decode_dag_pb(block) + assert unixfs is not None + assert unixfs.filesize == 0 + ok("empty stream → 1 empty dag-pb leaf block stored") + + +async def test_add_stream_single_chunk(): + print("\n[8] add_stream — single chunk returns leaf CID directly (no root node)") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + data = b"small enough to be one chunk" + root_cid = await dag.add_stream(io.BytesIO(data)) + + assert len(stored) == 1, f"expected 1 block, got {len(stored)}" + block = stored[bytes(root_cid)] + _, unixfs = decode_dag_pb(block) + assert unixfs is not None + assert unixfs.data == data + ok("single chunk: leaf CID returned directly, inline data correct") + + +async def test_add_stream_gzip(): + print("\n[9] add_stream — gzip stream decompresses and adds correctly") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + original = b"gzip content " * 20000 # ~260 KB — 2 chunks after decompress + + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(original) + compressed_size = buf.tell() + buf.seek(0) + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + + with gzip.GzipFile(fileobj=buf, mode="rb") as gz: + root_cid = await dag.add_stream(gz) + + # Reassemble all leaf data + root_block = stored[bytes(root_cid)] + links, _ = decode_dag_pb(root_block) + reassembled = b"" + for link in links: + leaf = stored[bytes(link.cid)] + _, leaf_unixfs = decode_dag_pb(leaf) + assert leaf_unixfs is not None + reassembled += leaf_unixfs.data + + assert reassembled == original + ok( + f"gzip stream: {compressed_size} compressed → {len(original)} bytes added " + f"in {len(links)} chunks" + ) + + +async def test_add_stream_vs_add_file_same_cid(): + print("\n[10] add_stream(open(f)) produces same CID as add_file(path)") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + data = b"compare stream vs file " * 8000 # ~176 KB, 3 chunks + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(data) + tmp = f.name + + try: + + def make_dag(): + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored = {} + + async def add_block(cid, d): + stored[bytes(cid)] = d + + mock.add_block = AsyncMock(side_effect=add_block) + return MerkleDag(mock) + + dag1 = make_dag() + cid_file = await dag1.add_file(tmp, wrap_with_directory=False) + + dag2 = make_dag() + with open(tmp, "rb") as fh: + cid_stream = await dag2.add_stream(fh) + + assert bytes(cid_file) == bytes(cid_stream), ( + f"CIDs differ:\n add_file: {cid_to_text(cid_file)}\n" + f" add_stream: {cid_to_text(cid_stream)}" + ) + ok(f"add_file == add_stream CID: {cid_to_text(cid_file)[:30]}...") + finally: + os.unlink(tmp) + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main(): + print("=" * 60) + print("io.IOBase Input Support — Test Suite") + print("=" * 60) + + # sync tests + test_chunk_stream_bytesio() + test_chunk_stream_empty() + test_chunk_stream_file_handle() + test_chunk_stream_gzip() + test_chunk_stream_matches_chunk_bytes() + + # async tests + await test_add_stream_bytesio() + await test_add_stream_empty() + await test_add_stream_single_chunk() + await test_add_stream_gzip() + await test_add_stream_vs_add_file_same_cid() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_provider_query.py b/tests/core/bitswap/test_provider_query.py new file mode 100644 index 000000000..8617cc6eb --- /dev/null +++ b/tests/core/bitswap/test_provider_query.py @@ -0,0 +1,450 @@ +""" +Tests for ProviderQueryManager and its integration with BitswapClient. + +Covers: +- ProviderCacheEntry – TTL, expiry +- ProviderCache – LRU eviction, TTL, cleanup, stats +- ProviderQueryManager – single/batch queries, cache hit/miss, + max_providers cap, error handling, stats +- BitswapClient integration – provider_query_manager wired at construction, + get_block() uses DHT discovery +""" + +from __future__ import annotations + +import time +from unittest.mock import Mock + +import pytest +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import cid_to_bytes, compute_cid_v0, parse_cid +from libp2p.bitswap.client import BitswapClient +from libp2p.bitswap.provider_query import ( + ProviderCache, + ProviderCacheEntry, + ProviderQueryManager, +) +from libp2p.peer.id import ID as PeerID +from libp2p.peer.peerinfo import PeerInfo + +# ── helpers ─────────────────────────────────────────────────────────────────── + +PEER_A = PeerID.from_base58("QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN") +PEER_B = PeerID.from_base58("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") +PEER_C = PeerID.from_base58("QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64") + +SAMPLE_PEERS = [PEER_A, PEER_B, PEER_C] + +CID_1 = parse_cid(compute_cid_v0(b"block-one")) +CID_2 = parse_cid(compute_cid_v0(b"block-two")) +CID_3 = parse_cid(compute_cid_v0(b"block-three")) + +SAMPLE_CIDS = [CID_1, CID_2, CID_3] + + +def _mock_dht(return_peers: list[PeerID] | None = None) -> Mock: + """ + Return a mock DHT whose provider_store.find_providers returns *return_peers*. + + find_providers is the async network lookup path; get_providers is the + local-store read that ProviderQueryManager no longer calls directly. + """ + dht = Mock() + dht.provider_store = Mock() + peer_infos = [PeerInfo(p, []) for p in (return_peers or [])] + + async def _async_find_providers(key: bytes, count: int = 20) -> list[PeerInfo]: + return peer_infos[:count] + + dht.provider_store.find_providers = Mock(side_effect=_async_find_providers) + return dht + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderCacheEntry +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderCacheEntry: + def test_fresh_entry_not_expired(self) -> None: + entry = ProviderCacheEntry(providers=SAMPLE_PEERS, ttl=300) + assert not entry.is_expired() + assert entry.age() < 1.0 + + def test_entry_with_past_timestamp_is_expired(self) -> None: + entry = ProviderCacheEntry( + providers=SAMPLE_PEERS, + timestamp=time.time() - 10, + ttl=5, + ) + assert entry.is_expired() + + def test_default_ttl_applied(self) -> None: + entry = ProviderCacheEntry(providers=[PEER_A]) + assert entry.ttl == 300 + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderCache +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderCache: + def test_put_and_get(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=60) + cache.put(b"k1", SAMPLE_PEERS) + assert cache.get(b"k1") == SAMPLE_PEERS + + def test_miss_returns_none(self) -> None: + cache = ProviderCache() + assert cache.get(b"no-such-key") is None + + def test_expired_entry_returns_none(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"k1", SAMPLE_PEERS, ttl=0.01) + time.sleep(0.05) + assert cache.get(b"k1") is None + + def test_lru_evicts_oldest(self) -> None: + cache = ProviderCache(max_size=3, default_ttl=300) + cache.put(b"a", [PEER_A]) + cache.put(b"b", [PEER_B]) + cache.put(b"c", [PEER_C]) + cache.get(b"a") # mark 'a' recently used + cache.put(b"d", [PEER_A]) # 'b' should be evicted + assert cache.get(b"b") is None + assert cache.get(b"a") is not None + assert cache.get(b"d") is not None + + def test_clear_empties_cache(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"k1", [PEER_A]) + cache.put(b"k2", [PEER_B]) + cache.clear() + assert cache.size() == 0 + + def test_cleanup_expired_removes_stale(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"stale", [PEER_A], ttl=0.01) + cache.put(b"fresh", [PEER_B], ttl=300) + time.sleep(0.05) + removed = cache.cleanup_expired() + assert removed == 1 + assert cache.size() == 1 + + def test_stats_keys_present(self) -> None: + cache = ProviderCache(max_size=5, default_ttl=300) + cache.put(b"k", [PEER_A]) + stats = cache.stats() + assert {"size", "max_size", "expired"} <= stats.keys() + assert stats["size"] == 1 + assert stats["max_size"] == 5 + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderQueryManager +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderQueryManager: + @pytest.mark.trio + async def test_cache_miss_queries_dht(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + providers = await mgr.find_providers_single(CID_1, timeout=5.0) + + assert providers == [PEER_A] + stats = mgr.get_stats() + assert stats["queries"] == 1 + assert stats["cache_misses"] == 1 + assert stats["cache_hits"] == 0 + assert stats["providers_found"] == 1 + # Verify the async network path was used, not the local store read + dht.provider_store.find_providers.assert_called_once() + + @pytest.mark.trio + async def test_cache_hit_skips_dht(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_B]) + + providers = await mgr.find_providers_single(CID_1) + + assert providers == [PEER_B] + dht.provider_store.find_providers.assert_not_called() + assert mgr.get_stats()["cache_hits"] == 1 + + @pytest.mark.trio + async def test_second_call_uses_cache(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + await mgr.find_providers_single(CID_1) # miss + await mgr.find_providers_single(CID_1) # hit + + stats = mgr.get_stats() + assert stats["queries"] == 1 # no extra DHT call + assert stats["cache_hits"] == 1 + + @pytest.mark.trio + async def test_max_providers_cap(self) -> None: + dht = _mock_dht(return_peers=SAMPLE_PEERS) + mgr = ProviderQueryManager(dht, max_providers=1) + + providers = await mgr.find_providers_single(CID_1) + assert len(providers) == 1 + + @pytest.mark.trio + async def test_no_providers_returns_empty(self) -> None: + dht = _mock_dht(return_peers=[]) + mgr = ProviderQueryManager(dht) + providers = await mgr.find_providers_single(CID_1) + assert providers == [] + + @pytest.mark.trio + async def test_dht_error_increments_errors(self) -> None: + dht = _mock_dht() + + async def _raise(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("dht down") + + dht.provider_store.find_providers = Mock(side_effect=_raise) + mgr = ProviderQueryManager(dht) + + providers = await mgr.find_providers_single(CID_1, timeout=5.0) + + assert providers == [] + assert mgr.get_stats()["errors"] == 1 + + @pytest.mark.trio + async def test_batch_all_cache_hits(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + for cid in SAMPLE_CIDS: + mgr.cache.put(cid_to_bytes(cid), [PEER_A]) + + results = await mgr.find_providers(SAMPLE_CIDS) + + assert len(results) == 3 + dht.provider_store.find_providers.assert_not_called() + + @pytest.mark.trio + async def test_batch_partial_cache(self) -> None: + dht = _mock_dht(return_peers=[PEER_B]) + mgr = ProviderQueryManager(dht) + # Pre-cache only first CID + mgr.cache.put(cid_to_bytes(CID_1), [PEER_A]) + + results = await mgr.find_providers(SAMPLE_CIDS) + + assert len(results) == 3 + # Only 2 DHT calls (CID_2 and CID_3 are cache misses) + assert dht.provider_store.find_providers.call_count == 2 + + @pytest.mark.trio + async def test_use_cache_false_always_queries_dht(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_B]) # pre-populated + + providers = await mgr.find_providers_single(CID_1, use_cache=False) + + # DHT was queried despite cache having an entry + dht.provider_store.find_providers.assert_called_once() + assert providers == [PEER_A] + + @pytest.mark.trio + async def test_clear_cache_forces_new_query(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + await mgr.find_providers_single(CID_1) # miss → cached + await mgr.find_providers_single(CID_1) # hit + mgr.clear_cache() + await mgr.find_providers_single(CID_1) # miss again + + assert mgr.get_stats()["cache_misses"] == 2 + assert dht.provider_store.find_providers.call_count == 2 + + @pytest.mark.trio + async def test_cleanup_expired_cache(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_A], ttl=0.01) + mgr.cache.put(cid_to_bytes(CID_2), [PEER_B], ttl=300) + await trio.sleep(0.05) + + removed = await mgr.cleanup_expired_cache() + + assert removed == 1 + assert mgr.cache.size() == 1 + + def test_get_stats_initial_values(self) -> None: + mgr = ProviderQueryManager(_mock_dht()) + stats = mgr.get_stats() + assert stats["queries"] == 0 + assert stats["cache_hits"] == 0 + assert stats["cache_misses"] == 0 + assert stats["errors"] == 0 + assert stats["providers_found"] == 0 + + @pytest.mark.trio + async def test_empty_cid_list(self) -> None: + mgr = ProviderQueryManager(_mock_dht()) + assert await mgr.find_providers([]) == {} + + +# ═════════════════════════════════════════════════════════════════════════════ +# BitswapClient integration +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestBitswapClientProviderQueryIntegration: + """Verify that BitswapClient wires ProviderQueryManager into get_block().""" + + def _make_client( + self, + mock_host: Mock, + pqm: ProviderQueryManager | None = None, + ) -> BitswapClient: + store = MemoryBlockStore() + return BitswapClient(mock_host, block_store=store, provider_query_manager=pqm) + + def test_provider_query_manager_stored_on_client(self, mock_host: Mock) -> None: + dht = _mock_dht() + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + assert client.provider_query_manager is pqm + + def test_no_pqm_by_default(self, mock_host: Mock) -> None: + client = self._make_client(mock_host) + assert client.provider_query_manager is None + + @pytest.mark.trio + async def test_get_block_returns_local_without_dht(self, mock_host: Mock) -> None: + """Local cache hit must never touch the DHT.""" + dht = _mock_dht(return_peers=[PEER_A]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"local block" + cid = parse_cid(compute_cid_v0(block_data)) + await client.block_store.put_block(cid, block_data) + + result = await client.block_store.get_block(cid) + assert result == block_data + # DHT must not have been consulted + dht.provider_store.find_providers.assert_not_called() + + @pytest.mark.trio + async def test_get_block_uses_pqm_to_pick_peer(self, mock_host: Mock) -> None: + """ + When the block is not local, get_block() should call + provider_query_manager.find_providers_single() and use the + returned peer_id. + """ + discovered_peer = PEER_A + block_data = b"remote block" + cid = parse_cid(compute_cid_v0(block_data)) + + dht = _mock_dht(return_peers=[discovered_peer]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + # Patch _request_block so we can inspect the peer_id it receives + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] == discovered_peer + + @pytest.mark.trio + async def test_get_block_falls_back_to_broadcast_when_no_providers( + self, mock_host: Mock + ) -> None: + """ + When the DHT returns no providers, get_block() must still call + _request_block with peer_id=None (broadcast fallback). + """ + dht = _mock_dht(return_peers=[]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"broadcast block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] is None # broadcast + + @pytest.mark.trio + async def test_explicit_peer_id_skips_pqm(self, mock_host: Mock) -> None: + """An explicit peer_id argument must bypass DHT discovery.""" + dht = _mock_dht(return_peers=[PEER_B]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"explicit peer block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + await client.get_block(cid, peer_id=PEER_A) + + # DHT must NOT have been called + dht.provider_store.get_providers.assert_not_called() + # The explicit peer_id must be passed through unchanged + assert captured["peer_id"] == PEER_A + + @pytest.mark.trio + async def test_pqm_error_falls_back_gracefully(self, mock_host: Mock) -> None: + """A crashing PQM must not prevent the block fetch from proceeding.""" + dht = _mock_dht() + + async def _raise(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("dht exploded") + + dht.provider_store.find_providers = Mock(side_effect=_raise) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"fallback block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] is None # graceful broadcast fallback diff --git a/tests/core/bitswap/test_unixfs_encoding.py b/tests/core/bitswap/test_unixfs_encoding.py new file mode 100644 index 000000000..cff119430 --- /dev/null +++ b/tests/core/bitswap/test_unixfs_encoding.py @@ -0,0 +1,255 @@ +""" +Test that add_file / add_bytes now produce dag-pb leaf blocks (UnixFS-wrapped) +and that balanced_layout builds the correct tree structure. + +Run with: + python test_unixfs_encoding.py +""" + +import os +import tempfile + +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import CODEC_DAG_PB, CODEC_RAW, cid_to_text, compute_cid_v1 +from libp2p.bitswap.dag_pb import ( + MAX_LINKS_PER_NODE, + balanced_layout, + create_leaf_node, + decode_dag_pb, + is_file_node, +) + + +def ok(label): + print(f" OK {label}") + + +def fail(label, detail=""): + raise AssertionError(f"FAIL {label} {detail}") + + +# ── 1. create_leaf_node wraps data in dag-pb + UnixFS ──────────────────────── +def test_create_leaf_node(): + print("\n[1] create_leaf_node") + data = b"hello leaf" + leaf = create_leaf_node(data) + + # Must be a valid dag-pb file node + assert is_file_node(leaf), "leaf must be a dag-pb file node" + ok("create_leaf_node produces a dag-pb file node") + + # Decode and check inline data + links, unixfs = decode_dag_pb(leaf) + assert links == [], "leaf must have no links" + assert unixfs is not None + assert unixfs.data == data, f"inline data mismatch: {unixfs.data!r} != {data!r}" + assert unixfs.filesize == len(data) + ok(f"leaf contains inline data ({len(data)} bytes), filesize={unixfs.filesize}") + + # CID must be dag-pb, not raw + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + raw_cid = compute_cid_v1(data, codec=CODEC_RAW) + assert bytes(cid) != bytes(raw_cid), "dag-pb leaf CID must differ from raw CID" + ok(f"leaf CID is dag-pb (not raw): {cid_to_text(cid)[:30]}...") + + # Empty leaf + empty_leaf = create_leaf_node(b"") + _, empty_unixfs = decode_dag_pb(empty_leaf) + assert empty_unixfs is not None + assert empty_unixfs.filesize == 0 + ok("empty leaf node is valid") + + +# ── 2. balanced_layout single leaf ─────────────────────────────────────────── +def test_balanced_layout_single(): + print("\n[2] balanced_layout — single leaf returns leaf unchanged") + data = b"only chunk" + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + + root_cid, root_block = balanced_layout([(cid, leaf, len(data))]) + assert bytes(root_cid) == bytes(cid) + assert root_block == leaf + ok("single leaf: root_cid == leaf_cid") + + +# ── 3. balanced_layout two leaves ──────────────────────────────────────────── +def test_balanced_layout_two_leaves(): + print("\n[3] balanced_layout — two leaves builds one root") + leaves = [] + for i in range(2): + data = f"chunk {i}".encode() * 100 + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, len(data))) + + root_cid, root_block = balanced_layout(leaves) + + # Root must be a dag-pb file node with 2 links + assert is_file_node(root_block) + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 2, f"expected 2 links, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == sum(s for _, _, s in leaves) + assert len(unixfs.blocksizes) == 2 + ok(f"root has 2 links, filesize={unixfs.filesize}, blocksizes={unixfs.blocksizes}") + + +# ── 4. balanced_layout 175 leaves builds 2-level tree ──────────────────────── +def test_balanced_layout_two_levels(): + print("\n[4] balanced_layout — 175 leaves builds 2-level tree (174 + 1)") + n = MAX_LINKS_PER_NODE + 1 # 175 + chunk_size = 100 + leaves = [] + for i in range(n): + data = bytes([i % 256]) * chunk_size + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, chunk_size)) + + root_cid, root_block = balanced_layout(leaves) + links, unixfs = decode_dag_pb(root_block) + + # Root should link to 2 internal nodes (174 + 1) + assert len(links) == 2, f"expected 2 top-level links, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == n * chunk_size + ok("175 leaves → root has 2 links (174-leaf node + 1-leaf node)") + ok(f"root filesize = {unixfs.filesize} = 175 * {chunk_size}") + + +# ── 5. balanced_layout 174 leaves stays flat ───────────────────────────────── +def test_balanced_layout_flat(): + print("\n[5] balanced_layout — exactly 174 leaves stays flat (1 level)") + n = MAX_LINKS_PER_NODE # 174 + leaves = [] + for i in range(n): + data = bytes([i % 256]) * 50 + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, 50)) + + root_cid, root_block = balanced_layout(leaves) + links, unixfs = decode_dag_pb(root_block) + + assert len(links) == 174, f"expected 174 direct links, got {len(links)}" + ok("174 leaves → flat root with 174 direct links") + + +# ── 6. add_file produces dag-pb leaves (not raw) via MerkleDag ─────────────── +async def test_add_file_produces_dag_pb_leaves(): + print("\n[6] MerkleDag.add_file produces dag-pb leaf blocks") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock_client = MagicMock(spec=BitswapClient) + mock_client.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block_impl(cid, data): + stored[bytes(cid)] = data + + mock_client.add_block = AsyncMock(side_effect=add_block_impl) + + dag = MerkleDag(mock_client) + + # Write a 3-chunk file + chunk_size = 63 * 1024 + content = b"x" * (chunk_size * 3 - 7) # 3 chunks + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(content) + tmp = f.name + + try: + root_cid = await dag.add_file( + tmp, chunk_size=chunk_size, wrap_with_directory=False + ) + finally: + os.unlink(tmp) + + # Every stored block must be a dag-pb file node (no raw blocks) + raw_blocks = [] + for cid_bytes, block_data in stored.items(): + if not is_file_node(block_data): + raw_blocks.append(cid_to_text(cid_bytes)[:20]) + + assert raw_blocks == [], f"Found non-dag-pb blocks: {raw_blocks}" + ok(f"All {len(stored)} stored blocks are dag-pb file nodes (no raw blocks)") + + # Root must link to 3 leaves + root_block = stored[bytes(root_cid)] + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 3, f"expected 3 links on root, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == len(content) + ok(f"root has 3 links, filesize={unixfs.filesize}") + + # Each leaf must contain inline UnixFS data + for link in links: + leaf_block = stored[bytes(link.cid)] + leaf_links, leaf_unixfs = decode_dag_pb(leaf_block) + assert leaf_links == [], "leaf must have no links" + assert leaf_unixfs is not None and leaf_unixfs.data != b"" + ok("each leaf contains inline UnixFS data") + + +# ── 7. add_bytes produces dag-pb leaves ────────────────────────────────────── +async def test_add_bytes_produces_dag_pb_leaves(): + print("\n[7] MerkleDag.add_bytes produces dag-pb leaf blocks") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock_client = MagicMock(spec=BitswapClient) + mock_client.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block_impl(cid, data): + stored[bytes(cid)] = data + + mock_client.add_block = AsyncMock(side_effect=add_block_impl) + + dag = MerkleDag(mock_client) + content = b"y" * (63 * 1024 * 2 + 500) # 3 chunks + root_cid = await dag.add_bytes(content) + + raw_blocks = [cid_to_text(c)[:20] for c, d in stored.items() if not is_file_node(d)] + assert raw_blocks == [], f"Found non-dag-pb blocks: {raw_blocks}" + ok(f"All {len(stored)} stored blocks are dag-pb file nodes") + + root_block = stored[bytes(root_cid)] + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 3 + assert unixfs is not None + assert unixfs.filesize == len(content) + ok(f"root has 3 links, filesize={unixfs.filesize}") + + +# ── main ────────────────────────────────────────────────────────────────────── +async def main(): + print("=" * 60) + print("UnixFSFile / Balanced DAG — Test Suite") + print("=" * 60) + + test_create_leaf_node() + test_balanced_layout_single() + test_balanced_layout_two_leaves() + test_balanced_layout_two_levels() + test_balanced_layout_flat() + await test_add_file_produces_dag_pb_leaves() + await test_add_bytes_produces_dag_pb_leaves() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_wantlist.py b/tests/core/bitswap/test_wantlist.py new file mode 100644 index 000000000..a632fc80b --- /dev/null +++ b/tests/core/bitswap/test_wantlist.py @@ -0,0 +1,281 @@ +""" +Test Wantlist / Message dataclasses. + +Run with: + python test_wantlist.py +""" + +from libp2p.bitswap.cid import CODEC_RAW, cid_to_bytes, compute_cid_v1 +from libp2p.bitswap.messages import create_wantlist_entry +from libp2p.bitswap.wantlist import ( + BitswapMessage, + BlockPresence, + BlockPresenceType, + Wantlist, + WantlistEntry, + WantType, +) + + +def make_cid(content: bytes) -> bytes: + return cid_to_bytes(compute_cid_v1(content, codec=CODEC_RAW)) + + +def ok(label): + print(f" OK {label}") + + +# ── WantType enum ───────────────────────────────────────────────────────────── + + +def test_want_type_values(): + print("\n[1] WantType enum values match protobuf") + assert WantType.Block.value == 0 + assert WantType.Have.value == 1 + ok("WantType.Block == 0, WantType.Have == 1") + + +# ── WantlistEntry ───────────────────────────────────────────────────────────── + + +def test_wantlist_entry_from_cid(): + print("\n[2] WantlistEntry.from_cid normalises any CIDInput") + cid = compute_cid_v1(b"entry test", codec=CODEC_RAW) + cid_bytes = cid_to_bytes(cid) + + # from bytes + e1 = WantlistEntry.from_cid(cid_bytes) + assert e1.cid == cid_bytes + assert e1.want_type == WantType.Block + assert e1.priority == 1 + assert not e1.cancel + ok("from bytes — defaults correct") + + # from CIDObject + e2 = WantlistEntry.from_cid(cid, want_type=WantType.Have, send_dont_have=True) + assert e2.want_type == WantType.Have + assert e2.send_dont_have + ok("from CIDObject — WantType.Have, send_dont_have=True") + + # cancel entry + e3 = WantlistEntry.from_cid(cid_bytes, cancel=True) + assert e3.cancel + ok("cancel entry") + + +# ── Wantlist ────────────────────────────────────────────────────────────────── + + +def test_wantlist_add_cancel_contains(): + print("\n[3] Wantlist.add / cancel / contains") + cid1 = make_cid(b"block 1") + cid2 = make_cid(b"block 2") + cid3 = make_cid(b"block 3") + + wl = Wantlist() + assert len(wl) == 0 + assert not wl + + wl.add(cid1, want_type=WantType.Block, send_dont_have=True) + wl.add(cid2, want_type=WantType.Have) + wl.cancel(cid3) + + assert len(wl) == 3 + assert bool(wl) + ok("len(wl) == 3 after 2 adds + 1 cancel") + + assert wl.contains(cid1) + assert wl.contains(cid2) + assert not wl.contains(cid3) # cancel entry → not "contained" + ok("contains() returns True for non-cancel entries only") + + # Check entry fields + e1 = wl.entries[0] + assert e1.want_type == WantType.Block + assert e1.send_dont_have + e2 = wl.entries[1] + assert e2.want_type == WantType.Have + e3 = wl.entries[2] + assert e3.cancel + ok("entry fields correct (want_type, send_dont_have, cancel)") + + +def test_wantlist_full_flag(): + print("\n[4] Wantlist.full flag") + wl = Wantlist(full=True) + assert wl.full + ok("full=True preserved") + + +# ── BlockPresence ───────────────────────────────────────────────────────────── + + +def test_block_presence(): + print("\n[5] BlockPresence constructors") + cid = make_cid(b"presence test") + + have = BlockPresence.have(cid) + assert have.cid == cid + assert have.type == BlockPresenceType.Have + ok("BlockPresence.have()") + + dont = BlockPresence.dont_have(cid) + assert dont.cid == cid + assert dont.type == BlockPresenceType.DontHave + ok("BlockPresence.dont_have()") + + assert BlockPresenceType.Have.value == 0 + assert BlockPresenceType.DontHave.value == 1 + ok("BlockPresenceType values match protobuf (Have=0, DontHave=1)") + + +# ── BitswapMessage ──────────────────────────────────────────────────────────── + + +def test_bitswap_message_properties(): + print("\n[6] BitswapMessage builder + properties") + cid1 = make_cid(b"want me") + cid2 = make_cid(b"block data") + cid3 = make_cid(b"i have this") + cid4 = make_cid(b"i dont have this") + data = b"actual block content" + + msg = BitswapMessage() + assert not msg.is_want + assert not msg.has_blocks + assert not msg.has_presences + + msg.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + assert msg.is_want + ok("is_want True after add_want()") + + msg.add_block(cid2, data) + assert msg.has_blocks + assert msg.blocks[0] == (cid2, data) + ok("has_blocks True after add_block()") + + msg.add_have(cid3) + msg.add_dont_have(cid4) + assert msg.has_presences + assert len(msg.block_presences) == 2 + assert msg.block_presences[0].type == BlockPresenceType.Have + assert msg.block_presences[1].type == BlockPresenceType.DontHave + ok("has_presences True, HAVE and DONT_HAVE entries correct") + + +def test_bitswap_message_cancel_want(): + print("\n[7] BitswapMessage.cancel_want()") + cid = make_cid(b"cancel me") + msg = BitswapMessage() + msg.cancel_want(cid) + assert msg.is_want + assert msg.wantlist is not None + assert msg.wantlist.entries[0].cancel + ok("cancel_want() adds cancel entry") + + +# ── to_proto / from_proto round-trip ───────────────────────────────────────── + + +def test_to_proto_from_proto_roundtrip(): + print("\n[8] BitswapMessage to_proto() / from_proto() round-trip") + cid1 = make_cid(b"want block") + cid2 = make_cid(b"block payload") + cid3 = make_cid(b"have this") + data = b"block payload data" + + original = BitswapMessage() + original.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + original.add_block(cid2, data) + original.add_have(cid3) + original.add_dont_have(make_cid(b"dont have")) + + proto = original.to_proto() + restored = BitswapMessage.from_proto(proto) + + # Wantlist + assert restored.wantlist is not None + assert len(restored.wantlist.entries) == 1 + e = restored.wantlist.entries[0] + assert e.cid == cid1 + assert e.want_type == WantType.Block + assert e.send_dont_have + ok("wantlist entry round-trips correctly") + + # Block payload + assert len(restored.blocks) == 1 + restored_cid, restored_data = restored.blocks[0] + assert restored_data == data + ok("block payload round-trips correctly") + + # Block presences + assert len(restored.block_presences) == 2 + assert restored.block_presences[0].type == BlockPresenceType.Have + assert restored.block_presences[1].type == BlockPresenceType.DontHave + ok("block presences round-trip correctly") + + +# ── backward compat: create_wantlist_entry accepts int OR WantType ──────────── + + +def test_create_wantlist_entry_backward_compat(): + print("\n[9] create_wantlist_entry — backward compat (int OR WantType)") + cid = make_cid(b"compat test") + + # Old style: raw int + e_int = create_wantlist_entry(cid, want_type=0) + assert e_int.wantType == 0 + ok("want_type=0 (int) still works") + + e_int2 = create_wantlist_entry(cid, want_type=1) + assert e_int2.wantType == 1 + ok("want_type=1 (int) still works") + + # New style: WantType enum + e_enum = create_wantlist_entry(cid, want_type=WantType.Block) + assert e_enum.wantType == 0 + ok("want_type=WantType.Block works") + + e_enum2 = create_wantlist_entry(cid, want_type=WantType.Have) + assert e_enum2.wantType == 1 + ok("want_type=WantType.Have works") + + +# ── public API exports ──────────────────────────────────────────────────────── + + +def test_public_exports(): + print("\n[10] All types exported from libp2p.bitswap") + from libp2p.bitswap import ( + WantType, + ) + + assert WantType.Block.value == 0 + assert WantType.Have.value == 1 + ok( + "WantType, WantlistEntry, Wantlist, BlockPresence, BlockPresenceType, " + "BitswapMessage all importable from libp2p.bitswap" + ) + + +# ── main ────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Wantlist / Message Dataclasses — Test Suite") + print("=" * 60) + + test_want_type_values() + test_wantlist_entry_from_cid() + test_wantlist_add_cancel_contains() + test_wantlist_full_flag() + test_block_presence() + test_bitswap_message_properties() + test_bitswap_message_cancel_want() + test_to_proto_from_proto_roundtrip() + test_create_wantlist_entry_backward_compat() + test_public_exports() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py b/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py index 87e669cc0..b1be7cdb8 100644 --- a/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py +++ b/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py @@ -41,6 +41,7 @@ def _make_dht() -> KadDHT: host = MagicMock() key_pair = create_new_key_pair() host.get_id.return_value = ID.from_pubkey(key_pair.public_key) + host.get_private_key.return_value = key_pair.private_key host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_peerstore.return_value = MagicMock() host.new_stream = AsyncMock() diff --git a/tests/core/kad_dht/test_unit_value_store.py b/tests/core/kad_dht/test_unit_value_store.py index bdaaacd9c..6a5d7d4a7 100644 --- a/tests/core/kad_dht/test_unit_value_store.py +++ b/tests/core/kad_dht/test_unit_value_store.py @@ -15,6 +15,7 @@ import pytest +from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.kad_dht.value_store import ( DEFAULT_TTL, ValueStore, @@ -24,8 +25,11 @@ ) from libp2p.records.record import make_put_record +# Create a real key pair for signing +key_pair = create_new_key_pair() mock_host = Mock() -peer_id = ID.from_base58("QmTest123") +mock_host.get_private_key.return_value = key_pair.private_key +peer_id = ID.from_pubkey(key_pair.public_key) class TestValueStore: @@ -445,6 +449,178 @@ async def test_store_at_peer_local_peer(self): assert result is True + @pytest.mark.trio + async def test_store_at_peer_propagates_signature_and_author(self): + """ + _store_at_peer must include signature and author from the locally-stored + signed record in the outbound PUT_VALUE message. + + This ensures signed-record authenticity is preserved when replicating + values to remote peers, matching go-libp2p interoperability requirements. + """ + import varint + + from libp2p.kad_dht.pb.kademlia_pb2 import Message + + # Build a host with a real key pair so put() creates a genuine signed record + kp = create_new_key_pair() + remote_peer_id = ID.from_base58("QmRemote123456789") + local_peer_id = ID.from_pubkey(kp.public_key) + + # Capture the bytes written to the mock stream + written: list[bytes] = [] + + mock_stream = Mock() + + async def _write(data: bytes) -> None: + written.append(data) + + async def _read(n: int) -> bytes: + # Simulate a minimal valid PUT_VALUE acknowledgement + resp = Message() + resp.type = Message.MessageType.PUT_VALUE + resp.key = b"test_key" + raw = resp.SerializeToString() + length = varint.encode(len(raw)) + # Return one byte at a time for the varint reader, then the body + full = length + raw + if not hasattr(_read, "_buf"): + _read._buf = iter(full) # type: ignore[attr-defined] + byte_val = next(_read._buf, b"") # type: ignore[attr-defined] + return bytes([byte_val]) if isinstance(byte_val, int) else byte_val + + mock_stream.write = Mock(side_effect=_write) + mock_stream.read = Mock(side_effect=_read) + mock_stream.close = Mock(return_value=None) + + # Patch close to be awaitable + async def _close() -> None: + pass + + mock_stream.close = _close + + h = Mock() + h.get_private_key.return_value = kp.private_key + h.get_peerstore.return_value = Mock() + + # env_to_send_in_RPC is called; return empty bytes to keep test simple + from libp2p.peer.peerstore import env_to_send_in_RPC + + original_env = env_to_send_in_RPC + + import libp2p.kad_dht.value_store as vs_module + + vs_module.env_to_send_in_RPC = Mock(return_value=(b"", None)) # type: ignore[attr-defined] + + async def _new_stream(*_args: object, **_kwargs: object) -> object: + return mock_stream + + h.new_stream = _new_stream + + try: + store = ValueStore(host=h, local_peer_id=local_peer_id) + key = b"test_key" + value = b"test_value" + + # Store locally first (creates signed record) + store.put(key, value) + + # Confirm the local record has signature and author set + local_record, _ = store.store[key] + assert local_record.signature, "put() must produce a non-empty signature" + assert local_record.author, "put() must populate the author field" + + # Now replicate to a remote peer + await store._store_at_peer(remote_peer_id, key, value) + + # Reconstruct the serialized message from what was written + # written[0] is the varint length prefix, written[1] is the proto body + assert len(written) >= 2, "Expected varint + proto body to be written" + sent_msg = Message() + sent_msg.ParseFromString(written[1]) + + assert sent_msg.HasField("record"), "Outbound message must contain a record" + assert sent_msg.record.signature == local_record.signature, ( + "Outbound record must carry the signature from the signed record" + ) + assert sent_msg.record.author == local_record.author, ( + "Outbound record must carry the author from the signed record" + ) + finally: + vs_module.env_to_send_in_RPC = original_env # type: ignore[attr-defined] + + @pytest.mark.trio + async def test_store_at_peer_signs_record_without_prior_put(self): + """ + When _store_at_peer is called without a prior put() (e.g. the get_value + propagation path), it must still produce a signed outbound record — + never a bare unsigned one. + """ + import varint + + from libp2p.kad_dht.pb.kademlia_pb2 import Message + + kp = create_new_key_pair() + remote_peer_id = ID.from_base58("QmRemote999") + local_peer_id = ID.from_pubkey(kp.public_key) + + written: list[bytes] = [] + + async def _write(data: bytes) -> None: + written.append(data) + + mock_stream = Mock() + resp = Message() + resp.type = Message.MessageType.PUT_VALUE + resp.key = b"bare_key" + raw = resp.SerializeToString() + resp_bytes = varint.encode(len(raw)) + raw + resp_iter = iter(resp_bytes) + + async def _read(n: int) -> bytes: + byte_val = next(resp_iter, b"") + return bytes([byte_val]) if isinstance(byte_val, int) else byte_val + + mock_stream.write = Mock(side_effect=_write) + mock_stream.read = Mock(side_effect=_read) + + async def _close() -> None: + pass + + mock_stream.close = _close + + h = Mock() + h.get_private_key.return_value = kp.private_key + + import libp2p.kad_dht.value_store as vs_module + + original_env = vs_module.env_to_send_in_RPC + vs_module.env_to_send_in_RPC = Mock(return_value=(b"", None)) # type: ignore[attr-defined] + + async def _new_stream(*_args: object, **_kwargs: object) -> object: + return mock_stream + + h.new_stream = _new_stream + + try: + store = ValueStore(host=h, local_peer_id=local_peer_id) + key = b"bare_key" + value = b"bare_value" + + # Do NOT call store.put() — _store_at_peer must sign the record itself + await store._store_at_peer(remote_peer_id, key, value) + + assert len(written) >= 2 + sent_msg = Message() + sent_msg.ParseFromString(written[1]) + assert sent_msg.record.key == key + assert sent_msg.record.value == value + # The record must be signed even without a prior put() + assert sent_msg.record.signature, "record must be signed inline" + assert sent_msg.record.author, "record must carry author field" + finally: + vs_module.env_to_send_in_RPC = original_env # type: ignore[attr-defined] + @pytest.mark.trio async def test_get_from_peer_local_peer(self): """Test _get_from_peer returns None when querying local peer.""" diff --git a/tests/core/records/test_validator.py b/tests/core/records/test_validator.py index 4a0efc0f7..9faf3bb6c 100644 --- a/tests/core/records/test_validator.py +++ b/tests/core/records/test_validator.py @@ -5,7 +5,12 @@ from libp2p.peer.id import ID from libp2p.records.pubkey import PublicKeyValidator, unmarshal_public_key from libp2p.records.record import make_put_record -from libp2p.records.utils import InvalidRecordType, split_key +from libp2p.records.utils import ( + InvalidRecordType, + sign_record, + split_key, + verify_record, +) from libp2p.records.validator import NamespacedValidator, Validator bad_paths = [ @@ -243,3 +248,85 @@ def select(self, key: str, values: list[bytes]) -> int: # Non-namespaced key uses custom fallback that rejects with pytest.raises(ValueError, match="Rejected by fallback"): validators.validate("plain-key", b"value") + + +# ───────────────────────────────────────────────────────────────────────────── +# verify_record — multi-key-type coverage +# ───────────────────────────────────────────────────────────────────────────── + + +class TestVerifyRecord: + """ + verify_record must accept signatures from every key type that libp2p + serialises via crypto_pb2.PublicKey (Ed25519, Secp256k1, RSA). + + Previously the implementation hard-coded Ed25519PublicKey.from_bytes, + causing it to silently return False for RSA and Secp256k1 peers and + breaking DHT interoperability with non-Ed25519 nodes. + """ + + def _round_trip(self, key_pair) -> None: # noqa: ANN001 + """Sign with *key_pair* and assert verify_record returns True.""" + key = b"/test/mykey" + value = b"hello world" + sig, author = sign_record(key_pair.private_key, key, value) + assert verify_record(sig, author, key, value), ( + f"verify_record returned False for key type " + f"{key_pair.private_key.get_type()}" + ) + + def _tampered_fails(self, key_pair) -> None: # noqa: ANN001 + """Tampered payload must make verify_record return False.""" + key = b"/test/mykey" + value = b"hello world" + sig, author = sign_record(key_pair.private_key, key, value) + assert not verify_record(sig, author, key, b"tampered"), ( + f"verify_record accepted tampered value for key type " + f"{key_pair.private_key.get_type()}" + ) + + def test_ed25519_valid_signature(self) -> None: + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + self._round_trip(ed_kp()) + + def test_ed25519_tampered_value_rejected(self) -> None: + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + self._tampered_fails(ed_kp()) + + def test_secp256k1_valid_signature(self) -> None: + from libp2p.crypto.secp256k1 import create_new_key_pair as secp_kp + + self._round_trip(secp_kp()) + + def test_secp256k1_tampered_value_rejected(self) -> None: + from libp2p.crypto.secp256k1 import create_new_key_pair as secp_kp + + self._tampered_fails(secp_kp()) + + def test_rsa_valid_signature(self) -> None: + from libp2p.crypto.rsa import create_new_key_pair as rsa_kp + + self._round_trip(rsa_kp()) + + def test_rsa_tampered_value_rejected(self) -> None: + from libp2p.crypto.rsa import create_new_key_pair as rsa_kp + + self._tampered_fails(rsa_kp()) + + def test_garbage_author_bytes_returns_false(self) -> None: + """Completely invalid author bytes must return False, not raise.""" + assert not verify_record(b"sig", b"not-a-valid-protobuf", b"key", b"value") + + def test_wrong_key_returns_false(self) -> None: + """Signature verified against a different key must return False.""" + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + kp1 = ed_kp() + kp2 = ed_kp() + key = b"/test/k" + value = b"v" + sig, _ = sign_record(kp1.private_key, key, value) + _, author2 = sign_record(kp2.private_key, key, value) + assert not verify_record(sig, author2, key, value)