Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,40 +114,6 @@ func (c *Client) ListInstancesWithIPUpdateCtx(ctx context.Context) ([]Instance,
return instances, nil
}

func (c *Client) GetLatestBinaryHashCtx(ctx context.Context) (string, error) {
metadataURL := "https://storage.googleapis.com/storage/v1/b/client-binary/o/client_linux_x86_64?alt=json"

req, err := http.NewRequest("GET", metadataURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}

resp, err := c.do(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}

var result struct {
Metadata map[string]string `json:"metadata"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}

return result.Metadata["hash"], nil
}

func (c *Client) AddSSHKeyCtx(ctx context.Context, instanceID string) (*AddSSHKeyResponse, error) {
url := fmt.Sprintf("%s/instances/%s/add_key", c.baseURL, instanceID)

Expand Down
4 changes: 2 additions & 2 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestNewClient(t *testing.T) {
func TestCreateInstanceRequest(t *testing.T) {
req := CreateInstanceRequest{
CPUCores: 8,
GPUType: "t4",
GPUType: "a6000",
Template: "ubuntu-22.04",
NumGPUs: 1,
DiskSizeGB: 100,
Expand Down Expand Up @@ -143,7 +143,7 @@ func TestTemplateStruct(t *testing.T) {
Version: 1,
DefaultSpecs: ThunderTemplateDefaultSpecs{
Cores: 8,
GpuType: "t4",
GpuType: "a6000",
NumGpus: 1,
Storage: 100,
},
Expand Down
2 changes: 0 additions & 2 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,4 @@ type ConnectClient interface {
ListInstances() ([]Instance, error)
ListInstancesWithIPUpdateCtx(ctx context.Context) ([]Instance, error)
AddSSHKeyCtx(ctx context.Context, instanceID string) (*AddSSHKeyResponse, error)
GetLatestBinaryHashCtx(ctx context.Context) (string, error)
GetNextDeviceID() (string, error)
}
235 changes: 10 additions & 225 deletions cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package cmd

import (
"context"
"encoding/base64"
"errors"
"fmt"
"os"
"os/exec"
"os/signal"
"runtime"
"strconv"
"strings"
"time"

tea "github.com/charmbracelet/bubbletea"
Expand Down Expand Up @@ -251,9 +249,6 @@ func runConnectWithOptions(instanceID string, tunnelPortsStr []string, debug boo
phase1Start := time.Now()
tui.SendPhaseUpdate(p, 0, tui.PhaseInProgress, "Fetching instances...", 0)

hashChan := make(chan string, 1)
hashErrChan := make(chan error, 1)

if runtime.GOOS == "windows" {
if err := checkWindowsOpenSSH(); err != nil {
return err
Expand All @@ -274,16 +269,6 @@ func runConnectWithOptions(instanceID string, tunnelPortsStr []string, debug boo
return nil
}

// Fetch binary hash in background
go func() {
hash, err := client.GetLatestBinaryHashCtx(ctx)
if err != nil {
hashErrChan <- err
return
}
hashChan <- hash
}()

if checkCancelled() {
return nil
}
Expand Down Expand Up @@ -327,13 +312,6 @@ func runConnectWithOptions(instanceID string, tunnelPortsStr []string, debug boo
port = 22
}

gpuCount := 1
if instance.NumGPUs != "" {
if count, err := strconv.Atoi(instance.NumGPUs); err == nil {
gpuCount = count
}
}

phaseTimings["instance_validation"] = time.Since(phase2Start)
tui.SendPhaseUpdate(p, 1, tui.PhaseCompleted, fmt.Sprintf("Found: %s (%s)", instance.Name, instance.IP), phaseTimings["instance_validation"])

Expand Down Expand Up @@ -491,211 +469,18 @@ func runConnectWithOptions(instanceID string, tunnelPortsStr []string, debug boo
return nil
}

// Get binary hash (already fetched in background)
var binaryHash string
select {
case hash := <-hashChan:
binaryHash = hash
case <-hashErrChan:
binaryHash = ""
case <-ctx.Done():
if checkCancelled() {
return nil
}
case <-time.After(2 * time.Second):
binaryHash = ""
}

// For production mode, check active sessions first (like VSCode extension) to skip operations if needed
var activeSessions int
var existingConfig *utils.ThunderConfig
var existingHash string
var canEarlyReturn bool

// Set up token on the instance (binary is now managed by the instance itself)
if instance.Mode == "production" {
var checkErr error
activeSessions, checkErr = utils.CheckActiveSessions(sshClient)
if checkErr != nil {
activeSessions = 0
}

if activeSessions > 1 {
tokenB64 := base64.StdEncoding.EncodeToString([]byte(config.Token))
combinedTokenCmd := fmt.Sprintf("sudo install -d -m 755 /home/ubuntu/.thunder && echo '%s' | base64 -d | sudo tee /home/ubuntu/.thunder/token > /dev/null && sudo chown ubuntu:ubuntu /home/ubuntu/.thunder/token && sudo chmod 600 /home/ubuntu/.thunder/token && sudo sed -i '/export TNR_API_TOKEN/d' /home/ubuntu/.bashrc || true && echo 'export TNR_API_TOKEN=\"$(cat /home/ubuntu/.thunder/token)\"' | sudo tee -a /home/ubuntu/.bashrc > /dev/null", tokenB64)
_, _ = utils.ExecuteSSHCommand(sshClient, combinedTokenCmd)
phaseTimings["instance_setup"] = time.Since(phase5Start)
tui.SendPhaseComplete(p, 4, phaseTimings["instance_setup"])
canEarlyReturn = true
} else {
// No active sessions - match VSCode extension: skip config/hash check, run cleanup (idempotent)
if err := utils.RemoveThunderVirtualization(sshClient, config.Token); err != nil {
shutdownTUI()
return fmt.Errorf("failed to remove Thunder virtualization: %w", err)
}
phaseTimings["instance_setup"] = time.Since(phase5Start)
tui.SendPhaseComplete(p, 4, phaseTimings["instance_setup"])
canEarlyReturn = true

}
}

// For prototyping mode, do full config/hash read in parallel
if instance.Mode != "production" || !canEarlyReturn {
// Clean up ld.so.preload early if binary is missing to prevent stderr pollution
_ = utils.CleanupLdSoPreloadIfBinaryMissing(sshClient)

type configResult struct {
config *utils.ThunderConfig
err error
}
type instanceHashResult struct {
hash string
err error
}

configChan := make(chan configResult, 1)
instanceHashChan := make(chan instanceHashResult, 1)

go func() {
config, err := utils.GetThunderConfig(sshClient)
configChan <- configResult{config: config, err: err}
}()

expectedHash := utils.NormalizeHash(binaryHash)
isValidHash := expectedHash != "" && len(expectedHash) == 32 && utils.IsHexString(expectedHash)
hashAlgorithm := utils.DetectHashAlgorithm(expectedHash)

if isValidHash {
go func() {
hash, err := utils.GetInstanceBinaryHash(sshClient, hashAlgorithm)
instanceHashChan <- instanceHashResult{hash: hash, err: err}
}()
} else {
instanceHashChan <- instanceHashResult{hash: "", err: nil}
}

configRes := <-configChan
hashRes := <-instanceHashChan

if configRes.err == nil {
existingConfig = configRes.config
}

if hashRes.err == nil {
existingHash = hashRes.hash
}

}

ranConfigurator := false

// Early return if GPU config and hash match
if !canEarlyReturn {
if instance.Mode == "prototyping" && existingConfig != nil && existingConfig.DeviceID != "" {
expectedHash := utils.NormalizeHash(binaryHash)
isValidHash := expectedHash != "" && len(expectedHash) == 32 && utils.IsHexString(expectedHash)
gpuTypeMatches := strings.EqualFold(existingConfig.GPUType, instance.GPUType)
gpuCountMatches := existingConfig.GPUCount == gpuCount
hashMatches := isValidHash && existingHash != "" && existingHash == expectedHash

if gpuTypeMatches && gpuCountMatches && hashMatches {
phaseTimings["instance_setup"] = time.Since(phase5Start)
tui.SendPhaseComplete(p, 4, phaseTimings["instance_setup"])
canEarlyReturn = true
ranConfigurator = true
}
}
}

// Skip token/bootstrap operations if GPU config matches (ConfigureThunderVirtualization handles token update)
skipTokenBootstrap := canEarlyReturn
skipActiveSessionsCheck := canEarlyReturn
if !canEarlyReturn && instance.Mode == "prototyping" && existingConfig != nil && existingConfig.DeviceID != "" {
gpuTypeMatches := strings.EqualFold(existingConfig.GPUType, instance.GPUType)
gpuCountMatches := existingConfig.GPUCount == gpuCount
if gpuTypeMatches && gpuCountMatches {
skipTokenBootstrap = true
skipActiveSessionsCheck = true
}
}

// For prototyping mode, handle token/bootstrap and active sessions check
if instance.Mode == "prototyping" && !canEarlyReturn {
if !skipTokenBootstrap {
// Combine token bootstrap and bashrc update into a single SSH command
tokenB64 := base64.StdEncoding.EncodeToString([]byte(config.Token))
combinedTokenCmd := fmt.Sprintf("sudo install -d -m 755 /home/ubuntu/.thunder && echo '%s' | base64 -d | sudo tee /home/ubuntu/.thunder/token > /dev/null && sudo chown ubuntu:ubuntu /home/ubuntu/.thunder/token && sudo chmod 600 /home/ubuntu/.thunder/token && sudo sed -i '/export TNR_API_TOKEN/d' /home/ubuntu/.bashrc || true && echo 'export TNR_API_TOKEN=\"$(cat /home/ubuntu/.thunder/token)\"' | sudo tee -a /home/ubuntu/.bashrc > /dev/null", tokenB64)
_, _ = utils.ExecuteSSHCommand(sshClient, combinedTokenCmd)
}

if !skipActiveSessionsCheck {
var checkErr error
activeSessions, checkErr = utils.CheckActiveSessions(sshClient)
if checkErr != nil {
activeSessions = 0
}
} else {
activeSessions = 0
}
} else if instance.Mode == "prototyping" {
activeSessions = 0
}

if !canEarlyReturn {
switch instance.Mode {
case "production":
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Production mode detected, disabling Thunder virtualization...", 0)
if err := utils.RemoveThunderVirtualization(sshClient, config.Token); err != nil {
shutdownTUI()
return fmt.Errorf("failed to remove Thunder virtualization: %w", err)
}
default:
var deviceID string
if existingConfig != nil && existingConfig.DeviceID != "" {
deviceID = existingConfig.DeviceID
} else {
if newID, err := client.GetNextDeviceID(); err == nil {
deviceID = newID
}
}

switch {
case activeSessions > 1:
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, fmt.Sprintf("Detected %d active SSH sessions, skipping binary update", activeSessions), 0)
case deviceID == "":
tui.SendPhaseUpdate(p, 4, tui.PhaseWarning, "Unable to determine device ID, skipping environment setup", 0)
default:
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Updating Thunder binary and config if needed...", 0)
if err := utils.ConfigureThunderVirtualization(sshClient, instanceID, deviceID, instance.GPUType, gpuCount, config.Token, binaryHash, existingConfig); err != nil {
shutdownTUI()
return fmt.Errorf("failed to configure Thunder virtualization: %w", err)
}
ranConfigurator = true
}
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Production mode detected, setting up token...", 0)
if err := utils.RemoveThunderVirtualization(sshClient, config.Token); err != nil {
shutdownTUI()
return fmt.Errorf("failed to set up token: %w", err)
}
}

if checkCancelled() {
return nil
}

if instance.Mode == "prototyping" && !ranConfigurator && binaryHash != "" {
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Checking Thunder binary version...", 0)
expectedHash := utils.NormalizeHash(binaryHash)
hashAlgo := utils.DetectHashAlgorithm(expectedHash)

existingHash, hashErr := utils.GetInstanceBinaryHash(sshClient, hashAlgo)
existingHashNormalized := utils.NormalizeHash(existingHash)

if hashErr == nil && existingHashNormalized != "" && existingHashNormalized != expectedHash {
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Binary outdated, updating in background...", 0)
deviceID := ""
if existingConfig != nil && existingConfig.DeviceID != "" {
deviceID = existingConfig.DeviceID
}
if deviceID != "" {
_ = utils.TriggerBackgroundSetup(sshClient, instanceID, deviceID, instance.GPUType, gpuCount, config.Token)
}
} else {
tui.SendPhaseUpdate(p, 4, tui.PhaseInProgress, "Setting up token...", 0)
if err := utils.SetupToken(sshClient, config.Token); err != nil {
shutdownTUI()
return fmt.Errorf("failed to set up token: %w", err)
}
}

Expand Down
20 changes: 1 addition & 19 deletions cmd/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ type mockAPIClient struct {
addSSHKeyCalled int
addSSHKeyInstanceIDs []string

binaryHash string
binaryHashErr error

nextDeviceID string
nextDeviceIDErr error

mu sync.Mutex
}

Expand Down Expand Up @@ -72,18 +66,6 @@ func (m *mockAPIClient) AddSSHKeyCtx(ctx context.Context, instanceID string) (*a
return m.addSSHKeyResponse, m.addSSHKeyErr
}

func (m *mockAPIClient) GetLatestBinaryHashCtx(ctx context.Context) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.binaryHash, m.binaryHashErr
}

func (m *mockAPIClient) GetNextDeviceID() (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.nextDeviceID, m.nextDeviceIDErr
}

// =============================================================================
// Mock SSH Client
// =============================================================================
Expand Down Expand Up @@ -114,7 +96,7 @@ func createTestInstance(id, uuid, name, ip, status, template, mode string, port
Mode: mode,
Port: port,
NumGPUs: "1",
GPUType: "t4",
GPUType: "a6000",
}
}

Expand Down
Loading
Loading