Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/auth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type DefaultsConfig struct {
Defaults struct {
Server string `toml:"server"`
WsID string `toml:"wsID"`
} `toml:"defaults"`
}

Expand Down
46 changes: 46 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import (
"time"

"github.com/BurntSushi/toml"
"github.com/google/uuid"
)

// .config/sqlrsync/defaults.toml
type DefaultsConfig struct {
Defaults struct {
Server string `toml:"server"`
WsID string `toml:"wsID"`
} `toml:"defaults"`
}

Expand Down Expand Up @@ -78,6 +80,14 @@ func LoadDefaultsConfig() (*DefaultsConfig, error) {
// Return default config if file doesn't exist
config := &DefaultsConfig{}
config.Defaults.Server = "wss://sqlrsync.com"
// Generate wsID if it doesn't exist
if err := generateAndSetWsID(config); err != nil {
return nil, fmt.Errorf("failed to generate wsID: %w", err)
}
// Save the new config with wsID
if err := SaveDefaultsConfig(config); err != nil {
return nil, fmt.Errorf("failed to save defaults config with wsID: %w", err)
}
return config, nil
}
return nil, fmt.Errorf("failed to read defaults config file %s: %w", path, err)
Expand All @@ -93,9 +103,45 @@ func LoadDefaultsConfig() (*DefaultsConfig, error) {
config.Defaults.Server = "wss://sqlrsync.com"
}

// Generate wsID if it doesn't exist
needsSave := false
if config.Defaults.WsID == "" {
if err := generateAndSetWsID(&config); err != nil {
return nil, fmt.Errorf("failed to generate wsID: %w", err)
}
needsSave = true
}

// Save config if we made changes
if needsSave {
if err := SaveDefaultsConfig(&config); err != nil {
return nil, fmt.Errorf("failed to save defaults config with wsID: %w", err)
}
}

return &config, nil
}

// generateAndSetWsID generates a new wsID (UUID + hostname) and sets it in the config
func generateAndSetWsID(config *DefaultsConfig) error {
hostname, err := os.Hostname()
if err != nil {
return fmt.Errorf("failed to get hostname: %w", err)
}

config.Defaults.WsID = hostname + ":" + uuid.New().String()
return nil
}

// GetWsID loads the defaults config and returns the wsID
func GetWsID() (string, error) {
config, err := LoadDefaultsConfig()
if err != nil {
return "", fmt.Errorf("failed to load defaults config: %w", err)
}
return config.Defaults.WsID, nil
}

