diff --git a/src/requests/models.py b/src/requests/models.py index 59b5615960..9f25d8c939 100644 --- a/src/requests/models.py +++ b/src/requests/models.py @@ -678,11 +678,12 @@ def prepare_auth( auth = url_auth if any(url_auth) else None if auth: - if isinstance(auth, tuple) and len(auth) == 2: # type: ignore[arg-type] # pyright widens tuple from Callable in AuthType + if callable(auth): + auth_handler = cast("Callable[..., PreparedRequest]", auth) + elif isinstance(auth, tuple) and len(auth) == 2: # type: ignore[arg-type] # pyright widens tuple from Callable in AuthType # special-case basic HTTP auth auth_handler = HTTPBasicAuth(*auth) # type: ignore[arg-type] # pyright widens tuple from Callable in AuthType else: - # TODO: can be fixed by flipping the conditionals auth_handler = cast("Callable[..., PreparedRequest]", auth) # Allow auth to make its changes. diff --git a/tests/test_requests.py b/tests/test_requests.py index 571535fe79..c46ff5da39 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -2124,6 +2124,34 @@ def test_basic_auth_str_is_always_native(self, username, password, auth_str): assert isinstance(s, builtin_str) assert s == auth_str + def test_callable_namedtuple_auth_uses_call_not_tuple_branch(self): + """A callable that is also a 2-tuple subclass (e.g. namedtuple) must + have its __call__ invoked, not be silently downgraded to HTTPBasicAuth. + + Regression: prepare_auth previously checked isinstance(auth, tuple) + before callable(auth), so any AuthBase subclass that also inherited + from a 2-field namedtuple would have its __call__ bypassed and its + fields extracted as Basic Auth credentials instead. + """ + from collections import namedtuple + from requests.auth import AuthBase + + _Base = namedtuple("_TokenAuth", ["token", "scheme"]) + + class TokenAuth(_Base, AuthBase): + def __call__(self, r): + r.headers["Authorization"] = f"{self.scheme} {self.token}" + return r + + auth = TokenAuth(token="my-secret", scheme="Bearer") + assert isinstance(auth, tuple) # confirm the ambiguous case + assert callable(auth) + + p = requests.Request("GET", "http://example.com", auth=auth).prepare() + + assert p.headers["Authorization"] == "Bearer my-secret" + assert not p.headers["Authorization"].startswith("Basic") + def test_requests_history_is_saved(self, httpbin): r = requests.get(httpbin("redirect/5")) total = r.history[-1].history