diff --git a/examples/agent_feature_store/feature_repo/feature_store.yaml b/examples/agent_feature_store/feature_repo/feature_store.yaml index 7cc89d536c8..5a4cafbb00b 100644 --- a/examples/agent_feature_store/feature_repo/feature_store.yaml +++ b/examples/agent_feature_store/feature_repo/feature_store.yaml @@ -16,6 +16,11 @@ offline_store: entity_key_serialization_version: 3 +mlflow: + enabled: true + tracking_uri: "http://localhost:5000" + enable_tracing: true + feature_server: type: mcp enabled: true diff --git a/pyproject.toml b/pyproject.toml index 9f25f82ad0e..0efd1403e63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,13 @@ mssql = ["ibis-framework[mssql]>=10.0.0"] oracle = ["ibis-framework[oracle]>=10.0.0"] mysql = ["pymysql", "types-PyMySQL"] openlineage = ["openlineage-python>=1.40.0"] +mlflow = [ + "mlflow>=2.14.0", + "opentelemetry-api>=1.28.0", + "opentelemetry-sdk>=1.28.0", + "opentelemetry-instrumentation-fastapi>=0.49b0", + "opentelemetry-instrumentation-httpx>=0.49b0", +] opentelemetry = ["prometheus_client", "psutil"] spark = ["pyspark>=4.0.0"] trino = ["trino>=0.305.0,<0.400.0", "regex"] diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 43fb8485316..d25594c913f 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -251,6 +251,39 @@ async def load_static_artifacts(app: FastAPI, store): logger.warning(f"Failed to load static artifacts: {e}") +def _instrument_app_for_tracing(app: FastAPI, store: "feast.FeatureStore") -> None: + """Add OTEL instrumentation to FastAPI if tracing is enabled. + + This enables automatic extraction of ``traceparent`` HTTP headers from + incoming requests, creating server spans that link to the caller's trace. + This is the Tier 3 bridge: when an agent sends traceparent, server spans + become children of the agent's trace tree. + """ + mlflow_cfg = store.config.mlflow + if mlflow_cfg is None or not mlflow_cfg.enabled or not mlflow_cfg.enable_tracing: + return + + from feast.tracing import _is_embedded_store + + tracking_uri = mlflow_cfg.get_tracking_uri() + if _is_embedded_store(store) and tracking_uri and tracking_uri.startswith("http"): + logger.info( + "Skipping FastAPI OTEL instrumentation (embedded store + HTTP tracking)" + ) + return + + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) + logger.info("FastAPI OTEL instrumentation enabled for trace propagation") + except ImportError: + logger.debug( + "opentelemetry-instrumentation-fastapi not installed; " + "cross-process trace linking disabled" + ) + + def get_app( store: "feast.FeatureStore", registry_ttl_sec: int = DEFAULT_FEATURE_SERVER_REGISTRY_TTL, @@ -360,44 +393,64 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + _instrument_app_for_tracing(app, store) + @app.post( "/get-online-features", dependencies=[Depends(inject_user_details)], response_model=OnlineFeaturesResponse, ) - async def get_online_features(request: GetOnlineFeaturesRequest) -> Any: - with feast_metrics.track_request_latency( - "/get-online-features", - ) as metrics_ctx: - features = await _get_features(request, store) - feat_count, fv_count = _resolve_feature_counts(features) - metrics_ctx.feature_count = feat_count - metrics_ctx.feature_view_count = fv_count - - entity_count = len(next(iter(request.entities.values()), [])) - feast_metrics.track_online_features_entities(entity_count) - - read_params = dict( - features=features, - entity_rows=request.entities, - full_feature_names=request.full_feature_names, - include_feature_view_version_metadata=request.include_feature_view_version_metadata, - ) + async def get_online_features( + request: GetOnlineFeaturesRequest, raw_request: Request + ) -> Any: + from feast.tracing import traced_tool_span - if store._get_provider().async_supported.online.read: - response = await store.get_online_features_async(**read_params) # type: ignore - else: - response = await run_in_threadpool( - lambda: store.get_online_features(**read_params) # type: ignore + session_id = raw_request.headers.get("mcp-session-id", "") + feature_refs = ",".join(request.features) if request.features else "" + entity_count = len(next(iter(request.entities.values()), [])) + + with traced_tool_span( + store, + "feast.get_online_features", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_refs": feature_refs, + "feast.entity_count": str(entity_count), + "feast.project": store.config.project, + "feast.retrieval_type": "online", + }, + ): + with feast_metrics.track_request_latency( + "/get-online-features", + ) as metrics_ctx: + features = await _get_features(request, store) + feat_count, fv_count = _resolve_feature_counts(features) + metrics_ctx.feature_count = feat_count + metrics_ctx.feature_view_count = fv_count + + feast_metrics.track_online_features_entities(entity_count) + + read_params = dict( + features=features, + entity_rows=request.entities, + full_feature_names=request.full_feature_names, + include_feature_view_version_metadata=request.include_feature_view_version_metadata, ) - response_dict = await run_in_threadpool( - MessageToDict, - response.proto, - preserving_proto_field_name=True, - float_precision=18, - ) - return response_dict + if store._get_provider().async_supported.online.read: + response = await store.get_online_features_async(**read_params) # type: ignore + else: + response = await run_in_threadpool( + lambda: store.get_online_features(**read_params) # type: ignore + ) + + response_dict = await run_in_threadpool( + MessageToDict, + response.proto, + preserving_proto_field_name=True, + float_precision=18, + ) + return response_dict @app.post( "/retrieve-online-documents", @@ -405,41 +458,58 @@ async def get_online_features(request: GetOnlineFeaturesRequest) -> Any: response_model=OnlineFeaturesResponse, ) async def retrieve_online_documents( - request: GetOnlineDocumentsRequest, + request: GetOnlineDocumentsRequest, raw_request: Request ) -> Any: - with feast_metrics.track_request_latency("/retrieve-online-documents"): - logger.warning( - "This endpoint is in alpha and will be moved to /get-online-features when stable." - ) - features = await _get_features(request, store) + from feast.tracing import traced_tool_span - read_params = dict( - features=features, - query=request.query, - top_k=request.top_k, - ) - if request.api_version == 2 and request.query_string is not None: - read_params["query_string"] = request.query_string + session_id = raw_request.headers.get("mcp-session-id", "") + feature_refs = ",".join(request.features) if request.features else "" + top_k = str(request.top_k) if request.top_k else "" - if request.api_version == 2: - read_params["include_feature_view_version_metadata"] = ( - request.include_feature_view_version_metadata - ) - response = await run_in_threadpool( - lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + with traced_tool_span( + store, + "feast.retrieve_online_documents", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_refs": feature_refs, + "feast.top_k": top_k, + "feast.project": store.config.project, + "feast.retrieval_type": "document", + }, + ): + with feast_metrics.track_request_latency("/retrieve-online-documents"): + logger.warning( + "This endpoint is in alpha and will be moved to /get-online-features when stable." ) - else: - response = await run_in_threadpool( - lambda: store.retrieve_online_documents(**read_params) # type: ignore + features = await _get_features(request, store) + + read_params = dict( + features=features, + query=request.query, + top_k=request.top_k, ) + if request.api_version == 2 and request.query_string is not None: + read_params["query_string"] = request.query_string - response_dict = await run_in_threadpool( - MessageToDict, - response.proto, - preserving_proto_field_name=True, - float_precision=18, - ) - return response_dict + if request.api_version == 2: + read_params["include_feature_view_version_metadata"] = ( + request.include_feature_view_version_metadata + ) + response = await run_in_threadpool( + lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + ) + else: + response = await run_in_threadpool( + lambda: store.retrieve_online_documents(**read_params) # type: ignore + ) + + response_dict = await run_in_threadpool( + MessageToDict, + response.proto, + preserving_proto_field_name=True, + float_precision=18, + ) + return response_dict @app.post("/push", dependencies=[Depends(inject_user_details)]) async def push(request: PushFeaturesRequest) -> Response: @@ -550,19 +620,34 @@ async def _get_feast_object( ) @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) - async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None: - df = pd.DataFrame(request.df) - feature_view_name = request.feature_view_name - allow_registry_cache = request.allow_registry_cache - resource = await _get_feast_object(feature_view_name, allow_registry_cache) - assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) - await run_in_threadpool( - store.write_to_online_store, - feature_view_name=feature_view_name, - df=df, - allow_registry_cache=allow_registry_cache, - transform_on_write=request.transform_on_write, - ) + async def write_to_online_store( + request: WriteToFeatureStoreRequest, raw_request: Request + ) -> None: + from feast.tracing import traced_tool_span + + session_id = raw_request.headers.get("mcp-session-id", "") + + with traced_tool_span( + store, + "feast.write_to_online_store", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_view": request.feature_view_name, + "feast.project": store.config.project, + }, + ): + df = pd.DataFrame(request.df) + feature_view_name = request.feature_view_name + allow_registry_cache = request.allow_registry_cache + resource = await _get_feast_object(feature_view_name, allow_registry_cache) + assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) + await run_in_threadpool( + store.write_to_online_store, + feature_view_name=feature_view_name, + df=df, + allow_registry_cache=allow_registry_cache, + transform_on_write=request.transform_on_write, + ) @app.get("/health") async def health(): diff --git a/sdk/python/feast/infra/mcp_servers/mcp_server.py b/sdk/python/feast/infra/mcp_servers/mcp_server.py index 972023cdd12..ecf513b9de2 100644 --- a/sdk/python/feast/infra/mcp_servers/mcp_server.py +++ b/sdk/python/feast/infra/mcp_servers/mcp_server.py @@ -37,13 +37,20 @@ def add_mcp_support_to_app(app, store: FeatureStore, config) -> Optional["FastAp return None try: - # Create MCP server from the FastAPI app + # Create MCP server from the FastAPI app. + # Forward mcp-session-id so endpoint handlers can tag spans with it. mcp = FastApiMCP( app, name=getattr(config, "mcp_server_name", "feast-feature-store"), description="Feast Feature Store MCP Server - Access feature store data and operations through MCP", + headers=["authorization", "mcp-session-id"], ) + # Instrument the internal httpx client with OTEL so that trace + # context propagates from the /mcp server span into the internal + # ASGI calls to the actual FastAPI endpoints (Tier 3 enabler). + _instrument_mcp_http_client(mcp) + transport = getattr(config, "mcp_transport", "sse") if transport == "http": mount_http = getattr(mcp, "mount_http", None) @@ -83,3 +90,28 @@ def add_mcp_support_to_app(app, store: FeatureStore, config) -> Optional["FastAp except Exception as e: logger.error(f"Failed to initialize MCP integration: {e}", exc_info=True) return None + + +def _instrument_mcp_http_client(mcp: "FastApiMCP") -> None: + """Instrument fastapi_mcp's internal httpx client with OTEL. + + This ensures that when fastapi_mcp makes internal ASGI calls to + the actual FastAPI endpoints, the current OTEL trace context + (from the /mcp server span) is propagated via traceparent headers. + Without this, the endpoint spans would be orphaned from the + incoming trace. + """ + try: + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + http_client = getattr(mcp, "_http_client", None) + if http_client is not None: + HTTPXClientInstrumentor.instrument_client(http_client) + logger.info("MCP internal httpx client instrumented for trace propagation") + else: + logger.debug("Could not access fastapi_mcp internal httpx client") + except ImportError: + logger.debug( + "opentelemetry-instrumentation-httpx not installed; " + "internal trace propagation disabled" + ) diff --git a/sdk/python/feast/mlflow_integration/config.py b/sdk/python/feast/mlflow_integration/config.py new file mode 100644 index 00000000000..27344666a9f --- /dev/null +++ b/sdk/python/feast/mlflow_integration/config.py @@ -0,0 +1,69 @@ +import os +from typing import Optional + +from pydantic import StrictBool, StrictInt, StrictStr + +from feast.repo_config import FeastBaseModel + +MLFLOW_TAG_MAX_LENGTH = 5000 +MLFLOW_TAG_TRUNCATION_LIMIT = MLFLOW_TAG_MAX_LENGTH - 10 +MLFLOW_TAG_TRUNCATION_SLICE = MLFLOW_TAG_MAX_LENGTH - 13 + +MLFLOW_PARAM_MAX_LENGTH = 500 +MLFLOW_PARAM_TRUNCATION_LIMIT = MLFLOW_PARAM_MAX_LENGTH - 10 +MLFLOW_PARAM_TRUNCATION_SLICE = MLFLOW_PARAM_MAX_LENGTH - 13 + +DEFAULT_ENTITY_DF_MAX_ROWS = 100_000 + + +def resolve_tracking_uri(configured_uri: Optional[str] = None) -> Optional[str]: + """Return the effective MLflow tracking URI. + + Priority: + 1. Explicitly configured URI from feature_store.yaml + 2. MLFLOW_TRACKING_URI environment variable (MLflow's native convention) + 3. None — let MLflow fall back to its own defaults (local ./mlruns) + """ + if configured_uri: + return configured_uri + return os.environ.get("MLFLOW_TRACKING_URI") + + +class MlflowConfig(FeastBaseModel): + enabled: StrictBool = False + """ bool: Whether MLflow integration is enabled. Defaults to False. """ + + tracking_uri: Optional[StrictStr] = None + """ str: MLflow tracking URI. When not set, the MLFLOW_TRACKING_URI + environment variable is used. If neither is set, MLflow falls back + to its own default (local ./mlruns directory). """ + + auto_log: StrictBool = True + """ bool: Automatically log feature retrieval metadata to the active + MLflow run when get_historical_features or get_online_features is + called. Defaults to True. """ + + auto_log_entity_df: StrictBool = False + """ bool: When True, the input entity_df (or SQL query) is recorded in + the MLflow run. Defaults to False. """ + + entity_df_max_rows: StrictInt = DEFAULT_ENTITY_DF_MAX_ROWS + """ int: Maximum number of entity DataFrame rows to save as an MLflow + artifact. DataFrames exceeding this limit are skipped to avoid + OOM and slow uploads. Defaults to 100000. """ + + log_operations: StrictBool = False + """ bool: Log feast apply and materialize operations to a separate + MLflow experiment. Opt-in to avoid noise. Defaults to False. """ + + ops_experiment_suffix: StrictStr = "-feast-ops" + """ str: Suffix appended to the project name to form the MLflow + experiment name for operation logs. Defaults to '-feast-ops'. """ + + enable_tracing: StrictBool = True + """ bool: When True and mlflow.enabled=True, initialize OTEL TracerProvider + with MlflowSpanExporter for distributed tracing. Defaults to True. """ + + def get_tracking_uri(self) -> Optional[str]: + """Resolve the effective tracking URI for this config instance.""" + return resolve_tracking_uri(self.tracking_uri) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 3fbcb9ec498..c9885324a7a 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -355,6 +355,9 @@ class RepoConfig(FeastBaseModel): openlineage_config: Optional[OpenLineageConfig] = Field(None, alias="openlineage") """ Configuration for OpenLineage data lineage integration (optional). """ + mlflow: Optional[Any] = None + """ MlflowConfig: MLflow integration and tracing configuration (optional). """ + def __init__(self, **data: Any): super().__init__(**data) @@ -395,6 +398,12 @@ def __init__(self, **data: Any): if "openlineage" in data: self.openlineage_config = data["openlineage"] + # Initialize MLflow configuration + if "mlflow" in data and isinstance(data["mlflow"], dict): + from feast.mlflow_integration.config import MlflowConfig + + self.mlflow = MlflowConfig(**data["mlflow"]) + if self.entity_key_serialization_version < 3: warnings.warn( "The serialization version below 3 are deprecated. " diff --git a/sdk/python/feast/tracing.py b/sdk/python/feast/tracing.py new file mode 100644 index 00000000000..ab4fd7cb024 --- /dev/null +++ b/sdk/python/feast/tracing.py @@ -0,0 +1,168 @@ +""" +Feast MCP server tracing via OpenTelemetry + MLflow. + +Initializes an OTEL TracerProvider lazily (post-fork safe for gunicorn). +Two export modes, auto-detected: + + 1. File-based (Milvus-lite safe): + Spans export to local ./mlruns via MlflowSpanExporter. + No HTTP calls, no background threads, no segfault risk. + + 2. HTTP-based (production, non-embedded stores): + Spans export to a remote MLflow tracking server. + Enables cross-process trace stitching (Tier 3). + +Auto-detection logic: + - If online_store uses embedded milvus-lite (has ``path`` but no ``host``) + AND tracking_uri is HTTP → force file-based mode with a warning. + - Otherwise, use tracking_uri from config. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional + +if TYPE_CHECKING: + from feast.feature_store import FeatureStore + +_logger = logging.getLogger(__name__) + +_initialized = False +_enabled = False +_tracer: Any = None + + +def _is_embedded_store(store: "FeatureStore") -> bool: + """Detect if the online store is an embedded C++ runtime (milvus-lite). + + Milvus-lite embeds a C++ runtime in-process. Combined with MLflow's HTTP + export (background threads + network I/O) in a gunicorn-forked worker, + this causes segfaults. We detect this case so we can disable HTTP export. + """ + online_cfg = store.config.online_config + if online_cfg is None: + return False + store_type = "" + if isinstance(online_cfg, dict): + store_type = online_cfg.get("type", "") + has_path = bool(online_cfg.get("path")) + has_host = bool(online_cfg.get("host")) + else: + store_type = getattr(online_cfg, "type", "") or "" + has_path = bool(getattr(online_cfg, "path", None)) + has_host = bool(getattr(online_cfg, "host", None)) + + if "milvus" not in store_type.lower(): + return False + return has_path and not has_host + + +def _lazy_init(store: "FeatureStore") -> bool: + """Initialize OTEL TracerProvider on first use (post-fork safe). + + Called from ``traced_tool_span`` the first time a traced endpoint is + hit. At this point the code runs inside the gunicorn worker, so + importing heavy libraries and setting up providers is safe. + """ + global _initialized, _enabled, _tracer + + if _initialized: + return _enabled + _initialized = True + + mlflow_cfg = store.config.mlflow + if mlflow_cfg is None or not mlflow_cfg.enabled: + return False + if not mlflow_cfg.enable_tracing: + return False + + try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + except ImportError: + _logger.debug("opentelemetry-sdk not installed; tracing disabled") + return False + + tracking_uri = mlflow_cfg.get_tracking_uri() + is_embedded = _is_embedded_store(store) + is_http_uri = bool(tracking_uri and tracking_uri.startswith("http")) + + if is_embedded and is_http_uri: + _logger.warning( + "Embedded online store detected (milvus-lite). " + "Forcing file-based tracing to avoid segfault. " + "Use a remote Milvus cluster (online_store.host) for HTTP tracing." + ) + tracking_uri = None + + try: + import mlflow + + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(store.config.project) + except Exception as e: + _logger.warning("Failed to configure MLflow experiment: %s", e) + + try: + from mlflow.tracing.export import MlflowV3SpanExporter as Exporter + except ImportError: + try: + from mlflow.tracing.export import MlflowSpanExporter as Exporter + except ImportError: + _logger.debug("MLflow span exporter not available; tracing disabled") + return False + + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(Exporter())) + trace.set_tracer_provider(provider) + _tracer = trace.get_tracer("feast.server") + _enabled = True + + mode = "HTTP" if is_http_uri and not is_embedded else "file" + _logger.info( + "Feast tracing initialized (mode=%s, project=%s)", mode, store.config.project + ) + return True + + +@contextmanager +def traced_tool_span( + store: "FeatureStore", + name: str, + attributes: Optional[Dict[str, str]] = None, +) -> Iterator[Any]: + """Context manager that creates a traced span for one tool call. + + Lazily initializes tracing on first use. If tracing is disabled + or unavailable the body runs with ``span=None``. + + The span is created within the current OTEL context, so if + FastAPI OTEL instrumentation has set a parent span (from an + incoming ``traceparent`` header), this span becomes a child of + that parent — enabling cross-process trace linking. + """ + if not _lazy_init(store): + yield None + return + + try: + from opentelemetry import context, trace + + current_ctx = context.get_current() + span = _tracer.start_span(name, context=current_ctx) + if attributes: + for k, v in attributes.items(): + span.set_attribute(k, v) + token = context.attach(trace.set_span_in_context(span)) + try: + yield span + finally: + span.end() + context.detach(token) + except Exception as exc: + _logger.debug("Traced tool span failed: %s", exc) + yield None diff --git a/sdk/python/feast/tracing_context.py b/sdk/python/feast/tracing_context.py new file mode 100644 index 00000000000..560fb1da05a --- /dev/null +++ b/sdk/python/feast/tracing_context.py @@ -0,0 +1,85 @@ +""" +Thread-local trace-scoped feature accumulator. + +During an agent→Feast round-trip the server-side retrieval pushes feature refs +into this buffer. After the response returns, the agent-side +``FeastSpanProcessor`` reads it to tag the LLM span with +``feast.context_features``. +""" + +from __future__ import annotations + +import threading +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Dict, Iterator, List, Optional, Set + + +@dataclass +class FeastTraceContext: + """Accumulated feature-retrieval metadata for the current trace.""" + + feature_refs: List[str] = field(default_factory=list) + feature_views: Set[str] = field(default_factory=set) + feature_service: Optional[str] = None + retrieval_span_ids: List[str] = field(default_factory=list) + + def push_retrieval( + self, + feature_refs: List[str], + feature_service: Optional[str] = None, + span_id: Optional[str] = None, + ) -> None: + """Record one retrieval's metadata into this context.""" + self.feature_refs.extend(feature_refs) + for ref in feature_refs: + parts = ref.split(":", 1) + if len(parts) == 2: + self.feature_views.add(parts[0]) + if feature_service: + self.feature_service = feature_service + if span_id: + self.retrieval_span_ids.append(span_id) + + def get_context_attributes(self) -> Dict[str, str]: + """Return span-attribute-ready dict of accumulated metadata.""" + if not self.feature_refs: + return {} + deduplicated = sorted(set(self.feature_refs)) + attrs: Dict[str, str] = { + "feast.context_features": ",".join(deduplicated), + "feast.context_feature_count": str(len(deduplicated)), + } + if self.feature_views: + attrs["feast.context_feature_views"] = ",".join( + sorted(self.feature_views) + ) + if self.feature_service: + attrs["feast.context_feature_service"] = self.feature_service + return attrs + + def clear(self) -> None: + self.feature_refs.clear() + self.feature_views.clear() + self.feature_service = None + self.retrieval_span_ids.clear() + + +_thread_local = threading.local() + + +def get_current_context() -> Optional[FeastTraceContext]: + """Return the active ``FeastTraceContext`` for this thread, or ``None``.""" + return getattr(_thread_local, "feast_ctx", None) + + +@contextmanager +def feast_trace_scope() -> Iterator[FeastTraceContext]: + """Context manager that creates and cleans up a ``FeastTraceContext``.""" + ctx = FeastTraceContext() + _thread_local.feast_ctx = ctx + try: + yield ctx + finally: + ctx.clear() + _thread_local.feast_ctx = None diff --git a/sdk/python/feast/tracing_hooks.py b/sdk/python/feast/tracing_hooks.py new file mode 100644 index 00000000000..53aa54bf1b8 --- /dev/null +++ b/sdk/python/feast/tracing_hooks.py @@ -0,0 +1,50 @@ +""" +MLflow SpanProcessor that tags LLM spans with Feast feature context. + +Install via ``install_feast_span_processor()`` which registers a callable +with ``mlflow.tracing.configure(span_processors=[...])``. The processor +only fires on ``LLM`` / ``CHAT_MODEL`` span types and reads the thread-local +``FeastTraceContext`` populated during feature retrieval. +""" + +from __future__ import annotations + +import logging +from typing import Any + +_logger = logging.getLogger(__name__) + +_LLM_SPAN_TYPES = {"LLM", "CHAT_MODEL"} + + +def feast_span_processor(span: Any) -> None: + """Callable for ``mlflow.tracing.configure(span_processors=[...])``.""" + span_type = getattr(span, "span_type", None) + if span_type not in _LLM_SPAN_TYPES: + return + + from feast.tracing_context import get_current_context + + ctx = get_current_context() + if ctx is None or not ctx.feature_refs: + return + + attrs = ctx.get_context_attributes() + for key, value in attrs.items(): + try: + span.set_attribute(key, value) + except Exception: + _logger.debug("Failed to set attribute %s on LLM span", key) + + +def install_feast_span_processor() -> None: + """Register ``feast_span_processor`` with the MLflow tracing system.""" + try: + import mlflow.tracing + + mlflow.tracing.configure(span_processors=[feast_span_processor]) + _logger.debug("Feast span processor installed") + except (ImportError, AttributeError): + _logger.debug( + "mlflow.tracing.configure not available; span processor not installed" + ) diff --git a/sdk/python/tests/integration/test_tracing_integration.py b/sdk/python/tests/integration/test_tracing_integration.py new file mode 100644 index 00000000000..afec6e89757 --- /dev/null +++ b/sdk/python/tests/integration/test_tracing_integration.py @@ -0,0 +1,235 @@ +# Copyright 2026 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for OTEL tracing with MLflow backend.""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import pytest + +from feast import Entity, FeatureService, FeatureStore, FeatureView, Field, FileSource +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig +from feast.types import Float32, Int64 + +mlflow = pytest.importorskip("mlflow", reason="mlflow is not installed") + +try: + from opentelemetry import trace as otel_trace # noqa: F401 + from opentelemetry.sdk.trace import TracerProvider # noqa: F401 + from opentelemetry.sdk.trace.export.in_memory import ( + InMemorySpanExporter, # type: ignore[import-untyped] # noqa: F401 + ) + + HAS_OTEL = True +except ImportError: + HAS_OTEL = False + +pytestmark = pytest.mark.skipif(not HAS_OTEL, reason="opentelemetry-sdk not installed") + + +import feast.mlflow # noqa: E402 +from feast.mlflow_integration import MlflowConfig # noqa: E402 + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _isolate_globals(): + """Reset module-level state between tests.""" + feast.mlflow._client = None + feast.mlflow._registered_store = None + + import feast.tracing + + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None + yield + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None + + +@pytest.fixture() +def tracking_uri(tmp_path): + uri = str(tmp_path / "mlruns") + mlflow.set_tracking_uri(uri) + mlflow.set_experiment("test_tracing") + yield uri + mlflow.set_tracking_uri("") + + +@pytest.fixture() +def driver_parquet(tmp_path): + data_dir = tmp_path / "data" + data_dir.mkdir() + + end = datetime.now().replace(microsecond=0, second=0, minute=0) + start = end - timedelta(days=7) + timestamps = pd.date_range(start, end, freq="h") + driver_ids = [1001, 1002] + + np.random.seed(42) + rows = [ + { + "driver_id": did, + "event_timestamp": ts, + "created": ts, + "conv_rate": float(np.random.uniform(0, 1)), + "acc_rate": float(np.random.uniform(0, 1)), + "avg_daily_trips": int(np.random.randint(1, 100)), + } + for ts in timestamps + for did in driver_ids + ] + df = pd.DataFrame(rows) + path = str(data_dir / "driver_stats.parquet") + df.to_parquet(path) + return tmp_path, path + + +@pytest.fixture() +def feast_objects(driver_parquet): + _, parquet_path = driver_parquet + + driver = Entity(name="driver", join_keys=["driver_id"]) + source = FileSource( + name="driver_stats_source", + path=parquet_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=7), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=source, + ) + fs = FeatureService(name="driver_activity_v1", features=[fv]) + return driver, source, fv, fs + + +def _make_store(tmp_path, tracking_uri, *, enable_tracing=True): + data_dir = tmp_path / "data" + data_dir.mkdir(exist_ok=True) + + config = RepoConfig( + project="test_tracing", + provider="local", + registry=str(data_dir / "registry.db"), + online_store=SqliteOnlineStoreConfig(path=str(data_dir / "online.db")), + entity_key_serialization_version=3, + mlflow=MlflowConfig( + enabled=True, + tracking_uri=tracking_uri, + auto_log=True, + enable_tracing=enable_tracing, + ), + ) + return FeatureStore(config=config) + + +@pytest.fixture() +def store_with_tracing(driver_parquet, tracking_uri, feast_objects): + tmp_path, _ = driver_parquet + store = _make_store(tmp_path, tracking_uri, enable_tracing=True) + store.apply(list(feast_objects)) + store.materialize( + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), + ) + return store + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTracingInitialization: + def test_lazy_init_enables_when_configured(self, store_with_tracing): + import feast.tracing + + result = feast.tracing._lazy_init(store_with_tracing) + assert result is True + assert feast.tracing._enabled is True + + def test_lazy_init_disabled_when_tracing_off( + self, driver_parquet, tracking_uri, feast_objects + ): + import feast.tracing + + tmp_path, _ = driver_parquet + store = _make_store(tmp_path, tracking_uri, enable_tracing=False) + result = feast.tracing._lazy_init(store) + assert result is False + + +class TestOnlineFeatureTracing: + def test_trace_context_populated_after_retrieval(self, store_with_tracing): + from feast.tracing_context import feast_trace_scope + + with feast_trace_scope() as ctx: + store_with_tracing.get_online_features( + features=["driver_hourly_stats:conv_rate"], + entity_rows=[{"driver_id": 1001}], + ) + assert len(ctx.feature_refs) > 0 + assert any("conv_rate" in ref for ref in ctx.feature_refs) + + def test_get_online_features_returns_data(self, store_with_tracing): + response = store_with_tracing.get_online_features( + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + ], + entity_rows=[{"driver_id": 1001}], + ) + result = response.to_dict() + assert "driver_id" in result + assert "conv_rate" in result + + +class TestHistoricalFeatureTracing: + def test_trace_context_populated_after_historical_retrieval( + self, store_with_tracing + ): + from feast.tracing_context import feast_trace_scope + + entity_df = pd.DataFrame( + { + "driver_id": [1001], + "event_timestamp": [datetime.now() - timedelta(hours=1)], + } + ) + + with feast_trace_scope() as ctx: + store_with_tracing.get_historical_features( + entity_df=entity_df, + features=["driver_hourly_stats:conv_rate"], + ).to_df() + assert len(ctx.feature_refs) > 0 diff --git a/sdk/python/tests/unit/test_tracing.py b/sdk/python/tests/unit/test_tracing.py new file mode 100644 index 00000000000..fb07799b7de --- /dev/null +++ b/sdk/python/tests/unit/test_tracing.py @@ -0,0 +1,326 @@ +# Copyright 2026 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for feast.tracing, feast.tracing_context, and feast.tracing_hooks.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# tracing_context tests +# --------------------------------------------------------------------------- + + +class TestFeastTraceContext: + def test_push_retrieval_accumulates_refs(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval( + feature_refs=["fv1:f1", "fv1:f2"], + feature_service="my_service", + span_id="span-1", + ) + assert ctx.feature_refs == ["fv1:f1", "fv1:f2"] + assert ctx.feature_views == {"fv1"} + assert ctx.feature_service == "my_service" + assert ctx.retrieval_span_ids == ["span-1"] + + def test_push_retrieval_multiple_calls(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1"], "svc1", "s1") + ctx.push_retrieval(["fv2:f3"], "svc2", "s2") + assert len(ctx.feature_refs) == 2 + assert ctx.feature_views == {"fv1", "fv2"} + assert ctx.feature_service == "svc2" + assert ctx.retrieval_span_ids == ["s1", "s2"] + + def test_get_context_attributes_empty(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + assert ctx.get_context_attributes() == {} + + def test_get_context_attributes_populated(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f2", "fv1:f1", "fv2:f3"], "my_svc") + + attrs = ctx.get_context_attributes() + assert attrs["feast.context_features"] == "fv1:f1,fv1:f2,fv2:f3" + assert attrs["feast.context_feature_count"] == "3" + assert "fv1" in attrs["feast.context_feature_views"] + assert "fv2" in attrs["feast.context_feature_views"] + assert attrs["feast.context_feature_service"] == "my_svc" + + def test_get_context_attributes_deduplicates(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1", "fv1:f1"]) + attrs = ctx.get_context_attributes() + assert attrs["feast.context_feature_count"] == "1" + + def test_clear(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1"], "svc", "s1") + ctx.clear() + assert ctx.feature_refs == [] + assert ctx.feature_views == set() + assert ctx.feature_service is None + assert ctx.retrieval_span_ids == [] + + +class TestFeastTraceScope: + def test_scope_creates_and_cleans_up(self): + from feast.tracing_context import feast_trace_scope, get_current_context + + assert get_current_context() is None + with feast_trace_scope() as ctx: + assert get_current_context() is ctx + ctx.push_retrieval(["fv1:f1"]) + assert len(ctx.feature_refs) == 1 + assert get_current_context() is None + + def test_nested_scope_replaces_outer(self): + from feast.tracing_context import feast_trace_scope, get_current_context + + with feast_trace_scope() as outer: + outer.push_retrieval(["fv1:f1"]) + with feast_trace_scope() as inner: + assert get_current_context() is inner + assert inner.feature_refs == [] + assert get_current_context() is None + + +# --------------------------------------------------------------------------- +# tracing_hooks tests +# --------------------------------------------------------------------------- + + +class TestFeastSpanProcessor: + def test_skips_non_llm_span(self): + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "RETRIEVER" + feast_span_processor(span) + span.set_attribute.assert_not_called() + + def test_skips_when_no_context(self): + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "LLM" + feast_span_processor(span) + span.set_attribute.assert_not_called() + + def test_tags_llm_span_with_feast_context(self): + from feast.tracing_context import feast_trace_scope + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "LLM" + + with feast_trace_scope() as ctx: + ctx.push_retrieval(["fv1:f1", "fv2:f2"], "my_svc") + feast_span_processor(span) + + calls = {c.args[0]: c.args[1] for c in span.set_attribute.call_args_list} + assert "feast.context_features" in calls + assert "fv1:f1" in calls["feast.context_features"] + assert "fv2:f2" in calls["feast.context_features"] + assert calls["feast.context_feature_service"] == "my_svc" + + def test_tags_chat_model_span(self): + from feast.tracing_context import feast_trace_scope + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "CHAT_MODEL" + + with feast_trace_scope() as ctx: + ctx.push_retrieval(["fv1:f1"]) + feast_span_processor(span) + + span.set_attribute.assert_called() + + +class TestInstallFeastSpanProcessor: + def test_install_calls_configure(self): + try: + import mlflow.tracing + + has_mlflow_tracing = True + except (ImportError, AttributeError): + has_mlflow_tracing = False + + if has_mlflow_tracing: + with patch.object(mlflow.tracing, "configure") as mock_configure: + from feast.tracing_hooks import install_feast_span_processor + + install_feast_span_processor() + mock_configure.assert_called_once() + else: + mock_tracing = MagicMock() + with patch.dict("sys.modules", {"mlflow.tracing": mock_tracing}): + import importlib + + import feast.tracing_hooks + + importlib.reload(feast.tracing_hooks) + feast.tracing_hooks.install_feast_span_processor() + mock_tracing.configure.assert_called_once() + + def test_install_graceful_on_missing_mlflow(self): + from feast.tracing_hooks import install_feast_span_processor + + try: + import mlflow.tracing + + with patch.object(mlflow.tracing, "configure", side_effect=AttributeError): + install_feast_span_processor() + except (ImportError, AttributeError): + install_feast_span_processor() + + +# --------------------------------------------------------------------------- +# tracing.py tests +# --------------------------------------------------------------------------- + + +class TestLazyInit: + def setup_method(self): + import feast.tracing + + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None + + def test_disabled_when_no_mlflow_config(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None + assert feast.tracing._lazy_init(store) is False + + def test_disabled_when_mlflow_not_enabled(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow.enabled = False + assert feast.tracing._lazy_init(store) is False + + def test_disabled_when_enable_tracing_false(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow.enabled = True + store.config.mlflow.enable_tracing = False + assert feast.tracing._lazy_init(store) is False + + def test_idempotent(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None + feast.tracing._lazy_init(store) + feast.tracing._lazy_init(store) + # Should only set _initialized once (no crash on second call) + assert feast.tracing._initialized is True + + +class TestIsEmbeddedStore: + def test_milvus_with_path_no_host(self): + from feast.tracing import _is_embedded_store + + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="milvus", path="data/online.db", host=None + ) + assert _is_embedded_store(store) is True + + def test_milvus_with_host(self): + from feast.tracing import _is_embedded_store + + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="milvus", path=None, host="localhost" + ) + assert _is_embedded_store(store) is False + + def test_sqlite_not_embedded(self): + from feast.tracing import _is_embedded_store + + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="sqlite", path="data/online.db" + ) + assert _is_embedded_store(store) is False + + def test_none_config(self): + from feast.tracing import _is_embedded_store + + store = MagicMock() + store.config.online_config = None + assert _is_embedded_store(store) is False + + def test_dict_config(self): + from feast.tracing import _is_embedded_store + + store = MagicMock() + store.config.online_config = {"type": "milvus", "path": "data/db", "host": ""} + assert _is_embedded_store(store) is True + + +class TestTracedToolSpan: + def test_noop_when_disabled(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None + + with feast.tracing.traced_tool_span(store, "test.span") as span: + assert span is None + + +# --------------------------------------------------------------------------- +# MlflowConfig.enable_tracing field test +# --------------------------------------------------------------------------- + + +class TestMlflowConfigEnableTracing: + def test_default_is_true(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig() + assert cfg.enable_tracing is True + + def test_can_disable(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig(enable_tracing=False) + assert cfg.enable_tracing is False