diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 55f06a2..ce247b0 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: stable + go-version-file: 'go.mod' - name: Install task uses: jaxxstorm/action-install-gh-release@v2.1.0 with: diff --git a/cmd/logwrap/integration_test.go b/cmd/logwrap/integration_test.go index cc1e19e..abe23e5 100644 --- a/cmd/logwrap/integration_test.go +++ b/cmd/logwrap/integration_test.go @@ -20,6 +20,7 @@ import ( "github.com/sgaunet/logwrap/pkg/processor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" ) // testBinaryPath holds the path to the compiled logwrap binary for subprocess tests. @@ -49,9 +50,9 @@ func TestMain(m *testing.M) { os.Exit(1) } - exitCode := m.Run() - os.RemoveAll(tmpDir) - os.Exit(exitCode) + goleak.VerifyTestMain(m, goleak.Cleanup(func(exitCode int) { + os.RemoveAll(tmpDir) + })) } // runPipeline constructs the full logwrap pipeline with a thread-safe captured diff --git a/cmd/logwrap/main.go b/cmd/logwrap/main.go index 589c992..0abbda1 100644 --- a/cmd/logwrap/main.go +++ b/cmd/logwrap/main.go @@ -26,10 +26,12 @@ var ( ) const ( - exitCodeSIGINT = 130 // 128 + 2 (SIGINT) - exitCodeSIGTERM = 143 // 128 + 15 (SIGTERM) + signalExitCodeBase = 128 // UNIX convention: 128 + signal number + exitCodeSIGINT = signalExitCodeBase + 2 // SIGINT + exitCodeSIGTERM = signalExitCodeBase + 15 // SIGTERM gracefulShutdownTimeout = 5 * time.Second processorWaitTimeout = 3 * time.Second + killTimeout = 2 * time.Second usage = `LogWrap - Command execution wrapper with configurable log prefixes Usage: @@ -292,6 +294,16 @@ func run(cfg *config.Config, command []string) int { procOpts = append(procOpts, processor.WithFilter(f)) } + // Set up signal handling before starting the child process to avoid + // a race where a signal arrives after Start() but before Notify(), + // which would use Go's default handler (os.Exit) and orphan the child. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() + + procOpts = append(procOpts, processor.WithContext(ctx)) proc := processor.New(form, os.Stdout, procOpts...) if err := exec.Start(); err != nil { @@ -301,12 +313,7 @@ func run(cfg *config.Config, command []string) int { stdout, stderr := exec.GetStreams() - // Set up signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - // Start stream processing in background - ctx := context.Background() processingDone := make(chan error, 1) go func() { processingDone <- proc.ProcessStreams(ctx, stdout, stderr) @@ -372,7 +379,16 @@ func handleSignalShutdown(exec *executor.Executor, proc *processor.Processor, si fmt.Fprintf(os.Stderr, "Warning: failed to kill process: %v\n", err) } proc.Stop() - return <-cmdDone // Wait for process to actually die + // Wait for process to die, with a hard timeout to avoid hanging + // indefinitely if the process is in an unkillable state (e.g., D state on Linux). + killTimer := time.NewTimer(killTimeout) + defer killTimer.Stop() + select { + case cmdErr := <-cmdDone: + return cmdErr + case <-killTimer.C: + return nil + } } } @@ -391,7 +407,7 @@ func waitForProcessing(proc *processor.Processor, processingDone chan error) { } } -func determineExitCode(exec *executor.Executor, receivedSignal os.Signal, _ error) int { +func determineExitCode(exec *executor.Executor, receivedSignal os.Signal, cmdErr error) int { // If we received a signal, use signal-based exit code if receivedSignal != nil { switch receivedSignal { @@ -399,9 +415,20 @@ func determineExitCode(exec *executor.Executor, receivedSignal os.Signal, _ erro return exitCodeSIGINT case syscall.SIGTERM: return exitCodeSIGTERM + default: + if sig, ok := receivedSignal.(syscall.Signal); ok { + return signalExitCodeBase + int(sig) + } + + return 1 } } - // Otherwise use command's exit code + // If the command failed with a non-exit error (e.g., I/O error, context error), + // the executor's exit code stays at 0. Use 1 to avoid masking the failure. + if cmdErr != nil && exec.GetExitCode() == 0 { + return 1 + } + return exec.GetExitCode() } diff --git a/go.mod b/go.mod index 4852694..5fd57e7 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,12 @@ go 1.25.0 require gopkg.in/yaml.v3 v3.0.1 -require github.com/itchyny/timefmt-go v0.1.8 +require ( + github.com/itchyny/timefmt-go v0.1.8 + go.uber.org/goleak v1.3.0 +) + +require github.com/kr/text v0.2.0 // indirect require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 6eeb4a1..f4725df 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,20 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/itchyny/timefmt-go v0.1.8 h1:1YEo1JvfXeAHKdjelbYr/uCuhkybaHCeTkH8Bo791OI= github.com/itchyny/timefmt-go v0.1.8/go.mod h1:5E46Q+zj7vbTgWY8o5YkMeYb4I6GeWLFnetPy5oBrAI= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/apperrors/testmain_test.go b/pkg/apperrors/testmain_test.go new file mode 100644 index 0000000..05ce2cb --- /dev/null +++ b/pkg/apperrors/testmain_test.go @@ -0,0 +1,11 @@ +package apperrors + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index ca2cad9..786f5c6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -62,6 +62,7 @@ import ( "io" "os" "path/filepath" + "slices" "strings" "github.com/sgaunet/logwrap/pkg/apperrors" @@ -321,7 +322,7 @@ func parseCLIFlags(args []string) (*CLIFlags, error) { } func applyCLIOverrides(config *Config, flags *CLIFlags) { - if flags.Template != nil && *flags.Template != "" { + if flags.setFlags["template"] { config.Prefix.Template = *flags.Template } if flags.setFlags["utc"] { @@ -330,7 +331,7 @@ func applyCLIOverrides(config *Config, flags *CLIFlags) { if flags.setFlags["colors"] { config.Prefix.Colors.Enabled = *flags.ColorsEnabled } - if flags.OutputFormat != nil && *flags.OutputFormat != "" { + if flags.setFlags["format"] { config.Output.Format = *flags.OutputFormat } } @@ -369,9 +370,10 @@ func FindConfigFile() string { // - Path traversal: rejects paths containing ".." after filepath.Clean // - File type: only .yaml and .yml extensions are accepted (case-insensitive) func validateConfigPath(configFile string) error { - // Prevent path traversal attacks + // Prevent path traversal attacks by checking for ".." as a path component, + // not a substring — filenames like "backup..yaml" are valid. cleaned := filepath.Clean(configFile) - if strings.Contains(cleaned, "..") { + if slices.Contains(strings.Split(cleaned, string(filepath.Separator)), "..") { return apperrors.ErrPathTraversal } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index bc833eb..968d785 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -490,7 +490,7 @@ func TestApplyCLIOverrides(t *testing.T) { TimestampUTC: &utc, ColorsEnabled: &colors, OutputFormat: &format, - setFlags: map[string]bool{"utc": true, "colors": true}, + setFlags: map[string]bool{"utc": true, "colors": true, "template": true, "format": true}, } // Apply overrides diff --git a/pkg/config/testmain_test.go b/pkg/config/testmain_test.go new file mode 100644 index 0000000..9be74f1 --- /dev/null +++ b/pkg/config/testmain_test.go @@ -0,0 +1,11 @@ +package config + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/config/validation.go b/pkg/config/validation.go index 93145b9..c68ddff 100644 --- a/pkg/config/validation.go +++ b/pkg/config/validation.go @@ -374,6 +374,10 @@ func isValidLogLevel(level string, validLevels []string) bool { // that assigns levels to lines. Without detection, all lines have // an empty detected level and level filters silently drop everything. func (c *Config) validateFilter() error { + if !c.Filter.Enabled { + return nil + } + validLevels := []string{"TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL"} if !c.LogLevel.Detection.Enabled { @@ -387,19 +391,19 @@ func (c *Config) validateFilter() error { if err := validateFilterLevelNames(c.Filter.ExcludeLevels, "exclude_levels", validLevels); err != nil { return err } - if slices.Contains(c.Filter.ExcludePatterns, "") { - return fmt.Errorf("%w in exclude_patterns", apperrors.ErrEmptyFilterPattern) - } - if slices.Contains(c.Filter.IncludePatterns, "") { - return fmt.Errorf("%w in include_patterns", apperrors.ErrEmptyFilterPattern) - } - if err := validateRegexPatterns(c.Filter.ExcludePatterns, "exclude_patterns"); err != nil { + if err := validateFilterPatterns(c.Filter.ExcludePatterns, "exclude_patterns"); err != nil { return err } - if err := validateRegexPatterns(c.Filter.IncludePatterns, "include_patterns"); err != nil { - return err + return validateFilterPatterns(c.Filter.IncludePatterns, "include_patterns") +} + +// validateFilterPatterns checks that a pattern list contains no empty strings +// and that all entries are valid regular expressions. +func validateFilterPatterns(patterns []string, field string) error { + if slices.Contains(patterns, "") { + return fmt.Errorf("%w in %s", apperrors.ErrEmptyFilterPattern, field) } - return nil + return validateRegexPatterns(patterns, field) } // validateFilterLevelNames checks that all level names in the list are valid diff --git a/pkg/config/validation_test.go b/pkg/config/validation_test.go index 471c361..b796f0b 100644 --- a/pkg/config/validation_test.go +++ b/pkg/config/validation_test.go @@ -1057,6 +1057,7 @@ func TestConfig_ValidateFilter_EmptyPatterns(t *testing.T) { t.Parallel() cfg := getDefaultConfig() + cfg.Filter.Enabled = true cfg.Filter.ExcludePatterns = tt.exclude cfg.Filter.IncludePatterns = tt.include @@ -1094,6 +1095,7 @@ func TestConfig_ValidateFilter_LevelsRequireDetection(t *testing.T) { t.Parallel() cfg := getDefaultConfig() + cfg.Filter.Enabled = true cfg.LogLevel.Detection.Enabled = tt.detection if !tt.detection { cfg.LogLevel.Detection.Keywords = nil @@ -1131,6 +1133,7 @@ func TestConfig_ValidateFilter_InvalidRegex(t *testing.T) { t.Parallel() cfg := getDefaultConfig() + cfg.Filter.Enabled = true cfg.Filter.ExcludePatterns = tt.exclude cfg.Filter.IncludePatterns = tt.include @@ -1170,6 +1173,7 @@ func TestConfig_ValidateFilter_InvalidLevelNames(t *testing.T) { t.Parallel() cfg := getDefaultConfig() + cfg.Filter.Enabled = true cfg.Filter.IncludeLevels = tt.include cfg.Filter.ExcludeLevels = tt.exclude diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index d9e136a..4073ddd 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -44,7 +44,9 @@ import ( "os" "os/exec" "path/filepath" + "slices" "strings" + "sync/atomic" "syscall" "time" @@ -68,8 +70,8 @@ type Executor struct { stderrPipe io.ReadCloser commandName string // stored for error messages exitCode int - isStarted bool - isFinished bool + isStarted atomic.Bool + isFinished atomic.Bool } // New creates a new Executor instance for the given command. @@ -88,7 +90,10 @@ func New(command []string) (*Executor, error) { // Send SIGTERM (not SIGKILL) when the context is cancelled. // If the process doesn't exit within WaitDelay, Go escalates to SIGKILL. cmd.Cancel = func() error { - return cmd.Process.Signal(syscall.SIGTERM) + if cmd.Process != nil { + return cmd.Process.Signal(syscall.SIGTERM) + } + return nil } cmd.WaitDelay = gracefulStopDelay cmd.Stdin = os.Stdin @@ -120,7 +125,7 @@ func New(command []string) (*Executor, error) { // Start begins execution of the command. func (e *Executor) Start() error { - if e.isStarted { + if e.isStarted.Load() { return appErrors.ErrExecutorStarted } @@ -128,32 +133,50 @@ func (e *Executor) Start() error { return fmt.Errorf("failed to start command %q: %w", e.commandName, err) } - e.isStarted = true + e.isStarted.Store(true) return nil } // Wait waits for the command to complete and returns any error. func (e *Executor) Wait() error { - if !e.isStarted { + if !e.isStarted.Load() { return appErrors.ErrExecutorNotStarted } - if e.isFinished { + if e.isFinished.Load() { return nil } err := e.cmd.Wait() - e.isFinished = true if err != nil { var exitError *exec.ExitError - if errors.As(err, &exitError) { + + switch { + // ErrWaitDelay means the process exited but its pipes weren't fully + // drained before WaitDelay expired (e.g., a grandchild holds them open). + // The process itself succeeded, so treat this as a normal exit. + case errors.Is(err, exec.ErrWaitDelay): + e.isFinished.Store(true) + return nil + + case errors.As(err, &exitError): e.exitCode = resolveExitCode(exitError) - } else { + + // Context cancellation can race with the process exiting. If the + // process already exited, extract its real exit code instead of + // treating context.Canceled as a generic failure. + case errors.Is(err, context.Canceled) && e.cmd.ProcessState != nil: + e.exitCode = e.cmd.ProcessState.ExitCode() + + default: + e.isFinished.Store(true) return fmt.Errorf("command %q execution failed: %w", e.commandName, err) } } + e.isFinished.Store(true) + return nil } @@ -183,14 +206,14 @@ func (e *Executor) GetExitCode() int { // IsFinished returns true if the command has finished execution. func (e *Executor) IsFinished() bool { - return e.isFinished + return e.isFinished.Load() } // Stop gracefully terminates the command using SIGTERM. // Context cancellation triggers the custom Cancel function (SIGTERM). // If the process doesn't exit within WaitDelay, Go escalates to SIGKILL. func (e *Executor) Stop() error { - if !e.isStarted || e.isFinished { + if !e.isStarted.Load() || e.isFinished.Load() { return nil } @@ -200,7 +223,7 @@ func (e *Executor) Stop() error { // Kill forcefully terminates the command with SIGKILL. func (e *Executor) Kill() error { - if !e.isStarted || e.isFinished { + if !e.isStarted.Load() || e.isFinished.Load() { return nil } @@ -238,10 +261,14 @@ func (e *Executor) Cleanup() { // Commands run with the current user's privileges. Callers are responsible // for validating commands before passing them to logwrap. func validateCommand(command string) error { + // Check the raw path before filepath.Clean, which normalizes away ".." + // in absolute paths (e.g., "/../etc/passwd" → "/etc/passwd"). + if slices.Contains(strings.Split(command, string(filepath.Separator)), "..") { + return appErrors.ErrCommandPathTraversal + } cleaned := filepath.Clean(command) - if strings.Contains(cleaned, "..") { + if slices.Contains(strings.Split(cleaned, string(filepath.Separator)), "..") { return appErrors.ErrCommandPathTraversal } - return nil } \ No newline at end of file diff --git a/pkg/executor/testmain_test.go b/pkg/executor/testmain_test.go new file mode 100644 index 0000000..604922f --- /dev/null +++ b/pkg/executor/testmain_test.go @@ -0,0 +1,11 @@ +package executor_test + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index 217b0ee..7cf12d7 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -99,6 +99,15 @@ func (f *Filter) passesLevelFilter(line string) bool { } detectedLevel := f.detectLevel(strings.ToUpper(line)) + + // Lines with no detected level always pass the level filter. + // Only lines with a recognized level keyword are subject to + // include/exclude rules. This prevents plain output (e.g., + // "Starting server...") from being silently dropped. + if detectedLevel == "" { + return true + } + if len(f.includeLevels) > 0 && !f.includeLevels[detectedLevel] { return false } diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index 52bae54..c537919 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -152,7 +152,7 @@ func TestFilter_IncludeLevels(t *testing.T) { {"warn line", "WARN: disk space low", true}, {"info line", "INFO: started", false}, {"debug line", "DEBUG: variable dump", false}, - {"no keyword", "regular message", false}, + {"no keyword", "regular message", true}, // lines without a detected level pass through } for _, tt := range tests { diff --git a/pkg/filter/testmain_test.go b/pkg/filter/testmain_test.go new file mode 100644 index 0000000..3c4b28b --- /dev/null +++ b/pkg/filter/testmain_test.go @@ -0,0 +1,11 @@ +package filter + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/formatter/formatter.go b/pkg/formatter/formatter.go index 2bf9823..415db1a 100644 --- a/pkg/formatter/formatter.go +++ b/pkg/formatter/formatter.go @@ -63,6 +63,7 @@ import ( "time" "github.com/itchyny/timefmt-go" + "github.com/sgaunet/logwrap/pkg/apperrors" "github.com/sgaunet/logwrap/pkg/config" "github.com/sgaunet/logwrap/pkg/processor" ) @@ -120,10 +121,22 @@ func New(cfg *config.Config) (*DefaultFormatter, error) { colors := make(map[string]string) if cfg.Prefix.Colors.Enabled { + infoCode, err := getColorCode(cfg.Prefix.Colors.Info) + if err != nil { + return nil, fmt.Errorf("invalid info color: %w", err) + } + errorCode, err := getColorCode(cfg.Prefix.Colors.Error) + if err != nil { + return nil, fmt.Errorf("invalid error color: %w", err) + } + timestampCode, err := getColorCode(cfg.Prefix.Colors.Timestamp) + if err != nil { + return nil, fmt.Errorf("invalid timestamp color: %w", err) + } colors = map[string]string{ - "info": getColorCode(cfg.Prefix.Colors.Info), - "error": getColorCode(cfg.Prefix.Colors.Error), - "timestamp": getColorCode(cfg.Prefix.Colors.Timestamp), + "info": infoCode, + "error": errorCode, + "timestamp": timestampCode, "reset": "\033[0m", } } @@ -134,10 +147,19 @@ func New(cfg *config.Config) (*DefaultFormatter, error) { userInfo: userInfo, pid: os.Getpid(), colors: colors, - templateUsesLine: strings.Contains(cfg.Prefix.Template, ".Line"), + templateUsesLine: templateReferencesLine(cfg.Prefix.Template), }, nil } +// templateReferencesLine reports whether the template string uses the .Line +// field, accounting for Go template whitespace-trim syntax ({{- and {{). +func templateReferencesLine(tmpl string) bool { + return strings.Contains(tmpl, "{{.Line") || + strings.Contains(tmpl, "{{ .Line") || + strings.Contains(tmpl, "{{- .Line") || + strings.Contains(tmpl, "{{-.Line") +} + // FormatLine formats a log line according to the configured output format. func (f *DefaultFormatter) FormatLine(line string, streamType processor.StreamType) string { data := f.buildTemplateData(line, streamType) @@ -388,19 +410,23 @@ func (f *DefaultFormatter) applyTimestampColor(text, color string) string { return text } -func getColorCode(colorName string) string { - colors := map[string]string{ - "black": "\033[30m", - "red": "\033[31m", - "green": "\033[32m", - "yellow": "\033[33m", - "blue": "\033[34m", - "magenta": "\033[35m", - "cyan": "\033[36m", - "white": "\033[37m", - "none": "", - "": "", - } - - return colors[strings.ToLower(colorName)] +var colorCodes = map[string]string{ + "black": "\033[30m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "none": "", + "": "", +} + +func getColorCode(colorName string) (string, error) { + code, ok := colorCodes[strings.ToLower(colorName)] + if !ok { + return "", fmt.Errorf("%w: %q", apperrors.ErrInvalidColor, colorName) + } + return code, nil } \ No newline at end of file diff --git a/pkg/formatter/formatter_test.go b/pkg/formatter/formatter_test.go index 4d70fe1..faa019c 100644 --- a/pkg/formatter/formatter_test.go +++ b/pkg/formatter/formatter_test.go @@ -746,30 +746,36 @@ func TestGetColorCode(t *testing.T) { t.Parallel() tests := []struct { - name string + name string colorName string - expected string + expected string + wantErr bool }{ - {"black", "black", "\033[30m"}, - {"red", "red", "\033[31m"}, - {"green", "green", "\033[32m"}, - {"yellow", "yellow", "\033[33m"}, - {"blue", "blue", "\033[34m"}, - {"magenta", "magenta", "\033[35m"}, - {"cyan", "cyan", "\033[36m"}, - {"white", "white", "\033[37m"}, - {"none", "none", ""}, - {"empty", "", ""}, - {"invalid", "invalid", ""}, - {"case insensitive", "RED", "\033[31m"}, + {"black", "black", "\033[30m", false}, + {"red", "red", "\033[31m", false}, + {"green", "green", "\033[32m", false}, + {"yellow", "yellow", "\033[33m", false}, + {"blue", "blue", "\033[34m", false}, + {"magenta", "magenta", "\033[35m", false}, + {"cyan", "cyan", "\033[36m", false}, + {"white", "white", "\033[37m", false}, + {"none", "none", "", false}, + {"empty", "", "", false}, + {"invalid", "invalid", "", true}, + {"case insensitive", "RED", "\033[31m", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := getColorCode(tt.colorName) - assert.Equal(t, tt.expected, result) + result, err := getColorCode(tt.colorName) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } }) } } diff --git a/pkg/formatter/testmain_test.go b/pkg/formatter/testmain_test.go new file mode 100644 index 0000000..11dff5d --- /dev/null +++ b/pkg/formatter/testmain_test.go @@ -0,0 +1,11 @@ +package formatter + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/processor/processor.go b/pkg/processor/processor.go index cf2c28e..2b0a8d5 100644 --- a/pkg/processor/processor.go +++ b/pkg/processor/processor.go @@ -103,34 +103,30 @@ type LineFilter interface { // Processor handles real-time processing of command output streams. type Processor struct { - formatter Formatter - filter LineFilter - output io.Writer - wg sync.WaitGroup - errors []error - mutex sync.Mutex - cancel context.CancelFunc - stopCh chan struct{} - stopOnce sync.Once + formatter Formatter + filter LineFilter + output io.Writer + wg sync.WaitGroup + errors []error + mutex sync.Mutex + parentDone <-chan struct{} // closed when parent context is cancelled; nil if no WithContext + stopCh chan struct{} + readers []io.Reader // stored so Stop() can close them to unblock scanners + stopOnce sync.Once } // Option defines a function that configures a Processor. type Option func(*Processor) -// WithContext sets a cancellable context for the processor. -// The derived context's cancel function is called when Stop() is invoked, -// and the done channel is used to propagate cancellation to ProcessStreams. +// WithContext enables context-based cancellation for the processor. +// When Stop() is called or the parent context is cancelled, it signals +// ProcessStreams to cancel processing. No goroutines are created until +// ProcessStreams is called, so it is safe to create a processor with +// WithContext and never call ProcessStreams. func WithContext(ctx context.Context) Option { return func(p *Processor) { - derived, cancel := context.WithCancel(ctx) //nolint:gosec // G118 - cancel is called via Stop() - p.cancel = cancel + p.parentDone = ctx.Done() p.stopCh = make(chan struct{}) - - // Monitor the derived context and close stopCh when it's done - go func() { - <-derived.Done() - close(p.stopCh) - }() } } @@ -147,7 +143,6 @@ func New(formatter Formatter, output io.Writer, opts ...Option) *Processor { p := &Processor{ formatter: formatter, output: output, - cancel: func() {}, errors: make([]error, 0), } @@ -164,20 +159,14 @@ func (p *Processor) ProcessStreams(ctx context.Context, stdout, stderr io.Reader return pkgerrors.ErrReadersNil } - // If WithContext was used, merge the stop channel so either the passed - // ctx or Stop() can cancel processing. - if p.stopCh != nil { - var mergedCancel context.CancelFunc - ctx, mergedCancel = context.WithCancel(ctx) - defer mergedCancel() + p.mutex.Lock() + p.readers = []io.Reader{stdout, stderr} + p.mutex.Unlock() - go func() { - select { - case <-p.stopCh: - mergedCancel() - case <-ctx.Done(): - } - }() + if p.stopCh != nil { + var cancel context.CancelFunc + ctx, cancel = p.setupCancellation(ctx) + defer cancel() } const streamCount = 2 @@ -199,19 +188,39 @@ func (p *Processor) ProcessStreams(ctx context.Context, stdout, stderr io.Reader p.wg.Wait() - if len(p.errors) > 0 { - return fmt.Errorf("%w: %v", pkgerrors.ErrProcessingErrors, p.errors) + // Clear reader references so Stop() won't close them — the executor + // owns these pipes and will close them via Cleanup(). + p.mutex.Lock() + p.readers = nil + p.mutex.Unlock() + + p.Stop() // Clean up cancellation goroutines + + if errs := p.GetErrors(); len(errs) > 0 { + return fmt.Errorf("%w: %v", pkgerrors.ErrProcessingErrors, errs) } return nil } -// Stop cancels the processor context to stop stream processing. +// Stop signals the processor to stop stream processing. // Safe to call multiple times - subsequent calls are no-ops. +// If the readers implement io.Closer, they are closed to unblock +// any in-progress scanner.Scan() calls. func (p *Processor) Stop() { p.stopOnce.Do(func() { - if p.cancel != nil { - p.cancel() + if p.stopCh != nil { + close(p.stopCh) + } + + p.mutex.Lock() + readers := p.readers + p.mutex.Unlock() + + for _, r := range readers { + if c, ok := r.(io.Closer); ok { + _ = c.Close() + } } }) } @@ -230,6 +239,7 @@ func (p *Processor) Wait(timeout time.Duration) error { return nil case <-time.After(timeout): p.Stop() + <-done return fmt.Errorf("%w after %v", pkgerrors.ErrProcessorTimeout, timeout) } } @@ -244,6 +254,32 @@ func (p *Processor) GetErrors() []error { return errors } +// setupCancellation wires stopCh and parent context into the given ctx. +// Returns the enhanced ctx and a cleanup function that must be deferred. +func (p *Processor) setupCancellation(ctx context.Context) (context.Context, context.CancelFunc) { + ctx, mergedCancel := context.WithCancel(ctx) //nolint:gosec // G118 - cancel is returned to caller + + go func() { + select { + case <-p.stopCh: + mergedCancel() + case <-ctx.Done(): + } + }() + + if p.parentDone != nil { + go func() { + select { + case <-p.parentDone: + p.Stop() + case <-p.stopCh: + } + }() + } + + return ctx, mergedCancel +} + // processStream reads lines from a single stream using [bufio.Scanner]. // // Scanner buffer configuration: @@ -287,12 +323,6 @@ func (p *Processor) processStream(ctx context.Context, stream io.Reader, streamT scanner.Buffer(buf, maxScannerSize) for scanner.Scan() { - select { - case <-ctx.Done(): - return fmt.Errorf("context cancelled: %w", ctx.Err()) - default: - } - line := scanner.Text() if p.filter != nil && !p.filter.ShouldInclude(line) { @@ -304,6 +334,14 @@ func (p *Processor) processStream(ctx context.Context, stream io.Reader, streamT if _, err := p.output.Write([]byte(formattedLine + "\n")); err != nil { return fmt.Errorf("failed to write to output: %w", err) } + + // Check for context cancellation after writing the line, not before, + // so that already-scanned lines are never silently dropped. + select { + case <-ctx.Done(): + return nil + default: + } } if err := scanner.Err(); err != nil { @@ -322,13 +360,12 @@ func (p *Processor) processStream(ctx context.Context, stream io.Reader, streamT } // isExpectedStreamError returns true for errors that occur during normal -// process shutdown: EOF, closed file descriptors, and closed pipes. +// process shutdown: closed file descriptors and closed pipes. +// Note: bufio.Scanner.Err() never returns io.EOF (it returns nil at EOF), +// and errors.Is already unwraps *os.PathError chains, so only os.ErrClosed +// needs to be checked. func isExpectedStreamError(err error) bool { - if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) { - return true - } - var pathErr *os.PathError - return errors.As(err, &pathErr) && errors.Is(pathErr.Err, os.ErrClosed) + return errors.Is(err, os.ErrClosed) } func (p *Processor) addError(err error) { diff --git a/pkg/processor/processor_test.go b/pkg/processor/processor_test.go index 39de5bf..c89791f 100644 --- a/pkg/processor/processor_test.go +++ b/pkg/processor/processor_test.go @@ -273,9 +273,9 @@ func TestProcessor_ProcessStreams_ContextCancellation(t *testing.T) { err := p.ProcessStreams(ctx, slowReader, strings.NewReader("")) - // Should return an error due to context cancellation - assert.Error(t, err) - assert.Contains(t, err.Error(), "processing errors occurred") + // Context cancellation is a normal shutdown path, not an error. + // Some lines may have been processed before cancellation. + assert.NoError(t, err) } func TestProcessor_ProcessStreams_FormatterError(t *testing.T) { diff --git a/pkg/processor/testmain_test.go b/pkg/processor/testmain_test.go new file mode 100644 index 0000000..43f8ac2 --- /dev/null +++ b/pkg/processor/testmain_test.go @@ -0,0 +1,11 @@ +package processor_test + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +}