diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a6cccd2a5cb..9d0c9c2f634 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### Notable Changes ### CLI +* `ssh connect` now supports specifying a serverless usage policy with `--usage-policy-id` ([#5781](https://github.com/databricks/cli/pull/5781)). * `ssh connect` now accepts a `--base-environment` flag to run a serverless session on a custom base environment. It takes an `env.yaml` path, a `workspace-base-environments/...` resource ID, or a base environment display name, and is rejected together with `--environment-version` or `--cluster` ([#5706](https://github.com/databricks/cli/pull/5706)). * `databricks aitools install` is now plugin-first: it installs the Databricks plugin through each agent's own CLI (Claude Code, Codex, GitHub Copilot) instead of copying raw skill files. Agents without a plugin (OpenCode, Antigravity) still get skill files, and Cursor prints the `/add-plugin databricks` step. Use `--skills-only` to force raw skill files for every agent, or `--path ` to write skills to a directory ([#5738](https://github.com/databricks/cli/pull/5738)). diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 972ccf4a81a..2c50b871902 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -41,6 +41,7 @@ Connect to a dedicated cluster: var environmentVersion int var baseEnvironment string var autoApprove bool + var usagePolicyID string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks dedicated cluster ID") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") @@ -50,6 +51,7 @@ Connect to a dedicated cluster: cmd.Flags().StringVar(&connectionName, "name", "", "Connection name to reuse across sessions (serverless only)") cmd.Flags().StringVar(&accelerator, "accelerator", "", "Serverless GPU accelerator type (GPU_1xA10 or GPU_8xH100)") cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") + cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID for the serverless SSH server job (serverless only)") cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode") cmd.Flags().MarkHidden("proxy") @@ -130,6 +132,7 @@ Connect to a dedicated cluster: BaseEnvironment: baseEnvironment, AdditionalArgs: args, AutoApprove: autoApprove, + UsagePolicyID: usagePolicyID, } if err := opts.Validate(); err != nil { return err diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go index 47e16cdc649..21c651b2365 100644 --- a/experimental/ssh/cmd/server.go +++ b/experimental/ssh/cmd/server.go @@ -29,6 +29,7 @@ and proxies them to local SSH daemon processes.`, var secretScopeName string var authorizedKeySecretName string var serverless bool + var usagePolicyID string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") cmd.MarkFlagRequired("cluster") @@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes.`, cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients") cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI") cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization") + cmd.Flags().StringVar(&usagePolicyID, "usage-policy-id", "", "Usage policy ID the job was submitted with") cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // The server can be executed under a directory with an invalid bundle configuration. @@ -71,6 +73,7 @@ and proxies them to local SSH daemon processes.`, DefaultPort: defaultServerPort, PortRange: serverPortRange, Serverless: serverless, + UsagePolicyID: usagePolicyID, } return server.Run(ctx, wsc, opts) } diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 9fc20aa56df..cb9919cfa0f 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -119,6 +119,8 @@ type ClientOptions struct { BaseEnvironment string // If true, skip confirmation prompts for IDE extension install and IDE settings updates. AutoApprove bool + // Id of the usage policy to use for the serverless SSH server job. Serverless only. + UsagePolicyID string } func (o *ClientOptions) Validate() error { @@ -128,6 +130,9 @@ func (o *ClientOptions) Validate() error { if o.Accelerator != "" && o.ConnectionName == "" { return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") } + if o.UsagePolicyID != "" && o.ClusterID != "" { + return errors.New("--usage-policy-id flag can only be used with serverless compute (--name flag)") + } if o.Accelerator != "" && o.Accelerator != "GPU_1xA10" && o.Accelerator != "GPU_8xH100" { return fmt.Errorf("invalid accelerator value: %q, expected %q or %q", o.Accelerator, "GPU_1xA10", "GPU_8xH100") } @@ -214,6 +219,9 @@ func (o *ClientOptions) ToProxyCommand() (string, error) { if o.Accelerator != "" { proxyCommand += " --accelerator=" + o.Accelerator } + if o.UsagePolicyID != "" { + proxyCommand += " --usage-policy-id=" + o.UsagePolicyID + } } else { proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", executablePath, o.ClusterID, o.AutoStartCluster, o.ShutdownDelay.String()) @@ -463,14 +471,25 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return nil } +// serverMetadata describes a running SSH server, combining the persisted workspace +// metadata with the user name validated live via Driver Proxy. +type serverMetadata struct { + Port int + UserName string + // ClusterID required for Driver Proxy connections. For serverless it comes from the persisted metadata. + ClusterID string + // UsagePolicyID the server was started with, used to decide whether a running server can be reused. + UsagePolicyID string +} + // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. // For serverless, clusterID is read from the workspace metadata. -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) { +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (serverMetadata, error) { wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) if err != nil { - return 0, "", "", errors.Join(errServerMetadata, err) + return serverMetadata{}, errors.Join(errServerMetadata, err) } log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata) @@ -481,33 +500,38 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, } if effectiveClusterID == "" { - return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata")) + return serverMetadata{}, errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata")) } req, err := newDriverProxyRequest(ctx, client, effectiveClusterID, wsMetadata.Port, "metadata", liteswap) if err != nil { - return 0, "", "", err + return serverMetadata{}, err } log.Debugf(ctx, "Metadata URL: %s", req.URL) httpClient := &http.Client{Transport: client.Config.HTTPTransport} resp, err := httpClient.Do(req) if err != nil { - return 0, "", "", err + return serverMetadata{}, err } defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return 0, "", "", err + return serverMetadata{}, err } log.Debugf(ctx, "Metadata response: %s", string(bodyBytes)) log.Debugf(ctx, "Metadata response status code: %d", resp.StatusCode) if resp.StatusCode != http.StatusOK { - return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) + return serverMetadata{}, errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) } - return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil + return serverMetadata{ + Port: wsMetadata.Port, + UserName: string(bodyBytes), + ClusterID: effectiveClusterID, + UsagePolicyID: wsMetadata.UsagePolicyID, + }, nil } // newDriverProxyRequest builds an authenticated GET request to one of the SSH server's @@ -559,36 +583,10 @@ func fetchServerErrorLogs(ctx context.Context, client *databricks.WorkspaceClien return strings.TrimSpace(string(body)) } -// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start. -// It returns the job run ID (when known) so callers can fetch and surface the run's error -// details if the server never comes up. -func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { +// Assemble the SubmitRun request that bootstraps the SSH server. +// Extracted from submitSSHTunnelJob so this logic can be unit tested. +func buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment string, opts ClientOptions) jobs.SubmitRun { sessionID := opts.SessionIdentifier() - contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) - if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) - } - - err = client.Workspace.MkdirsByPath(ctx, contentDir) - if err != nil { - return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) - } - - sshTunnelJobName := "ssh-server-bootstrap-" + sessionID - jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap")) - notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript - encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent)) - - err = client.Workspace.Import(ctx, workspace.Import{ - Path: jobNotebookPath, - Format: workspace.ImportFormatSource, - Content: encodedContent, - Language: workspace.LanguagePython, - Overwrite: true, - }) - if err != nil { - return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) - } baseParams := map[string]string{ "version": version, @@ -598,10 +596,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, "maxClients": strconv.Itoa(opts.MaxClients), "sessionId": sessionID, "serverless": strconv.FormatBool(opts.IsServerlessMode()), + // Recorded in the server's metadata.json so reconnects can tell which usage policy + // the running server was started under. + "usagePolicyId": opts.UsagePolicyID, } - log.Infof(ctx, "Submitting a job to start the ssh server...") - task := jobs.SubmitTask{ TaskKey: sshServerTaskKey, NotebookTask: &jobs.NotebookTask{ @@ -614,7 +613,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = serverlessEnvironmentKey if opts.Accelerator != "" { - log.Infof(ctx, "Using accelerator: %s", opts.Accelerator) task.Compute = &jobs.Compute{ HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), } @@ -624,20 +622,17 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } submitRequest := jobs.SubmitRun{ - RunName: sshTunnelJobName, + RunName: "ssh-server-bootstrap-" + sessionID, TimeoutSeconds: int(opts.ServerTimeout.Seconds()), Tasks: []jobs.SubmitTask{task}, + BudgetPolicyId: opts.UsagePolicyID, } if opts.IsServerlessMode() { // base_environment and environment_version are mutually exclusive: a custom // base environment carries its own version, so we don't also set one. var spec compute.Environment - if opts.BaseEnvironment != "" { - baseEnvironment, err := resolveBaseEnvironment(ctx, client, opts.BaseEnvironment) - if err != nil { - return 0, err - } + if baseEnvironment != "" { spec.BaseEnvironment = baseEnvironment } else { spec.EnvironmentVersion = strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion)) @@ -650,6 +645,54 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } } + return submitRequest +} + +// submitSSHTunnelJob submits the bootstrap job and waits for the SSH server task to start. +// It returns the job run ID (when known) so callers can fetch and surface the run's error +// details if the server never comes up. +func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { + sessionID := opts.SessionIdentifier() + contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) + if err != nil { + return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + } + + err = client.Workspace.MkdirsByPath(ctx, contentDir) + if err != nil { + return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) + } + + jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap")) + notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript + encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent)) + + err = client.Workspace.Import(ctx, workspace.Import{ + Path: jobNotebookPath, + Format: workspace.ImportFormatSource, + Content: encodedContent, + Language: workspace.LanguagePython, + Overwrite: true, + }) + if err != nil { + return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) + } + + log.Infof(ctx, "Submitting a job to start the ssh server...") + if opts.IsServerlessMode() && opts.Accelerator != "" { + log.Infof(ctx, "Using accelerator: %s", opts.Accelerator) + } + + var baseEnvironment string + if opts.IsServerlessMode() && opts.BaseEnvironment != "" { + baseEnvironment, err = resolveBaseEnvironment(ctx, client, opts.BaseEnvironment) + if err != nil { + return 0, err + } + } + + submitRequest := buildSSHServerSubmitRun(version, secretScopeName, jobNotebookPath, baseEnvironment, opts) + waiter, err := client.Jobs.Submit(ctx, submitRequest) if err != nil { return 0, fmt.Errorf("failed to submit job: %w", err) @@ -1046,18 +1089,31 @@ func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string { "Remove the stale entry and reconnect:\n " + cmd } +func usagePolicyMatches(storedPolicy, requestedPolicy string) bool { + return requestedPolicy == "" || storedPolicy == requestedPolicy +} + func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { sessionID := opts.SessionIdentifier() // For dedicated clusters, use clusterID; for serverless, it will be read from metadata clusterID := opts.ClusterID - serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) - if errors.Is(err, errServerMetadata) { + meta, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) + if err != nil && !errors.Is(err, errServerMetadata) { + return "", 0, "", err + } + + // Start a new server when none is running, or when the running one was started under a + // different usage policy. A job's usage policy is fixed at submission, so we can't retarget + // the existing server; the new server overwrites metadata.json and the old one idles out via + // shutdownDelay. + needNewServer := err != nil || !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID) + if needNewServer { cmdio.LogString(ctx, "Starting SSH server...") - runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) - if err != nil { - return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err) + runID, submitErr := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) + if submitErr != nil { + return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", submitErr) } sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime()) @@ -1068,7 +1124,13 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC if ctx.Err() != nil { return "", 0, "", ctx.Err() } - serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) + meta, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) + // Accept only once metadata reflects the requested usage policy, so we don't latch + // onto a server a previous connection started under a different policy before our new + // server has overwritten metadata.json. + if err == nil && !usagePolicyMatches(meta.UsagePolicyID, opts.UsagePolicyID) { + err = fmt.Errorf("found a running SSH server with usage policy %q, waiting for the one with %q", meta.UsagePolicyID, opts.UsagePolicyID) + } if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break @@ -1085,11 +1147,9 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC return "", 0, "", fmt.Errorf("failed to start the ssh server: %w\n%s", err, describeRunFailure(ctx, client, runID)) } } - } else if err != nil { - return "", 0, "", err } - return userName, serverPort, effectiveClusterID, nil + return meta.UserName, meta.Port, meta.ClusterID, nil } func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isReconnect bool, serverStartTimeMs int64) { diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go index 48fb6f0c1f4..c146f6bd813 100644 --- a/experimental/ssh/internal/client/client_test.go +++ b/experimental/ssh/internal/client/client_test.go @@ -111,6 +111,15 @@ func TestValidate(t *testing.T) { { name: "base environment with serverless GPU accelerator", opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", BaseEnvironment: "my-gpu-env"}, + }, + { + name: "usage policy with cluster ID", + opts: client.ClientOptions{ClusterID: "abc-123", UsagePolicyID: "pol-1"}, + wantErr: "--usage-policy-id flag can only be used with serverless compute (--name flag)", + }, + { + name: "usage policy with connection name", + opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1"}, }, } @@ -233,6 +242,11 @@ func TestToProxyCommand(t *testing.T) { opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute}, want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10", }, + { + name: "serverless with usage policy", + opts: client.ClientOptions{ConnectionName: "my-conn", UsagePolicyID: "pol-1", ShutdownDelay: 2 * time.Minute}, + want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --usage-policy-id=pol-1", + }, { name: "with metadata", opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"}, diff --git a/experimental/ssh/internal/client/policy_internal_test.go b/experimental/ssh/internal/client/policy_internal_test.go new file mode 100644 index 00000000000..f501f0ba6e6 --- /dev/null +++ b/experimental/ssh/internal/client/policy_internal_test.go @@ -0,0 +1,26 @@ +package client + +import "testing" + +func TestUsagePolicyMatches(t *testing.T) { + tests := []struct { + name string + stored string + requested string + want bool + }{ + {name: "empty request matches any server", stored: "pol-1", requested: "", want: true}, + {name: "empty request matches server without policy", stored: "", requested: "", want: true}, + {name: "equal policies match", stored: "pol-1", requested: "pol-1", want: true}, + {name: "different policies do not match", stored: "pol-1", requested: "pol-2", want: false}, + {name: "request against server without policy does not match", stored: "", requested: "pol-1", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := usagePolicyMatches(tt.stored, tt.requested); got != tt.want { + t.Errorf("usagePolicyMatches(%q, %q) = %v, want %v", tt.stored, tt.requested, got, tt.want) + } + }) + } +} diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py index 87d0d2756fe..28a20f73688 100644 --- a/experimental/ssh/internal/client/ssh-server-bootstrap.py +++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py @@ -26,6 +26,7 @@ dbutils.widgets.text("shutdownDelay", "10m") dbutils.widgets.text("sessionId", "") dbutils.widgets.text("serverless", "false") +dbutils.widgets.text("usagePolicyId", "") def cleanup(): @@ -126,6 +127,7 @@ def run_ssh_server(): if not session_id: raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.") serverless = dbutils.widgets.get("serverless") + usage_policy_id = dbutils.widgets.get("usagePolicyId") # Mark this process's WSFS command origin so workspace-file activity from the # remote SSH session is attributable @@ -172,6 +174,10 @@ def run_ssh_server(): "--log-file=stdout", ] + # Recorded in the server's metadata.json so reconnects can match the usage policy. + if usage_policy_id: + server_args.append(f"--usage-policy-id={usage_policy_id}") + # Tee the server output instead of inheriting stdout: the run-page logs remain the only # place to debug a RUNNING server, but on failure we attach the log tail to the exception # so "ssh connect" can print it (the Jobs run-output API has no stdout logs for notebook tasks). diff --git a/experimental/ssh/internal/client/submit_internal_test.go b/experimental/ssh/internal/client/submit_internal_test.go new file mode 100644 index 00000000000..5b1af006b1d --- /dev/null +++ b/experimental/ssh/internal/client/submit_internal_test.go @@ -0,0 +1,76 @@ +package client + +import ( + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildSSHServerSubmitRun(t *testing.T) { + const notebookPath = "/Workspace/Users/me/.databricks/ssh-tunnel/v1/conn/ssh-server-bootstrap" + + t.Run("serverless with usage policy", func(t *testing.T) { + opts := ClientOptions{ + ConnectionName: "conn", + UsagePolicyID: "pol-1", + ServerTimeout: time.Hour, + EnvironmentVersion: 4, + } + got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts) + + // Usage policy flows onto the run and into the base params the server reads. + assert.Equal(t, "pol-1", got.BudgetPolicyId) + assert.Equal(t, "pol-1", got.Tasks[0].NotebookTask.BaseParameters["usagePolicyId"]) + assert.Equal(t, "true", got.Tasks[0].NotebookTask.BaseParameters["serverless"]) + + // Serverless runs on an environment, not an existing cluster. + assert.Equal(t, serverlessEnvironmentKey, got.Tasks[0].EnvironmentKey) + assert.Empty(t, got.Tasks[0].ExistingClusterId) + assert.Len(t, got.Environments, 1) + assert.Nil(t, got.Tasks[0].Compute) + }) + + t.Run("serverless with accelerator", func(t *testing.T) { + opts := ClientOptions{ + ConnectionName: "conn", + Accelerator: "GPU_1xA10", + ServerTimeout: time.Hour, + } + got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts) + + assert.Equal(t, compute.HardwareAcceleratorType("GPU_1xA10"), got.Tasks[0].Compute.HardwareAccelerator) + }) + + t.Run("serverless with base environment", func(t *testing.T) { + opts := ClientOptions{ + ConnectionName: "conn", + ServerTimeout: time.Hour, + EnvironmentVersion: 4, + BaseEnvironment: "my-env", + } + got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "workspace-base-environments/dbe_123", opts) + + // A resolved base environment carries its own version, so environment_version is not set. + require.Len(t, got.Environments, 1) + assert.Equal(t, "workspace-base-environments/dbe_123", got.Environments[0].Spec.BaseEnvironment) + assert.Empty(t, got.Environments[0].Spec.EnvironmentVersion) + }) + + t.Run("dedicated cluster", func(t *testing.T) { + opts := ClientOptions{ + ClusterID: "abc-123", + ServerTimeout: time.Hour, + } + got := buildSSHServerSubmitRun("v1", "scope", notebookPath, "", opts) + + // Usage policy is serverless-only; a dedicated run carries none and targets the cluster. + assert.Empty(t, got.BudgetPolicyId) + assert.Empty(t, got.Tasks[0].NotebookTask.BaseParameters["usagePolicyId"]) + assert.Equal(t, "abc-123", got.Tasks[0].ExistingClusterId) + assert.Empty(t, got.Tasks[0].EnvironmentKey) + assert.Empty(t, got.Environments) + }) +} diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go index b07b6863c00..a5e89a7b701 100644 --- a/experimental/ssh/internal/server/server.go +++ b/experimental/ssh/internal/server/server.go @@ -37,6 +37,9 @@ type ServerOptions struct { SessionID string // Serverless indicates whether the server is running on serverless compute. Serverless bool + // UsagePolicyID the job was submitted with. Persisted to metadata.json so reconnects + // can tell which usage policy the running server was started under. + UsagePolicyID string // The directory to store sshd configuration ConfigDir string // The name of the secrets scope to use for client and server keys @@ -66,8 +69,9 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt // Save metadata including ClusterID (required for Driver Proxy connections in serverless mode) metadata := &workspace.WorkspaceMetadata{ - Port: port, - ClusterID: opts.ClusterID, + Port: port, + ClusterID: opts.ClusterID, + UsagePolicyID: opts.UsagePolicyID, } err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.SessionID, metadata) if err != nil { diff --git a/experimental/ssh/internal/workspace/workspace.go b/experimental/ssh/internal/workspace/workspace.go index 0a28b684ebc..576e8a6df9f 100644 --- a/experimental/ssh/internal/workspace/workspace.go +++ b/experimental/ssh/internal/workspace/workspace.go @@ -19,6 +19,9 @@ type WorkspaceMetadata struct { Port int `json:"port"` // ClusterID is required for Driver Proxy websocket connections (for any compute type, including serverless) ClusterID string `json:"cluster_id,omitempty"` + // UsagePolicyID records the usage policy the server's job was submitted with, so a + // reconnect can tell whether a running server matches the requested usage policy. + UsagePolicyID string `json:"usage_policy_id,omitempty"` } func getWorkspaceRootDir(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {