diff --git a/cmd/httpx/httpx.go b/cmd/httpx/httpx.go index 77e54f4c..78037c08 100644 --- a/cmd/httpx/httpx.go +++ b/cmd/httpx/httpx.go @@ -73,21 +73,26 @@ func main() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { - for range c { - gologger.Info().Msgf("CTRL+C pressed: Exiting\n") - httpxRunner.Close() - if options.ShouldSaveResume() { - gologger.Info().Msgf("Creating resume file: %s\n", runner.DefaultResumeFile) - err := httpxRunner.SaveResumeConfig() - if err != nil { - gologger.Error().Msgf("Couldn't create resume file: %s\n", err) - } - } - os.Exit(1) - } + // First Ctrl+C: stop dispatching, let in-flight requests finish + <-c + gologger.Info().Msgf("CTRL+C pressed: Exiting\n") + httpxRunner.Interrupt() + // Second Ctrl+C: force exit + <-c + gologger.Info().Msgf("Forcing exit\n") + os.Exit(1) }() httpxRunner.RunEnumeration() + + if httpxRunner.IsInterrupted() && options.ShouldSaveResume() { + gologger.Info().Msgf("Creating resume file: %s\n", runner.DefaultResumeFile) + err := httpxRunner.SaveResumeConfig() + if err != nil { + gologger.Error().Msgf("Couldn't create resume file: %s\n", err) + } + } + httpxRunner.Close() } diff --git a/runner/runner.go b/runner/runner.go index 14b91625..8647c23f 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -97,12 +97,32 @@ type Runner struct { simHashes gcache.Cache[uint64, struct{}] // Include simHashes for efficient duplicate detection httpApiEndpoint *Server authProvider authprovider.AuthProvider + interruptCh chan struct{} } func (r *Runner) HTTPX() *httpx.HTTPX { return r.hp } +// Interrupt signals the runner to stop dispatching new items. +func (r *Runner) Interrupt() { + select { + case <-r.interruptCh: + default: + close(r.interruptCh) + } +} + +// IsInterrupted returns true if the runner was interrupted. +func (r *Runner) IsInterrupted() bool { + select { + case <-r.interruptCh: + return true + default: + return false + } +} + // picked based on try-fail but it seems to close to one it's used https://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html#c1992 var hammingDistanceThreshold int = 22 @@ -121,7 +141,8 @@ type pHashUrl struct { // New creates a new client for running enumeration process. func New(options *Options) (*Runner, error) { runner := &Runner{ - options: options, + options: options, + interruptCh: make(chan struct{}), } var err error if options.Wappalyzer != nil { @@ -664,6 +685,16 @@ func (r *Runner) streamInput() (chan string, error) { go func() { defer close(out) + // trySend sends item to out, returning false if interrupted + trySend := func(item string) bool { + select { + case <-r.interruptCh: + return false + case out <- item: + return true + } + } + if fileutil.FileExists(r.options.InputFile) { // check if input mode is specified for special format handling if format := r.getInputFormat(); format != nil { @@ -676,9 +707,9 @@ func (r *Runner) streamInput() (chan string, error) { if err := format.Parse(finput, func(item string) bool { item = strings.TrimSpace(item) if r.options.SkipDedupe || r.testAndSet(item) { - out <- item + return trySend(item) } - return true + return !r.IsInterrupted() }); err != nil { gologger.Error().Msgf("Could not parse input file '%s': %s\n", r.options.InputFile, err) return @@ -690,7 +721,9 @@ func (r *Runner) streamInput() (chan string, error) { } for item := range fchan { if r.options.SkipDedupe || r.testAndSet(item) { - out <- item + if !trySend(item) { + return + } } } } @@ -706,7 +739,9 @@ func (r *Runner) streamInput() (chan string, error) { } for item := range fchan { if r.options.SkipDedupe || r.testAndSet(item) { - out <- item + if !trySend(item) { + return + } } } } @@ -718,7 +753,9 @@ func (r *Runner) streamInput() (chan string, error) { } for item := range fchan { if r.options.SkipDedupe || r.testAndSet(item) { - out <- item + if !trySend(item) { + return + } } } } @@ -1402,6 +1439,12 @@ func (r *Runner) RunEnumeration() { wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads)) processItem := func(k string) error { + select { + case <-r.interruptCh: + return nil + default: + } + if r.options.resumeCfg != nil { r.options.resumeCfg.current = k r.options.resumeCfg.currentIndex++ @@ -1447,6 +1490,9 @@ func (r *Runner) RunEnumeration() { if r.options.Stream { for item := range streamChan { + if r.IsInterrupted() { + break + } _ = processItem(item) } } else { diff --git a/runner/runner_test.go b/runner/runner_test.go index 10b8320b..832c4e36 100644 --- a/runner/runner_test.go +++ b/runner/runner_test.go @@ -15,6 +15,74 @@ import ( "github.com/stretchr/testify/require" ) +func TestRunner_resumeAfterInterrupt(t *testing.T) { + domains := []string{"a.com", "b.com", "c.com", "d.com", "e.com", "f.com", "g.com", "h.com", "i.com", "j.com"} + interruptAfter := 4 + + // --- Full scan (reference): process all domains without interrupt --- + rFull, err := New(&Options{}) + require.Nil(t, err, "could not create httpx runner") + rFull.options.resumeCfg = &ResumeCfg{} + var fullOutput []string + for _, d := range domains { + rFull.options.resumeCfg.current = d + rFull.options.resumeCfg.currentIndex++ + fullOutput = append(fullOutput, d) + } + + // --- Interrupted scan: process items, interrupt after interruptAfter --- + rInt, err := New(&Options{}) + require.Nil(t, err, "could not create httpx runner") + rInt.options.resumeCfg = &ResumeCfg{} + var interruptedOutput []string + for _, d := range domains { + // same check as processItem: bail out if interrupted + select { + case <-rInt.interruptCh: + continue + default: + } + + rInt.options.resumeCfg.current = d + rInt.options.resumeCfg.currentIndex++ + interruptedOutput = append(interruptedOutput, d) + + if len(interruptedOutput) == interruptAfter { + rInt.Interrupt() + } + } + + // simulate SaveResumeConfig: save the index after interrupt + savedIndex := rInt.options.resumeCfg.currentIndex + + // the saved index must equal exactly the number of items that were processed + require.Equal(t, interruptAfter, savedIndex, "resume index should equal number of completed items") + // every domain before the index must be in the interrupted output + require.Equal(t, domains[:interruptAfter], interruptedOutput, "interrupted output should contain exactly the first N domains") + + // --- Resumed scan: load saved index, skip already-processed items --- + rRes, err := New(&Options{}) + require.Nil(t, err, "could not create httpx runner") + rRes.options.resumeCfg = &ResumeCfg{Index: savedIndex} + var resumedOutput []string + for _, d := range domains { + // same resume-skip logic as processItem + rRes.options.resumeCfg.current = d + rRes.options.resumeCfg.currentIndex++ + if rRes.options.resumeCfg.currentIndex <= rRes.options.resumeCfg.Index { + continue + } + resumedOutput = append(resumedOutput, d) + } + + // every domain after the index must be in the resumed output + require.Equal(t, domains[interruptAfter:], resumedOutput, "resumed output should contain exactly the remaining domains") + + // union of interrupted + resumed must equal the full scan + combined := append(interruptedOutput, resumedOutput...) + require.Equal(t, fullOutput, combined, "interrupted + resumed should equal full scan") +} + func TestRunner_domain_targets(t *testing.T) { options := &Options{} r, err := New(options)