Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

src = ./.;

vendorHash = "sha256-3eyTM8oYEkmAHshFGDTrbVWU106zvP48nDnrGtAta9M=";
vendorHash = "sha256-PE3kbfJQlvUeSPmLawxtVnqTEz+6EI+TS8dc7jphl7w=";

subPackages = [ "cmd/roborev" ];

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/charmbracelet/x/ansi v0.11.6
github.com/coder/acp-go-sdk v0.6.3
github.com/coreos/go-systemd/v22 v22.7.0
github.com/fsnotify/fsnotify v1.9.0
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ=
github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA=
github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
26 changes: 22 additions & 4 deletions internal/daemon/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func KillDaemon(info *RuntimeInfo) bool {

// CleanupZombieDaemons finds and kills all unresponsive daemons.
// Returns the number of zombies cleaned up.
func CleanupZombieDaemons() int {
func CleanupZombieDaemons(target DaemonEndpoint) int {
runtimes, err := ListAllRuntimes()
if err != nil {
return 0
Expand All @@ -448,7 +448,10 @@ func CleanupZombieDaemons() int {
// For Unix sockets, check PID liveness first to avoid slow HTTP probes
// against sockets whose owner process is already dead.
if ep.IsUnix() && info.PID > 0 && !isProcessAlive(info.PID) {
os.Remove(ep.Address)
if ep.Address != target.Address {
// Clean up non-matching sockets.
os.Remove(ep.Address)
}
if info.SourcePath != "" {
os.Remove(info.SourcePath)
} else {
Expand All @@ -463,8 +466,23 @@ func CleanupZombieDaemons() int {
continue
}

// Unresponsive - try to kill it
if KillDaemon(info) {
// Unresponsive — try to kill it. When the zombie's
// socket matches the target (e.g. a systemd-managed
// socket we're about to serve on), kill the process
// and clean up the runtime file but preserve the socket.
if ep.IsUnix() && ep.Address == target.Address {
if info.PID > 0 && !killProcess(info.PID) {
// Could not confirm kill; leave runtime
// metadata so the next attempt can retry.
continue
}
if info.SourcePath != "" {
os.Remove(info.SourcePath)
} else if info.PID > 0 {
RemoveRuntimeForPID(info.PID)
}
cleaned++
} else if KillDaemon(info) {
cleaned++
}
}
Expand Down
54 changes: 54 additions & 0 deletions internal/daemon/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,60 @@ func TestIsDaemonAliveLegacyStatusCodes(t *testing.T) {
}
}

func TestCleanupZombieDaemonsPreservesTargetSocket(t *testing.T) {
// Regression test: when a zombie's socket matches the target
// (e.g. a systemd-managed socket), cleanup must remove the
// runtime file but preserve the socket — even when killProcess
// returns true because the PID was reused by a non-roborev
// process.
if runtime.GOOS == "windows" {
t.Skip("Unix sockets not supported on Windows")
}

dataDir := testenv.SetDataDir(t)
assert := assert.New(t)

// Create a real Unix socket as the "target" (stands in for the
// systemd-managed socket). Use a short path to stay under the
// Unix socket name length limit on macOS.
socketDir, err := os.MkdirTemp("/tmp", "rr-test-*")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(socketDir) })
socketPath := filepath.Join(socketDir, "d.sock")
ln, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer ln.Close()

target := DaemonEndpoint{Network: "unix", Address: socketPath}

// Write a stale runtime file that points at the target socket.
// Use the current PID so isProcessAlive returns true (the PID
// exists), but mock identifyProcess to say it's not roborev
// (simulating PID reuse).
stalePID := os.Getpid()
runtimeJSON, err := json.Marshal(map[string]any{
"pid": stalePID,
"addr": socketPath,
"port": 0,
"network": "unix",
"version": "stale",
})
require.NoError(t, err)
runtimePath := filepath.Join(
dataDir, fmt.Sprintf("daemon.%d.json", stalePID))
require.NoError(t, os.WriteFile(runtimePath, runtimeJSON, 0644))

mockIdentifyProcess(t, func(pid int) processIdentity {
return processNotRoborev
})

cleaned := CleanupZombieDaemons(target)

assert.Equal(1, cleaned, "should count stale daemon as cleaned")
assert.NoFileExists(runtimePath, "runtime file should be removed")
assert.FileExists(socketPath, "target socket must be preserved")
}

func TestRuntimeInfo_Endpoint(t *testing.T) {
assert := assert.New(t)

Expand Down
198 changes: 138 additions & 60 deletions internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import (
"sync"
"time"

"github.com/coreos/go-systemd/v22/activation"
"github.com/coreos/go-systemd/v22/daemon"

"github.com/roborev-dev/roborev/internal/agent"
"github.com/roborev-dev/roborev/internal/config"
"github.com/roborev-dev/roborev/internal/git"
Expand All @@ -27,19 +30,20 @@ import (

// Server is the HTTP API server for the daemon
type Server struct {
db *storage.DB
configWatcher *ConfigWatcher
broadcaster Broadcaster
workerPool *WorkerPool
httpServer *http.Server
syncWorker *storage.SyncWorker
ciPoller *CIPoller
hookRunner *HookRunner
errorLog *ErrorLog
activityLog *ActivityLog
startTime time.Time
endpointMu sync.Mutex // protects endpoint (written by Start, read by Stop)
endpoint DaemonEndpoint
db *storage.DB
configWatcher *ConfigWatcher
broadcaster Broadcaster
workerPool *WorkerPool
httpServer *http.Server
syncWorker *storage.SyncWorker
ciPoller *CIPoller
hookRunner *HookRunner
errorLog *ErrorLog
activityLog *ActivityLog
startTime time.Time
endpointMu sync.Mutex // protects endpoint (written by Start, read by Stop)
endpoint DaemonEndpoint
socketActivated bool // true if started via systemd socket activation

// Cached machine ID to avoid INSERT on every status request
machineIDMu sync.Mutex
Expand Down Expand Up @@ -124,13 +128,23 @@ func NewServer(db *storage.DB, cfg *config.Config, configPath string) *Server {
func (s *Server) Start(ctx context.Context) error {
cfg := s.configWatcher.Config()

ep, err := ParseEndpoint(cfg.ServerAddr)
// Check for socket activation before falling back to the config
listener, ep, err := getSystemdListener()
if err != nil {
return err
}
if listener != nil {
s.socketActivated = true
log.Printf("Using systemd socket activation on %s", ep)
} else {
ep, err = ParseEndpoint(cfg.ServerAddr)
if err != nil {
return err
}
}

// Clean up any zombie daemons first (there can be only one)
if cleaned := CleanupZombieDaemons(); cleaned > 0 {
if cleaned := CleanupZombieDaemons(ep); cleaned > 0 {
log.Printf("Cleaned up %d zombie daemon(s)", cleaned)
if s.activityLog != nil {
s.activityLog.Log(
Expand All @@ -143,6 +157,9 @@ func (s *Server) Start(ctx context.Context) error {

// Check if a responsive daemon is still running after cleanup
if info, err := GetAnyRunningDaemon(); err == nil && IsDaemonAlive(info.Endpoint()) {
if listener != nil {
_ = listener.Close()
}
return fmt.Errorf("daemon already running (pid %d on %s)", info.PID, info.Addr)
}

Expand All @@ -157,53 +174,54 @@ func (s *Server) Start(ctx context.Context) error {
// Continue without hot-reloading - not a fatal error
}

// Bind the listener before publishing runtime metadata so concurrent CLI
// invocations cannot race a half-started daemon and kill it as a zombie.
var listener net.Listener
if ep.IsUnix() {
socketPath := ep.Address
socketDir := filepath.Dir(socketPath)
if err := os.MkdirAll(socketDir, 0700); err != nil {
s.configWatcher.Stop()
return fmt.Errorf("create socket directory: %w", err)
}
// Verify the parent directory has safe permissions (owner-only)
if fi, err := os.Stat(socketDir); err == nil {
if perm := fi.Mode().Perm(); perm&0077 != 0 {
if !s.socketActivated {
// Bind the listener before publishing runtime metadata so concurrent CLI
// invocations cannot race a half-started daemon and kill it as a zombie.
if ep.IsUnix() {
socketPath := ep.Address
socketDir := filepath.Dir(socketPath)
if err := os.MkdirAll(socketDir, 0700); err != nil {
s.configWatcher.Stop()
return fmt.Errorf("socket directory %s has unsafe permissions %o (must not be group/world accessible)", socketDir, perm)
return fmt.Errorf("create socket directory: %w", err)
}
}
// Remove stale socket from a previous run
os.Remove(socketPath)
listener, err = ep.Listener()
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("listen on %s: %w", ep, err)
}
if err := os.Chmod(socketPath, 0600); err != nil {
_ = listener.Close()
s.configWatcher.Stop()
return fmt.Errorf("chmod socket: %w", err)
}
} else {
// TCP: find an available port first
addr, _, err := FindAvailablePort(ep.Address)
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("find available port: %w", err)
}
ep = DaemonEndpoint{Network: "tcp", Address: addr}
s.httpServer.Addr = addr
// Verify the parent directory has safe permissions (owner-only)
if fi, err := os.Stat(socketDir); err == nil {
if perm := fi.Mode().Perm(); perm&0077 != 0 {
s.configWatcher.Stop()
return fmt.Errorf("socket directory %s has unsafe permissions %o (must not be group/world accessible)", socketDir, perm)
}
}
// Remove stale socket from a previous run
os.Remove(socketPath)
listener, err = ep.Listener()
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("listen on %s: %w", ep, err)
}
if err := os.Chmod(socketPath, 0600); err != nil {
_ = listener.Close()
s.configWatcher.Stop()
return fmt.Errorf("chmod socket: %w", err)
}
} else {
// TCP: find an available port first
addr, _, err := FindAvailablePort(ep.Address)
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("find available port: %w", err)
}
ep = DaemonEndpoint{Network: "tcp", Address: addr}
s.httpServer.Addr = addr

listener, err = ep.Listener()
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("listen on %s: %w", ep, err)
listener, err = ep.Listener()
if err != nil {
s.configWatcher.Stop()
return fmt.Errorf("listen on %s: %w", ep, err)
}
// Update ep with actual bound address
ep = DaemonEndpoint{Network: "tcp", Address: listener.Addr().String()}
s.httpServer.Addr = ep.Address
}
// Update ep with actual bound address
ep = DaemonEndpoint{Network: "tcp", Address: listener.Addr().String()}
s.httpServer.Addr = ep.Address
}

s.endpointMu.Lock()
Expand Down Expand Up @@ -240,6 +258,10 @@ func (s *Server) Start(ctx context.Context) error {
log.Printf("Warning: failed to write runtime info: %v", err)
}

// Notify systemd that the daemon is ready. No-op when not running
// under systemd (NOTIFY_SOCKET is unset).
_, _ = daemon.SdNotify(false, daemon.SdNotifyReady)

// Log daemon start after runtime publication.
if s.activityLog != nil {
binary, _ := os.Executable()
Expand Down Expand Up @@ -342,6 +364,62 @@ func logHookWarnings(repos []storage.Repo) {
}
}

// getSystemdListener returns the listener and endpoint passed by systemd during
// socket activation, or (nil, empty, nil) if not running under socket activation.
// Validates the listener matches the daemon's local-only trust model.
func getSystemdListener() (net.Listener, DaemonEndpoint, error) {
listeners, err := activation.Listeners()
if err != nil {
return nil, DaemonEndpoint{}, fmt.Errorf("socket activation: %w", err)
}
if len(listeners) == 0 {
return nil, DaemonEndpoint{}, nil
}
if len(listeners) > 1 {
return nil, DaemonEndpoint{}, fmt.Errorf(
"socket activation: multiple sockets not supported")
}

listener := listeners[0]
if listener == nil {
return nil, DaemonEndpoint{}, fmt.Errorf(
"socket activation: unsupported socket type")
}
addr := listener.Addr().String()
if listener.Addr().Network() == "unix" {
if strings.HasPrefix(addr, "@") || strings.HasPrefix(addr, "\x00") {
_ = listener.Close()
return nil, DaemonEndpoint{}, fmt.Errorf(
"socket activation: abstract Unix sockets are not supported"+
" (got %q); use a filesystem path in ListenStream=", addr)
}
addr = "unix://" + addr
}
ep, err := ParseEndpoint(addr)
if err != nil {
// Errors on non-localhost, etc.
_ = listener.Close()
return nil, ep, err
}

// Ensure that Unix sockets have safe permissions.
if ep.IsUnix() {
fi, err := os.Stat(ep.Address)
if err != nil {
_ = listener.Close()
return nil, ep, fmt.Errorf("socket activation: %w", err)
}
if perm := fi.Mode().Perm(); perm&0077 != 0 {
_ = listener.Close()
return nil, ep, fmt.Errorf(
"socket activation: socket %q has unsafe permissions: %04o",
ep.Address, perm)
}
}

return listener, ep, nil
}

// Stop gracefully shuts down the server
func (s *Server) Stop() error {
// Log daemon stop before shutting down components
Expand All @@ -368,11 +446,11 @@ func (s *Server) Stop() error {
log.Printf("HTTP server shutdown error: %v", err)
}

// Clean up Unix domain socket
// Clean up Unix domain socket (if we created it)
s.endpointMu.Lock()
ep := s.endpoint
s.endpointMu.Unlock()
if ep.IsUnix() {
if ep.IsUnix() && !s.socketActivated {
os.Remove(ep.Address)
}

Expand Down
Loading