diff --git a/sql/snowflake_stream_state.sql b/sql/snowflake_stream_state.sql new file mode 100644 index 0000000..0782cf2 --- /dev/null +++ b/sql/snowflake_stream_state.sql @@ -0,0 +1,35 @@ +-- Snowflake Stream State Table +-- Stores server-confirmed completed batches for persistent job resumption +-- +-- This table tracks which batches have been successfully processed and confirmed +-- by the server (via checkpoint watermarks). This enables jobs to resume from +-- the correct position after interruption or failure. + +CREATE TABLE IF NOT EXISTS amp_stream_state ( + -- Job/Table identification + connection_name VARCHAR(255) NOT NULL, + table_name VARCHAR(255) NOT NULL, + network VARCHAR(100) NOT NULL, + + -- Batch identification (compact 16-char hex ID) + batch_id VARCHAR(16) NOT NULL, + + -- Block range covered by this batch + start_block BIGINT NOT NULL, + end_block BIGINT NOT NULL, + + -- Block hashes for reorg detection (optional) + end_hash VARCHAR(66), + start_parent_hash VARCHAR(66), + + -- Processing metadata + processed_at TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), + + -- Primary key ensures no duplicate batches + PRIMARY KEY (connection_name, table_name, network, batch_id) +); + +-- Comments for documentation +COMMENT ON TABLE amp_stream_state IS 'Persistent stream state for job resumption - tracks server-confirmed completed batches'; +COMMENT ON COLUMN amp_stream_state.batch_id IS 'Compact 16-character hex identifier generated from block range + hash'; +COMMENT ON COLUMN amp_stream_state.processed_at IS 'Timestamp when batch was marked as successfully processed'; diff --git a/src/amp/config/label_manager.py b/src/amp/config/label_manager.py new file mode 100644 index 0000000..39cc081 --- /dev/null +++ b/src/amp/config/label_manager.py @@ -0,0 +1,171 @@ +""" +Label Manager for CSV-based label datasets. + +This module provides functionality to register and manage CSV label datasets +that can be joined with streaming data during loading operations. +""" + +import logging +from typing import Dict, List, Optional + +import pyarrow as pa +import pyarrow.csv as csv + + +class LabelManager: + """ + Manages CSV label datasets for joining with streaming data. + + Labels are registered by name and loaded as PyArrow Tables for efficient + joining operations. This allows reuse of label datasets across multiple + queries and loaders. + + Example: + >>> manager = LabelManager() + >>> manager.add_label('token_labels', '/path/to/tokens.csv') + >>> label_table = manager.get_label('token_labels') + """ + + def __init__(self): + self._labels: Dict[str, pa.Table] = {} + self.logger = logging.getLogger(__name__) + + def add_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None: + """ + Load and register a CSV label dataset with automatic hex→binary conversion. + + Hex string columns (like Ethereum addresses) are automatically converted to + binary format for efficient storage and joining. This reduces memory usage + by ~50% and improves join performance. + + Args: + name: Unique name for this label dataset + csv_path: Path to the CSV file + binary_columns: List of column names containing hex addresses to convert to binary. + If None, auto-detects columns with 'address' in the name. + + Raises: + FileNotFoundError: If CSV file doesn't exist + ValueError: If CSV cannot be parsed or name already exists + """ + if name in self._labels: + self.logger.warning(f"Label '{name}' already exists, replacing with new data") + + try: + # Load CSV as PyArrow Table (initially as strings) + temp_table = csv.read_csv(csv_path, read_options=csv.ReadOptions(autogenerate_column_names=False)) + + # Force all columns to be strings initially + column_types = {col_name: pa.string() for col_name in temp_table.column_names} + convert_opts = csv.ConvertOptions(column_types=column_types) + label_table = csv.read_csv(csv_path, convert_options=convert_opts) + + # Auto-detect or use specified binary columns + if binary_columns is None: + # Auto-detect columns with 'address' in name (case-insensitive) + binary_columns = [col for col in label_table.column_names if 'address' in col.lower()] + + # Convert hex string columns to binary for efficiency + converted_columns = [] + for col_name in binary_columns: + if col_name not in label_table.column_names: + self.logger.warning(f"Binary column '{col_name}' not found in CSV, skipping") + continue + + hex_col = label_table.column(col_name) + + # Detect hex string format and convert to binary + # Sample first non-null value to determine format + sample_value = None + for v in hex_col.to_pylist()[:100]: # Check first 100 values + if v is not None: + sample_value = v + break + + if sample_value is None: + self.logger.warning(f"Column '{col_name}' has no non-null values, skipping conversion") + continue + + # Detect if it's a hex string (with or without 0x prefix) + if isinstance(sample_value, str) and all(c in '0123456789abcdefABCDEFx' for c in sample_value): + # Determine binary length from hex string + hex_str = sample_value[2:] if sample_value.startswith('0x') else sample_value + binary_length = len(hex_str) // 2 + + # Convert all values to binary (fixed-size to match streaming data) + def hex_to_binary(v): + if v is None: + return None + hex_str = v[2:] if v.startswith('0x') else v + return bytes.fromhex(hex_str) + + binary_values = pa.array( + [hex_to_binary(v) for v in hex_col.to_pylist()], + type=pa.binary( + binary_length + ), # Fixed-size binary to match server data (e.g., 20 bytes for addresses) + ) + + # Replace the column + label_table = label_table.set_column( + label_table.schema.get_field_index(col_name), col_name, binary_values + ) + converted_columns.append(f'{col_name} (hex→fixed_size_binary[{binary_length}])') + self.logger.info(f"Converted '{col_name}' from hex string to fixed_size_binary[{binary_length}]") + + self._labels[name] = label_table + + conversion_info = f', converted: {", ".join(converted_columns)}' if converted_columns else '' + self.logger.info( + f"Loaded label '{name}' from {csv_path}: " + f'{label_table.num_rows:,} rows, {len(label_table.schema)} columns ' + f'({", ".join(label_table.schema.names)}){conversion_info}' + ) + + except FileNotFoundError: + raise FileNotFoundError(f'Label CSV file not found: {csv_path}') + except Exception as e: + raise ValueError(f"Failed to load label CSV '{csv_path}': {e}") from e + + def get_label(self, name: str) -> Optional[pa.Table]: + """ + Get label table by name. + + Args: + name: Name of the label dataset + + Returns: + PyArrow Table containing label data, or None if not found + """ + return self._labels.get(name) + + def list_labels(self) -> List[str]: + """ + List all registered label names. + + Returns: + List of label names + """ + return list(self._labels.keys()) + + def remove_label(self, name: str) -> bool: + """ + Remove a label dataset. + + Args: + name: Name of the label to remove + + Returns: + True if label was removed, False if it didn't exist + """ + if name in self._labels: + del self._labels[name] + self.logger.info(f"Removed label '{name}'") + return True + return False + + def clear(self) -> None: + """Remove all label datasets.""" + count = len(self._labels) + self._labels.clear() + self.logger.info(f'Cleared {count} label dataset(s)') diff --git a/src/amp/loaders/types.py b/src/amp/loaders/types.py index 4487f09..78bfa86 100644 --- a/src/amp/loaders/types.py +++ b/src/amp/loaders/types.py @@ -43,6 +43,15 @@ def __str__(self) -> str: return f'❌ Failed to load to {self.table_name}: {self.error}' +@dataclass +class LabelJoinConfig: + """Configuration for label joining operations""" + + label_name: str + label_key_column: str + stream_key_column: str + + @dataclass class LoadConfig: """Configuration for data loading operations""" diff --git a/src/amp/streaming/resilience.py b/src/amp/streaming/resilience.py new file mode 100644 index 0000000..dcb4b24 --- /dev/null +++ b/src/amp/streaming/resilience.py @@ -0,0 +1,177 @@ +""" +Resilience primitives for production-grade streaming. + +Provides retry logic, circuit breaker pattern, and adaptive back pressure +to handle transient failures, rate limiting, and service outages gracefully. +""" + +import logging +import random +import threading +import time +from dataclasses import dataclass +from typing import Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class RetryConfig: + """Configuration for retry behavior with exponential backoff.""" + + enabled: bool = True + max_retries: int = 5 # More generous default for production durability + initial_backoff_ms: int = 2000 # Start with 2s delay + max_backoff_ms: int = 120000 # Cap at 2 minutes + backoff_multiplier: float = 2.0 + jitter: bool = True # Add randomness to prevent thundering herd + + +@dataclass +class BackPressureConfig: + """Configuration for adaptive back pressure / rate limiting.""" + + enabled: bool = True + initial_delay_ms: int = 0 + max_delay_ms: int = 5000 + adapt_on_429: bool = True # Slow down on rate limit responses + adapt_on_timeout: bool = True # Slow down on timeouts + recovery_factor: float = 0.9 # How fast to speed up after success (10% speedup) + + +class ErrorClassifier: + """Classify errors as transient (retryable) or permanent (fatal).""" + + TRANSIENT_PATTERNS = [ + 'timeout', + '429', + '503', + '504', + 'connection reset', + 'temporary failure', + 'service unavailable', + 'too many requests', + 'rate limit', + 'throttle', + 'connection error', + 'broken pipe', + 'connection refused', + 'timed out', + ] + + @staticmethod + def is_transient(error: str) -> bool: + """ + Determine if an error is transient and worth retrying. + + Args: + error: Error message or exception string + + Returns: + True if error appears transient, False if permanent + """ + if not error: + return False + + error_lower = error.lower() + return any(pattern in error_lower for pattern in ErrorClassifier.TRANSIENT_PATTERNS) + + +class ExponentialBackoff: + """ + Calculate exponential backoff delays with optional jitter. + + Jitter helps prevent thundering herd when many clients retry simultaneously. + """ + + def __init__(self, config: RetryConfig): + self.config = config + self.attempt = 0 + + def next_delay(self) -> Optional[float]: + """ + Calculate next backoff delay in seconds. + + Returns: + Delay in seconds, or None if max retries exceeded + """ + if self.attempt >= self.config.max_retries: + return None + + # Exponential backoff: initial * (multiplier ^ attempt) + delay_ms = min( + self.config.initial_backoff_ms * (self.config.backoff_multiplier**self.attempt), + self.config.max_backoff_ms, + ) + + # Add jitter: randomize to 50-150% of calculated delay + if self.config.jitter: + delay_ms *= 0.5 + random.random() + + self.attempt += 1 + return delay_ms / 1000.0 + + def reset(self): + """Reset backoff state for new operation.""" + self.attempt = 0 + + +class AdaptiveRateLimiter: + """ + Adaptive rate limiting that adjusts delay based on error responses. + + Slows down when seeing rate limits (429) or timeouts. + Speeds up gradually when operations succeed. + """ + + def __init__(self, config: BackPressureConfig): + self.config = config + self.current_delay_ms = config.initial_delay_ms + self._lock = threading.Lock() + + def wait(self): + """Wait before next request (applies current delay).""" + if not self.config.enabled: + return + + delay_ms = self.current_delay_ms + if delay_ms > 0: + time.sleep(delay_ms / 1000.0) + + def record_success(self): + """Speed up gradually after a successful operation.""" + if not self.config.enabled: + return + + with self._lock: + # Speed up by recovery_factor (e.g., 10% faster per success) + # Can decrease all the way to zero - only delay when actually needed + self.current_delay_ms = max(0, self.current_delay_ms * self.config.recovery_factor) + + def record_rate_limit(self): + """Slow down significantly after rate limit response (429).""" + if not self.config.enabled or not self.config.adapt_on_429: + return + + with self._lock: + # Double the delay + 1 second penalty + self.current_delay_ms = min(self.current_delay_ms * 2 + 1000, self.config.max_delay_ms) + + logger.warning( + f'Rate limit detected (429). Adaptive back pressure increased delay to {self.current_delay_ms}ms.' + ) + + def record_timeout(self): + """Slow down moderately after timeout.""" + if not self.config.enabled or not self.config.adapt_on_timeout: + return + + with self._lock: + # 1.5x the delay + 500ms penalty + self.current_delay_ms = min(self.current_delay_ms * 1.5 + 500, self.config.max_delay_ms) + + logger.info(f'Timeout detected. Adaptive back pressure increased delay to {self.current_delay_ms}ms.') + + def get_current_delay(self) -> int: + """Get current delay in milliseconds (for monitoring).""" + return int(self.current_delay_ms) diff --git a/src/amp/streaming/state.py b/src/amp/streaming/state.py new file mode 100644 index 0000000..21d0a05 --- /dev/null +++ b/src/amp/streaming/state.py @@ -0,0 +1,475 @@ +""" +Unified stream state management for amp. + +This module replaces the separate checkpoint and processed_ranges systems with a +single unified mechanism that provides both resumability and idempotency. +""" + +import hashlib +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Dict, List, Optional, Set, Tuple + +from amp.streaming.types import BlockRange, ResumeWatermark + + +@dataclass(frozen=True, eq=True) +class BatchIdentifier: + """ + Unique identifier for a microbatch based on its block range and chain state. + + This serves as the atomic unit of processing across the entire system: + - Used for idempotency checks (prevent duplicate processing) + - Stored as metadata in data tables (enable fast invalidation) + - Tracked in state store (for resume position calculation) + + The unique_id is a hash of the block range + block hashes, making it unique + across blockchain reorganizations (same range, different hash = different batch). + """ + + network: str + start_block: int + end_block: int + end_hash: str # Hash of the end block (required for uniqueness) + start_parent_hash: str = "" # Hash of block before start (optional for chain validation) + + @property + def unique_id(self) -> str: + """ + Generate a 16-character hex string as unique identifier. + + Uses SHA256 hash of canonical representation to ensure: + - Deterministic (same input always produces same ID) + - Collision-resistant (extremely unlikely to have duplicates) + - Compact (16 hex chars = 64 bits, suitable for indexing) + """ + canonical = ( + f"{self.network}:" + f"{self.start_block}:{self.end_block}:" + f"{self.end_hash}:{self.start_parent_hash}" + ) + return hashlib.sha256(canonical.encode()).hexdigest()[:16] + + @property + def position_key(self) -> Tuple[str, int, int]: + """Position-based key for range queries (network, start, end).""" + return (self.network, self.start_block, self.end_block) + + @classmethod + def from_block_range(cls, br: BlockRange) -> "BatchIdentifier": + """ + Create BatchIdentifier from a BlockRange metadata object. + + Supports two modes: + 1. Hash-based IDs: When BlockRange has server-provided block hash (streaming with reorg detection) + 2. Position-based IDs: When BlockRange lacks hash (parallel loads from regular queries) + + Both produce compact 16-char hex IDs, but position-based IDs are derived from + block range coordinates only, making them suitable for immutable historical data. + """ + if br.hash: + # Hash-based ID: Include server-provided block hash for reorg detection + end_hash = br.hash + else: + # Position-based ID: Generate synthetic hash from block range coordinates + # This provides same compact format without requiring server-provided hashes + import hashlib + canonical = f"{br.network}:{br.start}:{br.end}" + end_hash = hashlib.sha256(canonical.encode('utf-8')).hexdigest() + + return cls( + network=br.network, + start_block=br.start, + end_block=br.end, + end_hash=end_hash, + start_parent_hash=br.prev_hash or "", + ) + + def to_block_range(self) -> BlockRange: + """Convert back to BlockRange for server communication.""" + return BlockRange( + network=self.network, + start=self.start_block, + end=self.end_block, + hash=self.end_hash, + prev_hash=self.start_parent_hash or None, + ) + + def overlaps_or_after(self, from_block: int) -> bool: + """Check if this batch overlaps or comes after a given block number.""" + return self.end_block >= from_block + + +@dataclass +class ProcessedBatch: + """ + Record of a successfully processed batch with full metadata. + + This is the persistence format used by database-backed StreamStateStore + implementations. The in-memory store just uses BatchIdentifier directly. + """ + + batch_id: BatchIdentifier + processed_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + reorg_invalidation: bool = False # Marks batches deleted due to reorg + + def to_dict(self) -> dict: + """Serialize for database storage.""" + return { + "network": self.batch_id.network, + "start_block": self.batch_id.start_block, + "end_block": self.batch_id.end_block, + "end_hash": self.batch_id.end_hash, + "start_parent_hash": self.batch_id.start_parent_hash, + "unique_id": self.batch_id.unique_id, + "processed_at": self.processed_at.isoformat(), + "reorg_invalidation": self.reorg_invalidation, + } + + @classmethod + def from_dict(cls, data: dict) -> "ProcessedBatch": + """Deserialize from database storage.""" + batch_id = BatchIdentifier( + network=data["network"], + start_block=data["start_block"], + end_block=data["end_block"], + end_hash=data["end_hash"], + start_parent_hash=data.get("start_parent_hash", ""), + ) + return cls( + batch_id=batch_id, + processed_at=datetime.fromisoformat(data["processed_at"]), + reorg_invalidation=data.get("reorg_invalidation", False), + ) + + +class StreamStateStore(ABC): + """ + Abstract base class for unified stream state management. + + Replaces both CheckpointStore and ProcessedRangesStore with a single + mechanism that provides: + - Idempotency: Check if batches were already processed + - Resumability: Calculate resume position from processed batches + - Reorg handling: Invalidate batches affected by chain reorganizations + """ + + @abstractmethod + def is_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> bool: + """ + Check if all given batches have already been processed. + + Used for idempotency - prevents duplicate processing of the same data. + Returns True only if ALL batches in the list are already processed. + """ + pass + + @abstractmethod + def mark_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> None: + """ + Mark the given batches as successfully processed. + + Called after data has been committed to the target system. + """ + pass + + @abstractmethod + def get_resume_position( + self, connection_name: str, table_name: str, detect_gaps: bool = False + ) -> Optional[ResumeWatermark]: + """ + Calculate the resume position from processed batches. + + Args: + connection_name: Connection identifier + table_name: Destination table name + detect_gaps: If True, detect and return gaps in processed ranges. + If False, only return max processed position per network. + + Returns: + ResumeWatermark with ranges. When detect_gaps=True: + - Gap ranges: BlockRange(network, gap_start, gap_end, hash=None) + - Remaining range markers: BlockRange(network, max_block+1, max_block+1, hash=end_hash) + (start==end signals "process from here to max_block in config") + + When detect_gaps=False: + - Returns only the maximum processed block for each network + """ + pass + + @abstractmethod + def invalidate_from_block( + self, connection_name: str, table_name: str, network: str, from_block: int + ) -> List[BatchIdentifier]: + """ + Invalidate batches affected by a blockchain reorganization. + + Removes all batches for the given network where end_block >= from_block. + Returns the list of invalidated batch IDs for use in deleting data. + """ + pass + + @abstractmethod + def cleanup_before_block( + self, connection_name: str, table_name: str, network: str, before_block: int + ) -> None: + """ + Remove old batch records before a given block number. + + Used for TTL-based cleanup to prevent unbounded state growth. + """ + pass + + +class InMemoryStreamStateStore(StreamStateStore): + """ + In-memory implementation of StreamStateStore. + + This is the default implementation that works immediately without any + database dependencies. State is lost on process restart, but provides + idempotency within a single session. + + Loaders can optionally implement persistent versions that survive restarts. + """ + + def __init__(self): + # Key: (connection_name, table_name, network) + # Value: Set of BatchIdentifier objects + self._state: Dict[Tuple[str, str, str], Set[BatchIdentifier]] = {} + + def _get_key( + self, connection_name: str, table_name: str, network: Optional[str] = None + ) -> Tuple[str, str, str] | List[Tuple[str, str, str]]: + """Get storage key(s) for the given parameters.""" + if network: + return (connection_name, table_name, network) + else: + # Return all keys for this connection/table across all networks + return [ + k + for k in self._state.keys() + if k[0] == connection_name and k[1] == table_name + ] + + def is_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> bool: + """Check if all batches have been processed.""" + if not batch_ids: + return False + + # Group by network + by_network: Dict[str, List[BatchIdentifier]] = {} + for batch_id in batch_ids: + by_network.setdefault(batch_id.network, []).append(batch_id) + + # Check each network + for network, network_batch_ids in by_network.items(): + key = self._get_key(connection_name, table_name, network) + processed = self._state.get(key, set()) + + # All batches for this network must be in the processed set + for batch_id in network_batch_ids: + if batch_id not in processed: + return False + + return True + + def mark_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> None: + """Mark batches as processed.""" + # Group by network + by_network: Dict[str, List[BatchIdentifier]] = {} + for batch_id in batch_ids: + by_network.setdefault(batch_id.network, []).append(batch_id) + + # Store in sets by network + for network, network_batch_ids in by_network.items(): + key = self._get_key(connection_name, table_name, network) + if key not in self._state: + self._state[key] = set() + + self._state[key].update(network_batch_ids) + + def get_resume_position( + self, connection_name: str, table_name: str, detect_gaps: bool = False + ) -> Optional[ResumeWatermark]: + """ + Calculate resume position from processed batches. + + Args: + connection_name: Connection identifier + table_name: Destination table name + detect_gaps: If True, detect and return gaps in processed ranges + + Returns: + ResumeWatermark with gap and/or continuation ranges + """ + keys = self._get_key(connection_name, table_name) + if not isinstance(keys, list): + keys = [keys] + + if not detect_gaps: + # Simple mode: Return max processed position per network + return self._get_max_processed_position(keys) + + # Gap-aware mode: Detect gaps and combine with remaining range markers + gaps = self._detect_gaps_in_memory(keys) + max_positions = self._get_max_processed_position(keys) + + if not gaps and not max_positions: + return None + + all_ranges = [] + + # Add gap ranges + all_ranges.extend(gaps) + + # Add remaining range markers (after max processed block, to finish historical catch-up) + if max_positions: + for br in max_positions.ranges: + all_ranges.append( + BlockRange( + network=br.network, + start=br.end + 1, + end=br.end + 1, # Same value = marker for remaining unprocessed range + hash=br.hash, + prev_hash=br.prev_hash + ) + ) + + return ResumeWatermark(ranges=all_ranges) if all_ranges else None + + def _get_max_processed_position(self, keys: List[Tuple[str, str, str]]) -> Optional[ResumeWatermark]: + """Get max processed position for each network (simple mode).""" + # Find max block for each network + max_by_network: Dict[str, BatchIdentifier] = {} + + for key in keys: + network = key[2] + batches = self._state.get(key, set()) + + if batches: + # Find batch with highest end_block for this network + max_batch = max(batches, key=lambda b: b.end_block) + + if ( + network not in max_by_network + or max_batch.end_block > max_by_network[network].end_block + ): + max_by_network[network] = max_batch + + if not max_by_network: + return None + + # Convert to BlockRange list for ResumeWatermark + ranges = [batch_id.to_block_range() for batch_id in max_by_network.values()] + return ResumeWatermark(ranges=ranges) + + def _detect_gaps_in_memory(self, keys: List[Tuple[str, str, str]]) -> List[BlockRange]: + """Detect gaps in processed ranges using in-memory analysis.""" + gaps = [] + + for key in keys: + network = key[2] + batches = self._state.get(key, set()) + + if not batches: + continue + + # Sort batches by end_block + sorted_batches = sorted(batches, key=lambda b: b.end_block) + + # Find gaps between consecutive batches + for i in range(len(sorted_batches) - 1): + current_batch = sorted_batches[i] + next_batch = sorted_batches[i + 1] + + # Gap exists if next batch doesn't start immediately after current + if next_batch.start_block > current_batch.end_block + 1: + gaps.append( + BlockRange( + network=network, + start=current_batch.end_block + 1, + end=next_batch.start_block - 1, + hash=None, # Position-based for gaps + prev_hash=None + ) + ) + + return gaps + + def invalidate_from_block( + self, connection_name: str, table_name: str, network: str, from_block: int + ) -> List[BatchIdentifier]: + """Invalidate batches affected by reorg.""" + key = self._get_key(connection_name, table_name, network) + batches = self._state.get(key, set()) + + # Find batches that overlap or come after the reorg point + affected = [b for b in batches if b.overlaps_or_after(from_block)] + + # Remove from state + if affected: + self._state[key] = batches - set(affected) + + return affected + + def cleanup_before_block( + self, connection_name: str, table_name: str, network: str, before_block: int + ) -> None: + """Remove old batches before a given block.""" + key = self._get_key(connection_name, table_name, network) + batches = self._state.get(key, set()) + + # Keep only batches that end at or after the cutoff + kept = {b for b in batches if b.end_block >= before_block} + + if kept != batches: + self._state[key] = kept + + +class NullStreamStateStore(StreamStateStore): + """ + No-op implementation that disables state tracking. + + Used when state management is disabled entirely. All operations are no-ops, + providing no resumability or idempotency guarantees. + """ + + def is_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> bool: + """Always return False (never skip processing).""" + return False + + def mark_processed( + self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier] + ) -> None: + """No-op.""" + pass + + def get_resume_position( + self, connection_name: str, table_name: str, detect_gaps: bool = False + ) -> Optional[ResumeWatermark]: + """Always return None (no resume position available).""" + return None + + def invalidate_from_block( + self, connection_name: str, table_name: str, network: str, from_block: int + ) -> List[BatchIdentifier]: + """Return empty list (nothing to invalidate).""" + return [] + + def cleanup_before_block( + self, connection_name: str, table_name: str, network: str, before_block: int + ) -> None: + """No-op.""" + pass diff --git a/src/amp/streaming/types.py b/src/amp/streaming/types.py index 1067a74..18c1074 100644 --- a/src/amp/streaming/types.py +++ b/src/amp/streaming/types.py @@ -17,6 +17,8 @@ class BlockRange: network: str start: int end: int + hash: Optional[str] = None # Block hash from server (for end block) + prev_hash: Optional[str] = None # Previous block hash (for chain validation) def __post_init__(self): if self.start > self.end: @@ -40,16 +42,55 @@ def merge_with(self, other: 'BlockRange') -> 'BlockRange': """Merge with another range on the same network""" if self.network != other.network: raise ValueError(f'Cannot merge ranges from different networks: {self.network} vs {other.network}') - return BlockRange(network=self.network, start=min(self.start, other.start), end=max(self.end, other.end)) + return BlockRange( + network=self.network, + start=min(self.start, other.start), + end=max(self.end, other.end), + hash=other.hash if other.end > self.end else self.hash, + prev_hash=self.prev_hash, # Keep original prev_hash + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'BlockRange': - """Create BlockRange from dictionary""" - return cls(network=data['network'], start=data['start'], end=data['end']) + """Create BlockRange from dictionary (supports both server and client formats) + + The server sends ranges with nested numbers: {"numbers": {"start": X, "end": Y}, ...} + But our to_dict() outputs flat format: {"start": X, "end": Y, ...} for simplicity. + + Both formats must be supported because: + - Server → Client: Uses nested "numbers" format (confirmed 2025-10-23) + - Client → Storage: Uses flat format for checkpoints, watermarks, internal state + - Backward compatibility: Existing stored state uses flat format + """ + # Server format: {"numbers": {"start": X, "end": Y}, "network": ..., "hash": ..., "prev_hash": ...} + if 'numbers' in data: + numbers = data['numbers'] + return cls( + network=data['network'], + start=numbers.get('start') if isinstance(numbers, dict) else numbers['start'], + end=numbers.get('end') if isinstance(numbers, dict) else numbers['end'], + hash=data.get('hash'), + prev_hash=data.get('prev_hash'), + ) + else: + # Client/internal format: {"network": ..., "start": ..., "end": ...} + # Used by to_dict(), checkpoints, watermarks, and stored state + return cls( + network=data['network'], + start=data['start'], + end=data['end'], + hash=data.get('hash'), + prev_hash=data.get('prev_hash'), + ) def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return {'network': self.network, 'start': self.start, 'end': self.end} + """Convert to dictionary (client format for simplicity)""" + result = {'network': self.network, 'start': self.start, 'end': self.end} + if self.hash is not None: + result['hash'] = self.hash + if self.prev_hash is not None: + result['prev_hash'] = self.prev_hash + return result @dataclass @@ -57,7 +98,7 @@ class BatchMetadata: """Metadata associated with a response batch""" ranges: List[BlockRange] - # Additional metadata fields can be added here + ranges_complete: bool = False # Marks safe checkpoint boundaries extra: Optional[Dict[str, Any]] = None @classmethod @@ -70,20 +111,30 @@ def from_flight_data(cls, metadata_bytes: bytes) -> 'BatchMetadata': else: metadata_str = metadata_bytes.decode('utf-8') metadata_dict = json.loads(metadata_str) + + # Parse block ranges ranges = [BlockRange.from_dict(r) for r in metadata_dict.get('ranges', [])] - extra = {k: v for k, v in metadata_dict.items() if k != 'ranges'} - return cls(ranges=ranges, extra=extra if extra else None) + + # Extract ranges_complete flag (server sends this at microbatch boundaries) + ranges_complete = metadata_dict.get('ranges_complete', False) + + # Store remaining fields in extra + extra = {k: v for k, v in metadata_dict.items() if k not in ('ranges', 'ranges_complete')} + + return cls(ranges=ranges, ranges_complete=ranges_complete, extra=extra if extra else None) except (json.JSONDecodeError, KeyError) as e: # Fallback to empty metadata if parsing fails - return cls(ranges=[], extra={'parse_error': str(e)}) + return cls(ranges=[], ranges_complete=False, extra={'parse_error': str(e)}) @dataclass class ResponseBatch: - """Response batch containing data and metadata""" + """Response batch containing data and metadata, optionally marking reorg events""" data: pa.RecordBatch metadata: BatchMetadata + is_reorg: bool = False # True if this is a reorg notification + invalidation_ranges: Optional[List[BlockRange]] = None # Ranges invalidated by reorg @property def num_rows(self) -> int: @@ -95,41 +146,23 @@ def networks(self) -> List[str]: """List of networks covered by this batch""" return list(set(r.network for r in self.metadata.ranges)) - -class ResponseBatchType(Enum): - """Type of response batch""" - - DATA = 'data' - REORG = 'reorg' - - -@dataclass -class ResponseBatchWithReorg: - """Response that can be either a data batch or a reorg notification""" - - batch_type: ResponseBatchType - data: Optional[ResponseBatch] = None - invalidation_ranges: Optional[List[BlockRange]] = None - - @property - def is_data(self) -> bool: - """True if this is a data batch""" - return self.batch_type == ResponseBatchType.DATA - - @property - def is_reorg(self) -> bool: - """True if this is a reorg notification""" - return self.batch_type == ResponseBatchType.REORG - @classmethod - def data_batch(cls, batch: ResponseBatch) -> 'ResponseBatchWithReorg': + def data_batch(cls, data: pa.RecordBatch, metadata: BatchMetadata) -> 'ResponseBatch': """Create a data batch response""" - return cls(batch_type=ResponseBatchType.DATA, data=batch) + return cls(data=data, metadata=metadata, is_reorg=False) @classmethod - def reorg_batch(cls, invalidation_ranges: List[BlockRange]) -> 'ResponseBatchWithReorg': - """Create a reorg notification response""" - return cls(batch_type=ResponseBatchType.REORG, invalidation_ranges=invalidation_ranges) + def reorg_batch(cls, invalidation_ranges: List[BlockRange]) -> 'ResponseBatch': + """Create a reorg notification response (with empty data)""" + # Create empty batch for reorg notifications + empty_batch = pa.record_batch([], schema=pa.schema([])) + empty_metadata = BatchMetadata(ranges=[]) + return cls( + data=empty_batch, + metadata=empty_metadata, + is_reorg=True, + invalidation_ranges=invalidation_ranges + ) @dataclass