From c48c27101b6ed2aa9832697cc27837cd39db348a Mon Sep 17 00:00:00 2001 From: Steven Tan Date: Tue, 10 Mar 2026 10:56:44 +0800 Subject: [PATCH] Add list action to manage_ka and manage_mas MCP tools Adds a "list" action to both manage_ka and manage_mas tools, enabling users to discover all Knowledge Assistants and Supervisor Agents in their workspace. Uses the existing list_all_agent_bricks method with TileType filtering. Also fixes pre-existing ruff format issues in auth.py and test_sql.py. --- .../tools/agent_bricks.py | 61 +++++++++++++++++-- .../databricks_tools_core/auth.py | 8 +-- databricks-tools-core/tests/unit/test_sql.py | 54 +++++++--------- 3 files changed, 78 insertions(+), 45 deletions(-) diff --git a/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py b/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py index 3dfb21c9..ae5d5598 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py @@ -8,6 +8,7 @@ from databricks_tools_core.agent_bricks import ( AgentBricksManager, EndpointStatus, + TileType, get_tile_example_queue, ) from databricks_tools_core.identity import with_description_footer @@ -464,6 +465,38 @@ def _mas_delete(tile_id: str) -> Dict[str, Any]: return {"success": False, "tile_id": tile_id, "error": str(e)} +def _ka_list() -> Dict[str, Any]: + """List all Knowledge Assistants in the workspace.""" + manager = _get_manager() + tiles = manager.list_all_agent_bricks(tile_type=TileType.KA) + kas = [] + for tile in tiles: + kas.append( + { + "tile_id": tile.get("tile_id", ""), + "name": tile.get("name", ""), + "description": tile.get("description", ""), + } + ) + return {"knowledge_assistants": kas, "count": len(kas)} + + +def _mas_list() -> Dict[str, Any]: + """List all Supervisor Agents in the workspace.""" + manager = _get_manager() + tiles = manager.list_all_agent_bricks(tile_type=TileType.MAS) + agents = [] + for tile in tiles: + agents.append( + { + "tile_id": tile.get("tile_id", ""), + "name": tile.get("name", ""), + "description": tile.get("description", ""), + } + ) + return {"supervisor_agents": agents, "count": len(agents)} + + # ============================================================================ # Consolidated MCP Tools # ============================================================================ @@ -486,13 +519,14 @@ def manage_ka( questions from indexed documents (PDFs, text files, etc.). Actions: + - list: List all Knowledge Assistants in the workspace - create_or_update: Create or update a KA (requires name, volume_path) - get: Get KA details by tile_id - find_by_name: Find a KA by exact name - delete: Delete a KA by tile_id Args: - action: "create_or_update", "get", "find_by_name", or "delete" + action: "list", "create_or_update", "get", "find_by_name", or "delete" name: Name for the KA (for create_or_update, find_by_name) volume_path: Path to the volume folder containing documents (e.g., "/Volumes/catalog/schema/volume/folder") (for create_or_update) @@ -504,12 +538,16 @@ def manage_ka( Returns: Dict with operation result. Varies by action: + - list: knowledge_assistants (list of {tile_id, name, description}), count - create_or_update: tile_id, name, operation, endpoint_status, examples_queued - get: tile_id, name, description, endpoint_status, knowledge_sources, examples_count - find_by_name: found, tile_id, name, endpoint_name, endpoint_status - delete: success, tile_id Example: + >>> manage_ka(action="list") + {"knowledge_assistants": [{"tile_id": "...", "name": "...", ...}], "count": 3} + >>> manage_ka( ... action="create_or_update", ... name="HR Policy Assistant", @@ -527,7 +565,9 @@ def manage_ka( """ action = action.lower() - if action == "create_or_update": + if action == "list": + return _ka_list() + elif action == "create_or_update": return _ka_create_or_update( name=name, volume_path=volume_path, @@ -543,7 +583,8 @@ def manage_ka( elif action == "delete": return _ka_delete(tile_id=tile_id) else: - return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"} + valid = "list, create_or_update, get, find_by_name, delete" + return {"error": f"Invalid action '{action}'. Must be one of: {valid}"} @mcp.tool @@ -564,13 +605,14 @@ def manage_mas( Genie spaces, Knowledge Assistants, UC functions, and external MCP servers as agents. Actions: + - list: List all Supervisor Agents in the workspace - create_or_update: Create or update a Supervisor Agent (requires name, agents) - get: Get Supervisor Agent details by tile_id - find_by_name: Find a Supervisor Agent by exact name - delete: Delete a Supervisor Agent by tile_id Args: - action: "create_or_update", "get", "find_by_name", or "delete" + action: "list", "create_or_update", "get", "find_by_name", or "delete" name: Name for the Supervisor Agent (for create_or_update, find_by_name) agents: List of agent configurations (for create_or_update). Each agent requires: - name: Agent identifier (used internally for routing) @@ -591,12 +633,16 @@ def manage_mas( Returns: Dict with operation result. Varies by action: + - list: supervisor_agents (list of {tile_id, name, description}), count - create_or_update: tile_id, name, operation, endpoint_status, agents_count - get: tile_id, name, description, endpoint_status, agents, examples_count - find_by_name: found, tile_id, name, endpoint_status, agents_count - delete: success, tile_id Example: + >>> manage_mas(action="list") + {"supervisor_agents": [{"tile_id": "...", "name": "...", ...}], "count": 2} + >>> manage_mas( ... action="create_or_update", ... name="Customer Support MAS", @@ -625,7 +671,9 @@ def manage_mas( """ action = action.lower() - if action == "create_or_update": + if action == "list": + return _mas_list() + elif action == "create_or_update": return _mas_create_or_update( name=name, agents=agents, @@ -641,4 +689,5 @@ def manage_mas( elif action == "delete": return _mas_delete(tile_id=tile_id) else: - return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"} + valid = "list, create_or_update, get, find_by_name, delete" + return {"error": f"Invalid action '{action}'. Must be one of: {valid}"} diff --git a/databricks-tools-core/databricks_tools_core/auth.py b/databricks-tools-core/databricks_tools_core/auth.py index 21913983..c3db9fb4 100644 --- a/databricks-tools-core/databricks_tools_core/auth.py +++ b/databricks-tools-core/databricks_tools_core/auth.py @@ -160,9 +160,7 @@ def get_workspace_client() -> WorkspaceClient: # Cross-workspace: explicit token overrides env OAuth so tool operations # target the caller-specified workspace instead of the app's own workspace if force and host and token: - return tag_client( - WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs) - ) + return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)) # In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M. # Setting auth_type="oauth-m2m" prevents the SDK from also reading @@ -185,9 +183,7 @@ def get_workspace_client() -> WorkspaceClient: # Development mode: use explicit token if provided if host and token: - return tag_client( - WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs) - ) + return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)) if host: return tag_client(WorkspaceClient(host=host, **product_kwargs)) diff --git a/databricks-tools-core/tests/unit/test_sql.py b/databricks-tools-core/tests/unit/test_sql.py index d1b661c6..42137ba5 100644 --- a/databricks-tools-core/tests/unit/test_sql.py +++ b/databricks-tools-core/tests/unit/test_sql.py @@ -121,8 +121,7 @@ def test_executor_without_query_tags_omits_from_api(self, mock_get_client): assert "query_tags" not in call_kwargs -def _make_warehouse(id, name, state, creator_name="other@example.com", - enable_serverless_compute=False): +def _make_warehouse(id, name, state, creator_name="other@example.com", enable_serverless_compute=False): """Helper to create a mock warehouse object.""" w = mock.Mock() w.id = id @@ -141,33 +140,29 @@ class TestSortWithinTier: def test_serverless_first(self): """Serverless warehouses should come before classic ones.""" classic = _make_warehouse("c1", "Classic WH", State.RUNNING) - serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, enable_serverless_compute=True) result = _sort_within_tier([classic, serverless], current_user=None) assert result[0].id == "s1" assert result[1].id == "c1" def test_serverless_before_user_owned(self): """Serverless should be preferred over user-owned classic.""" - classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, - creator_name="me@example.com") - serverless_other = _make_warehouse("s1", "Other WH", State.RUNNING, - creator_name="other@example.com", - enable_serverless_compute=True) - result = _sort_within_tier([classic_owned, serverless_other], - current_user="me@example.com") + classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, creator_name="me@example.com") + serverless_other = _make_warehouse( + "s1", "Other WH", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True + ) + result = _sort_within_tier([classic_owned, serverless_other], current_user="me@example.com") assert result[0].id == "s1" def test_serverless_user_owned_first(self): """Among serverless, user-owned should come first.""" - serverless_other = _make_warehouse("s1", "Other Serverless", State.RUNNING, - creator_name="other@example.com", - enable_serverless_compute=True) - serverless_owned = _make_warehouse("s2", "My Serverless", State.RUNNING, - creator_name="me@example.com", - enable_serverless_compute=True) - result = _sort_within_tier([serverless_other, serverless_owned], - current_user="me@example.com") + serverless_other = _make_warehouse( + "s1", "Other Serverless", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True + ) + serverless_owned = _make_warehouse( + "s2", "My Serverless", State.RUNNING, creator_name="me@example.com", enable_serverless_compute=True + ) + result = _sort_within_tier([serverless_other, serverless_owned], current_user="me@example.com") assert result[0].id == "s2" assert result[1].id == "s1" @@ -177,8 +172,7 @@ def test_empty_list(self): def test_no_current_user(self): """Without a current user, only serverless preference applies.""" classic = _make_warehouse("c1", "Classic", State.RUNNING) - serverless = _make_warehouse("s1", "Serverless", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Serverless", State.RUNNING, enable_serverless_compute=True) result = _sort_within_tier([classic, serverless], current_user=None) assert result[0].id == "s1" @@ -186,14 +180,12 @@ def test_no_current_user(self): class TestGetBestWarehouseServerless: """Tests for serverless preference in get_best_warehouse.""" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_user): """Among running shared warehouses, serverless should be picked.""" classic_shared = _make_warehouse("c1", "Shared WH", State.RUNNING) - serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, - enable_serverless_compute=True) + serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [classic_shared, serverless_shared] mock_client_fn.return_value = mock_client @@ -201,14 +193,12 @@ def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_use result = get_best_warehouse() assert result == "s1" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user): """Among running non-shared warehouses, serverless should be picked.""" classic = _make_warehouse("c1", "My WH", State.RUNNING) - serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [classic, serverless] mock_client_fn.return_value = mock_client @@ -216,14 +206,12 @@ def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user result = get_best_warehouse() assert result == "s1" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_tier_order_preserved_over_serverless(self, mock_client_fn, mock_user): """A running shared classic should still beat a stopped serverless.""" running_shared_classic = _make_warehouse("c1", "Shared WH", State.RUNNING) - stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, - enable_serverless_compute=True) + stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [stopped_serverless, running_shared_classic] mock_client_fn.return_value = mock_client