Skip to content

Commit 9fc330c

Browse files
authored
RateLimitRoundtripper: Fix mutex leak and not respecting context cancellation (#2298)
- Make sure that all exit paths from functions correctly release the mutex. - Add a sleep function that respects context cancellation, and use it in-place of time.Sleep.
1 parent 68ee84d commit 9fc330c

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

github/transport.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package github
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"io"
78
"log"
@@ -66,7 +67,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
6667
// for read and write requests. See isWriteMethod for the distinction between them.
6768
if rlt.nextRequestDelay > 0 {
6869
log.Printf("[DEBUG] Sleeping %s between operations", rlt.nextRequestDelay)
69-
time.Sleep(rlt.nextRequestDelay)
70+
sleep(req.Context(), rlt.nextRequestDelay)
7071
}
7172

7273
rlt.nextRequestDelay = rlt.calculateNextDelay(req.Method)
@@ -82,6 +83,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
8283
// See https://github.com/google/go-github/pull/986
8384
r1, r2, err := drainBody(resp.Body)
8485
if err != nil {
86+
rlt.smartLock(false)
8587
return nil, err
8688
}
8789
resp.Body = r1
@@ -95,7 +97,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
9597
retryAfter := arlErr.GetRetryAfter()
9698
log.Printf("[WARN] Abuse detection mechanism triggered, sleeping for %s before retrying",
9799
retryAfter)
98-
time.Sleep(retryAfter)
100+
sleep(req.Context(), retryAfter)
99101
rlt.smartLock(false)
100102
return rlt.RoundTrip(req)
101103
}
@@ -106,7 +108,7 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
106108
retryAfter := time.Until(rlErr.Rate.Reset.Time)
107109
log.Printf("[WARN] Rate limit %d reached, sleeping for %s (until %s) before retrying",
108110
rlErr.Rate.Limit, retryAfter, time.Now().Add(retryAfter))
109-
time.Sleep(retryAfter)
111+
sleep(req.Context(), retryAfter)
110112
rlt.smartLock(false)
111113
return rlt.RoundTrip(req)
112114
}
@@ -116,6 +118,17 @@ func (rlt *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, err
116118
return resp, nil
117119
}
118120

121+
// sleep is used an alternative to time.Sleep that supports cancellation via the passed context.Context.
122+
func sleep(ctx context.Context, dur time.Duration) {
123+
t := time.NewTimer(dur)
124+
defer t.Stop()
125+
126+
select {
127+
case <-t.C:
128+
case <-ctx.Done():
129+
}
130+
}
131+
119132
// smartLock wraps the mutex locking system and performs its operation via a boolean input for locking and unlocking.
120133
// It also skips the locking when parallelRequests is set to true since, in this case, the lock is not needed.
121134
func (rlt *RateLimitTransport) smartLock(lock bool) {

github/transport_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,42 @@ func TestRateLimitTransport_abuseLimit_get(t *testing.T) {
160160
}
161161
}
162162

163+
func TestRateLimitTransport_abuseLimit_get_cancelled(t *testing.T) {
164+
ts := githubApiMock([]*mockResponse{
165+
{
166+
ExpectedUri: "/repos/test/blah",
167+
ResponseBody: `{
168+
"message": "You have triggered an abuse detection mechanism and have been temporarily blocked from content creation. Please retry your request again later.",
169+
"documentation_url": "https://developer.github.com/v3/#abuse-rate-limits"
170+
}`,
171+
StatusCode: 403,
172+
ResponseHeaders: map[string]string{
173+
"Retry-After": "10",
174+
},
175+
},
176+
})
177+
defer ts.Close()
178+
179+
httpClient := http.DefaultClient
180+
httpClient.Transport = NewRateLimitTransport(http.DefaultTransport)
181+
182+
client := github.NewClient(httpClient)
183+
u, _ := url.Parse(ts.URL + "/")
184+
client.BaseURL = u
185+
186+
ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
187+
defer cancel()
188+
189+
start := time.Now()
190+
_, _, err := client.Repositories.Get(ctx, "test", "blah")
191+
if !errors.Is(err, context.DeadlineExceeded) {
192+
t.Fatalf("Expected context deadline exceeded, got: %v", err)
193+
}
194+
if time.Since(start) > time.Second {
195+
t.Fatalf("Waited for longer than expected: %s", time.Since(start))
196+
}
197+
}
198+
163199
func TestRateLimitTransport_abuseLimit_post(t *testing.T) {
164200
ts := githubApiMock([]*mockResponse{
165201
{

0 commit comments

Comments
 (0)