55from unittest .mock import patch
66import os
77from skyflow .error import SkyflowError
8- from skyflow .service_account import is_expired , generate_bearer_token , \
9- generate_bearer_token_from_creds
8+ from skyflow .service_account import is_expired , generate_bearer_token , generate_bearer_token_from_creds
109from skyflow .utils import SkyflowMessages
1110from skyflow .service_account ._utils import (
12- get_service_account_token , get_signed_jwt , generate_signed_data_tokens ,
13- get_signed_data_token_response_object , generate_signed_data_tokens_from_creds ,
14- _validate_and_resolve_ctx , _normalize_credentials , get_signed_tokens ,
11+ get_service_account_token ,
12+ get_signed_jwt ,
13+ generate_signed_data_tokens ,
14+ get_signed_data_token_response_object ,
15+ generate_signed_data_tokens_from_creds ,
16+ _validate_and_resolve_ctx ,
17+ _normalize_credentials ,
18+ get_signed_tokens ,
1519)
1620
1721creds_path = os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (__file__ ))), "credentials.json" )
18- with open (creds_path , 'r' ) as file :
22+ with open (creds_path , "r" ) as file :
1923 credentials = json .load (file )
2024
2125VALID_CREDENTIALS_STRING = json .dumps (credentials )
2226
23- CREDENTIALS_WITHOUT_CLIENT_ID = {
24- 'privateKey' : 'private_key'
25- }
27+ CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey" : "private_key" }
2628
27- CREDENTIALS_WITHOUT_KEY_ID = {
28- 'privateKey' : 'private_key' ,
29- 'clientID' : 'client_id'
30- }
29+ CREDENTIALS_WITHOUT_KEY_ID = {"privateKey" : "private_key" , "clientID" : "client_id" }
3130
32- CREDENTIALS_WITHOUT_TOKEN_URI = {
33- 'privateKey' : 'private_key' ,
34- 'clientID' : 'client_id' ,
35- 'keyID' : 'key_id'
36- }
31+ CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey" : "private_key" , "clientID" : "client_id" , "keyID" : "key_id" }
3732
3833VALID_SERVICE_ACCOUNT_CREDS = credentials
3934
4035# Snake-case version of the real credentials (keys remapped to snake_case)
4136SNAKE_CASE_CREDS = {
42- ' private_key' : credentials [' privateKey' ],
43- ' client_id' : credentials [' clientID' ],
44- ' key_id' : credentials [' keyID' ],
45- ' token_uri' : credentials [' tokenURI' ],
37+ " private_key" : credentials [" privateKey" ],
38+ " client_id" : credentials [" clientID" ],
39+ " key_id" : credentials [" keyID" ],
40+ " token_uri" : credentials [" tokenURI" ],
4641}
4742
48- SNAKE_CASE_CREDS_STRING = json .dumps ({
49- 'private_key' : credentials ['privateKey' ],
50- 'client_id' : credentials ['clientID' ],
51- 'key_id' : credentials ['keyID' ],
52- 'token_uri' : credentials ['tokenURI' ],
53- })
43+ SNAKE_CASE_CREDS_STRING = json .dumps (
44+ {
45+ "private_key" : credentials ["privateKey" ],
46+ "client_id" : credentials ["clientID" ],
47+ "key_id" : credentials ["keyID" ],
48+ "token_uri" : credentials ["tokenURI" ],
49+ }
50+ )
5451
5552
5653class TestServiceAccountUtils (unittest .TestCase ):
57-
5854 # ── is_expired ────────────────────────────────────────────────────────────
5955
6056 def test_is_expired_none_token (self ):
@@ -144,33 +140,33 @@ def test_get_service_account_token_with_snake_case_creds(self):
144140
145141 def test_get_service_account_token_missing_private_key_snake (self ):
146142 creds = {
147- ' client_id' : 'id' ,
148- ' key_id' : ' kid' ,
149- ' token_uri' : ' https://example.com' ,
143+ " client_id" : "id" ,
144+ " key_id" : " kid" ,
145+ " token_uri" : " https://example.com" ,
150146 }
151147 with self .assertRaises (SkyflowError ) as context :
152148 get_service_account_token (creds , {}, None )
153149 self .assertEqual (context .exception .message , SkyflowMessages .Error .MISSING_PRIVATE_KEY .value )
154150
155151 def test_get_service_account_token_invalid_token_uri (self ):
156152 creds = {
157- ' privateKey' : ' key' ,
158- ' clientID' : 'id' ,
159- ' keyID' : ' kid' ,
160- ' tokenURI' : ' not-a-url' ,
153+ " privateKey" : " key" ,
154+ " clientID" : "id" ,
155+ " keyID" : " kid" ,
156+ " tokenURI" : " not-a-url" ,
161157 }
162158 with self .assertRaises (SkyflowError ) as context :
163159 get_service_account_token (creds , {}, None )
164160 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TOKEN_URI .value )
165161
166162 def test_get_service_account_token_invalid_token_uri_in_options (self ):
167163 creds = {
168- ' privateKey' : ' key' ,
169- ' clientID' : 'id' ,
170- ' keyID' : ' kid' ,
171- ' tokenURI' : ' https://valid-url.com' ,
164+ " privateKey" : " key" ,
165+ " clientID" : "id" ,
166+ " keyID" : " kid" ,
167+ " tokenURI" : " https://valid-url.com" ,
172168 }
173- options = {' token_uri' : ' not-a-valid-url' }
169+ options = {" token_uri" : " not-a-valid-url" }
174170 with self .assertRaises (SkyflowError ) as context :
175171 get_service_account_token (creds , options , None )
176172 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TOKEN_URI .value )
@@ -182,14 +178,14 @@ def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_si
182178 "privateKey" : "private_key" ,
183179 "clientID" : "client_id" ,
184180 "keyID" : "key_id" ,
185- "tokenURI" : "https://valid-url.com"
181+ "tokenURI" : "https://valid-url.com" ,
186182 }
187183 options = {"role_ids" : ["role1" , "role2" ]}
188184 mock_get_signed_jwt .return_value = "signed"
189185 mock_auth_api = mock_auth_client .return_value .get_auth_api .return_value
190- mock_auth_api .authentication_service_get_auth_token .return_value = type ("obj" , (), {
191- "access_token" : "token" , "token_type" : "bearer"
192- } )
186+ mock_auth_api .authentication_service_get_auth_token .return_value = type (
187+ "obj" , (), { " access_token" : "token" , "token_type" : "bearer" }
188+ )
193189 access_token , token_type = get_service_account_token (creds , options , None )
194190 self .assertEqual (access_token , "token" )
195191 self .assertEqual (token_type , "bearer" )
@@ -204,16 +200,18 @@ def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt,
204200 "privateKey" : "private_key" ,
205201 "clientID" : "client_id" ,
206202 "keyID" : "key_id" ,
207- "tokenURI" : "https://valid-url.com"
203+ "tokenURI" : "https://valid-url.com" ,
208204 }
209205 mock_get_signed_jwt .return_value = "signed"
210206 mock_auth_api = mock_auth_client .return_value .get_auth_api .return_value
211207 from skyflow .generated .rest .errors .unauthorized_error import UnauthorizedError
208+
212209 mock_auth_api .authentication_service_get_auth_token .side_effect = UnauthorizedError ("unauthorized" )
213210 with self .assertRaises (SkyflowError ) as context :
214211 get_service_account_token (creds , {}, None )
215- self .assertEqual (context .exception .message ,
216- SkyflowMessages .Error .UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN .value )
212+ self .assertEqual (
213+ context .exception .message , SkyflowMessages .Error .UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN .value
214+ )
217215
218216 @patch ("skyflow.service_account._utils.AuthClient" )
219217 @patch ("skyflow.service_account._utils.get_signed_jwt" )
@@ -222,7 +220,7 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt,
222220 "privateKey" : "private_key" ,
223221 "clientID" : "client_id" ,
224222 "keyID" : "key_id" ,
225- "tokenURI" : "https://valid-url.com"
223+ "tokenURI" : "https://valid-url.com" ,
226224 }
227225 mock_get_signed_jwt .return_value = "signed"
228226 mock_auth_api = mock_auth_client .return_value .get_auth_api .return_value
@@ -266,9 +264,9 @@ def test_get_signed_data_token_response_object(self):
266264 token = "sample_token"
267265 signed_token = "signed_sample_token"
268266 response = get_signed_data_token_response_object (signed_token , token )
269- self .assertIsInstance (response , dict )
270- self .assertEqual (response ["token" ], token )
271- self .assertEqual (response ["signed_token" ], signed_token )
267+ self .assertIsInstance (response , tuple )
268+ self .assertEqual (response [0 ], token )
269+ self .assertEqual (response [1 ], signed_token )
272270
273271 # ── get_signed_tokens ─────────────────────────────────────────────────────
274272
@@ -278,7 +276,7 @@ def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode):
278276 "privateKey" : "private_key" ,
279277 "clientID" : "client_id" ,
280278 "keyID" : "key_id" ,
281- "tokenURI" : "https://valid-url.com"
279+ "tokenURI" : "https://valid-url.com" ,
282280 }
283281 options = {"data_tokens" : ["token1" ]}
284282 with self .assertRaises (SkyflowError ) as context :
@@ -290,16 +288,14 @@ def test_get_signed_tokens_returns_list_one_per_token(self):
290288 self .assertIsInstance (result , list )
291289 self .assertEqual (len (result ), 2 )
292290
293- def test_get_signed_tokens_items_are_dicts_with_token_and_signed_token (self ):
291+ def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token (self ):
294292 result = generate_signed_data_tokens (creds_path , {"data_tokens" : ["token1" , "token2" ]})
295293 for item in result :
296- self .assertIsInstance (item , dict )
297- self .assertIn ("token" , item )
298- self .assertIn ("signed_token" , item )
299- self .assertEqual (result [0 ]["token" ], "token1" )
300- self .assertEqual (result [1 ]["token" ], "token2" )
301- self .assertTrue (result [0 ]["signed_token" ].startswith ("signed_token_" ))
302- self .assertTrue (result [1 ]["signed_token" ].startswith ("signed_token_" ))
294+ self .assertIsInstance (item , tuple )
295+ self .assertEqual (result [0 ][0 ], "token1" )
296+ self .assertEqual (result [1 ][0 ], "token2" )
297+ self .assertTrue (result [0 ][1 ].startswith ("signed_token_" ))
298+ self .assertTrue (result [1 ][1 ].startswith ("signed_token_" ))
303299
304300 def test_get_signed_tokens_returns_list_single_token (self ):
305301 result = generate_signed_data_tokens (creds_path , {"data_tokens" : ["token1" ]})
@@ -396,14 +392,14 @@ def test_get_signed_tokens_with_snake_case_creds(self):
396392 # ── generate_signed_data_tokens (file path) ───────────────────────────────
397393
398394 def test_generate_signed_data_tokens_from_file_path (self ):
399- options = {"data_tokens" : ["token1" , "token2" ], "ctx" : ' ctx' }
395+ options = {"data_tokens" : ["token1" , "token2" ], "ctx" : " ctx" }
400396 result = generate_signed_data_tokens (creds_path , options )
401397 self .assertEqual (len (result ), 2 )
402398
403399 def test_generate_signed_data_tokens_from_invalid_file_path (self ):
404400 options = {"data_tokens" : ["token1" , "token2" ]}
405401 with self .assertRaises (SkyflowError ) as context :
406- generate_signed_data_tokens (' credentials1.json' , options )
402+ generate_signed_data_tokens (" credentials1.json" , options )
407403 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_CREDENTIAL_FILE_PATH .value )
408404
409405 def test_generate_signed_data_tokens_with_dict_ctx (self ):
@@ -421,7 +417,7 @@ def test_generate_signed_data_tokens_from_creds(self):
421417 def test_generate_signed_data_tokens_from_creds_with_invalid_string (self ):
422418 options = {"data_tokens" : ["token1" , "token2" ]}
423419 with self .assertRaises (SkyflowError ) as context :
424- generate_signed_data_tokens_from_creds ('{' , options )
420+ generate_signed_data_tokens_from_creds ("{" , options )
425421 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_CREDENTIALS_STRING .value )
426422
427423 def test_generate_signed_data_tokens_from_creds_with_dict_ctx (self ):
@@ -446,54 +442,54 @@ def test_generate_signed_data_tokens_from_creds_snake(self):
446442
447443 def test_normalize_credentials_snake_case (self ):
448444 snake = {
449- ' private_key' : 'pk' ,
450- ' client_id' : ' cid' ,
451- ' key_id' : ' kid' ,
452- ' token_uri' : ' https://uri' ,
453- ' client_name' : ' name' ,
445+ " private_key" : "pk" ,
446+ " client_id" : " cid" ,
447+ " key_id" : " kid" ,
448+ " token_uri" : " https://uri" ,
449+ " client_name" : " name" ,
454450 }
455451 result = _normalize_credentials (snake )
456- self .assertEqual (result [' privateKey' ], 'pk' )
457- self .assertEqual (result [' clientID' ], ' cid' )
458- self .assertEqual (result [' keyID' ], ' kid' )
459- self .assertEqual (result [' tokenURI' ], ' https://uri' )
460- self .assertEqual (result [' clientName' ], ' name' )
461- self .assertNotIn (' private_key' , result )
462- self .assertNotIn (' client_id' , result )
463- self .assertNotIn (' key_id' , result )
464- self .assertNotIn (' token_uri' , result )
465- self .assertNotIn (' client_name' , result )
452+ self .assertEqual (result [" privateKey" ], "pk" )
453+ self .assertEqual (result [" clientID" ], " cid" )
454+ self .assertEqual (result [" keyID" ], " kid" )
455+ self .assertEqual (result [" tokenURI" ], " https://uri" )
456+ self .assertEqual (result [" clientName" ], " name" )
457+ self .assertNotIn (" private_key" , result )
458+ self .assertNotIn (" client_id" , result )
459+ self .assertNotIn (" key_id" , result )
460+ self .assertNotIn (" token_uri" , result )
461+ self .assertNotIn (" client_name" , result )
466462
467463 def test_normalize_credentials_camel_case_unchanged (self ):
468464 camel = {
469- ' privateKey' : 'pk' ,
470- ' clientID' : ' cid' ,
471- ' keyID' : ' kid' ,
472- ' tokenURI' : ' https://uri' ,
465+ " privateKey" : "pk" ,
466+ " clientID" : " cid" ,
467+ " keyID" : " kid" ,
468+ " tokenURI" : " https://uri" ,
473469 }
474470 result = _normalize_credentials (camel )
475471 self .assertEqual (result , camel )
476472
477473 def test_normalize_credentials_mixed_keys (self ):
478474 mixed = {
479- ' private_key' : 'pk' ,
480- ' clientID' : ' cid' ,
481- ' key_id' : ' kid' ,
482- ' tokenURI' : ' https://uri' ,
475+ " private_key" : "pk" ,
476+ " clientID" : " cid" ,
477+ " key_id" : " kid" ,
478+ " tokenURI" : " https://uri" ,
483479 }
484480 result = _normalize_credentials (mixed )
485- self .assertEqual (result [' privateKey' ], 'pk' )
486- self .assertEqual (result [' clientID' ], ' cid' )
487- self .assertEqual (result [' keyID' ], ' kid' )
488- self .assertEqual (result [' tokenURI' ], ' https://uri' )
489- self .assertNotIn (' private_key' , result )
490- self .assertNotIn (' key_id' , result )
481+ self .assertEqual (result [" privateKey" ], "pk" )
482+ self .assertEqual (result [" clientID" ], " cid" )
483+ self .assertEqual (result [" keyID" ], " kid" )
484+ self .assertEqual (result [" tokenURI" ], " https://uri" )
485+ self .assertNotIn (" private_key" , result )
486+ self .assertNotIn (" key_id" , result )
491487
492488 def test_normalize_credentials_unknown_key_passes_through (self ):
493- creds = {' unknown_field' : ' value' , ' anotherField' : ' val2' }
489+ creds = {" unknown_field" : " value" , " anotherField" : " val2" }
494490 result = _normalize_credentials (creds )
495- self .assertEqual (result [' unknown_field' ], ' value' )
496- self .assertEqual (result [' anotherField' ], ' val2' )
491+ self .assertEqual (result [" unknown_field" ], " value" )
492+ self .assertEqual (result [" anotherField" ], " val2" )
497493
498494 def test_normalize_credentials_empty_dict (self ):
499495 self .assertEqual (_normalize_credentials ({}), {})
@@ -504,11 +500,11 @@ def test_validate_and_resolve_ctx_none(self):
504500 self .assertIsNone (_validate_and_resolve_ctx (None ))
505501
506502 def test_validate_and_resolve_ctx_empty_string (self ):
507- self .assertIsNone (_validate_and_resolve_ctx ('' ))
508- self .assertIsNone (_validate_and_resolve_ctx (' ' ))
503+ self .assertIsNone (_validate_and_resolve_ctx ("" ))
504+ self .assertIsNone (_validate_and_resolve_ctx (" " ))
509505
510506 def test_validate_and_resolve_ctx_valid_string (self ):
511- self .assertEqual (_validate_and_resolve_ctx (' user_12345' ), ' user_12345' )
507+ self .assertEqual (_validate_and_resolve_ctx (" user_12345" ), " user_12345" )
512508
513509 def test_validate_and_resolve_ctx_empty_dict (self ):
514510 self .assertIsNone (_validate_and_resolve_ctx ({}))
@@ -577,9 +573,9 @@ def test_get_service_account_token_with_token_uri_option_override(self, mock_get
577573 options = {"token_uri" : override_uri }
578574 mock_get_signed_jwt .return_value = "signed"
579575 mock_auth_api = mock_auth_client .return_value .get_auth_api .return_value
580- mock_auth_api .authentication_service_get_auth_token .return_value = type ("obj" , (), {
581- "access_token" : "token" , "token_type" : "bearer"
582- } )
576+ mock_auth_api .authentication_service_get_auth_token .return_value = type (
577+ "obj" , (), { " access_token" : "token" , "token_type" : "bearer" }
578+ )
583579 get_service_account_token (creds , options , None )
584580 mock_get_signed_jwt .assert_called_once ()
585581 call_args = mock_get_signed_jwt .call_args
0 commit comments