Skip to content

Commit 38527d1

Browse files
SK-2813: add unit tests
1 parent 9fbbaa4 commit 38527d1

3 files changed

Lines changed: 196 additions & 3 deletions

File tree

tests/service_account/test__utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class TestServiceAccountUtils(unittest.TestCase):
5757

5858
# ── is_expired ────────────────────────────────────────────────────────────
5959

60+
def test_is_expired_none_token(self):
61+
self.assertTrue(is_expired(None))
62+
6063
def test_is_expired_empty_token(self):
6164
self.assertTrue(is_expired(""))
6265

@@ -160,6 +163,18 @@ def test_get_service_account_token_invalid_token_uri(self):
160163
get_service_account_token(creds, {}, None)
161164
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value)
162165

166+
def test_get_service_account_token_invalid_token_uri_in_options(self):
167+
creds = {
168+
'privateKey': 'key',
169+
'clientID': 'id',
170+
'keyID': 'kid',
171+
'tokenURI': 'https://valid-url.com',
172+
}
173+
options = {'token_uri': 'not-a-valid-url'}
174+
with self.assertRaises(SkyflowError) as context:
175+
get_service_account_token(creds, options, None)
176+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value)
177+
163178
@patch("skyflow.service_account._utils.AuthClient")
164179
@patch("skyflow.service_account._utils.get_signed_jwt")
165180
def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client):

tests/utils/test__utils.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
2-
from unittest.mock import patch, Mock
2+
from unittest.mock import patch, Mock, MagicMock, PropertyMock
33
import os
4-
from unittest.mock import MagicMock
54
from urllib.parse import quote
65
import tempfile, json
76
from requests import PreparedRequest
@@ -15,7 +14,7 @@
1514
parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \
1615
handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \
1716
parse_reidentify_text_response, convert_detected_entity_to_entity_info
18-
from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error
17+
from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode
1918
from skyflow.utils.enums import EnvUrls, Env, ContentType
2019
from skyflow.vault.connection import InvokeConnectionResponse
2120
from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse
@@ -36,6 +35,13 @@ def test_get_credentials_env_variable(self):
3635
credentials_string = credentials.get('credentials_string')
3736
self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n'))
3837

38+
@patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None)
39+
@patch.dict(os.environ, {}, clear=True)
40+
def test_get_credentials_no_credentials_raises(self, mock_find_dotenv):
41+
with self.assertRaises(SkyflowError) as context:
42+
get_credentials(config_level_creds=None, common_skyflow_creds=None)
43+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)
44+
3945
def test_get_credentials_with_config_level_creds(self):
4046
test_creds = {"authToken": "test_token"}
4147
creds = get_credentials(config_level_creds=test_creds)
@@ -140,13 +146,54 @@ def test_to_lowercase_keys(self):
140146
expected_output = {"key1": "value1", "key2": "value2"}
141147
self.assertEqual(to_lowercase_keys(input_dict), expected_output)
142148

149+
def test_r_urlencode_with_list_input(self):
150+
pairs = {}
151+
r_urlencode([], pairs, ["a", "b"])
152+
self.assertIn("[0]", pairs)
153+
self.assertIn("[1]", pairs)
154+
self.assertEqual(pairs["[0]"], "a")
155+
self.assertEqual(pairs["[1]"], "b")
156+
157+
def test_r_urlencode_with_tuple_input(self):
158+
pairs = {}
159+
r_urlencode([], pairs, ("x", "y"))
160+
self.assertIn("[0]", pairs)
161+
self.assertEqual(pairs["[0]"], "x")
162+
143163
def test_get_metrics(self):
144164
metrics = get_metrics()
145165
self.assertIn('sdk_name_version', metrics)
146166
self.assertIn('sdk_client_device_model', metrics)
147167
self.assertIn('sdk_client_os_details', metrics)
148168
self.assertIn('sdk_runtime_details', metrics)
149169

170+
def test_get_metrics_platform_node_exception(self):
171+
import skyflow.utils._utils as utils_module
172+
utils_module._CACHED_METRICS.clear()
173+
with patch("skyflow.utils._utils.platform") as mock_platform:
174+
mock_platform.node.side_effect = OSError("no node")
175+
metrics = utils_module.get_metrics()
176+
self.assertEqual(metrics["sdk_client_device_model"], "")
177+
utils_module._CACHED_METRICS.clear()
178+
179+
def test_get_metrics_sys_attribute_exception(self):
180+
import skyflow.utils._utils as utils_module
181+
utils_module._CACHED_METRICS.clear()
182+
183+
class _RaisingSys:
184+
@property
185+
def platform(self):
186+
raise RuntimeError("no platform")
187+
@property
188+
def version(self):
189+
raise RuntimeError("no version")
190+
191+
with patch("skyflow.utils._utils.sys", _RaisingSys()):
192+
metrics = utils_module.get_metrics()
193+
self.assertEqual(metrics["sdk_client_os_details"], "")
194+
self.assertIn("sdk_runtime_details", metrics)
195+
utils_module._CACHED_METRICS.clear()
196+
150197

151198
def test_construct_invoke_connection_request_valid(self):
152199
mock_connection_request = Mock()
@@ -244,6 +291,16 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self):
244291

245292
self.assertIsInstance(result, PreparedRequest)
246293

