Skip to content

Commit 83ebbe1

Browse files
SK-2681: update validation for ctx for bearer token generation
1 parent 6e5b60f commit 83ebbe1

2 files changed

Lines changed: 34 additions & 3 deletions

File tree

skyflow/service_account/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def _validate_and_resolve_ctx(ctx):
3535
invalid_input_error_code
3636
)
3737
return ctx
38+
if isinstance(ctx, (bool, int, float)):
39+
return ctx
3840
raise SkyflowError(
3941
SkyflowMessages.Error.INVALID_CTX_TYPE.value,
4042
invalid_input_error_code

tests/service_account/test__utils.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,27 @@ def test_get_signed_jwt_invalid_format(self, mock_jwt_encode):
114114
get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None)
115115
self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value)
116116

117+
@patch("skyflow.service_account._utils.jwt.encode")
118+
def test_get_signed_jwt_with_valid_string_ctx(self, mock_jwt_encode):
119+
mock_jwt_encode.return_value = "mock_token"
120+
get_signed_jwt({"ctx": "valid_ctx"}, "client_id", "key_id", "token_uri", "private_key", None)
121+
payload = mock_jwt_encode.call_args.kwargs["payload"]
122+
self.assertEqual(payload["ctx"], "valid_ctx")
123+
124+
@patch("skyflow.service_account._utils.jwt.encode")
125+
def test_get_signed_jwt_with_valid_dict_ctx(self, mock_jwt_encode):
126+
mock_jwt_encode.return_value = "mock_token"
127+
get_signed_jwt({"ctx": {"role": "admin"}}, "client_id", "key_id", "token_uri", "private_key", None)
128+
payload = mock_jwt_encode.call_args.kwargs["payload"]
129+
self.assertEqual(payload["ctx"], {"role": "admin"})
130+
131+
@patch("skyflow.service_account._utils.jwt.encode")
132+
def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode):
133+
mock_jwt_encode.return_value = "mock_token"
134+
get_signed_jwt({"ctx": ""}, "client_id", "key_id", "token_uri", "private_key", None)
135+
payload = mock_jwt_encode.call_args.kwargs["payload"]
136+
self.assertNotIn("ctx", payload)
137+
117138
def test_get_signed_data_token_response_object(self):
118139
token = "sample_token"
119140
signed_token = "signed_sample_token"
@@ -183,9 +204,17 @@ def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self):
183204
with self.assertRaises(SkyflowError):
184205
_validate_and_resolve_ctx(ctx)
185206

186-
def test_validate_and_resolve_ctx_invalid_type_int(self):
187-
with self.assertRaises(SkyflowError):
188-
_validate_and_resolve_ctx(42)
207+
def test_validate_and_resolve_ctx_valid_type_int(self):
208+
self.assertEqual(_validate_and_resolve_ctx(42), 42)
209+
210+
def test_validate_and_resolve_ctx_valid_type_float(self):
211+
self.assertEqual(_validate_and_resolve_ctx(3.14), 3.14)
212+
213+
def test_validate_and_resolve_ctx_valid_type_bool_true(self):
214+
self.assertEqual(_validate_and_resolve_ctx(True), True)
215+
216+
def test_validate_and_resolve_ctx_valid_type_bool_false(self):
217+
self.assertEqual(_validate_and_resolve_ctx(False), False)
189218

190219
def test_validate_and_resolve_ctx_invalid_type_list(self):
191220
with self.assertRaises(SkyflowError):

0 commit comments

Comments
 (0)