diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692..01730c4b 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1114,6 +1114,271 @@ def clear_output_converters(self) -> None: self._conn.clear_output_converters() logger.info("Cleared all output converters") + # ---- Session Context API ---- + + _SESSION_KEY_MAX_LEN: int = 128 + _SESSION_VALUE_MAX_LEN: int = 4000 + + def set_session_context( + self, + *, + read_only: bool = False, + **context: Optional[str], + ) -> None: + """ + Set session-level metadata on the current connection. + + This stores name-value pairs in the SQL Server session context via + ``sp_set_session_context``, making them visible to: + + * ``SESSION_CONTEXT()`` in T-SQL queries, triggers, and stored procedures + * Extended Events sessions that capture session context + * ``sys.dm_exec_sessions`` (for *application_name*) + * Audit specifications that reference session context + + Only keys that are passed will be set. Calling this method again + merges new values with previously-set ones; to clear a key pass + ``None`` as its value. + + Well-known keys (optional, not enforced): + ``application_name``, ``module_name``, ``action_name``, ``user_id`` + + Args: + read_only: If ``True``, the keys become read-only for the + remainder of the session — subsequent calls cannot change them. + **context: Key-value pairs to store in the session context. + Pass ``None`` as a value to clear a key. + + Raises: + InterfaceError: If the connection is closed. + ProgrammingError: If a key or value exceeds length limits. + DatabaseError: If ``sp_set_session_context`` execution fails. + + Example:: + + conn.set_session_context( + application_name="BillingAPI", + module_name="InvoiceProcessor", + action_name="GenerateInvoice", + user_id="123", + ) + # Values are now readable in T-SQL: + # SELECT SESSION_CONTEXT(N'application_name') + + # Clear a key: + conn.set_session_context(user_id=None) + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot set session context on a closed connection", + ) + + if not context: + return # nothing to do + + # Validate lengths + for key, value in context.items(): + if not isinstance(key, str) or not key: + raise ProgrammingError( + driver_error="Invalid session context key", + ddbc_error="Session context key must be a non-empty string", + ) + if len(key) > self._SESSION_KEY_MAX_LEN: + raise ProgrammingError( + driver_error="Session context key too long", + ddbc_error=( + f"Session context key exceeds {self._SESSION_KEY_MAX_LEN} characters" + ), + ) + if value is not None: + if not isinstance(value, str): + raise ProgrammingError( + driver_error="Invalid session context value", + ddbc_error="Session context values must be strings or None", + ) + if len(value) > self._SESSION_VALUE_MAX_LEN: + raise ProgrammingError( + driver_error="Session context value too long", + ddbc_error=( + f"Session context value exceeds {self._SESSION_VALUE_MAX_LEN} characters" + ), + ) + + # Initialize local cache if first call + if not hasattr(self, "_session_context"): + self._session_context: Dict[str, str] = {} + if not hasattr(self, "_session_context_read_only_keys"): + self._session_context_read_only_keys: set = set() + + # Reject attempts to clear read-only keys via value=None + for key, value in context.items(): + if value is None and key in self._session_context_read_only_keys: + raise ProgrammingError( + driver_error="Cannot clear read-only session context key", + ddbc_error=( + f"Session context key '{key}' was set with read_only=True " + "and cannot be cleared" + ), + ) + + # Build a single batch of sp_set_session_context calls to execute + # in one round trip instead of N separate calls. + batch_parts: list[str] = [] + batch_params: list = [] + for key, value in context.items(): + if read_only: + batch_parts.append( + "EXEC sp_set_session_context @key=?, @value=?, @read_only=1" + ) + else: + batch_parts.append( + "EXEC sp_set_session_context @key=?, @value=?" + ) + batch_params.append(key) + batch_params.append(value) + + cursor = self.cursor() + try: + cursor.execute("; ".join(batch_parts), *batch_params) + finally: + cursor.close() + + # Update local cache after successful execution + for key, value in context.items(): + if value is None: + self._session_context.pop(key, None) + else: + self._session_context[key] = value + if read_only: + self._session_context_read_only_keys.add(key) + logger.debug("Set session context: %s", sanitize_user_input(key)) + + logger.info( + "Session context set with %d key(s): %s", + len(context), + ", ".join(sanitize_user_input(k) for k in context), + ) + + def get_session_context(self) -> Dict[str, str]: + """ + Return the current session context for keys previously set via + :meth:`set_session_context`. + + For each known key the driver queries the server with + ``SELECT SESSION_CONTEXT(N'')``, so the returned values + reflect any server-side mutations (e.g. by triggers or stored + procedures) that occurred after the initial ``set_session_context`` + call. + + Returns: + dict: A ``{key: value}`` mapping. Keys whose server-side + value is ``NULL`` are omitted. + + Raises: + InterfaceError: If the connection is closed. + DatabaseError: If the server query fails. + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot get session context on a closed connection", + ) + if not hasattr(self, "_session_context") or not self._session_context: + return {} + + # Query each known key from the server in one batch using + # parameterised queries to avoid SQL injection. + # SESSION_CONTEXT() requires nvarchar, so we CAST the parameter. + keys = list(self._session_context.keys()) + select_parts = ["SELECT SESSION_CONTEXT(CAST(? AS nvarchar(128)))" for _ in keys] + cursor = self.cursor() + try: + cursor.execute("; ".join(select_parts), *keys) + result: Dict[str, str] = {} + for key in keys: + row = cursor.fetchone() + if row is not None and row[0] is not None: + result[key] = str(row[0]) + # Advance to next result set for the next SELECT. + if not cursor.nextset(): + break + return result + finally: + cursor.close() + + def clear_session_context(self, *keys: str) -> None: + """ + Clear one or more session context keys by sending ``NULL`` to the server. + + If no keys are provided, all non-read-only keys that were previously + set via :meth:`set_session_context` are cleared. + + Args: + *keys: Key names to clear. If omitted, all clearable keys are + cleared. + + Raises: + InterfaceError: If the connection is closed. + ProgrammingError: If a specified key was set with ``read_only=True``. + DatabaseError: If ``sp_set_session_context`` execution fails. + + Example:: + + conn.clear_session_context("user_id") # clear one key + conn.clear_session_context("user_id", "action_name") # clear several + conn.clear_session_context() # clear all + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Cannot clear session context on a closed connection", + ) + if not hasattr(self, "_session_context") or not self._session_context: + return + + read_only_keys: set = getattr(self, "_session_context_read_only_keys", set()) + + if keys: + # Validate that none of the requested keys are read-only. + read_only_requested = read_only_keys & set(keys) + if read_only_requested: + raise ProgrammingError( + driver_error="Cannot clear read-only session context keys", + ddbc_error=( + "The following keys are read-only and cannot be cleared: " + + ", ".join(sorted(read_only_requested)) + ), + ) + to_clear = [k for k in keys if k in self._session_context] + else: + to_clear = [k for k in self._session_context if k not in read_only_keys] + + if not to_clear: + return + + # Build a batch to NULL-out every key in one round trip. + batch_parts: list[str] = [] + batch_params: list = [] + for key in to_clear: + batch_parts.append("EXEC sp_set_session_context @key=?, @value=NULL") + batch_params.append(key) + + cursor = self.cursor() + try: + cursor.execute("; ".join(batch_parts), *batch_params) + finally: + cursor.close() + + for key in to_clear: + self._session_context.pop(key, None) + + logger.info( + "Cleared %d session context key(s): %s", + len(to_clear), + ", ".join(sanitize_user_input(k) for k in to_clear), + ) + def execute(self, sql: str, *args: Any) -> Cursor: """ Creates a new Cursor object, calls its execute method, and returns the new cursor. @@ -1636,6 +1901,25 @@ def close(self) -> None: # references self._cursors.clear() + # If pooling is enabled, clear session context so the next consumer + # of this ODBC handle does not inherit leftover metadata. + if self._pooling and hasattr(self, "_session_context") and self._session_context: + read_only_keys = getattr(self, "_session_context_read_only_keys", set()) + if read_only_keys & set(self._session_context): + logger.warning( + "Pooled connection has read-only session context keys that " + "cannot be cleared: %s. These values will persist for the " + "next consumer of this connection.", + ", ".join(sorted(read_only_keys & set(self._session_context))), + ) + try: + self.clear_session_context() + except Exception: + logger.warning( + "Failed to clear session context before pool return", + exc_info=True, + ) + # Close the connection even if cursor cleanup had issues try: if self._conn: @@ -1721,4 +2005,4 @@ def __del__(self) -> None: self.close() except Exception as e: # Dont raise exceptions from __del__ to avoid issues during garbage collection - logger.warning(f"Error during connection cleanup: {e}") + logger.warning(f"Error during connection cleanup: {e}") \ No newline at end of file diff --git a/tests/test_023_audit_context.py b/tests/test_023_audit_context.py new file mode 100644 index 00000000..c5c6ac80 --- /dev/null +++ b/tests/test_023_audit_context.py @@ -0,0 +1,267 @@ +""" +Tests for the session context API +(set_session_context / get_session_context / clear_session_context). + +Functions: +- test_set_and_get_session_context: Set named fields and verify via server query. +- test_session_context_server_roundtrip: Verify values are readable via SESSION_CONTEXT(). +- test_session_context_extra_keys: Test arbitrary extra key-value pairs. +- test_session_context_merge: Successive calls merge, not replace. +- test_session_context_empty_call: Calling with no arguments is a no-op. +- test_session_context_clear_value: Setting a key to None clears it server-side. +- test_session_context_read_only: read_only=True prevents subsequent changes. +- test_session_context_closed_connection: Raises InterfaceError when connection is closed. +- test_session_context_key_too_long: Raises ProgrammingError for oversized keys. +- test_session_context_value_too_long: Raises ProgrammingError for oversized values. +- test_session_context_non_string_value: Raises ProgrammingError for non-string values. +- test_clear_session_context_single_key: Clear one key. +- test_clear_session_context_multiple_keys: Clear several keys. +- test_clear_session_context_all: Clear all keys. +- test_clear_session_context_read_only_raises: Clearing a read-only key raises ProgrammingError. +- test_clear_session_context_closed: Raises InterfaceError on closed connection. +- test_clear_session_context_noop: No-op when nothing has been set. +- test_get_session_context_reflects_server: Getter fetches live server values. +- test_pool_return_clears_context: Session context is cleared when pooled connection is closed. +""" + +import pytest +from mssql_python import connect +from mssql_python.exceptions import InterfaceError, ProgrammingError, DatabaseError + + +@pytest.fixture() +def audit_conn(conn_str): + """Dedicated connection for session context tests (module-scoped fixtures + would share session state, so we create a fresh connection per test).""" + conn = connect(conn_str) + yield conn + conn.close() + + +class TestSessionContext: + """Tests for Connection.set_session_context / get_session_context.""" + + def test_set_and_get_session_context(self, audit_conn): + """Named fields are reflected in the local cache.""" + audit_conn.set_session_context( + application_name="BillingAPI", + module_name="InvoiceProcessor", + action_name="GenerateInvoice", + user_id="123", + ) + ctx = audit_conn.get_session_context() + assert ctx["application_name"] == "BillingAPI" + assert ctx["module_name"] == "InvoiceProcessor" + assert ctx["action_name"] == "GenerateInvoice" + assert ctx["user_id"] == "123" + + def test_session_context_server_roundtrip(self, audit_conn): + """Values set via set_session_context are readable with SESSION_CONTEXT().""" + audit_conn.set_session_context(application_name="RoundTrip", user_id="42") + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'application_name')") + row = cursor.fetchone() + assert row[0] == "RoundTrip" + + cursor.execute("SELECT SESSION_CONTEXT(N'user_id')") + row = cursor.fetchone() + assert row[0] == "42" + finally: + cursor.close() + + def test_session_context_extra_keys(self, audit_conn): + """Arbitrary extra keys are stored via sp_set_session_context.""" + audit_conn.set_session_context(tenant_id="ACME", correlation_id="abc-def") + ctx = audit_conn.get_session_context() + assert ctx["tenant_id"] == "ACME" + assert ctx["correlation_id"] == "abc-def" + + # Verify server-side + cursor = audit_conn.cursor() + try: + cursor.execute("SELECT SESSION_CONTEXT(N'tenant_id')") + assert cursor.fetchone()[0] == "ACME" + finally: + cursor.close() + + def test_session_context_merge(self, audit_conn): + """Successive calls merge values, not replace.""" + audit_conn.set_session_context(application_name="App1") + audit_conn.set_session_context(module_name="Mod1") + ctx = audit_conn.get_session_context() + assert ctx["application_name"] == "App1" + assert ctx["module_name"] == "Mod1" + + def test_session_context_overwrite(self, audit_conn): + """A second call with the same key overwrites the previous value.""" + audit_conn.set_session_context(action_name="First") + audit_conn.set_session_context(action_name="Second") + assert audit_conn.get_session_context()["action_name"] == "Second" + + def test_session_context_empty_call(self, audit_conn): + """Calling with no arguments is a silent no-op.""" + audit_conn.set_session_context() + assert audit_conn.get_session_context() == {} + + def test_session_context_clear_value(self, audit_conn): + """Setting a key to None clears it (sends NULL to the server).""" + audit_conn.set_session_context(user_id="99") + audit_conn.set_session_context(user_id=None) + assert "user_id" not in audit_conn.get_session_context() + + def test_session_context_read_only(self, audit_conn): + """read_only=True makes the key immutable for the session.""" + audit_conn.set_session_context(action_name="Locked", read_only=True) + # Attempting to change a read-only key should raise a DatabaseError + # from SQL Server (error 15664). + with pytest.raises(DatabaseError): + audit_conn.set_session_context(action_name="Changed") + + def test_session_context_read_only_clear_via_none(self, audit_conn): + """Setting a read-only key to None raises ProgrammingError.""" + audit_conn.set_session_context(action_name="Locked", read_only=True) + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(action_name=None) + + def test_session_context_closed_connection_set(self, audit_conn): + """set_session_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.set_session_context(application_name="X") + + def test_session_context_closed_connection_get(self, audit_conn): + """get_session_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.get_session_context() + + def test_session_context_key_max_length(self, audit_conn): + """A key at exactly 128 characters is accepted.""" + key = "k" * 128 + audit_conn.set_session_context(**{key: "val"}) + ctx = audit_conn.get_session_context() + assert ctx[key] == "val" + + def test_session_context_key_one_over_max(self, audit_conn): + """A key at 129 characters is rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(**{"k" * 129: "v"}) + + def test_session_context_key_too_long(self, audit_conn): + """Keys longer than 128 characters are rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(**{"x" * 200: "v"}) + + def test_session_context_value_max_length(self, audit_conn): + """A value at exactly 4000 characters is accepted.""" + val = "v" * 4000 + audit_conn.set_session_context(user_id=val) + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == val + + def test_session_context_value_one_over_max(self, audit_conn): + """A value at 4001 characters is rejected.""" + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(user_id="v" * 4001) + + def test_session_context_non_string_value(self, audit_conn): + """Non-string values are rejected with ProgrammingError.""" + with pytest.raises(ProgrammingError): + audit_conn.set_session_context(user_id=123) # type: ignore[arg-type] + + def test_get_session_context_returns_fresh(self, audit_conn): + """get_session_context queries the server and returns a fresh dict each time.""" + audit_conn.set_session_context(application_name="Fresh") + ctx1 = audit_conn.get_session_context() + ctx2 = audit_conn.get_session_context() + assert ctx1 == ctx2 + assert ctx1 is not ctx2 # distinct dict objects + + def test_get_session_context_reflects_server(self, audit_conn): + """Getter fetches live values from the server, not stale cache.""" + audit_conn.set_session_context(user_id="original") + # Mutate the value directly via T-SQL (bypassing the Python API) + cursor = audit_conn.cursor() + try: + cursor.execute( + "EXEC sp_set_session_context @key=N'user_id', @value=N'mutated'" + ) + finally: + cursor.close() + # The getter should reflect the server-side mutation + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == "mutated" + + # ---- clear_session_context tests ---- + + def test_clear_session_context_single_key(self, audit_conn): + """Clearing a single key removes it from the server.""" + audit_conn.set_session_context(user_id="1", module_name="Mod") + audit_conn.clear_session_context("user_id") + ctx = audit_conn.get_session_context() + assert "user_id" not in ctx + assert ctx["module_name"] == "Mod" + + def test_clear_session_context_multiple_keys(self, audit_conn): + """Clearing multiple keys removes them all.""" + audit_conn.set_session_context( + user_id="1", module_name="Mod", action_name="Act" + ) + audit_conn.clear_session_context("user_id", "action_name") + ctx = audit_conn.get_session_context() + assert "user_id" not in ctx + assert "action_name" not in ctx + assert ctx["module_name"] == "Mod" + + def test_clear_session_context_all(self, audit_conn): + """Calling with no args clears all non-read-only keys.""" + audit_conn.set_session_context(user_id="1", module_name="Mod") + audit_conn.clear_session_context() + assert audit_conn.get_session_context() == {} + + def test_clear_session_context_read_only_raises(self, audit_conn): + """Explicitly clearing a read-only key raises ProgrammingError.""" + audit_conn.set_session_context(user_id="locked", read_only=True) + with pytest.raises(ProgrammingError): + audit_conn.clear_session_context("user_id") + + def test_clear_session_context_all_skips_read_only(self, audit_conn): + """clear_session_context() without args skips read-only keys.""" + audit_conn.set_session_context(user_id="locked", read_only=True) + audit_conn.set_session_context(module_name="clearable") + audit_conn.clear_session_context() + ctx = audit_conn.get_session_context() + assert ctx["user_id"] == "locked" + assert "module_name" not in ctx + + def test_clear_session_context_closed(self, audit_conn): + """clear_session_context raises InterfaceError on a closed connection.""" + audit_conn.close() + with pytest.raises(InterfaceError): + audit_conn.clear_session_context("user_id") + + def test_clear_session_context_noop(self, audit_conn): + """Clearing when nothing has been set is a silent no-op.""" + audit_conn.clear_session_context() # should not raise + + # ---- Pool return tests ---- + + def test_pool_return_clears_context(self, conn_str): + """When pooling is enabled, close() clears session context server-side.""" + conn = connect(conn_str) + conn.set_session_context(application_name="PoolTest", module_name="Mod") + # Simulate pooling enabled + conn._pooling = True + conn.close() + # After close the cache should be empty + assert not getattr(conn, "_session_context", {}) + + def test_pool_return_skips_without_pooling(self, conn_str): + """Without pooling, close() does not attempt to clear session context.""" + conn = connect(conn_str) + conn.set_session_context(application_name="NoPools") + conn._pooling = False + conn.close() + # Cache is left as-is (object is closed, no pool reuse) + assert conn._session_context.get("application_name") == "NoPools" \ No newline at end of file