From 622a7265d5fb25933e690a063903e8be64fe8809 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Dec 2025 15:49:17 +0200 Subject: [PATCH 1/2] Added intrastructure and integration point with OTel --- redis/observability/__init__.py | 0 redis/observability/attributes.py | 296 ++++++++++ redis/observability/config.py | 169 ++++++ redis/observability/metrics.py | 506 +++++++++++++++++ redis/observability/providers.py | 333 +++++++++++ redis/observability/recorder.py | 505 +++++++++++++++++ tests/test_observability/__init__.py | 0 tests/test_observability/test_config.py | 348 ++++++++++++ tests/test_observability/test_provider.py | 457 +++++++++++++++ tests/test_observability/test_recorder.py | 649 ++++++++++++++++++++++ 10 files changed, 3263 insertions(+) create mode 100644 redis/observability/__init__.py create mode 100644 redis/observability/attributes.py create mode 100644 redis/observability/config.py create mode 100644 redis/observability/metrics.py create mode 100644 redis/observability/providers.py create mode 100644 redis/observability/recorder.py create mode 100644 tests/test_observability/__init__.py create mode 100644 tests/test_observability/test_config.py create mode 100644 tests/test_observability/test_provider.py create mode 100644 tests/test_observability/test_recorder.py diff --git a/redis/observability/__init__.py b/redis/observability/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py new file mode 100644 index 0000000000..095855d384 --- /dev/null +++ b/redis/observability/attributes.py @@ -0,0 +1,296 @@ +""" +OpenTelemetry semantic convention attributes for Redis. + +This module provides constants and helper functions for building OTel attributes +according to the semantic conventions for database clients. + +Reference: https://opentelemetry.io/docs/specs/semconv/database/redis/ +""" +from enum import Enum +from typing import Any, Dict, Optional + +# Database semantic convention attributes +DB_SYSTEM = "db.system" +DB_NAMESPACE = "db.namespace" +DB_OPERATION_NAME = "db.operation.name" +DB_OPERATION_BATCH_SIZE = "db.operation.batch.size" +DB_RESPONSE_STATUS_CODE = "db.response.status_code" +DB_STORED_PROCEDURE_NAME = "db.stored_procedure.name" + +# Error attributes +ERROR_TYPE = "error.type" + +# Network attributes +NETWORK_PEER_ADDRESS = "network.peer.address" +NETWORK_PEER_PORT = "network.peer.port" + +# Server attributes +SERVER_ADDRESS = "server.address" +SERVER_PORT = "server.port" + +# Connection pool attributes +DB_CLIENT_CONNECTION_POOL_NAME = "db.client.connection.pool.name" +DB_CLIENT_CONNECTION_STATE = "db.client.connection.state" + +# Redis-specific attributes +REDIS_CLIENT_LIBRARY = "redis.client.library" +REDIS_CLIENT_CONNECTION_PUBSUB = "redis.client.connection.pubsub" +REDIS_CLIENT_CONNECTION_CLOSE_REASON = "redis.client.connection.close.reason" +REDIS_CLIENT_CONNECTION_NOTIFICATION = "redis.client.connection.notification" +REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS = "redis.client.operation.retry_attempts" +REDIS_CLIENT_OPERATION_BLOCKING = "redis.client.operation.blocking" +REDIS_CLIENT_PUBSUB_MESSAGE_DIRECTION = "redis.client.pubsub.message.direction" +REDIS_CLIENT_PUBSUB_CHANNEL = "redis.client.pubsub.channel" +REDIS_CLIENT_PUBSUB_SHARDED = "redis.client.pubsub.sharded" +REDIS_CLIENT_ERROR_INTERNAL = "redis.client.errors.internal" +REDIS_CLIENT_STREAM_NAME = "redis.client.stream.name" +REDIS_CLIENT_CONSUMER_GROUP = "redis.client.consumer_group" +REDIS_CLIENT_CONSUMER_NAME = "redis.client.consumer_name" + +class ConnectionState(Enum): + IDLE = "idle" + USED = "used" + +class PubSubDirection(Enum): + PUBLISH = "publish" + RECEIVE = "receive" + + +class AttributeBuilder: + """ + Helper class to build OTel semantic convention attributes for Redis operations. + """ + + @staticmethod + def build_base_attributes( + server_address: Optional[str] = None, + server_port: Optional[int] = None, + db_namespace: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Build base attributes common to all Redis operations. + + Args: + server_address: Redis server address (FQDN or IP) + server_port: Redis server port + db_namespace: Redis database index + + Returns: + Dictionary of base attributes + """ + attrs: Dict[str, Any] = { + DB_SYSTEM: "redis", + REDIS_CLIENT_LIBRARY: "redis-py" + } + + if server_address is not None: + attrs[SERVER_ADDRESS] = server_address + + if server_port is not None: + attrs[SERVER_PORT] = server_port + + if db_namespace is not None: + attrs[DB_NAMESPACE] = str(db_namespace) + + return attrs + + @staticmethod + def build_operation_attributes( + command_name: Optional[str] = None, + batch_size: Optional[int] = None, + response_status_code: Optional[str] = None, + error_type: Optional[Exception] = None, + network_peer_address: Optional[str] = None, + network_peer_port: Optional[int] = None, + stored_procedure_name: Optional[str] = None, + retry_attempts: Optional[int] = None, + is_blocking: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Build attributes for a Redis operation (command execution). + + Args: + command_name: Redis command name (e.g., 'GET', 'SET', 'MULTI') + batch_size: Number of commands in batch (for pipelines/transactions) + response_status_code: Redis error prefix (e.g., 'ERR', 'WRONGTYPE') + error_type: Error type if operation failed + network_peer_address: Resolved peer address + network_peer_port: Peer port number + stored_procedure_name: Lua script name or SHA1 digest + retry_attempts: Number of retry attempts made + is_blocking: Whether the operation is a blocking command + + Returns: + Dictionary of operation attributes + """ + attrs: Dict[str, Any] = {} + + if command_name is not None: + attrs[DB_OPERATION_NAME] = command_name.upper() + + if batch_size is not None and batch_size >= 2: + attrs[DB_OPERATION_BATCH_SIZE] = batch_size + + if response_status_code is not None: + attrs[DB_RESPONSE_STATUS_CODE] = response_status_code + + if error_type is not None: + attrs[ERROR_TYPE] = AttributeBuilder.extract_error_type(error_type) + + if network_peer_address is not None: + attrs[NETWORK_PEER_ADDRESS] = network_peer_address + + if network_peer_port is not None: + attrs[NETWORK_PEER_PORT] = network_peer_port + + if stored_procedure_name is not None: + attrs[DB_STORED_PROCEDURE_NAME] = stored_procedure_name + + if retry_attempts is not None and retry_attempts > 0: + attrs[REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS] = retry_attempts + + if is_blocking is not None: + attrs[REDIS_CLIENT_OPERATION_BLOCKING] = is_blocking + + return attrs + + @staticmethod + def build_connection_pool_attributes( + pool_name: str, + connection_state: Optional[ConnectionState] = None, + is_pubsub: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Build attributes for connection pool metrics. + + Args: + pool_name: Unique connection pool name + connection_state: Connection state ('idle' or 'used') + is_pubsub: Whether this is a PubSub connection + + Returns: + Dictionary of connection pool attributes + """ + attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes() + attrs[DB_CLIENT_CONNECTION_POOL_NAME] = pool_name + + if connection_state is not None: + attrs[DB_CLIENT_CONNECTION_STATE] = connection_state.value + + if is_pubsub is not None: + attrs[REDIS_CLIENT_CONNECTION_PUBSUB] = is_pubsub + + return attrs + + @staticmethod + def build_error_attributes( + is_internal: bool = False, + error_type: Optional[Exception] = None, + ) -> Dict[str, Any]: + """ + Build error attributes. + + Args: + is_internal: Whether the error is internal (e.g., timeout, network error) + error_type: The exception that occurred + + Returns: + Dictionary of error attributes + """ + attrs: Dict[str, Any] = {REDIS_CLIENT_ERROR_INTERNAL: is_internal} + + if error_type is not None: + attrs[DB_RESPONSE_STATUS_CODE] = None + + return attrs + + @staticmethod + def build_pubsub_message_attributes( + direction: PubSubDirection, + channel: Optional[str] = None, + sharded: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Build attributes for a PubSub message. + + Args: + direction: Message direction ('publish' or 'receive') + channel: Pub/Sub channel name + sharded: True if sharded Pub/Sub channel + + Returns: + Dictionary of PubSub message attributes + """ + attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes() + attrs[REDIS_CLIENT_PUBSUB_MESSAGE_DIRECTION] = direction.value + + if channel is not None: + attrs[REDIS_CLIENT_PUBSUB_CHANNEL] = channel + + if sharded is not None: + attrs[REDIS_CLIENT_PUBSUB_SHARDED] = sharded + + return attrs + + @staticmethod + def build_streaming_attributes( + stream_name: Optional[str] = None, + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Build attributes for a streaming operation. + + Args: + stream_name: Name of the stream + consumer_group: Name of the consumer group + consumer_name: Name of the consumer + + Returns: + Dictionary of streaming attributes + """ + attrs: Dict[str, Any] = AttributeBuilder.build_base_attributes() + + if stream_name is not None: + attrs[REDIS_CLIENT_STREAM_NAME] = stream_name + + if consumer_group is not None: + attrs[REDIS_CLIENT_CONSUMER_GROUP] = consumer_group + + if consumer_name is not None: + attrs[REDIS_CLIENT_CONSUMER_NAME] = consumer_name + + return attrs + + + @staticmethod + def extract_error_type(exception: Exception) -> str: + """ + Extract error type from an exception. + + Args: + exception: The exception that occurred + + Returns: + Error type string (exception class name) + """ + return type(exception).__name__ + + @staticmethod + def build_pool_name( + server_address: str, + server_port: int, + db_namespace: int = 0, + ) -> str: + """ + Build a unique connection pool name. + + Args: + server_address: Redis server address + server_port: Redis server port + db_namespace: Redis database index + + Returns: + Unique pool name in format "address:port/db" + """ + return f"{server_address}:{server_port}/{db_namespace}" \ No newline at end of file diff --git a/redis/observability/config.py b/redis/observability/config.py new file mode 100644 index 0000000000..6d00f1f99c --- /dev/null +++ b/redis/observability/config.py @@ -0,0 +1,169 @@ +import os +from typing import Dict, List, Optional, Union +from enum import IntFlag, auto + + +class MetricGroup(IntFlag): + """Metric groups that can be enabled/disabled.""" + RESILIENCY = auto() + CONNECTION_BASIC = auto() + CONNECTION_ADVANCED = auto() + COMMAND = auto() + CSC = auto() + STREAMING = auto() + PUBSUB = auto() + +class TelemetryOption(IntFlag): + """Telemetry options to export.""" + METRICS = auto() + TRACES = auto() + LOGS = auto() + + +""" +OpenTelemetry configuration for redis-py. + +This module handles configuration for OTel observability features, +including parsing environment variables and validating settings. +""" + + +class OTelConfig: + """ + Configuration for OpenTelemetry observability in redis-py. + + This class manages all OTel-related settings including metrics, traces (future), + and logs (future). Configuration can be provided via constructor parameters or + environment variables (OTEL_* spec). + + Constructor parameters take precedence over environment variables. + + Args: + enabled_telemetry: Enabled telemetry options to export (default: metrics). Traces and logs will be added + in future phases. + metrics_sample_percentage: Percentage of commands to sample (default: 100.0, range: 0.0-100.0) + metric_groups: Group of metrics that should be exported. + include_commands: Explicit allowlist of commands to track + exclude_commands: Blocklist of commands to track + + Note: + Redis-py uses the global MeterProvider set by your application. + Set it up before initializing observability: + + from opentelemetry import metrics + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics._internal.view import View + from opentelemetry.sdk.metrics._internal.aggregation import ExplicitBucketHistogramAggregation + + # Configure histogram bucket boundaries via Views + views = [ + View( + instrument_name="db.client.operation.duration", + aggregation=ExplicitBucketHistogramAggregation( + boundaries=[0.0001, 0.00025, 0.0005, 0.001, 0.0025, 0.005, + 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5] + ), + ), + # Add more views for other histograms... + ] + + provider = MeterProvider(views=views, metric_readers=[reader]) + metrics.set_meter_provider(provider) + + # Then initialize redis-py observability + from redis.observability import get_observability_instance, OTelConfig + otel = get_observability_instance() + otel.init(OTelConfig()) + """ + + DEFAULT_TELEMETRY = TelemetryOption.METRICS + DEFAULT_METRIC_GROUPS = MetricGroup.COMMAND | MetricGroup.CONNECTION_BASIC | MetricGroup.RESILIENCY + + def __init__( + self, + # Core enablement + enabled_telemetry: List[TelemetryOption] = None, + # Metrics-specific + metrics_sample_percentage: float = 100.0, + metric_groups: List[MetricGroup] = None, + # Redis-specific telemetry controls + include_commands: Optional[List[str]] = None, + exclude_commands: Optional[List[str]] = None, + ): + # Core enablement + if enabled_telemetry is None: + self.enabled_telemetry = self.DEFAULT_TELEMETRY + else: + self.enabled_telemetry = TelemetryOption(0) + for option in enabled_telemetry: + self.enabled_telemetry |= option + + # Enable default metrics if None given + if metric_groups is None: + self.metric_groups = self.DEFAULT_METRIC_GROUPS + else: + self.metric_groups = MetricGroup(0) + for metric_group in metric_groups: + self.metric_groups |= metric_group + + # Metrics configuration + if not 0.0 <= metrics_sample_percentage <= 100.0: + raise ValueError( + f"metrics_sample_percentage must be between 0.0 and 100.0, " + f"got {metrics_sample_percentage}" + ) + self.metrics_sample_percentage = metrics_sample_percentage + + # Redis-specific controls + self.include_commands = set(include_commands) if include_commands else None + self.exclude_commands = set(exclude_commands) if exclude_commands else set() + + def is_enabled(self) -> bool: + """Check if any observability feature is enabled.""" + return bool(self.enabled_telemetry) + + def set_sample_percentage(self, percentage: float) -> None: + """ + Set the metrics sample percentage at runtime. + + This allows dynamic adjustment of sampling rate for high-throughput deployments. + + Args: + percentage: Percentage of commands to sample (0.0-100.0) + + Raises: + ValueError: If percentage is not in valid range + + Example: + >>> config.set_sample_percentage(10.0) # Sample 10% of commands + """ + if not 0.0 <= percentage <= 100.0: + raise ValueError( + f"metrics_sample_percentage must be between 0.0 and 100.0, " + f"got {percentage}" + ) + self.metrics_sample_percentage = percentage + + def should_track_command(self, command_name: str) -> bool: + """ + Determine if a command should be tracked based on include/exclude lists. + + Args: + command_name: The Redis command name (e.g., 'GET', 'SET') + + Returns: + True if the command should be tracked, False otherwise + """ + command_upper = command_name.upper() + + # If include list is specified, only track commands in the list + if self.include_commands is not None: + return command_upper in self.include_commands + + # Otherwise, track all commands except those in exclude list + return command_upper not in self.exclude_commands + + def __repr__(self) -> str: + return ( + f"OTelConfig(enabled_telemetry={self.enabled_telemetry}, " + ) \ No newline at end of file diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py new file mode 100644 index 0000000000..7a70d82a84 --- /dev/null +++ b/redis/observability/metrics.py @@ -0,0 +1,506 @@ +""" +OpenTelemetry metrics collector for redis-py. + +This module defines and manages all metric instruments according to +OTel semantic conventions for database clients. +""" + +import logging +import time +from typing import Any, Dict, Optional + +from redis.observability.attributes import AttributeBuilder, ConnectionState, REDIS_CLIENT_CONNECTION_NOTIFICATION, \ + REDIS_CLIENT_CONNECTION_CLOSE_REASON, ERROR_TYPE, PubSubDirection +from redis.observability.config import OTelConfig, MetricGroup + +logger = logging.getLogger(__name__) + +# Optional imports - OTel SDK may not be installed +try: + from opentelemetry.metrics import Counter, Histogram, Meter, UpDownCounter + + OTEL_AVAILABLE = True +except ImportError: + OTEL_AVAILABLE = False + Counter = None + Histogram = None + Meter = None + UpDownCounter = None + + +class RedisMetricsCollector: + """ + Collects and records OpenTelemetry metrics for Redis operations. + + This class manages all metric instruments and provides methods to record + various Redis operations including connection pool events, command execution, + and cluster-specific operations. + + Args: + meter: OpenTelemetry Meter instance + config: OTel configuration object + """ + + METER_NAME = "redis-py" + METER_VERSION = "1.0.0" + + def __init__(self, meter: "Meter", config: OTelConfig): + if not OTEL_AVAILABLE: + raise ImportError( + "OpenTelemetry API is not installed. " + "Install it with: pip install opentelemetry-api" + ) + + self.meter = meter + self.config = config + self.attr_builder = AttributeBuilder() + + # Initialize enabled metric instruments + + if MetricGroup.RESILIENCY in self.config.metric_groups: + self._init_resiliency_metrics() + + if MetricGroup.COMMAND in self.config.metric_groups: + self._init_command_metrics() + + if MetricGroup.CONNECTION_BASIC in self.config.metric_groups: + self._init_connection_basic_metrics() + + if MetricGroup.CONNECTION_ADVANCED in self.config.metric_groups: + self._init_connection_advanced_metrics() + + if MetricGroup.PUBSUB in self.config.metric_groups: + self._init_pubsub_metrics() + + if MetricGroup.STREAMING in self.config.metric_groups: + self._init_streaming_metrics() + + logger.info("RedisMetricsCollector initialized") + + def _init_resiliency_metrics(self) -> None: + """Initialize resiliency metrics.""" + self.client_errors = self.meter.create_counter( + name="redis.client.errors", + unit="{error}", + description="A counter of all errors (both returned to the user and handled internally in the client library)", + ) + + self.maintenance_notifications = self.meter.create_counter( + name="redis.client.maintenance.notifications", + unit="{notification}", + description="Tracks server-side maintenance notifications" + ) + + def _init_connection_basic_metrics(self) -> None: + """Initialize basic connection metrics.""" + self.connection_count = self.meter.create_up_down_counter( + name="db.client.connection.count", + unit="{connections}", + description="Current connections by state (idle/used)", + ) + + self.connection_create_time = self.meter.create_histogram( + name="db.client.connection.create_time", + unit="{seconds}", + description="Time to create a new connection", + ) + + self.connection_relaxed_timeout = self.meter.create_up_down_counter( + name="redis.client.connection.relaxed_timeout", + unit="{relaxation}", + description="Counts up for relaxed timeout, counts down for unrelaxed timeout", + ) + + self.connection_handoff = self.meter.create_counter( + name="redis.client.connection.handoff", + unit="{handoff}", + description="Connections that have been handed off (e.g., after a MOVING notification)", + ) + + def _init_connection_advanced_metrics(self) -> None: + """Initialize advanced connection metrics.""" + self.connection_timeouts = self.meter.create_counter( + name="db.client.connection.timeouts", + unit="{timeout}", + description="The number of connection timeouts that have occurred trying to obtain a connection from the pool.", + ) + + self.connection_wait_time = self.meter.create_histogram( + name="db.client.connection.wait_time", + unit="{seconds}", + description="Time to obtain an open connection from the pool", + ) + + self.connection_use_time = self.meter.create_histogram( + name="db.client.connection.use_time", + unit="{seconds}", + description="Time between borrowing and returning a connection", + ) + + self.connection_closed = self.meter.create_counter( + name="redis.client.connection.closed", + unit="{connection}", + description="Total number of closed connections", + ) + + + def _init_command_metrics(self) -> None: + """Initialize command execution metric instruments.""" + self.operation_duration = self.meter.create_histogram( + name="db.client.operation.duration", + unit="{seconds}", + description="Command execution duration", + ) + + def _init_pubsub_metrics(self) -> None: + """Initialize PubSub metric instruments.""" + self.pubsub_messages = self.meter.create_counter( + name="redis.client.pubsub.messages", + unit="{message}", + description="Tracks published and received messages", + ) + + def _init_streaming_metrics(self) -> None: + """Initialize Streaming metric instruments.""" + self.stream_lag = self.meter.create_histogram( + name="redis.client.stream.lag", + unit="{seconds}", + description="End-to-end lag per message, showing how stale are the messages when the application starts processing them." + ) + + # Resiliency metric recording methods + + def record_error_count( + self, + server_address: str, + server_port: int, + network_peer_address: str, + network_peer_port: int, + error_type: Exception, + retry_attempts: int, + ): + """ + Record error count + + Args: + server_address: Server address + server_port: Server port + network_peer_address: Network peer address + network_peer_port: Network peer port + error_type: Error type + retry_attempts: Retry attempts + """ + attrs = self.attr_builder.build_base_attributes( + server_address=server_address, + server_port=server_port, + ) + attrs.update( + self.attr_builder.build_operation_attributes( + error_type=error_type, + network_peer_address=network_peer_address, + network_peer_port=network_peer_port, + retry_attempts=retry_attempts, + ) + ) + + attrs.update( + self.attr_builder.build_error_attributes( + error_type=error_type, + ) + ) + + self.client_errors.add(1, attributes=attrs) + + def record_maint_notification_count( + self, + server_address: str, + server_port: int, + network_peer_address: str, + network_peer_port: int, + maint_notification: str + ): + """ + Record maintenance notification count + + Args: + server_address: Server address + server_port: Server port + network_peer_address: Network peer address + network_peer_port: Network peer port + maint_notification: Maintenance notification + """ + attrs = self.attr_builder.build_base_attributes( + server_address=server_address, + server_port=server_port, + ) + + attrs.update( + self.attr_builder.build_operation_attributes( + network_peer_address=network_peer_address, + network_peer_port=network_peer_port, + ) + ) + + attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] = maint_notification + self.maintenance_notifications.add(1, attributes=attrs) + + def record_connection_count( + self, + count: int, + pool_name: str, + state: ConnectionState, + is_pubsub: bool, + ) -> None: + """ + Record current connection count by state. + + Args: + count: Increment/Decrement + pool_name: Connection pool name + state: Connection state ('idle' or 'used') + is_pubsub: Whether or not the connection is pubsub + """ + attrs = self.attr_builder.build_connection_pool_attributes( + pool_name=pool_name, + connection_state=state, + is_pubsub=is_pubsub, + ) + self.connection_count.add(count, attributes=attrs) + + def record_connection_timeout(self, pool_name: str) -> None: + """ + Record a connection timeout event. + + Args: + pool_name: Connection pool name + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + self.connection_timeouts.add(1, attributes=attrs) + + def record_connection_create_time( + self, + pool_name: str, + duration_seconds: float, + ) -> None: + """ + Record time taken to create a new connection. + + Args: + pool_name: Connection pool name + duration_seconds: Creation time in seconds + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + self.connection_create_time.record(duration_seconds, attributes=attrs) + + def record_connection_wait_time( + self, + pool_name: str, + duration_seconds: float, + ) -> None: + """ + Record time taken to obtain a connection from the pool. + + Args: + pool_name: Connection pool name + duration_seconds: Wait time in seconds + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + self.connection_wait_time.record(duration_seconds, attributes=attrs) + + def record_connection_use_time( + self, + pool_name: str, + duration_seconds: float, + ) -> None: + """ + Record time a connection was in use (borrowed from pool). + + Args: + pool_name: Connection pool name + duration_seconds: Use time in seconds + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + self.connection_use_time.record(duration_seconds, attributes=attrs) + + # Command execution metric recording methods + + def record_operation_duration( + self, + command_name: str, + duration_seconds: float, + server_address: Optional[str] = None, + server_port: Optional[int] = None, + db_namespace: Optional[int] = None, + batch_size: Optional[int] = None, + response_status_code: Optional[str] = None, + error_type: Optional[Exception] = None, + network_peer_address: Optional[str] = None, + network_peer_port: Optional[int] = None, + retry_attempts: Optional[int] = None, + is_blocking: Optional[bool] = None, + ) -> None: + + """ + Record command execution duration. + + Args: + command_name: Redis command name (e.g., 'GET', 'SET', 'MULTI') + duration_seconds: Execution time in seconds + server_address: Redis server address + server_port: Redis server port + db_namespace: Redis database index + batch_size: Number of commands in batch (for pipelines/transactions) + response_status_code: Redis error prefix if operation failed + error_type: Error type if operation failed + network_peer_address: Resolved peer address + network_peer_port: Peer port number + retry_attempts: Number of retry attempts made + is_blocking: Whether the operation is a blocking command + """ + + # Check if this command should be tracked + if not self.config.should_track_command(command_name): + return + + # Build attributes + attrs = self.attr_builder.build_base_attributes( + server_address=server_address, + server_port=server_port, + db_namespace=db_namespace, + ) + + attrs.update( + self.attr_builder.build_operation_attributes( + command_name=command_name, + batch_size=batch_size, + response_status_code=response_status_code, + error_type=error_type, + network_peer_address=network_peer_address, + network_peer_port=network_peer_port, + retry_attempts=retry_attempts, + is_blocking=is_blocking, + ) + ) + + attrs.update( + self.attr_builder.build_error_attributes( + error_type=error_type, + ) + ) + + self.operation_duration.record(duration_seconds, attributes=attrs) + + def record_connection_closed( + self, + pool_name: str, + close_reason: Optional[str] = None, + error_type: Optional[Exception] = None, + ) -> None: + """ + Record a connection closed event. + + Args: + pool_name: Connection pool name + close_reason: Reason for closing (e.g., 'idle_timeout', 'error', 'shutdown') + error_type: Error type if closed due to error + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + if close_reason: + attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] = close_reason + if error_type: + attrs[ERROR_TYPE] = AttributeBuilder.extract_error_type(error_type) + self.connection_closed.add(1, attributes=attrs) + + def record_connection_relaxed_timeout( + self, + pool_name: str, + maint_notification: str, + relaxed: bool, + ) -> None: + """ + Record a connection timeout relaxation event. + + Args: + pool_name: Connection pool name + maint_notification: Maintenance notification type + relaxed: True to count up (relaxed), False to count down (unrelaxed) + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] = maint_notification + self.connection_relaxed_timeout.add(1 if relaxed else -1, attributes=attrs) + + def record_connection_handoff( + self, + pool_name: str, + ) -> None: + """ + Record a connection handoff event (e.g., after MOVING notification). + + Args: + pool_name: Connection pool name + """ + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) + self.connection_handoff.add(1, attributes=attrs) + + # PubSub metric recording methods + + def record_pubsub_message( + self, + direction: PubSubDirection, + channel: Optional[str] = None, + sharded: Optional[bool] = None, + ) -> None: + """ + Record a PubSub message (published or received). + + Args: + direction: Message direction ('publish' or 'receive') + channel: Pub/Sub channel name + sharded: True if sharded Pub/Sub channel + """ + attrs = self.attr_builder.build_pubsub_message_attributes( + direction=direction, + channel=channel, + sharded=sharded, + ) + self.pubsub_messages.add(1, attributes=attrs) + + # Streaming metric recording methods + + def record_streaming_lag( + self, + lag_seconds: float, + stream_name: Optional[str] = None, + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, + ) -> None: + """ + Record the lag of a streaming message. + + Args: + lag_seconds: Lag in seconds + stream_name: Stream name + consumer_group: Consumer group name + consumer_name: Consumer name + """ + attrs = self.attr_builder.build_streaming_attributes( + stream_name=stream_name, + consumer_group=consumer_group, + consumer_name=consumer_name, + ) + self.stream_lag.record(lag_seconds, attributes=attrs) + + # Utility methods + + @staticmethod + def monotonic_time() -> float: + """ + Get monotonic time for duration measurements. + + Returns: + Current monotonic time in seconds + """ + return time.monotonic() + + def __repr__(self) -> str: + return f"RedisMetricsCollector(meter={self.meter}, config={self.config})" \ No newline at end of file diff --git a/redis/observability/providers.py b/redis/observability/providers.py new file mode 100644 index 0000000000..04732ea806 --- /dev/null +++ b/redis/observability/providers.py @@ -0,0 +1,333 @@ +""" +OpenTelemetry provider management for redis-py. + +This module handles initialization and lifecycle management of OTel SDK components +including MeterProvider, TracerProvider (future), and LoggerProvider (future). + +Uses a singleton pattern - initialize once globally, all Redis clients use it automatically. + +Redis-py uses the global MeterProvider set by your application. Set it up before +initializing observability: + + from opentelemetry import metrics + from opentelemetry.sdk.metrics import MeterProvider + + provider = MeterProvider(...) + metrics.set_meter_provider(provider) + + # Then initialize redis-py observability + otel = get_observability_instance() + otel.init(OTelConfig(enable_metrics=True)) +""" + +import logging +from typing import Optional + +from opentelemetry.sdk.metrics import MeterProvider + +from redis.observability.config import OTelConfig + +logger = logging.getLogger(__name__) + +# Global singleton instance +_global_provider_manager: Optional["OTelProviderManager"] = None + + +class OTelProviderManager: + """ + Manages OpenTelemetry SDK providers and their lifecycle. + + This class handles: + - Getting the global MeterProvider set by the application + - Configuring histogram bucket boundaries via Views + - Graceful shutdown + + Args: + config: OTel configuration object + """ + + def __init__(self, config: OTelConfig): + self.config = config + self._meter_provider: Optional[MeterProvider] = None + + def get_meter_provider(self) -> Optional[MeterProvider]: + """ + Get the global MeterProvider set by the application. + + Returns: + MeterProvider instance or None if metrics are disabled + + Raises: + ImportError: If OpenTelemetry is not installed + RuntimeError: If metrics are enabled but no global MeterProvider is set + """ + if not self.config.is_enabled(): + return None + + # Lazy import - only import OTel when metrics are enabled + try: + from opentelemetry import metrics + from opentelemetry.metrics import NoOpMeterProvider + except ImportError: + raise ImportError( + "OpenTelemetry is not installed. Install it with:\n" + " pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp-proto-http" + ) + + # Get the global MeterProvider + if self._meter_provider is None: + self._meter_provider = metrics.get_meter_provider() + + # Check if it's a real provider (not NoOp) + if isinstance(self._meter_provider, NoOpMeterProvider): + raise RuntimeError( + "Metrics are enabled but no global MeterProvider is configured.\n" + "\n" + "Set up OpenTelemetry before initializing redis-py observability:\n" + "\n" + " from opentelemetry import metrics\n" + " from opentelemetry.sdk.metrics import MeterProvider\n" + " from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader\n" + " from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter\n" + "\n" + " # Create exporter\n" + " exporter = OTLPMetricExporter(\n" + " endpoint='http://localhost:4318/v1/metrics'\n" + " )\n" + "\n" + " # Create reader\n" + " reader = PeriodicExportingMetricReader(\n" + " exporter=exporter,\n" + " export_interval_millis=10000\n" + " )\n" + "\n" + " # Create and set global provider\n" + " provider = MeterProvider(metric_readers=[reader])\n" + " metrics.set_meter_provider(provider)\n" + "\n" + " # Now initialize redis-py observability\n" + " from redis.observability import get_observability_instance, OTelConfig\n" + " otel = get_observability_instance()\n" + " otel.init(OTelConfig(enable_metrics=True))\n" + ) + + logger.info("Using global MeterProvider from application") + + return self._meter_provider + + def shutdown(self, timeout_millis: int = 30000) -> bool: + """ + Shutdown observability and flush any pending metrics. + + Note: We don't shutdown the global MeterProvider since it's owned by the application. + We only force flush pending metrics. + + Args: + timeout_millis: Maximum time to wait for flush + + Returns: + True if flush was successful, False otherwise + """ + logger.debug("Flushing metrics before shutdown (not shutting down global MeterProvider)") + return self.force_flush(timeout_millis=timeout_millis) + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """ + Force flush any pending metrics from the global MeterProvider. + + Args: + timeout_millis: Maximum time to wait for flush + + Returns: + True if flush was successful, False otherwise + """ + if self._meter_provider is None: + return True + + try: + logger.debug("Force flushing metrics from global MeterProvider") + self._meter_provider.force_flush(timeout_millis=timeout_millis) + return True + except Exception as e: + logger.error(f"Error flushing metrics: {e}") + return False + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - shutdown provider.""" + self.shutdown() + + def __repr__(self) -> str: + return f"OTelProviderManager(config={self.config})" + + +# Singleton instance class + +class ObservabilityInstance: + """ + Singleton instance for managing OpenTelemetry observability. + + This class follows the singleton pattern similar to Glide's GetOtelInstance(). + Use GetObservabilityInstance() to get the singleton instance, then call init() + to initialize observability. + + Example: + >>> from redis.observability.config import OTelConfig + >>> + >>> # Get singleton instance + >>> otel = get_observability_instance() + >>> + >>> # Initialize once at app startup + >>> otel.init(OTelConfig()) + >>> + >>> # All Redis clients now automatically collect metrics + >>> import redis + >>> r = redis.Redis(host='localhost', port=6379) + >>> r.set('key', 'value') # Metrics collected automatically + """ + + def __init__(self): + self._provider_manager: Optional[OTelProviderManager] = None + + def init(self, config: OTelConfig) -> "ObservabilityInstance": + """ + Initialize OpenTelemetry observability globally for all Redis clients. + + This should be called once at application startup. After initialization, + all Redis clients will automatically collect and export metrics without + needing any additional configuration. + + Safe to call multiple times - will shutdown previous instance before + initializing a new one. + + Args: + config: OTel configuration object + + Returns: + Self for method chaining + + Example: + >>> otel = get_observability_instance() + >>> otel.init(OTelConfig()) + """ + if self._provider_manager is not None: + logger.warning("Observability already initialized. Shutting down previous instance.") + self._provider_manager.shutdown() + + self._provider_manager = OTelProviderManager(config) + + logger.info("Observability initialized") + + return self + + def is_enabled(self) -> bool: + """ + Check if observability is enabled. + + Returns: + True if observability is initialized and metrics are enabled + + Example: + >>> otel = get_observability_instance() + >>> if otel.is_enabled(): + ... print("Metrics are being collected") + """ + return self._provider_manager is not None and self._provider_manager.config.is_enabled() + + def get_provider_manager(self) -> Optional[OTelProviderManager]: + """ + Get the provider manager instance. + + Returns: + The provider manager, or None if not initialized + + Example: + >>> otel = get_observability_instance() + >>> manager = otel.get_provider_manager() + >>> if manager is not None: + ... print(f"Observability enabled: {manager.config.is_enabled()}") + """ + return self._provider_manager + + def shutdown(self, timeout_millis: int = 30000) -> bool: + """ + Shutdown observability and flush any pending metrics. + + This should be called at application shutdown to ensure all metrics + are exported before the application exits. + + Args: + timeout_millis: Maximum time to wait for shutdown + + Returns: + True if shutdown was successful + + Example: + >>> otel = get_observability_instance() + >>> # At application shutdown + >>> otel.shutdown() + """ + if self._provider_manager is None: + logger.debug("Observability not initialized, nothing to shutdown") + return True + + success = self._provider_manager.shutdown(timeout_millis) + self._provider_manager = None + logger.info("Observability shutdown") + + return success + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """ + Force flush all pending metrics immediately. + + Useful for testing or when you want to ensure metrics are exported + before a specific point in your application. + + Args: + timeout_millis: Maximum time to wait for flush + + Returns: + True if flush was successful + + Example: + >>> otel = get_observability_instance() + >>> # Execute some Redis commands + >>> r.set('key', 'value') + >>> # Force flush metrics immediately + >>> otel.force_flush() + """ + if self._provider_manager is None: + logger.debug("Observability not initialized, nothing to flush") + return True + + return self._provider_manager.force_flush(timeout_millis) + + +# Global singleton instance +_observability_instance: Optional[ObservabilityInstance] = None + + +def get_observability_instance() -> ObservabilityInstance: + """ + Get the global observability singleton instance. + + This is the Pythonic way to get the singleton instance. + + Returns: + The global ObservabilityInstance singleton + + Example: + >>> + >>> otel = get_observability_instance() + >>> otel.init(OTelConfig()) + """ + global _observability_instance + + if _observability_instance is None: + _observability_instance = ObservabilityInstance() + + return _observability_instance \ No newline at end of file diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py new file mode 100644 index 0000000000..835e92aaf9 --- /dev/null +++ b/redis/observability/recorder.py @@ -0,0 +1,505 @@ +""" +Simple, clean API for recording observability metrics. + +This module provides a straightforward interface for Redis core code to record +metrics without needing to know about OpenTelemetry internals. + +Usage in Redis core code: + from redis.observability.recorder import record_operation_duration + + start_time = time.monotonic() + # ... execute Redis command ... + record_operation_duration( + command_name='SET', + duration_seconds=time.monotonic() - start_time, + server_address='localhost', + server_port=6379, + db_namespace='0', + error=None + ) +""" + +import time +from typing import Optional + +from redis.observability.attributes import PubSubDirection +from redis.observability.metrics import RedisMetricsCollector + +# Global metrics collector instance (lazy-initialized) +_metrics_collector: Optional[RedisMetricsCollector] = None + + +def record_operation_duration( + command_name: str, + duration_seconds: float, + server_address: Optional[str] = None, + server_port: Optional[int] = None, + db_namespace: Optional[str] = None, + error: Optional[Exception] = None, +) -> None: + """ + Record a Redis command execution duration. + + This is a simple, clean API that Redis core code can call directly. + If observability is not enabled, this returns immediately with zero overhead. + + Args: + command_name: Redis command name (e.g., 'GET', 'SET') + duration_seconds: Command execution time in seconds + server_address: Redis server address + server_port: Redis server port + db_namespace: Redis database index + error: Exception if command failed, None if successful + + Example: + >>> start = time.monotonic() + >>> # ... execute command ... + >>> record_operation_duration('SET', time.monotonic() - start, 'localhost', 6379, '0') + """ + global _metrics_collector + + # Fast path: if collector not initialized, observability is disabled + if _metrics_collector is None: + # Try to initialize (only once) + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return # Observability not enabled + + # Determine error type and status + status_code = "ok" + if error is not None: + status_code = "error" + + # Record the metric + # try: + _metrics_collector.record_operation_duration( + command_name=command_name, + duration_seconds=duration_seconds, + server_address=server_address, + server_port=server_port, + db_namespace=db_namespace, + error_type=error, + response_status_code=status_code, + network_peer_address=server_address, + network_peer_port=server_port, + ) + # except Exception: + # # Don't let metric recording errors break Redis operations + # pass + + +def record_connection_create_time( + pool_name: str, + duration_seconds: float, +) -> None: + """ + Record connection creation time. + + Args: + pool_name: Connection pool identifier + duration_seconds: Time taken to create connection in seconds + + Example: + >>> start = time.monotonic() + >>> # ... create connection ... + >>> record_connection_create_time('ConnectionPool', time.monotonic() - start) + """ + global _metrics_collector + + # Fast path: if collector not initialized, observability is disabled + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_create_time( + pool_name=pool_name, + duration_seconds=duration_seconds, + ) + # except Exception: + # pass + + +def record_connection_count( + count: int, + pool_name: str, + state: str, + is_pubsub: bool = False, +) -> None: + """ + Record current connection count by state. + + Args: + count: Increment/Decrement + pool_name: Connection pool identifier + state: Connection state ('idle' or 'used') + is_pubsub: Whether or not the connection is pubsub + + Example: + >>> record_connection_count(1, 'ConnectionPool', 'idle', False) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + from redis.observability.attributes import ConnectionState + connection_state = ConnectionState.IDLE if state == 'idle' else ConnectionState.USED + _metrics_collector.record_connection_count( + count=count, + pool_name=pool_name, + state=connection_state, + is_pubsub=is_pubsub, + ) + # except Exception: + # pass + + +def record_connection_timeout( + pool_name: str, +) -> None: + """ + Record a connection timeout event. + + Args: + pool_name: Connection pool identifier + + Example: + >>> record_connection_timeout('ConnectionPool') + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_timeout( + pool_name=pool_name, + ) + # except Exception: + # pass + + +def record_connection_wait_time( + pool_name: str, + duration_seconds: float, +) -> None: + """ + Record time taken to obtain a connection from the pool. + + Args: + pool_name: Connection pool identifier + duration_seconds: Wait time in seconds + + Example: + >>> start = time.monotonic() + >>> # ... wait for connection from pool ... + >>> record_connection_wait_time('ConnectionPool', time.monotonic() - start) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_wait_time( + pool_name=pool_name, + duration_seconds=duration_seconds, + ) + # except Exception: + # pass + + +def record_connection_use_time( + pool_name: str, + duration_seconds: float, +) -> None: + """ + Record time a connection was in use (borrowed from pool). + + Args: + pool_name: Connection pool identifier + duration_seconds: Use time in seconds + + Example: + >>> start = time.monotonic() + >>> # ... use connection ... + >>> record_connection_use_time('ConnectionPool', time.monotonic() - start) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_use_time( + pool_name=pool_name, + duration_seconds=duration_seconds, + ) + # except Exception: + # pass + + +def record_connection_closed( + pool_name: str, + close_reason: Optional[str] = None, + error_type: Optional[Exception] = None, +) -> None: + """ + Record a connection closed event. + + Args: + pool_name: Connection pool identifier + close_reason: Reason for closing (e.g., 'idle_timeout', 'error', 'shutdown') + error_type: Error type if closed due to error + + Example: + >>> record_connection_closed('ConnectionPool', 'idle_timeout') + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_closed( + pool_name=pool_name, + close_reason=close_reason, + error_type=error_type, + ) + # except Exception: + # pass + + +def record_connection_relaxed_timeout( + pool_name: str, + maint_notification: str, + relaxed: bool, +) -> None: + """ + Record a connection timeout relaxation event. + + Args: + pool_name: Connection pool identifier + maint_notification: Maintenance notification type + relaxed: True to count up (relaxed), False to count down (unrelaxed) + + Example: + >>> record_connection_relaxed_timeout('ConnectionPool', 'MOVING', True) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_relaxed_timeout( + pool_name=pool_name, + maint_notification=maint_notification, + relaxed=relaxed, + ) + # except Exception: + # pass + + +def record_connection_handoff( + pool_name: str, +) -> None: + """ + Record a connection handoff event (e.g., after MOVING notification). + + Args: + pool_name: Connection pool identifier + + Example: + >>> record_connection_handoff('ConnectionPool') + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_connection_handoff( + pool_name=pool_name, + ) + # except Exception: + # pass + + +def record_error_count( + server_address: str, + server_port: int, + network_peer_address: str, + network_peer_port: int, + error_type: Exception, + retry_attempts: int, +) -> None: + """ + Record error count. + + Args: + server_address: Server address + server_port: Server port + network_peer_address: Network peer address + network_peer_port: Network peer port + error_type: Error type (Exception) + retry_attempts: Retry attempts + + Example: + >>> record_error_count('localhost', 6379, 'localhost', 6379, ConnectionError(), 3) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_error_count( + server_address=server_address, + server_port=server_port, + network_peer_address=network_peer_address, + network_peer_port=network_peer_port, + error_type=error_type, + retry_attempts=retry_attempts, + ) + # except Exception: + # pass + + +def record_pubsub_message( + direction: PubSubDirection, + channel: Optional[str] = None, + sharded: Optional[bool] = None, +) -> None: + """ + Record a PubSub message (published or received). + + Args: + direction: Message direction ('publish' or 'receive') + channel: Pub/Sub channel name + sharded: True if sharded Pub/Sub channel + + Example: + >>> record_pubsub_message(PubSubDirection.PUBLISH, 'channel', False) + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_pubsub_message( + direction=direction, + channel=channel, + sharded=sharded, + ) + # except Exception: + # pass + + +def record_streaming_lag( + lag_seconds: float, + stream_name: Optional[str] = None, + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, +) -> None: + """ + Record the lag of a streaming message. + + Args: + lag_seconds: Lag in seconds + stream_name: Stream name + consumer_group: Consumer group name + consumer_name: Consumer name + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + if _metrics_collector is None: + return + + # try: + _metrics_collector.record_streaming_lag( + lag_seconds=lag_seconds, + stream_name=stream_name, + consumer_group=consumer_group, + consumer_name=consumer_name, + ) + # except Exception: + # pass + + +def _get_or_create_collector() -> Optional[RedisMetricsCollector]: + """ + Get or create the global metrics collector. + + Returns: + RedisMetricsCollector instance if observability is enabled, None otherwise + """ + try: + from redis.observability.providers import get_provider_manager + from redis.observability.metrics import RedisMetricsCollector + + manager = get_provider_manager() + if manager is None or not manager.config.enable_metrics: + return None + + # Get meter from the global MeterProvider + meter = manager.get_meter_provider().get_meter( + RedisMetricsCollector.METER_NAME, + RedisMetricsCollector.METER_VERSION + ) + + return RedisMetricsCollector(meter, manager.config) + + except ImportError: + # Observability module not available + return None + except Exception: + # Any other error - don't break Redis operations + return None + + +def reset_collector() -> None: + """ + Reset the global collector (used for testing or re-initialization). + """ + global _metrics_collector + _metrics_collector = None + + +def is_enabled() -> bool: + """ + Check if observability is enabled. + + Returns: + True if metrics are being collected, False otherwise + """ + global _metrics_collector + + if _metrics_collector is None: + _metrics_collector = _get_or_create_collector() + + return _metrics_collector is not None \ No newline at end of file diff --git a/tests/test_observability/__init__.py b/tests/test_observability/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_observability/test_config.py b/tests/test_observability/test_config.py new file mode 100644 index 0000000000..6d49c803dd --- /dev/null +++ b/tests/test_observability/test_config.py @@ -0,0 +1,348 @@ +""" +Unit tests for redis.observability.config module. + +These tests verify the OTelConfig class behavior including: +- Default configuration values +- Custom configuration via constructor parameters +- Validation of configuration values +- Command filtering (include/exclude lists) +- Runtime configuration changes +""" + +import pytest + +from redis.observability.config import OTelConfig, MetricGroup, TelemetryOption + + +class TestOTelConfigDefaults: + """Tests for OTelConfig default values.""" + + def test_default_enabled_telemetry(self): + """Test that default telemetry is METRICS only.""" + config = OTelConfig() + assert config.enabled_telemetry == TelemetryOption.METRICS + + def test_default_metric_groups(self): + """Test that default metric groups are COMMAND, CONNECTION_BASIC, RESILIENCY.""" + config = OTelConfig() + expected = MetricGroup.COMMAND | MetricGroup.CONNECTION_BASIC | MetricGroup.RESILIENCY + assert config.metric_groups == expected + + def test_default_sample_percentage(self): + """Test that default sample percentage is 100.0.""" + config = OTelConfig() + assert config.metrics_sample_percentage == 100.0 + + def test_default_include_commands_is_none(self): + """Test that include_commands is None by default.""" + config = OTelConfig() + assert config.include_commands is None + + def test_default_exclude_commands_is_empty_set(self): + """Test that exclude_commands is empty set by default.""" + config = OTelConfig() + assert config.exclude_commands == set() + + def test_is_enabled_returns_true_by_default(self): + """Test that is_enabled returns True with default config.""" + config = OTelConfig() + assert config.is_enabled() is True + + +class TestOTelConfigEnabledTelemetry: + """Tests for enabled_telemetry configuration.""" + + def test_single_telemetry_option(self): + """Test setting a single telemetry option.""" + config = OTelConfig(enabled_telemetry=[TelemetryOption.METRICS]) + assert config.enabled_telemetry == TelemetryOption.METRICS + + def test_multiple_telemetry_options(self): + """Test setting multiple telemetry options.""" + config = OTelConfig( + enabled_telemetry=[TelemetryOption.METRICS, TelemetryOption.TRACES] + ) + assert TelemetryOption.METRICS in config.enabled_telemetry + assert TelemetryOption.TRACES in config.enabled_telemetry + + def test_all_telemetry_options(self): + """Test setting all telemetry options.""" + config = OTelConfig( + enabled_telemetry=[ + TelemetryOption.METRICS, + TelemetryOption.TRACES, + TelemetryOption.LOGS, + ] + ) + assert TelemetryOption.METRICS in config.enabled_telemetry + assert TelemetryOption.TRACES in config.enabled_telemetry + assert TelemetryOption.LOGS in config.enabled_telemetry + + def test_empty_telemetry_list_disables_all(self): + """Test that empty telemetry list disables all telemetry.""" + config = OTelConfig(enabled_telemetry=[]) + assert config.enabled_telemetry == TelemetryOption(0) + assert config.is_enabled() is False + + +class TestOTelConfigMetricGroups: + """Tests for metric_groups configuration.""" + + def test_single_metric_group(self): + """Test setting a single metric group.""" + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + assert config.metric_groups == MetricGroup.COMMAND + + def test_multiple_metric_groups(self): + """Test setting multiple metric groups.""" + config = OTelConfig( + metric_groups=[MetricGroup.COMMAND, MetricGroup.PUBSUB] + ) + assert MetricGroup.COMMAND in config.metric_groups + assert MetricGroup.PUBSUB in config.metric_groups + + def test_all_metric_groups(self): + """Test setting all metric groups.""" + config = OTelConfig( + metric_groups=[ + MetricGroup.RESILIENCY, + MetricGroup.CONNECTION_BASIC, + MetricGroup.CONNECTION_ADVANCED, + MetricGroup.COMMAND, + MetricGroup.CSC, + MetricGroup.STREAMING, + MetricGroup.PUBSUB, + ] + ) + assert MetricGroup.RESILIENCY in config.metric_groups + assert MetricGroup.CONNECTION_BASIC in config.metric_groups + assert MetricGroup.CONNECTION_ADVANCED in config.metric_groups + assert MetricGroup.COMMAND in config.metric_groups + assert MetricGroup.CSC in config.metric_groups + assert MetricGroup.STREAMING in config.metric_groups + assert MetricGroup.PUBSUB in config.metric_groups + + def test_empty_metric_groups_list(self): + """Test that empty metric groups list results in no groups enabled.""" + config = OTelConfig(metric_groups=[]) + assert config.metric_groups == MetricGroup(0) + + +class TestOTelConfigSamplePercentage: + """Tests for metrics_sample_percentage configuration.""" + + def test_valid_sample_percentage_zero(self): + """Test that 0.0 sample percentage is valid.""" + config = OTelConfig(metrics_sample_percentage=0.0) + assert config.metrics_sample_percentage == 0.0 + + def test_valid_sample_percentage_hundred(self): + """Test that 100.0 sample percentage is valid.""" + config = OTelConfig(metrics_sample_percentage=100.0) + assert config.metrics_sample_percentage == 100.0 + + def test_valid_sample_percentage_middle(self): + """Test that middle value sample percentage is valid.""" + config = OTelConfig(metrics_sample_percentage=50.5) + assert config.metrics_sample_percentage == 50.5 + + def test_invalid_sample_percentage_negative(self): + """Test that negative sample percentage raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + OTelConfig(metrics_sample_percentage=-1.0) + assert "metrics_sample_percentage must be between 0.0 and 100.0" in str(exc_info.value) + + def test_invalid_sample_percentage_over_hundred(self): + """Test that sample percentage over 100 raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + OTelConfig(metrics_sample_percentage=100.1) + assert "metrics_sample_percentage must be between 0.0 and 100.0" in str(exc_info.value) + + +class TestOTelConfigSetSamplePercentage: + """Tests for set_sample_percentage method.""" + + def test_set_sample_percentage_valid(self): + """Test setting valid sample percentage at runtime.""" + config = OTelConfig() + config.set_sample_percentage(25.0) + assert config.metrics_sample_percentage == 25.0 + + def test_set_sample_percentage_zero(self): + """Test setting sample percentage to zero.""" + config = OTelConfig() + config.set_sample_percentage(0.0) + assert config.metrics_sample_percentage == 0.0 + + def test_set_sample_percentage_hundred(self): + """Test setting sample percentage to 100.""" + config = OTelConfig(metrics_sample_percentage=50.0) + config.set_sample_percentage(100.0) + assert config.metrics_sample_percentage == 100.0 + + def test_set_sample_percentage_invalid_negative(self): + """Test that setting negative sample percentage raises ValueError.""" + config = OTelConfig() + with pytest.raises(ValueError) as exc_info: + config.set_sample_percentage(-5.0) + assert "metrics_sample_percentage must be between 0.0 and 100.0" in str(exc_info.value) + + def test_set_sample_percentage_invalid_over_hundred(self): + """Test that setting sample percentage over 100 raises ValueError.""" + config = OTelConfig() + with pytest.raises(ValueError) as exc_info: + config.set_sample_percentage(150.0) + assert "metrics_sample_percentage must be between 0.0 and 100.0" in str(exc_info.value) + + +class TestOTelConfigIncludeCommands: + """Tests for include_commands configuration.""" + + def test_include_commands_single(self): + """Test include_commands with single command.""" + config = OTelConfig(include_commands=['GET']) + assert config.include_commands == {'GET'} + + def test_include_commands_multiple(self): + """Test include_commands with multiple commands.""" + config = OTelConfig(include_commands=['GET', 'SET', 'DEL']) + assert config.include_commands == {'GET', 'SET', 'DEL'} + + def test_include_commands_empty_list(self): + """Test include_commands with empty list results in empty set.""" + config = OTelConfig(include_commands=[]) + assert config.include_commands == None + + +class TestOTelConfigExcludeCommands: + """Tests for exclude_commands configuration.""" + + def test_exclude_commands_single(self): + """Test exclude_commands with single command.""" + config = OTelConfig(exclude_commands=['PING']) + assert config.exclude_commands == {'PING'} + + def test_exclude_commands_multiple(self): + """Test exclude_commands with multiple commands.""" + config = OTelConfig(exclude_commands=['PING', 'INFO', 'DEBUG']) + assert config.exclude_commands == {'PING', 'INFO', 'DEBUG'} + + def test_exclude_commands_empty_list(self): + """Test exclude_commands with empty list results in empty set.""" + config = OTelConfig(exclude_commands=[]) + assert config.exclude_commands == set() + + +class TestOTelConfigShouldTrackCommand: + """Tests for should_track_command method.""" + + def test_should_track_command_default_tracks_all(self): + """Test that all commands are tracked by default.""" + config = OTelConfig() + assert config.should_track_command('GET') is True + assert config.should_track_command('SET') is True + assert config.should_track_command('PING') is True + + def test_should_track_command_case_insensitive(self): + """Test that command matching is case-insensitive.""" + config = OTelConfig(include_commands=['GET']) + assert config.should_track_command('GET') is True + assert config.should_track_command('get') is True + assert config.should_track_command('Get') is True + + def test_should_track_command_with_include_list(self): + """Test that only included commands are tracked.""" + config = OTelConfig(include_commands=['GET', 'SET']) + assert config.should_track_command('GET') is True + assert config.should_track_command('SET') is True + assert config.should_track_command('DEL') is False + assert config.should_track_command('PING') is False + + def test_should_track_command_with_exclude_list(self): + """Test that excluded commands are not tracked.""" + config = OTelConfig(exclude_commands=['PING', 'INFO']) + assert config.should_track_command('GET') is True + assert config.should_track_command('SET') is True + assert config.should_track_command('PING') is False + assert config.should_track_command('INFO') is False + + def test_should_track_command_include_takes_precedence(self): + """Test that include_commands takes precedence over exclude_commands.""" + # When include_commands is set, exclude_commands is ignored + config = OTelConfig( + include_commands=['GET', 'SET'], + exclude_commands=['GET'] # This should be ignored + ) + assert config.should_track_command('GET') is True + assert config.should_track_command('SET') is True + assert config.should_track_command('DEL') is False + + def test_should_track_command_empty_include_tracks_all(self): + """Test that empty include list tracks all commands.""" + config = OTelConfig(include_commands=[]) + assert config.should_track_command('GET') is True + assert config.should_track_command('SET') is True + + +class TestOTelConfigRepr: + """Tests for __repr__ method.""" + + def test_repr_contains_enabled_telemetry(self): + """Test that repr contains enabled_telemetry.""" + config = OTelConfig() + repr_str = repr(config) + assert 'enabled_telemetry' in repr_str + + +class TestMetricGroupEnum: + """Tests for MetricGroup IntFlag enum.""" + + def test_metric_group_values_are_unique(self): + """Test that all MetricGroup values are unique powers of 2.""" + values = [ + MetricGroup.RESILIENCY, + MetricGroup.CONNECTION_BASIC, + MetricGroup.CONNECTION_ADVANCED, + MetricGroup.COMMAND, + MetricGroup.CSC, + MetricGroup.STREAMING, + MetricGroup.PUBSUB, + ] + # Each value should be a power of 2 + for value in values: + assert value & (value - 1) == 0 # Power of 2 check + + def test_metric_group_can_be_combined(self): + """Test that MetricGroup values can be combined with bitwise OR.""" + combined = MetricGroup.COMMAND | MetricGroup.PUBSUB + assert MetricGroup.COMMAND in combined + assert MetricGroup.PUBSUB in combined + assert MetricGroup.STREAMING not in combined + + def test_metric_group_membership_check(self): + """Test checking membership in combined MetricGroup.""" + combined = MetricGroup.RESILIENCY | MetricGroup.CONNECTION_BASIC + assert bool(combined & MetricGroup.RESILIENCY) + assert bool(combined & MetricGroup.CONNECTION_BASIC) + assert not bool(combined & MetricGroup.COMMAND) + + +class TestTelemetryOptionEnum: + """Tests for TelemetryOption IntFlag enum.""" + + def test_telemetry_option_values_are_unique(self): + """Test that all TelemetryOption values are unique powers of 2.""" + values = [ + TelemetryOption.METRICS, + TelemetryOption.TRACES, + TelemetryOption.LOGS, + ] + for value in values: + assert value & (value - 1) == 0 # Power of 2 check + + def test_telemetry_option_can_be_combined(self): + """Test that TelemetryOption values can be combined with bitwise OR.""" + combined = TelemetryOption.METRICS | TelemetryOption.TRACES + assert TelemetryOption.METRICS in combined + assert TelemetryOption.TRACES in combined + assert TelemetryOption.LOGS not in combined diff --git a/tests/test_observability/test_provider.py b/tests/test_observability/test_provider.py new file mode 100644 index 0000000000..90d8f61fb6 --- /dev/null +++ b/tests/test_observability/test_provider.py @@ -0,0 +1,457 @@ +""" +Unit tests for redis.observability.providers module. + +These tests verify the OTelProviderManager and ObservabilityInstance classes including: +- Provider initialization and configuration +- MeterProvider retrieval and validation +- Shutdown and force flush operations +- Singleton pattern behavior +- Context manager support +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from redis.observability.config import OTelConfig, TelemetryOption +from redis.observability.providers import ( + OTelProviderManager, + ObservabilityInstance, + get_observability_instance, +) + + +class TestOTelProviderManagerInit: + """Tests for OTelProviderManager initialization.""" + + def test_init_with_config(self): + """Test that OTelProviderManager initializes with config.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + assert manager.config is config + assert manager._meter_provider is None + + def test_init_with_custom_config(self): + """Test initialization with custom config.""" + config = OTelConfig( + enabled_telemetry=[TelemetryOption.METRICS], + metrics_sample_percentage=50.0, + ) + manager = OTelProviderManager(config) + + assert manager.config.metrics_sample_percentage == 50.0 + + +class TestOTelProviderManagerGetMeterProvider: + """Tests for get_meter_provider method.""" + + def test_get_meter_provider_returns_none_when_disabled(self): + """Test that get_meter_provider returns None when telemetry is disabled.""" + config = OTelConfig(enabled_telemetry=[]) + manager = OTelProviderManager(config) + + result = manager.get_meter_provider() + + assert result is None + + def test_get_meter_provider_raises_when_no_global_provider(self): + """Test that get_meter_provider raises RuntimeError when no global provider is set.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + with patch('opentelemetry.metrics') as mock_metrics: + from opentelemetry.metrics import NoOpMeterProvider + mock_metrics.get_meter_provider.return_value = NoOpMeterProvider() + + with pytest.raises(RuntimeError) as exc_info: + manager.get_meter_provider() + + assert "no global MeterProvider is configured" in str(exc_info.value) + + def test_get_meter_provider_returns_global_provider(self): + """Test that get_meter_provider returns the global MeterProvider.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + mock_provider = Mock() + # Make sure it's not a NoOpMeterProvider + mock_provider.__class__.__name__ = 'MeterProvider' + + with patch('opentelemetry.metrics') as mock_metrics: + mock_metrics.get_meter_provider.return_value = mock_provider + result = manager.get_meter_provider() + + assert result is mock_provider + + def test_get_meter_provider_caches_provider(self): + """Test that get_meter_provider caches the provider.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + mock_provider = Mock() + + with patch('opentelemetry.metrics') as mock_metrics: + mock_metrics.get_meter_provider.return_value = mock_provider + + # Call twice + result1 = manager.get_meter_provider() + result2 = manager.get_meter_provider() + + # Should only call get_meter_provider once due to caching + assert mock_metrics.get_meter_provider.call_count == 1 + assert result1 is result2 + + +class TestOTelProviderManagerShutdown: + """Tests for shutdown method.""" + + def test_shutdown_calls_force_flush(self): + """Test that shutdown calls force_flush.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + with patch.object(manager, 'force_flush', return_value=True) as mock_flush: + result = manager.shutdown(timeout_millis=5000) + + mock_flush.assert_called_once_with(timeout_millis=5000) + assert result is True + + def test_shutdown_with_default_timeout(self): + """Test shutdown with default timeout.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + with patch.object(manager, 'force_flush', return_value=True) as mock_flush: + manager.shutdown() + + mock_flush.assert_called_once_with(timeout_millis=30000) + + +class TestOTelProviderManagerForceFlush: + """Tests for force_flush method.""" + + def test_force_flush_returns_true_when_no_provider(self): + """Test that force_flush returns True when no provider is set.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + result = manager.force_flush() + + assert result is True + + def test_force_flush_calls_provider_force_flush(self): + """Test that force_flush calls the provider's force_flush.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + mock_provider = Mock() + manager._meter_provider = mock_provider + + result = manager.force_flush(timeout_millis=5000) + + mock_provider.force_flush.assert_called_once_with(timeout_millis=5000) + assert result is True + + def test_force_flush_returns_false_on_exception(self): + """Test that force_flush returns False when an exception occurs.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + mock_provider = Mock() + mock_provider.force_flush.side_effect = Exception("Flush failed") + manager._meter_provider = mock_provider + + result = manager.force_flush() + + assert result is False + + +class TestOTelProviderManagerContextManager: + """Tests for context manager support.""" + + def test_context_manager_enter_returns_self(self): + """Test that __enter__ returns self.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + result = manager.__enter__() + + assert result is manager + + def test_context_manager_exit_calls_shutdown(self): + """Test that __exit__ calls shutdown.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + with patch.object(manager, 'shutdown') as mock_shutdown: + manager.__exit__(None, None, None) + + mock_shutdown.assert_called_once() + + def test_context_manager_with_statement(self): + """Test using OTelProviderManager with 'with' statement.""" + config = OTelConfig() + + with patch.object(OTelProviderManager, 'shutdown') as mock_shutdown: + with OTelProviderManager(config) as manager: + assert manager.config is config + + mock_shutdown.assert_called_once() + + +class TestOTelProviderManagerRepr: + """Tests for __repr__ method.""" + + def test_repr_contains_config(self): + """Test that repr contains config information.""" + config = OTelConfig() + manager = OTelProviderManager(config) + + repr_str = repr(manager) + + assert 'OTelProviderManager' in repr_str + assert 'config=' in repr_str + + +class TestObservabilityInstanceInit: + """Tests for ObservabilityInstance initialization.""" + + def test_init_creates_empty_instance(self): + """Test that ObservabilityInstance initializes with no provider manager.""" + instance = ObservabilityInstance() + + assert instance._provider_manager is None + + def test_init_method_creates_provider_manager(self): + """Test that init() creates a provider manager.""" + instance = ObservabilityInstance() + config = OTelConfig() + + result = instance.init(config) + + assert instance._provider_manager is not None + assert instance._provider_manager.config is config + assert result is instance # Returns self for chaining + + def test_init_method_replaces_existing_manager(self): + """Test that init() replaces existing provider manager.""" + instance = ObservabilityInstance() + config1 = OTelConfig(metrics_sample_percentage=50.0) + config2 = OTelConfig(metrics_sample_percentage=75.0) + + instance.init(config1) + old_manager = instance._provider_manager + + with patch.object(old_manager, 'shutdown') as mock_shutdown: + instance.init(config2) + + mock_shutdown.assert_called_once() + assert instance._provider_manager.config.metrics_sample_percentage == 75.0 + + +class TestObservabilityInstanceIsEnabled: + """Tests for is_enabled method.""" + + def test_is_enabled_returns_false_when_not_initialized(self): + """Test that is_enabled returns False when not initialized.""" + instance = ObservabilityInstance() + + assert instance.is_enabled() is False + + def test_is_enabled_returns_true_when_initialized_and_enabled(self): + """Test that is_enabled returns True when initialized with enabled config.""" + instance = ObservabilityInstance() + config = OTelConfig() + + instance.init(config) + + assert instance.is_enabled() is True + + def test_is_enabled_returns_false_when_telemetry_disabled(self): + """Test that is_enabled returns False when telemetry is disabled.""" + instance = ObservabilityInstance() + config = OTelConfig(enabled_telemetry=[]) + + instance.init(config) + + assert instance.is_enabled() is False + + +class TestObservabilityInstanceGetProviderManager: + """Tests for get_provider_manager method.""" + + def test_get_provider_manager_returns_none_when_not_initialized(self): + """Test that get_provider_manager returns None when not initialized.""" + instance = ObservabilityInstance() + + assert instance.get_provider_manager() is None + + def test_get_provider_manager_returns_manager_when_initialized(self): + """Test that get_provider_manager returns the manager when initialized.""" + instance = ObservabilityInstance() + config = OTelConfig() + + instance.init(config) + + manager = instance.get_provider_manager() + assert manager is not None + assert manager.config is config + + +class TestObservabilityInstanceShutdown: + """Tests for shutdown method.""" + + def test_shutdown_returns_true_when_not_initialized(self): + """Test that shutdown returns True when not initialized.""" + instance = ObservabilityInstance() + + result = instance.shutdown() + + assert result is True + + def test_shutdown_calls_provider_manager_shutdown(self): + """Test that shutdown calls the provider manager's shutdown.""" + instance = ObservabilityInstance() + config = OTelConfig() + instance.init(config) + + with patch.object(instance._provider_manager, 'shutdown', return_value=True) as mock_shutdown: + result = instance.shutdown(timeout_millis=5000) + + mock_shutdown.assert_called_once_with(5000) + assert result is True + assert instance._provider_manager is None + + def test_shutdown_clears_provider_manager(self): + """Test that shutdown clears the provider manager.""" + instance = ObservabilityInstance() + config = OTelConfig() + instance.init(config) + + with patch.object(instance._provider_manager, 'shutdown', return_value=True): + instance.shutdown() + + assert instance._provider_manager is None + + +class TestObservabilityInstanceForceFlush: + """Tests for force_flush method.""" + + def test_force_flush_returns_true_when_not_initialized(self): + """Test that force_flush returns True when not initialized.""" + instance = ObservabilityInstance() + + result = instance.force_flush() + + assert result is True + + def test_force_flush_calls_provider_manager_force_flush(self): + """Test that force_flush calls the provider manager's force_flush.""" + instance = ObservabilityInstance() + config = OTelConfig() + instance.init(config) + + with patch.object(instance._provider_manager, 'force_flush', return_value=True) as mock_flush: + result = instance.force_flush(timeout_millis=5000) + + mock_flush.assert_called_once_with(5000) + assert result is True + + +class TestGetObservabilityInstance: + """Tests for get_observability_instance function.""" + + def test_get_observability_instance_returns_singleton(self): + """Test that get_observability_instance returns the same instance.""" + # Reset the global instance for this test + import redis.observability.providers as providers + original_instance = providers._observability_instance + + try: + providers._observability_instance = None + + instance1 = get_observability_instance() + instance2 = get_observability_instance() + + assert instance1 is instance2 + finally: + # Restore original instance + providers._observability_instance = original_instance + + def test_get_observability_instance_creates_new_if_none(self): + """Test that get_observability_instance creates a new instance if none exists.""" + import redis.observability.providers as providers + original_instance = providers._observability_instance + + try: + providers._observability_instance = None + + instance = get_observability_instance() + + assert instance is not None + assert isinstance(instance, ObservabilityInstance) + finally: + providers._observability_instance = original_instance + + def test_get_observability_instance_returns_existing(self): + """Test that get_observability_instance returns existing instance.""" + import redis.observability.providers as providers + original_instance = providers._observability_instance + + try: + existing = ObservabilityInstance() + providers._observability_instance = existing + + instance = get_observability_instance() + + assert instance is existing + finally: + providers._observability_instance = original_instance + + +class TestObservabilityInstanceIntegration: + """Integration tests for ObservabilityInstance.""" + + def test_full_lifecycle(self): + """Test full lifecycle: init -> use -> shutdown.""" + instance = ObservabilityInstance() + config = OTelConfig() + + # Initialize + result = instance.init(config) + assert result is instance + assert instance.is_enabled() is True + + # Get provider manager + manager = instance.get_provider_manager() + assert manager is not None + + # Force flush (with mocked provider) + with patch.object(manager, 'force_flush', return_value=True): + flush_result = instance.force_flush() + assert flush_result is True + + # Shutdown + with patch.object(manager, 'shutdown', return_value=True): + shutdown_result = instance.shutdown() + assert shutdown_result is True + + assert instance._provider_manager is None + assert instance.is_enabled() is False + + def test_reinitialize_after_shutdown(self): + """Test that instance can be reinitialized after shutdown.""" + instance = ObservabilityInstance() + config1 = OTelConfig(metrics_sample_percentage=50.0) + config2 = OTelConfig(metrics_sample_percentage=75.0) + + # First initialization + instance.init(config1) + with patch.object(instance._provider_manager, 'shutdown', return_value=True): + instance.shutdown() + + # Second initialization + instance.init(config2) + + assert instance.is_enabled() is True + assert instance._provider_manager.config.metrics_sample_percentage == 75.0 diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py new file mode 100644 index 0000000000..b78f964dce --- /dev/null +++ b/tests/test_observability/test_recorder.py @@ -0,0 +1,649 @@ +""" +Unit tests for redis.observability.recorder module. + +These tests verify that recorder functions correctly pass arguments through +to the underlying OTel Meter instruments (Counter, Histogram, UpDownCounter). +The MeterProvider is mocked to verify the actual integration point where +metrics are exported to OTel. +""" + +import pytest +from unittest.mock import MagicMock, patch + +from redis.observability import recorder +from redis.observability.attributes import ( + ConnectionState, + PubSubDirection, + # Connection pool attributes + DB_CLIENT_CONNECTION_POOL_NAME, + DB_CLIENT_CONNECTION_STATE, + REDIS_CLIENT_CONNECTION_PUBSUB, + # Server attributes + SERVER_ADDRESS, + SERVER_PORT, + # Database attributes + DB_NAMESPACE, + DB_OPERATION_NAME, + DB_OPERATION_BATCH_SIZE, + DB_RESPONSE_STATUS_CODE, + # Error attributes + ERROR_TYPE, + # Network attributes + NETWORK_PEER_ADDRESS, + NETWORK_PEER_PORT, + # Redis-specific attributes + REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS, + REDIS_CLIENT_CONNECTION_CLOSE_REASON, + REDIS_CLIENT_CONNECTION_NOTIFICATION, + REDIS_CLIENT_PUBSUB_MESSAGE_DIRECTION, + REDIS_CLIENT_PUBSUB_CHANNEL, + REDIS_CLIENT_PUBSUB_SHARDED, + # Streaming attributes + REDIS_CLIENT_STREAM_NAME, + REDIS_CLIENT_CONSUMER_GROUP, + REDIS_CLIENT_CONSUMER_NAME, +) +from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.recorder import record_operation_duration, record_connection_create_time, \ + record_connection_count, record_connection_timeout, record_connection_wait_time, record_connection_use_time, \ + record_connection_closed, record_connection_relaxed_timeout, record_connection_handoff, record_error_count, \ + record_pubsub_message, reset_collector, record_streaming_lag + + +class MockInstruments: + """Container for mock OTel instruments.""" + + def __init__(self): + # Counters + self.client_errors = MagicMock() + self.maintenance_notifications = MagicMock() + self.connection_timeouts = MagicMock() + self.connection_closed = MagicMock() + self.connection_handoff = MagicMock() + self.pubsub_messages = MagicMock() + + # UpDownCounters + self.connection_count = MagicMock() + self.connection_relaxed_timeout = MagicMock() + + # Histograms + self.connection_create_time = MagicMock() + self.connection_wait_time = MagicMock() + self.connection_use_time = MagicMock() + self.operation_duration = MagicMock() + self.stream_lag = MagicMock() + + +@pytest.fixture +def mock_instruments(): + """Create mock OTel instruments.""" + return MockInstruments() + + +@pytest.fixture +def mock_meter(mock_instruments): + """Create a mock Meter that returns our mock instruments.""" + meter = MagicMock() + + def create_counter_side_effect(name, **kwargs): + instrument_map = { + 'redis.client.errors': mock_instruments.client_errors, + 'redis.client.maintenance.notifications': mock_instruments.maintenance_notifications, + 'db.client.connection.timeouts': mock_instruments.connection_timeouts, + 'redis.client.connection.closed': mock_instruments.connection_closed, + 'redis.client.connection.handoff': mock_instruments.connection_handoff, + 'redis.client.pubsub.messages': mock_instruments.pubsub_messages, + } + return instrument_map.get(name, MagicMock()) + + def create_up_down_counter_side_effect(name, **kwargs): + instrument_map = { + 'db.client.connection.count': mock_instruments.connection_count, + 'redis.client.connection.relaxed_timeout': mock_instruments.connection_relaxed_timeout, + } + return instrument_map.get(name, MagicMock()) + + def create_histogram_side_effect(name, **kwargs): + instrument_map = { + 'db.client.connection.create_time': mock_instruments.connection_create_time, + 'db.client.connection.wait_time': mock_instruments.connection_wait_time, + 'db.client.connection.use_time': mock_instruments.connection_use_time, + 'db.client.operation.duration': mock_instruments.operation_duration, + 'redis.client.stream.lag': mock_instruments.stream_lag, + } + return instrument_map.get(name, MagicMock()) + + meter.create_counter.side_effect = create_counter_side_effect + meter.create_up_down_counter.side_effect = create_up_down_counter_side_effect + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + +@pytest.fixture +def mock_config(): + """Create a config with all metric groups enabled.""" + config = OTelConfig( + metric_groups=[ + MetricGroup.RESILIENCY, + MetricGroup.CONNECTION_BASIC, + MetricGroup.CONNECTION_ADVANCED, + MetricGroup.COMMAND, + MetricGroup.PUBSUB, + MetricGroup.STREAMING, + ] + ) + return config + + +@pytest.fixture +def metrics_collector(mock_meter, mock_config): + """Create a real RedisMetricsCollector with mocked Meter.""" + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + from redis.observability.metrics import RedisMetricsCollector + collector = RedisMetricsCollector(mock_meter, mock_config) + return collector + + +@pytest.fixture +def setup_recorder(metrics_collector, mock_instruments): + """ + Setup the recorder module with our collector that has mocked instruments. + """ + from redis.observability import recorder + + # Reset the global collector before test + recorder.reset_collector() + + # Patch _get_or_create_collector to return our collector with mocked instruments + with patch.object( + recorder, + '_get_or_create_collector', + return_value=metrics_collector + ): + yield mock_instruments + + # Reset after test + recorder.reset_collector() + + +class TestRecordOperationDuration: + """Tests for record_operation_duration - verifies Histogram.record() calls.""" + + def test_record_operation_duration_success(self, setup_recorder): + """Test that operation duration is recorded to the histogram with correct attributes.""" + + instruments = setup_recorder + + record_operation_duration( + command_name='SET', + duration_seconds=0.005, + server_address='localhost', + server_port=6379, + db_namespace='0', + error=None, + ) + + # Verify histogram.record() was called + instruments.operation_duration.record.assert_called_once() + call_args = instruments.operation_duration.record.call_args + + # Verify duration value + assert call_args[0][0] == 0.005 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs[SERVER_ADDRESS] == 'localhost' + assert attrs[SERVER_PORT] == 6379 + assert attrs[DB_NAMESPACE] == '0' + assert attrs[DB_OPERATION_NAME] == 'SET' + assert attrs[DB_RESPONSE_STATUS_CODE] == 'ok' + + def test_record_operation_duration_with_error(self, setup_recorder): + """Test that error information is included in attributes.""" + + instruments = setup_recorder + + error = ConnectionError("Connection refused") + record_operation_duration( + command_name='GET', + duration_seconds=0.001, + server_address='localhost', + server_port=6379, + error=error, + ) + + instruments.operation_duration.record.assert_called_once() + call_args = instruments.operation_duration.record.call_args + + attrs = call_args[1]['attributes'] + assert attrs[DB_OPERATION_NAME] == 'GET' + assert attrs[DB_RESPONSE_STATUS_CODE] is None + assert attrs[ERROR_TYPE] == 'ConnectionError' + + +class TestRecordConnectionCreateTime: + """Tests for record_connection_create_time - verifies Histogram.record() calls.""" + + def test_record_connection_create_time(self, setup_recorder): + """Test that connection creation time is recorded with pool name.""" + + instruments = setup_recorder + + record_connection_create_time( + pool_name='ConnectionPool', + duration_seconds=0.025, + ) + + instruments.connection_create_time.record.assert_called_once() + call_args = instruments.connection_create_time.record.call_args + + # Verify duration value + assert call_args[0][0] == 0.025 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + + +class TestRecordConnectionCount: + """Tests for record_connection_count - verifies UpDownCounter.add() calls.""" + + def test_record_connection_count_idle_increment(self, setup_recorder): + """Test incrementing idle connection count.""" + + instruments = setup_recorder + + record_connection_count( + count=1, + pool_name='ConnectionPool', + state='idle', + is_pubsub=False, + ) + + instruments.connection_count.add.assert_called_once() + call_args = instruments.connection_count.add.call_args + + # Verify increment value + assert call_args[0][0] == 1 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + assert attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.IDLE.value + assert attrs[REDIS_CLIENT_CONNECTION_PUBSUB] is False + + def test_record_connection_count_used_decrement(self, setup_recorder): + """Test decrementing used connection count for pubsub.""" + + instruments = setup_recorder + + record_connection_count( + count=-1, + pool_name='ConnectionPool', + state='used', + is_pubsub=True, + ) + + instruments.connection_count.add.assert_called_once() + call_args = instruments.connection_count.add.call_args + + assert call_args[0][0] == -1 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_STATE] == ConnectionState.USED.value + assert attrs[REDIS_CLIENT_CONNECTION_PUBSUB] is True + + +class TestRecordConnectionTimeout: + """Tests for record_connection_timeout - verifies Counter.add() calls.""" + + def test_record_connection_timeout(self, setup_recorder): + """Test recording connection timeout event.""" + + instruments = setup_recorder + + record_connection_timeout( + pool_name='ConnectionPool', + ) + + instruments.connection_timeouts.add.assert_called_once() + call_args = instruments.connection_timeouts.add.call_args + + # Counter increments by 1 + assert call_args[0][0] == 1 + + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + + +class TestRecordConnectionWaitTime: + """Tests for record_connection_wait_time - verifies Histogram.record() calls.""" + + def test_record_connection_wait_time(self, setup_recorder): + """Test recording connection wait time.""" + + instruments = setup_recorder + + record_connection_wait_time( + pool_name='ConnectionPool', + duration_seconds=0.010, + ) + + instruments.connection_wait_time.record.assert_called_once() + call_args = instruments.connection_wait_time.record.call_args + + assert call_args[0][0] == 0.010 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + + +class TestRecordConnectionUseTime: + """Tests for record_connection_use_time - verifies Histogram.record() calls.""" + + def test_record_connection_use_time(self, setup_recorder): + """Test recording connection use time.""" + + instruments = setup_recorder + + record_connection_use_time( + pool_name='ConnectionPool', + duration_seconds=0.050, + ) + + instruments.connection_use_time.record.assert_called_once() + call_args = instruments.connection_use_time.record.call_args + + assert call_args[0][0] == 0.050 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + + +class TestRecordConnectionClosed: + """Tests for record_connection_closed - verifies Counter.add() calls.""" + + def test_record_connection_closed_with_reason(self, setup_recorder): + """Test recording connection closed with reason.""" + + instruments = setup_recorder + + record_connection_closed( + pool_name='ConnectionPool', + close_reason='idle_timeout', + ) + + instruments.connection_closed.add.assert_called_once() + call_args = instruments.connection_closed.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + assert attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] == 'idle_timeout' + + def test_record_connection_closed_with_error(self, setup_recorder): + """Test recording connection closed with error type.""" + + instruments = setup_recorder + + error = ConnectionResetError("Connection reset by peer") + record_connection_closed( + pool_name='ConnectionPool', + close_reason='error', + error_type=error, + ) + + instruments.connection_closed.add.assert_called_once() + attrs = instruments.connection_closed.add.call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] == 'error' + assert attrs[ERROR_TYPE] == 'ConnectionResetError' + + +class TestRecordConnectionRelaxedTimeout: + """Tests for record_connection_relaxed_timeout - verifies UpDownCounter.add() calls.""" + + def test_record_connection_relaxed_timeout_relaxed(self, setup_recorder): + """Test recording relaxed timeout increments counter by 1.""" + + instruments = setup_recorder + + record_connection_relaxed_timeout( + pool_name='ConnectionPool', + maint_notification='MOVING', + relaxed=True, + ) + + instruments.connection_relaxed_timeout.add.assert_called_once() + call_args = instruments.connection_relaxed_timeout.add.call_args + + # relaxed=True means count up (+1) + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + assert attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] == 'MOVING' + + def test_record_connection_relaxed_timeout_unrelaxed(self, setup_recorder): + """Test recording unrelaxed timeout decrements counter by 1.""" + + instruments = setup_recorder + + record_connection_relaxed_timeout( + pool_name='ConnectionPool', + maint_notification='MIGRATING', + relaxed=False, + ) + + instruments.connection_relaxed_timeout.add.assert_called_once() + call_args = instruments.connection_relaxed_timeout.add.call_args + + # relaxed=False means count down (-1) + assert call_args[0][0] == -1 + attrs = call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] == 'MIGRATING' + + +class TestRecordConnectionHandoff: + """Tests for record_connection_handoff - verifies Counter.add() calls.""" + + def test_record_connection_handoff(self, setup_recorder): + """Test recording connection handoff event.""" + + instruments = setup_recorder + + record_connection_handoff( + pool_name='ConnectionPool', + ) + + instruments.connection_handoff.add.assert_called_once() + call_args = instruments.connection_handoff.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[DB_CLIENT_CONNECTION_POOL_NAME] == 'ConnectionPool' + + +class TestRecordErrorCount: + """Tests for record_error_count - verifies Counter.add() calls.""" + + def test_record_error_count(self, setup_recorder): + """Test recording error count with all attributes.""" + + instruments = setup_recorder + + error = ConnectionError("Connection refused") + record_error_count( + server_address='localhost', + server_port=6379, + network_peer_address='127.0.0.1', + network_peer_port=6379, + error_type=error, + retry_attempts=3, + ) + + instruments.client_errors.add.assert_called_once() + call_args = instruments.client_errors.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[SERVER_ADDRESS] == 'localhost' + assert attrs[SERVER_PORT] == 6379 + assert attrs[NETWORK_PEER_ADDRESS] == '127.0.0.1' + assert attrs[NETWORK_PEER_PORT] == 6379 + assert attrs[ERROR_TYPE] == 'ConnectionError' + assert attrs[REDIS_CLIENT_OPERATION_RETRY_ATTEMPTS] == 3 + + +class TestRecordPubsubMessage: + """Tests for record_pubsub_message - verifies Counter.add() calls.""" + + def test_record_pubsub_message_publish(self, setup_recorder): + """Test recording published message.""" + + instruments = setup_recorder + + record_pubsub_message( + direction=PubSubDirection.PUBLISH, + channel='my-channel', + sharded=False, + ) + + instruments.pubsub_messages.add.assert_called_once() + call_args = instruments.pubsub_messages.add.call_args + + assert call_args[0][0] == 1 + attrs = call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_PUBSUB_MESSAGE_DIRECTION] == PubSubDirection.PUBLISH.value + assert attrs[REDIS_CLIENT_PUBSUB_CHANNEL] == 'my-channel' + assert attrs[REDIS_CLIENT_PUBSUB_SHARDED] is False + + def test_record_pubsub_message_receive_sharded(self, setup_recorder): + """Test recording received message on sharded channel.""" + + instruments = setup_recorder + + record_pubsub_message( + direction=PubSubDirection.RECEIVE, + channel='sharded-channel', + sharded=True, + ) + + instruments.pubsub_messages.add.assert_called_once() + attrs = instruments.pubsub_messages.add.call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_PUBSUB_MESSAGE_DIRECTION] == PubSubDirection.RECEIVE.value + assert attrs[REDIS_CLIENT_PUBSUB_CHANNEL] == 'sharded-channel' + assert attrs[REDIS_CLIENT_PUBSUB_SHARDED] is True + + +class TestRecordStreamingLag: + """Tests for record_streaming_lag - verifies Histogram.record() calls.""" + + def test_record_streaming_lag_with_all_attributes(self, setup_recorder): + """Test recording streaming lag with all attributes.""" + + instruments = setup_recorder + + record_streaming_lag( + lag_seconds=0.150, + stream_name='my-stream', + consumer_group='my-group', + consumer_name='consumer-1', + ) + + instruments.stream_lag.record.assert_called_once() + call_args = instruments.stream_lag.record.call_args + + # Verify lag value + assert call_args[0][0] == 0.150 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_STREAM_NAME] == 'my-stream' + assert attrs[REDIS_CLIENT_CONSUMER_GROUP] == 'my-group' + assert attrs[REDIS_CLIENT_CONSUMER_NAME] == 'consumer-1' + + def test_record_streaming_lag_minimal(self, setup_recorder): + """Test recording streaming lag with only required attributes.""" + + instruments = setup_recorder + + record_streaming_lag( + lag_seconds=0.025, + ) + + instruments.stream_lag.record.assert_called_once() + call_args = instruments.stream_lag.record.call_args + + # Verify lag value + assert call_args[0][0] == 0.025 + + def test_record_streaming_lag_with_stream_only(self, setup_recorder): + """Test recording streaming lag with stream name only.""" + + instruments = setup_recorder + + record_streaming_lag( + lag_seconds=0.500, + stream_name='events-stream', + ) + + instruments.stream_lag.record.assert_called_once() + attrs = instruments.stream_lag.record.call_args[1]['attributes'] + assert attrs[REDIS_CLIENT_STREAM_NAME] == 'events-stream' + + +class TestRecorderDisabled: + """Tests for recorder behavior when observability is disabled.""" + + def test_record_operation_duration_when_disabled(self): + """Test that recording does nothing when collector is None.""" + + reset_collector() + + with patch.object(recorder, '_get_or_create_collector', return_value=None): + # Should not raise any exception + record_operation_duration( + command_name='SET', + duration_seconds=0.005, + server_address='localhost', + server_port=6379, + ) + + reset_collector() + + def test_is_enabled_returns_false_when_disabled(self): + """Test is_enabled returns False when collector is None.""" + reset_collector() + + with patch.object(recorder, '_get_or_create_collector', return_value=None): + assert recorder.is_enabled() is False + + recorder.reset_collector() + + def test_all_record_functions_safe_when_disabled(self): + """Test that all record functions are safe to call when disabled.""" + + reset_collector() + + with patch.object(recorder, '_get_or_create_collector', return_value=None): + # None of these should raise + recorder.record_connection_create_time('pool', 0.1) + recorder.record_connection_count(1, 'pool', 'idle', False) + recorder.record_connection_timeout('pool') + recorder.record_connection_wait_time('pool', 0.1) + recorder.record_connection_use_time('pool', 0.1) + recorder.record_connection_closed('pool') + recorder.record_connection_relaxed_timeout('pool', 'MOVING', True) + recorder.record_connection_handoff('pool') + recorder.record_error_count('host', 6379, '127.0.0.1', 6379, Exception(), 0) + recorder.record_pubsub_message(PubSubDirection.PUBLISH) + recorder.record_streaming_lag(0.1, 'stream', 'group', 'consumer') + + recorder.reset_collector() + + +class TestResetCollector: + """Tests for reset_collector function.""" + + def test_reset_collector_clears_global(self): + """Test that reset_collector clears the global collector.""" + + reset_collector() + assert recorder._metrics_collector is None From 3bf95d20767ccd895c4e343851cee559a14e8446 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 5 Dec 2025 12:33:37 +0200 Subject: [PATCH 2/2] Added check for enabled metric groups --- redis/observability/metrics.py | 38 +++ redis/observability/recorder.py | 20 +- tests/test_observability/test_recorder.py | 313 +++++++++++++++++++++- 3 files changed, 356 insertions(+), 15 deletions(-) diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 7a70d82a84..22991de6b6 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -190,6 +190,9 @@ def record_error_count( error_type: Error type retry_attempts: Retry attempts """ + if not hasattr(self, "client_errors"): + return + attrs = self.attr_builder.build_base_attributes( server_address=server_address, server_port=server_port, @@ -229,6 +232,9 @@ def record_maint_notification_count( network_peer_port: Network peer port maint_notification: Maintenance notification """ + if not hasattr(self, "maintenance_notifications"): + return + attrs = self.attr_builder.build_base_attributes( server_address=server_address, server_port=server_port, @@ -260,6 +266,9 @@ def record_connection_count( state: Connection state ('idle' or 'used') is_pubsub: Whether or not the connection is pubsub """ + if not hasattr(self, "connection_count"): + return + attrs = self.attr_builder.build_connection_pool_attributes( pool_name=pool_name, connection_state=state, @@ -274,6 +283,9 @@ def record_connection_timeout(self, pool_name: str) -> None: Args: pool_name: Connection pool name """ + if not hasattr(self, "connection_timeouts"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) self.connection_timeouts.add(1, attributes=attrs) @@ -289,6 +301,9 @@ def record_connection_create_time( pool_name: Connection pool name duration_seconds: Creation time in seconds """ + if not hasattr(self, "connection_create_time"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) self.connection_create_time.record(duration_seconds, attributes=attrs) @@ -304,6 +319,9 @@ def record_connection_wait_time( pool_name: Connection pool name duration_seconds: Wait time in seconds """ + if not hasattr(self, "connection_wait_time"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) self.connection_wait_time.record(duration_seconds, attributes=attrs) @@ -319,6 +337,9 @@ def record_connection_use_time( pool_name: Connection pool name duration_seconds: Use time in seconds """ + if not hasattr(self, "connection_use_time"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) self.connection_use_time.record(duration_seconds, attributes=attrs) @@ -357,6 +378,8 @@ def record_operation_duration( retry_attempts: Number of retry attempts made is_blocking: Whether the operation is a blocking command """ + if not hasattr(self, "operation_duration"): + return # Check if this command should be tracked if not self.config.should_track_command(command_name): @@ -404,6 +427,9 @@ def record_connection_closed( close_reason: Reason for closing (e.g., 'idle_timeout', 'error', 'shutdown') error_type: Error type if closed due to error """ + if not hasattr(self, "connection_closed"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) if close_reason: attrs[REDIS_CLIENT_CONNECTION_CLOSE_REASON] = close_reason @@ -425,6 +451,9 @@ def record_connection_relaxed_timeout( maint_notification: Maintenance notification type relaxed: True to count up (relaxed), False to count down (unrelaxed) """ + if not hasattr(self, "connection_relaxed_timeout"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) attrs[REDIS_CLIENT_CONNECTION_NOTIFICATION] = maint_notification self.connection_relaxed_timeout.add(1 if relaxed else -1, attributes=attrs) @@ -439,6 +468,9 @@ def record_connection_handoff( Args: pool_name: Connection pool name """ + if not hasattr(self, "connection_handoff"): + return + attrs = self.attr_builder.build_connection_pool_attributes(pool_name=pool_name) self.connection_handoff.add(1, attributes=attrs) @@ -458,6 +490,9 @@ def record_pubsub_message( channel: Pub/Sub channel name sharded: True if sharded Pub/Sub channel """ + if not hasattr(self, "pubsub_messages"): + return + attrs = self.attr_builder.build_pubsub_message_attributes( direction=direction, channel=channel, @@ -483,6 +518,9 @@ def record_streaming_lag( consumer_group: Consumer group name consumer_name: Consumer name """ + if not hasattr(self, "stream_lag"): + return + attrs = self.attr_builder.build_streaming_attributes( stream_name=stream_name, consumer_group=consumer_group, diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index 835e92aaf9..fee2d196bd 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -22,7 +22,7 @@ import time from typing import Optional -from redis.observability.attributes import PubSubDirection +from redis.observability.attributes import PubSubDirection, ConnectionState from redis.observability.metrics import RedisMetricsCollector # Global metrics collector instance (lazy-initialized) @@ -124,7 +124,7 @@ def record_connection_create_time( def record_connection_count( count: int, pool_name: str, - state: str, + state: ConnectionState, is_pubsub: bool = False, ) -> None: """ @@ -148,11 +148,10 @@ def record_connection_count( # try: from redis.observability.attributes import ConnectionState - connection_state = ConnectionState.IDLE if state == 'idle' else ConnectionState.USED _metrics_collector.record_connection_count( count=count, pool_name=pool_name, - state=connection_state, + state=state, is_pubsub=is_pubsub, ) # except Exception: @@ -408,14 +407,11 @@ def record_pubsub_message( if _metrics_collector is None: return - # try: - _metrics_collector.record_pubsub_message( - direction=direction, - channel=channel, - sharded=sharded, - ) - # except Exception: - # pass + _metrics_collector.record_pubsub_message( + direction=direction, + channel=channel, + sharded=sharded, + ) def record_streaming_lag( diff --git a/tests/test_observability/test_recorder.py b/tests/test_observability/test_recorder.py index b78f964dce..f39bedfc39 100644 --- a/tests/test_observability/test_recorder.py +++ b/tests/test_observability/test_recorder.py @@ -44,6 +44,7 @@ REDIS_CLIENT_CONSUMER_NAME, ) from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.metrics import RedisMetricsCollector from redis.observability.recorder import record_operation_duration, record_connection_create_time, \ record_connection_count, record_connection_timeout, record_connection_wait_time, record_connection_use_time, \ record_connection_closed, record_connection_relaxed_timeout, record_connection_handoff, record_error_count, \ @@ -257,7 +258,7 @@ def test_record_connection_count_idle_increment(self, setup_recorder): record_connection_count( count=1, pool_name='ConnectionPool', - state='idle', + state=ConnectionState.IDLE, is_pubsub=False, ) @@ -281,7 +282,7 @@ def test_record_connection_count_used_decrement(self, setup_recorder): record_connection_count( count=-1, pool_name='ConnectionPool', - state='used', + state=ConnectionState.USED, is_pubsub=True, ) @@ -625,7 +626,7 @@ def test_all_record_functions_safe_when_disabled(self): with patch.object(recorder, '_get_or_create_collector', return_value=None): # None of these should raise recorder.record_connection_create_time('pool', 0.1) - recorder.record_connection_count(1, 'pool', 'idle', False) + recorder.record_connection_count(1, 'pool', ConnectionState.IDLE, False) recorder.record_connection_timeout('pool') recorder.record_connection_wait_time('pool', 0.1) recorder.record_connection_use_time('pool', 0.1) @@ -647,3 +648,309 @@ def test_reset_collector_clears_global(self): reset_collector() assert recorder._metrics_collector is None + + +class TestMetricGroupsDisabled: + """Tests for verifying metrics are not sent to Meter when their MetricGroup is disabled. + + These tests call recorder.record_*() functions and verify that no calls + are made to the underlying Meter instruments (.add() or .record()). + """ + + def _create_collector_with_disabled_groups(self, mock_instruments, enabled_groups): + """Helper to create a collector with specific metric groups enabled.""" + mock_meter = MagicMock() + + def create_counter_side_effect(name, **kwargs): + instrument_map = { + 'redis.client.errors': mock_instruments.client_errors, + 'redis.client.maintenance.notifications': mock_instruments.maintenance_notifications, + 'db.client.connection.timeouts': mock_instruments.connection_timeouts, + 'redis.client.connection.closed': mock_instruments.connection_closed, + 'redis.client.connection.handoff': mock_instruments.connection_handoff, + 'redis.client.pubsub.messages': mock_instruments.pubsub_messages, + } + return instrument_map.get(name, MagicMock()) + + def create_up_down_counter_side_effect(name, **kwargs): + instrument_map = { + 'db.client.connection.count': mock_instruments.connection_count, + 'redis.client.connection.relaxed_timeout': mock_instruments.connection_relaxed_timeout, + } + return instrument_map.get(name, MagicMock()) + + def create_histogram_side_effect(name, **kwargs): + instrument_map = { + 'db.client.connection.create_time': mock_instruments.connection_create_time, + 'db.client.connection.wait_time': mock_instruments.connection_wait_time, + 'db.client.connection.use_time': mock_instruments.connection_use_time, + 'db.client.operation.duration': mock_instruments.operation_duration, + 'redis.client.stream.lag': mock_instruments.stream_lag, + } + return instrument_map.get(name, MagicMock()) + + mock_meter.create_counter.side_effect = create_counter_side_effect + mock_meter.create_up_down_counter.side_effect = create_up_down_counter_side_effect + mock_meter.create_histogram.side_effect = create_histogram_side_effect + + config = OTelConfig(metric_groups=enabled_groups) + + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + return RedisMetricsCollector(mock_meter, config) + + def test_record_operation_duration_no_meter_call_when_command_disabled(self): + """Test that record_operation_duration makes no Meter calls when COMMAND group is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.RESILIENCY] # No COMMAND + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_operation_duration( + command_name='SET', + duration_seconds=0.005, + server_address='localhost', + server_port=6379, + ) + + # Verify no call to the histogram's record method + instruments.operation_duration.record.assert_not_called() + + def test_record_connection_count_no_meter_call_when_connection_basic_disabled(self): + """Test that record_connection_count makes no Meter calls when CONNECTION_BASIC is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_BASIC + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_count( + count=1, + pool_name='test-pool', + state=ConnectionState.IDLE, + is_pubsub=False, + ) + + # Verify no call to the up_down_counter's add method + instruments.connection_count.add.assert_not_called() + + def test_record_connection_create_time_no_meter_call_when_connection_basic_disabled(self): + """Test that record_connection_create_time makes no Meter calls when CONNECTION_BASIC is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_BASIC + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_create_time( + pool_name='test-pool', + duration_seconds=0.050, + ) + + # Verify no call to the histogram's record method + instruments.connection_create_time.record.assert_not_called() + + def test_record_connection_wait_time_no_meter_call_when_connection_advanced_disabled(self): + """Test that record_connection_wait_time makes no Meter calls when CONNECTION_ADVANCED is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_ADVANCED + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_wait_time( + pool_name='test-pool', + duration_seconds=0.010, + ) + + # Verify no call to the histogram's record method + instruments.connection_wait_time.record.assert_not_called() + + def test_record_connection_use_time_no_meter_call_when_connection_advanced_disabled(self): + """Test that record_connection_use_time makes no Meter calls when CONNECTION_ADVANCED is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_ADVANCED + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_use_time( + pool_name='test-pool', + duration_seconds=0.100, + ) + + # Verify no call to the histogram's record method + instruments.connection_use_time.record.assert_not_called() + + def test_record_connection_closed_no_meter_call_when_connection_advanced_disabled(self): + """Test that record_connection_closed makes no Meter calls when CONNECTION_ADVANCED is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_ADVANCED + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_closed( + pool_name='test-pool', + close_reason='idle_timeout', + ) + + # Verify no call to the counter's add method + instruments.connection_closed.add.assert_not_called() + + def test_record_connection_relaxed_timeout_no_meter_call_when_connection_basic_disabled(self): + """Test that record_connection_relaxed_timeout makes no Meter calls when CONNECTION_BASIC is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No CONNECTION_BASIC + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_connection_relaxed_timeout( + pool_name='test-pool', + maint_notification='MOVING', + relaxed=True, + ) + + # Verify no call to the up_down_counter's add method + instruments.connection_relaxed_timeout.add.assert_not_called() + + def test_record_pubsub_message_no_meter_call_when_pubsub_disabled(self): + """Test that record_pubsub_message makes no Meter calls when PUBSUB group is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No PUBSUB + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_pubsub_message( + direction=PubSubDirection.PUBLISH, + channel='test-channel', + ) + + # Verify no call to the counter's add method + instruments.pubsub_messages.add.assert_not_called() + + def test_record_streaming_lag_no_meter_call_when_streaming_disabled(self): + """Test that record_streaming_lag makes no Meter calls when STREAMING group is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No STREAMING + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_streaming_lag( + lag_seconds=0.150, + stream_name='test-stream', + consumer_group='test-group', + consumer_name='test-consumer', + ) + + # Verify no call to the histogram's record method + instruments.stream_lag.record.assert_not_called() + + def test_record_error_count_no_meter_call_when_resiliency_disabled(self): + """Test that record_error_count makes no Meter calls when RESILIENCY group is disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND] # No RESILIENCY + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + record_error_count( + server_address='localhost', + server_port=6379, + network_peer_address='127.0.0.1', + network_peer_port=6379, + error_type=Exception('test error'), + retry_attempts=0, + ) + + # Verify no call to the counter's add method + instruments.client_errors.add.assert_not_called() + + def test_all_record_functions_no_meter_calls_when_all_groups_disabled(self): + """Test that all record_* functions make no Meter calls when all groups are disabled.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [] # No metric groups enabled + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + # Call all record functions + record_operation_duration('GET', 0.001, 'localhost', 6379) + record_connection_create_time('pool', 0.050) + record_connection_count(1, 'pool', ConnectionState.IDLE, False) + record_connection_timeout('pool') + record_connection_wait_time('pool', 0.010) + record_connection_use_time('pool', 0.100) + record_connection_closed('pool', 'shutdown') + record_connection_relaxed_timeout('pool', 'MOVING', True) + record_connection_handoff('pool') + record_error_count('localhost', 6379, '127.0.0.1', 6379, Exception('err'), 0) + record_pubsub_message(PubSubDirection.PUBLISH, 'channel') + record_streaming_lag(0.150, 'stream', 'group', 'consumer') + + # Verify no Meter instrument methods were called + instruments.operation_duration.record.assert_not_called() + instruments.connection_create_time.record.assert_not_called() + instruments.connection_count.add.assert_not_called() + instruments.connection_timeouts.add.assert_not_called() + instruments.connection_wait_time.record.assert_not_called() + instruments.connection_use_time.record.assert_not_called() + instruments.connection_closed.add.assert_not_called() + instruments.connection_relaxed_timeout.add.assert_not_called() + instruments.connection_handoff.add.assert_not_called() + instruments.client_errors.add.assert_not_called() + instruments.pubsub_messages.add.assert_not_called() + instruments.stream_lag.record.assert_not_called() + + def test_enabled_group_receives_meter_calls_disabled_group_does_not(self): + """Test that only enabled groups receive Meter calls.""" + instruments = MockInstruments() + collector = self._create_collector_with_disabled_groups( + instruments, + [MetricGroup.COMMAND, MetricGroup.PUBSUB] # Only COMMAND and PUBSUB enabled + ) + + recorder.reset_collector() + with patch.object(recorder, '_get_or_create_collector', return_value=collector): + # Call functions from enabled groups + record_operation_duration('GET', 0.001, 'localhost', 6379) + record_pubsub_message(PubSubDirection.PUBLISH, 'channel') + + # Call functions from disabled groups + record_connection_count(1, 'pool', ConnectionState.IDLE, False) + record_error_count('localhost', 6379, '127.0.0.1', 6379, Exception('err'), 0) + record_streaming_lag(0.150, 'stream', 'group', 'consumer') + + # Enabled groups should have received Meter calls + instruments.operation_duration.record.assert_called_once() + instruments.pubsub_messages.add.assert_called_once() + + # Disabled groups should NOT have received Meter calls + instruments.connection_count.add.assert_not_called() + instruments.client_errors.add.assert_not_called() + instruments.stream_lag.record.assert_not_called()