diff --git a/databricks_mcp/tests/integration_tests/conftest.py b/databricks_mcp/tests/integration_tests/conftest.py index bf654a00..ad2a3ebb 100644 --- a/databricks_mcp/tests/integration_tests/conftest.py +++ b/databricks_mcp/tests/integration_tests/conftest.py @@ -180,6 +180,35 @@ def cached_vs_call_result(vs_mcp_client, cached_vs_tools_list): return vs_mcp_client.call_tool(tool.name, {param_name: "test"}) +# ============================================================================= +# DBSQL Fixtures +# ============================================================================= + + +@pytest.fixture(scope="session") +def dbsql_mcp_url(workspace_client): + """Construct MCP URL for the DBSQL server.""" + base_url = workspace_client.config.host + return f"{base_url}/api/2.0/mcp/sql" + + +@pytest.fixture(scope="session") +def dbsql_mcp_client(dbsql_mcp_url, workspace_client): + """DatabricksMCPClient pointed at the DBSQL server.""" + return DatabricksMCPClient(dbsql_mcp_url, workspace_client) + + +@pytest.fixture(scope="session") +def cached_dbsql_tools_list(dbsql_mcp_client): + """Cache the DBSQL list_tools() result; skip if DBSQL MCP endpoint unavailable.""" + try: + tools = dbsql_mcp_client.list_tools() + except ExceptionGroup as e: # ty: ignore[unresolved-reference] + _skip_if_not_found(e, "DBSQL MCP endpoint not available in workspace") + assert tools, "DBSQL list_tools() returned no tools" + return tools + + # ============================================================================= # Genie Fixtures # ============================================================================= diff --git a/databricks_mcp/tests/integration_tests/test_mcp_core.py b/databricks_mcp/tests/integration_tests/test_mcp_core.py index c6cbf83a..58303502 100644 --- a/databricks_mcp/tests/integration_tests/test_mcp_core.py +++ b/databricks_mcp/tests/integration_tests/test_mcp_core.py @@ -126,6 +126,170 @@ def test_call_tool_returns_result_with_content(self, cached_genie_call_result): assert len(cached_genie_call_result.content) > 0 +# ============================================================================= +# DBSQL +# ============================================================================= + + +@pytest.mark.integration +class TestMCPClientDBSQL: + """Verify list_tools() and call_tool() against a live DBSQL MCP server.""" + + def test_list_tools_returns_expected_tools(self, cached_dbsql_tools_list): + tool_names = [t.name for t in cached_dbsql_tools_list] + for expected in ["execute_sql", "execute_sql_read_only", "poll_sql_result"]: + assert expected in tool_names, f"Expected tool '{expected}' not found in {tool_names}" + + def test_call_tool_execute_sql_read_only(self, dbsql_mcp_client, cached_dbsql_tools_list): + """execute_sql_read_only with SHOW CATALOGS should return results.""" + result = dbsql_mcp_client.call_tool("execute_sql_read_only", {"query": "SHOW CATALOGS"}) + assert isinstance(result, CallToolResult) + assert result.content, "SHOW CATALOGS should return content" + assert len(result.content) > 0 + + +# ============================================================================= +# Raw streamable_http_client +# ============================================================================= + + +@pytest.mark.integration +class TestRawStreamableHttpClient: + """Verify DatabricksOAuthClientProvider works with the raw MCP SDK streamable_http_client. + + This tests the low-level path: httpx.AsyncClient + DatabricksOAuthClientProvider + + streamable_http_client + ClientSession, without going through DatabricksMCPClient. + """ + + @pytest.mark.asyncio + async def test_uc_function_list_and_call(self, uc_function_url, workspace_client): + """list_tools + call_tool via raw streamable_http_client for UC functions.""" + import httpx + from mcp import ClientSession + from mcp.client.streamable_http import streamable_http_client + + from databricks_mcp import DatabricksOAuthClientProvider + + async with httpx.AsyncClient( + auth=DatabricksOAuthClientProvider(workspace_client), + follow_redirects=True, + timeout=httpx.Timeout(120.0, read=120.0), + ) as http_client: + async with streamable_http_client(uc_function_url, http_client=http_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # list_tools + tools_response = await session.list_tools() + tools = tools_response.tools + assert len(tools) > 0 + tool_names = [t.name for t in tools] + assert any("echo_message" in name for name in tool_names) + + # call_tool + tool_name = next(n for n in tool_names if "echo_message" in n) + result = await session.call_tool(tool_name, {"message": "raw_client_test"}) + assert result.content + assert "raw_client_test" in str(result.content[0].text) + + @pytest.mark.asyncio + async def test_vs_list_tools(self, vs_mcp_url, workspace_client): + """list_tools via raw streamable_http_client for Vector Search.""" + import httpx + from mcp import ClientSession + from mcp.client.streamable_http import streamable_http_client + + from databricks_mcp import DatabricksOAuthClientProvider + + async with httpx.AsyncClient( + auth=DatabricksOAuthClientProvider(workspace_client), + follow_redirects=True, + timeout=httpx.Timeout(120.0, read=120.0), + ) as http_client: + async with streamable_http_client(vs_mcp_url, http_client=http_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tools_response = await session.list_tools() + assert len(tools_response.tools) > 0 + + @pytest.mark.asyncio + async def test_dbsql_list_and_call(self, dbsql_mcp_url, workspace_client): + """list_tools + call_tool via raw streamable_http_client for DBSQL.""" + import httpx + from mcp import ClientSession + from mcp.client.streamable_http import streamable_http_client + + from databricks_mcp import DatabricksOAuthClientProvider + + async with httpx.AsyncClient( + auth=DatabricksOAuthClientProvider(workspace_client), + follow_redirects=True, + timeout=httpx.Timeout(120.0, read=120.0), + ) as http_client: + async with streamable_http_client(dbsql_mcp_url, http_client=http_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + tools_response = await session.list_tools() + tools = tools_response.tools + tool_names = [t.name for t in tools] + assert "execute_sql_read_only" in tool_names + + result = await session.call_tool( + "execute_sql_read_only", {"query": "SHOW CATALOGS"} + ) + assert result.content + assert len(result.content) > 0 + + @pytest.mark.asyncio + async def test_genie_list_and_call(self, genie_mcp_url, workspace_client): + """list_tools + call_tool via raw streamable_http_client for Genie.""" + import httpx + from mcp import ClientSession + from mcp.client.streamable_http import streamable_http_client + + from databricks_mcp import DatabricksOAuthClientProvider + + async with httpx.AsyncClient( + auth=DatabricksOAuthClientProvider(workspace_client), + follow_redirects=True, + timeout=httpx.Timeout(120.0, read=120.0), + ) as http_client: + async with streamable_http_client(genie_mcp_url, http_client=http_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + tools_response = await session.list_tools() + tools = tools_response.tools + assert len(tools) > 0 + + # Call the first tool (query_space_*) + tool = tools[0] + properties = tool.inputSchema.get("properties", {}) + param_name = next(iter(properties), "query") + result = await session.call_tool( + tool.name, {param_name: "How many rows are there?"} + ) + assert result.content + assert len(result.content) > 0 + + # ============================================================================= # Error paths # ============================================================================= diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py index 9102a26b..cd1ce3f6 100644 --- a/integrations/langchain/tests/integration_tests/test_chat_models.py +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -27,7 +27,7 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, StateGraph from langgraph.graph.message import add_messages -from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition +from langgraph.prebuilt import ToolNode, tools_condition from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -58,7 +58,7 @@ def test_chat_databricks_invoke(model): response = chat.invoke("How to learn Java? Start the response by 'To learn Java,'") assert isinstance(response, AIMessage) assert response.content == "To learn " - assert 20 <= response.response_metadata["prompt_tokens"] <= 30 + assert 15 <= response.response_metadata["prompt_tokens"] <= 60 assert 1 <= response.response_metadata["completion_tokens"] <= 10 expected_total = ( response.response_metadata["prompt_tokens"] @@ -94,6 +94,8 @@ def test_chat_databricks_invoke(model): @pytest.mark.foundation_models @pytest.mark.parametrize("model", _FOUNDATION_MODELS) def test_chat_databricks_invoke_multiple_completions(model): + if "claude" in model: + pytest.skip("Anthropic does not support n > 1") chat = ChatDatabricks( model=model, temperature=0.5, @@ -129,8 +131,15 @@ def on_llm_new_token(self, *args, **kwargs): assert all("Python" not in chunk.content for chunk in chunks) assert callback.chunk_counts == len(chunks) - last_chunk = chunks[-1] - assert last_chunk.response_metadata["finish_reason"] == "stop" + # finish_reason may be on the last content chunk, not necessarily chunks[-1] + # (a usage-only chunk may follow when stream_options is enabled) + finish_reasons = [ + chunk.response_metadata.get("finish_reason") + for chunk in chunks + if chunk.response_metadata.get("finish_reason") + ] + assert len(finish_reasons) >= 1, "Expected at least one chunk with finish_reason" + assert finish_reasons[-1] in ("stop", "end_turn") @pytest.mark.foundation_models @@ -159,12 +168,23 @@ def on_llm_new_token(self, *args, **kwargs): assert all("Python" not in chunk.content for chunk in chunks) assert callback.chunk_counts == len(chunks) - last_chunk = chunks[-1] - assert last_chunk.response_metadata["finish_reason"] == "stop" - assert last_chunk.usage_metadata is not None - assert last_chunk.usage_metadata["input_tokens"] > 0 - assert last_chunk.usage_metadata["output_tokens"] > 0 - assert last_chunk.usage_metadata["total_tokens"] > 0 + # finish_reason may be on the last content chunk, not necessarily chunks[-1] + # (a usage-only chunk may follow when stream_options is enabled) + finish_reasons = [ + chunk.response_metadata.get("finish_reason") + for chunk in chunks + if chunk.response_metadata.get("finish_reason") + ] + assert len(finish_reasons) >= 1, "Expected at least one chunk with finish_reason" + assert finish_reasons[-1] in ("stop", "end_turn") + + # Usage may not be on the last chunk — find chunks that have it + usage_chunks = [c for c in chunks if c.usage_metadata is not None] + assert len(usage_chunks) >= 1, "Expected at least one chunk with usage_metadata" + usage = usage_chunks[-1].usage_metadata + assert usage["input_tokens"] > 0 + assert usage["output_tokens"] > 0 + assert usage["total_tokens"] > 0 @pytest.mark.asyncio @@ -368,6 +388,8 @@ def test_chat_databricks_with_structured_output(model, schema, method): if schema is None and method == "function_calling": pytest.skip("Cannot use function_calling without schema") + if method == "json_mode" and "claude" in model: + pytest.skip("Anthropic does not support json_object response format") structured_llm = llm.with_structured_output(schema, method=method) @@ -432,21 +454,6 @@ def multiply(a: int, b: int) -> int: return a * b -@pytest.mark.foundation_models -@pytest.mark.parametrize("model", _FOUNDATION_MODELS) -def test_chat_databricks_langgraph(model): - model = ChatDatabricks( - model=model, - temperature=0, - max_tokens=100, - ) - tools = [add, multiply] - - app = create_react_agent(model, tools) - response = app.invoke({"messages": [("human", "What is (10 + 5) * 3?")]}) - assert "45" in response["messages"][-1].content - - @pytest.mark.foundation_models @pytest.mark.parametrize("model", _FOUNDATION_MODELS) def test_chat_databricks_langgraph_with_memory(model): @@ -495,7 +502,11 @@ def chatbot(state: State): config={"configurable": {"thread_id": "1"}}, ) - assert "40" in response["messages"][-1].content + # The LLM should reference the result of subtracting 5 from 45 + final = response["messages"][-1].content + assert any(x in final for x in ["40", "subtract", "minus"]), ( + f"Expected reference to subtraction result in: {final[:200]}" + ) @pytest.mark.st_endpoints @@ -751,48 +762,6 @@ def test_chat_databricks_utf8_encoding(model): assert "blåbær" in full_content.lower() -def test_chat_databricks_with_timeout_and_retries(): - """Test that ChatDatabricks can be initialized with timeout and max_retries parameters.""" - from unittest.mock import Mock, patch - - # Mock the OpenAI client - mock_openai_client = Mock() - mock_workspace_client = Mock() - mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client - - with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client): - # Create ChatDatabricks with timeout and max_retries - chat = ChatDatabricks( - model="databricks-meta-llama-3-3-70b-instruct", timeout=45.0, max_retries=3 - ) - - # Verify the parameters are set correctly - assert chat.timeout == 45.0 - assert chat.max_retries == 3 - - # Verify the client was configured with these parameters - assert chat.client == mock_openai_client - - # Test with workspace_client parameter - with patch( - "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client - ) as mock_get_client: - chat_with_ws = ChatDatabricks( - model="databricks-meta-llama-3-3-70b-instruct", - workspace_client=mock_workspace_client, - timeout=30.0, - max_retries=2, - ) - - # Verify get_openai_client was called with all parameters - mock_get_client.assert_called_once_with( - workspace_client=mock_workspace_client, timeout=30.0, max_retries=2 - ) - - assert chat_with_ws.timeout == 30.0 - assert chat_with_ws.max_retries == 2 - - def test_chat_databricks_with_gpt_oss(): """ API ref: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#contentitem @@ -804,6 +773,10 @@ def test_chat_databricks_with_gpt_oss(): assert isinstance(response.content, str) +@pytest.mark.skipif( + os.environ.get("RUN_DOGFOOD_TESTS", "").lower() != "true", + reason="Requires dogfood workspace. Set RUN_DOGFOOD_TESTS=true to run.", +) def test_chat_databricks_custom_outputs(): llm = ChatDatabricks(model="agents_ml-bbqiu-codegen", use_responses_api=True) response = llm.invoke( @@ -813,6 +786,10 @@ def test_chat_databricks_custom_outputs(): assert response.custom_outputs["key"] == "value" # type: ignore[attr-defined] +@pytest.mark.skipif( + os.environ.get("RUN_DOGFOOD_TESTS", "").lower() != "true", + reason="Requires dogfood workspace. Set RUN_DOGFOOD_TESTS=true to run.", +) def test_chat_databricks_custom_outputs_stream(): llm = ChatDatabricks(model="agents_ml-bbqiu-mcp-openai", use_responses_api=True) response = llm.stream( @@ -824,10 +801,6 @@ def test_chat_databricks_custom_outputs_stream(): def test_chat_databricks_token_count(): - import mlflow - - mlflow.set_experiment("4435237072766312") - mlflow.langchain.autolog() llm = ChatDatabricks(model="databricks-gpt-oss-120b") response = llm.invoke("What is the 100th fibonacci number?") assert response.content is not None @@ -840,16 +813,15 @@ def test_chat_databricks_token_count(): + response.response_metadata["completion_tokens"] ) + # Usage may not be on the last chunk — find chunks that have it chunks = list(llm.stream("What is the 100th fibonacci number?")) - last_chunk = chunks[-1] - assert last_chunk.usage_metadata is not None - assert last_chunk.usage_metadata["input_tokens"] > 0 - assert last_chunk.usage_metadata["output_tokens"] > 0 - assert last_chunk.usage_metadata["total_tokens"] > 0 - assert ( - last_chunk.usage_metadata["total_tokens"] - == last_chunk.usage_metadata["input_tokens"] + last_chunk.usage_metadata["output_tokens"] - ) + usage_chunks = [c for c in chunks if c.usage_metadata is not None] + assert len(usage_chunks) >= 1, "Expected at least one chunk with usage_metadata" + usage = usage_chunks[-1].usage_metadata + assert usage["input_tokens"] > 0 + assert usage["output_tokens"] > 0 + assert usage["total_tokens"] > 0 + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] def test_chat_databricks_gpt5_stream_with_usage(): @@ -874,14 +846,8 @@ def test_chat_databricks_gpt5_stream_with_usage(): ) ) """ - from databricks.sdk import WorkspaceClient - - # Use dogfood profile to access GPT-5 - workspace_client = WorkspaceClient(profile=DATABRICKS_CLI_PROFILE) - llm = ChatDatabricks( - endpoint="gpt-5", - workspace_client=workspace_client, + endpoint="databricks-gpt-5", max_tokens=100, stream_usage=True, ) @@ -1030,7 +996,7 @@ def _verify_responses_usage_metadata_keys(lc_usage, openai_usage): if openai_usage.output_tokens_details is not None: assert "output_token_details" in lc_usage if openai_usage.output_tokens_details.reasoning_tokens is not None: - assert "reasoning_tokens" in lc_usage["output_token_details"] + assert "reasoning" in lc_usage["output_token_details"] @pytest.mark.foundation_models diff --git a/integrations/langchain/tests/unit_tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py index dc7aa81e..a7a29d0d 100644 --- a/integrations/langchain/tests/unit_tests/test_chat_models.py +++ b/integrations/langchain/tests/unit_tests/test_chat_models.py @@ -2071,3 +2071,21 @@ def test_chat_databricks_responses_api_invoke_returns_usage_metadata(): assert usage_metadata["total_tokens"] == 150 assert usage_metadata["input_token_details"]["cache_read"] == 25 assert usage_metadata["output_token_details"]["reasoning"] == 10 + + +def test_chat_databricks_with_timeout_and_retries(): + """Test that ChatDatabricks can be initialized with timeout and max_retries parameters.""" + mock_openai_client = Mock() + + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + chat = ChatDatabricks( + model="databricks-meta-llama-3-3-70b-instruct", timeout=45.0, max_retries=3 + ) + assert chat.timeout == 45.0 + assert chat.max_retries == 3 + assert chat.client == mock_openai_client + mock_get_client.assert_called_once_with( + workspace_client=None, timeout=45.0, max_retries=3 + ) diff --git a/pyproject.toml b/pyproject.toml index 2a9be2b5..4bd3fa3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,3 +125,9 @@ root = ["./src", "./tests"] [tool.ty.src] include = ["./src", "./tests"] +exclude = [ + # Fixture code and deploy scripts — not SDK code, uses MLflow/agents types that don't resolve + "./tests/integration_tests/obo/model_serving_fixture", + "./tests/integration_tests/obo/app_fixture", + "./tests/integration_tests/obo/deploy_serving_agent.py", +] diff --git a/src/databricks_ai_bridge/model_serving_obo_credential_strategy.py b/src/databricks_ai_bridge/model_serving_obo_credential_strategy.py index d5ed73e4..0772b0de 100644 --- a/src/databricks_ai_bridge/model_serving_obo_credential_strategy.py +++ b/src/databricks_ai_bridge/model_serving_obo_credential_strategy.py @@ -159,7 +159,7 @@ class ModelServingUserCredentials(CredentialsStrategy): In any other environments, the class defaults to the DefaultCredentialStrategy. To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows: - user_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials()) + user_client = WorkspaceClient(credentials_strategy=ModelServingUserCredentials()) """ def __init__(self): diff --git a/src/databricks_ai_bridge/test_utils/fmapi.py b/src/databricks_ai_bridge/test_utils/fmapi.py index cca1e513..24bc3aaf 100644 --- a/src/databricks_ai_bridge/test_utils/fmapi.py +++ b/src/databricks_ai_bridge/test_utils/fmapi.py @@ -32,6 +32,8 @@ "databricks-gpt-5-1-codex-mini", # Responses API only, no Chat Completions support "databricks-gpt-5-2-codex", # Responses API only, no Chat Completions support "databricks-gpt-5-3-codex", # Responses API only, no Chat Completions support + "databricks-gpt-5-4", # Requires /v1/responses for tool calling, not /v1/chat/completions + "databricks-gemini-3-1-flash-lite", # Requires thought_signature on function calls } # Additional models skipped only in LangChain tests diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 24bdb875..0afb21f8 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -66,3 +66,4 @@ def pytest_configure(config): "markers", "behavior: mark test as behavior validation (search results)" ) config.addinivalue_line("markers", "slow: mark test as slow (may take > 30 seconds)") + config.addinivalue_line("markers", "obo: mark test as OBO credential flow test") diff --git a/tests/integration_tests/obo/app_fixture/agent_server/__init__.py b/tests/integration_tests/obo/app_fixture/agent_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/obo/app_fixture/agent_server/agent.py b/tests/integration_tests/obo/app_fixture/agent_server/agent.py new file mode 100644 index 00000000..2b0b9cac --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/agent_server/agent.py @@ -0,0 +1,72 @@ +from typing import AsyncGenerator + +import mlflow +from agents import Agent, Runner, function_tool, set_default_openai_api, set_default_openai_client +from agents.tracing import set_trace_processors +from databricks_openai import AsyncDatabricksOpenAI +from mlflow.genai.agent_server import invoke, stream +from mlflow.types.responses import ( + ResponsesAgentRequest, + ResponsesAgentResponse, + ResponsesAgentStreamEvent, +) + +from agent_server.utils import ( + get_user_workspace_client, + process_agent_stream_events, +) + +set_default_openai_client(AsyncDatabricksOpenAI()) +set_default_openai_api("chat_completions") +set_trace_processors([]) +mlflow.openai.autolog() + +NAME = "agent-obo-test" +SYSTEM_PROMPT = ( + "You are a helpful assistant. When asked who the user is, " + "call the whoami tool and return the raw result." +) +MODEL = "databricks-claude-sonnet-4-6" + + +def _make_whoami_tool(user_wc): + """Create a whoami tool that uses the given workspace client.""" + + @function_tool + def whoami() -> str: + """Returns the identity of the current user.""" + me = user_wc.current_user.me() + return me.user_name + + return whoami + + +def create_agent(tools) -> Agent: + return Agent( + name=NAME, + instructions=SYSTEM_PROMPT, + model=MODEL, + tools=tools, + ) + + +@invoke() +async def invoke(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + user_wc = get_user_workspace_client() + whoami_tool = _make_whoami_tool(user_wc) + agent = create_agent([whoami_tool]) + messages = [i.model_dump() for i in request.input] + result = await Runner.run(agent, messages) + return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items]) + + +@stream() +async def stream(request: dict) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: + user_wc = get_user_workspace_client() + whoami_tool = _make_whoami_tool(user_wc) + agent = create_agent([whoami_tool]) + messages = [i.model_dump() for i in request.input] + result = Runner.run_streamed(agent, input=messages) + + async for event in process_agent_stream_events(result.stream_events()): + yield event diff --git a/tests/integration_tests/obo/app_fixture/agent_server/start_server.py b/tests/integration_tests/obo/app_fixture/agent_server/start_server.py new file mode 100644 index 00000000..59550964 --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/agent_server/start_server.py @@ -0,0 +1,18 @@ +from dotenv import load_dotenv +from mlflow.genai.agent_server import AgentServer, setup_mlflow_git_based_version_tracking + +load_dotenv(dotenv_path=".env", override=True) + +import agent_server.agent # noqa: E402, F401 + +agent_server = AgentServer("ResponsesAgent", enable_chat_proxy=True) +app = agent_server.app # noqa: F841 + +try: + setup_mlflow_git_based_version_tracking() +except Exception: + pass + + +def main(): + agent_server.run(app_import_string="agent_server.start_server:app") diff --git a/tests/integration_tests/obo/app_fixture/agent_server/utils.py b/tests/integration_tests/obo/app_fixture/agent_server/utils.py new file mode 100644 index 00000000..01c51581 --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/agent_server/utils.py @@ -0,0 +1,79 @@ +import json +import logging +import os +from typing import AsyncGenerator, AsyncIterator, Optional +from uuid import uuid4 + +from agents.result import StreamEvent +from databricks.sdk import WorkspaceClient +from mlflow.genai.agent_server import get_request_headers +from mlflow.types.responses import ResponsesAgentStreamEvent + + +def get_databricks_host(workspace_client: WorkspaceClient | None = None) -> Optional[str]: + workspace_client = workspace_client or WorkspaceClient() + try: + return workspace_client.config.host + except Exception as e: + logging.exception(f"Error getting databricks host from env: {e}") + return None + + +def get_user_workspace_client() -> WorkspaceClient: + """Get a WorkspaceClient authenticated as the requesting user via OBO. + + Reads the x-forwarded-access-token header injected by the Databricks Apps + proxy when user authorization scopes are configured on the app. + Falls back to the app's default credentials if the header is absent. + """ + headers = get_request_headers() + token = headers.get("x-forwarded-access-token") + if not token: + logging.warning( + "No x-forwarded-access-token header found. " + "Ensure user authorization scopes are configured on the app. " + "Available headers: %s", + list(headers.keys()), + ) + return WorkspaceClient() + host = get_databricks_host() + # Temporarily clear app SP credentials from env to avoid + # "more than one authorization method" conflict in the SDK + old_id = os.environ.pop("DATABRICKS_CLIENT_ID", None) + old_secret = os.environ.pop("DATABRICKS_CLIENT_SECRET", None) + try: + wc = WorkspaceClient(host=host, token=token) + finally: + if old_id is not None: + os.environ["DATABRICKS_CLIENT_ID"] = old_id + if old_secret is not None: + os.environ["DATABRICKS_CLIENT_SECRET"] = old_secret + return wc + + +async def process_agent_stream_events( + async_stream: AsyncIterator[StreamEvent], +) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: + curr_item_id = str(uuid4()) + async for event in async_stream: + if event.type == "raw_response_event": + event_data = event.data.model_dump() + if event_data["type"] == "response.output_item.added": + curr_item_id = str(uuid4()) + event_data["item"]["id"] = curr_item_id + elif event_data.get("item") is not None and event_data["item"].get("id") is not None: + event_data["item"]["id"] = curr_item_id + elif event_data.get("item_id") is not None: + event_data["item_id"] = curr_item_id + yield event_data + elif event.type == "run_item_stream_event" and event.item.type == "tool_call_output_item": + output = event.item.to_input_item() + if not isinstance(output.get("output"), str): + try: + output["output"] = json.dumps(output.get("output")) + except (TypeError, ValueError): + output["output"] = str(output.get("output")) + yield ResponsesAgentStreamEvent( + type="response.output_item.done", + item=output, + ) diff --git a/tests/integration_tests/obo/app_fixture/app.yaml b/tests/integration_tests/obo/app_fixture/app.yaml new file mode 100644 index 00000000..9abd4e00 --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/app.yaml @@ -0,0 +1,13 @@ +command: ["uv", "run", "start-app"] + +env: + - name: MLFLOW_TRACKING_URI + value: "databricks" + - name: MLFLOW_REGISTRY_URI + value: "databricks-uc" + - name: API_PROXY + value: "http://localhost:8000/invocations" + - name: CHAT_APP_PORT + value: "3000" + - name: CHAT_PROXY_TIMEOUT_SECONDS + value: "300" diff --git a/tests/integration_tests/obo/app_fixture/pyproject.toml b/tests/integration_tests/obo/app_fixture/pyproject.toml new file mode 100644 index 00000000..d34c17cd --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "agent-obo-test" +version = "0.1.0" +description = "Minimal OBO test agent for databricks-ai-bridge integration tests" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115.12", + "uvicorn>=0.34.2", + "databricks-openai>=0.9.0", + "databricks-agents", + "mlflow>=3.9.0", + "openai-agents>=0.4.1", + "python-dotenv", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["agent_server", "scripts"] + +[project.scripts] +start-app = "scripts.start_app:main" +start-server = "agent_server.start_server:main" diff --git a/tests/integration_tests/obo/app_fixture/scripts/__init__.py b/tests/integration_tests/obo/app_fixture/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/obo/app_fixture/scripts/start_app.py b/tests/integration_tests/obo/app_fixture/scripts/start_app.py new file mode 100644 index 00000000..6082c835 --- /dev/null +++ b/tests/integration_tests/obo/app_fixture/scripts/start_app.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Simplified start script for CI deployment (backend only, no UI).""" + +import argparse +import subprocess +import sys +import threading + +from dotenv import load_dotenv + + +def main(): + load_dotenv(dotenv_path=".env", override=True) + + parser = argparse.ArgumentParser() + parser.add_argument("--no-ui", action="store_true", default=True) + args, backend_args = parser.parse_known_args() + + cmd = ["uv", "run", "start-server"] + backend_args + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + + def monitor(): + for line in iter(proc.stdout.readline, ""): + print(line.rstrip()) # noqa: T201 + + thread = threading.Thread(target=monitor, daemon=True) + thread.start() + + try: + proc.wait() + except KeyboardInterrupt: + proc.terminate() + + sys.exit(proc.returncode or 0) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_tests/obo/deploy_serving_agent.py b/tests/integration_tests/obo/deploy_serving_agent.py new file mode 100644 index 00000000..a3250ec4 --- /dev/null +++ b/tests/integration_tests/obo/deploy_serving_agent.py @@ -0,0 +1,102 @@ +""" +Deploy the whoami OBO agent to a Model Serving endpoint. + +Run manually or on a weekly schedule to keep the endpoint on the latest SDK. + +Environment Variables: + DATABRICKS_HOST - Workspace URL + DATABRICKS_CLIENT_ID - Service principal client ID + DATABRICKS_CLIENT_SECRET - Service principal client secret + OBO_TEST_SERVING_ENDPOINT - Target serving endpoint name (optional override) + OBO_TEST_WAREHOUSE_ID - SQL warehouse ID +""" + +import logging +import os +import tempfile +from pathlib import Path + +import mlflow +from databricks.sdk import WorkspaceClient +from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksSQLWarehouse + +log = logging.getLogger(__name__) + +# Must match the constants in whoami_serving_agent.py +LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-6" +SQL_WAREHOUSE_ID = os.environ["OBO_TEST_WAREHOUSE_ID"] + +UC_CATALOG = "integration_testing" +UC_SCHEMA = "databricks_ai_bridge_mcp_test" +UC_MODEL_NAME_SHORT = "test_endpoint_dhruv" +UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME_SHORT}" + + +def main(): + w = WorkspaceClient() + log.info("Workspace: %s", w.config.host) + + mlflow.set_registry_uri("databricks-uc") + + experiment_name = f"/Users/{w.current_user.me().user_name}/obo-serving-agent-deploy" + mlflow.set_experiment(experiment_name) + + # Copy agent file to a temp dir, injecting the warehouse ID + agent_source = Path(__file__).parent / "model_serving_fixture" / "whoami_serving_agent.py" + with tempfile.TemporaryDirectory() as tmp: + agent_file = Path(tmp) / "agent.py" + content = agent_source.read_text() + content = content.replace( + 'SQL_WAREHOUSE_ID = "" # Injected by deploy_serving_agent.py at log time', + f'SQL_WAREHOUSE_ID = "{SQL_WAREHOUSE_ID}"', + ) + agent_file.write_text(content) + + system_policy = SystemAuthPolicy( + resources=[ + DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME), + DatabricksSQLWarehouse(warehouse_id=SQL_WAREHOUSE_ID), + ] + ) + user_policy = UserAuthPolicy( + api_scopes=[ + "sql.statement-execution", + "sql.warehouses", + "serving.serving-endpoints", + ] + ) + + input_example = { + "input": [{"role": "user", "content": "Who am I?"}], + } + + with mlflow.start_run(): + logged_agent_info = mlflow.pyfunc.log_model( + name="agent", + python_model=str(agent_file), + input_example=input_example, + auth_policy=AuthPolicy( + system_auth_policy=system_policy, + user_auth_policy=user_policy, + ), + pip_requirements=[ + "databricks-openai", + "databricks-ai-bridge", + "databricks-sdk", + ], + ) + log.info("Logged model: %s", logged_agent_info.model_uri) + + registered = mlflow.register_model(logged_agent_info.model_uri, UC_MODEL_NAME) + log.info("Registered: %s version %s", UC_MODEL_NAME, registered.version) + + from databricks import agents + + agents.deploy(UC_MODEL_NAME, registered.version, scale_to_zero=True) + log.info("Deployment initiated (scale_to_zero=True)") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/tests/integration_tests/obo/model_serving_fixture/whoami_serving_agent.py b/tests/integration_tests/obo/model_serving_fixture/whoami_serving_agent.py new file mode 100644 index 00000000..f726f2d0 --- /dev/null +++ b/tests/integration_tests/obo/model_serving_fixture/whoami_serving_agent.py @@ -0,0 +1,153 @@ +""" +Minimal OBO whoami agent for Model Serving. + +Calls the whoami() UC function via SQL Statement Execution API +using ModelServingUserCredentials to act as the invoking user. + +This file gets logged as an MLflow model artifact via: + mlflow.pyfunc.log_model(python_model="whoami_serving_agent.py", ...) +""" + +import json +from typing import Any, Callable, Generator +from uuid import uuid4 + +import mlflow +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import StatementState +from mlflow.entities import SpanType +from mlflow.pyfunc import ResponsesAgent +from mlflow.types.responses import ( + ResponsesAgentRequest, + ResponsesAgentResponse, + ResponsesAgentStreamEvent, + output_to_responses_items_stream, + to_chat_completions_input, +) +from openai import OpenAI +from pydantic import BaseModel + +from databricks_ai_bridge import ModelServingUserCredentials + +LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-6" +SQL_WAREHOUSE_ID = "" # Injected by deploy_serving_agent.py at log time + + +class ToolInfo(BaseModel): + name: str + spec: dict + exec_fn: Callable + + +def create_whoami_tool(user_client: WorkspaceClient) -> ToolInfo: + @mlflow.trace(span_type=SpanType.TOOL) + def execute_whoami(**kwargs) -> str: + try: + response = user_client.statement_execution.execute_statement( + warehouse_id=SQL_WAREHOUSE_ID, + statement="SELECT integration_testing.databricks_ai_bridge_mcp_test.whoami() as result", + wait_timeout="30s", + ) + if response.status.state == StatementState.SUCCEEDED: + if response.result and response.result.data_array: + return str(response.result.data_array[0][0]) + return "No result returned" + return f"Query failed with state: {response.status.state}" + except Exception as e: + return f"Error calling whoami: {e}" + + tool_spec = { + "type": "function", + "function": { + "name": "whoami", + "description": "Returns information about the current user", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + } + return ToolInfo(name="whoami", spec=tool_spec, exec_fn=execute_whoami) + + +class ToolCallingAgent(ResponsesAgent): + def __init__(self, llm_endpoint: str, warehouse_id: str): + self.llm_endpoint = llm_endpoint + self.warehouse_id = warehouse_id + self._tools_dict = None + + def get_tool_specs(self) -> list[dict]: + if self._tools_dict is None: + return [] + return [t.spec for t in self._tools_dict.values()] + + @mlflow.trace(span_type=SpanType.TOOL) + def execute_tool(self, tool_name: str, args: dict) -> Any: + return self._tools_dict[tool_name].exec_fn(**args) + + def call_llm( + self, messages: list[dict[str, Any]], user_client: WorkspaceClient + ) -> Generator[dict[str, Any], None, None]: + client: OpenAI = user_client.serving_endpoints.get_open_ai_client() + for chunk in client.chat.completions.create( + model=self.llm_endpoint, + messages=to_chat_completions_input(messages), + tools=self.get_tool_specs(), + stream=True, + ): + chunk_dict = chunk.to_dict() + if len(chunk_dict.get("choices", [])) > 0: + yield chunk_dict + + def handle_tool_call( + self, tool_call: dict[str, Any], messages: list[dict[str, Any]] + ) -> ResponsesAgentStreamEvent: + try: + args = json.loads(tool_call.get("arguments", "{}")) + except Exception: + args = {} + result = str(self.execute_tool(tool_name=tool_call["name"], args=args)) + output = self.create_function_call_output_item(tool_call["call_id"], result) + messages.append(output) + return ResponsesAgentStreamEvent(type="response.output_item.done", item=output) + + def call_and_run_tools( + self, + messages: list[dict[str, Any]], + user_client: WorkspaceClient, + max_iter: int = 10, + ) -> Generator[ResponsesAgentStreamEvent, None, None]: + for _ in range(max_iter): + last_msg = messages[-1] + if last_msg.get("role") == "assistant": + return + elif last_msg.get("type") == "function_call": + yield self.handle_tool_call(last_msg, messages) + else: + yield from output_to_responses_items_stream( + chunks=self.call_llm(messages, user_client), aggregator=messages + ) + yield ResponsesAgentStreamEvent( + type="response.output_item.done", + item=self.create_text_output_item("Max iterations reached.", str(uuid4())), + ) + + def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: + user_client = WorkspaceClient(credentials_strategy=ModelServingUserCredentials()) + outputs = [ + event.item + for event in self.predict_stream(request, user_client) + if event.type == "response.output_item.done" + ] + return ResponsesAgentResponse(output=outputs) + + def predict_stream( + self, request: ResponsesAgentRequest, user_client: WorkspaceClient = None + ) -> Generator[ResponsesAgentStreamEvent, None, None]: + if user_client is None: + user_client = WorkspaceClient(credentials_strategy=ModelServingUserCredentials()) + whoami_tool = create_whoami_tool(user_client) + self._tools_dict = {whoami_tool.name: whoami_tool} + messages = to_chat_completions_input([i.model_dump() for i in request.input]) + yield from self.call_and_run_tools(messages=messages, user_client=user_client) + + +AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME, warehouse_id=SQL_WAREHOUSE_ID) +mlflow.models.set_model(AGENT) diff --git a/tests/integration_tests/obo/test_obo_credential_flow.py b/tests/integration_tests/obo/test_obo_credential_flow.py new file mode 100644 index 00000000..149fd428 --- /dev/null +++ b/tests/integration_tests/obo/test_obo_credential_flow.py @@ -0,0 +1,222 @@ +""" +End-to-end integration tests for OBO (On-Behalf-Of) credential flows. + +Invokes pre-deployed agents (Model Serving endpoint and Databricks App) as +two different service principals and asserts each caller sees their own identity +via the whoami() UC function tool. + + - SP-A ("deployer"): authenticated via DATABRICKS_CLIENT_ID/SECRET + - SP-B ("end user"): authenticated via OBO_TEST_CLIENT_ID/SECRET + +Environment Variables: +====================== +Required: + RUN_OBO_INTEGRATION_TESTS - Set to "1" to enable + DATABRICKS_HOST - Workspace URL + DATABRICKS_CLIENT_ID - SP-A client ID + DATABRICKS_CLIENT_SECRET - SP-A client secret + OBO_TEST_CLIENT_ID - SP-B client ID + OBO_TEST_CLIENT_SECRET - SP-B client secret + OBO_TEST_SERVING_ENDPOINT - Pre-deployed Model Serving endpoint name + OBO_TEST_APP_NAME - Pre-deployed Databricks App name +""" + +from __future__ import annotations + +import logging +import os +import time + +import pytest +from databricks.sdk import WorkspaceClient + +DatabricksOpenAI = pytest.importorskip("databricks_openai").DatabricksOpenAI + +log = logging.getLogger(__name__) + +# Skip all tests if not enabled +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_OBO_INTEGRATION_TESTS") != "1", + reason="OBO integration tests disabled. Set RUN_OBO_INTEGRATION_TESTS=1 to enable.", +) + +_MAX_RETRIES = 3 +_MAX_WARMUP_ATTEMPTS = 10 +_WARMUP_INTERVAL = 30 # seconds between warmup attempts (5 min total) +_PROMPT = "Call the whoami tool and respond with ONLY the raw result. Do not add any other text." + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _invoke_agent(client: DatabricksOpenAI, model: str) -> str: + """Invoke the agent and return the response text, with retry logic.""" + last_exc = None + for attempt in range(_MAX_RETRIES): + try: + response = client.responses.create( + model=model, + input=[{"role": "user", "content": _PROMPT}], + ) + # Extract text from response output items + parts = [] + for item in response.output: + if hasattr(item, "text"): + parts.append(item.text) + elif hasattr(item, "content") and isinstance(item.content, list): + for content_item in item.content: + if hasattr(content_item, "text"): + parts.append(content_item.text) + text = " ".join(parts) + assert text, f"Agent returned empty response: {response.output}" + return text + except Exception as exc: + last_exc = exc + if attempt < _MAX_RETRIES - 1: + log.warning("Attempt %d/%d failed: %s — retrying", attempt + 1, _MAX_RETRIES, exc) + time.sleep(2) + raise last_exc # type: ignore[misc] + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def sp_a_workspace_client(): + """SP-A WorkspaceClient using default DATABRICKS_CLIENT_ID/SECRET.""" + return WorkspaceClient() + + +@pytest.fixture(scope="module") +def sp_b_workspace_client(): + """SP-B WorkspaceClient using OBO_TEST_CLIENT_ID/SECRET.""" + client_id = os.environ.get("OBO_TEST_CLIENT_ID") + client_secret = os.environ.get("OBO_TEST_CLIENT_SECRET") + host = os.environ.get("DATABRICKS_HOST") + if not all([client_id, client_secret, host]): + pytest.skip("OBO_TEST_CLIENT_ID, OBO_TEST_CLIENT_SECRET, and DATABRICKS_HOST must be set") + return WorkspaceClient(host=host, client_id=client_id, client_secret=client_secret) + + +@pytest.fixture(scope="module") +def sp_a_identity(sp_a_workspace_client): + """SP-A's display name.""" + return sp_a_workspace_client.current_user.me().display_name + + +@pytest.fixture(scope="module") +def sp_b_identity(): + """SP-B's client ID — the value whoami()/current_user() returns for an SP.""" + return os.environ["OBO_TEST_CLIENT_ID"] + + +@pytest.fixture(scope="module") +def sp_a_client(sp_a_workspace_client): + """DatabricksOpenAI client authenticated as SP-A.""" + return DatabricksOpenAI(workspace_client=sp_a_workspace_client) + + +@pytest.fixture(scope="module") +def sp_b_client(sp_b_workspace_client): + """DatabricksOpenAI client authenticated as SP-B.""" + return DatabricksOpenAI(workspace_client=sp_b_workspace_client) + + +@pytest.fixture(scope="module") +def serving_endpoint(): + """Pre-deployed Model Serving endpoint name.""" + name = os.environ.get("OBO_TEST_SERVING_ENDPOINT") + if not name: + pytest.skip("OBO_TEST_SERVING_ENDPOINT must be set") + return name + + +@pytest.fixture(scope="module") +def serving_endpoint_ready(sp_a_client, serving_endpoint): + """Warm up the serving endpoint (may be scaled to zero) before tests.""" + for attempt in range(_MAX_WARMUP_ATTEMPTS): + try: + sp_a_client.responses.create( + model=serving_endpoint, + input=[{"role": "user", "content": "ping"}], + ) + log.info("Serving endpoint is warm after %d attempt(s)", attempt + 1) + return + except Exception as exc: + log.info( + "Warmup attempt %d/%d: %s — waiting %ds", + attempt + 1, + _MAX_WARMUP_ATTEMPTS, + exc, + _WARMUP_INTERVAL, + ) + time.sleep(_WARMUP_INTERVAL) + pytest.fail( + f"Serving endpoint '{serving_endpoint}' did not scale up within " + f"{_MAX_WARMUP_ATTEMPTS * _WARMUP_INTERVAL}s" + ) + + +@pytest.fixture(scope="module") +def app_name(): + """Pre-deployed Databricks App name.""" + name = os.environ.get("OBO_TEST_APP_NAME") + if not name: + pytest.skip("OBO_TEST_APP_NAME must be set") + return name + + +# ============================================================================= +# Tests: Model Serving OBO +# ============================================================================= + + +@pytest.mark.obo +class TestModelServingOBO: + """Invoke a pre-deployed Model Serving agent as two different SPs.""" + + def test_sp_a_and_sp_b_see_different_identities( + self, sp_a_client, sp_b_client, serving_endpoint, serving_endpoint_ready + ): + sp_a_response = _invoke_agent(sp_a_client, serving_endpoint) + sp_b_response = _invoke_agent(sp_b_client, serving_endpoint) + assert sp_a_response != sp_b_response, ( + "SP-A and SP-B should see different identities from whoami()" + ) + + def test_sp_b_sees_own_identity( + self, sp_b_client, sp_b_identity, serving_endpoint, serving_endpoint_ready + ): + response = _invoke_agent(sp_b_client, serving_endpoint) + assert sp_b_identity in response, ( + f"Expected SP-B identity '{sp_b_identity}' in response, got: {response}" + ) + + +# ============================================================================= +# Tests: Databricks Apps OBO +# ============================================================================= + + +@pytest.mark.obo +class TestAppsOBO: + """Invoke a pre-deployed Databricks App agent as two different SPs.""" + + def test_sp_a_and_sp_b_see_different_identities(self, sp_a_client, sp_b_client, app_name): + model = f"apps/{app_name}" + sp_a_response = _invoke_agent(sp_a_client, model) + sp_b_response = _invoke_agent(sp_b_client, model) + assert sp_a_response != sp_b_response, ( + "SP-A and SP-B should see different identities from whoami()" + ) + + def test_sp_b_sees_own_identity(self, sp_b_client, sp_b_identity, app_name): + model = f"apps/{app_name}" + response = _invoke_agent(sp_b_client, model) + assert sp_b_identity in response, ( + f"Expected SP-B identity '{sp_b_identity}' in response, got: {response}" + )