294+
def test_parse_insert_response_with_tokens_continue_on_error(self):
295+
api_response = Mock()
296+
api_response.headers = {"x-request-id": "req-1"}
297+
api_response.data = Mock(responses=[
298+
{"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}},
299+
])
300+
result = parse_insert_response(api_response, continue_on_error=True)
301+
self.assertEqual(result.inserted_fields[0]["col1"], "tok1")
302+
self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1")
303+
247304
def test_parse_insert_response(self):
248305
api_response = Mock()
249306
api_response.headers = {"x-request-id": "12345", "content-type": "application/json"}
@@ -423,6 +480,31 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel
423480
self.assertEqual(context.exception.message, "Not Found")
424481
self.assertEqual(context.exception.request_id, "1234")
425482

483+
@patch("requests.Response")
484+
def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response):
485+
from requests.models import HTTPError
486+
mock_response.status_code = 400
487+
mock_response.content = json.dumps({
488+
"error": {
489+
"message": "Client error",
490+
"http_code": 400,
491+
"http_status": "Bad Request",
492+
"grpc_code": 3,
493+
"details": None,
494+
}
495+
}).encode("utf-8")
496+
mock_response.headers = {
497+
"x-request-id": "rid-1",
498+
"error-from-client": "true",
499+
}
500+
mock_response.raise_for_status.side_effect = HTTPError("400")
501+
with self.assertRaises(SkyflowError) as context:
502+
parse_invoke_connection_response(mock_response)
503+
err = context.exception
504+
self.assertEqual(err.message, "Client error")
505+
self.assertIsNotNone(err.details)
506+
self.assertTrue(any(d.get("error_from_client") is True for d in err.details))
507+
426508
@patch("requests.Response")
427509
def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response):
428510
mock_response.status_code = 500

tests/vault/controller/test__detect.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock, patch, MagicMock
33
import base64
44
import os
5+
import tempfile
56
from skyflow.error import SkyflowError
67
from skyflow.generated.rest import WordCharacterCount
78
from skyflow.utils import SkyflowMessages
@@ -707,3 +708,98 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba
707708
self.assertIsNone(result.page_count)
708709
self.assertIsNone(result.slide_count)
709710
self.assertEqual(result.entities, [])
711+
712+
def test_poll_for_processed_file_exception(self):
713+
files_api = Mock()
714+
files_api.with_raw_response = files_api
715+
files_api.get_run.side_effect = Exception("poll error")
716+
self.vault_client.get_detect_file_api.return_value = files_api
717+
with self.assertRaises(Exception):
718+
self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5)
719+
720+
def test_save_output_directory_not_exists(self):
721+
output = Mock()
722+
output.processedFile = base64.b64encode(b"data").decode()
723+
output.processedFileType = "redacted_file"
724+
output.processedFileExtension = "txt"
725+
response = Mock()
726+
response.output = [output]
727+
with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False):
728+
self.detect._Detect__save_deidentify_file_response_output(
729+
response, "/nonexistent_dir", "file.txt", "file"
730+
)
731+
732+
def test_save_output_second_non_redacted_item(self):
733+
with tempfile.TemporaryDirectory() as tmp_dir:
734+
output1 = Mock()
735+
output1.processedFile = base64.b64encode(b"data1").decode()
736+
output1.processedFileType = "redacted_file"
737+
output1.processedFileExtension = "txt"
738+
output2 = Mock()
739+
output2.processedFile = base64.b64encode(b"data2").decode()
740+
output2.processedFileType = "entities"
741+
output2.processedFileExtension = "json"
742+
response = Mock()
743+
response.output = [output1, output2]
744+
self.detect._Detect__save_deidentify_file_response_output(
745+
response, tmp_dir, "original.txt", "original"
746+
)
747+
748+
def test_save_output_path_traversal_blocked(self):
749+
output = Mock()
750+
output.processedFile = base64.b64encode(b"data").decode()
751+
output.processedFileType = "redacted_file"
752+
output.processedFileExtension = "txt"
753+
response = Mock()
754+
response.output = [output]
755+
call_count = [0]
756+
757+
def fake_realpath(p):
758+
call_count[0] += 1
759+
if call_count[0] == 1:
760+
return "/safe_dir"
761+
return "/outside/path"
762+
763+
with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \
764+
patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath):
765+
self.detect._Detect__save_deidentify_file_response_output(
766+
response, "/safe_dir", "file.txt", "file"
767+
)
768+
769+
def test_save_output_write_exception(self):
770+
with tempfile.TemporaryDirectory() as tmp_dir:
771+
output = Mock()
772+
output.processedFile = base64.b64encode(b"data").decode()
773+
output.processedFileType = "redacted_file"
774+
output.processedFileExtension = "txt"
775+
response = Mock()
776+
response.output = [output]
777+
with patch("skyflow.vault.controller._detect.base64.b64decode",
778+
side_effect=Exception("decode error")), \
779+
self.assertRaises(Exception):
780+
self.detect._Detect__save_deidentify_file_response_output(
781+
response, tmp_dir, "file.txt", "file"
782+
)
783+
784+
@patch("skyflow.vault.controller._detect.validate_deidentify_file_request")
785+
@patch("skyflow.vault.controller._detect.base64")
786+
def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate):
787+
file_content = b"test content"
788+
file_obj = Mock()
789+
file_obj.read.return_value = file_content
790+
file_obj.name = "test.txt"
791+
mock_base64.b64encode.return_value.decode.return_value = "encoded"
792+
req = DeidentifyFileRequest(file=FileInput(file=file_obj))
793+
req.entities = []
794+
req.token_format = None
795+
req.allow_regex_list = []
796+
req.restrict_regex_list = []
797+
req.transformations = None
798+
req.output_directory = None
799+
req.wait_time = None
800+
files_api = Mock()
801+
files_api.with_raw_response = files_api
802+
files_api.deidentify_text.side_effect = Exception("API error inside try")
803+
self.vault_client.get_detect_file_api.return_value = files_api
804+
with self.assertRaises(Exception):
805+
self.detect.deidentify_file(req)

0 commit comments

Comments
 (0)