diff --git a/databricks-mcp-server/databricks_mcp_server/tools/genie.py b/databricks-mcp-server/databricks_mcp_server/tools/genie.py index 5a606ecf..0a622b3e 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/genie.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/genie.py @@ -42,6 +42,7 @@ def create_or_update_genie( description: Optional[str] = None, sample_questions: Optional[List[str]] = None, space_id: Optional[str] = None, + serialized_space: Optional[str] = None, ) -> Dict[str, Any]: """ Create or update a Genie Space for SQL-based data exploration. @@ -58,6 +59,9 @@ def create_or_update_genie( description: Optional description of what the Genie space does sample_questions: Optional list of sample questions to help users space_id: Optional existing space_id to update instead of create + serialized_space: Optional JSON string containing full space configuration + (settings, instructions). Use this to import/clone a Genie space + exported via get_genie with include_serialized_space=True. Returns: Dictionary with: @@ -75,6 +79,14 @@ def create_or_update_genie( ... sample_questions=["What were total sales last month?"] ... ) {"space_id": "abc123...", "display_name": "Sales Analytics", "operation": "created", ...} + + Clone a space: + >>> source = get_genie(space_id="abc123", include_serialized_space=True) + >>> create_or_update_genie( + ... display_name="Sales Analytics (Copy)", + ... table_identifiers=source["table_identifiers"], + ... serialized_space=source["serialized_space"] + ... ) """ try: description = with_description_footer(description) @@ -99,6 +111,7 @@ def create_or_update_genie( warehouse_id=warehouse_id, table_identifiers=table_identifiers, sample_questions=sample_questions, + serialized_space=serialized_space, ) else: return {"error": f"Genie space {space_id} not found"} @@ -113,6 +126,7 @@ def create_or_update_genie( warehouse_id=warehouse_id, table_identifiers=table_identifiers, sample_questions=sample_questions, + serialized_space=serialized_space, ) space_id = existing.space_id else: @@ -121,6 +135,7 @@ def create_or_update_genie( warehouse_id=warehouse_id, table_identifiers=table_identifiers, description=description, + serialized_space=serialized_space, ) space_id = result.get("space_id", "") @@ -154,7 +169,10 @@ def create_or_update_genie( @mcp.tool -def get_genie(space_id: Optional[str] = None) -> Dict[str, Any]: +def get_genie( + space_id: Optional[str] = None, + include_serialized_space: bool = False, +) -> Dict[str, Any]: """ Get details of a Genie Space, or list all spaces. @@ -163,6 +181,9 @@ def get_genie(space_id: Optional[str] = None) -> Dict[str, Any]: Args: space_id: The Genie space ID. If omitted, lists all spaces. + include_serialized_space: If True, includes the serialized_space field + containing the full space configuration (settings, instructions). + Useful for exporting a space to clone or import elsewhere. Returns: Single space dict (if space_id provided) or {"spaces": [...]}. @@ -173,11 +194,15 @@ def get_genie(space_id: Optional[str] = None) -> Dict[str, Any]: >>> get_genie() {"spaces": [{"space_id": "abc123...", "title": "Sales Analytics", ...}, ...]} + + Export for cloning: + >>> get_genie("abc123...", include_serialized_space=True) + {"space_id": "abc123...", ..., "serialized_space": "{...}"} """ if space_id: try: manager = _get_manager() - result = manager.genie_get(space_id) + result = manager.genie_get(space_id, include_serialized_space=include_serialized_space) if not result: return {"error": f"Genie space {space_id} not found"} @@ -185,7 +210,7 @@ def get_genie(space_id: Optional[str] = None) -> Dict[str, Any]: questions_response = manager.genie_list_questions(space_id, question_type="SAMPLE_QUESTION") sample_questions = [q.get("question_text", "") for q in questions_response.get("curated_questions", [])] - return { + response = { "space_id": result.get("space_id", space_id), "display_name": result.get("display_name", ""), "description": result.get("description", ""), @@ -193,6 +218,11 @@ def get_genie(space_id: Optional[str] = None) -> Dict[str, Any]: "table_identifiers": result.get("table_identifiers", []), "sample_questions": sample_questions, } + + if include_serialized_space and result.get("serialized_space"): + response["serialized_space"] = result["serialized_space"] + + return response except Exception as e: return {"error": f"Failed to get Genie space {space_id}: {e}"} @@ -246,6 +276,192 @@ def delete_genie(space_id: str) -> Dict[str, Any]: return {"success": False, "space_id": space_id, "error": str(e)} +@mcp.tool +def clone_genie( + source_space_id: str, + new_display_name: str, + warehouse_id: Optional[str] = None, + description: Optional[str] = None, +) -> Dict[str, Any]: + """ + Clone a Genie Space by exporting its full configuration and creating a new space. + + Exports the source space (including settings, instructions, sample questions) + and imports it as a new space. Useful for promoting spaces across environments + or creating variants for different teams. + + Args: + source_space_id: The Genie space ID to clone from + new_display_name: Display name for the cloned space + warehouse_id: Optional warehouse ID for the clone. If not provided, + uses the same warehouse as the source space. + description: Optional description for the clone. If not provided, + uses the source space's description. + + Returns: + Dictionary with: + - space_id: The new cloned space ID + - display_name: The new display name + - source_space_id: The original space ID + - operation: 'cloned' + + Example: + >>> clone_genie( + ... source_space_id="abc123...", + ... new_display_name="Sales Analytics (Staging)", + ... ) + {"space_id": "def456...", "display_name": "Sales Analytics (Staging)", ...} + """ + try: + manager = _get_manager() + + source = manager.genie_get(source_space_id, include_serialized_space=True) + if not source: + return {"error": f"Source Genie space {source_space_id} not found"} + + target_warehouse = warehouse_id or source.get("warehouse_id") + if not target_warehouse: + target_warehouse = manager.get_best_warehouse_id() + if not target_warehouse: + return {"error": "No SQL warehouses available. Please provide a warehouse_id."} + + target_description = description or source.get("description", "") + + result = manager.genie_create( + display_name=new_display_name, + warehouse_id=target_warehouse, + table_identifiers=source.get("table_identifiers", []), + description=target_description, + serialized_space=source.get("serialized_space"), + ) + + new_space_id = result.get("space_id", "") + + try: + if new_space_id: + from ..manifest import track_resource + + track_resource( + resource_type="genie_space", + name=new_display_name, + resource_id=new_space_id, + ) + except Exception: + pass + + return { + "space_id": new_space_id, + "display_name": new_display_name, + "source_space_id": source_space_id, + "operation": "cloned", + "warehouse_id": target_warehouse, + "table_count": len(source.get("table_identifiers", [])), + } + + except Exception as e: + return {"error": f"Failed to clone Genie space: {e}"} + + +@mcp.tool +def manage_genie_instructions( + space_id: str, + action: str = "list", + instruction_type: Optional[str] = None, + title: Optional[str] = None, + content: Optional[str] = None, + instructions: Optional[List[Dict[str, str]]] = None, +) -> Dict[str, Any]: + """ + Manage instructions for a Genie Space (list, add text notes, add SQL examples). + + Instructions guide how Genie interprets questions and generates SQL. + Text instructions provide general guidance; SQL instructions provide + example queries that Genie can reference. + + Args: + space_id: The Genie space ID + action: One of: + - "list": List all instructions in the space + - "add_text": Add a text instruction/note + - "add_sql": Add a SQL query example + - "add_sql_function": Add a certified SQL function + - "add_batch": Add multiple SQL instructions at once + instruction_type: Not needed — determined by action + title: Title for the instruction (required for add_text, add_sql) + content: Content of the instruction (required for add_text, add_sql, + add_sql_function) + instructions: For add_batch: list of {"title": str, "content": str} + dicts to add as SQL instructions + + Returns: + For "list": {"instructions": [...]} with all instructions + For "add_*": The created instruction dict + For "add_batch": {"added": int, "results": [...]} + + Example: + >>> manage_genie_instructions(space_id="abc123", action="list") + {"instructions": [{"title": "...", "content": "...", "instruction_type": "..."}, ...]} + + >>> manage_genie_instructions( + ... space_id="abc123", + ... action="add_text", + ... title="Date handling", + ... content="When users say 'last month', use date_trunc('month', current_date()) - interval 1 month" + ... ) + + >>> manage_genie_instructions( + ... space_id="abc123", + ... action="add_sql", + ... title="Revenue by region", + ... content="SELECT region, SUM(amount) as revenue FROM sales GROUP BY region" + ... ) + + >>> manage_genie_instructions( + ... space_id="abc123", + ... action="add_batch", + ... instructions=[ + ... {"title": "Top customers", + ... "content": "SELECT customer, SUM(amount) FROM orders GROUP BY 1 LIMIT 10"}, + ... {"title": "Monthly trend", + ... "content": "SELECT date_trunc('month', order_date), COUNT(*) FROM orders GROUP BY 1"}, + ... ] + ... ) + """ + try: + manager = _get_manager() + + if action == "list": + result = manager.genie_list_instructions(space_id) + return {"instructions": result.get("instructions", [])} + + elif action == "add_text": + if not content: + return {"error": "content is required for add_text"} + return manager.genie_add_text_instruction(space_id, content=content, title=title or "Notes") + + elif action == "add_sql": + if not title or not content: + return {"error": "title and content are required for add_sql"} + return manager.genie_add_sql_instruction(space_id, title=title, content=content) + + elif action == "add_sql_function": + if not content: + return {"error": "content (function name) is required for add_sql_function"} + return manager.genie_add_sql_function(space_id, function_name=content) + + elif action == "add_batch": + if not instructions: + return {"error": "instructions list is required for add_batch"} + results = manager.genie_add_sql_instructions_batch(space_id, instructions) + return {"added": len(results), "results": results} + + else: + return {"error": f"Unknown action '{action}'. Use: list, add_text, add_sql, add_sql_function, add_batch"} + + except Exception as e: + return {"error": f"Failed to manage instructions for space {space_id}: {e}"} + + # ============================================================================ # Genie Conversation API Tools # ============================================================================ diff --git a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py b/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py index 7b829bbc..dff46183 100644 --- a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py +++ b/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py @@ -793,10 +793,19 @@ def mas_list_evaluation_runs( # Genie Space Operations # ======================================================================== - def genie_get(self, space_id: str) -> Optional[GenieSpaceDict]: - """Get Genie space by ID.""" + def genie_get(self, space_id: str, include_serialized_space: bool = False) -> Optional[GenieSpaceDict]: + """Get Genie space by ID. + + Args: + space_id: The Genie space ID + include_serialized_space: If True, includes the serialized_space field + containing full space configuration (settings, instructions, etc.) + """ try: - return self._get(f"/api/2.0/data-rooms/{space_id}") + params = {} + if include_serialized_space: + params["include_serialized_space"] = "true" + return self._get(f"/api/2.0/data-rooms/{space_id}", params=params or None) except Exception as e: if "does not exist" in str(e).lower() or "not found" in str(e).lower(): return None @@ -812,6 +821,7 @@ def genie_create( parent_folder_id: Optional[str] = None, create_dir: bool = True, run_as_type: str = "VIEWER", + serialized_space: Optional[str] = None, ) -> Dict[str, Any]: """Create a Genie space. @@ -824,6 +834,8 @@ def genie_create( parent_folder_id: Optional parent folder ID create_dir: Whether to create parent folder if missing run_as_type: Run as type (default: "VIEWER") + serialized_space: Optional JSON string containing full space configuration + (settings, instructions). Used to import/clone a Genie space. Returns: Created Genie space data @@ -838,6 +850,9 @@ def genie_create( "run_as_type": run_as_type, } + if serialized_space: + room_payload["serialized_space"] = serialized_space + if description: room_payload["description"] = description @@ -869,6 +884,7 @@ def genie_update( warehouse_id: Optional[str] = None, table_identifiers: Optional[List[str]] = None, sample_questions: Optional[List[str]] = None, + serialized_space: Optional[str] = None, ) -> Dict[str, Any]: """Update a Genie space. @@ -879,6 +895,8 @@ def genie_update( warehouse_id: Optional new warehouse ID table_identifiers: Optional new table identifiers sample_questions: Optional sample questions (replaces all existing) + serialized_space: Optional JSON string containing full space configuration + (settings, instructions). Replaces the existing configuration. Returns: Updated Genie space data @@ -913,6 +931,9 @@ def genie_update( if current_space.get(field): update_payload[field] = current_space[field] + if serialized_space: + update_payload["serialized_space"] = serialized_space + result = self._patch(f"/api/2.0/data-rooms/{space_id}", update_payload) if sample_questions is not None: diff --git a/databricks-tools-core/databricks_tools_core/auth.py b/databricks-tools-core/databricks_tools_core/auth.py index 21913983..c3db9fb4 100644 --- a/databricks-tools-core/databricks_tools_core/auth.py +++ b/databricks-tools-core/databricks_tools_core/auth.py @@ -160,9 +160,7 @@ def get_workspace_client() -> WorkspaceClient: # Cross-workspace: explicit token overrides env OAuth so tool operations # target the caller-specified workspace instead of the app's own workspace if force and host and token: - return tag_client( - WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs) - ) + return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)) # In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M. # Setting auth_type="oauth-m2m" prevents the SDK from also reading @@ -185,9 +183,7 @@ def get_workspace_client() -> WorkspaceClient: # Development mode: use explicit token if provided if host and token: - return tag_client( - WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs) - ) + return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)) if host: return tag_client(WorkspaceClient(host=host, **product_kwargs)) diff --git a/databricks-tools-core/tests/unit/test_sql.py b/databricks-tools-core/tests/unit/test_sql.py index d1b661c6..42137ba5 100644 --- a/databricks-tools-core/tests/unit/test_sql.py +++ b/databricks-tools-core/tests/unit/test_sql.py @@ -121,8 +121,7 @@ def test_executor_without_query_tags_omits_from_api(self, mock_get_client): assert "query_tags" not in call_kwargs -def _make_warehouse(id, name, state, creator_name="other@example.com", - enable_serverless_compute=False): +def _make_warehouse(id, name, state, creator_name="other@example.com", enable_serverless_compute=False): """Helper to create a mock warehouse object.""" w = mock.Mock() w.id = id @@ -141,33 +140,29 @@ class TestSortWithinTier: def test_serverless_first(self): """Serverless warehouses should come before classic ones.""" classic = _make_warehouse("c1", "Classic WH", State.RUNNING) - serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, enable_serverless_compute=True) result = _sort_within_tier([classic, serverless], current_user=None) assert result[0].id == "s1" assert result[1].id == "c1" def test_serverless_before_user_owned(self): """Serverless should be preferred over user-owned classic.""" - classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, - creator_name="me@example.com") - serverless_other = _make_warehouse("s1", "Other WH", State.RUNNING, - creator_name="other@example.com", - enable_serverless_compute=True) - result = _sort_within_tier([classic_owned, serverless_other], - current_user="me@example.com") + classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, creator_name="me@example.com") + serverless_other = _make_warehouse( + "s1", "Other WH", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True + ) + result = _sort_within_tier([classic_owned, serverless_other], current_user="me@example.com") assert result[0].id == "s1" def test_serverless_user_owned_first(self): """Among serverless, user-owned should come first.""" - serverless_other = _make_warehouse("s1", "Other Serverless", State.RUNNING, - creator_name="other@example.com", - enable_serverless_compute=True) - serverless_owned = _make_warehouse("s2", "My Serverless", State.RUNNING, - creator_name="me@example.com", - enable_serverless_compute=True) - result = _sort_within_tier([serverless_other, serverless_owned], - current_user="me@example.com") + serverless_other = _make_warehouse( + "s1", "Other Serverless", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True + ) + serverless_owned = _make_warehouse( + "s2", "My Serverless", State.RUNNING, creator_name="me@example.com", enable_serverless_compute=True + ) + result = _sort_within_tier([serverless_other, serverless_owned], current_user="me@example.com") assert result[0].id == "s2" assert result[1].id == "s1" @@ -177,8 +172,7 @@ def test_empty_list(self): def test_no_current_user(self): """Without a current user, only serverless preference applies.""" classic = _make_warehouse("c1", "Classic", State.RUNNING) - serverless = _make_warehouse("s1", "Serverless", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Serverless", State.RUNNING, enable_serverless_compute=True) result = _sort_within_tier([classic, serverless], current_user=None) assert result[0].id == "s1" @@ -186,14 +180,12 @@ def test_no_current_user(self): class TestGetBestWarehouseServerless: """Tests for serverless preference in get_best_warehouse.""" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_user): """Among running shared warehouses, serverless should be picked.""" classic_shared = _make_warehouse("c1", "Shared WH", State.RUNNING) - serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, - enable_serverless_compute=True) + serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [classic_shared, serverless_shared] mock_client_fn.return_value = mock_client @@ -201,14 +193,12 @@ def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_use result = get_best_warehouse() assert result == "s1" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user): """Among running non-shared warehouses, serverless should be picked.""" classic = _make_warehouse("c1", "My WH", State.RUNNING) - serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, - enable_serverless_compute=True) + serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [classic, serverless] mock_client_fn.return_value = mock_client @@ -216,14 +206,12 @@ def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user result = get_best_warehouse() assert result == "s1" - @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", - return_value="me@example.com") + @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com") @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client") def test_tier_order_preserved_over_serverless(self, mock_client_fn, mock_user): """A running shared classic should still beat a stopped serverless.""" running_shared_classic = _make_warehouse("c1", "Shared WH", State.RUNNING) - stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, - enable_serverless_compute=True) + stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, enable_serverless_compute=True) mock_client = mock.Mock() mock_client.warehouses.list.return_value = [stopped_serverless, running_shared_classic] mock_client_fn.return_value = mock_client