From c9e27d10cd6865793cc38d80cff3143a3ac1c063 Mon Sep 17 00:00:00 2001 From: Peter Dedene Date: Thu, 19 Feb 2026 13:39:11 +0100 Subject: [PATCH 1/2] feat(pty): add PTY support for interactive TUI applications Enable interactive TUI apps (vim, htop, etc.) to work correctly through claw-wrap with proper colors and cursor key handling. Changes: - Add creack/pty and golang.org/x/term dependencies - Extend protocol with WinSize, UsePTY, and window size messages - Include PTY flag in HMAC signature to prevent downgrade attacks - Add use_pty config option for tools (opt-in per tool) - Implement PTY allocation in daemon executor with SIGWINCH forwarding - Add raw terminal mode and window size tracking in wrapper - Wrapper always requests PTY when terminal attached; daemon decides based on tool config (graceful fallback to pipes) Security: PTY flag signed in HMAC, process group isolation preserved, output redaction applies to PTY stream, audit logging handles PTY output. --- go.mod | 4 +- go.sum | 6 + internal/auth/auth.go | 26 +++- internal/auth/auth_test.go | 79 ++++++++++ internal/config/config.go | 1 + internal/config/config_test.go | 60 ++++++++ internal/daemon/daemon.go | 2 +- internal/daemon/executor.go | 256 ++++++++++++++++++++++++++++++--- internal/protocol/protocol.go | 31 ++-- internal/wrapper/wrapper.go | 183 +++++++++++++++++++++-- 10 files changed, 598 insertions(+), 50 deletions(-) diff --git a/go.mod b/go.mod index 8dd27d8..efe4c25 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,11 @@ go 1.24.4 require ( filippo.io/age v1.3.1 + github.com/creack/pty v1.1.24 github.com/elazarl/goproxy v1.8.1 github.com/itchyny/gojq v0.12.18 - golang.org/x/sys v0.40.0 + golang.org/x/sys v0.41.0 + golang.org/x/term v0.40.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index bdc097d..65dabdf 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ filippo.io/hpke v0.4.0 h1:p575VVQ6ted4pL+it6M00V/f2qTZITO0zgmdKCkd5+A= filippo.io/hpke v0.4.0/go.mod h1:EmAN849/P3qdeK+PCMkDpDm83vRHM5cDipBJ8xbQLVY= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/elazarl/goproxy v1.8.1 h1:/qGpPJGgIPOTZ7IoIQvjavocp//qYSe9LQnIGCgRY5k= @@ -24,6 +26,10 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index c5b3a60..f685ac3 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -161,6 +161,14 @@ func ComputeHMAC(secret []byte, timestamp, tool, cwd string, args []string, nonc // timestamp + tool + json(args) + cwd + json(env canonical) + nonce. // Fields are separated by newlines to prevent boundary confusion. func ComputeHMACWithEnv(secret []byte, timestamp, tool, cwd string, args []string, env map[string]string, nonce string) (string, error) { + return ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, false) +} + +// ComputeHMACWithPTY computes an HMAC-SHA256 signature including the PTY flag. +// Message format: timestamp + tool + json(args) + cwd + json(env canonical) + nonce + ptyFlag +// where ptyFlag is "1" if usePTY is true, "0" otherwise. +// Fields are separated by newlines to prevent boundary confusion. +func ComputeHMACWithPTY(secret []byte, timestamp, tool, cwd string, args []string, env map[string]string, nonce string, usePTY bool) (string, error) { // Serialize args as JSON for consistent encoding argsJSON, err := json.Marshal(args) if err != nil { @@ -172,8 +180,14 @@ func ComputeHMACWithEnv(secret []byte, timestamp, tool, cwd string, args []strin return "", fmt.Errorf("failed to marshal env: %w", err) } + // PTY flag as string for consistent encoding + ptyFlag := "0" + if usePTY { + ptyFlag = "1" + } + // Build the message to sign — fields separated by \n to prevent boundary confusion - message := timestamp + "\n" + tool + "\n" + string(argsJSON) + "\n" + cwd + "\n" + envJSON + "\n" + nonce + message := timestamp + "\n" + tool + "\n" + string(argsJSON) + "\n" + cwd + "\n" + envJSON + "\n" + nonce + "\n" + ptyFlag // Compute HMAC-SHA256 mac := hmac.New(sha256.New, secret) @@ -193,14 +207,22 @@ func VerifyHMAC(secret []byte, timestamp, tool, cwd string, args []string, nonce } // VerifyHMACWithEnv verifies the provided HMAC signature with env included. +// For backward compatibility, assumes usePTY=false. func VerifyHMACWithEnv(secret []byte, timestamp, tool, cwd string, args []string, env map[string]string, nonce, providedHMAC string) error { + return VerifyHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, false, providedHMAC) +} + +// VerifyHMACWithPTY verifies the provided HMAC signature with PTY flag included. +// It uses constant-time comparison to prevent timing attacks and validates +// that the timestamp is within the allowed freshness window. +func VerifyHMACWithPTY(secret []byte, timestamp, tool, cwd string, args []string, env map[string]string, nonce string, usePTY bool, providedHMAC string) error { // First validate timestamp freshness if err := ValidateTimestamp(timestamp); err != nil { return err } // Compute expected HMAC - expectedHMAC, err := ComputeHMACWithEnv(secret, timestamp, tool, cwd, args, env, nonce) + expectedHMAC, err := ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, usePTY) if err != nil { return fmt.Errorf("failed to compute expected HMAC: %w", err) } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 07824cc..3e26ebf 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -635,3 +635,82 @@ func TestLoadSecret_RejectsSymlink(t *testing.T) { t.Errorf("LoadSecret() error = %v, want symlink-related error", err) } } + +func TestComputeHMACWithPTY_DifferentPTYFlag(t *testing.T) { + // PTY flag should change the HMAC signature + secret := []byte("test-secret-key-for-hmac-testing") + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + tool := "test-tool" + cwd := "/home/user" + args := []string{"arg1", "arg2"} + env := map[string]string{"FOO": "bar"} + nonce := "dGVzdC1ub25jZS0xMjM0" // base64("test-nonce-1234") + + // Compute HMAC with PTY false + hmacNoPTY, err := ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, false) + if err != nil { + t.Fatalf("ComputeHMACWithPTY(usePTY=false) error = %v", err) + } + + // Compute HMAC with PTY true + hmacWithPTY, err := ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, true) + if err != nil { + t.Fatalf("ComputeHMACWithPTY(usePTY=true) error = %v", err) + } + + // Signatures must be different + if hmacNoPTY == hmacWithPTY { + t.Error("HMACs should differ based on PTY flag") + } +} + +func TestVerifyHMACWithPTY_PTYFlagMismatch(t *testing.T) { + // HMAC computed with usePTY=true should fail verification with usePTY=false + secret := []byte("test-secret-key-for-hmac-testing") + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + tool := "test-tool" + cwd := "/home/user" + args := []string{"arg1"} + nonce := "dGVzdC1ub25jZS0xMjM0" + + // Compute with PTY=true + sig, err := ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, nil, nonce, true) + if err != nil { + t.Fatalf("ComputeHMACWithPTY() error = %v", err) + } + + // Verify with PTY=true should pass + if err := VerifyHMACWithPTY(secret, timestamp, tool, cwd, args, nil, nonce, true, sig); err != nil { + t.Errorf("VerifyHMACWithPTY(usePTY=true) rejected valid signature: %v", err) + } + + // Verify with PTY=false should fail + if err := VerifyHMACWithPTY(secret, timestamp, tool, cwd, args, nil, nonce, false, sig); err == nil { + t.Error("VerifyHMACWithPTY(usePTY=false) should reject signature computed with usePTY=true") + } +} + +func TestComputeHMACWithEnv_BackwardCompatibleWithPTYFalse(t *testing.T) { + // ComputeHMACWithEnv should produce same result as ComputeHMACWithPTY(usePTY=false) + secret := []byte("test-secret-key-for-hmac-testing") + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + tool := "test-tool" + cwd := "/home/user" + args := []string{"arg1", "arg2"} + env := map[string]string{"VAR1": "value1"} + nonce := "dGVzdC1ub25jZS0xMjM0" + + hmacEnv, err := ComputeHMACWithEnv(secret, timestamp, tool, cwd, args, env, nonce) + if err != nil { + t.Fatalf("ComputeHMACWithEnv() error = %v", err) + } + + hmacPTYFalse, err := ComputeHMACWithPTY(secret, timestamp, tool, cwd, args, env, nonce, false) + if err != nil { + t.Fatalf("ComputeHMACWithPTY() error = %v", err) + } + + if hmacEnv != hmacPTYFalse { + t.Error("ComputeHMACWithEnv should equal ComputeHMACWithPTY(usePTY=false) for backward compatibility") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index cb8ebbc..5ca0940 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -183,6 +183,7 @@ type ToolDef struct { RedactOutput []ToolRedactRule `yaml:"redact_output,omitempty"` ConfigFile *ConfigFileDef `yaml:"config_file,omitempty"` UseProxy bool `yaml:"use_proxy,omitempty"` // Enable HTTP proxy for this tool + UsePTY bool `yaml:"use_pty,omitempty"` // Enable PTY mode for interactive TUI apps } // ToolRedactRule defines an output redaction rule for tool stdout/stderr. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b4db55b..ee3a5d8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -946,6 +946,66 @@ func TestValidate_FullValidConfig(t *testing.T) { } } +func TestValidate_UsePTY_ValidConfig(t *testing.T) { + cfg := &Config{ + Tools: map[string]ToolDef{ + "vim": { + Binary: "/usr/bin/vim", + UsePTY: true, + }, + "grep": { + Binary: "/usr/bin/grep", + UsePTY: false, + }, + }, + } + err := cfg.Validate() + if err != nil { + t.Errorf("Validate() unexpected error for PTY config: %v", err) + } + + // Verify the flag is preserved + if !cfg.Tools["vim"].UsePTY { + t.Error("UsePTY should be true for vim") + } + if cfg.Tools["grep"].UsePTY { + t.Error("UsePTY should be false for grep") + } +} + +func TestLoad_UsePTY_FromYAML(t *testing.T) { + yaml := ` +tools: + vim: + binary: /usr/bin/vim + use_pty: true + grep: + binary: /usr/bin/grep +` + tmpFile, err := os.CreateTemp("", "config-pty-*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(yaml); err != nil { + t.Fatalf("Failed to write config: %v", err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if !cfg.Tools["vim"].UsePTY { + t.Error("UsePTY should be true for vim (from YAML)") + } + if cfg.Tools["grep"].UsePTY { + t.Error("UsePTY should default to false for grep") + } +} + func TestValidate_ConfigFilePathTraversalRejected(t *testing.T) { cfg := &Config{ Credentials: map[string]CredentialDef{ diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index c0f4529..4c119fe 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -613,7 +613,7 @@ func (d *Daemon) handleProxyRequest(conn net.Conn, data []byte, cfg *config.Conf return } - if err := auth.VerifyHMACWithEnv(d.secret, req.Timestamp, req.Tool, req.Cwd, req.Args, req.Env, req.Nonce, req.HMAC); err != nil { + if err := auth.VerifyHMACWithPTY(d.secret, req.Timestamp, req.Tool, req.Cwd, req.Args, req.Env, req.Nonce, req.UsePTY, req.HMAC); err != nil { d.metrics.Inc("auth_fail") log.Printf("[WARN] deny reason=auth_failed tool=%s err=%v", req.Tool, err) d.sendProxyError(conn, "authentication failed") diff --git a/internal/daemon/executor.go b/internal/daemon/executor.go index ef926fa..a5fd6a1 100644 --- a/internal/daemon/executor.go +++ b/internal/daemon/executor.go @@ -21,6 +21,8 @@ import ( "syscall" "time" + "github.com/creack/pty" + "claw-wrap/internal/audit" "claw-wrap/internal/config" "claw-wrap/internal/credentials" @@ -145,6 +147,12 @@ type ToolExecutor struct { startTime time.Time stdoutHash hash.Hash stderrHash hash.Hash + + // PTY mode fields + usePTY bool // effective PTY mode (req.UsePTY && tool.UsePTY) + ptyMaster *os.File // PTY master fd (nil if not using PTY) + ptyBuf *OutputBuffer + ptyHash hash.Hash } // NewToolExecutor creates a new ToolExecutor for the given request. @@ -433,7 +441,14 @@ func (e *ToolExecutor) setupConfigFile() error { } // startProcess spawns the tool in a new process group. +// If PTY mode is enabled, delegates to startProcessWithPTY. func (e *ToolExecutor) startProcess(env []string) error { + // Check if PTY mode should be used + e.usePTY = e.req.UsePTY && e.tool.UsePTY + if e.usePTY { + return e.startProcessWithPTY(env) + } + e.cmd = exec.CommandContext(e.ctx, e.tool.Binary, e.req.Args...) e.cmd.Dir = e.req.Cwd e.cmd.Env = env @@ -495,6 +510,150 @@ func (e *ToolExecutor) startProcess(env []string) error { return nil } +// startProcessWithPTY spawns the tool with a pseudo-terminal. +// This enables interactive TUI applications to work correctly with colors and cursor control. +func (e *ToolExecutor) startProcessWithPTY(env []string) error { + e.cmd = exec.CommandContext(e.ctx, e.tool.Binary, e.req.Args...) + e.cmd.Dir = e.req.Cwd + e.cmd.Env = env + + // Start with PTY - pty.Start handles setting up stdin/stdout/stderr + ptmx, err := pty.Start(e.cmd) + if err != nil { + return fmt.Errorf("pty start: %w", err) + } + e.ptyMaster = ptmx + + // Get process group ID (process is its own leader in PTY mode) + e.pgid = e.cmd.Process.Pid + + // Set initial window size if provided + if e.req.WindowSize != nil { + if err := pty.Setsize(ptmx, &pty.Winsize{ + Rows: e.req.WindowSize.Rows, + Cols: e.req.WindowSize.Cols, + }); err != nil { + log.Printf("[WARN] set initial window size: %v", err) + } + } + + // Create single output buffer for PTY (stdout and stderr are merged) + // PTY mode always streams inline - no file buffering threshold + e.ptyBuf = NewOutputBuffer("stdout", 0, e.maxOutSz, e.sendMessage) + if len(e.tool.RedactOutput) > 0 { + e.ptyBuf.SetRedactor(NewOutputRedactor(e.tool.RedactOutput)) + } + + // Wire up SHA256 hasher for audit output hash + if auditCfg := e.cfg.GetAuditConfig(); auditCfg != nil && auditCfg.Enabled && auditCfg.GetIncludeOutputHash() { + e.ptyHash = sha256.New() + e.ptyBuf.SetTee(e.ptyHash) + } + + // Start I/O pumpers + e.pumperWg.Add(1) // Single pumper for PTY output + + go e.ptyOutputPumper(ptmx) + go e.ptyInputPumper(ptmx) // stdin pumper runs independently + + return nil +} + +// ptyOutputPumper reads from PTY master and writes to the output buffer. +func (e *ToolExecutor) ptyOutputPumper(r io.Reader) { + defer e.pumperWg.Done() + + buf := make([]byte, 32*1024) // 32KB buffer + for { + n, err := r.Read(buf) + if n > 0 { + if writeErr := e.ptyBuf.Write(buf[:n]); writeErr != nil { + if errors.Is(writeErr, ErrOutputLimitExceeded) { + log.Printf("[WARN] pty output: %v, killing process", writeErr) + e.killProcessGroup(syscall.SIGKILL) + return + } + log.Printf("[WARN] pty output write: %v", writeErr) + } + } + if err != nil { + if err != io.EOF { + log.Printf("[DEBUG] pty output read: %v", err) + } + return + } + } +} + +// ptyInputPumper reads WrapperMessages from the connection and writes to PTY master. +func (e *ToolExecutor) ptyInputPumper(ptmx *os.File) { + reader := framing.NewNDJSONReaderWithLimit(e.conn, e.msgSize) + + for { + if e.readMsgTO > 0 { + _ = e.conn.SetReadDeadline(time.Now().Add(e.readMsgTO)) + } + + var msg protocol.WrapperMessage + if err := reader.Read(&msg); err != nil { + _ = e.conn.SetReadDeadline(time.Time{}) + if err != io.EOF { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + log.Printf("[WARN] pty stdin/control read timeout after %v", e.readMsgTO) + } + log.Printf("[DEBUG] pty stdin read: %v", err) + } + return + } + _ = e.conn.SetReadDeadline(time.Time{}) + + if err := e.handlePTYWrapperMessage(&msg, ptmx); err != nil { + log.Printf("[WARN] handle pty wrapper message: %v", err) + } + } +} + +// handlePTYWrapperMessage processes a message from the wrapper in PTY mode. +func (e *ToolExecutor) handlePTYWrapperMessage(msg *protocol.WrapperMessage, ptmx *os.File) error { + switch msg.Type { + case protocol.MsgTypeStdin: + if msg.EOF { + // In PTY mode, we don't close the master - just stop writing + return nil + } + // Decode and write data to PTY master + data, err := base64.StdEncoding.DecodeString(msg.Data) + if err != nil { + return fmt.Errorf("decode stdin: %w", err) + } + if _, err := ptmx.Write(data); err != nil { + return fmt.Errorf("write pty: %w", err) + } + + case protocol.MsgTypeSignal: + if err := e.forwardSignal(msg.Signal); err != nil { + return fmt.Errorf("forward signal: %w", err) + } + + case protocol.MsgTypeWinSize: + if err := pty.Setsize(ptmx, &pty.Winsize{ + Rows: msg.Rows, + Cols: msg.Cols, + }); err != nil { + return fmt.Errorf("set window size: %w", err) + } + + case protocol.MsgTypeCleanup: + // Compatibility no-op + log.Printf("[DEBUG] ignoring client cleanup request in PTY mode") + + default: + log.Printf("[WARN] unknown wrapper message type in PTY mode: %s", msg.Type) + } + + return nil +} + // runIOLoop waits for the process to complete and handles timeout. func (e *ToolExecutor) runIOLoop() error { // Wait for process in a goroutine @@ -737,27 +896,49 @@ func (e *ToolExecutor) sendDone(exitCode int, timeout bool) { // finalizeOutput closes output buffers and streams any file-buffered output. func (e *ToolExecutor) finalizeOutput() error { - // Finalize stdout buffer - if stdoutPath, err := e.stdoutBuf.Finalize(); err != nil { - return fmt.Errorf("finalize stdout: %w", err) - } else if stdoutPath != "" { - if err := e.streamFile(stdoutPath, protocol.MsgTypeStdout); err != nil { - return fmt.Errorf("stream stdout file: %w", err) + // PTY mode uses single buffer + if e.usePTY { + if e.ptyBuf != nil { + if ptyPath, err := e.ptyBuf.Finalize(); err != nil { + return fmt.Errorf("finalize pty output: %w", err) + } else if ptyPath != "" { + if err := e.streamFile(ptyPath, protocol.MsgTypeStdout); err != nil { + return fmt.Errorf("stream pty file: %w", err) + } + if err := os.Remove(ptyPath); err != nil && !os.IsNotExist(err) { + log.Printf("[WARN] cleanup pty temp file %s: %v", ptyPath, err) + } + } } - if err := os.Remove(stdoutPath); err != nil && !os.IsNotExist(err) { - log.Printf("[WARN] cleanup stdout temp file %s: %v", stdoutPath, err) + return nil + } + + // Pipe mode uses separate stdout/stderr buffers + // Finalize stdout buffer + if e.stdoutBuf != nil { + if stdoutPath, err := e.stdoutBuf.Finalize(); err != nil { + return fmt.Errorf("finalize stdout: %w", err) + } else if stdoutPath != "" { + if err := e.streamFile(stdoutPath, protocol.MsgTypeStdout); err != nil { + return fmt.Errorf("stream stdout file: %w", err) + } + if err := os.Remove(stdoutPath); err != nil && !os.IsNotExist(err) { + log.Printf("[WARN] cleanup stdout temp file %s: %v", stdoutPath, err) + } } } // Finalize stderr buffer - if stderrPath, err := e.stderrBuf.Finalize(); err != nil { - return fmt.Errorf("finalize stderr: %w", err) - } else if stderrPath != "" { - if err := e.streamFile(stderrPath, protocol.MsgTypeStderr); err != nil { - return fmt.Errorf("stream stderr file: %w", err) - } - if err := os.Remove(stderrPath); err != nil && !os.IsNotExist(err) { - log.Printf("[WARN] cleanup stderr temp file %s: %v", stderrPath, err) + if e.stderrBuf != nil { + if stderrPath, err := e.stderrBuf.Finalize(); err != nil { + return fmt.Errorf("finalize stderr: %w", err) + } else if stderrPath != "" { + if err := e.streamFile(stderrPath, protocol.MsgTypeStderr); err != nil { + return fmt.Errorf("stream stderr file: %w", err) + } + if err := os.Remove(stderrPath); err != nil && !os.IsNotExist(err) { + log.Printf("[WARN] cleanup stderr temp file %s: %v", stderrPath, err) + } } } @@ -862,12 +1043,18 @@ func (e *ToolExecutor) cleanup() { // Cancel context e.cancel() - // Close stdin pipe if still open + // Close stdin pipe if still open (pipe mode) if e.stdinPipe != nil { e.stdinPipe.Close() e.stdinPipe = nil } + // Close PTY master if still open (PTY mode) + if e.ptyMaster != nil { + e.ptyMaster.Close() + e.ptyMaster = nil + } + // Kill process group if still running if e.pgid != 0 { // Best effort kill @@ -881,6 +1068,9 @@ func (e *ToolExecutor) cleanup() { if e.stderrBuf != nil { e.stderrBuf.Cleanup() } + if e.ptyBuf != nil { + e.ptyBuf.Cleanup() + } // Remove config dir if e.configDir != "" { @@ -901,6 +1091,19 @@ func (e *ToolExecutor) emitAuditEntry(exitCode int, timeout bool) { return } + // Calculate output bytes based on mode + var outputBytes int64 + if e.usePTY && e.ptyBuf != nil { + outputBytes = e.ptyBuf.Accumulated() + } else { + if e.stdoutBuf != nil { + outputBytes += e.stdoutBuf.Accumulated() + } + if e.stderrBuf != nil { + outputBytes += e.stderrBuf.Accumulated() + } + } + entry := audit.Entry{ Timestamp: e.startTime.UTC().Format(time.RFC3339), Tool: e.req.Tool, @@ -908,7 +1111,7 @@ func (e *ToolExecutor) emitAuditEntry(exitCode int, timeout bool) { CallerPID: e.callerPID, CallerExe: e.callerExe, ExitCode: exitCode, - OutputBytes: e.stdoutBuf.Accumulated() + e.stderrBuf.Accumulated(), + OutputBytes: outputBytes, } if timeout { entry.Timeout = true @@ -919,11 +1122,18 @@ func (e *ToolExecutor) emitAuditEntry(exitCode int, timeout bool) { if auditCfg.GetIncludeDuration() { entry.DurationMs = time.Since(e.startTime).Milliseconds() } - if auditCfg.GetIncludeOutputHash() && e.stdoutHash != nil { - combined := sha256.New() - combined.Write(e.stdoutHash.Sum(nil)) - combined.Write(e.stderrHash.Sum(nil)) - entry.OutputHash = fmt.Sprintf("sha256:%x", combined.Sum(nil)) + // Calculate output hash based on mode + if auditCfg.GetIncludeOutputHash() { + if e.usePTY && e.ptyHash != nil { + entry.OutputHash = fmt.Sprintf("sha256:%x", e.ptyHash.Sum(nil)) + } else if e.stdoutHash != nil { + combined := sha256.New() + combined.Write(e.stdoutHash.Sum(nil)) + if e.stderrHash != nil { + combined.Write(e.stderrHash.Sum(nil)) + } + entry.OutputHash = fmt.Sprintf("sha256:%x", combined.Sum(nil)) + } } if err := e.auditLogger.Log(entry); err != nil { diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index a4f7195..b35d1da 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -2,18 +2,26 @@ package protocol // ProtocolVersion is the wire protocol version for wrapper/daemon requests. -const ProtocolVersion = 3 +const ProtocolVersion = 4 + +// WinSize represents terminal window dimensions for PTY mode. +type WinSize struct { + Rows uint16 `json:"rows"` + Cols uint16 `json:"cols"` +} // ProxyRequest is sent by wrapper to request tool execution (NDJSON format) type ProxyRequest struct { - Version int `json:"version,omitempty"` - Tool string `json:"tool"` - Args []string `json:"args"` - Cwd string `json:"cwd"` - Timestamp string `json:"timestamp"` - Nonce string `json:"nonce"` - HMAC string `json:"hmac"` - Env map[string]string `json:"env,omitempty"` + Version int `json:"version,omitempty"` + Tool string `json:"tool"` + Args []string `json:"args"` + Cwd string `json:"cwd"` + Timestamp string `json:"timestamp"` + Nonce string `json:"nonce"` + HMAC string `json:"hmac"` + Env map[string]string `json:"env,omitempty"` + UsePTY bool `json:"use_pty,omitempty"` + WindowSize *WinSize `json:"window_size,omitempty"` } // ResponseMessage is sent by daemon during execution (length-prefixed) @@ -29,11 +37,13 @@ type ResponseMessage struct { // WrapperMessage is sent by wrapper during execution (NDJSON format) type WrapperMessage struct { - Type string `json:"type"` // stdin, signal, cleanup + Type string `json:"type"` // stdin, signal, cleanup, winsize Data string `json:"data,omitempty"` // base64 for stdin EOF bool `json:"eof,omitempty"` // stdin EOF Signal string `json:"signal,omitempty"` // SIGINT, SIGTERM, SIGHUP Files []string `json:"files,omitempty"` // cleanup paths + Rows uint16 `json:"rows,omitempty"` // terminal rows (for winsize) + Cols uint16 `json:"cols,omitempty"` // terminal cols (for winsize) } // Message type constants @@ -46,6 +56,7 @@ const ( MsgTypeStdin = "stdin" MsgTypeSignal = "signal" MsgTypeCleanup = "cleanup" + MsgTypeWinSize = "winsize" ) // ValidSignals for signal forwarding diff --git a/internal/wrapper/wrapper.go b/internal/wrapper/wrapper.go index 26f200c..0865578 100644 --- a/internal/wrapper/wrapper.go +++ b/internal/wrapper/wrapper.go @@ -13,6 +13,8 @@ import ( "syscall" "time" + "golang.org/x/term" + "claw-wrap/internal/auth" "claw-wrap/internal/framing" "claw-wrap/internal/paths" @@ -64,42 +66,63 @@ func (w *Wrapper) RunTool(toolName string, args []string) error { return fmt.Errorf("get cwd: %w", err) } - // 3. Compute timestamp, nonce, and HMAC + // 3. Detect terminal and request PTY if available + // Always request PTY when stdin is a terminal - daemon decides based on tool config + usePTY := false + var winSize *protocol.WinSize + stdinFd := int(os.Stdin.Fd()) + if term.IsTerminal(stdinFd) { + usePTY = true + cols, rows, _ := term.GetSize(stdinFd) + if cols > 0 && rows > 0 { + winSize = &protocol.WinSize{ + Rows: uint16(rows), + Cols: uint16(cols), + } + } + } + + // 4. Compute timestamp, nonce, and HMAC (includes PTY flag in signature) timestamp := strconv.FormatInt(time.Now().Unix(), 10) nonce, err := auth.GenerateNonce() if err != nil { return fmt.Errorf("generate nonce: %w", err) } var reqEnv map[string]string - hmac, err := auth.ComputeHMACWithEnv(secret, timestamp, toolName, cwd, args, reqEnv, nonce) + hmac, err := auth.ComputeHMACWithPTY(secret, timestamp, toolName, cwd, args, reqEnv, nonce, usePTY) if err != nil { return fmt.Errorf("compute hmac: %w", err) } - // 4. Connect to socket + // 5. Connect to socket conn, err := net.Dial("unix", w.socketPath) if err != nil { return fmt.Errorf("connect: %w", err) } defer conn.Close() - // 5. Send ProxyRequest (NDJSON) + // 6. Send ProxyRequest (NDJSON) req := &protocol.ProxyRequest{ - Version: protocol.ProtocolVersion, - Tool: toolName, - Args: args, - Cwd: cwd, - Timestamp: timestamp, - Nonce: nonce, - HMAC: hmac, - Env: reqEnv, + Version: protocol.ProtocolVersion, + Tool: toolName, + Args: args, + Cwd: cwd, + Timestamp: timestamp, + Nonce: nonce, + HMAC: hmac, + Env: reqEnv, + UsePTY: usePTY, + WindowSize: winSize, } ndjson := framing.NewNDJSONWriter(conn) if err := ndjson.Write(req); err != nil { return fmt.Errorf("send request: %w", err) } - // 6. Enter I/O loop + // 7. Enter I/O loop (PTY mode uses raw terminal) + if usePTY { + return w.ioLoopPTY(conn, ndjson, stdinFd) + } return w.ioLoop(conn, ndjson) } @@ -211,6 +234,140 @@ func (w *Wrapper) ioLoop(conn net.Conn, ndjson *framing.NDJSONWriter) error { } } +// ioLoopPTY handles I/O in PTY mode with raw terminal and SIGWINCH forwarding. +func (w *Wrapper) ioLoopPTY(conn net.Conn, ndjson *framing.NDJSONWriter, stdinFd int) error { + // Put terminal in raw mode + oldState, err := term.MakeRaw(stdinFd) + if err != nil { + // Fall back to regular mode if raw mode fails + return w.ioLoop(conn, ndjson) + } + defer term.Restore(stdinFd, oldState) + + decoder := framing.NewDecoder(conn) + + // Channels for coordination + stdinCh := make(chan []byte, 16) + stdinEOF := make(chan struct{}) + signalCh := make(chan os.Signal, 1) + winSizeCh := make(chan os.Signal, 1) + doneCh := make(chan struct{}) + var exitCode int + + // Start stdin reader goroutine + go func() { + buf := make([]byte, 32*1024) // 32KB chunks + for { + n, err := os.Stdin.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + select { + case stdinCh <- chunk: + case <-doneCh: + return + } + } + if err != nil { + close(stdinEOF) + return + } + } + }() + + // Register signal handlers + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + signal.Notify(winSizeCh, syscall.SIGWINCH) + defer signal.Stop(signalCh) + defer signal.Stop(winSizeCh) + + // Main loop + responseCh := make(chan *protocol.ResponseMessage) + errCh := make(chan error) + + // Response reader goroutine + go func() { + for { + var msg protocol.ResponseMessage + if err := decoder.Decode(&msg); err != nil { + if err != io.EOF { + errCh <- err + } + return + } + select { + case responseCh <- &msg: + case <-doneCh: + return + } + } + }() + + for { + select { + case msg := <-responseCh: + done, err := w.handleResponse(msg) + if err != nil { + return err + } + if done { + exitCode = msg.ExitCode + close(doneCh) + // Restore terminal before exit + term.Restore(stdinFd, oldState) + os.Exit(exitCode) + } + + case err := <-errCh: + close(doneCh) + return fmt.Errorf("read response: %w", err) + + case chunk := <-stdinCh: + msg := &protocol.WrapperMessage{ + Type: protocol.MsgTypeStdin, + Data: base64.StdEncoding.EncodeToString(chunk), + } + if err := ndjson.Write(msg); err != nil { + return fmt.Errorf("send stdin: %w", err) + } + + case <-stdinEOF: + msg := &protocol.WrapperMessage{ + Type: protocol.MsgTypeStdin, + EOF: true, + } + ndjson.Write(msg) + stdinEOF = nil // Disable this case + + case sig := <-signalCh: + sigName := "SIGTERM" + switch sig { + case syscall.SIGINT: + sigName = "SIGINT" + case syscall.SIGHUP: + sigName = "SIGHUP" + } + msg := &protocol.WrapperMessage{ + Type: protocol.MsgTypeSignal, + Signal: sigName, + } + ndjson.Write(msg) + + case <-winSizeCh: + // Get new window size and forward to daemon + cols, rows, _ := term.GetSize(stdinFd) + if cols > 0 && rows > 0 { + msg := &protocol.WrapperMessage{ + Type: protocol.MsgTypeWinSize, + Rows: uint16(rows), + Cols: uint16(cols), + } + ndjson.Write(msg) + } + } + } +} + func (w *Wrapper) handleResponse(msg *protocol.ResponseMessage) (done bool, err error) { switch msg.Type { case protocol.MsgTypeStdout: From fe9a7e62231d39bd8a3d065f296f239b6061bfef Mon Sep 17 00:00:00 2001 From: Peter Dedene Date: Sat, 21 Feb 2026 11:35:36 +0100 Subject: [PATCH 2/2] docs: clarify PTY terminal restore and process group comments Address PR review feedback: - Explain why both defer and explicit term.Restore() are needed (os.Exit doesn't run defers) - Clarify that pty.Start() calls Setsid() internally for process group creation --- internal/daemon/executor.go | 3 ++- internal/wrapper/wrapper.go | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/daemon/executor.go b/internal/daemon/executor.go index a5fd6a1..0d335c0 100644 --- a/internal/daemon/executor.go +++ b/internal/daemon/executor.go @@ -524,7 +524,8 @@ func (e *ToolExecutor) startProcessWithPTY(env []string) error { } e.ptyMaster = ptmx - // Get process group ID (process is its own leader in PTY mode) + // Get process group ID. pty.Start() calls Setsid() internally, + // making the child its own session leader and process group leader. e.pgid = e.cmd.Process.Pid // Set initial window size if provided diff --git a/internal/wrapper/wrapper.go b/internal/wrapper/wrapper.go index 0865578..ac8917b 100644 --- a/internal/wrapper/wrapper.go +++ b/internal/wrapper/wrapper.go @@ -242,6 +242,8 @@ func (w *Wrapper) ioLoopPTY(conn net.Conn, ndjson *framing.NDJSONWriter, stdinFd // Fall back to regular mode if raw mode fails return w.ioLoop(conn, ndjson) } + // Restore terminal on all exit paths. Note: os.Exit() doesn't run defers, + // so we call Restore explicitly before os.Exit and use defer for error returns. defer term.Restore(stdinFd, oldState) decoder := framing.NewDecoder(conn) @@ -313,7 +315,7 @@ func (w *Wrapper) ioLoopPTY(conn net.Conn, ndjson *framing.NDJSONWriter, stdinFd if done { exitCode = msg.ExitCode close(doneCh) - // Restore terminal before exit + // Restore terminal before os.Exit (defers don't run on os.Exit) term.Restore(stdinFd, oldState) os.Exit(exitCode) }