diff --git a/flake.nix b/flake.nix index a9690339..7180c0fd 100644 --- a/flake.nix +++ b/flake.nix @@ -19,7 +19,7 @@ src = ./.; - vendorHash = "sha256-3eyTM8oYEkmAHshFGDTrbVWU106zvP48nDnrGtAta9M="; + vendorHash = "sha256-PE3kbfJQlvUeSPmLawxtVnqTEz+6EI+TS8dc7jphl7w="; subPackages = [ "cmd/roborev" ]; diff --git a/go.mod b/go.mod index 43af4000..1f1cbf82 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/charmbracelet/x/ansi v0.11.6 github.com/coder/acp-go-sdk v0.6.3 + github.com/coreos/go-systemd/v22 v22.7.0 github.com/fsnotify/fsnotify v1.9.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 71a4a896..858bf968 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ= github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko= +github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= +github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/daemon/runtime.go b/internal/daemon/runtime.go index c814548e..42d68b1e 100644 --- a/internal/daemon/runtime.go +++ b/internal/daemon/runtime.go @@ -435,7 +435,7 @@ func KillDaemon(info *RuntimeInfo) bool { // CleanupZombieDaemons finds and kills all unresponsive daemons. // Returns the number of zombies cleaned up. -func CleanupZombieDaemons() int { +func CleanupZombieDaemons(target DaemonEndpoint) int { runtimes, err := ListAllRuntimes() if err != nil { return 0 @@ -448,7 +448,10 @@ func CleanupZombieDaemons() int { // For Unix sockets, check PID liveness first to avoid slow HTTP probes // against sockets whose owner process is already dead. if ep.IsUnix() && info.PID > 0 && !isProcessAlive(info.PID) { - os.Remove(ep.Address) + if ep.Address != target.Address { + // Clean up non-matching sockets. + os.Remove(ep.Address) + } if info.SourcePath != "" { os.Remove(info.SourcePath) } else { @@ -463,8 +466,23 @@ func CleanupZombieDaemons() int { continue } - // Unresponsive - try to kill it - if KillDaemon(info) { + // Unresponsive — try to kill it. When the zombie's + // socket matches the target (e.g. a systemd-managed + // socket we're about to serve on), kill the process + // and clean up the runtime file but preserve the socket. + if ep.IsUnix() && ep.Address == target.Address { + if info.PID > 0 && !killProcess(info.PID) { + // Could not confirm kill; leave runtime + // metadata so the next attempt can retry. + continue + } + if info.SourcePath != "" { + os.Remove(info.SourcePath) + } else if info.PID > 0 { + RemoveRuntimeForPID(info.PID) + } + cleaned++ + } else if KillDaemon(info) { cleaned++ } } diff --git a/internal/daemon/runtime_test.go b/internal/daemon/runtime_test.go index 0e52ba60..b0e38e43 100644 --- a/internal/daemon/runtime_test.go +++ b/internal/daemon/runtime_test.go @@ -492,6 +492,60 @@ func TestIsDaemonAliveLegacyStatusCodes(t *testing.T) { } } +func TestCleanupZombieDaemonsPreservesTargetSocket(t *testing.T) { + // Regression test: when a zombie's socket matches the target + // (e.g. a systemd-managed socket), cleanup must remove the + // runtime file but preserve the socket — even when killProcess + // returns true because the PID was reused by a non-roborev + // process. + if runtime.GOOS == "windows" { + t.Skip("Unix sockets not supported on Windows") + } + + dataDir := testenv.SetDataDir(t) + assert := assert.New(t) + + // Create a real Unix socket as the "target" (stands in for the + // systemd-managed socket). Use a short path to stay under the + // Unix socket name length limit on macOS. + socketDir, err := os.MkdirTemp("/tmp", "rr-test-*") + require.NoError(t, err) + t.Cleanup(func() { os.RemoveAll(socketDir) }) + socketPath := filepath.Join(socketDir, "d.sock") + ln, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer ln.Close() + + target := DaemonEndpoint{Network: "unix", Address: socketPath} + + // Write a stale runtime file that points at the target socket. + // Use the current PID so isProcessAlive returns true (the PID + // exists), but mock identifyProcess to say it's not roborev + // (simulating PID reuse). + stalePID := os.Getpid() + runtimeJSON, err := json.Marshal(map[string]any{ + "pid": stalePID, + "addr": socketPath, + "port": 0, + "network": "unix", + "version": "stale", + }) + require.NoError(t, err) + runtimePath := filepath.Join( + dataDir, fmt.Sprintf("daemon.%d.json", stalePID)) + require.NoError(t, os.WriteFile(runtimePath, runtimeJSON, 0644)) + + mockIdentifyProcess(t, func(pid int) processIdentity { + return processNotRoborev + }) + + cleaned := CleanupZombieDaemons(target) + + assert.Equal(1, cleaned, "should count stale daemon as cleaned") + assert.NoFileExists(runtimePath, "runtime file should be removed") + assert.FileExists(socketPath, "target socket must be preserved") +} + func TestRuntimeInfo_Endpoint(t *testing.T) { assert := assert.New(t) diff --git a/internal/daemon/server.go b/internal/daemon/server.go index eaa7f928..fc6866ce 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -17,6 +17,9 @@ import ( "sync" "time" + "github.com/coreos/go-systemd/v22/activation" + "github.com/coreos/go-systemd/v22/daemon" + "github.com/roborev-dev/roborev/internal/agent" "github.com/roborev-dev/roborev/internal/config" "github.com/roborev-dev/roborev/internal/git" @@ -27,19 +30,20 @@ import ( // Server is the HTTP API server for the daemon type Server struct { - db *storage.DB - configWatcher *ConfigWatcher - broadcaster Broadcaster - workerPool *WorkerPool - httpServer *http.Server - syncWorker *storage.SyncWorker - ciPoller *CIPoller - hookRunner *HookRunner - errorLog *ErrorLog - activityLog *ActivityLog - startTime time.Time - endpointMu sync.Mutex // protects endpoint (written by Start, read by Stop) - endpoint DaemonEndpoint + db *storage.DB + configWatcher *ConfigWatcher + broadcaster Broadcaster + workerPool *WorkerPool + httpServer *http.Server + syncWorker *storage.SyncWorker + ciPoller *CIPoller + hookRunner *HookRunner + errorLog *ErrorLog + activityLog *ActivityLog + startTime time.Time + endpointMu sync.Mutex // protects endpoint (written by Start, read by Stop) + endpoint DaemonEndpoint + socketActivated bool // true if started via systemd socket activation // Cached machine ID to avoid INSERT on every status request machineIDMu sync.Mutex @@ -124,13 +128,23 @@ func NewServer(db *storage.DB, cfg *config.Config, configPath string) *Server { func (s *Server) Start(ctx context.Context) error { cfg := s.configWatcher.Config() - ep, err := ParseEndpoint(cfg.ServerAddr) + // Check for socket activation before falling back to the config + listener, ep, err := getSystemdListener() if err != nil { return err } + if listener != nil { + s.socketActivated = true + log.Printf("Using systemd socket activation on %s", ep) + } else { + ep, err = ParseEndpoint(cfg.ServerAddr) + if err != nil { + return err + } + } // Clean up any zombie daemons first (there can be only one) - if cleaned := CleanupZombieDaemons(); cleaned > 0 { + if cleaned := CleanupZombieDaemons(ep); cleaned > 0 { log.Printf("Cleaned up %d zombie daemon(s)", cleaned) if s.activityLog != nil { s.activityLog.Log( @@ -143,6 +157,9 @@ func (s *Server) Start(ctx context.Context) error { // Check if a responsive daemon is still running after cleanup if info, err := GetAnyRunningDaemon(); err == nil && IsDaemonAlive(info.Endpoint()) { + if listener != nil { + _ = listener.Close() + } return fmt.Errorf("daemon already running (pid %d on %s)", info.PID, info.Addr) } @@ -157,53 +174,54 @@ func (s *Server) Start(ctx context.Context) error { // Continue without hot-reloading - not a fatal error } - // Bind the listener before publishing runtime metadata so concurrent CLI - // invocations cannot race a half-started daemon and kill it as a zombie. - var listener net.Listener - if ep.IsUnix() { - socketPath := ep.Address - socketDir := filepath.Dir(socketPath) - if err := os.MkdirAll(socketDir, 0700); err != nil { - s.configWatcher.Stop() - return fmt.Errorf("create socket directory: %w", err) - } - // Verify the parent directory has safe permissions (owner-only) - if fi, err := os.Stat(socketDir); err == nil { - if perm := fi.Mode().Perm(); perm&0077 != 0 { + if !s.socketActivated { + // Bind the listener before publishing runtime metadata so concurrent CLI + // invocations cannot race a half-started daemon and kill it as a zombie. + if ep.IsUnix() { + socketPath := ep.Address + socketDir := filepath.Dir(socketPath) + if err := os.MkdirAll(socketDir, 0700); err != nil { s.configWatcher.Stop() - return fmt.Errorf("socket directory %s has unsafe permissions %o (must not be group/world accessible)", socketDir, perm) + return fmt.Errorf("create socket directory: %w", err) } - } - // Remove stale socket from a previous run - os.Remove(socketPath) - listener, err = ep.Listener() - if err != nil { - s.configWatcher.Stop() - return fmt.Errorf("listen on %s: %w", ep, err) - } - if err := os.Chmod(socketPath, 0600); err != nil { - _ = listener.Close() - s.configWatcher.Stop() - return fmt.Errorf("chmod socket: %w", err) - } - } else { - // TCP: find an available port first - addr, _, err := FindAvailablePort(ep.Address) - if err != nil { - s.configWatcher.Stop() - return fmt.Errorf("find available port: %w", err) - } - ep = DaemonEndpoint{Network: "tcp", Address: addr} - s.httpServer.Addr = addr + // Verify the parent directory has safe permissions (owner-only) + if fi, err := os.Stat(socketDir); err == nil { + if perm := fi.Mode().Perm(); perm&0077 != 0 { + s.configWatcher.Stop() + return fmt.Errorf("socket directory %s has unsafe permissions %o (must not be group/world accessible)", socketDir, perm) + } + } + // Remove stale socket from a previous run + os.Remove(socketPath) + listener, err = ep.Listener() + if err != nil { + s.configWatcher.Stop() + return fmt.Errorf("listen on %s: %w", ep, err) + } + if err := os.Chmod(socketPath, 0600); err != nil { + _ = listener.Close() + s.configWatcher.Stop() + return fmt.Errorf("chmod socket: %w", err) + } + } else { + // TCP: find an available port first + addr, _, err := FindAvailablePort(ep.Address) + if err != nil { + s.configWatcher.Stop() + return fmt.Errorf("find available port: %w", err) + } + ep = DaemonEndpoint{Network: "tcp", Address: addr} + s.httpServer.Addr = addr - listener, err = ep.Listener() - if err != nil { - s.configWatcher.Stop() - return fmt.Errorf("listen on %s: %w", ep, err) + listener, err = ep.Listener() + if err != nil { + s.configWatcher.Stop() + return fmt.Errorf("listen on %s: %w", ep, err) + } + // Update ep with actual bound address + ep = DaemonEndpoint{Network: "tcp", Address: listener.Addr().String()} + s.httpServer.Addr = ep.Address } - // Update ep with actual bound address - ep = DaemonEndpoint{Network: "tcp", Address: listener.Addr().String()} - s.httpServer.Addr = ep.Address } s.endpointMu.Lock() @@ -240,6 +258,10 @@ func (s *Server) Start(ctx context.Context) error { log.Printf("Warning: failed to write runtime info: %v", err) } + // Notify systemd that the daemon is ready. No-op when not running + // under systemd (NOTIFY_SOCKET is unset). + _, _ = daemon.SdNotify(false, daemon.SdNotifyReady) + // Log daemon start after runtime publication. if s.activityLog != nil { binary, _ := os.Executable() @@ -342,6 +364,62 @@ func logHookWarnings(repos []storage.Repo) { } } +// getSystemdListener returns the listener and endpoint passed by systemd during +// socket activation, or (nil, empty, nil) if not running under socket activation. +// Validates the listener matches the daemon's local-only trust model. +func getSystemdListener() (net.Listener, DaemonEndpoint, error) { + listeners, err := activation.Listeners() + if err != nil { + return nil, DaemonEndpoint{}, fmt.Errorf("socket activation: %w", err) + } + if len(listeners) == 0 { + return nil, DaemonEndpoint{}, nil + } + if len(listeners) > 1 { + return nil, DaemonEndpoint{}, fmt.Errorf( + "socket activation: multiple sockets not supported") + } + + listener := listeners[0] + if listener == nil { + return nil, DaemonEndpoint{}, fmt.Errorf( + "socket activation: unsupported socket type") + } + addr := listener.Addr().String() + if listener.Addr().Network() == "unix" { + if strings.HasPrefix(addr, "@") || strings.HasPrefix(addr, "\x00") { + _ = listener.Close() + return nil, DaemonEndpoint{}, fmt.Errorf( + "socket activation: abstract Unix sockets are not supported"+ + " (got %q); use a filesystem path in ListenStream=", addr) + } + addr = "unix://" + addr + } + ep, err := ParseEndpoint(addr) + if err != nil { + // Errors on non-localhost, etc. + _ = listener.Close() + return nil, ep, err + } + + // Ensure that Unix sockets have safe permissions. + if ep.IsUnix() { + fi, err := os.Stat(ep.Address) + if err != nil { + _ = listener.Close() + return nil, ep, fmt.Errorf("socket activation: %w", err) + } + if perm := fi.Mode().Perm(); perm&0077 != 0 { + _ = listener.Close() + return nil, ep, fmt.Errorf( + "socket activation: socket %q has unsafe permissions: %04o", + ep.Address, perm) + } + } + + return listener, ep, nil +} + // Stop gracefully shuts down the server func (s *Server) Stop() error { // Log daemon stop before shutting down components @@ -368,11 +446,11 @@ func (s *Server) Stop() error { log.Printf("HTTP server shutdown error: %v", err) } - // Clean up Unix domain socket + // Clean up Unix domain socket (if we created it) s.endpointMu.Lock() ep := s.endpoint s.endpointMu.Unlock() - if ep.IsUnix() { + if ep.IsUnix() && !s.socketActivated { os.Remove(ep.Address) }