Skip to content

Commit 763af72

Browse files
SK-2777: update unit tests
1 parent 599bcf5 commit 763af72

2 files changed

Lines changed: 275 additions & 53 deletions

File tree

skyflow/vault/client/client.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ def initialize_client_configuration(self):
4040
self.__config.get("vault_id"),
4141
logger=self.__logger)
4242
self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials
43-
token = self.get_bearer_token(self.__credentials)
43+
bearer_token = self.get_bearer_token(self.__credentials)
4444
if needs_reinit:
45-
self.initialize_api_client(self.__vault_url, token)
45+
self.initialize_api_client(self.__vault_url, bearer_token)
4646

47-
def initialize_api_client(self, vault_url, token):
48-
self.__api_client = Skyflow(
49-
base_url=vault_url,
50-
token=lambda: self.__bearer_token if self.__bearer_token else token,
51-
)
47+
def initialize_api_client(self, vault_url, bearer_token):
48+
token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731
49+
self.__api_client = Skyflow(base_url=vault_url, token=token_provider)
5250

5351
def get_records_api(self):
5452
return self.__api_client.records

tests/vault/client/test__client.py

Lines changed: 270 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import patch, MagicMock
2+
from unittest.mock import patch, MagicMock, call
33
from skyflow.vault.client.client import VaultClient
44

55
CONFIG = {
@@ -12,11 +12,19 @@
1212
}
1313

1414
CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"}
15+
CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"}
16+
CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"}
17+
CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'}
18+
1519

1620
class TestVaultClient(unittest.TestCase):
1721
def setUp(self):
1822
self.vault_client = VaultClient(CONFIG)
1923

24+
# ------------------------------------------------------------------ #
25+
# Basic setters / getters #
26+
# ------------------------------------------------------------------ #
27+
2028
def test_set_common_skyflow_credentials(self):
2129
credentials = {"api_key": "dummy_api_key"}
2230
self.vault_client.set_common_skyflow_credentials(credentials)
@@ -28,73 +36,289 @@ def test_set_logger(self):
2836
self.assertEqual(self.vault_client.get_log_level(), "INFO")
2937
self.assertEqual(self.vault_client.get_logger(), mock_logger)
3038

39+
def test_get_vault_id(self):
40+
self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"])
41+
42+
def test_get_config(self):
43+
self.assertEqual(self.vault_client.get_config(), CONFIG)
44+
45+
def test_get_common_skyflow_credentials(self):
46+
credentials = {"api_key": "dummy_api_key"}
47+
self.vault_client.set_common_skyflow_credentials(credentials)
48+
self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials)
49+
50+
def test_get_log_level(self):
51+
self.vault_client.set_logger("DEBUG", MagicMock())
52+
self.assertEqual(self.vault_client.get_log_level(), "DEBUG")
53+
54+
def test_get_logger(self):
55+
mock_logger = MagicMock()
56+
self.vault_client.set_logger("INFO", mock_logger)
57+
self.assertEqual(self.vault_client.get_logger(), mock_logger)
58+
59+
# ------------------------------------------------------------------ #
60+
# initialize_client_configuration — first call (slow path) #
61+
# ------------------------------------------------------------------ #
62+
3163
@patch("skyflow.vault.client.client.get_credentials")
3264
@patch("skyflow.vault.client.client.get_vault_url")
3365
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
34-
def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials):
35-
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY)
66+
def test_initialize_client_configuration_first_call(
67+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials
68+
):
69+
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
3670
mock_get_vault_url.return_value = "https://test-vault-url.com"
3771

3872
self.vault_client.initialize_client_configuration()
3973

40-
mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None)
41-
mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None)
74+
mock_get_credentials.assert_called_once_with(
75+
CONFIG["credentials"], None, logger=None
76+
)
77+
mock_get_vault_url.assert_called_once_with(
78+
CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None
79+
)
4280
mock_init_api_client.assert_called_once()
4381

