11import unittest
2- from unittest .mock import patch , MagicMock
2+ from unittest .mock import patch , MagicMock , call
33from skyflow .vault .client .client import VaultClient
44
55CONFIG = {
1212}
1313
1414CREDENTIALS_WITH_API_KEY = {"api_key" : "dummy_api_key" }
15+ CREDENTIALS_WITH_TOKEN = {"token" : "dummy_static_token" }
16+ CREDENTIALS_WITH_PATH = {"path" : "/some/path/credentials.json" }
17+ CREDENTIALS_WITH_STRING = {"credentials_string" : '{"clientID": "x"}' }
18+
1519
1620class TestVaultClient (unittest .TestCase ):
1721 def setUp (self ):
1822 self .vault_client = VaultClient (CONFIG )
1923
24+ # ------------------------------------------------------------------ #
25+ # Basic setters / getters #
26+ # ------------------------------------------------------------------ #
27+
2028 def test_set_common_skyflow_credentials (self ):
2129 credentials = {"api_key" : "dummy_api_key" }
2230 self .vault_client .set_common_skyflow_credentials (credentials )
@@ -28,73 +36,289 @@ def test_set_logger(self):
2836 self .assertEqual (self .vault_client .get_log_level (), "INFO" )
2937 self .assertEqual (self .vault_client .get_logger (), mock_logger )
3038
39+ def test_get_vault_id (self ):
40+ self .assertEqual (self .vault_client .get_vault_id (), CONFIG ["vault_id" ])
41+
42+ def test_get_config (self ):
43+ self .assertEqual (self .vault_client .get_config (), CONFIG )
44+
45+ def test_get_common_skyflow_credentials (self ):
46+ credentials = {"api_key" : "dummy_api_key" }
47+ self .vault_client .set_common_skyflow_credentials (credentials )
48+ self .assertEqual (self .vault_client .get_common_skyflow_credentials (), credentials )
49+
50+ def test_get_log_level (self ):
51+ self .vault_client .set_logger ("DEBUG" , MagicMock ())
52+ self .assertEqual (self .vault_client .get_log_level (), "DEBUG" )
53+
54+ def test_get_logger (self ):
55+ mock_logger = MagicMock ()
56+ self .vault_client .set_logger ("INFO" , mock_logger )
57+ self .assertEqual (self .vault_client .get_logger (), mock_logger )
58+
59+ # ------------------------------------------------------------------ #
60+ # initialize_client_configuration — first call (slow path) #
61+ # ------------------------------------------------------------------ #
62+
3163 @patch ("skyflow.vault.client.client.get_credentials" )
3264 @patch ("skyflow.vault.client.client.get_vault_url" )
3365 @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
34- def test_initialize_client_configuration (self , mock_init_api_client , mock_get_vault_url , mock_get_credentials ):
35- mock_get_credentials .return_value = (CREDENTIALS_WITH_API_KEY )
66+ def test_initialize_client_configuration_first_call (
67+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials
68+ ):
69+ mock_get_credentials .return_value = CREDENTIALS_WITH_API_KEY
3670 mock_get_vault_url .return_value = "https://test-vault-url.com"
3771
3872 self .vault_client .initialize_client_configuration ()
3973
40- mock_get_credentials .assert_called_once_with (CONFIG ["credentials" ], None , logger = None )
41- mock_get_vault_url .assert_called_once_with (CONFIG ["cluster_id" ], CONFIG ["env" ], CONFIG ["vault_id" ], logger = None )
74+ mock_get_credentials .assert_called_once_with (
75+ CONFIG ["credentials" ], None , logger = None
76+ )
77+ mock_get_vault_url .assert_called_once_with (
78+ CONFIG ["cluster_id" ], CONFIG ["env" ], CONFIG ["vault_id" ], logger = None
79+ )
4280 mock_init_api_client .assert_called_once ()
4381
44- @patch ("skyflow.vault.client.client.Skyflow" )
45- def test_initialize_api_client (self , mock_api_client ):
46- self .vault_client .initialize_api_client ("https://test-vault-url.com" , "dummy_token" )
47- mock_api_client .assert_called_once_with (base_url = "https://test-vault-url.com" , token = "dummy_token" )
82+ # ------------------------------------------------------------------ #
83+ # initialize_client_configuration — fast path (static token) #
84+ # ------------------------------------------------------------------ #
4885
49- def test_get_records_api (self ):
86+ @patch ("skyflow.vault.client.client.get_credentials" )
87+ @patch ("skyflow.vault.client.client.get_vault_url" )
88+ @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
89+ def test_initialize_client_configuration_fast_path_api_key (
90+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials
91+ ):
92+ """Once initialized with api_key, subsequent calls skip all work."""
93+ mock_get_credentials .return_value = CREDENTIALS_WITH_API_KEY
94+ mock_get_vault_url .return_value = "https://test-vault-url.com"
95+ # Side-effect simulates initialize_api_client actually setting __api_client
96+ mock_init_api_client .side_effect = lambda * _ : setattr (
97+ self .vault_client , "_VaultClient__api_client" , MagicMock ()
98+ )
99+
100+ self .vault_client .initialize_client_configuration () # first call — slow path
101+ mock_get_credentials .reset_mock ()
102+ mock_get_vault_url .reset_mock ()
103+ mock_init_api_client .reset_mock ()
104+
105+ self .vault_client .initialize_client_configuration () # second call — fast path
106+
107+ mock_get_credentials .assert_not_called ()
108+ mock_get_vault_url .assert_not_called ()
109+ mock_init_api_client .assert_not_called ()
110+
111+ @patch ("skyflow.vault.client.client.get_credentials" )
112+ @patch ("skyflow.vault.client.client.get_vault_url" )
113+ @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
114+ def test_initialize_client_configuration_fast_path_static_token (
115+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials
116+ ):
117+ """Once initialized with a static token, subsequent calls skip all work."""
118+ mock_get_credentials .return_value = CREDENTIALS_WITH_TOKEN
119+ mock_get_vault_url .return_value = "https://test-vault-url.com"
120+ mock_init_api_client .side_effect = lambda * _ : setattr (
121+ self .vault_client , "_VaultClient__api_client" , MagicMock ()
122+ )
123+
124+ self .vault_client .initialize_client_configuration ()
125+ mock_get_credentials .reset_mock ()
126+ mock_get_vault_url .reset_mock ()
127+ mock_init_api_client .reset_mock ()
128+
129+ self .vault_client .initialize_client_configuration ()
130+
131+ mock_get_credentials .assert_not_called ()
132+ mock_get_vault_url .assert_not_called ()
133+ mock_init_api_client .assert_not_called ()
134+
135+ # ------------------------------------------------------------------ #
136+ # initialize_client_configuration — fast path (service account) #
137+ # ------------------------------------------------------------------ #
138+
139+ @patch ("skyflow.vault.client.client.is_expired" , return_value = False )
140+ @patch ("skyflow.vault.client.client.get_credentials" )
141+ @patch ("skyflow.vault.client.client.get_vault_url" )
142+ @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
143+ def test_initialize_client_configuration_fast_path_valid_sa_token (
144+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials , mock_is_expired
145+ ):
146+ """Service account with a still-valid token skips get_bearer_token entirely."""
147+ mock_get_credentials .return_value = CREDENTIALS_WITH_PATH
148+ mock_get_vault_url .return_value = "https://test-vault-url.com"
149+
150+ # Seed the cached bearer token as if first call already ran
50151 self .vault_client ._VaultClient__api_client = MagicMock ()
51- self .vault_client ._VaultClient__api_client . records = MagicMock ()
52- records_api = self .vault_client .get_records_api ()
53- self .assertIsNotNone ( records_api )
152+ self .vault_client ._VaultClient__is_static_token = False
153+ self .vault_client ._VaultClient__bearer_token = "cached_sa_token"
154+ self .vault_client . _VaultClient__credentials = CREDENTIALS_WITH_PATH
54155
55- def test_get_tokens_api (self ):
156+ self .vault_client .initialize_client_configuration ()
157+
158+ mock_get_credentials .assert_not_called ()
159+ mock_get_vault_url .assert_not_called ()
160+ mock_init_api_client .assert_not_called ()
161+
162+ # ------------------------------------------------------------------ #
163+ # initialize_client_configuration — token expiry (no client reinit) #
164+ # ------------------------------------------------------------------ #
165+
166+ @patch ("skyflow.vault.client.client.generate_bearer_token" , return_value = ("new_sa_token" , None ))
167+ @patch ("skyflow.vault.client.client.is_expired" , return_value = True )
168+ @patch ("skyflow.vault.client.client.get_credentials" )
169+ @patch ("skyflow.vault.client.client.get_vault_url" )
170+ @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
171+ def test_initialize_client_configuration_expired_token_no_reinit (
172+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials ,
173+ mock_is_expired , mock_generate_bearer_token
174+ ):
175+ """Expired service account token is regenerated in-place; httpx client is NOT recreated."""
176+ mock_get_credentials .return_value = CREDENTIALS_WITH_PATH
177+ mock_get_vault_url .return_value = "https://test-vault-url.com"
178+
179+ # Client already initialized — simulate warm state with an expired token
56180 self .vault_client ._VaultClient__api_client = MagicMock ()
57- self .vault_client ._VaultClient__api_client . tokens = MagicMock ()
58- tokens_api = self .vault_client .get_tokens_api ()
59- self .assertIsNotNone ( tokens_api )
181+ self .vault_client ._VaultClient__is_static_token = False
182+ self .vault_client ._VaultClient__bearer_token = "expired_sa_token"
183+ self .vault_client . _VaultClient__credentials = CREDENTIALS_WITH_PATH
60184
61- def test_get_query_api (self ):
185+ self .vault_client .initialize_client_configuration ()
186+
187+ # Token was regenerated
188+ mock_generate_bearer_token .assert_called_once ()
189+ self .assertEqual (
190+ self .vault_client ._VaultClient__bearer_token , "new_sa_token"
191+ )
192+ # httpx client was NOT recreated
193+ mock_init_api_client .assert_not_called ()
194+
195+ # ------------------------------------------------------------------ #
196+ # initialize_client_configuration — config update forces reinit #
197+ # ------------------------------------------------------------------ #
198+
199+ @patch ("skyflow.vault.client.client.get_credentials" )
200+ @patch ("skyflow.vault.client.client.get_vault_url" )
201+ @patch ("skyflow.vault.client.client.VaultClient.initialize_api_client" )
202+ def test_initialize_client_configuration_reinit_after_update_config (
203+ self , mock_init_api_client , mock_get_vault_url , mock_get_credentials
204+ ):
205+ """update_config() marks the client stale; next call must recreate it."""
206+ mock_get_credentials .return_value = CREDENTIALS_WITH_API_KEY
207+ mock_get_vault_url .return_value = "https://test-vault-url.com"
208+
209+ # Simulate already-initialized client
62210 self .vault_client ._VaultClient__api_client = MagicMock ()
63- self .vault_client ._VaultClient__api_client .query = MagicMock ()
64- query_api = self .vault_client .get_query_api ()
65- self .assertIsNotNone (query_api )
211+ self .vault_client ._VaultClient__is_static_token = True
66212
67- def test_get_vault_id (self ):
68- self .assertEqual (self .vault_client .get_vault_id (), CONFIG ["vault_id" ])
213+ self .vault_client .update_config ({"cluster_id" : "new_cluster" })
214+ self .vault_client .initialize_client_configuration ()
215+
216+ mock_get_credentials .assert_called_once ()
217+ mock_get_vault_url .assert_called_once ()
218+ mock_init_api_client .assert_called_once ()
219+
220+ # ------------------------------------------------------------------ #
221+ # initialize_api_client — lambda token provider #
222+ # ------------------------------------------------------------------ #
223+
224+ @patch ("skyflow.vault.client.client.Skyflow" )
225+ def test_initialize_api_client_passes_callable_token (self , mock_skyflow ):
226+ """initialize_api_client must pass a callable (lambda) as token, not a string."""
227+ self .vault_client .initialize_api_client ("https://test-vault-url.com" , "initial_token" )
228+
229+ args , kwargs = mock_skyflow .call_args
230+ self .assertEqual (kwargs ["base_url" ], "https://test-vault-url.com" )
231+ self .assertTrue (callable (kwargs ["token" ]), "token must be a callable (lambda)" )
232+
233+ @patch ("skyflow.vault.client.client.Skyflow" )
234+ def test_initialize_api_client_lambda_returns_cached_bearer_token (self , mock_skyflow ):
235+ """Lambda returns __bearer_token when it is set (interceptor behaviour)."""
236+ self .vault_client ._VaultClient__bearer_token = "refreshed_token"
237+ self .vault_client .initialize_api_client ("https://test-vault-url.com" , "initial_token" )
238+
239+ _ , kwargs = mock_skyflow .call_args
240+ self .assertEqual (kwargs ["token" ](), "refreshed_token" )
241+
242+ @patch ("skyflow.vault.client.client.Skyflow" )
243+ def test_initialize_api_client_lambda_falls_back_to_initial_token (self , mock_skyflow ):
244+ """Lambda falls back to the initial token when __bearer_token is None."""
245+ self .vault_client ._VaultClient__bearer_token = None
246+ self .vault_client .initialize_api_client ("https://test-vault-url.com" , "initial_token" )
247+
248+ _ , kwargs = mock_skyflow .call_args
249+ self .assertEqual (kwargs ["token" ](), "initial_token" )
250+
251+ # ------------------------------------------------------------------ #
252+ # get_bearer_token #
253+ # ------------------------------------------------------------------ #
254+
255+ def test_get_bearer_token_with_api_key (self ):
256+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_API_KEY )
257+ self .assertEqual (result , "dummy_api_key" )
258+
259+ def test_get_bearer_token_with_static_token (self ):
260+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_TOKEN )
261+ self .assertEqual (result , "dummy_static_token" )
262+
263+ @patch ("skyflow.vault.client.client.generate_bearer_token" , return_value = ("sa_token" , None ))
264+ def test_get_bearer_token_generates_from_path_on_first_call (self , mock_generate ):
265+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_PATH )
266+ mock_generate .assert_called_once ()
267+ self .assertEqual (result , "sa_token" )
268+ self .assertEqual (self .vault_client ._VaultClient__bearer_token , "sa_token" )
269+
270+ @patch ("skyflow.vault.client.client.generate_bearer_token_from_creds" , return_value = ("sa_token_str" , None ))
271+ @patch ("skyflow.vault.client.client.log_info" )
272+ def test_get_bearer_token_generates_from_credentials_string (self , mock_log , mock_generate ):
273+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_STRING )
274+ mock_generate .assert_called_once ()
275+ self .assertEqual (result , "sa_token_str" )
276+
277+ @patch ("skyflow.vault.client.client.generate_bearer_token" , return_value = ("new_token" , None ))
278+ @patch ("skyflow.vault.client.client.is_expired" , return_value = True )
279+ @patch ("skyflow.vault.client.client.log_info" )
280+ def test_get_bearer_token_regenerates_on_expiry (self , mock_log , mock_is_expired , mock_generate ):
281+ """Expired token is regenerated silently — no exception raised."""
282+ self .vault_client ._VaultClient__bearer_token = "expired_token"
283+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_PATH )
284+ mock_generate .assert_called_once ()
285+ self .assertEqual (result , "new_token" )
69286
70287 @patch ("skyflow.vault.client.client.generate_bearer_token" )
71- @patch ("skyflow.vault.client.client.generate_bearer_token_from_creds" )
288+ @patch ("skyflow.vault.client.client.is_expired" , return_value = False )
72289 @patch ("skyflow.vault.client.client.log_info" )
73- def test_get_bearer_token_with_api_key (self , mock_log_info , mock_generate_bearer_token ,
74- mock_generate_bearer_token_from_creds ):
75- token = self .vault_client .get_bearer_token (CREDENTIALS_WITH_API_KEY )
76- self .assertEqual (token , CREDENTIALS_WITH_API_KEY ["api_key" ])
77-
78- def test_update_config (self ):
79- new_config = {"credentials" : "new_credentials" }
80- self .vault_client .update_config (new_config )
290+ def test_get_bearer_token_reuses_valid_cached_token (self , mock_log , mock_is_expired , mock_generate ):
291+ """Valid cached token is reused without calling generate_bearer_token."""
292+ self .vault_client ._VaultClient__bearer_token = "valid_token"
293+ result = self .vault_client .get_bearer_token (CREDENTIALS_WITH_PATH )
294+ mock_generate .assert_not_called ()
295+ self .assertEqual (result , "valid_token" )
296+
297+ # ------------------------------------------------------------------ #
298+ # update_config #
299+ # ------------------------------------------------------------------ #
300+
301+ def test_update_config_sets_flag (self ):
302+ self .vault_client .update_config ({"credentials" : "new_credentials" })
81303 self .assertTrue (self .vault_client ._VaultClient__is_config_updated )
82304 self .assertEqual (self .vault_client .get_config ()["credentials" ], "new_credentials" )
83305
84- def test_get_config (self ):
85- self .assertEqual (self .vault_client .get_config (), CONFIG )
306+ # ------------------------------------------------------------------ #
307+ # API accessor stubs #
308+ # ------------------------------------------------------------------ #
86309
87- def test_get_common_skyflow_credentials (self ):
88- credentials = {"api_key" : "dummy_api_key" }
89- self .vault_client .set_common_skyflow_credentials (credentials )
90- self .assertEqual (self .vault_client .get_common_skyflow_credentials (), credentials )
310+ def test_get_records_api (self ):
311+ self .vault_client ._VaultClient__api_client = MagicMock ()
312+ self .assertIsNotNone (self .vault_client .get_records_api ())
91313
92- def test_get_log_level (self ):
93- log_level = "DEBUG"
94- self .vault_client .set_logger (log_level , MagicMock ())
95- self .assertEqual (self .vault_client .get_log_level (), log_level )
314+ def test_get_tokens_api (self ):
315+ self .vault_client ._VaultClient__api_client = MagicMock ()
316+ self .assertIsNotNone (self .vault_client .get_tokens_api ())
96317
97- def test_get_logger (self ):
98- mock_logger = MagicMock ()
99- self .vault_client .set_logger ("INFO" , mock_logger )
100- self .assertEqual (self .vault_client .get_logger (), mock_logger )
318+ def test_get_query_api (self ):
319+ self .vault_client ._VaultClient__api_client = MagicMock ()
320+ self .assertIsNotNone (self .vault_client .get_query_api ())
321+
322+
323+ if __name__ == "__main__" :
324+ unittest .main ()
0 commit comments