Skip to content
Closed
17 changes: 15 additions & 2 deletions cmd/gau/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"os"
"sync"
"time"

"github.com/lc/gau/v2/pkg/output"
"github.com/lc/gau/v2/runner"
Expand All @@ -14,6 +15,8 @@ import (
)

func main() {
startTime := time.Now()

cfg, err := flags.New().ReadInConfig()
if err != nil {
log.Warnf("error reading config: %v", err)
Expand Down Expand Up @@ -43,12 +46,13 @@ func main() {
}

var writeWg sync.WaitGroup
var urlCount int64
writeWg.Add(1)
go func(out io.Writer, JSON bool) {
defer writeWg.Done()
if JSON {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters)
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters, &urlCount)
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters, &urlCount); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}(out, config.JSON)
Expand Down Expand Up @@ -85,4 +89,13 @@ func main() {

// wait for writer to finish output
writeWg.Wait()

// Calculate duration
duration := time.Since(startTime)

// Log summary
log.Infof("=== Gau Execution Summary ===")
log.Infof("Total URLs: %d", urlCount)
log.Infof("Duration: %v", duration)
log.Infof("=============================")
}
90 changes: 87 additions & 3 deletions pkg/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package httpclient

import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"time"

"github.com/valyala/fasthttp"
Expand All @@ -12,8 +15,23 @@ var (
ErrNilResponse = errors.New("unexpected nil response")
ErrNon200Response = errors.New("API responded with non-200 status code")
ErrBadRequest = errors.New("API responded with 400 status code")
ErrRateLimited = errors.New("API rate limited")
)

// StatusCodeError is an error type that carries an HTTP status code
type StatusCodeError struct {
Code int
Msg string
}

func (e *StatusCodeError) Error() string {
return fmt.Sprintf("%s (status code: %d)", e.Msg, e.Code)
}

func (e *StatusCodeError) Unwrap() error {
return errors.New(e.Msg)
}

type Header struct {
Key string
Value string
Expand All @@ -39,6 +57,28 @@ func MakeRequest(c *fasthttp.Client, url string, maxRetries uint, timeout uint,
req.Header.Set("Accept", "*/*")
req.SetRequestURI(url)
respBody, err = doReq(c, req, timeout)

// Check if we should retry based on error type
if err != nil {
// Exponential backoff: 1s, 2s, 4s, 8s, 16s... with cap at 30s
backoffDuration := time.Duration(math.Pow(2, float64(retries-i))) * time.Second
if backoffDuration > 30*time.Second {
backoffDuration = 30 * time.Second
}
if i > 0 && shouldRetry(err) {
time.Sleep(backoffDuration)
continue
}
}

// Check for rate limit (429) or bad request (400) from error
if err != nil {
statusCode := getStatusCodeFromError(err)
if statusCode == 429 || statusCode == 400 {
return nil, ErrRateLimited
}
}

if err == nil {
break
}
Expand All @@ -49,6 +89,48 @@ func MakeRequest(c *fasthttp.Client, url string, maxRetries uint, timeout uint,
return respBody, nil
}

// shouldRetry determines if an error should trigger a retry
func shouldRetry(err error) bool {
if err == nil {
return false
}
// Network errors that should trigger retry
errMsg := err.Error()
retryableErrors := []string{
"connection refused",
"connection reset",
"connection timed out",
"no such host",
"timeout",
"server closed connection",
"network is unreachable",
"i/o timeout",
}
for _, pattern := range retryableErrors {
if containsIgnoreCase(errMsg, pattern) {
return true
}
}
return false
}

// containsIgnoreCase checks if s contains substr (case-insensitive)
func containsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}

// getStatusCodeFromError attempts to extract status code from error
func getStatusCodeFromError(err error) int {
if err == nil {
return 0
}
var statusErr *StatusCodeError
if errors.As(err, &statusErr) {
return statusErr.Code
}
return 0
}

// doReq handles http requests
func doReq(c *fasthttp.Client, req *fasthttp.Request, timeout uint) ([]byte, error) {
resp := fasthttp.AcquireResponse()
Expand All @@ -58,10 +140,12 @@ func doReq(c *fasthttp.Client, req *fasthttp.Request, timeout uint) ([]byte, err
return nil, err
}
if resp.StatusCode() != 200 {
if resp.StatusCode() == 400 {
return nil, ErrBadRequest
errMsg := fmt.Sprintf("API responded with status code %d", resp.StatusCode())
// Return wrapped error with status code for proper handling
return nil, &StatusCodeError{
Code: resp.StatusCode(),
Msg: errMsg,
}
return nil, ErrNon200Response
}
if resp.Body() == nil {
return nil, ErrNilResponse
Expand Down
30 changes: 23 additions & 7 deletions pkg/output/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package output
import (
"io"
"net/url"
"os"
"path"
"strings"
"sync/atomic"

mapset "github.com/deckarep/golang-set/v2"
jsoniter "github.com/json-iterator/go"
Expand All @@ -15,49 +17,63 @@ type JSONResult struct {
Url string `json:"url"`
}

func WriteURLs(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool) error {
func WriteURLs(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool, urlCount *int64) error {
lastURL := mapset.NewThreadUnsafeSet[string]()
for result := range results {
buf := bytebufferpool.Get()
u, err := url.Parse(result)
if err != nil {
continue
}
if path.Ext(u.Path) != "" && blacklistMap.Contains(strings.ToLower(path.Ext(u.Path))) {
ext := strings.TrimPrefix(strings.ToLower(path.Ext(u.Path)), ".")
if ext != "" && blacklistMap.Contains(ext) {
continue
}

if RemoveParameters && !lastURL.Contains(u.Host+u.Path) {
continue
if RemoveParameters {
if lastURL.Contains(u.Host + u.Path) {
continue // already seen this endpoint, skip duplicate params
}
lastURL.Add(u.Host + u.Path)
}
lastURL.Add(u.Host + u.Path)

buf.B = append(buf.B, []byte(result)...)
buf.B = append(buf.B, "\n"...)
_, err = writer.Write(buf.B)
if err != nil {
return err
}
atomic.AddInt64(urlCount, 1)
// Real-time flush: sync stdout after each write to prevent data loss
if writer == os.Stdout {
os.Stdout.Sync()
}
bytebufferpool.Put(buf)
}
return nil
}

func WriteURLsJSON(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool) {
func WriteURLsJSON(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool, urlCount *int64) {
var jr JSONResult
enc := jsoniter.NewEncoder(writer)
for result := range results {
u, err := url.Parse(result)
if err != nil {
continue
}
if blacklistMap.Contains(strings.ToLower(path.Ext(u.Path))) {
ext := strings.TrimPrefix(strings.ToLower(path.Ext(u.Path)), ".")
if ext != "" && blacklistMap.Contains(ext) {
continue
}
jr.Url = result
if err := enc.Encode(jr); err != nil {
// todo: handle this error
continue
}
atomic.AddInt64(urlCount, 1)
// Real-time flush: sync stdout after each write to prevent data loss
if writer == os.Stdout {
os.Stdout.Sync()
}
}
}
Loading