Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:

- name: Run ruff check
run: |
uvx ruff check \
uvx ruff@0.11.0 check \
--select=E,F,B,PIE \
--ignore=E401,E402,F401,F403,B017,B904,ANN,TCH \
--line-length=120 \
Expand All @@ -26,7 +26,7 @@ jobs:
- name: Check formatting
run: |
uvx ruff format --check \
uvx ruff@0.11.0 format --check \
--line-length=120 \
--target-version=py311 \
databricks-tools-core/ databricks-mcp-server/ .test/src/
Expand Down
8 changes: 2 additions & 6 deletions databricks-tools-core/databricks_tools_core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def get_workspace_client() -> WorkspaceClient:
# Cross-workspace: explicit token overrides env OAuth so tool operations
# target the caller-specified workspace instead of the app's own workspace
if force and host and token:
return tag_client(
WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)
)
return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))

# In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M.
# Setting auth_type="oauth-m2m" prevents the SDK from also reading
Expand All @@ -185,9 +183,7 @@ def get_workspace_client() -> WorkspaceClient:

# Development mode: use explicit token if provided
if host and token:
return tag_client(
WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs)
)
return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))

if host:
return tag_client(WorkspaceClient(host=host, **product_kwargs))
Expand Down
54 changes: 21 additions & 33 deletions databricks-tools-core/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def test_executor_without_query_tags_omits_from_api(self, mock_get_client):
assert "query_tags" not in call_kwargs


def _make_warehouse(id, name, state, creator_name="other@example.com",
enable_serverless_compute=False):
def _make_warehouse(id, name, state, creator_name="other@example.com", enable_serverless_compute=False):
"""Helper to create a mock warehouse object."""
w = mock.Mock()
w.id = id
Expand All @@ -141,33 +140,29 @@ class TestSortWithinTier:
def test_serverless_first(self):
"""Serverless warehouses should come before classic ones."""
classic = _make_warehouse("c1", "Classic WH", State.RUNNING)
serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, enable_serverless_compute=True)
result = _sort_within_tier([classic, serverless], current_user=None)
assert result[0].id == "s1"
assert result[1].id == "c1"

def test_serverless_before_user_owned(self):
"""Serverless should be preferred over user-owned classic."""
classic_owned = _make_warehouse("c1", "My WH", State.RUNNING,
creator_name="me@example.com")
serverless_other = _make_warehouse("s1", "Other WH", State.RUNNING,
creator_name="other@example.com",
enable_serverless_compute=True)
result = _sort_within_tier([classic_owned, serverless_other],
current_user="me@example.com")
classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, creator_name="me@example.com")
serverless_other = _make_warehouse(
"s1", "Other WH", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
)
result = _sort_within_tier([classic_owned, serverless_other], current_user="me@example.com")
assert result[0].id == "s1"

def test_serverless_user_owned_first(self):
"""Among serverless, user-owned should come first."""
serverless_other = _make_warehouse("s1", "Other Serverless", State.RUNNING,
creator_name="other@example.com",
enable_serverless_compute=True)
serverless_owned = _make_warehouse("s2", "My Serverless", State.RUNNING,
creator_name="me@example.com",
enable_serverless_compute=True)
result = _sort_within_tier([serverless_other, serverless_owned],
current_user="me@example.com")
serverless_other = _make_warehouse(
"s1", "Other Serverless", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
)
serverless_owned = _make_warehouse(
"s2", "My Serverless", State.RUNNING, creator_name="me@example.com", enable_serverless_compute=True
)
result = _sort_within_tier([serverless_other, serverless_owned], current_user="me@example.com")
assert result[0].id == "s2"
assert result[1].id == "s1"

Expand All @@ -177,53 +172,46 @@ def test_empty_list(self):
def test_no_current_user(self):
"""Without a current user, only serverless preference applies."""
classic = _make_warehouse("c1", "Classic", State.RUNNING)
serverless = _make_warehouse("s1", "Serverless", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Serverless", State.RUNNING, enable_serverless_compute=True)
result = _sort_within_tier([classic, serverless], current_user=None)
assert result[0].id == "s1"


class TestGetBestWarehouseServerless:
"""Tests for serverless preference in get_best_warehouse."""

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_user):
"""Among running shared warehouses, serverless should be picked."""
classic_shared = _make_warehouse("c1", "Shared WH", State.RUNNING)
serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING,
enable_serverless_compute=True)
serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [classic_shared, serverless_shared]
mock_client_fn.return_value = mock_client

result = get_best_warehouse()
assert result == "s1"

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user):
"""Among running non-shared warehouses, serverless should be picked."""
classic = _make_warehouse("c1", "My WH", State.RUNNING)
serverless = _make_warehouse("s1", "Fast WH", State.RUNNING,
enable_serverless_compute=True)
serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [classic, serverless]
mock_client_fn.return_value = mock_client

result = get_best_warehouse()
assert result == "s1"

@mock.patch("databricks_tools_core.sql.warehouse.get_current_username",
return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
@mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
def test_tier_order_preserved_over_serverless(self, mock_client_fn, mock_user):
"""A running shared classic should still beat a stopped serverless."""
running_shared_classic = _make_warehouse("c1", "Shared WH", State.RUNNING)
stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED,
enable_serverless_compute=True)
stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, enable_serverless_compute=True)
mock_client = mock.Mock()
mock_client.warehouses.list.return_value = [stopped_serverless, running_shared_classic]
mock_client_fn.return_value = mock_client
Expand Down
Loading