Skip to content

Commit 3626913

Browse files
SK-2777: Public Release - Update Fern client re-initialisation (#240)
* SK-2777: Update Fern client re-initialisation (#239)
1 parent ad095e4 commit 3626913

11 files changed

Lines changed: 343 additions & 116 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
if sys.version_info < (3, 8):
99
raise RuntimeError("skyflow requires Python 3.8+")
10-
current_version = '2.0.0'
10+
current_version = '2.0.0.dev0+e2aa629'
1111

1212
setup(
1313
name='skyflow',

skyflow/error/_skyflow_error.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@ def __init__(self,
1515
self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value
1616
self.details = details
1717
self.request_id = request_id
18-
log_error(message, http_code, request_id, grpc_code, http_status, details)
1918
super().__init__()

skyflow/utils/_skyflow_messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Error(Enum):
4747
EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token."
4848
INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string."
4949
INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string."
50-
EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
50+
EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
5151
EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key."
5252
EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key."
5353
INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string."

skyflow/utils/_utils.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,18 @@
3030
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
3131

3232
def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
33-
dotenv.load_dotenv()
34-
dotenv_path = dotenv.find_dotenv(usecwd=True)
35-
if dotenv_path:
36-
load_dotenv(dotenv_path)
37-
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
3833
if config_level_creds:
3934
return config_level_creds
4035
if common_skyflow_creds:
4136
return common_skyflow_creds
37+
dotenv_path = dotenv.find_dotenv(usecwd=True)
38+
if dotenv_path:
39+
load_dotenv(dotenv_path)
40+
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
4241
if env_skyflow_credentials:
43-
env_skyflow_credentials.strip()
44-
try:
45-
env_creds = env_skyflow_credentials.replace('\n', '\\n')
46-
return {
47-
'credentials_string': env_creds
48-
}
49-
except json.JSONDecodeError:
50-
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
51-
else:
52-
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
42+
env_creds = env_skyflow_credentials.strip().replace('\n', '\\n')
43+
return {'credentials_string': env_creds}
44+
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
5345

5446
def validate_api_key(api_key: str, logger = None) -> bool:
5547
if len(api_key) != 42:
@@ -185,8 +177,12 @@ def get_data_from_content_type(data, content_type):
185177
return converted_data, files
186178

187179

180+
_CACHED_METRICS: dict = {}
181+
188182
def get_metrics():
189-
sdk_name_version = "skyflow-python@" + SDK_VERSION
183+
global _CACHED_METRICS
184+
if _CACHED_METRICS:
185+
return _CACHED_METRICS
190186

191187
try:
192188
sdk_client_device_model = platform.node()
@@ -203,13 +199,13 @@ def get_metrics():
203199
except Exception:
204200
sdk_runtime_details = ""
205201

206-
details_dic = {
207-
'sdk_name_version': sdk_name_version,
202+
_CACHED_METRICS = {
203+
'sdk_name_version': "skyflow-python@" + SDK_VERSION,
208204
'sdk_client_device_model': sdk_client_device_model,
209205
'sdk_client_os_details': sdk_client_os_details,
210206
'sdk_runtime_details': "Python " + sdk_runtime_details,
211207
}
212-
return details_dic
208+
return _CACHED_METRICS
213209

214210
def parse_insert_response(api_response, continue_on_error):
215211
# Retrieve the headers and data from the API response

skyflow/utils/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SDK_VERSION = '2.0.0'
1+
SDK_VERSION = '2.0.0.dev0+e2aa629'

skyflow/utils/validations/_validations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
122122
)
123123
if is_expired(credentials.get("token"), logger):
124124
raise SkyflowError(
125-
SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id)
126-
if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value,
125+
SkyflowMessages.Error.EXPIRED_TOKEN.value
126+
if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value,
127127
invalid_input_error_code
128128
)
129129
elif "api_key" in credentials:
@@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest):
389389
if hasattr(request, 'wait_time') and request.wait_time is not None:
390390
if not isinstance(request.wait_time, (int, float)):
391391
raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code)
392-
if request.wait_time < 0 and request.wait_time > 64:
392+
if request.wait_time < 0 or request.wait_time > 64:
393393
raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code)
394394

