Skip to content

Commit db8ce96

Browse files
committed
add customizable auth
1 parent 26f0845 commit db8ce96

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

src/flareio/api_client.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import requests
1010

1111
from requests.adapters import HTTPAdapter
12+
from requests.auth import AuthBase
1213
from urllib3.util import Retry
1314

1415
import typing as t
@@ -34,7 +35,7 @@ def __init__(
3435
tenant_id: t.Optional[int] = None,
3536
session: t.Optional[requests.Session] = None,
3637
api_domain: t.Optional[str] = None,
37-
_disable_auth: bool = False,
38+
_auth: AuthBase | None = None,
3839
_enable_beta_features: bool = False,
3940
) -> None:
4041
if not api_key:
@@ -52,9 +53,9 @@ def __init__(
5253
self._api_key: str = api_key
5354
self._tenant_id: t.Optional[int] = tenant_id
5455

56+
self._auth: t.Optional[AuthBase] = _auth
5557
self._api_token: t.Optional[str] = None
5658
self._api_token_exp: t.Optional[datetime] = None
57-
self._disable_auth: bool = _disable_auth
5859
self._session = session or self._create_session()
5960

6061
@classmethod
@@ -135,16 +136,24 @@ def generate_token(self) -> str:
135136

136137
return token
137138

138-
def _auth_headers(self) -> dict:
139-
if self._disable_auth:
140-
return dict()
139+
def _apply_auth(
140+
self,
141+
*,
142+
request: requests.PreparedRequest,
143+
) -> requests.PreparedRequest:
144+
if self._auth:
145+
self._auth(request)
146+
return request
147+
141148
api_token: t.Optional[str] = self._api_token
142149
if not api_token or (
143150
self._api_token_exp and self._api_token_exp < datetime.now()
144151
):
145152
api_token = self.generate_token()
146153

147-
return {"Authorization": f"Bearer {api_token}"}
154+
request.headers["Authorization"] = f"Bearer {api_token}"
155+
156+
return request
148157

149158
def _request(
150159
self,
@@ -163,19 +172,20 @@ def _request(
163172
f"Client was used to access {netloc=} at {url=}. Only the domain {self._api_domain} is supported."
164173
)
165174

166-
headers = {
167-
**(headers or {}),
168-
**self._auth_headers(),
169-
}
170-
171-
return self._session.request(
175+
request = requests.Request(
172176
method=method,
173177
url=url,
174178
params=params,
175179
json=json,
176180
headers=headers,
177181
)
178182

183+
prepared = self._session.prepare_request(request)
184+
prepared = self._apply_auth(request=prepared)
185+
resp = self._session.send(prepared)
186+
187+
return resp
188+
179189
def post(
180190
self,
181191
url: str,

src/flareio/auth.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from requests import PreparedRequest
2+
from requests.auth import AuthBase
3+
4+
5+
class _StaticHeadersAuth(AuthBase):
6+
def __init__(
7+
self,
8+
*,
9+
headers: dict[str, str],
10+
) -> None:
11+
self._headers: dict[str, str] = headers
12+
13+
def __call__(
14+
self,
15+
r: PreparedRequest,
16+
) -> PreparedRequest:
17+
r.headers.update(self._headers)
18+
return r
19+
20+
21+
class _EmptyAuth(AuthBase):
22+
def __call__(
23+
self,
24+
r: PreparedRequest,
25+
) -> PreparedRequest:
26+
return r

tests/test_api_client_endpoints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from .utils import get_test_client
55

6+
from flareio.auth import _EmptyAuth
7+
68

79
def test_wrapped_methods() -> None:
810
client = get_test_client(authenticated=False)
@@ -126,7 +128,7 @@ def test_bad_domain() -> None:
126128
def test_disable_auth_does_not_call_generate() -> None:
127129
client = get_test_client(
128130
authenticated=False,
129-
_disable_auth=True,
131+
_auth=_EmptyAuth(),
130132
)
131133
with requests_mock.Mocker() as mocker:
132134
mocker.register_uri(

tests/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import requests_mock
22

3+
from requests.auth import AuthBase
4+
35
import typing as t
46

57
from flareio import FlareApiClient
@@ -11,14 +13,14 @@ def get_test_client(
1113
authenticated: bool = True,
1214
api_domain: t.Optional[str] = None,
1315
_enable_beta_features: bool = False,
14-
_disable_auth: bool = False,
16+
_auth: t.Optional[AuthBase] = None,
1517
) -> FlareApiClient:
1618
client = FlareApiClient(
1719
api_key="test-api-key",
1820
tenant_id=tenant_id,
1921
api_domain=api_domain,
2022
_enable_beta_features=_enable_beta_features,
21-
_disable_auth=_disable_auth,
23+
_auth=_auth,
2224
)
2325

2426
if authenticated:

0 commit comments

Comments
 (0)