Skip to content

Commit c27cc7e

Browse files
SK-2813: revert get signed data tokens response to tuple
1 parent 7be6e4c commit c27cc7e

2 files changed

Lines changed: 100 additions & 106 deletions

File tree

skyflow/service_account/_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ def generate_signed_data_tokens_from_creds(credentials, options):
239239
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code)
240240
return get_signed_tokens(json_credentials, options)
241241

242+
242243
def get_signed_data_token_response_object(signed_token, actual_token):
243-
return {
244-
ResponseField.TOKEN: actual_token,
245-
ResponseField.SIGNED_TOKEN: signed_token,
246-
}
244+
return actual_token, signed_token

tests/service_account/test__utils.py

Lines changed: 98 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,56 +5,52 @@
55
from unittest.mock import patch
66
import os
77
from skyflow.error import SkyflowError
8-
from skyflow.service_account import is_expired, generate_bearer_token, \
9-
generate_bearer_token_from_creds
8+
from skyflow.service_account import is_expired, generate_bearer_token, generate_bearer_token_from_creds
109
from skyflow.utils import SkyflowMessages
1110
from skyflow.service_account._utils import (
12-
get_service_account_token, get_signed_jwt, generate_signed_data_tokens,
13-
get_signed_data_token_response_object, generate_signed_data_tokens_from_creds,
14-
_validate_and_resolve_ctx, _normalize_credentials, get_signed_tokens,
11+
get_service_account_token,
12+
get_signed_jwt,
13+
generate_signed_data_tokens,
14+
get_signed_data_token_response_object,
15+
generate_signed_data_tokens_from_creds,
16+
_validate_and_resolve_ctx,
17+
_normalize_credentials,
18+
get_signed_tokens,
1519
)
1620

1721
creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json")
18-
with open(creds_path, 'r') as file:
22+
with open(creds_path, "r") as file:
1923
credentials = json.load(file)
2024

2125
VALID_CREDENTIALS_STRING = json.dumps(credentials)
2226

23-
CREDENTIALS_WITHOUT_CLIENT_ID = {
24-
'privateKey': 'private_key'
25-
}
27+
CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey": "private_key"}
2628

27-
CREDENTIALS_WITHOUT_KEY_ID = {
28-
'privateKey': 'private_key',
29-
'clientID': 'client_id'
30-
}
29+
CREDENTIALS_WITHOUT_KEY_ID = {"privateKey": "private_key", "clientID": "client_id"}
3130

32-
CREDENTIALS_WITHOUT_TOKEN_URI = {
33-
'privateKey': 'private_key',
34-
'clientID': 'client_id',
35-
'keyID': 'key_id'
36-
}
31+
CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id"}
3732

3833
VALID_SERVICE_ACCOUNT_CREDS = credentials
3934

4035
# Snake-case version of the real credentials (keys remapped to snake_case)
4136
SNAKE_CASE_CREDS = {
42-
'private_key': credentials['privateKey'],
43-
'client_id': credentials['clientID'],
44-
'key_id': credentials['keyID'],
45-
'token_uri': credentials['tokenURI'],
37+
"private_key": credentials["privateKey"],
38+
"client_id": credentials["clientID"],
39+
"key_id": credentials["keyID"],
40+
"token_uri": credentials["tokenURI"],
4641
}
4742

48-
SNAKE_CASE_CREDS_STRING = json.dumps({
49-
'private_key': credentials['privateKey'],
50-
'client_id': credentials['clientID'],
51-
'key_id': credentials['keyID'],
52-
'token_uri': credentials['tokenURI'],
53-
})
43+
SNAKE_CASE_CREDS_STRING = json.dumps(
44+
{
45+
"private_key": credentials["privateKey"],
46+
"client_id": credentials["clientID"],
47+
"key_id": credentials["keyID"],
48+
"token_uri": credentials["tokenURI"],
49+
}
50+
)
5451

5552

5653
class TestServiceAccountUtils(unittest.TestCase):
57-
5854
# ── is_expired ────────────────────────────────────────────────────────────
5955

6056
def test_is_expired_none_token(self):
@@ -144,33 +140,33 @@ def test_get_service_account_token_with_snake_case_creds(self):
144140