395395
def validate_insert_request(logger, request):

skyflow/vault/client/client.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ def __init__(self, config):
1414
self.__logger = None
1515
self.__is_config_updated = False
1616
self.__bearer_token = None
17+
self.__credentials = None
18+
self.__vault_url = None
19+
self.__is_static_token = None
1720

1821
def set_common_skyflow_credentials(self, credentials):
1922
self.__common_skyflow_credentials = credentials
@@ -23,16 +26,27 @@ def set_logger(self, log_level, logger):
2326
self.__logger = logger
2427

2528
def initialize_client_configuration(self):
26-
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
27-
token = self.get_bearer_token(credentials)
28-
vault_url = get_vault_url(self.__config.get("cluster_id"),
29-
self.__config.get("env"),
30-
self.__config.get("vault_id"),
31-
logger = self.__logger)
32-
self.initialize_api_client(vault_url, token)
33-
34-
def initialize_api_client(self, vault_url, token):
35-
self.__api_client = Skyflow(base_url=vault_url, token=token)
29+
if self.__api_client is not None and not self.__is_config_updated:
30+
if self.__is_static_token:
31+
return
32+
if self.__bearer_token is not None and not is_expired(self.__bearer_token):
33+
return
34+
35+
needs_reinit = self.__api_client is None or self.__is_config_updated
36+
if needs_reinit:
37+
self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger)
38+
self.__vault_url = get_vault_url(self.__config.get("cluster_id"),
39+
self.__config.get("env"),
40+
self.__config.get("vault_id"),
41+
logger=self.__logger)
42+
self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials
43+
bearer_token = self.get_bearer_token(self.__credentials)
44+
if needs_reinit:
45+
self.initialize_api_client(self.__vault_url, bearer_token)
46+
47+
def initialize_api_client(self, vault_url, bearer_token):
48+
token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731
49+
self.__api_client = Skyflow(base_url=vault_url, token=token_provider)
3650

3751
def get_records_api(self):
3852
return self.__api_client.records
@@ -63,11 +77,10 @@ def get_bearer_token(self, credentials):
6377
"ctx": self.__config.get("ctx")
6478
}
6579

66-
if self.__bearer_token is None or self.__is_config_updated:
80+
if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token):
6781
if 'path' in credentials:
68-
path = credentials.get("path")
6982
self.__bearer_token, _ = generate_bearer_token(
70-
path,
83+
credentials.get("path"),
7184
options,
7285
self.__logger
7386
)
@@ -83,10 +96,6 @@ def get_bearer_token(self, credentials):
8396
else:
8497
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)
8598

86-
if is_expired(self.__bearer_token):
87-
self.__is_config_updated = True
88-
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
89-
9099
return self.__bearer_token
91100

92101
def update_config(self, config):

skyflow/vault/controller/_detect.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
6262
current_wait_time = 1 # Start with 1 second
6363
try:
6464
while True:
65-
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
65+
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data
6666
status = response.status
6767
if status == 'IN_PROGRESS':
6868
if current_wait_time >= max_wait_time:
@@ -228,7 +228,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo
228228
restrict_regex=deidentify_text_body['restrict_regex'],
229229
token_type=deidentify_text_body['token_type'],
230230
transformations=deidentify_text_body['transformations'],
231-
request_options=self.__get_headers()
231+
request_options={'additional_headers': self.__get_headers()}
232232
)
233233
deidentify_text_response = parse_deidentify_text_response(api_response)
234234
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
@@ -252,7 +252,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo
252252
vault_id=self.__vault_client.get_vault_id(),
253253
text=reidentify_text_body['text'],
254254
format=reidentify_text_body['format'],
255-
request_options=self.__get_headers()
255+
request_options={'additional_headers': self.__get_headers()}
256256
)
257257
reidentify_text_response = parse_reidentify_text_response(api_response)
258258
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
@@ -296,7 +296,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
296296
'allow_regex': request.allow_regex_list,
297297
'restrict_regex': request.restrict_regex_list,
298298
'transformations': self.__get_transformations(request),
299-
'request_options': self.__get_headers()
299+
'request_options': {'additional_headers': self.__get_headers()}
300300
}
301301