44-
@patch("skyflow.vault.client.client.Skyflow")
45-
def test_initialize_api_client(self, mock_api_client):
46-
self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token")
47-
mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token")
82+
# ------------------------------------------------------------------ #
83+
# initialize_client_configuration — fast path (static token) #
84+
# ------------------------------------------------------------------ #
4885

49-
def test_get_records_api(self):
86+
@patch("skyflow.vault.client.client.get_credentials")
87+
@patch("skyflow.vault.client.client.get_vault_url")
88+
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
89+
def test_initialize_client_configuration_fast_path_api_key(
90+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials
91+
):
92+
"""Once initialized with api_key, subsequent calls skip all work."""
93+
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
94+
mock_get_vault_url.return_value = "https://test-vault-url.com"
95+
# Side-effect simulates initialize_api_client actually setting __api_client
96+
mock_init_api_client.side_effect = lambda *_: setattr(
97+
self.vault_client, "_VaultClient__api_client", MagicMock()
98+
)
99+
100+
self.vault_client.initialize_client_configuration() # first call — slow path
101+
mock_get_credentials.reset_mock()
102+
mock_get_vault_url.reset_mock()
103+
mock_init_api_client.reset_mock()
104+
105+
self.vault_client.initialize_client_configuration() # second call — fast path
106+
107+
mock_get_credentials.assert_not_called()
108+
mock_get_vault_url.assert_not_called()
109+
mock_init_api_client.assert_not_called()
110+
111+
@patch("skyflow.vault.client.client.get_credentials")
112+
@patch("skyflow.vault.client.client.get_vault_url")
113+
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
114+
def test_initialize_client_configuration_fast_path_static_token(
115+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials
116+
):
117+
"""Once initialized with a static token, subsequent calls skip all work."""
118+
mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN
119+
mock_get_vault_url.return_value = "https://test-vault-url.com"
120+
mock_init_api_client.side_effect = lambda *_: setattr(
121+
self.vault_client, "_VaultClient__api_client", MagicMock()
122+
)
123+
124+
self.vault_client.initialize_client_configuration()
125+
mock_get_credentials.reset_mock()
126+
mock_get_vault_url.reset_mock()
127+
mock_init_api_client.reset_mock()
128+
129+
self.vault_client.initialize_client_configuration()
130+
131+
mock_get_credentials.assert_not_called()
132+
mock_get_vault_url.assert_not_called()
133+
mock_init_api_client.assert_not_called()
134+
135+
# ------------------------------------------------------------------ #
136+
# initialize_client_configuration — fast path (service account) #
137+
# ------------------------------------------------------------------ #
138+
139+
@patch("skyflow.vault.client.client.is_expired", return_value=False)
140+
@patch("skyflow.vault.client.client.get_credentials")
141+
@patch("skyflow.vault.client.client.get_vault_url")
142+
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
143+
def test_initialize_client_configuration_fast_path_valid_sa_token(
144+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired
145+
):
146+
"""Service account with a still-valid token skips get_bearer_token entirely."""
147+
mock_get_credentials.return_value = CREDENTIALS_WITH_PATH
148+
mock_get_vault_url.return_value = "https://test-vault-url.com"
149+
150+
# Seed the cached bearer token as if first call already ran
50151
self.vault_client._VaultClient__api_client = MagicMock()
51-
self.vault_client._VaultClient__api_client.records = MagicMock()
52-
records_api = self.vault_client.get_records_api()
53-
self.assertIsNotNone(records_api)
152+
self.vault_client._VaultClient__is_static_token = False
153+
self.vault_client._VaultClient__bearer_token = "cached_sa_token"
154+
self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH
54155

