Skip to content

Commit d2df75b

Browse files
Fixed bug when using multiple cloud native authentication plugins for
connections.
1 parent bb838dc commit d2df75b

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

doc/src/release_notes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Common Changes
3131
#) Added Session Token-based authentication support when using
3232
:ref:`OCI Cloud Native Authentication <cloudnativeauthoci>`
3333
(`issue 527 <https://github.com/oracle/python-oracledb/issues/527>`__).
34+
#) Fixed bug when using multiple
35+
:ref:`cloud native authentication <tokenauth>` plugins for connections.
36+
Note that an invalid ``auth_type`` parameter will no longer raise an
37+
exception but will simply be ignored.
3438
#) Updated the `Jupyter notebook samples <https://github.com/oracle/
3539
python-oracledb/tree/main/samples/notebooks>`__ to cover recent
3640
python-oracledb features.

src/oracledb/plugins/azure_tokens.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,14 @@
2828
# Methods that generates an OAuth2 access token using the MSAL SDK
2929
# -----------------------------------------------------------------------------
3030

31+
import enum
32+
3133
import msal
3234
import oracledb
3335

3436

35-
def generate_token(token_auth_config, refresh=False):
36-
"""
37-
Generates an Azure access token based on provided credentials.
38-
"""
39-
user_auth_type = token_auth_config.get("auth_type") or ""
40-
auth_type = user_auth_type.lower()
41-
if auth_type == "azureserviceprincipal":
42-
return _service_principal_credentials(token_auth_config)
43-
else:
44-
raise ValueError(
45-
f"Unrecognized auth_type authentication method: {user_auth_type}"
46-
)
37+
class AuthType(str, enum.Enum):
38+
AzureServicePrincipal = "AzureServicePrincipal".lower()
4739

4840

4941
def _service_principal_credentials(token_auth_config):
@@ -65,11 +57,30 @@ def _service_principal_credentials(token_auth_config):
6557
return auth_response["access_token"]
6658

6759

60+
def generate_token(token_auth_config, refresh=False):
61+
"""
62+
Generates an Azure access token based on provided credentials.
63+
"""
64+
auth_type = token_auth_config["auth_type"].lower()
65+
if auth_type == AuthType.AzureServicePrincipal:
66+
return _service_principal_credentials(token_auth_config)
67+
68+
69+
def has_azure_auth_type(extra_auth_params):
70+
"""
71+
Validates that extra_auth_params contains a valid 'auth_type'
72+
"""
73+
if extra_auth_params is None:
74+
return False
75+
auth_type = extra_auth_params.get("auth_type")
76+
return auth_type is not None and auth_type.lower() in AuthType
77+
78+
6879
def azure_token_hook(params: oracledb.ConnectParams):
6980
"""
7081
Azure-specific hook for generating a token.
7182
"""
72-
if params.extra_auth_params is not None:
83+
if has_azure_auth_type(params.extra_auth_params):
7384

7485
def token_callback(refresh):
7586
return generate_token(params.extra_auth_params, refresh)

src/oracledb/plugins/oci_tokens.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def generate_token(token_auth_config, refresh=False):
190190
"""
191191
Generates an OCI access token based on provided credentials.
192192
"""
193-
user_auth_type = token_auth_config.get("auth_type") or ""
194-
auth_type = user_auth_type.lower()
193+
auth_type = token_auth_config["auth_type"].lower()
195194
if auth_type == AuthType.ConfigFileAuthentication:
196195
return _config_file_based_authentication(token_auth_config)
197196
elif auth_type == AuthType.InstancePrincipal:
@@ -202,17 +201,23 @@ def generate_token(token_auth_config, refresh=False):
202201
return _security_token_simple_authentication(token_auth_config)
203202
elif auth_type == AuthType.SimpleAuthentication:
204203
return _simple_authentication(token_auth_config)
205-
else:
206-
raise ValueError(
207-
f"Unrecognized auth_type authentication method {user_auth_type}"
208-
)
204+
205+
206+
def has_oci_auth_type(extra_auth_params):
207+
"""
208+
Validates that extra_auth_params contains a valid 'auth_type'
209+
"""
210+
if extra_auth_params is None:
211+
return False
212+
auth_type = extra_auth_params.get("auth_type")
213+
return auth_type is not None and auth_type.lower() in AuthType
209214

210215

211216
def oci_token_hook(params: oracledb.ConnectParams):
212217
"""
213218
OCI-specific hook for generating a token.
214219
"""
215-
if params.extra_auth_params is not None:
220+
if has_oci_auth_type(params.extra_auth_params):
216221

217222
def token_callback(refresh):
218223
return generate_token(params.extra_auth_params, refresh)

0 commit comments

Comments
 (0)