Skip to content
65 changes: 65 additions & 0 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,68 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]:
"""
auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower()
return _AUTH_TYPE_MAP.get(auth_value)


def _get_token_from_credential(credential: object) -> str:
"""Internal: call credential.get_token() and return the raw JWT string.

Centralises the token-acquisition + error-wrapping logic that both
:func:`acquire_token_from_credential` and
:func:`acquire_raw_token_from_credential` need.

Raises:
RuntimeError: If token acquisition fails.
"""
try:
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"_get_token_from_credential: Token acquired from %s - length=%d chars",
type(credential).__name__,
len(raw_token),
)
return raw_token
except Exception as e:
logger.error(
"_get_token_from_credential: Failed - credential=%s, error=%s",
type(credential).__name__,
str(e),
)
raise RuntimeError(
f"Failed to acquire token from credential " f"({type(credential).__name__}): {e}"
) from e


def acquire_token_from_credential(credential: object) -> bytes:
"""Acquire an ODBC token struct from a user-supplied credential object.

The credential must follow the Azure ``TokenCredential`` protocol — i.e.
have a ``.get_token(scope)`` method returning an object with a ``.token``
attribute (a raw JWT string).

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
bytes: ODBC-compatible token struct for ``SQL_COPT_SS_ACCESS_TOKEN``.

Raises:
RuntimeError: If token acquisition fails.
"""
return AADAuth.get_token_struct(_get_token_from_credential(credential))


def acquire_raw_token_from_credential(credential: object) -> str:
"""Acquire a raw JWT string from a user-supplied credential object.

Used by bulk copy, which needs the raw JWT rather than the ODBC struct.

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
str: Raw JWT token string.

