11import unittest
2- from unittest .mock import patch , Mock
2+ from unittest .mock import patch , Mock , MagicMock , PropertyMock
33import os
4- from unittest .mock import MagicMock
54from urllib .parse import quote
65import tempfile , json
76from requests import PreparedRequest
1514 parse_detokenize_response , parse_tokenize_response , parse_query_response , parse_invoke_connection_response , \
1615 handle_exception , validate_api_key , encode_column_values , parse_deidentify_text_response , \
1716 parse_reidentify_text_response , convert_detected_entity_to_entity_info
18- from skyflow .utils ._utils import parse_path_params , to_lowercase_keys , get_metrics , handle_json_error
17+ from skyflow .utils ._utils import parse_path_params , to_lowercase_keys , get_metrics , handle_json_error , r_urlencode
1918from skyflow .utils .enums import EnvUrls , Env , ContentType
2019from skyflow .vault .connection import InvokeConnectionResponse
2120from skyflow .vault .data import InsertResponse , DeleteResponse , GetResponse , QueryResponse
@@ -36,6 +35,13 @@ def test_get_credentials_env_variable(self):
3635 credentials_string = credentials .get ('credentials_string' )
3736 self .assertEqual (credentials_string , json .dumps (VALID_ENV_CREDENTIALS ).replace ('\n ' , '\\ n' ))
3837
38+ @patch ("skyflow.utils._utils.dotenv.find_dotenv" , return_value = None )
39+ @patch .dict (os .environ , {}, clear = True )
40+ def test_get_credentials_no_credentials_raises (self , mock_find_dotenv ):
41+ with self .assertRaises (SkyflowError ) as context :
42+ get_credentials (config_level_creds = None , common_skyflow_creds = None )
43+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_CREDENTIALS .value )
44+
3945 def test_get_credentials_with_config_level_creds (self ):
4046 test_creds = {"authToken" : "test_token" }
4147 creds = get_credentials (config_level_creds = test_creds )
@@ -140,13 +146,54 @@ def test_to_lowercase_keys(self):
140146 expected_output = {"key1" : "value1" , "key2" : "value2" }
141147 self .assertEqual (to_lowercase_keys (input_dict ), expected_output )
142148
149+ def test_r_urlencode_with_list_input (self ):
150+ pairs = {}
151+ r_urlencode ([], pairs , ["a" , "b" ])
152+ self .assertIn ("[0]" , pairs )
153+ self .assertIn ("[1]" , pairs )
154+ self .assertEqual (pairs ["[0]" ], "a" )
155+ self .assertEqual (pairs ["[1]" ], "b" )
156+
157+ def test_r_urlencode_with_tuple_input (self ):
158+ pairs = {}
159+ r_urlencode ([], pairs , ("x" , "y" ))
160+ self .assertIn ("[0]" , pairs )
161+ self .assertEqual (pairs ["[0]" ], "x" )
162+
143163 def test_get_metrics (self ):
144164 metrics = get_metrics ()
145165 self .assertIn ('sdk_name_version' , metrics )
146166 self .assertIn ('sdk_client_device_model' , metrics )
147167 self .assertIn ('sdk_client_os_details' , metrics )
148168 self .assertIn ('sdk_runtime_details' , metrics )
149169
170+ def test_get_metrics_platform_node_exception (self ):
171+ import skyflow .utils ._utils as utils_module
172+ utils_module ._CACHED_METRICS .clear ()
173+ with patch ("skyflow.utils._utils.platform" ) as mock_platform :
174+ mock_platform .node .side_effect = OSError ("no node" )
175+ metrics = utils_module .get_metrics ()
176+ self .assertEqual (metrics ["sdk_client_device_model" ], "" )
177+ utils_module ._CACHED_METRICS .clear ()
178+
179+ def test_get_metrics_sys_attribute_exception (self ):
180+ import skyflow .utils ._utils as utils_module
181+ utils_module ._CACHED_METRICS .clear ()
182+
183+ class _RaisingSys :
184+ @property
185+ def platform (self ):
186+ raise RuntimeError ("no platform" )
187+ @property
188+ def version (self ):
189+ raise RuntimeError ("no version" )
190+
191+ with patch ("skyflow.utils._utils.sys" , _RaisingSys ()):
192+ metrics = utils_module .get_metrics ()
193+ self .assertEqual (metrics ["sdk_client_os_details" ], "" )
194+ self .assertIn ("sdk_runtime_details" , metrics )
195+ utils_module ._CACHED_METRICS .clear ()
196+
150197
151198 def test_construct_invoke_connection_request_valid (self ):
152199 mock_connection_request = Mock ()
@@ -244,6 +291,16 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self):
244291
245292 self .assertIsInstance (result , PreparedRequest )
246293
294+ def test_parse_insert_response_with_tokens_continue_on_error (self ):
295+ api_response = Mock ()
296+ api_response .headers = {"x-request-id" : "req-1" }
297+ api_response .data = Mock (responses = [
298+ {"Status" : 200 , "Body" : {"records" : [{"skyflow_id" : "id1" , "tokens" : {"col1" : "tok1" }}]}},
299+ ])
300+ result = parse_insert_response (api_response , continue_on_error = True )
301+ self .assertEqual (result .inserted_fields [0 ]["col1" ], "tok1" )
302+ self .assertEqual (result .inserted_fields [0 ]["skyflow_id" ], "id1" )
303+
247304 def test_parse_insert_response (self ):
248305 api_response = Mock ()
249306 api_response .headers = {"x-request-id" : "12345" , "content-type" : "application/json" }
@@ -423,6 +480,31 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel
423480 self .assertEqual (context .exception .message , "Not Found" )
424481 self .assertEqual (context .exception .request_id , "1234" )
425482
483+ @patch ("requests.Response" )
484+ def test_parse_invoke_connection_response_with_error_from_client_header (self , mock_response ):
485+ from requests .models import HTTPError
486+ mock_response .status_code = 400
487+ mock_response .content = json .dumps ({
488+ "error" : {
489+ "message" : "Client error" ,
490+ "http_code" : 400 ,
491+ "http_status" : "Bad Request" ,
492+ "grpc_code" : 3 ,
493+ "details" : None ,
494+ }
495+ }).encode ("utf-8" )
496+ mock_response .headers = {
497+ "x-request-id" : "rid-1" ,
498+ "error-from-client" : "true" ,
499+ }
500+ mock_response .raise_for_status .side_effect = HTTPError ("400" )
501+ with self .assertRaises (SkyflowError ) as context :
502+ parse_invoke_connection_response (mock_response )
503+ err = context .exception
504+ self .assertEqual (err .message , "Client error" )
505+ self .assertIsNotNone (err .details )
506+ self .assertTrue (any (d .get ("error_from_client" ) is True for d in err .details ))
507+
426508 @patch ("requests.Response" )
427509 def test_parse_invoke_connection_response_http_error_without_json_error_message (self , mock_response ):
428510 mock_response .status_code = 500
0 commit comments