From 6a768d4b43505874955902c3626edfaa401b3f18 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Mar 2026 14:08:10 +0100 Subject: [PATCH] SSH: Check for Remote SSH extension in VS Code and Cursor Check if the required Remote SSH extension is installed and above a minimum version, and if not, offer to install it. --- experimental/ssh/internal/client/client.go | 3 + experimental/ssh/internal/vscode/run.go | 109 ++++++++-- experimental/ssh/internal/vscode/run_test.go | 198 +++++++++++++++++++ 3 files changed, 294 insertions(+), 16 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 3c4ed62155..a05c0adca7 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -211,6 +211,9 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err := vscode.CheckIDECommand(opts.IDE); err != nil { return err } + if err := vscode.CheckIDESSHExtension(ctx, opts.IDE); err != nil { + return err + } } // Check and update IDE settings for serverless mode, where we must set up diff --git a/experimental/ssh/internal/vscode/run.go b/experimental/ssh/internal/vscode/run.go index 373a54d8f4..fb88c32edd 100644 --- a/experimental/ssh/internal/vscode/run.go +++ b/experimental/ssh/internal/vscode/run.go @@ -5,8 +5,10 @@ import ( "fmt" "os" "os/exec" + "strings" "github.com/databricks/cli/libs/cmdio" + "golang.org/x/mod/semver" ) // Options as they can be set via --ide flag. @@ -16,27 +18,38 @@ const ( ) type ideDescriptor struct { - Option string - Command string - Name string - InstallURL string - AppName string + Option string + Command string + Name string + InstallURL string + AppName string + SSHExtensionID string + SSHExtensionName string + MinSSHExtensionVersion string } var vsCodeIDE = ideDescriptor{ - Option: VSCodeOption, - Command: "code", - Name: "VS Code", - InstallURL: "https://code.visualstudio.com/", - AppName: "Code", + Option: VSCodeOption, + Command: "code", + Name: "VS Code", + InstallURL: "https://code.visualstudio.com/", + AppName: "Code", + SSHExtensionID: "ms-vscode-remote.remote-ssh", + SSHExtensionName: "Remote - SSH", + // Earlier versions might work too, 0.120.0 is a safe not-too-old pick + MinSSHExtensionVersion: "0.120.0", } var cursorIDE = ideDescriptor{ - Option: CursorOption, - Command: "cursor", - Name: "Cursor", - InstallURL: "https://cursor.com/", - AppName: "Cursor", + Option: CursorOption, + Command: "cursor", + Name: "Cursor", + InstallURL: "https://cursor.com/", + AppName: "Cursor", + SSHExtensionID: "anysphere.remote-ssh", + SSHExtensionName: "Remote - SSH", + // Earlier versions don't support remote.SSH.serverPickPortsFromRange option + MinSSHExtensionVersion: "1.0.32", } func getIDE(option string) ideDescriptor { @@ -62,7 +75,71 @@ func CheckIDECommand(option string) error { return nil } -// LaunchIDE launches the IDE with a remote SSH connection. +// parseExtensionVersion finds the version of the given extension in the output +// of " --list-extensions --show-versions" (one "name@version" per line). +func parseExtensionVersion(output, extensionID string) (string, bool) { + for line := range strings.SplitSeq(output, "\n") { + name, version, ok := strings.Cut(strings.TrimSpace(line), "@") + if ok && name == extensionID { + return version, true + } + } + return "", false +} + +func isExtensionVersionAtLeast(version, minVersion string) bool { + v := "v" + version + return semver.IsValid(v) && semver.Compare(v, "v"+minVersion) >= 0 +} + +// CheckIDESSHExtension verifies that the required Remote SSH extension is installed +// with a compatible version, and offers to install/update it if not. +func CheckIDESSHExtension(ctx context.Context, option string) error { + ide := getIDE(option) + + out, err := exec.CommandContext(ctx, ide.Command, "--list-extensions", "--show-versions").Output() + if err != nil { + return fmt.Errorf("failed to list %s extensions: %w", ide.Name, err) + } + + version, found := parseExtensionVersion(string(out), ide.SSHExtensionID) + if found && isExtensionVersionAtLeast(version, ide.MinSSHExtensionVersion) { + return nil + } + + var msg string + if !found { + msg = fmt.Sprintf("Required extension %q is not installed in %s.", ide.SSHExtensionName, ide.Name) + } else { + msg = fmt.Sprintf("Extension %q version %s is installed, but version >= %s is required.", + ide.SSHExtensionName, version, ide.MinSSHExtensionVersion) + } + + if !cmdio.IsPromptSupported(ctx) { + return fmt.Errorf("%s Install it with: %s --install-extension %s", + msg, ide.Command, ide.SSHExtensionID) + } + + shouldInstall, err := cmdio.AskYesOrNo(ctx, msg+" Would you like to install it?") + if err != nil { + return fmt.Errorf("failed to prompt user: %w", err) + } + if !shouldInstall { + return fmt.Errorf("%s Install it with: %s --install-extension %s", + msg, ide.Command, ide.SSHExtensionID) + } + + cmdio.LogString(ctx, fmt.Sprintf("Installing %q...", ide.SSHExtensionName)) + installCmd := exec.CommandContext(ctx, ide.Command, "--install-extension", ide.SSHExtensionID, "--force") + installCmd.Stdout = os.Stdout + installCmd.Stderr = os.Stderr + if err := installCmd.Run(); err != nil { + return fmt.Errorf("failed to install extension %q: %w", ide.SSHExtensionName, err) + } + return nil +} + +// LaunchIDE launches the IDE with a remote SSH connection using special "ssh-remote" URI format. func LaunchIDE(ctx context.Context, ideOption, connectionName, userName, databricksUserName string) error { ide := getIDE(ideOption) diff --git a/experimental/ssh/internal/vscode/run_test.go b/experimental/ssh/internal/vscode/run_test.go index f50fa4f93d..54c9940b9b 100644 --- a/experimental/ssh/internal/vscode/run_test.go +++ b/experimental/ssh/internal/vscode/run_test.go @@ -1,11 +1,13 @@ package vscode import ( + "fmt" "os" "path/filepath" "runtime" "testing" + "github.com/databricks/cli/libs/cmdio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -88,3 +90,199 @@ func TestCheckIDECommand_Found(t *testing.T) { }) } } + +func TestParseExtensionVersion(t *testing.T) { + tests := []struct { + name string + output string + extensionID string + wantVersion string + wantFound bool + minVersion string + wantAtLeast bool + }{ + { + name: "found and above minimum", + output: "ms-python.python@2024.1.1\nms-vscode-remote.remote-ssh@0.123.0\n", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "0.123.0", + wantFound: true, + minVersion: "0.120.0", + wantAtLeast: true, + }, + { + name: "found but below minimum", + output: "ms-vscode-remote.remote-ssh@0.100.0\n", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "0.100.0", + wantFound: true, + minVersion: "0.120.0", + wantAtLeast: false, + }, + { + name: "not found", + output: "ms-python.python@2024.1.1\n", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "", + wantFound: false, + }, + { + name: "empty output", + output: "", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "", + wantFound: false, + }, + { + name: "multiple extensions", + output: "ext.a@1.0.0\next.b@2.0.0\next.c@3.0.0\n", + extensionID: "ext.b", + wantVersion: "2.0.0", + wantFound: true, + minVersion: "1.0.0", + wantAtLeast: true, + }, + { + name: "prerelease is less than release", + output: "ms-vscode-remote.remote-ssh@0.120.0-beta.1\n", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "0.120.0-beta.1", + wantFound: true, + minVersion: "0.120.0", + wantAtLeast: false, + }, + { + name: "line with whitespace", + output: " ms-vscode-remote.remote-ssh@0.123.0 \n", + extensionID: "ms-vscode-remote.remote-ssh", + wantVersion: "0.123.0", + wantFound: true, + minVersion: "0.120.0", + wantAtLeast: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + version, found := parseExtensionVersion(tt.output, tt.extensionID) + assert.Equal(t, tt.wantFound, found) + assert.Equal(t, tt.wantVersion, version) + if found { + assert.Equal(t, tt.wantAtLeast, isExtensionVersionAtLeast(version, tt.minVersion)) + } + }) + } +} + +func TestIsExtensionVersionAtLeast(t *testing.T) { + tests := []struct { + name string + version string + minVersion string + want bool + }{ + {name: "above minimum", version: "0.123.0", minVersion: "0.120.0", want: true}, + {name: "exact minimum", version: "0.120.0", minVersion: "0.120.0", want: true}, + {name: "below minimum", version: "0.100.0", minVersion: "0.120.0", want: false}, + {name: "major version ahead", version: "1.0.0", minVersion: "0.120.0", want: true}, + {name: "prerelease below release", version: "0.120.0-beta.1", minVersion: "0.120.0", want: false}, + {name: "prerelease above prior release", version: "0.121.0-beta.1", minVersion: "0.120.0", want: true}, + {name: "two-component version is valid", version: "1.0", minVersion: "0.120.0", want: true}, + {name: "empty version", version: "", minVersion: "0.120.0", want: false}, + {name: "garbage version", version: "abc", minVersion: "0.120.0", want: false}, + {name: "four-component version is invalid", version: "0.120.0.1", minVersion: "0.120.0", want: false}, + {name: "cursor exact minimum", version: "1.0.32", minVersion: "1.0.32", want: true}, + {name: "cursor above minimum", version: "1.1.0", minVersion: "1.0.32", want: true}, + {name: "cursor below minimum", version: "1.0.31", minVersion: "1.0.32", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isExtensionVersionAtLeast(tt.version, tt.minVersion)) + }) + } +} + +// createFakeIDEExecutable creates a fake IDE command that outputs the given text +// when called with --list-extensions --show-versions. +func createFakeIDEExecutable(t *testing.T, dir, command, output string) { + t.Helper() + if runtime.GOOS == "windows" { + // Write output to a temp file and use "type" to print it, avoiding escaping issues. + payloadPath := filepath.Join(dir, command+"-payload.txt") + err := os.WriteFile(payloadPath, []byte(output), 0o644) + require.NoError(t, err) + script := fmt.Sprintf("@echo off\ntype \"%s\"\n", payloadPath) + err = os.WriteFile(filepath.Join(dir, command+".cmd"), []byte(script), 0o755) + require.NoError(t, err) + } else { + // Use printf (a shell builtin) instead of cat to avoid PATH issues in tests. + script := fmt.Sprintf("#!/bin/sh\nprintf '%%s' '%s'\n", output) + err := os.WriteFile(filepath.Join(dir, command), []byte(script), 0o755) + require.NoError(t, err) + } +} + +func TestCheckIDESSHExtension_UpToDate(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("PATH", tmpDir) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + + extensionOutput := "ms-python.python@2024.1.1\nms-vscode-remote.remote-ssh@0.123.0\n" + createFakeIDEExecutable(t, tmpDir, "code", extensionOutput) + + err := CheckIDESSHExtension(ctx, VSCodeOption) + assert.NoError(t, err) +} + +func TestCheckIDESSHExtension_ExactMinVersion(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("PATH", tmpDir) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + + extensionOutput := "ms-vscode-remote.remote-ssh@0.120.0\n" + createFakeIDEExecutable(t, tmpDir, "code", extensionOutput) + + err := CheckIDESSHExtension(ctx, VSCodeOption) + assert.NoError(t, err) +} + +func TestCheckIDESSHExtension_Missing(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("PATH", tmpDir) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + + extensionOutput := "ms-python.python@2024.1.1\n" + createFakeIDEExecutable(t, tmpDir, "code", extensionOutput) + + err := CheckIDESSHExtension(ctx, VSCodeOption) + require.Error(t, err) + assert.Contains(t, err.Error(), `"Remote - SSH"`) + assert.Contains(t, err.Error(), "not installed") +} + +func TestCheckIDESSHExtension_Outdated(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("PATH", tmpDir) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + + extensionOutput := "ms-vscode-remote.remote-ssh@0.100.0\n" + createFakeIDEExecutable(t, tmpDir, "code", extensionOutput) + + err := CheckIDESSHExtension(ctx, VSCodeOption) + require.Error(t, err) + assert.Contains(t, err.Error(), "0.100.0") + assert.Contains(t, err.Error(), ">= 0.120.0") +} + +func TestCheckIDESSHExtension_Cursor(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("PATH", tmpDir) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + + extensionOutput := "anysphere.remote-ssh@1.0.32\n" + createFakeIDEExecutable(t, tmpDir, "cursor", extensionOutput) + + err := CheckIDESSHExtension(ctx, CursorOption) + assert.NoError(t, err) +}