Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from databricks_tools_core.agent_bricks import (
AgentBricksManager,
EndpointStatus,
TileType,
get_tile_example_queue,
)
from databricks_tools_core.identity import with_description_footer
Expand Down Expand Up @@ -464,6 +465,38 @@ def _mas_delete(tile_id: str) -> Dict[str, Any]:
return {"success": False, "tile_id": tile_id, "error": str(e)}


def _ka_list() -> Dict[str, Any]:
"""List all Knowledge Assistants in the workspace."""
manager = _get_manager()
tiles = manager.list_all_agent_bricks(tile_type=TileType.KA)
kas = []
for tile in tiles:
kas.append(
{
"tile_id": tile.get("tile_id", ""),
"name": tile.get("name", ""),
"description": tile.get("description", ""),
}
)
return {"knowledge_assistants": kas, "count": len(kas)}


def _mas_list() -> Dict[str, Any]:
"""List all Supervisor Agents in the workspace."""
manager = _get_manager()
tiles = manager.list_all_agent_bricks(tile_type=TileType.MAS)
agents = []
for tile in tiles:
agents.append(
{
"tile_id": tile.get("tile_id", ""),
"name": tile.get("name", ""),
"description": tile.get("description", ""),
}
)
return {"supervisor_agents": agents, "count": len(agents)}


# ============================================================================
# Consolidated MCP Tools
# ============================================================================
Expand All @@ -486,13 +519,14 @@ def manage_ka(
questions from indexed documents (PDFs, text files, etc.).

Actions:
- list: List all Knowledge Assistants in the workspace
- create_or_update: Create or update a KA (requires name, volume_path)
- get: Get KA details by tile_id
- find_by_name: Find a KA by exact name
- delete: Delete a KA by tile_id

Args:
action: "create_or_update", "get", "find_by_name", or "delete"
action: "list", "create_or_update", "get", "find_by_name", or "delete"
name: Name for the KA (for create_or_update, find_by_name)
volume_path: Path to the volume folder containing documents
(e.g., "/Volumes/catalog/schema/volume/folder") (for create_or_update)
Expand All @@ -504,12 +538,16 @@ def manage_ka(

Returns:
Dict with operation result. Varies by action:
- list: knowledge_assistants (list of {tile_id, name, description}), count
- create_or_update: tile_id, name, operation, endpoint_status, examples_queued
- get: tile_id, name, description, endpoint_status, knowledge_sources, examples_count
- find_by_name: found, tile_id, name, endpoint_name, endpoint_status
- delete: success, tile_id

Example:
>>> manage_ka(action="list")
{"knowledge_assistants": [{"tile_id": "...", "name": "...", ...}], "count": 3}

>>> manage_ka(
... action="create_or_update",
... name="HR Policy Assistant",
Expand All @@ -527,7 +565,9 @@ def manage_ka(
"""
action = action.lower()

if action == "create_or_update":
if action == "list":
return _ka_list()
elif action == "create_or_update":
return _ka_create_or_update(
name=name,
volume_path=volume_path,
Expand All @@ -543,7 +583,8 @@ def manage_ka(
elif action == "delete":
return _ka_delete(tile_id=tile_id)
else:
return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"}
valid = "list, create_or_update, get, find_by_name, delete"
return {"error": f"Invalid action '{action}'. Must be one of: {valid}"}


@mcp.tool
Expand All @@ -564,13 +605,14 @@ def manage_mas(
Genie spaces, Knowledge Assistants, UC functions, and external MCP servers as agents.

Actions:
- list: List all Supervisor Agents in the workspace
- create_or_update: Create or update a Supervisor Agent (requires name, agents)
- get: Get Supervisor Agent details by tile_id
- find_by_name: Find a Supervisor Agent by exact name
- delete: Delete a Supervisor Agent by tile_id

Args:
action: "create_or_update", "get", "find_by_name", or "delete"
action: "list", "create_or_update", "get", "find_by_name", or "delete"
name: Name for the Supervisor Agent (for create_or_update, find_by_name)
agents: List of agent configurations (for create_or_update). Each agent requires:
- name: Agent identifier (used internally for routing)
Expand All @@ -591,12 +633,16 @@ def manage_mas(

Returns:
Dict with operation result. Varies by action:
- list: supervisor_agents (list of {tile_id, name, description}), count
- create_or_update: tile_id, name, operation, endpoint_status, agents_count
- get: tile_id, name, description, endpoint_status, agents, examples_count
- find_by_name: found, tile_id, name, endpoint_status, agents_count
- delete: success, tile_id

Example:
>>> manage_mas(action="list")
{"supervisor_agents": [{"tile_id": "...", "name": "...", ...}], "count": 2}

>>> manage_mas(
... action="create_or_update",
... name="Customer Support MAS",
Expand Down Expand Up @@ -625,7 +671,9 @@ def manage_mas(
"""
action = action.lower()

if action == "create_or_update":
if action == "list":
return _mas_list()
elif action == "create_or_update":
return _mas_create_or_update(
name=name,
agents=agents,
Expand All @@ -641,4 +689,5 @@ def manage_mas(
elif action == "delete":
return _mas_delete(tile_id=tile_id)
else:
return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"}
valid = "list, create_or_update, get, find_by_name, delete"
return {"error": f"Invalid action '{action}'. Must be one of: {valid}"}
8 changes: 2 additions & 6 deletions databricks-tools-core/databricks_tools_core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def get_workspace_client() -> WorkspaceClient:
# Cross-workspace: explicit token overrides env OAuth so tool operations
# target the caller-specified workspace instead of the app's own workspace
if force and host and token:
return tag_client(
WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)
)
return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))

# In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M.
# Setting auth_type="oauth-m2m" prevents the SDK from also reading
Expand All @@ -185,9 +183,7 @@ def get_workspace_client() -> WorkspaceClient:

# Development mode: use explicit token if provided
if host and token:
return tag_client(
WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)
)
return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))