55-
def test_get_tokens_api(self):
156+
self.vault_client.initialize_client_configuration()
157+
158+
mock_get_credentials.assert_not_called()
159+
mock_get_vault_url.assert_not_called()
160+
mock_init_api_client.assert_not_called()
161+
162+
# ------------------------------------------------------------------ #
163+
# initialize_client_configuration — token expiry (no client reinit) #
164+
# ------------------------------------------------------------------ #
165+
166+
@patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None))
167+
@patch("skyflow.vault.client.client.is_expired", return_value=True)
168+
@patch("skyflow.vault.client.client.get_credentials")
169+
@patch("skyflow.vault.client.client.get_vault_url")
170+
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
171+
def test_initialize_client_configuration_expired_token_no_reinit(
172+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials,
173+
mock_is_expired, mock_generate_bearer_token
174+
):
175+
"""Expired service account token is regenerated in-place; httpx client is NOT recreated."""
176+
mock_get_credentials.return_value = CREDENTIALS_WITH_PATH
177+
mock_get_vault_url.return_value = "https://test-vault-url.com"
178+
179+
# Client already initialized — simulate warm state with an expired token
56180
self.vault_client._VaultClient__api_client = MagicMock()
57-
self.vault_client._VaultClient__api_client.tokens = MagicMock()
58-
tokens_api = self.vault_client.get_tokens_api()
59-
self.assertIsNotNone(tokens_api)
181+
self.vault_client._VaultClient__is_static_token = False
182+
self.vault_client._VaultClient__bearer_token = "expired_sa_token"
183+
self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH
60184

61-
def test_get_query_api(self):
185+
self.vault_client.initialize_client_configuration()
186+
187+
# Token was regenerated
188+
mock_generate_bearer_token.assert_called_once()
189+
self.assertEqual(
190+
self.vault_client._VaultClient__bearer_token, "new_sa_token"
191+
)
192+
# httpx client was NOT recreated
193+
mock_init_api_client.assert_not_called()
194+
195+
# ------------------------------------------------------------------ #
196+
# initialize_client_configuration — config update forces reinit #
197+
# ------------------------------------------------------------------ #
198+
199+
@patch("skyflow.vault.client.client.get_credentials")
200+
@patch("skyflow.vault.client.client.get_vault_url")
201+
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
202+
def test_initialize_client_configuration_reinit_after_update_config(
203+
self, mock_init_api_client, mock_get_vault_url, mock_get_credentials
204+
):
205+
"""update_config() marks the client stale; next call must recreate it."""
206+
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
207+
mock_get_vault_url.return_value = "https://test-vault-url.com"
208+
209+
# Simulate already-initialized client
62210
self.vault_client._VaultClient__api_client = MagicMock()
63-
self.vault_client._VaultClient__api_client.query = MagicMock()
64-
query_api = self.vault_client.get_query_api()
65-
self.assertIsNotNone(query_api)
211+
self.vault_client._VaultClient__is_static_token = True
66212

67-
def test_get_vault_id(self):
68-
self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"])
213+
self.vault_client.update_config({"cluster_id": "new_cluster"})
214+
self.vault_client.initialize_client_configuration()
215+
216+
mock_get_credentials.assert_called_once()
217+
mock_get_vault_url.assert_called_once()
218+
mock_init_api_client.assert_called_once()
219+
220+
# ------------------------------------------------------------------ #
221+
# initialize_api_client — lambda token provider #
222+
# ------------------------------------------------------------------ #
223+
224+
@patch("skyflow.vault.client.client.Skyflow")
225+
def test_initialize_api_client_passes_callable_token(self, mock_skyflow):
226+
"""initialize_api_client must pass a callable (lambda) as token, not a string."""
227+
self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token")
228+
229+
args, kwargs = mock_skyflow.call_args
230+
self.assertEqual(kwargs["base_url"], "https://test-vault-url.com")
231+
self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)")
232+
233+
@patch("skyflow.vault.client.client.Skyflow")
234+
def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow):
235+
"""Lambda returns __bearer_token when it is set (interceptor behaviour)."""
236+
self.vault_client._VaultClient__bearer_token = "refreshed_token"
237+
self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token")
238+
239+
_, kwargs = mock_skyflow.call_args
240+
self.assertEqual(kwargs["token"](), "refreshed_token")
241+
242+
@patch("skyflow.vault.client.client.Skyflow")
243+
def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow):
244+
"""Lambda falls back to the initial token when __bearer_token is None."""
245+
self.vault_client._VaultClient__bearer_token = None
246+
self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token")
247+
248+
_, kwargs = mock_skyflow.call_args
249+
self.assertEqual(kwargs["token"](), "initial_token")
250+
251+
# ------------------------------------------------------------------ #
252+
# get_bearer_token #
253+
# ------------------------------------------------------------------ #
254+
255+
def test_get_bearer_token_with_api_key(self):
256+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY)
257+
self.assertEqual(result, "dummy_api_key")
258+
259+
def test_get_bearer_token_with_static_token(self):
260+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN)
261+
self.assertEqual(result, "dummy_static_token")
262+
263+
@patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None))
264+
def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate):
265+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH)
266+
mock_generate.assert_called_once()
267+
self.assertEqual(result, "sa_token")
268+
self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token")
269+
270+
@patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None))
271+
@patch("skyflow.vault.client.client.log_info")
272+
def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate):
273+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING)
274+
mock_generate.assert_called_once()
275+
self.assertEqual(result, "sa_token_str")
276+
277+
@patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None))
278+
@patch("skyflow.vault.client.client.is_expired", return_value=True)
279+
@patch("skyflow.vault.client.client.log_info")
280+
def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate):
281+
"""Expired token is regenerated silently — no exception raised."""
282+
self.vault_client._VaultClient__bearer_token = "expired_token"
283+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH)
284+
mock_generate.assert_called_once()
285+
self.assertEqual(result, "new_token")
69286

