From d2f93be57baf45a8e0876e23a986aaadfd70513b Mon Sep 17 00:00:00 2001 From: Vanshika Vanshika Date: Thu, 14 May 2026 19:03:54 +0530 Subject: [PATCH 1/2] mlflow-tracing Signed-off-by: Vanshika Vanshika rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED --- sdk/python/feast/mlflow_integration/config.py | 69 ++++ sdk/python/feast/tracing.py | 84 ++++ sdk/python/feast/tracing_context.py | 85 ++++ sdk/python/feast/tracing_hooks.py | 50 +++ .../integration/test_tracing_integration.py | 220 +++++++++++ sdk/python/tests/unit/test_tracing.py | 366 ++++++++++++++++++ 6 files changed, 874 insertions(+) create mode 100644 sdk/python/feast/mlflow_integration/config.py create mode 100644 sdk/python/feast/tracing.py create mode 100644 sdk/python/feast/tracing_context.py create mode 100644 sdk/python/feast/tracing_hooks.py create mode 100644 sdk/python/tests/integration/test_tracing_integration.py create mode 100644 sdk/python/tests/unit/test_tracing.py diff --git a/sdk/python/feast/mlflow_integration/config.py b/sdk/python/feast/mlflow_integration/config.py new file mode 100644 index 00000000000..27344666a9f --- /dev/null +++ b/sdk/python/feast/mlflow_integration/config.py @@ -0,0 +1,69 @@ +import os +from typing import Optional + +from pydantic import StrictBool, StrictInt, StrictStr + +from feast.repo_config import FeastBaseModel + +MLFLOW_TAG_MAX_LENGTH = 5000 +MLFLOW_TAG_TRUNCATION_LIMIT = MLFLOW_TAG_MAX_LENGTH - 10 +MLFLOW_TAG_TRUNCATION_SLICE = MLFLOW_TAG_MAX_LENGTH - 13 + +MLFLOW_PARAM_MAX_LENGTH = 500 +MLFLOW_PARAM_TRUNCATION_LIMIT = MLFLOW_PARAM_MAX_LENGTH - 10 +MLFLOW_PARAM_TRUNCATION_SLICE = MLFLOW_PARAM_MAX_LENGTH - 13 + +DEFAULT_ENTITY_DF_MAX_ROWS = 100_000 + + +def resolve_tracking_uri(configured_uri: Optional[str] = None) -> Optional[str]: + """Return the effective MLflow tracking URI. + + Priority: + 1. Explicitly configured URI from feature_store.yaml + 2. MLFLOW_TRACKING_URI environment variable (MLflow's native convention) + 3. None — let MLflow fall back to its own defaults (local ./mlruns) + """ + if configured_uri: + return configured_uri + return os.environ.get("MLFLOW_TRACKING_URI") + + +class MlflowConfig(FeastBaseModel): + enabled: StrictBool = False + """ bool: Whether MLflow integration is enabled. Defaults to False. """ + + tracking_uri: Optional[StrictStr] = None + """ str: MLflow tracking URI. When not set, the MLFLOW_TRACKING_URI + environment variable is used. If neither is set, MLflow falls back + to its own default (local ./mlruns directory). """ + + auto_log: StrictBool = True + """ bool: Automatically log feature retrieval metadata to the active + MLflow run when get_historical_features or get_online_features is + called. Defaults to True. """ + + auto_log_entity_df: StrictBool = False + """ bool: When True, the input entity_df (or SQL query) is recorded in + the MLflow run. Defaults to False. """ + + entity_df_max_rows: StrictInt = DEFAULT_ENTITY_DF_MAX_ROWS + """ int: Maximum number of entity DataFrame rows to save as an MLflow + artifact. DataFrames exceeding this limit are skipped to avoid + OOM and slow uploads. Defaults to 100000. """ + + log_operations: StrictBool = False + """ bool: Log feast apply and materialize operations to a separate + MLflow experiment. Opt-in to avoid noise. Defaults to False. """ + + ops_experiment_suffix: StrictStr = "-feast-ops" + """ str: Suffix appended to the project name to form the MLflow + experiment name for operation logs. Defaults to '-feast-ops'. """ + + enable_tracing: StrictBool = True + """ bool: When True and mlflow.enabled=True, initialize OTEL TracerProvider + with MlflowSpanExporter for distributed tracing. Defaults to True. """ + + def get_tracking_uri(self) -> Optional[str]: + """Resolve the effective tracking URI for this config instance.""" + return resolve_tracking_uri(self.tracking_uri) diff --git a/sdk/python/feast/tracing.py b/sdk/python/feast/tracing.py new file mode 100644 index 00000000000..73ce6595ecb --- /dev/null +++ b/sdk/python/feast/tracing.py @@ -0,0 +1,84 @@ +""" +MLflow tracing for Feast feature server. + +Each MCP tool call produces one trace with a single span containing +full Feast metadata (project, features, entity counts, etc.). + +Tracing is initialized lazily on first span creation (post-fork in +gunicorn workers) to avoid conflicts with forked processes. Uses local +file-based tracing (``./mlruns``) which is safe with embedded stores +like Milvus-lite. View traces with ``mlflow ui`` from the server dir. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional + +if TYPE_CHECKING: + from feast.feature_store import FeatureStore + +_logger = logging.getLogger(__name__) + +_initialized = False +_enabled = False + + +def _lazy_init(store: "FeatureStore") -> bool: + """Initialize tracing on first use (post-fork safe). + + Called from ``traced_tool_span`` the first time a traced endpoint + is hit. At this point the code is running inside the gunicorn + worker, so importing MLflow and starting spans is safe. + """ + global _initialized, _enabled + + if _initialized: + return _enabled + _initialized = True + + mlflow_cfg = getattr(store.config, "mlflow", None) + if mlflow_cfg is None or not mlflow_cfg.enabled: + return False + if not getattr(mlflow_cfg, "enable_tracing", True): + return False + + try: + import mlflow + + if hasattr(mlflow, "start_span"): + _enabled = True + _logger.info("Feast tracing initialized (MLflow native, post-fork)") + return True + except ImportError: + pass + + return False + + +@contextmanager +def traced_tool_span( + store: "FeatureStore", + name: str, + attributes: Optional[Dict[str, str]] = None, +) -> Iterator[Any]: + """Context manager that creates a traced span for one tool call. + + Lazily initializes tracing on first use. If tracing is disabled + or unavailable the body runs with ``span=None``. + """ + if not _lazy_init(store): + yield None + return + + try: + import mlflow + + with mlflow.start_span(name=name) as span: + if attributes: + span.set_attributes(attributes) + yield span + except Exception as exc: + _logger.debug("Traced tool span failed: %s", exc) + yield None diff --git a/sdk/python/feast/tracing_context.py b/sdk/python/feast/tracing_context.py new file mode 100644 index 00000000000..560fb1da05a --- /dev/null +++ b/sdk/python/feast/tracing_context.py @@ -0,0 +1,85 @@ +""" +Thread-local trace-scoped feature accumulator. + +During an agent→Feast round-trip the server-side retrieval pushes feature refs +into this buffer. After the response returns, the agent-side +``FeastSpanProcessor`` reads it to tag the LLM span with +``feast.context_features``. +""" + +from __future__ import annotations + +import threading +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Dict, Iterator, List, Optional, Set + + +@dataclass +class FeastTraceContext: + """Accumulated feature-retrieval metadata for the current trace.""" + + feature_refs: List[str] = field(default_factory=list) + feature_views: Set[str] = field(default_factory=set) + feature_service: Optional[str] = None + retrieval_span_ids: List[str] = field(default_factory=list) + + def push_retrieval( + self, + feature_refs: List[str], + feature_service: Optional[str] = None, + span_id: Optional[str] = None, + ) -> None: + """Record one retrieval's metadata into this context.""" + self.feature_refs.extend(feature_refs) + for ref in feature_refs: + parts = ref.split(":", 1) + if len(parts) == 2: + self.feature_views.add(parts[0]) + if feature_service: + self.feature_service = feature_service + if span_id: + self.retrieval_span_ids.append(span_id) + + def get_context_attributes(self) -> Dict[str, str]: + """Return span-attribute-ready dict of accumulated metadata.""" + if not self.feature_refs: + return {} + deduplicated = sorted(set(self.feature_refs)) + attrs: Dict[str, str] = { + "feast.context_features": ",".join(deduplicated), + "feast.context_feature_count": str(len(deduplicated)), + } + if self.feature_views: + attrs["feast.context_feature_views"] = ",".join( + sorted(self.feature_views) + ) + if self.feature_service: + attrs["feast.context_feature_service"] = self.feature_service + return attrs + + def clear(self) -> None: + self.feature_refs.clear() + self.feature_views.clear() + self.feature_service = None + self.retrieval_span_ids.clear() + + +_thread_local = threading.local() + + +def get_current_context() -> Optional[FeastTraceContext]: + """Return the active ``FeastTraceContext`` for this thread, or ``None``.""" + return getattr(_thread_local, "feast_ctx", None) + + +@contextmanager +def feast_trace_scope() -> Iterator[FeastTraceContext]: + """Context manager that creates and cleans up a ``FeastTraceContext``.""" + ctx = FeastTraceContext() + _thread_local.feast_ctx = ctx + try: + yield ctx + finally: + ctx.clear() + _thread_local.feast_ctx = None diff --git a/sdk/python/feast/tracing_hooks.py b/sdk/python/feast/tracing_hooks.py new file mode 100644 index 00000000000..53aa54bf1b8 --- /dev/null +++ b/sdk/python/feast/tracing_hooks.py @@ -0,0 +1,50 @@ +""" +MLflow SpanProcessor that tags LLM spans with Feast feature context. + +Install via ``install_feast_span_processor()`` which registers a callable +with ``mlflow.tracing.configure(span_processors=[...])``. The processor +only fires on ``LLM`` / ``CHAT_MODEL`` span types and reads the thread-local +``FeastTraceContext`` populated during feature retrieval. +""" + +from __future__ import annotations + +import logging +from typing import Any + +_logger = logging.getLogger(__name__) + +_LLM_SPAN_TYPES = {"LLM", "CHAT_MODEL"} + + +def feast_span_processor(span: Any) -> None: + """Callable for ``mlflow.tracing.configure(span_processors=[...])``.""" + span_type = getattr(span, "span_type", None) + if span_type not in _LLM_SPAN_TYPES: + return + + from feast.tracing_context import get_current_context + + ctx = get_current_context() + if ctx is None or not ctx.feature_refs: + return + + attrs = ctx.get_context_attributes() + for key, value in attrs.items(): + try: + span.set_attribute(key, value) + except Exception: + _logger.debug("Failed to set attribute %s on LLM span", key) + + +def install_feast_span_processor() -> None: + """Register ``feast_span_processor`` with the MLflow tracing system.""" + try: + import mlflow.tracing + + mlflow.tracing.configure(span_processors=[feast_span_processor]) + _logger.debug("Feast span processor installed") + except (ImportError, AttributeError): + _logger.debug( + "mlflow.tracing.configure not available; span processor not installed" + ) diff --git a/sdk/python/tests/integration/test_tracing_integration.py b/sdk/python/tests/integration/test_tracing_integration.py new file mode 100644 index 00000000000..20a86072b95 --- /dev/null +++ b/sdk/python/tests/integration/test_tracing_integration.py @@ -0,0 +1,220 @@ +# Copyright 2026 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for OTEL tracing with MLflow backend.""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import pytest + +from feast import Entity, FeatureService, FeatureStore, FeatureView, Field, FileSource +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig +from feast.types import Float32, Int64 + +mlflow = pytest.importorskip("mlflow", reason="mlflow is not installed") + +try: + from opentelemetry import trace as otel_trace # noqa: F401 + from opentelemetry.sdk.trace import TracerProvider # noqa: F401 + from opentelemetry.sdk.trace.export.in_memory import InMemorySpanExporter # type: ignore[import-untyped] # noqa: F401 + + HAS_OTEL = True +except ImportError: + HAS_OTEL = False + +pytestmark = pytest.mark.skipif(not HAS_OTEL, reason="opentelemetry-sdk not installed") + + +import feast.mlflow # noqa: E402 +from feast.mlflow_integration import MlflowConfig # noqa: E402 + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _isolate_globals(): + """Reset module-level state between tests.""" + feast.mlflow._client = None + feast.mlflow._registered_store = None + + import feast.tracing + + feast.tracing._tracer_initialized = False + yield + feast.tracing._tracer_initialized = False + + +@pytest.fixture() +def tracking_uri(tmp_path): + uri = str(tmp_path / "mlruns") + mlflow.set_tracking_uri(uri) + mlflow.set_experiment("test_tracing") + yield uri + mlflow.set_tracking_uri("") + + +@pytest.fixture() +def driver_parquet(tmp_path): + data_dir = tmp_path / "data" + data_dir.mkdir() + + end = datetime.now().replace(microsecond=0, second=0, minute=0) + start = end - timedelta(days=7) + timestamps = pd.date_range(start, end, freq="h") + driver_ids = [1001, 1002] + + np.random.seed(42) + rows = [ + { + "driver_id": did, + "event_timestamp": ts, + "created": ts, + "conv_rate": float(np.random.uniform(0, 1)), + "acc_rate": float(np.random.uniform(0, 1)), + "avg_daily_trips": int(np.random.randint(1, 100)), + } + for ts in timestamps + for did in driver_ids + ] + df = pd.DataFrame(rows) + path = str(data_dir / "driver_stats.parquet") + df.to_parquet(path) + return tmp_path, path + + +@pytest.fixture() +def feast_objects(driver_parquet): + _, parquet_path = driver_parquet + + driver = Entity(name="driver", join_keys=["driver_id"]) + source = FileSource( + name="driver_stats_source", + path=parquet_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=7), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=source, + ) + fs = FeatureService(name="driver_activity_v1", features=[fv]) + return driver, source, fv, fs + + +def _make_store(tmp_path, tracking_uri, *, enable_tracing=True): + data_dir = tmp_path / "data" + data_dir.mkdir(exist_ok=True) + + config = RepoConfig( + project="test_tracing", + provider="local", + registry=str(data_dir / "registry.db"), + online_store=SqliteOnlineStoreConfig(path=str(data_dir / "online.db")), + entity_key_serialization_version=3, + mlflow=MlflowConfig( + enabled=True, + tracking_uri=tracking_uri, + auto_log=True, + enable_tracing=enable_tracing, + ), + ) + return FeatureStore(config=config) + + +@pytest.fixture() +def store_with_tracing(driver_parquet, tracking_uri, feast_objects): + tmp_path, _ = driver_parquet + store = _make_store(tmp_path, tracking_uri, enable_tracing=True) + store.apply(list(feast_objects)) + store.materialize( + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), + ) + return store + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTracingInitialization: + def test_tracer_is_set_when_enabled(self, store_with_tracing): + assert store_with_tracing._tracer is not None + + def test_tracer_is_none_when_disabled(self, driver_parquet, tracking_uri, feast_objects): + tmp_path, _ = driver_parquet + store = _make_store(tmp_path, tracking_uri, enable_tracing=False) + assert store._tracer is None + + +class TestOnlineFeatureTracing: + def test_trace_context_populated_after_retrieval(self, store_with_tracing): + from feast.tracing_context import feast_trace_scope + + with feast_trace_scope() as ctx: + store_with_tracing.get_online_features( + features=["driver_hourly_stats:conv_rate"], + entity_rows=[{"driver_id": 1001}], + ) + assert len(ctx.feature_refs) > 0 + assert any("conv_rate" in ref for ref in ctx.feature_refs) + + def test_get_online_features_returns_data(self, store_with_tracing): + response = store_with_tracing.get_online_features( + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + ], + entity_rows=[{"driver_id": 1001}], + ) + result = response.to_dict() + assert "driver_id" in result + assert "conv_rate" in result + + +class TestHistoricalFeatureTracing: + def test_trace_context_populated_after_historical_retrieval( + self, store_with_tracing + ): + from feast.tracing_context import feast_trace_scope + + entity_df = pd.DataFrame( + { + "driver_id": [1001], + "event_timestamp": [datetime.now() - timedelta(hours=1)], + } + ) + + with feast_trace_scope() as ctx: + store_with_tracing.get_historical_features( + entity_df=entity_df, + features=["driver_hourly_stats:conv_rate"], + ).to_df() + assert len(ctx.feature_refs) > 0 diff --git a/sdk/python/tests/unit/test_tracing.py b/sdk/python/tests/unit/test_tracing.py new file mode 100644 index 00000000000..de96d7af2d0 --- /dev/null +++ b/sdk/python/tests/unit/test_tracing.py @@ -0,0 +1,366 @@ +# Copyright 2026 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for feast.tracing, feast.tracing_context, and feast.tracing_hooks.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# tracing_context tests +# --------------------------------------------------------------------------- + + +class TestFeastTraceContext: + def test_push_retrieval_accumulates_refs(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval( + feature_refs=["fv1:f1", "fv1:f2"], + feature_service="my_service", + span_id="span-1", + ) + assert ctx.feature_refs == ["fv1:f1", "fv1:f2"] + assert ctx.feature_views == {"fv1"} + assert ctx.feature_service == "my_service" + assert ctx.retrieval_span_ids == ["span-1"] + + def test_push_retrieval_multiple_calls(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1"], "svc1", "s1") + ctx.push_retrieval(["fv2:f3"], "svc2", "s2") + assert len(ctx.feature_refs) == 2 + assert ctx.feature_views == {"fv1", "fv2"} + assert ctx.feature_service == "svc2" + assert ctx.retrieval_span_ids == ["s1", "s2"] + + def test_get_context_attributes_empty(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + assert ctx.get_context_attributes() == {} + + def test_get_context_attributes_populated(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f2", "fv1:f1", "fv2:f3"], "my_svc") + + attrs = ctx.get_context_attributes() + assert attrs["feast.context_features"] == "fv1:f1,fv1:f2,fv2:f3" + assert attrs["feast.context_feature_count"] == "3" + assert "fv1" in attrs["feast.context_feature_views"] + assert "fv2" in attrs["feast.context_feature_views"] + assert attrs["feast.context_feature_service"] == "my_svc" + + def test_get_context_attributes_deduplicates(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1", "fv1:f1"]) + attrs = ctx.get_context_attributes() + assert attrs["feast.context_feature_count"] == "1" + + def test_clear(self): + from feast.tracing_context import FeastTraceContext + + ctx = FeastTraceContext() + ctx.push_retrieval(["fv1:f1"], "svc", "s1") + ctx.clear() + assert ctx.feature_refs == [] + assert ctx.feature_views == set() + assert ctx.feature_service is None + assert ctx.retrieval_span_ids == [] + + +class TestFeastTraceScope: + def test_scope_creates_and_cleans_up(self): + from feast.tracing_context import feast_trace_scope, get_current_context + + assert get_current_context() is None + with feast_trace_scope() as ctx: + assert get_current_context() is ctx + ctx.push_retrieval(["fv1:f1"]) + assert len(ctx.feature_refs) == 1 + assert get_current_context() is None + + def test_nested_scope_replaces_outer(self): + from feast.tracing_context import feast_trace_scope, get_current_context + + with feast_trace_scope() as outer: + outer.push_retrieval(["fv1:f1"]) + with feast_trace_scope() as inner: + assert get_current_context() is inner + assert inner.feature_refs == [] + assert get_current_context() is None + + +# --------------------------------------------------------------------------- +# tracing_hooks tests +# --------------------------------------------------------------------------- + + +class TestFeastSpanProcessor: + def test_skips_non_llm_span(self): + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "RETRIEVER" + feast_span_processor(span) + span.set_attribute.assert_not_called() + + def test_skips_when_no_context(self): + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "LLM" + feast_span_processor(span) + span.set_attribute.assert_not_called() + + def test_tags_llm_span_with_feast_context(self): + from feast.tracing_context import feast_trace_scope + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "LLM" + + with feast_trace_scope() as ctx: + ctx.push_retrieval(["fv1:f1", "fv2:f2"], "my_svc") + feast_span_processor(span) + + calls = {c.args[0]: c.args[1] for c in span.set_attribute.call_args_list} + assert "feast.context_features" in calls + assert "fv1:f1" in calls["feast.context_features"] + assert "fv2:f2" in calls["feast.context_features"] + assert calls["feast.context_feature_service"] == "my_svc" + + def test_tags_chat_model_span(self): + from feast.tracing_context import feast_trace_scope + from feast.tracing_hooks import feast_span_processor + + span = MagicMock() + span.span_type = "CHAT_MODEL" + + with feast_trace_scope() as ctx: + ctx.push_retrieval(["fv1:f1"]) + feast_span_processor(span) + + span.set_attribute.assert_called() + + +class TestInstallFeastSpanProcessor: + def test_install_calls_configure(self): + try: + import mlflow.tracing + + has_mlflow_tracing = True + except (ImportError, AttributeError): + has_mlflow_tracing = False + + if has_mlflow_tracing: + with patch.object(mlflow.tracing, "configure") as mock_configure: + from feast.tracing_hooks import install_feast_span_processor + + install_feast_span_processor() + mock_configure.assert_called_once() + else: + mock_tracing = MagicMock() + with patch.dict("sys.modules", {"mlflow.tracing": mock_tracing}): + import importlib + + import feast.tracing_hooks + + importlib.reload(feast.tracing_hooks) + feast.tracing_hooks.install_feast_span_processor() + mock_tracing.configure.assert_called_once() + + def test_install_graceful_on_missing_mlflow(self): + from feast.tracing_hooks import install_feast_span_processor + + try: + import mlflow.tracing + + with patch.object( + mlflow.tracing, "configure", side_effect=AttributeError + ): + install_feast_span_processor() + except (ImportError, AttributeError): + install_feast_span_processor() + + +# --------------------------------------------------------------------------- +# tracing.py tests +# --------------------------------------------------------------------------- + + +class TestInitTracing: + def setup_method(self): + import feast.tracing + + feast.tracing._tracer_initialized = False + + def test_returns_none_when_otel_missing(self): + with patch.dict("sys.modules", {"opentelemetry": None}): + import importlib + + import feast.tracing + + importlib.reload(feast.tracing) + feast.tracing._tracer_initialized = False + result = feast.tracing.init_tracing(MagicMock()) + assert result is None + + def test_returns_tracer_with_mocked_deps(self): + mock_trace = MagicMock() + mock_tracer = MagicMock() + mock_trace.get_tracer.return_value = mock_tracer + mock_provider_cls = MagicMock() + mock_processor_cls = MagicMock() + mock_exporter_cls = MagicMock() + + modules = { + "opentelemetry": MagicMock(trace=mock_trace), + "opentelemetry.trace": mock_trace, + "opentelemetry.sdk": MagicMock(), + "opentelemetry.sdk.trace": MagicMock(TracerProvider=mock_provider_cls), + "opentelemetry.sdk.trace.export": MagicMock(SimpleSpanProcessor=mock_processor_cls), + "mlflow": MagicMock(), + "mlflow.tracing": MagicMock(), + "mlflow.tracing.export": MagicMock(MlflowSpanExporter=mock_exporter_cls), + } + + with patch.dict("sys.modules", modules): + import importlib + + import feast.tracing + + importlib.reload(feast.tracing) + feast.tracing._tracer_initialized = False + + result = feast.tracing.init_tracing(MagicMock()) + assert result is mock_tracer + mock_trace.set_tracer_provider.assert_called_once() + + def test_idempotent(self): + mock_trace = MagicMock() + mock_tracer = MagicMock() + mock_trace.get_tracer.return_value = mock_tracer + + modules = { + "opentelemetry": MagicMock(trace=mock_trace), + "opentelemetry.trace": mock_trace, + "opentelemetry.sdk": MagicMock(), + "opentelemetry.sdk.trace": MagicMock(TracerProvider=MagicMock()), + "opentelemetry.sdk.trace.export": MagicMock(SimpleSpanProcessor=MagicMock()), + "mlflow": MagicMock(), + "mlflow.tracing": MagicMock(), + "mlflow.tracing.export": MagicMock(MlflowSpanExporter=MagicMock()), + } + + with patch.dict("sys.modules", modules): + import importlib + + import feast.tracing + + importlib.reload(feast.tracing) + feast.tracing._tracer_initialized = False + + feast.tracing.init_tracing(MagicMock()) + feast.tracing.init_tracing(MagicMock()) + mock_trace.set_tracer_provider.assert_called_once() + + +class TestTracedDecorator: + def test_noop_without_tracer(self): + from feast.tracing import traced + + @traced("test.span") + def my_method(self_): + return 42 + + obj = SimpleNamespace() + assert my_method(obj) == 42 + + def test_creates_span_with_tracer(self): + from feast.tracing import traced + + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__ = MagicMock( + return_value=mock_span + ) + mock_tracer.start_as_current_span.return_value.__exit__ = MagicMock( + return_value=False + ) + + @traced("test.span") + def my_method(self_): + return 99 + + obj = SimpleNamespace(_tracer=mock_tracer, project="test_project") + result = my_method(obj) + assert result == 99 + mock_tracer.start_as_current_span.assert_called_once_with("test.span") + mock_span.set_attribute.assert_called_with("feast.project", "test_project") + + def test_records_exception(self): + from feast.tracing import traced + + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__ = MagicMock( + return_value=mock_span + ) + mock_tracer.start_as_current_span.return_value.__exit__ = MagicMock( + return_value=False + ) + + @traced("test.span") + def bad_method(self_): + raise ValueError("boom") + + obj = SimpleNamespace(_tracer=mock_tracer, project="p") + + with patch("feast.tracing._StatusCode_ERROR") as mock_status: + with pytest.raises(ValueError, match="boom"): + bad_method(obj) + mock_span.record_exception.assert_called_once() + mock_span.set_status.assert_called_once() + + +# --------------------------------------------------------------------------- +# MlflowConfig.enable_tracing field test +# --------------------------------------------------------------------------- + + +class TestMlflowConfigEnableTracing: + def test_default_is_true(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig() + assert cfg.enable_tracing is True + + def test_can_disable(self): + from feast.mlflow_integration.config import MlflowConfig + + cfg = MlflowConfig(enable_tracing=False) + assert cfg.enable_tracing is False From 8fc6d889d8af4d6d910409a55ebd05cfa6b869d0 Mon Sep 17 00:00:00 2001 From: Vanshika Vanshika Date: Tue, 19 May 2026 17:00:59 +0530 Subject: [PATCH 2/2] cross-process trace Signed-off-by: Vanshika Vanshika rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED --- .../feature_repo/feature_store.yaml | 5 + pyproject.toml | 7 + sdk/python/feast/feature_server.py | 229 ++++++++++++------ .../feast/infra/mcp_servers/mcp_server.py | 34 ++- sdk/python/feast/repo_config.py | 9 + sdk/python/feast/tracing.py | 134 ++++++++-- .../integration/test_tracing_integration.py | 29 ++- sdk/python/tests/unit/test_tracing.py | 184 ++++++-------- 8 files changed, 414 insertions(+), 217 deletions(-) diff --git a/examples/agent_feature_store/feature_repo/feature_store.yaml b/examples/agent_feature_store/feature_repo/feature_store.yaml index 7cc89d536c8..5a4cafbb00b 100644 --- a/examples/agent_feature_store/feature_repo/feature_store.yaml +++ b/examples/agent_feature_store/feature_repo/feature_store.yaml @@ -16,6 +16,11 @@ offline_store: entity_key_serialization_version: 3 +mlflow: + enabled: true + tracking_uri: "http://localhost:5000" + enable_tracing: true + feature_server: type: mcp enabled: true diff --git a/pyproject.toml b/pyproject.toml index 9f25f82ad0e..0efd1403e63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,13 @@ mssql = ["ibis-framework[mssql]>=10.0.0"] oracle = ["ibis-framework[oracle]>=10.0.0"] mysql = ["pymysql", "types-PyMySQL"] openlineage = ["openlineage-python>=1.40.0"] +mlflow = [ + "mlflow>=2.14.0", + "opentelemetry-api>=1.28.0", + "opentelemetry-sdk>=1.28.0", + "opentelemetry-instrumentation-fastapi>=0.49b0", + "opentelemetry-instrumentation-httpx>=0.49b0", +] opentelemetry = ["prometheus_client", "psutil"] spark = ["pyspark>=4.0.0"] trino = ["trino>=0.305.0,<0.400.0", "regex"] diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index f60eeb9d87d..dd2453e4917 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -251,6 +251,39 @@ async def load_static_artifacts(app: FastAPI, store): logger.warning(f"Failed to load static artifacts: {e}") +def _instrument_app_for_tracing(app: FastAPI, store: "feast.FeatureStore") -> None: + """Add OTEL instrumentation to FastAPI if tracing is enabled. + + This enables automatic extraction of ``traceparent`` HTTP headers from + incoming requests, creating server spans that link to the caller's trace. + This is the Tier 3 bridge: when an agent sends traceparent, server spans + become children of the agent's trace tree. + """ + mlflow_cfg = store.config.mlflow + if mlflow_cfg is None or not mlflow_cfg.enabled or not mlflow_cfg.enable_tracing: + return + + from feast.tracing import _is_embedded_store + + tracking_uri = mlflow_cfg.get_tracking_uri() + if _is_embedded_store(store) and tracking_uri and tracking_uri.startswith("http"): + logger.info( + "Skipping FastAPI OTEL instrumentation (embedded store + HTTP tracking)" + ) + return + + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) + logger.info("FastAPI OTEL instrumentation enabled for trace propagation") + except ImportError: + logger.debug( + "opentelemetry-instrumentation-fastapi not installed; " + "cross-process trace linking disabled" + ) + + def get_app( store: "feast.FeatureStore", registry_ttl_sec: int = DEFAULT_FEATURE_SERVER_REGISTRY_TTL, @@ -363,44 +396,64 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + _instrument_app_for_tracing(app, store) + @app.post( "/get-online-features", dependencies=[Depends(inject_user_details)], response_model=OnlineFeaturesResponse, ) - async def get_online_features(request: GetOnlineFeaturesRequest) -> Any: - with feast_metrics.track_request_latency( - "/get-online-features", - ) as metrics_ctx: - features = await _get_features(request, store) - feat_count, fv_count = _resolve_feature_counts(features) - metrics_ctx.feature_count = feat_count - metrics_ctx.feature_view_count = fv_count - - entity_count = len(next(iter(request.entities.values()), [])) - feast_metrics.track_online_features_entities(entity_count) - - read_params = dict( - features=features, - entity_rows=request.entities, - full_feature_names=request.full_feature_names, - include_feature_view_version_metadata=request.include_feature_view_version_metadata, - ) + async def get_online_features( + request: GetOnlineFeaturesRequest, raw_request: Request + ) -> Any: + from feast.tracing import traced_tool_span - if store._get_provider().async_supported.online.read: - response = await store.get_online_features_async(**read_params) # type: ignore - else: - response = await run_in_threadpool( - lambda: store.get_online_features(**read_params) # type: ignore + session_id = raw_request.headers.get("mcp-session-id", "") + feature_refs = ",".join(request.features) if request.features else "" + entity_count = len(next(iter(request.entities.values()), [])) + + with traced_tool_span( + store, + "feast.get_online_features", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_refs": feature_refs, + "feast.entity_count": str(entity_count), + "feast.project": store.config.project, + "feast.retrieval_type": "online", + }, + ): + with feast_metrics.track_request_latency( + "/get-online-features", + ) as metrics_ctx: + features = await _get_features(request, store) + feat_count, fv_count = _resolve_feature_counts(features) + metrics_ctx.feature_count = feat_count + metrics_ctx.feature_view_count = fv_count + + feast_metrics.track_online_features_entities(entity_count) + + read_params = dict( + features=features, + entity_rows=request.entities, + full_feature_names=request.full_feature_names, + include_feature_view_version_metadata=request.include_feature_view_version_metadata, ) - response_dict = await run_in_threadpool( - MessageToDict, - response.proto, - preserving_proto_field_name=True, - float_precision=18, - ) - return response_dict + if store._get_provider().async_supported.online.read: + response = await store.get_online_features_async(**read_params) # type: ignore + else: + response = await run_in_threadpool( + lambda: store.get_online_features(**read_params) # type: ignore + ) + + response_dict = await run_in_threadpool( + MessageToDict, + response.proto, + preserving_proto_field_name=True, + float_precision=18, + ) + return response_dict @app.post( "/retrieve-online-documents", @@ -408,41 +461,58 @@ async def get_online_features(request: GetOnlineFeaturesRequest) -> Any: response_model=OnlineFeaturesResponse, ) async def retrieve_online_documents( - request: GetOnlineDocumentsRequest, + request: GetOnlineDocumentsRequest, raw_request: Request ) -> Any: - with feast_metrics.track_request_latency("/retrieve-online-documents"): - logger.warning( - "This endpoint is in alpha and will be moved to /get-online-features when stable." - ) - features = await _get_features(request, store) + from feast.tracing import traced_tool_span - read_params = dict( - features=features, - query=request.query, - top_k=request.top_k, - ) - if request.api_version == 2 and request.query_string is not None: - read_params["query_string"] = request.query_string + session_id = raw_request.headers.get("mcp-session-id", "") + feature_refs = ",".join(request.features) if request.features else "" + top_k = str(request.top_k) if request.top_k else "" - if request.api_version == 2: - read_params["include_feature_view_version_metadata"] = ( - request.include_feature_view_version_metadata - ) - response = await run_in_threadpool( - lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + with traced_tool_span( + store, + "feast.retrieve_online_documents", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_refs": feature_refs, + "feast.top_k": top_k, + "feast.project": store.config.project, + "feast.retrieval_type": "document", + }, + ): + with feast_metrics.track_request_latency("/retrieve-online-documents"): + logger.warning( + "This endpoint is in alpha and will be moved to /get-online-features when stable." ) - else: - response = await run_in_threadpool( - lambda: store.retrieve_online_documents(**read_params) # type: ignore + features = await _get_features(request, store) + + read_params = dict( + features=features, + query=request.query, + top_k=request.top_k, ) + if request.api_version == 2 and request.query_string is not None: + read_params["query_string"] = request.query_string - response_dict = await run_in_threadpool( - MessageToDict, - response.proto, - preserving_proto_field_name=True, - float_precision=18, - ) - return response_dict + if request.api_version == 2: + read_params["include_feature_view_version_metadata"] = ( + request.include_feature_view_version_metadata + ) + response = await run_in_threadpool( + lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + ) + else: + response = await run_in_threadpool( + lambda: store.retrieve_online_documents(**read_params) # type: ignore + ) + + response_dict = await run_in_threadpool( + MessageToDict, + response.proto, + preserving_proto_field_name=True, + float_precision=18, + ) + return response_dict @app.post("/push", dependencies=[Depends(inject_user_details)]) async def push(request: PushFeaturesRequest) -> Response: @@ -553,19 +623,34 @@ async def _get_feast_object( ) @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) - async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None: - df = pd.DataFrame(request.df) - feature_view_name = request.feature_view_name - allow_registry_cache = request.allow_registry_cache - resource = await _get_feast_object(feature_view_name, allow_registry_cache) - assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) - await run_in_threadpool( - store.write_to_online_store, - feature_view_name=feature_view_name, - df=df, - allow_registry_cache=allow_registry_cache, - transform_on_write=request.transform_on_write, - ) + async def write_to_online_store( + request: WriteToFeatureStoreRequest, raw_request: Request + ) -> None: + from feast.tracing import traced_tool_span + + session_id = raw_request.headers.get("mcp-session-id", "") + + with traced_tool_span( + store, + "feast.write_to_online_store", + attributes={ + "feast.mcp_session_id": session_id, + "feast.feature_view": request.feature_view_name, + "feast.project": store.config.project, + }, + ): + df = pd.DataFrame(request.df) + feature_view_name = request.feature_view_name + allow_registry_cache = request.allow_registry_cache + resource = await _get_feast_object(feature_view_name, allow_registry_cache) + assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) + await run_in_threadpool( + store.write_to_online_store, + feature_view_name=feature_view_name, + df=df, + allow_registry_cache=allow_registry_cache, + transform_on_write=request.transform_on_write, + ) @app.get("/health") async def health(): diff --git a/sdk/python/feast/infra/mcp_servers/mcp_server.py b/sdk/python/feast/infra/mcp_servers/mcp_server.py index 972023cdd12..ecf513b9de2 100644 --- a/sdk/python/feast/infra/mcp_servers/mcp_server.py +++ b/sdk/python/feast/infra/mcp_servers/mcp_server.py @@ -37,13 +37,20 @@ def add_mcp_support_to_app(app, store: FeatureStore, config) -> Optional["FastAp return None try: - # Create MCP server from the FastAPI app + # Create MCP server from the FastAPI app. + # Forward mcp-session-id so endpoint handlers can tag spans with it. mcp = FastApiMCP( app, name=getattr(config, "mcp_server_name", "feast-feature-store"), description="Feast Feature Store MCP Server - Access feature store data and operations through MCP", + headers=["authorization", "mcp-session-id"], ) + # Instrument the internal httpx client with OTEL so that trace + # context propagates from the /mcp server span into the internal + # ASGI calls to the actual FastAPI endpoints (Tier 3 enabler). + _instrument_mcp_http_client(mcp) + transport = getattr(config, "mcp_transport", "sse") if transport == "http": mount_http = getattr(mcp, "mount_http", None) @@ -83,3 +90,28 @@ def add_mcp_support_to_app(app, store: FeatureStore, config) -> Optional["FastAp except Exception as e: logger.error(f"Failed to initialize MCP integration: {e}", exc_info=True) return None + + +def _instrument_mcp_http_client(mcp: "FastApiMCP") -> None: + """Instrument fastapi_mcp's internal httpx client with OTEL. + + This ensures that when fastapi_mcp makes internal ASGI calls to + the actual FastAPI endpoints, the current OTEL trace context + (from the /mcp server span) is propagated via traceparent headers. + Without this, the endpoint spans would be orphaned from the + incoming trace. + """ + try: + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + http_client = getattr(mcp, "_http_client", None) + if http_client is not None: + HTTPXClientInstrumentor.instrument_client(http_client) + logger.info("MCP internal httpx client instrumented for trace propagation") + else: + logger.debug("Could not access fastapi_mcp internal httpx client") + except ImportError: + logger.debug( + "opentelemetry-instrumentation-httpx not installed; " + "internal trace propagation disabled" + ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 3fbcb9ec498..c9885324a7a 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -355,6 +355,9 @@ class RepoConfig(FeastBaseModel): openlineage_config: Optional[OpenLineageConfig] = Field(None, alias="openlineage") """ Configuration for OpenLineage data lineage integration (optional). """ + mlflow: Optional[Any] = None + """ MlflowConfig: MLflow integration and tracing configuration (optional). """ + def __init__(self, **data: Any): super().__init__(**data) @@ -395,6 +398,12 @@ def __init__(self, **data: Any): if "openlineage" in data: self.openlineage_config = data["openlineage"] + # Initialize MLflow configuration + if "mlflow" in data and isinstance(data["mlflow"], dict): + from feast.mlflow_integration.config import MlflowConfig + + self.mlflow = MlflowConfig(**data["mlflow"]) + if self.entity_key_serialization_version < 3: warnings.warn( "The serialization version below 3 are deprecated. " diff --git a/sdk/python/feast/tracing.py b/sdk/python/feast/tracing.py index 73ce6595ecb..ab4fd7cb024 100644 --- a/sdk/python/feast/tracing.py +++ b/sdk/python/feast/tracing.py @@ -1,13 +1,21 @@ """ -MLflow tracing for Feast feature server. +Feast MCP server tracing via OpenTelemetry + MLflow. -Each MCP tool call produces one trace with a single span containing -full Feast metadata (project, features, entity counts, etc.). +Initializes an OTEL TracerProvider lazily (post-fork safe for gunicorn). +Two export modes, auto-detected: -Tracing is initialized lazily on first span creation (post-fork in -gunicorn workers) to avoid conflicts with forked processes. Uses local -file-based tracing (``./mlruns``) which is safe with embedded stores -like Milvus-lite. View traces with ``mlflow ui`` from the server dir. + 1. File-based (Milvus-lite safe): + Spans export to local ./mlruns via MlflowSpanExporter. + No HTTP calls, no background threads, no segfault risk. + + 2. HTTP-based (production, non-embedded stores): + Spans export to a remote MLflow tracking server. + Enables cross-process trace stitching (Tier 3). + +Auto-detection logic: + - If online_store uses embedded milvus-lite (has ``path`` but no ``host``) + AND tracking_uri is HTTP → force file-based mode with a warning. + - Otherwise, use tracking_uri from config. """ from __future__ import annotations @@ -23,38 +31,102 @@ _initialized = False _enabled = False +_tracer: Any = None + + +def _is_embedded_store(store: "FeatureStore") -> bool: + """Detect if the online store is an embedded C++ runtime (milvus-lite). + + Milvus-lite embeds a C++ runtime in-process. Combined with MLflow's HTTP + export (background threads + network I/O) in a gunicorn-forked worker, + this causes segfaults. We detect this case so we can disable HTTP export. + """ + online_cfg = store.config.online_config + if online_cfg is None: + return False + store_type = "" + if isinstance(online_cfg, dict): + store_type = online_cfg.get("type", "") + has_path = bool(online_cfg.get("path")) + has_host = bool(online_cfg.get("host")) + else: + store_type = getattr(online_cfg, "type", "") or "" + has_path = bool(getattr(online_cfg, "path", None)) + has_host = bool(getattr(online_cfg, "host", None)) + + if "milvus" not in store_type.lower(): + return False + return has_path and not has_host def _lazy_init(store: "FeatureStore") -> bool: - """Initialize tracing on first use (post-fork safe). + """Initialize OTEL TracerProvider on first use (post-fork safe). - Called from ``traced_tool_span`` the first time a traced endpoint - is hit. At this point the code is running inside the gunicorn - worker, so importing MLflow and starting spans is safe. + Called from ``traced_tool_span`` the first time a traced endpoint is + hit. At this point the code runs inside the gunicorn worker, so + importing heavy libraries and setting up providers is safe. """ - global _initialized, _enabled + global _initialized, _enabled, _tracer if _initialized: return _enabled _initialized = True - mlflow_cfg = getattr(store.config, "mlflow", None) + mlflow_cfg = store.config.mlflow if mlflow_cfg is None or not mlflow_cfg.enabled: return False - if not getattr(mlflow_cfg, "enable_tracing", True): + if not mlflow_cfg.enable_tracing: + return False + + try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + except ImportError: + _logger.debug("opentelemetry-sdk not installed; tracing disabled") return False + tracking_uri = mlflow_cfg.get_tracking_uri() + is_embedded = _is_embedded_store(store) + is_http_uri = bool(tracking_uri and tracking_uri.startswith("http")) + + if is_embedded and is_http_uri: + _logger.warning( + "Embedded online store detected (milvus-lite). " + "Forcing file-based tracing to avoid segfault. " + "Use a remote Milvus cluster (online_store.host) for HTTP tracing." + ) + tracking_uri = None + try: import mlflow - if hasattr(mlflow, "start_span"): - _enabled = True - _logger.info("Feast tracing initialized (MLflow native, post-fork)") - return True + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(store.config.project) + except Exception as e: + _logger.warning("Failed to configure MLflow experiment: %s", e) + + try: + from mlflow.tracing.export import MlflowV3SpanExporter as Exporter except ImportError: - pass + try: + from mlflow.tracing.export import MlflowSpanExporter as Exporter + except ImportError: + _logger.debug("MLflow span exporter not available; tracing disabled") + return False + + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(Exporter())) + trace.set_tracer_provider(provider) + _tracer = trace.get_tracer("feast.server") + _enabled = True - return False + mode = "HTTP" if is_http_uri and not is_embedded else "file" + _logger.info( + "Feast tracing initialized (mode=%s, project=%s)", mode, store.config.project + ) + return True @contextmanager @@ -67,18 +139,30 @@ def traced_tool_span( Lazily initializes tracing on first use. If tracing is disabled or unavailable the body runs with ``span=None``. + + The span is created within the current OTEL context, so if + FastAPI OTEL instrumentation has set a parent span (from an + incoming ``traceparent`` header), this span becomes a child of + that parent — enabling cross-process trace linking. """ if not _lazy_init(store): yield None return try: - import mlflow - - with mlflow.start_span(name=name) as span: - if attributes: - span.set_attributes(attributes) + from opentelemetry import context, trace + + current_ctx = context.get_current() + span = _tracer.start_span(name, context=current_ctx) + if attributes: + for k, v in attributes.items(): + span.set_attribute(k, v) + token = context.attach(trace.set_span_in_context(span)) + try: yield span + finally: + span.end() + context.detach(token) except Exception as exc: _logger.debug("Traced tool span failed: %s", exc) yield None diff --git a/sdk/python/tests/integration/test_tracing_integration.py b/sdk/python/tests/integration/test_tracing_integration.py index 20a86072b95..afec6e89757 100644 --- a/sdk/python/tests/integration/test_tracing_integration.py +++ b/sdk/python/tests/integration/test_tracing_integration.py @@ -32,7 +32,9 @@ try: from opentelemetry import trace as otel_trace # noqa: F401 from opentelemetry.sdk.trace import TracerProvider # noqa: F401 - from opentelemetry.sdk.trace.export.in_memory import InMemorySpanExporter # type: ignore[import-untyped] # noqa: F401 + from opentelemetry.sdk.trace.export.in_memory import ( + InMemorySpanExporter, # type: ignore[import-untyped] # noqa: F401 + ) HAS_OTEL = True except ImportError: @@ -57,9 +59,13 @@ def _isolate_globals(): import feast.tracing - feast.tracing._tracer_initialized = False + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None yield - feast.tracing._tracer_initialized = False + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None @pytest.fixture() @@ -165,13 +171,22 @@ def store_with_tracing(driver_parquet, tracking_uri, feast_objects): class TestTracingInitialization: - def test_tracer_is_set_when_enabled(self, store_with_tracing): - assert store_with_tracing._tracer is not None + def test_lazy_init_enables_when_configured(self, store_with_tracing): + import feast.tracing + + result = feast.tracing._lazy_init(store_with_tracing) + assert result is True + assert feast.tracing._enabled is True + + def test_lazy_init_disabled_when_tracing_off( + self, driver_parquet, tracking_uri, feast_objects + ): + import feast.tracing - def test_tracer_is_none_when_disabled(self, driver_parquet, tracking_uri, feast_objects): tmp_path, _ = driver_parquet store = _make_store(tmp_path, tracking_uri, enable_tracing=False) - assert store._tracer is None + result = feast.tracing._lazy_init(store) + assert result is False class TestOnlineFeatureTracing: diff --git a/sdk/python/tests/unit/test_tracing.py b/sdk/python/tests/unit/test_tracing.py index de96d7af2d0..fb07799b7de 100644 --- a/sdk/python/tests/unit/test_tracing.py +++ b/sdk/python/tests/unit/test_tracing.py @@ -19,9 +19,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch -import pytest - - # --------------------------------------------------------------------------- # tracing_context tests # --------------------------------------------------------------------------- @@ -199,9 +196,7 @@ def test_install_graceful_on_missing_mlflow(self): try: import mlflow.tracing - with patch.object( - mlflow.tracing, "configure", side_effect=AttributeError - ): + with patch.object(mlflow.tracing, "configure", side_effect=AttributeError): install_feast_span_processor() except (ImportError, AttributeError): install_feast_span_processor() @@ -212,139 +207,104 @@ def test_install_graceful_on_missing_mlflow(self): # --------------------------------------------------------------------------- -class TestInitTracing: +class TestLazyInit: def setup_method(self): import feast.tracing - feast.tracing._tracer_initialized = False - - def test_returns_none_when_otel_missing(self): - with patch.dict("sys.modules", {"opentelemetry": None}): - import importlib - - import feast.tracing - - importlib.reload(feast.tracing) - feast.tracing._tracer_initialized = False - result = feast.tracing.init_tracing(MagicMock()) - assert result is None + feast.tracing._initialized = False + feast.tracing._enabled = False + feast.tracing._tracer = None - def test_returns_tracer_with_mocked_deps(self): - mock_trace = MagicMock() - mock_tracer = MagicMock() - mock_trace.get_tracer.return_value = mock_tracer - mock_provider_cls = MagicMock() - mock_processor_cls = MagicMock() - mock_exporter_cls = MagicMock() + def test_disabled_when_no_mlflow_config(self): + import feast.tracing - modules = { - "opentelemetry": MagicMock(trace=mock_trace), - "opentelemetry.trace": mock_trace, - "opentelemetry.sdk": MagicMock(), - "opentelemetry.sdk.trace": MagicMock(TracerProvider=mock_provider_cls), - "opentelemetry.sdk.trace.export": MagicMock(SimpleSpanProcessor=mock_processor_cls), - "mlflow": MagicMock(), - "mlflow.tracing": MagicMock(), - "mlflow.tracing.export": MagicMock(MlflowSpanExporter=mock_exporter_cls), - } + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None + assert feast.tracing._lazy_init(store) is False - with patch.dict("sys.modules", modules): - import importlib + def test_disabled_when_mlflow_not_enabled(self): + import feast.tracing - import feast.tracing + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow.enabled = False + assert feast.tracing._lazy_init(store) is False - importlib.reload(feast.tracing) - feast.tracing._tracer_initialized = False + def test_disabled_when_enable_tracing_false(self): + import feast.tracing - result = feast.tracing.init_tracing(MagicMock()) - assert result is mock_tracer - mock_trace.set_tracer_provider.assert_called_once() + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow.enabled = True + store.config.mlflow.enable_tracing = False + assert feast.tracing._lazy_init(store) is False def test_idempotent(self): - mock_trace = MagicMock() - mock_tracer = MagicMock() - mock_trace.get_tracer.return_value = mock_tracer - - modules = { - "opentelemetry": MagicMock(trace=mock_trace), - "opentelemetry.trace": mock_trace, - "opentelemetry.sdk": MagicMock(), - "opentelemetry.sdk.trace": MagicMock(TracerProvider=MagicMock()), - "opentelemetry.sdk.trace.export": MagicMock(SimpleSpanProcessor=MagicMock()), - "mlflow": MagicMock(), - "mlflow.tracing": MagicMock(), - "mlflow.tracing.export": MagicMock(MlflowSpanExporter=MagicMock()), - } - - with patch.dict("sys.modules", modules): - import importlib - - import feast.tracing + import feast.tracing - importlib.reload(feast.tracing) - feast.tracing._tracer_initialized = False + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None + feast.tracing._lazy_init(store) + feast.tracing._lazy_init(store) + # Should only set _initialized once (no crash on second call) + assert feast.tracing._initialized is True - feast.tracing.init_tracing(MagicMock()) - feast.tracing.init_tracing(MagicMock()) - mock_trace.set_tracer_provider.assert_called_once() +class TestIsEmbeddedStore: + def test_milvus_with_path_no_host(self): + from feast.tracing import _is_embedded_store -class TestTracedDecorator: - def test_noop_without_tracer(self): - from feast.tracing import traced + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="milvus", path="data/online.db", host=None + ) + assert _is_embedded_store(store) is True - @traced("test.span") - def my_method(self_): - return 42 + def test_milvus_with_host(self): + from feast.tracing import _is_embedded_store - obj = SimpleNamespace() - assert my_method(obj) == 42 + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="milvus", path=None, host="localhost" + ) + assert _is_embedded_store(store) is False - def test_creates_span_with_tracer(self): - from feast.tracing import traced + def test_sqlite_not_embedded(self): + from feast.tracing import _is_embedded_store - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__ = MagicMock( - return_value=mock_span - ) - mock_tracer.start_as_current_span.return_value.__exit__ = MagicMock( - return_value=False + store = MagicMock() + store.config.online_config = SimpleNamespace( + type="sqlite", path="data/online.db" ) + assert _is_embedded_store(store) is False - @traced("test.span") - def my_method(self_): - return 99 + def test_none_config(self): + from feast.tracing import _is_embedded_store - obj = SimpleNamespace(_tracer=mock_tracer, project="test_project") - result = my_method(obj) - assert result == 99 - mock_tracer.start_as_current_span.assert_called_once_with("test.span") - mock_span.set_attribute.assert_called_with("feast.project", "test_project") + store = MagicMock() + store.config.online_config = None + assert _is_embedded_store(store) is False - def test_records_exception(self): - from feast.tracing import traced + def test_dict_config(self): + from feast.tracing import _is_embedded_store - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__ = MagicMock( - return_value=mock_span - ) - mock_tracer.start_as_current_span.return_value.__exit__ = MagicMock( - return_value=False - ) + store = MagicMock() + store.config.online_config = {"type": "milvus", "path": "data/db", "host": ""} + assert _is_embedded_store(store) is True - @traced("test.span") - def bad_method(self_): - raise ValueError("boom") - obj = SimpleNamespace(_tracer=mock_tracer, project="p") +class TestTracedToolSpan: + def test_noop_when_disabled(self): + import feast.tracing + + feast.tracing._initialized = False + store = MagicMock() + store.config.mlflow = None - with patch("feast.tracing._StatusCode_ERROR") as mock_status: - with pytest.raises(ValueError, match="boom"): - bad_method(obj) - mock_span.record_exception.assert_called_once() - mock_span.set_status.assert_called_once() + with feast.tracing.traced_tool_span(store, "test.span") as span: + assert span is None # ---------------------------------------------------------------------------