145141
def test_get_service_account_token_missing_private_key_snake(self):
146142
creds = {
147-
'client_id': 'id',
148-
'key_id': 'kid',
149-
'token_uri': 'https://example.com',
143+
"client_id": "id",
144+
"key_id": "kid",
145+
"token_uri": "https://example.com",
150146
}
151147
with self.assertRaises(SkyflowError) as context:
152148
get_service_account_token(creds, {}, None)
153149
self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value)
154150

155151
def test_get_service_account_token_invalid_token_uri(self):
156152
creds = {
157-
'privateKey': 'key',
158-
'clientID': 'id',
159-
'keyID': 'kid',
160-
'tokenURI': 'not-a-url',
153+
"privateKey": "key",
154+
"clientID": "id",
155+
"keyID": "kid",
156+
"tokenURI": "not-a-url",
161157
}
162158
with self.assertRaises(SkyflowError) as context:
163159
get_service_account_token(creds, {}, None)
164160
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value)
165161

166162
def test_get_service_account_token_invalid_token_uri_in_options(self):
167163
creds = {
168-
'privateKey': 'key',
169-
'clientID': 'id',
170-
'keyID': 'kid',
171-
'tokenURI': 'https://valid-url.com',
164+
"privateKey": "key",
165+
"clientID": "id",
166+
"keyID": "kid",
167+
"tokenURI": "https://valid-url.com",
172168
}
173-
options = {'token_uri': 'not-a-valid-url'}
169+
options = {"token_uri": "not-a-valid-url"}
174170
with self.assertRaises(SkyflowError) as context:
175171
get_service_account_token(creds, options, None)
176172
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value)
@@ -182,14 +178,14 @@ def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_si
182178
"privateKey": "private_key",
183179
"clientID": "client_id",
184180
"keyID": "key_id",
185-
"tokenURI": "https://valid-url.com"
181+
"tokenURI": "https://valid-url.com",
186182
}
187183
options = {"role_ids": ["role1", "role2"]}
188184
mock_get_signed_jwt.return_value = "signed"
189185
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
190-
mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {
191-
"access_token": "token", "token_type": "bearer"
192-
})
186+
mock_auth_api.authentication_service_get_auth_token.return_value = type(
187+
"obj", (), {"access_token": "token", "token_type": "bearer"}
188+
)
193189
access_token, token_type = get_service_account_token(creds, options, None)
194190
self.assertEqual(access_token, "token")
195191
self.assertEqual(token_type, "bearer")
@@ -204,16 +200,18 @@ def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt,
204200
"privateKey": "private_key",
205201
"clientID": "client_id",
206202
"keyID": "key_id",
207-
"tokenURI": "https://valid-url.com"
203+
"tokenURI": "https://valid-url.com",
208204
}
209205
mock_get_signed_jwt.return_value = "signed"
210206
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
211207
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
208+
212209
mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized")
213210
with self.assertRaises(SkyflowError) as context:
214211
get_service_account_token(creds, {}, None)
215-
self.assertEqual(context.exception.message,
216-
SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value)
212+
self.assertEqual(
213+
context.exception.message, SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value
214+
)
217215

218216
@patch("skyflow.service_account._utils.AuthClient")
219217
@patch("skyflow.service_account._utils.get_signed_jwt")
@@ -222,7 +220,7 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt,
222220
"privateKey": "private_key",
223221
"clientID": "client_id",
224222
"keyID": "key_id",
225-
"tokenURI": "https://valid-url.com"
223+
"tokenURI": "https://valid-url.com",
226224
}
227225
mock_get_signed_jwt.return_value = "signed"
228226
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
@@ -266,9 +264,9 @@ def test_get_signed_data_token_response_object(self):
266264
token = "sample_token"
267265
signed_token = "signed_sample_token"
268266
response = get_signed_data_token_response_object(signed_token, token)
269-
self.assertIsInstance(response, dict)
270-
self.assertEqual(response["token"], token)
271-
self.assertEqual(response["signed_token"], signed_token)
267+
self.assertIsInstance(response, tuple)
268+
self.assertEqual(response[0], token)
269+
self.assertEqual(response[1], signed_token)
272270

273271
# ── get_signed_tokens ─────────────────────────────────────────────────────
274272

