Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions cmd/collect_batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package cmd

import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/itchyny/gojq"
"github.com/mattn/go-isatty"
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
)

// collectBatchCmd represents the collect batch command
var collectBatchCmd = NewCollectBatchCmd()

func init() {
GetCollectCmd().AddCommand(GetCollectBatchCmd())
}

func GetCollectBatchCmd() *cobra.Command {
return collectBatchCmd
}

func NewCollectBatchCmd() *cobra.Command {
collectBatchCmd := &cobra.Command{
Use: "batch [file|-]",
Short: "Batch collect from a list of URLs",
Long: `Collect multiple resources from a list of URLs provided in a file or via stdin.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
inputFile := args[0]
outputDir, _ := cmd.Flags().GetString("output-dir")
delay, _ := cmd.Flags().GetString("delay")
skipExisting, _ := cmd.Flags().GetBool("continue")
parallel, _ := cmd.Flags().GetInt("parallel")
jqFilter, _ := cmd.Flags().GetString("jq")

delayDuration, err := time.ParseDuration(delay)
if err != nil {
return fmt.Errorf("invalid delay duration: %w", err)
}

var reader io.Reader

if inputFile == "-" {
reader = os.Stdin
} else {
file, err := os.Open(inputFile)
if err != nil {
return fmt.Errorf("error opening input file: %w", err)
}
defer file.Close()
reader = file
}

urls, err := readURLs(reader, jqFilter)
if err != nil {
return fmt.Errorf("error reading urls: %w", err)
}

if err := os.MkdirAll(outputDir, os.ModePerm); err != nil {
return fmt.Errorf("error creating output directory: %w", err)
}

urlsChan := make(chan string, len(urls))
var wg sync.WaitGroup
var bar *progressbar.ProgressBar
var outMutex sync.Mutex

if isatty.IsTerminal(os.Stdout.Fd()) {
bar = progressbar.Default(int64(len(urls)))
}

for i := 0; i < parallel; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for u := range urlsChan {
downloadURL(cmd, u, outputDir, skipExisting, delayDuration, bar, &outMutex)
if bar != nil {
bar.Add(1)
}
}
}()
}

for _, u := range urls {
urlsChan <- u
}
close(urlsChan)

wg.Wait()

return nil
},
}

collectBatchCmd.Flags().IntP("parallel", "p", 1, "Number of concurrent downloads")
collectBatchCmd.Flags().String("delay", "0s", "Delay between requests")
collectBatchCmd.Flags().StringP("output-dir", "o", ".", "Base output directory")
collectBatchCmd.Flags().Bool("continue", false, "Skip already collected files")
collectBatchCmd.Flags().String("jq", "", "jq filter to extract URLs from JSON input")

return collectBatchCmd
}

func downloadURL(cmd *cobra.Command, u, outputDir string, skipExisting bool, delayDuration time.Duration, bar *progressbar.ProgressBar, outMutex *sync.Mutex) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical race condition in this function. When multiple parallel workers try to download URLs that map to the same file path, they can interfere with each other, leading to corrupted files. For example, one worker might be writing to a file while another truncates it by calling os.Create.

To fix this, you should serialize all file operations for a given path. You can achieve this using a sync.Map to hold a mutex for each file path.

  1. Add this package-level variable:
    var fileLocks sync.Map
  2. Add locking logic inside downloadURL:
    func downloadURL(...) {
        // ... after getting filePath
        filePath := filepath.Join(outputDir, fileName)
    
        mu, _ := fileLocks.LoadOrStore(filePath, &sync.Mutex{})
        mu.(*sync.Mutex).Lock()
        defer mu.(*sync.Mutex).Unlock()
    
        // ... rest of the original function body
    }

fileName, err := getFileNameFromURL(u)
if err != nil {
logMessage(cmd, fmt.Sprintf("Skipping invalid URL %s: %v", u, err), bar, outMutex)
return
}
filePath := filepath.Join(outputDir, fileName)

if skipExisting {
if _, err := os.Stat(filePath); err == nil {
logMessage(cmd, fmt.Sprintf("Skipping already downloaded file: %s", filePath), bar, outMutex)
return
}
}

resp, err := http.Get(u)
if err != nil {
logMessage(cmd, fmt.Sprintf("Error downloading %s: %v", u, err), bar, outMutex)
return
}
defer resp.Body.Close()

out, err := os.Create(filePath)
if err != nil {
logMessage(cmd, fmt.Sprintf("Error creating file for %s: %v", u, err), bar, outMutex)
return
}
defer out.Close()

_, err = io.Copy(out, resp.Body)
if err != nil {
logMessage(cmd, fmt.Sprintf("Error saving content for %s: %v", u, err), bar, outMutex)
return
}

logMessage(cmd, fmt.Sprintf("Downloaded %s to %s", u, filePath), bar, outMutex)

if delayDuration > 0 {
time.Sleep(delayDuration)
}
Comment on lines +153 to +155

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of --delay is not ideal for rate-limiting parallel downloads. It causes each worker to sleep independently, leading to request bursts rather than a steady global rate.
A better approach is to use a global rate limiter that all workers respect. This time.Sleep should be removed, and a rate-limiting mechanism (e.g., using time.Ticker) should be added in the worker dispatch loop in RunE.

Example:

// In RunE:
var limiter <-chan time.Time
if delayDuration > 0 {
    ticker := time.NewTicker(delayDuration)
    defer ticker.Stop()
    limiter = ticker.C
}
// ...
// In worker goroutine:
for u := range urlsChan {
    if limiter != nil {
        <-limiter
    }
    downloadURL(...)
    // ...
}

}

func logMessage(cmd *cobra.Command, msg string, bar *progressbar.ProgressBar, outMutex *sync.Mutex) {
if bar != nil {
bar.Describe(msg)
} else {
outMutex.Lock()
defer outMutex.Unlock()
fmt.Fprintln(cmd.OutOrStdout(), msg)
}
}

func readURLs(reader io.Reader, jqFilter string) ([]string, error) {
if jqFilter == "" {
var urls []string
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" {
urls = append(urls, line)
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return urls, nil
}

query, err := gojq.Parse(jqFilter)
if err != nil {
return nil, fmt.Errorf("error parsing jq filter: %w", err)
}

var input interface{}
decoder := json.NewDecoder(reader)
if err := decoder.Decode(&input); err != nil {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

decoder.Decode(&input) reads the entire JSON content into memory before processing. This can lead to high memory consumption and performance issues with very large JSON files. Consider using a streaming JSON parser if large inputs are expected, to process the data more efficiently without loading it all into memory at once.

return nil, fmt.Errorf("error decoding json: %w", err)
}

var urls []string
iter := query.Run(input)
for {
v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
return nil, fmt.Errorf("error executing jq filter: %w", err)
}
if s, ok := v.(string); ok {
urls = append(urls, s)
}
}
return urls, nil
}

func getFileNameFromURL(rawURL string) (string, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return "", err
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("invalid URL scheme: %s", parsedURL.Scheme)
}
if parsedURL.Path == "" || parsedURL.Path == "/" {
return "index.html", nil
}
return filepath.Base(parsedURL.Path), nil
}
Loading
Loading