Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 185 additions & 30 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"context"
"fmt"
"io"
"math"
"math/rand"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -193,6 +196,9 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
httpClient: bi.httpClient,
retryMax: bi.cfg.RetryMax,
retryWaitMin: bi.cfg.RetryWaitMin,
retryWaitMax: bi.cfg.RetryWaitMax,
}
task.Run()
bi.downloadTasks.Enqueue(task)
Expand Down Expand Up @@ -252,6 +258,9 @@ type cloudFetchDownloadTask struct {
resultChan chan cloudFetchDownloadTaskResult
speedThresholdMbps float64
httpClient *http.Client
retryMax int
retryWaitMin time.Duration
retryWaitMax time.Duration
}

func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, int64, error) {
Expand Down Expand Up @@ -295,20 +304,32 @@ func (cft *cloudFetchDownloadTask) Run() {
cft.link.RowCount,
)
downloadStart := time.Now()
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
rawBody, err := fetchBatchBytes(
cft.ctx,
cft.link,
cft.minTimeToExpiry,
cft.speedThresholdMbps,
cft.httpClient,
cft.retryMax,
cft.retryWaitMin,
cft.retryWaitMax,
)
if err != nil {
cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err})
return
}

// Read all data into memory before closing
buf, err := io.ReadAll(getReader(data, cft.useLz4Compression))
data.Close() //nolint:errcheck,gosec // G104: close after reading data
downloadMs := time.Since(downloadStart).Milliseconds()
if err != nil {
cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err})
return
buf := rawBody
if cft.useLz4Compression {
// Decompression sits outside the retry loop: malformed LZ4 is data
// corruption, not a transient network condition.
buf, err = io.ReadAll(lz4.NewReader(bytes.NewReader(rawBody)))
if err != nil {
cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err})
return
}
}
downloadMs := time.Since(downloadStart).Milliseconds()

logger.Debug().Msgf(
"CloudFetch: downloaded data for link at offset %d row count %d",
Expand Down Expand Up @@ -350,43 +371,177 @@ func logCloudFetchSpeed(fullURL string, contentLength int64, duration time.Durat
}
}

// fetchBatchBytes downloads a single CloudFetch result link and returns the
// raw response body, still compressed if the server used LZ4. Connection-time
// failures, retryable HTTP statuses, and mid-stream body read failures are
// retried up to retryMax times with exponential backoff and equal jitter.
// Decompression and IPC parsing stay with the caller because those failures are
// not transient network conditions.
//
// Link expiry is rechecked after each backoff: a long retry chain may outlive
// a presigned URL, and continuing past expiry is guaranteed to fail.
func fetchBatchBytes(
ctx context.Context,
link *cli_service.TSparkArrowResultLink,
minTimeToExpiry time.Duration,
speedThresholdMbps float64,
httpClient *http.Client,
) (io.ReadCloser, error) {
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
retryMax int,
retryWaitMin time.Duration,
retryWaitMax time.Duration,
) ([]byte, error) {
if retryMax < 0 {
retryMax = 0
}

var (
lastErr error
lastStatus int
lastRetryAfter string
)

for attempt := 0; attempt <= retryMax; attempt++ {
if attempt > 0 {
wait := cloudFetchBackoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter)
logger.Debug().Msgf(
"CloudFetch: retrying download of link at offset %d (attempt %d/%d) in %v; lastStatus=%d lastErr=%v",
link.StartRowOffset, attempt, retryMax, wait, lastStatus, lastErr,
)
t := time.NewTimer(wait)
select {
case <-ctx.Done():
if !t.Stop() {
<-t.C
}
return nil, ctx.Err()
case <-t.C:
}
}

// Check link expiry *after* backoff: a long retry chain may outlive a
// presigned URL, and there's no point spending another HTTP attempt
// (or another retry) on a link we know will be rejected.
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
}

req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil)
if err != nil {
return nil, err
}
if link.HttpHeaders != nil {
for key, value := range link.HttpHeaders {
req.Header.Set(key, value)
}
}

startTime := time.Now()
res, err := httpClient.Do(req)
if err != nil {
// Caller cancellation is terminal; otherwise treat transport errors
// (TCP RST, TLS timeout, etc.) as transient.
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = err
lastStatus = 0
lastRetryAfter = ""
continue
}