@@ -278,7 +276,7 @@ def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode):
278276
"privateKey": "private_key",
279277
"clientID": "client_id",
280278
"keyID": "key_id",
281-
"tokenURI": "https://valid-url.com"
279+
"tokenURI": "https://valid-url.com",
282280
}
283281
options = {"data_tokens": ["token1"]}
284282
with self.assertRaises(SkyflowError) as context:
@@ -290,16 +288,14 @@ def test_get_signed_tokens_returns_list_one_per_token(self):
290288
self.assertIsInstance(result, list)
291289
self.assertEqual(len(result), 2)
292290

293-
def test_get_signed_tokens_items_are_dicts_with_token_and_signed_token(self):
291+
def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token(self):
294292
result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]})
295293
for item in result:
296-
self.assertIsInstance(item, dict)
297-
self.assertIn("token", item)
298-
self.assertIn("signed_token", item)
299-
self.assertEqual(result[0]["token"], "token1")
300-
self.assertEqual(result[1]["token"], "token2")
301-
self.assertTrue(result[0]["signed_token"].startswith("signed_token_"))
302-
self.assertTrue(result[1]["signed_token"].startswith("signed_token_"))
294+
self.assertIsInstance(item, tuple)
295+
self.assertEqual(result[0][0], "token1")
296+
self.assertEqual(result[1][0], "token2")
297+
self.assertTrue(result[0][1].startswith("signed_token_"))
298+
self.assertTrue(result[1][1].startswith("signed_token_"))
303299

304300
def test_get_signed_tokens_returns_list_single_token(self):
305301
result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]})
@@ -396,14 +392,14 @@ def test_get_signed_tokens_with_snake_case_creds(self):
396392
# ── generate_signed_data_tokens (file path) ───────────────────────────────
397393

398394
def test_generate_signed_data_tokens_from_file_path(self):
399-
options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'}
395+
options = {"data_tokens": ["token1", "token2"], "ctx": "ctx"}
400396
result = generate_signed_data_tokens(creds_path, options)
401397
self.assertEqual(len(result), 2)
402398

403399
def test_generate_signed_data_tokens_from_invalid_file_path(self):
404400
options = {"data_tokens": ["token1", "token2"]}
405401
with self.assertRaises(SkyflowError) as context:
406-
generate_signed_data_tokens('credentials1.json', options)
402+
generate_signed_data_tokens("credentials1.json", options)
407403
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value)
408404

409405
def test_generate_signed_data_tokens_with_dict_ctx(self):
@@ -421,7 +417,7 @@ def test_generate_signed_data_tokens_from_creds(self):
421417
def test_generate_signed_data_tokens_from_creds_with_invalid_string(self):
422418
options = {"data_tokens": ["token1", "token2"]}
423419
with self.assertRaises(SkyflowError) as context:
424-
generate_signed_data_tokens_from_creds('{', options)
420+
generate_signed_data_tokens_from_creds("{", options)
425421
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
426422

427423
def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self):
@@ -446,54 +442,54 @@ def test_generate_signed_data_tokens_from_creds_snake(self):
446442

447443
def test_normalize_credentials_snake_case(self):
448444
snake = {
449-
'private_key': 'pk',
450-
'client_id': 'cid',
451-
'key_id': 'kid',
452-
'token_uri': 'https://uri',
453-
'client_name': 'name',
445+
"private_key": "pk",
446+
"client_id": "cid",
447+
"key_id": "kid",
448+
"token_uri": "https://uri",
449+
"client_name": "name",
454450
}
455451
result = _normalize_credentials(snake)
456-
self.assertEqual(result['privateKey'], 'pk')
457-
self.assertEqual(result['clientID'], 'cid')
458-
self.assertEqual(result['keyID'], 'kid')
459-
self.assertEqual(result['tokenURI'], 'https://uri')
460-
self.assertEqual(result['clientName'], 'name')
461-
self.assertNotIn('private_key', result)
462-
self.assertNotIn('client_id', result)
463-
self.assertNotIn('key_id', result)
464-
self.assertNotIn('token_uri', result)
465-
self.assertNotIn('client_name', result)
452+
self.assertEqual(result["privateKey"], "pk")
453+
self.assertEqual(result["clientID"], "cid")
454+
self.assertEqual(result["keyID"], "kid")
455+
self.assertEqual(result["tokenURI"], "https://uri")
456+
self.assertEqual(result["clientName"], "name")
457+
self.assertNotIn("private_key", result)
458+
self.assertNotIn("client_id", result)
459+
self.assertNotIn("key_id", result)
460+
self.assertNotIn("token_uri", result)
461+
self.assertNotIn("client_name", result)
466462

