diff --git a/cmd/commands/upgrade_win.go b/cmd/commands/upgrade_win.go index 9c069b6c..c018d61c 100644 --- a/cmd/commands/upgrade_win.go +++ b/cmd/commands/upgrade_win.go @@ -20,6 +20,7 @@ package commands import ( "os" + "strings" "syscall" "unsafe" @@ -58,11 +59,23 @@ func runAsAdmin() error { return err } + // Build arguments string from os.Args (skip the executable name) + // This ensures all flags like --debug are passed through + args := "" + if len(os.Args) > 1 { + // Join all arguments starting from index 1 + quotedArgs := make([]string, 0, len(os.Args)-1) + for _, arg := range os.Args[1:] { + quotedArgs = append(quotedArgs, escapeArg(arg)) + } + args = strings.Join(quotedArgs, " ") + } + verb := "runas" cwd, _ := syscall.UTF16PtrFromString(".") - arg, _ := syscall.UTF16PtrFromString(SelfUpgradeName) + arg, _ := syscall.UTF16PtrFromString(args) run := windows.NewLazySystemDLL("shell32.dll").NewProc("ShellExecuteW") - run.Call( + ret, _, _ := run.Call( 0, uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(verb))), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(exePath))), @@ -70,6 +83,55 @@ func runAsAdmin() error { uintptr(unsafe.Pointer(cwd)), 1, ) + // ShellExecuteW returns a value > 32 on success + if ret <= 32 { + return syscall.Errno(ret) + } + // Exit the current process since we've successfully launched the elevated one + // This function never returns normally after successful elevation os.Exit(0) - return nil + return nil // unreachable but required for compilation +} + +// escapeArg escapes a command-line argument according to Windows rules. +// Based on https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way +func escapeArg(arg string) string { + // If the argument doesn't contain special characters, return as-is + if !strings.ContainsAny(arg, " \t\n\"") { + return arg + } + + // Build the escaped argument + var b strings.Builder + b.WriteByte('"') + + for i := 0; i < len(arg); { + // Count consecutive backslashes + backslashes := 0 + for i < len(arg) && arg[i] == '\\' { + backslashes++ + i++ + } + + if i >= len(arg) { + // Backslashes at the end need to be doubled (they precede the closing quote) + b.WriteString(strings.Repeat("\\", backslashes*2)) + break + } + + if arg[i] == '"' { + // Backslashes before a quote need to be doubled, and the quote needs to be escaped + b.WriteString(strings.Repeat("\\", backslashes*2)) + b.WriteString("\\\"") + i++ + } else { + // Regular backslashes (not before a quote) are literal + b.WriteString(strings.Repeat("\\", backslashes)) + b.WriteByte(arg[i]) + i++ + } + } + + b.WriteByte('"') + return b.String() }