if res.StatusCode == http.StatusOK {
// Read the full body inside the retry loop so truncated 200 OK
// responses are retried just like header-time failures.
buf, readErr := io.ReadAll(res.Body)
res.Body.Close() //nolint:errcheck,gosec // G104: close after drain
if readErr != nil {
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = readErr
lastStatus = 0
lastRetryAfter = ""
continue
}
logCloudFetchSpeed(link.FileLink, int64(len(buf)), time.Since(startTime), speedThresholdMbps)
return buf, nil
}

// Drain and close so the underlying connection can be reused.
_, _ = io.Copy(io.Discard, res.Body)
res.Body.Close() //nolint:errcheck,gosec // G104: closing after drain

lastStatus = res.StatusCode
lastErr = nil
lastRetryAfter = res.Header.Get("Retry-After")

if !isCloudFetchRetryableStatus(res.StatusCode) {
msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode)
return nil, dbsqlerrint.NewDriverError(ctx, msg, nil)
}
}

// TODO: Retry on HTTP errors
req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil)
if err != nil {
return nil, err
if lastStatus != 0 {
// lastErr is nil here by construction: the HTTP-status branch above
// explicitly clears it on every iteration. The status code is captured
// in msg, so there's no underlying error to wrap.
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("%s: %s %d (after %d retries)", errArrowRowsCloudFetchDownloadFailure, "HTTP error", lastStatus, retryMax), nil)
}
msg := fmt.Sprintf("%s: %v (after %d retries)", errArrowRowsCloudFetchDownloadFailure, lastErr, retryMax)
return nil, dbsqlerrint.NewDriverError(ctx, msg, lastErr)
}

// cloudFetchRetryableStatuses lists HTTP status codes from object storage that
// indicate transient conditions and warrant a retry. Mirrors AWS S3 guidance
// for SlowDown (503) / InternalError (500) plus the general 408/429/502/504.
var cloudFetchRetryableStatuses = map[int]struct{}{
http.StatusRequestTimeout: {}, // 408
http.StatusTooManyRequests: {}, // 429
http.StatusInternalServerError: {}, // 500
http.StatusBadGateway: {}, // 502
http.StatusServiceUnavailable: {}, // 503
http.StatusGatewayTimeout: {}, // 504
}

if link.HttpHeaders != nil {
for key, value := range link.HttpHeaders {
req.Header.Set(key, value)
func isCloudFetchRetryableStatus(status int) bool {
_, ok := cloudFetchRetryableStatuses[status]
return ok
}

// cloudFetchBackoff returns the wait before retry attempt N (1-based). The
// base delay is exponential — waitMin * 2^(attempt-1) capped at waitMax — with
// equal jitter applied: the actual sleep is uniformly distributed in
// [base/2, base]. Equal jitter (rather than no jitter) is used to spread
// synchronized retries across the up-to-MaxDownloadThreads concurrent
// downloads, which would otherwise hammer the storage endpoint in lockstep
// after a region-wide blip. If the server returned a parseable integer
// Retry-After header, that value (in seconds) is honored instead, capped at
// waitMax. HTTP-date Retry-After values are ignored — same as the Thrift
// client's backoff.
func cloudFetchBackoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration {
if retryAfter != "" {
if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 {
d := time.Duration(secs) * time.Second
if d > waitMax {
return waitMax
}
return d
}
}

startTime := time.Now()
res, err := httpClient.Do(req)
if err != nil {
return nil, err
expo := float64(waitMin) * math.Pow(2, float64(attempt-1))
if expo > float64(waitMax) || math.IsInf(expo, 0) {
expo = float64(waitMax)
}
if res.StatusCode != http.StatusOK {
msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode)
return nil, dbsqlerrint.NewDriverError(ctx, msg, err)
base := time.Duration(expo)
if base <= 0 {
return 0
}

// Log download speed metrics
logCloudFetchSpeed(link.FileLink, res.ContentLength, time.Since(startTime), speedThresholdMbps)

return res.Body, nil
half := base / 2
if half <= 0 {
return base
}
return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the global rand. Concurrent calls across MaxDownloadThreads goroutines will contend on the global rand mutex. Probably
immaterial at typical MaxDownloadThreads values, but for many-thousands-concurrent-downloads workloads (which is the stated motivation), worth swapping in a per-goroutine rand.Rand or math/rand/v2's Uint64N.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't fix — at MaxDownloadThreads=10 (default) the global rand mutex is held for sub-microsecond Int63n draws, and only once per backoff event (not per request). Per-goroutine rand.Rand is non-trivial code for no measurable gain. Happy to revisit if a future config tunes MaxDownloadThreads significantly higher.

}

func getReader(r io.Reader, useLz4Compression bool) io.Reader {
Expand Down
Loading
Loading