if host:
return tag_client(WorkspaceClient(host=host, **product_kwargs))
Expand Down
54 changes: 21 additions & 33 deletions databricks-tools-core/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def test_executor_without_query_tags_omits_from_api(self, mock_get_client):
assert "query_tags" not in call_kwargs


def _make_warehouse(id, name, state, creator_name="other@example.com",
enable_serverless_compute=False):
def _make_warehouse(id, name, state, creator_name="other@example.com", enable_serverless_compute=False):
"""Helper to create a mock warehouse object."""
w = mock.Mock()
w.id = id
Expand All @@ -141,33 +140,29 @@ class TestSortWithinTier:
def test_serverless_first(self):
"""Serverless warehouses should come before classic ones."""
classic = _make_warehouse("c1", "Classic WH", State.RUNNING)
serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, enable_serverless_compute=True)
result = _sort_within_tier([classic, serverless], current_user=None)
assert result[0].id == "s1"
assert result[1].id == "c1"

def test_serverless_before_user_owned(self):
"""Serverless should be preferred over user-owned classic."""
classic_owned = _make_warehouse("c1", "My WH", State.RUNNING,
creator_name="me@example.com")
serverless_other = _make_warehouse("s1", "Other WH", State.RUNNING,
creator_name="other@example.com",
enable_serverless_compute=True)
result = _sort_within_tier([classic_owned, serverless_other],
current_user="me@example.com")
classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, creator_name="me@example.com")
serverless_other = _make_warehouse(
"s1", "Other WH", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
)
result = _sort_within_tier([classic_owned, serverless_other], current_user="me@example.com")
assert result[0].id == "s1"

def test_serverless_user_owned_first(self):
"""Among serverless, user-owned should come first."""
serverless_other = _make_warehouse("s1", "Other Serverless", State.RUNNING,
creator_name="other@example.com",
enable_serverless_compute=True)
serverless_owned = _make_warehouse("s2", "My Serverless", State.RUNNING,
creator_name="me@example.com",
enable_serverless_compute=True)
result = _sort_within_tier([serverless_other, serverless_owned],
current_user="me@example.com")
serverless_other = _make_warehouse(
"s1", "Other Serverless", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
)
serverless_owned = _make_warehouse(
"s2", "My Serverless", State.RUNNING, creator_name="me@example.com", enable_serverless_compute=True
)
result = _sort_within_tier([serverless_other, serverless_owned], current_user="me@example.com")
assert result[0].id == "s2"
assert result[1].id == "s1"

Expand All @@ -177,53 +172,46 @@ def test_empty_list(self):
def test_no_current_user(self):
"""Without a current user, only serverless preference applies."""
classic = _make_warehouse("c1", "Classic", State.RUNNING)
serverless = _make_warehouse("s1", "Serverless", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Serverless", State.RUNNING, enable_serverless_compute=True)
result = _sort_within_tier([classic, serverless], current_user=None)
assert result[0].id == "s1"


class TestGetBestWarehouseServerless:
"""Tests for serverless preference in get_best_warehouse."""

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_user):
"""Among running shared warehouses, serverless should be picked."""
classic_shared = _make_warehouse("c1", "Shared WH", State.RUNNING)
serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING,
enable_serverless_compute=True)
serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [classic_shared, serverless_shared]
mock_client_fn.return_value = mock_client

result = get_best_warehouse()
assert result == "s1"

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user):
"""Among running non-shared warehouses, serverless should be picked."""
classic = _make_warehouse("c1", "My WH", State.RUNNING)
serverless = _make_warehouse("s1", "Fast WH", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [classic, serverless]
mock_client_fn.return_value = mock_client

result = get_best_warehouse()
assert result == "s1"

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_tier_order_preserved_over_serverless(self, mock_client_fn, mock_user):
"""A running shared classic should still beat a stopped serverless."""
running_shared_classic = _make_warehouse("c1", "Shared WH", State.RUNNING)
stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED,
enable_serverless_compute=True)
stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [stopped_serverless, running_shared_classic]
mock_client_fn.return_value = mock_client
Expand Down