Skip to content

Commit b79e020

Browse files
committed
Optimize digest auth
* Use transport middleware instead of client middleware to compatible with response auto unmarshal (SetSuccessResult / SetErrorResult). * Support cache, no need to re-authenticate the same site. * Deprecate request level digest auth. Signed-off-by: roc <roc@imroc.cc>
1 parent 82648f7 commit b79e020

3 files changed

Lines changed: 172 additions & 5 deletions

File tree

client.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type Client struct {
4949
DebugLog bool
5050
AllowGetMethodPayload bool
5151
*Transport
52-
52+
digestAuth *digestAuth
5353
cookiejarFactory func() *cookiejar.Jar
5454
trace bool
5555
disableAutoReadResponse bool
@@ -842,10 +842,14 @@ func (c *Client) SetCommonBasicAuth(username, password string) *Client {
842842
// Information about Digest Access Authentication can be found in RFC7616:
843843
//
844844
// https://datatracker.ietf.org/doc/html/rfc7616
845-
//
846-
// See `Request.SetDigestAuth`
847845
func (c *Client) SetCommonDigestAuth(username, password string) *Client {
848-
c.OnAfterResponse(handleDigestAuthFunc(username, password))
846+
c.digestAuth = &digestAuth{
847+
Username: username,
848+
Password: password,
849+
HttpClient: c.httpClient,
850+
cache: make(map[string]*cchal),
851+
}
852+
c.Transport.WrapRoundTripFunc(c.digestAuth.HttpRoundTripWrapperFunc)
849853
return c
850854
}
851855

digest.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,140 @@
11
package req
22

33
import (
4+
"bytes"
5+
"io"
46
"net/http"
7+
"sync"
58

69
"github.com/icholy/digest"
710
"github.com/imroc/req/v3/internal/header"
811
)
912

13+
// cchal is a cached challenge and the number of times it's been used.
14+
type cchal struct {
15+
c *digest.Challenge
16+
n int
17+
}
18+
19+
type digestAuth struct {
20+
Username string
21+
Password string
22+
HttpClient *http.Client
23+
cache map[string]*cchal
24+
cacheMu sync.Mutex
25+
}
26+
27+
func (da *digestAuth) digest(req *http.Request, chal *digest.Challenge, count int) (*digest.Credentials, error) {
28+
opt := digest.Options{
29+
Method: req.Method,
30+
URI: req.URL.RequestURI(),
31+
GetBody: req.GetBody,
32+
Count: count,
33+
Username: da.Username,
34+
Password: da.Password,
35+
}
36+
return digest.Digest(chal, opt)
37+
}
38+
39+
// challenge returns a cached challenge and count for the provided request
40+
func (da *digestAuth) challenge(req *http.Request) (*digest.Challenge, int, bool) {
41+
da.cacheMu.Lock()
42+
defer da.cacheMu.Unlock()
43+
host := req.URL.Hostname()
44+
cc, ok := da.cache[host]
45+
if !ok {
46+
return nil, 0, false
47+
}
48+
cc.n++
49+
return cc.c, cc.n, true
50+
}
51+
52+
// prepare attempts to find a cached challenge that matches the
53+
// requested domain, and use it to set the Authorization header
54+
func (da *digestAuth) prepare(req *http.Request) error {
55+
// add cookies
56+
if da.HttpClient.Jar != nil {
57+
for _, cookie := range da.HttpClient.Jar.Cookies(req.URL) {
58+
req.AddCookie(cookie)
59+
}
60+
}
61+
// add auth
62+
chal, count, ok := da.challenge(req)
63+
if !ok {
64+
return nil
65+
}
66+
cred, err := da.digest(req, chal, count)
67+
if err != nil {
68+
return err
69+
}
70+
if cred != nil {
71+
req.Header.Set("Authorization", cred.String())
72+
}
73+
return nil
74+
}
75+
76+
func (da *digestAuth) HttpRoundTripWrapperFunc(rt http.RoundTripper) HttpRoundTripFunc {
77+
return func(req *http.Request) (resp *http.Response, err error) {
78+
clone, err := cloner(req)
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
// make a copy of the request
84+
first, err := clone()
85+
if err != nil {
86+
return nil, err
87+
}
88+
89+
// prepare the first request using a cached challenge
90+
if err := da.prepare(first); err != nil {
91+
return nil, err
92+
}
93+
94+
// the first request will either succeed or return a 401
95+
res, err := rt.RoundTrip(first)
96+
if err != nil || res.StatusCode != http.StatusUnauthorized {
97+
return res, err
98+
}
99+
100+
// drain and close the first message body
101+
_, _ = io.Copy(io.Discard, res.Body)
102+
_ = res.Body.Close()
103+
104+
// find and cache the challenge
105+
host := req.URL.Hostname()
106+
chal, err := digest.FindChallenge(res.Header)
107+
if err != nil {
108+
// existing cached challenge didn't work, so remove it
109+
da.cacheMu.Lock()
110+
delete(da.cache, host)
111+
da.cacheMu.Unlock()
112+
if err == digest.ErrNoChallenge {
113+
return res, nil
114+
}
115+
return nil, err
116+
} else {
117+
// found new challenge, so cache it
118+
da.cacheMu.Lock()
119+
da.cache[host] = &cchal{c: chal}
120+
da.cacheMu.Unlock()
121+
}
122+
123+
// make a second copy of the request
124+
second, err := clone()
125+
if err != nil {
126+
return nil, err
127+
}
128+
129+
// prepare the second request based on the new challenge
130+
if err := da.prepare(second); err != nil {
131+
return nil, err
132+
}
133+
134+
return rt.RoundTrip(second)
135+
}
136+
}
137+
10138
// create response middleware for http digest authentication.
11139
func handleDigestAuthFunc(username, password string) ResponseMiddleware {
12140
return func(client *Client, resp *Response) error {
@@ -60,3 +188,38 @@ func createDigestAuth(req *http.Request, resp *http.Response, username, password
60188
}
61189
return cred.String(), nil
62190
}
191+
192+
// cloner returns a function which makes clones of the provided request
193+
func cloner(req *http.Request) (func() (*http.Request, error), error) {
194+
getbody := req.GetBody
195+
// if there's no GetBody function set we have to copy the body
196+
// into memory to use for future clones
197+
if getbody == nil {
198+
if req.Body == nil || req.Body == http.NoBody {
199+
getbody = func() (io.ReadCloser, error) {
200+
return http.NoBody, nil
201+
}
202+
} else {
203+
body, err := io.ReadAll(req.Body)
204+
if err != nil {
205+
return nil, err
206+
}
207+
if err := req.Body.Close(); err != nil {
208+
return nil, err
209+
}
210+
getbody = func() (io.ReadCloser, error) {
211+
return io.NopCloser(bytes.NewReader(body)), nil
212+
}
213+
}
214+
}
215+
return func() (*http.Request, error) {
216+
clone := req.Clone(req.Context())
217+
body, err := getbody()
218+
if err != nil {
219+
return nil, err
220+
}
221+
clone.Body = body
222+
clone.GetBody = getbody
223+
return clone, nil
224+
}, nil
225+
}

request.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ func (r *Request) SetBasicAuth(username, password string) *Request {
431431
//
432432
// https://datatracker.ietf.org/doc/html/rfc7616
433433
//
434-
// This method overrides the username and password set by method `Client.SetCommonDigestAuth`.
434+
// Deprecated: Use Client.SetCommonDigestAuth instead. Request level digest auth is not recommended,
435435
func (r *Request) SetDigestAuth(username, password string) *Request {
436436
r.OnAfterResponse(handleDigestAuthFunc(username, password))
437437
return r

0 commit comments

Comments
 (0)