Skip to content

Commit 8a708aa

Browse files
committed
support custom auth
1 parent 1be35a6 commit 8a708aa

File tree

4 files changed

+97
-61
lines changed

4 files changed

+97
-61
lines changed

src/flareio/api_client.py

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
import os
22

3-
from datetime import datetime
43
from datetime import timedelta
54
from http.cookiejar import DefaultCookiePolicy
65
from urllib.parse import urljoin
76
from urllib.parse import urlparse
87

98
import requests
9+
from requests.auth import AuthBase
1010

1111
from requests.adapters import HTTPAdapter
1212
from urllib3.util import Retry
1313

1414
import typing as t
1515

16-
from flareio.exceptions import TokenError
16+
from flareio.auth import _FlareApiKeyAuth
1717
from flareio.models import ScrollEventsResult
1818
from flareio.ratelimit import Limiter
1919
from flareio.version import __version__ as _flareio_version
2020

2121

2222
_API_DOMAIN_DEFAULT: str = "api.flare.io"
23-
_ALLOWED_API_DOMAINS: t.Tuple[str, ...] = (
24-
_API_DOMAIN_DEFAULT,
25-
"api.eu.flare.io",
23+
_ALLOWED_API_DOMAINS: frozenset[str] = frozenset(
24+
{
25+
_API_DOMAIN_DEFAULT,
26+
"api.eu.flare.io",
27+
}
2628
)
2729

2830

@@ -34,7 +36,7 @@ def __init__(
3436
tenant_id: t.Optional[int] = None,
3537
session: t.Optional[requests.Session] = None,
3638
api_domain: t.Optional[str] = None,
37-
_disable_auth: bool = False,
39+
_auth: t.Optional[AuthBase] = None,
3840
_enable_beta_features: bool = False,
3941
) -> None:
4042
if not api_key:
@@ -49,14 +51,16 @@ def __init__(
4951
raise Exception("Custom API domains considered a beta feature.")
5052
self._api_domain: str = api_domain
5153

52-
self._api_key: str = api_key
53-
self._tenant_id: t.Optional[int] = tenant_id
54-
55-
self._api_token: t.Optional[str] = None
56-
self._api_token_exp: t.Optional[datetime] = None
57-
self._disable_auth: bool = _disable_auth
5854
self._session = session or self._create_session()
5955

56+
_auth = _auth or _FlareApiKeyAuth(
57+
api_key=api_key,
58+
api_domain=self._api_domain,
59+
tenant_id=tenant_id,
60+
session=self._session,
61+
)
62+
self._session.auth = _auth
63+
6064
@classmethod
6165
def from_env(cls) -> "FlareApiClient":
6266
api_key: t.Optional[str] = os.environ.get("FLARE_API_KEY")
@@ -110,41 +114,7 @@ def _create_retry() -> Retry:
110114
return retry
111115

112116
def generate_token(self) -> str:
113-
payload: t.Optional[dict] = None
114-
115-
if self._tenant_id is not None:
116-
payload = {
117-
"tenant_id": self._tenant_id,
118-
}
119-
120-
resp = self._session.post(
121-
f"https://{self._api_domain}/tokens/generate",
122-
json=payload,
123-
headers={
124-
"Authorization": self._api_key,
125-
},
126-
)
127-
try:
128-
resp.raise_for_status()
129-
except Exception as ex:
130-
raise TokenError("Failed to fetch API Token") from ex
131-
token: str = resp.json()["token"]
132-
133-
self._api_token = token
134-
self._api_token_exp = datetime.now() + timedelta(minutes=45)
135-
136-
return token
137-
138-
def _auth_headers(self) -> dict:
139-
if self._disable_auth:
140-
return dict()
141-
api_token: t.Optional[str] = self._api_token
142-
if not api_token or (
143-
self._api_token_exp and self._api_token_exp < datetime.now()
144-
):
145-
api_token = self.generate_token()
146-
147-
return {"Authorization": f"Bearer {api_token}"}
117+
return self._auth.generate_token()
148118

149119
def _request(
150120
self,
@@ -163,11 +133,6 @@ def _request(
163133
f"Client was used to access {netloc=} at {url=}. Only the domain {self._api_domain} is supported."
164134
)
165135

166-
headers = {
167-
**(headers or {}),
168-
**self._auth_headers(),
169-
}
170-
171136
return self._session.request(
172137
method=method,
173138
url=url,

src/flareio/auth.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from datetime import datetime
2+
from datetime import timedelta
3+
4+
import requests
5+
6+
from requests.auth import AuthBase
7+
8+
import typing as t
9+
10+
from flareio.exceptions import TokenError
11+
12+
13+
class _FlareApiKeyAuth(AuthBase):
14+
def __init__(
15+
self,
16+
*,
17+
api_key: str,
18+
api_domain: str,
19+
tenant_id: t.Optional[int] = None,
20+
session: requests.Session,
21+
) -> None:
22+
self._api_key: str = api_key
23+
self._api_domain: str = api_domain
24+
self._tenant_id: t.Optional[int] = tenant_id
25+
self._session: requests.Session = session
26+
27+
self._api_token: t.Optional[str] = None
28+
self._api_token_exp: t.Optional[datetime] = None
29+
30+
def generate_token(self) -> str:
31+
payload: t.Optional[dict] = None
32+
33+
if self._tenant_id is not None:
34+
payload = {
35+
"tenant_id": self._tenant_id,
36+
}
37+
38+
resp = self._session.post(
39+
f"https://{self._api_domain}/tokens/generate",
40+
json=payload,
41+
headers={
42+
"Authorization": self._api_key,
43+
},
44+
)
45+
try:
46+
resp.raise_for_status()
47+
except Exception as ex:
48+
raise TokenError("Failed to fetch API Token") from ex
49+
token: str = resp.json()["token"]
50+
51+
self._api_token = token
52+
self._api_token_exp = datetime.now() + timedelta(minutes=45)
53+
54+
return token
55+
56+
def __call__(
57+
self,
58+
r: requests.PreparedRequest,
59+
) -> requests.PreparedRequest:
60+
# Token generation uses API key auth, don't override it.
61+
if "Authorization" in r.headers:
62+
return r
63+
64+
# Lazy token refresh.
65+
if not self._api_token or (
66+
self._api_token_exp and self._api_token_exp < datetime.now()
67+
):
68+
self.generate_token()
69+
70+
r.headers["Authorization"] = f"Bearer {self._api_token}"
71+
return r

tests/test_api_client_creation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def test_create_client_empty_api_key() -> None:
3838

3939
def test_generate_token() -> None:
4040
client = get_test_client(authenticated=False)
41-
assert client._api_token is None
42-
assert client._api_token_exp is None
41+
assert client._auth._api_token is None
42+
assert client._auth._api_token_exp is None
4343
with requests_mock.Mocker() as mocker:
4444
mocker.register_uri(
4545
"POST",
@@ -53,9 +53,9 @@ def test_generate_token() -> None:
5353
token = client.generate_token()
5454
assert token == "test-token-hello"
5555

56-
assert client._api_token == "test-token-hello"
57-
assert client._api_token_exp
58-
assert client._api_token_exp >= datetime.now()
56+
assert client._auth._api_token == "test-token-hello"
57+
assert client._auth._api_token_exp
58+
assert client._auth._api_token_exp >= datetime.now()
5959

6060
assert mocker.last_request.url == "https://api.flare.io/tokens/generate"
6161
assert mocker.last_request.text is None
@@ -64,8 +64,8 @@ def test_generate_token() -> None:
6464

6565
def test_generate_token_error() -> None:
6666
client = get_test_client(authenticated=False)
67-
assert client._api_token is None
68-
assert client._api_token_exp is None
67+
assert client._auth._api_token is None
68+
assert client._auth._api_token_exp is None
6969

7070
with requests_mock.Mocker() as mocker:
7171
mocker.register_uri(

tests/test_api_client_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
def test_wrapped_methods() -> None:
88
client = get_test_client(authenticated=False)
9-
assert client._api_token is None
10-
assert client._api_token_exp is None
9+
assert client._auth._api_token is None
10+
assert client._auth._api_token_exp is None
1111

1212
# POST: This one will generate since its the first one.
1313
with requests_mock.Mocker() as mocker:

0 commit comments

Comments
 (0)