467463
def test_normalize_credentials_camel_case_unchanged(self):
468464
camel = {
469-
'privateKey': 'pk',
470-
'clientID': 'cid',
471-
'keyID': 'kid',
472-
'tokenURI': 'https://uri',
465+
"privateKey": "pk",
466+
"clientID": "cid",
467+
"keyID": "kid",
468+
"tokenURI": "https://uri",
473469
}
474470
result = _normalize_credentials(camel)
475471
self.assertEqual(result, camel)
476472

477473
def test_normalize_credentials_mixed_keys(self):
478474
mixed = {
479-
'private_key': 'pk',
480-
'clientID': 'cid',
481-
'key_id': 'kid',
482-
'tokenURI': 'https://uri',
475+
"private_key": "pk",
476+
"clientID": "cid",
477+
"key_id": "kid",
478+
"tokenURI": "https://uri",
483479
}
484480
result = _normalize_credentials(mixed)
485-
self.assertEqual(result['privateKey'], 'pk')
486-
self.assertEqual(result['clientID'], 'cid')
487-
self.assertEqual(result['keyID'], 'kid')
488-
self.assertEqual(result['tokenURI'], 'https://uri')
489-
self.assertNotIn('private_key', result)
490-
self.assertNotIn('key_id', result)
481+
self.assertEqual(result["privateKey"], "pk")
482+
self.assertEqual(result["clientID"], "cid")
483+
self.assertEqual(result["keyID"], "kid")
484+
self.assertEqual(result["tokenURI"], "https://uri")
485+
self.assertNotIn("private_key", result)
486+
self.assertNotIn("key_id", result)
491487

492488
def test_normalize_credentials_unknown_key_passes_through(self):
493-
creds = {'unknown_field': 'value', 'anotherField': 'val2'}
489+
creds = {"unknown_field": "value", "anotherField": "val2"}
494490
result = _normalize_credentials(creds)
495-
self.assertEqual(result['unknown_field'], 'value')
496-
self.assertEqual(result['anotherField'], 'val2')
491+
self.assertEqual(result["unknown_field"], "value")
492+
self.assertEqual(result["anotherField"], "val2")
497493

498494
def test_normalize_credentials_empty_dict(self):
499495
self.assertEqual(_normalize_credentials({}), {})
@@ -504,11 +500,11 @@ def test_validate_and_resolve_ctx_none(self):
504500
self.assertIsNone(_validate_and_resolve_ctx(None))
505501

506502
def test_validate_and_resolve_ctx_empty_string(self):
507-
self.assertIsNone(_validate_and_resolve_ctx(''))
508-
self.assertIsNone(_validate_and_resolve_ctx(' '))
503+
self.assertIsNone(_validate_and_resolve_ctx(""))
504+
self.assertIsNone(_validate_and_resolve_ctx(" "))
509505

510506
def test_validate_and_resolve_ctx_valid_string(self):
511-
self.assertEqual(_validate_and_resolve_ctx('user_12345'), 'user_12345')
507+
self.assertEqual(_validate_and_resolve_ctx("user_12345"), "user_12345")
512508

513509
def test_validate_and_resolve_ctx_empty_dict(self):
514510
self.assertIsNone(_validate_and_resolve_ctx({}))
@@ -577,9 +573,9 @@ def test_get_service_account_token_with_token_uri_option_override(self, mock_get
577573
options = {"token_uri": override_uri}
578574
mock_get_signed_jwt.return_value = "signed"
579575
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
580-
mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {
581-
"access_token": "token", "token_type": "bearer"
582-
})
576+
mock_auth_api.authentication_service_get_auth_token.return_value = type(
577+
"obj", (), {"access_token": "token", "token_type": "bearer"}
578+
)
583579
get_service_account_token(creds, options, None)
584580
mock_get_signed_jwt.assert_called_once()
585581
call_args = mock_get_signed_jwt.call_args

0 commit comments

Comments
 (0)