Raises:
RuntimeError: If token acquisition fails.
"""
return _get_token_from_credential(credential)
30 changes: 29 additions & 1 deletion mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -334,10 +335,37 @@ def __init__(
# fresh token; re-parsing self.connection_str at that point would miss
# them because UID is already gone.
self._credential_kwargs: Optional[Dict[str, str]] = None
# User-supplied token provider for custom Entra ID authentication.
# Stored so bulk copy can call .get_token() for a fresh JWT later.
self._token_provider = None

# Custom token_provider= parameter — takes priority, mutually exclusive
# with Authentication= in the connection string.
if token_provider is not None:
if _KEY_AUTHENTICATION in parsed_params:
raise ValueError(
"Cannot specify both 'token_provider' parameter and "
"'Authentication' in the connection string. "
"Use one or the other."
)
if not callable(getattr(token_provider, "get_token", None)):
raise TypeError(
f"token_provider must have a .get_token() method. "
f"Got {type(token_provider).__name__}."
)
from mssql_python.auth import acquire_token_from_credential

token = acquire_token_from_credential(token_provider)
self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token
self._token_provider = token_provider
# Strip sensitive params (UID/PWD/Trusted_Connection) since
# access-token auth is used — same as the Authentication= path.
sanitized = remove_sensitive_params(parsed_params)
self.connection_str = _ConnectionStringBuilder(sanitized).build()

# Handle Entra ID authentication if specified.
# The parsed dict is used directly — no re-parsing of the connection string.
if _KEY_AUTHENTICATION in parsed_params:
elif _KEY_AUTHENTICATION in parsed_params:
auth_type = process_auth_parameters(parsed_params)

if auth_type:
Expand Down
30 changes: 16 additions & 14 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,16 @@ def get_attribute_set_timing(attribute):

_CONNECTION_STRING_DRIVER_KEY = "Driver"
_CONNECTION_STRING_APP_KEY = "APP"
_CONNECTION_STRING_AUTH_KEY = "Authentication"
_CONNECTION_STRING_UID_KEY = "UID"
_CONNECTION_STRING_PWD_KEY = "PWD"
_CONNECTION_STRING_TRUSTED_CONNECTION_KEY = "Trusted_Connection"

# Aliases used by auth.py / connection.py — kept for readability.
_KEY_AUTHENTICATION = _CONNECTION_STRING_AUTH_KEY
_KEY_UID = _CONNECTION_STRING_UID_KEY
_KEY_PWD = _CONNECTION_STRING_PWD_KEY
_KEY_TRUSTED_CONNECTION = _CONNECTION_STRING_TRUSTED_CONNECTION_KEY

# Reserved connection string parameters that are controlled by the driver
# and cannot be set by users
Expand All @@ -486,16 +496,16 @@ def get_attribute_set_timing(attribute):
"address": "Server",
"addr": "Server",
# Authentication
"uid": "UID",
"pwd": "PWD",
"authentication": "Authentication",
"trusted_connection": "Trusted_Connection",
"uid": _CONNECTION_STRING_UID_KEY,
"pwd": _CONNECTION_STRING_PWD_KEY,
"authentication": _CONNECTION_STRING_AUTH_KEY,
"trusted_connection": _CONNECTION_STRING_TRUSTED_CONNECTION_KEY,
# Database
"database": "Database",
# Driver (always controlled by mssql-python)
"driver": "Driver",
"driver": _CONNECTION_STRING_DRIVER_KEY,
# Application name (always controlled by mssql-python)
"app": "APP",
"app": _CONNECTION_STRING_APP_KEY,
# Encryption and Security
"encrypt": "Encrypt",
"trustservercertificate": "TrustServerCertificate",
Expand All @@ -519,14 +529,6 @@ def get_attribute_set_timing(attribute):
"packetsize": "PacketSize",
}

# Canonical normalized key names produced by _ConnectionStringParser._normalize_params.
# Consumer code should reference these instead of hard-coding raw strings so that
# a rename in _ALLOWED_CONNECTION_STRING_PARAMS is caught at import time.
_KEY_AUTHENTICATION = "Authentication"
_KEY_UID = "UID"
_KEY_PWD = "PWD"
_KEY_TRUSTED_CONNECTION = "Trusted_Connection"


def get_info_constants() -> Dict[str, int]:
"""
Expand Down
19 changes: 18 additions & 1 deletion mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,7 +2961,24 @@ def bulkcopy(
pycore_context = connstr_to_pycore_params(params)

# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
if self.connection._auth_type:
if self.connection._token_provider is not None:
# User-supplied credential — use it directly for a fresh token.
from mssql_python.auth import acquire_raw_token_from_credential

try:
raw_token = acquire_raw_token_from_credential(self.connection._token_provider)
except RuntimeError as e:
raise RuntimeError(
f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}"
) from e
pycore_context["access_token"] = raw_token
for key in ("authentication", "user_name", "password"):
pycore_context.pop(key, None)
logger.debug(
"Bulk copy: acquired fresh token from custom credential (%s)",
type(self.connection._token_provider).__name__,
)
elif self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection. credential
# kwargs (e.g. user-assigned MSI client_id) were captured by
# Connection.__init__ before remove_sensitive_params stripped UID
Expand Down
25 changes: 25 additions & 0 deletions mssql_python/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def connect(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> Connection:
"""
Expand All @@ -35,6 +36,29 @@ def connect(
This per-connection override is useful for migration from pyodbc:
connections that need string UUIDs can pass native_uuid=False, while the default (True)
returns native uuid.UUID objects.
token_provider (object, optional): A token provider for Microsoft Entra ID
authentication. This must be any object with a ``.get_token(scope)`` method that
returns an object with a ``.token`` attribute containing a raw JWT string — for
example, any ``azure-identity`` credential class such as
``DefaultAzureCredential``, ``AzureCliCredential``, ``ManagedIdentityCredential``,
``CertificateCredential``, etc.

When provided, the driver calls ``token_provider.get_token()`` to acquire an
access token for SQL Server, bypassing the built-in credential map.
Cannot be combined with ``Authentication=`` in the connection string.

For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault``
in the connection string — ``DefaultAzureCredential`` automatically picks the
right credential per environment (CLI on dev, Managed Identity in prod).
Use ``token_provider=`` only when you need explicit control over token
acquisition (e.g., excluding specific providers, using a credential not in
the built-in map, or passing custom options to the credential constructor).

Example::

from azure.identity import AzureCliCredential
conn = mssql_python.connect("Server=s;Database=d",
token_provider=AzureCliCredential())
Keyword Args:
**kwargs: Additional key/value pairs for the connection string.
Below attributes are not implemented in the internal driver:
Expand All @@ -58,6 +82,7 @@ def connect(
attrs_before=attrs_before,
timeout=timeout,
native_uuid=native_uuid,
token_provider=token_provider,
**kwargs,
)
return conn
2 changes: 2 additions & 0 deletions mssql_python/mssql_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class Connection:
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> None: ...

Expand Down Expand Up @@ -291,6 +292,7 @@ def connect(
attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None,
timeout: int = 0,
native_uuid: Optional[bool] = None,
token_provider: Optional[object] = None,
**kwargs: Any,
) -> Connection: ...

Expand Down
Loading
Loading