302302
elif file_extension in ['mp3', 'wav']:
@@ -316,7 +316,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
316316
'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None,
317317
'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None,
318318
'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None,
319-
'request_options': self.__get_headers()
319+
'request_options': {'additional_headers': self.__get_headers()}
320320
}
321321

322322
elif file_extension == 'pdf':
@@ -331,7 +331,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
331331
'restrict_regex': request.restrict_regex_list,
332332
'max_resolution': getattr(request, 'max_resolution', None),
333333
'density': getattr(request, 'pixel_density', None),
334-
'request_options': self.__get_headers()
334+
'request_options': {'additional_headers': self.__get_headers()}
335335
}
336336

337337
elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']:
@@ -347,7 +347,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
347347
'masking_method': getattr(request, 'masking_method', None),
348348
'output_ocr_text': getattr(request, 'output_ocr_text', None),
349349
'output_processed_image': getattr(request, 'output_processed_image', None),
350-
'request_options': self.__get_headers()
350+
'request_options': {'additional_headers': self.__get_headers()}
351351
}
352352

353353
elif file_extension in ['ppt', 'pptx']:
@@ -360,7 +360,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
360360
'token_type': self.__get_token_format(request),
361361
'allow_regex': request.allow_regex_list,
362362
'restrict_regex': request.restrict_regex_list,
363-
'request_options': self.__get_headers()
363+
'request_options': {'additional_headers': self.__get_headers()}
364364
}
365365

366366
elif file_extension in ['csv', 'xls', 'xlsx']:
@@ -373,7 +373,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
373373
'token_type': self.__get_token_format(request),
374374
'allow_regex': request.allow_regex_list,
375375
'restrict_regex': request.restrict_regex_list,
376-
'request_options': self.__get_headers()
376+
'request_options': {'additional_headers': self.__get_headers()}
377377
}
378378

379379
elif file_extension in ['doc', 'docx']:
@@ -386,7 +386,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
386386
'token_type': self.__get_token_format(request),
387387
'allow_regex': request.allow_regex_list,
388388
'restrict_regex': request.restrict_regex_list,
389-
'request_options': self.__get_headers()
389+
'request_options': {'additional_headers': self.__get_headers()}
390390
}
391391

392392
elif file_extension in ['json', 'xml']:
@@ -400,7 +400,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
400400
'allow_regex': request.allow_regex_list,
401401
'restrict_regex': request.restrict_regex_list,
402402
'transformations': self.__get_transformations(request),
403-
'request_options': self.__get_headers()
403+
'request_options': {'additional_headers': self.__get_headers()}
404404
}
405405

406406
else:
@@ -414,7 +414,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
414414
'allow_regex': request.allow_regex_list,
415415
'restrict_regex': request.restrict_regex_list,
416416
'transformations': self.__get_transformations(request),
417-
'request_options': self.__get_headers()
417+
'request_options': {'additional_headers': self.__get_headers()}
418418
}
419419

420420
log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
@@ -448,7 +448,7 @@ def get_detect_run(self, request: GetDetectRunRequest):
448448
response = files_api.get_run(
449449
run_id,
450450
vault_id=self.__vault_client.get_vault_id(),
451-
request_options=self.__get_headers()
451+
request_options={'additional_headers': self.__get_headers()}
452452
)
453453
if response.data.status == 'IN_PROGRESS':
454454
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS'))

0 commit comments

Comments
 (0)