func SaveDefaultsConfig(config *DefaultsConfig) error {
path, err := GetDefaultsPath()
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.24.5
require (
github.com/BurntSushi/toml v1.5.0
github.com/fatih/color v1.18.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/spf13/cobra v1.8.0
github.com/sqlrsync/sqlrsync.com/bridge v0.0.0-00010101000000-000000000000
Expand Down
2 changes: 2 additions & 0 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
Expand Down
11 changes: 10 additions & 1 deletion client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/sqlrsync/sqlrsync.com/sync"
)

var VERSION = "0.0.4"
var VERSION = "0.0.5"
var (
serverURL string
verbose bool
Expand Down Expand Up @@ -145,6 +145,13 @@ func runSync(cmd *cobra.Command, args []string) error {
visibility = 1
}

// Get workspace ID for client identification
wsID, err := GetWsID()
if err != nil {
logger.Warn("Failed to get workspace ID", zap.Error(err))
wsID = "" // Continue with empty wsID
}

// Create sync coordinator
coordinator := sync.NewCoordinator(&sync.CoordinatorConfig{
ServerURL: serverURL,
Expand All @@ -162,6 +169,8 @@ func runSync(cmd *cobra.Command, args []string) error {
DryRun: dryRun,
Logger: logger,
Verbose: verbose,
WsID: wsID, // Add websocket ID
ClientVersion: VERSION,
})

// Execute the operation
Expand Down
66 changes: 61 additions & 5 deletions client/remote/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ type Config struct {
InspectionDepth int // How many bytes to inspect (default: 32)
PingPong bool
AuthToken string
ClientVersion string // version of the client software
SendKeyRequest bool // the -sqlrsync file doesn't exist, so make a token

SendConfigCmd bool // we don't have the version number or remote path
LocalHostname string
LocalAbsolutePath string
WsID string // Workspace ID for X-ClientID header

// Progress tracking
ProgressConfig *ProgressConfig
Expand Down Expand Up @@ -685,6 +687,14 @@ func (c *Client) Connect() error {

headers.Set("Authorization", c.config.AuthToken)

headers.Set("X-ClientVersion", c.config.ClientVersion);

if c.config.WsID != "" {
headers.Set("X-ClientID", c.config.WsID)
} else {
c.logger.Fatal("No wsID provided for X-ClientID header")
}

if c.config.LocalHostname != "" {
headers.Set("X-LocalHostname", c.config.LocalHostname)
}
Expand All @@ -703,12 +713,58 @@ func (c *Client) Connect() error {

conn, response, err := dialer.DialContext(connectCtx, u.String(), headers)
if err != nil {
if response != nil && response.Body != nil {
respStr, _ := io.ReadAll(response.Body)
response.Body.Close()
return fmt.Errorf("%s", respStr)
if response != nil {
// Extract detailed error information from the response
statusCode := response.StatusCode
statusText := response.Status

var respBodyStr string
if response.Body != nil {
respBytes, readErr := io.ReadAll(response.Body)
response.Body.Close()
if readErr == nil {
respBodyStr = strings.TrimSpace(string(respBytes))
}
}

// Create a clean error message
var errorMsg strings.Builder
errorMsg.WriteString(fmt.Sprintf("HTTP %d (%s)", statusCode, statusText))

if respBodyStr != "" {
errorMsg.WriteString(fmt.Sprintf(": %s", respBodyStr))
}

return fmt.Errorf("%s", errorMsg.String())
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)

// Handle cases where response is nil (e.g., network errors, bad handshake)
var errorMsg strings.Builder
errorMsg.WriteString("Failed to connect to WebSocket")

// Analyze the error type and provide helpful context
errorStr := err.Error()
if strings.Contains(errorStr, "bad handshake") {
errorMsg.WriteString(" - WebSocket handshake failed")
errorMsg.WriteString("\nThis could be due to:")
errorMsg.WriteString("\n• Invalid server URL or endpoint")
errorMsg.WriteString("\n• Server not supporting WebSocket connections")
errorMsg.WriteString("\n• Network connectivity issues")
errorMsg.WriteString("\n• Authentication problems")
} else if strings.Contains(errorStr, "timeout") {
errorMsg.WriteString(" - Connection timeout")
errorMsg.WriteString("\nThe server may be overloaded or unreachable")
} else if strings.Contains(errorStr, "refused") {
errorMsg.WriteString(" - Connection refused")
errorMsg.WriteString("\nThe server may be down or the port may be blocked")
} else if strings.Contains(errorStr, "no such host") {
errorMsg.WriteString(" - DNS resolution failed")
errorMsg.WriteString("\nCheck the server hostname in your configuration")
}

errorMsg.WriteString(fmt.Sprintf("\nOriginal error: %v", err))

return fmt.Errorf("%s", errorMsg.String())
}
defer response.Body.Close()

Expand Down
69 changes: 63 additions & 6 deletions client/subscription/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -33,12 +34,14 @@ type Message struct {
Timestamp time.Time `json:"timestamp"`
}

// Config holds subscription manager configuration
type Config struct {
// ManagerConfig holds subscription manager configuration
type ManagerConfig struct {
ServerURL string
ReplicaPath string
AuthToken string
ReplicaID string
WsID string // websocket ID for client identification
ClientVersion string // version of the client software
Logger *zap.Logger
MaxReconnectAttempts int // Maximum number of reconnect attempts (0 = infinite)
InitialReconnectDelay time.Duration // Initial delay before first reconnect
Expand All @@ -53,7 +56,7 @@ type Config struct {
// MaxReconnectDelay is reached. Reconnection attempts continue indefinitely unless
// MaxReconnectAttempts is set to a positive value.
type Manager struct {
config *Config
config *ManagerConfig
logger *zap.Logger
conn *websocket.Conn
ctx context.Context
Expand All @@ -72,7 +75,7 @@ type Manager struct {
}

// NewManager creates a new subscription manager
func NewManager(config *Config) *Manager {
func NewManager(config *ManagerConfig) *Manager {
ctx, cancel := context.WithCancel(context.Background())

// Set default reconnection parameters if not provided
Expand Down Expand Up @@ -201,15 +204,69 @@ func (m *Manager) doConnect() error {
headers.Set("X-ReplicaID", m.config.ReplicaID)
}

headers.Set("X-ClientVersion", m.config.ClientVersion)
headers.Set("X-ClientID", m.config.WsID)

dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}

m.logger.Debug("Dialing WebSocket", zap.String("url", u.String()))

conn, _, err := dialer.DialContext(m.ctx, u.String(), headers)
conn, response, err := dialer.DialContext(m.ctx, u.String(), headers)
if err != nil {
return fmt.Errorf("failed to connect to subscription service: %w", err)
if response != nil {
// Extract detailed error information from the response
statusCode := response.StatusCode
statusText := response.Status

var respBodyStr string
if response.Body != nil {
respBytes, readErr := io.ReadAll(response.Body)
response.Body.Close()
if readErr == nil {
respBodyStr = strings.TrimSpace(string(respBytes))
}
}

// Create a clean error message
var errorMsg strings.Builder
errorMsg.WriteString(fmt.Sprintf("HTTP %d (%s)", statusCode, statusText))

if respBodyStr != "" {
errorMsg.WriteString(fmt.Sprintf(": %s", respBodyStr))
}

return fmt.Errorf("%s", errorMsg.String())
}

// Handle cases where response is nil (e.g., network errors, bad handshake)
var errorMsg strings.Builder
errorMsg.WriteString("Failed to connect to subscription service")

// Analyze the error type and provide helpful context
errorStr := err.Error()
if strings.Contains(errorStr, "bad handshake") {
errorMsg.WriteString(" - WebSocket handshake failed")
errorMsg.WriteString("\nThis could be due to:")
errorMsg.WriteString("\n• Invalid server URL or endpoint")
errorMsg.WriteString("\n• Server not supporting WebSocket connections")
errorMsg.WriteString("\n• Network connectivity issues")
errorMsg.WriteString("\n• Authentication problems")
} else if strings.Contains(errorStr, "timeout") {
errorMsg.WriteString(" - Connection timeout")
errorMsg.WriteString("\nThe server may be overloaded or unreachable")
} else if strings.Contains(errorStr, "refused") {
errorMsg.WriteString(" - Connection refused")
errorMsg.WriteString("\nThe server may be down or the port may be blocked")
} else if strings.Contains(errorStr, "no such host") {
errorMsg.WriteString(" - DNS resolution failed")
errorMsg.WriteString("\nCheck the server hostname in your configuration")
}

errorMsg.WriteString(fmt.Sprintf("\nOriginal error: %v", err))

return fmt.Errorf("%s", errorMsg.String())
}

m.mu.Lock()
Expand Down
12 changes: 10 additions & 2 deletions client/sync/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type CoordinatorConfig struct {
DryRun bool
Logger *zap.Logger
Verbose bool
WsID string // Websocket ID for client identification
ClientVersion string // version of the client software
}

// Coordinator manages sync operations and subscriptions
Expand Down Expand Up @@ -240,11 +242,13 @@ func (c *Coordinator) executeSubscribe() error {
}

// Create subscription manager with reconnection configuration
c.subManager = subscription.NewManager(&subscription.Config{
c.subManager = subscription.NewManager(&subscription.ManagerConfig{
ServerURL: authResult.ServerURL,
ReplicaPath: authResult.RemotePath,
AuthToken: authResult.AccessToken,
ReplicaID: authResult.ReplicaID,
WsID: c.config.WsID,
ClientVersion: c.config.ClientVersion,
Logger: c.logger.Named("subscription"),
MaxReconnectAttempts: 20, // Infinite reconnect attempts
InitialReconnectDelay: 5 * time.Second, // Start with 5 seconds delay
Expand Down Expand Up @@ -373,6 +377,8 @@ func (c *Coordinator) executePull(isSubscription bool) error {
Version: version,
SendConfigCmd: true,
SendKeyRequest: c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath),
WsID: c.config.WsID, // Add websocket ID
ClientVersion: c.config.ClientVersion,
//ProgressCallback: remote.DefaultProgressCallback(remote.FormatSimple),
ProgressCallback: nil,
ProgressConfig: &remote.ProgressConfig{
Expand Down Expand Up @@ -499,7 +505,9 @@ func (c *Coordinator) executePush() error {
SendConfigCmd: true,
SetVisibility: c.config.SetVisibility,
CommitMessage: c.config.CommitMessage,
ProgressCallback: nil, //remote.DefaultProgressCallback(remote.FormatSimple),
WsID: c.config.WsID, // Add websocket ID
ClientVersion: c.config.ClientVersion,
ProgressCallback: nil, //remote.DefaultProgressCallback(remote.FormatSimple),
ProgressConfig: &remote.ProgressConfig{
Enabled: true,
Format: remote.FormatSimple,
Expand Down
Loading