11package mcp
22
33import (
4+ "context"
45 "encoding/json"
56 "errors"
67 "fmt"
8+ "os"
9+ "os/signal"
710 "strings"
11+ "syscall"
12+ "time"
813
914 "github.com/databricks/cli/cmd/root"
1015 "github.com/databricks/cli/experimental/aitools/lib/middlewares"
1116 "github.com/databricks/cli/experimental/aitools/lib/session"
1217 "github.com/databricks/cli/libs/cmdctx"
1318 "github.com/databricks/cli/libs/cmdio"
19+ "github.com/databricks/cli/libs/log"
1420 "github.com/databricks/databricks-sdk-go/service/sql"
1521 "github.com/spf13/cobra"
1622)
1723
24+ const (
25+ // pollIntervalInitial is the starting interval between status polls.
26+ pollIntervalInitial = 1 * time .Second
27+
28+ // pollIntervalMax is the maximum interval between status polls.
29+ pollIntervalMax = 5 * time .Second
30+
31+ // cancelTimeout is how long to wait for server-side cancellation.
32+ cancelTimeout = 10 * time .Second
33+ )
34+
1835func newQueryCmd () * cobra.Command {
36+ var warehouseID string
37+
1938 cmd := & cobra.Command {
2039 Use : "query SQL" ,
2140 Short : "Execute SQL against a Databricks warehouse" ,
2241 Long : `Execute a SQL statement against a Databricks SQL warehouse and return results.
2342
24- The command auto-detects an available warehouse unless DATABRICKS_WAREHOUSE_ID is set.
43+ The command auto-detects an available warehouse unless --warehouse is set
44+ or the DATABRICKS_WAREHOUSE_ID environment variable is configured.
2545
2646Output includes the query results as JSON and row count.` ,
27- Example : ` databricks experimental aitools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"` ,
47+ Example : ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
48+ databricks experimental aitools tools query --warehouse abc123 "SELECT 1"` ,
2849 Args : cobra .ExactArgs (1 ),
2950 PreRunE : root .MustWorkspaceClient ,
3051 RunE : func (cmd * cobra.Command , args []string ) error {
@@ -36,31 +57,14 @@ Output includes the query results as JSON and row count.`,
3657 return errors .New ("SQL statement is required" )
3758 }
3859
39- // set up session with client for middleware compatibility
40- sess := session .NewSession ()
41- sess .Set (middlewares .DatabricksClientKey , w )
42- ctx = session .WithSession (ctx , sess )
43-
44- warehouseID , err := middlewares .GetWarehouseID (ctx , true )
60+ wID , err := resolveWarehouseID (ctx , w , warehouseID )
4561 if err != nil {
4662 return err
4763 }
4864
49- resp , err := w .StatementExecution .ExecuteAndWait (ctx , sql.ExecuteStatementRequest {
50- WarehouseId : warehouseID ,
51- Statement : sqlStatement ,
52- WaitTimeout : "50s" ,
53- })
65+ resp , err := executeAndPoll (ctx , w .StatementExecution , wID , sqlStatement )
5466 if err != nil {
55- return fmt .Errorf ("execute statement: %w" , err )
56- }
57-
58- if resp .Status != nil && resp .Status .State == sql .StatementStateFailed {
59- errMsg := "query failed"
60- if resp .Status .Error != nil {
61- errMsg = resp .Status .Error .Message
62- }
63- return errors .New (errMsg )
67+ return err
6468 }
6569
6670 output , err := formatQueryResult (resp )
@@ -73,13 +77,178 @@ Output includes the query results as JSON and row count.`,
7377 },
7478 }
7579
80+ cmd .Flags ().StringVarP (& warehouseID , "warehouse" , "w" , "" , "SQL warehouse ID to use for execution" )
81+
7682 return cmd
7783}
7884
85+ // resolveWarehouseID returns the warehouse ID to use for query execution.
86+ // Priority: explicit flag > middleware auto-detection (env var > server default > first running).
87+ func resolveWarehouseID (ctx context.Context , w any , flagValue string ) (string , error ) {
88+ if flagValue != "" {
89+ return flagValue , nil
90+ }
91+
92+ sess := session .NewSession ()
93+ sess .Set (middlewares .DatabricksClientKey , w )
94+ ctx = session .WithSession (ctx , sess )
95+
96+ return middlewares .GetWarehouseID (ctx , true )
97+ }
98+
99+ // executeAndPoll submits a SQL statement asynchronously and polls until completion.
100+ // It shows a spinner in interactive mode and supports Ctrl+C cancellation.
101+ func executeAndPoll (ctx context.Context , api sql.StatementExecutionInterface , warehouseID , statement string ) (* sql.StatementResponse , error ) {
102+ // Submit asynchronously to get the statement ID immediately for cancellation.
103+ resp , err := api .ExecuteStatement (ctx , sql.ExecuteStatementRequest {
104+ WarehouseId : warehouseID ,
105+ Statement : statement ,
106+ WaitTimeout : "0s" ,
107+ })
108+ if err != nil {
109+ return nil , fmt .Errorf ("execute statement: %w" , err )
110+ }
111+
112+ statementID := resp .StatementId
113+
114+ // Check if it completed immediately.
115+ if isTerminalState (resp .Status ) {
116+ return resp , checkFailedState (resp .Status )
117+ }
118+
119+ // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below.
120+ pollCtx , pollCancel := context .WithCancel (ctx )
121+ defer pollCancel ()
122+
123+ sigCh := make (chan os.Signal , 1 )
124+ signal .Notify (sigCh , os .Interrupt , syscall .SIGTERM )
125+ defer signal .Stop (sigCh )
126+
127+ go func () {
128+ select {
129+ case <- sigCh :
130+ log .Infof (ctx , "Received interrupt, cancelling query %s" , statementID )
131+ pollCancel ()
132+ case <- pollCtx .Done ():
133+ }
134+ }()
135+
136+ // cancelStatement performs best-effort server-side cancellation.
137+ // Called on any poll exit due to context cancellation (signal or parent).
138+ cancelStatement := func () {
139+ cancelCtx , cancel := context .WithTimeout (context .Background (), cancelTimeout )
140+ defer cancel ()
141+ if err := api .CancelExecution (cancelCtx , sql.CancelExecutionRequest {
142+ StatementId : statementID ,
143+ }); err != nil {
144+ log .Warnf (ctx , "Failed to cancel statement %s: %v" , statementID , err )
145+ }
146+ }
147+
148+ // Spinner for interactive feedback, updated every second via ticker.
149+ sp := cmdio .NewSpinner (pollCtx )
150+ defer sp .Close ()
151+ start := time .Now ()
152+ sp .Update ("Executing query..." )
153+
154+ ticker := time .NewTicker (1 * time .Second )
155+ defer ticker .Stop ()
156+ go func () {
157+ for {
158+ select {
159+ case <- pollCtx .Done ():
160+ return
161+ case <- ticker .C :
162+ elapsed := time .Since (start ).Truncate (time .Second )
163+ sp .Update (fmt .Sprintf ("Executing query... (%s elapsed)" , elapsed ))
164+ }
165+ }
166+ }()
167+
168+ // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped).
169+ interval := pollIntervalInitial
170+ for {
171+ select {
172+ case <- pollCtx .Done ():
173+ cancelStatement ()
174+ cmdio .LogString (ctx , "Query cancelled." )
175+ return nil , root .ErrAlreadyPrinted
176+ case <- time .After (interval ):
177+ }
178+
179+ log .Debugf (ctx , "Polling statement %s: %s elapsed" , statementID , time .Since (start ).Truncate (time .Second ))
180+
181+ pollResp , err := api .GetStatementByStatementId (pollCtx , statementID )
182+ if err != nil {
183+ if pollCtx .Err () != nil {
184+ cancelStatement ()
185+ cmdio .LogString (ctx , "Query cancelled." )
186+ return nil , root .ErrAlreadyPrinted
187+ }
188+ return nil , fmt .Errorf ("poll statement status: %w" , err )
189+ }
190+
191+ if isTerminalState (pollResp .Status ) {
192+ sp .Close ()
193+ if err := checkFailedState (pollResp .Status ); err != nil {
194+ return nil , err
195+ }
196+ return & sql.StatementResponse {
197+ StatementId : pollResp .StatementId ,
198+ Status : pollResp .Status ,
199+ Manifest : pollResp .Manifest ,
200+ Result : pollResp .Result ,
201+ }, nil
202+ }
203+
204+ interval = min (interval + time .Second , pollIntervalMax )
205+ }
206+ }
207+
208+ // isTerminalState returns true if the statement has reached a final state.
209+ func isTerminalState (status * sql.StatementStatus ) bool {
210+ if status == nil {
211+ return false
212+ }
213+ switch status .State {
214+ case sql .StatementStateSucceeded , sql .StatementStateFailed ,
215+ sql .StatementStateCanceled , sql .StatementStateClosed :
216+ return true
217+ case sql .StatementStatePending , sql .StatementStateRunning :
218+ return false
219+ }
220+ return false
221+ }
222+
223+ // checkFailedState returns an error if the statement is in a non-success terminal state.
224+ func checkFailedState (status * sql.StatementStatus ) error {
225+ if status == nil {
226+ return nil
227+ }
228+ switch status .State {
229+ case sql .StatementStateFailed :
230+ msg := "query failed"
231+ if status .Error != nil {
232+ msg = fmt .Sprintf ("query failed: %s %s" , status .Error .ErrorCode , status .Error .Message )
233+ if strings .Contains (status .Error .Message , "UNRESOLVED_MAP_KEY" ) {
234+ msg += "\n \n Hint: your shell may have stripped quotes from the SQL string. " +
235+ "Use single quotes for map keys (e.g. info['key']) or pass the query via --file."
236+ }
237+ }
238+ return errors .New (msg )
239+ case sql .StatementStateCanceled :
240+ return errors .New ("query was cancelled" )
241+ case sql .StatementStateClosed :
242+ return errors .New ("query was closed before results could be fetched" )
243+ case sql .StatementStatePending , sql .StatementStateRunning , sql .StatementStateSucceeded :
244+ return nil
245+ }
246+ return nil
247+ }
248+
79249// cleanSQL removes surrounding quotes, empty lines, and SQL comments.
80250func cleanSQL (s string ) string {
81251 s = strings .TrimSpace (s )
82- // remove surrounding quotes if present
83252 if (strings .HasPrefix (s , `"` ) && strings .HasSuffix (s , `"` )) ||
84253 (strings .HasPrefix (s , `'` ) && strings .HasSuffix (s , `'` )) {
85254 s = s [1 : len (s )- 1 ]
@@ -88,12 +257,12 @@ func cleanSQL(s string) string {
88257 var lines []string
89258 for _ , line := range strings .Split (s , "\n " ) {
90259 line = strings .TrimSpace (line )
91- // skip empty lines and single-line comments
92260 if line == "" || strings .HasPrefix (line , "--" ) {
93261 continue
94262 }
95263 lines = append (lines , line )
96264 }
265+
97266 return strings .Join (lines , "\n " )
98267}
99268
@@ -105,15 +274,13 @@ func formatQueryResult(resp *sql.StatementResponse) (string, error) {
105274 return sb .String (), nil
106275 }
107276
108- // get column names
109277 var columns []string
110278 if resp .Manifest .Schema != nil {
111279 for _ , col := range resp .Manifest .Schema .Columns {
112280 columns = append (columns , col .Name )
113281 }
114282 }
115283
116- // format as JSON array for consistency with Neon API
117284 var rows []map [string ]any
118285 if resp .Result .DataArray != nil {
119286 for _ , row := range resp .Result .DataArray {
0 commit comments