70287
@patch("skyflow.vault.client.client.generate_bearer_token")
71-
@patch("skyflow.vault.client.client.generate_bearer_token_from_creds")
288+
@patch("skyflow.vault.client.client.is_expired", return_value=False)
72289
@patch("skyflow.vault.client.client.log_info")
73-
def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token,
74-
mock_generate_bearer_token_from_creds):
75-
token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY)
76-
self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"])
77-
78-
def test_update_config(self):
79-
new_config = {"credentials": "new_credentials"}
80-
self.vault_client.update_config(new_config)
290+
def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate):
291+
"""Valid cached token is reused without calling generate_bearer_token."""
292+
self.vault_client._VaultClient__bearer_token = "valid_token"
293+
result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH)
294+
mock_generate.assert_not_called()
295+
self.assertEqual(result, "valid_token")
296+
297+
# ------------------------------------------------------------------ #
298+
# update_config #
299+
# ------------------------------------------------------------------ #
300+
301+
def test_update_config_sets_flag(self):
302+
self.vault_client.update_config({"credentials": "new_credentials"})
81303
self.assertTrue(self.vault_client._VaultClient__is_config_updated)
82304
self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials")
83305

84-
def test_get_config(self):
85-
self.assertEqual(self.vault_client.get_config(), CONFIG)
306+
# ------------------------------------------------------------------ #
307+
# API accessor stubs #
308+
# ------------------------------------------------------------------ #
86309

87-
def test_get_common_skyflow_credentials(self):
88-
credentials = {"api_key": "dummy_api_key"}
89-
self.vault_client.set_common_skyflow_credentials(credentials)
90-
self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials)
310+
def test_get_records_api(self):
311+
self.vault_client._VaultClient__api_client = MagicMock()
312+
self.assertIsNotNone(self.vault_client.get_records_api())
91313

92-
def test_get_log_level(self):
93-
log_level = "DEBUG"
94-
self.vault_client.set_logger(log_level, MagicMock())
95-
self.assertEqual(self.vault_client.get_log_level(), log_level)
314+
def test_get_tokens_api(self):
315+
self.vault_client._VaultClient__api_client = MagicMock()
316+
self.assertIsNotNone(self.vault_client.get_tokens_api())
96317

97-
def test_get_logger(self):
98-
mock_logger = MagicMock()
99-
self.vault_client.set_logger("INFO", mock_logger)
100-
self.assertEqual(self.vault_client.get_logger(), mock_logger)
318+
def test_get_query_api(self):
319+
self.vault_client._VaultClient__api_client = MagicMock()
320+
self.assertIsNotNone(self.vault_client.get_query_api())
321+
322+
323+
if __name__ == "__main__":
324+
unittest.main()

0 commit comments

Comments
 (0)