diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index cd55015e..244afa9c 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -161,7 +161,7 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None: self._skip_increment_for_next_fetch = ( False # Track if we need to skip incrementing the row index ) - self.messages = [] # Store diagnostic messages + self.messages: List[Tuple[str, str]] = [] # Store diagnostic messages def _is_unicode_string(self, param: str) -> bool: """ @@ -810,6 +810,25 @@ def _check_closed(self) -> None: ddbc_error="", ) + def _capture_diagnostics(self, ret: int) -> None: + """Append diagnostic messages to self.messages when the return code + indicates records may be present. + + Captures on SQL_SUCCESS_WITH_INFO (info/warning messages) and + SQL_NO_DATA (trailing diagnostics attached by SQLMoreResults, + e.g. a PRINT after the final result set). + + Skips SQL_SUCCESS to avoid the ~10 ms overhead of scanning the + driver's internal state when no records exist. SQL_ERROR is + handled separately by check_error() which extracts diagnostics + and raises. + """ + if self.hstmt and ret in ( + ddbc_sql_const.SQL_SUCCESS_WITH_INFO.value, + ddbc_sql_const.SQL_NO_DATA.value, + ): + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + def _ensure_pyarrow(self) -> Any: """ Import and return pyarrow or raise ImportError accordingly. @@ -1518,13 +1537,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self._reset_cursor() raise - # Capture diagnostic messages only on SQL_SUCCESS_WITH_INFO. - # SQL_SUCCESS has no records — calling DDBCSQLGetAllDiagRecords on it - # costs ~10ms/call (driver scans internal state to find nothing). - # SQL_ERROR is already handled by check_error() above which extracts - # diagnostics and raises. - if ret == ddbc_sql_const.SQL_SUCCESS_WITH_INFO.value and self.hstmt: - self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self._capture_diagnostics(ret) self.last_executed_stmt = operation @@ -2773,12 +2786,16 @@ def batch_generator(): return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) - def nextset(self) -> Union[bool, None]: + def nextset(self) -> Optional[bool]: """ Skip to the next available result set. Returns: - True if there is another result set, None otherwise. + True if there is another result set, False otherwise. + Note: PEP 249 specifies True/None; we return True/False + for backward compatibility with existing callers and pyodbc + parity. The signature is Optional[bool] to keep the door + open for a future migration to True/None semantics. Raises: Error: If the previous call to execute did not produce any result set. @@ -2799,6 +2816,12 @@ def nextset(self) -> Union[bool, None]: ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + # Capture diagnostic messages (e.g. PRINT output) — handles both + # SQL_SUCCESS_WITH_INFO and SQL_NO_DATA (trailing PRINT after the + # final result set). Without this, messages from subsequent result + # sets are silently lost (GH-612). + self._capture_diagnostics(ret) + if ret == ddbc_sql_const.SQL_NO_DATA.value: logger.debug("nextset: No more result sets available") self._clear_rownumber() diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 349bbc55..d39f42ae 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -6452,6 +6452,138 @@ def test_cursor_messages_with_error(cursor): assert "After error" in cursor.messages[0][1], "Message should be from after the error" +def test_cursor_messages_nextset_multiple_prints(cursor): + """Test that PRINT messages from subsequent result sets are captured via nextset(). + + Regression test for GH-612: PRINT messages after the first one were lost + because nextset() did not capture SQL_SUCCESS_WITH_INFO diagnostics. + """ + cursor.execute(""" + PRINT 'hi'; + PRINT 'ih'; + """) + + # First PRINT is captured by execute() + assert len(cursor.messages) == 1, "execute() should capture the first PRINT message" + assert "hi" in cursor.messages[0][1] + + # Advance to the next result set — should capture the second PRINT + assert cursor.nextset() + assert len(cursor.messages) == 1, "nextset() should capture the second PRINT message" + assert "ih" in cursor.messages[0][1] + + # No more result sets + assert not cursor.nextset() + + +def test_cursor_messages_nextset_print_with_select(cursor): + """Test PRINT messages interleaved with SELECT result sets via nextset(). + + Ensures messages are captured correctly when PRINT and SELECT are mixed. + Only messages collected from nextset() itself are checked so the test + fails if nextset() drops messages (even if fetchall() would mask it). + """ + cursor.execute(""" + PRINT 'before select'; + SELECT 1 AS val; + PRINT 'after select'; + """) + + # First PRINT captured by execute() + assert len(cursor.messages) >= 1 + assert "before select" in cursor.messages[0][1] + + nextset_messages = [] + all_rows = [] + + while cursor.nextset(): + # Collect only messages produced by nextset() — not by fetchall() + nextset_messages.extend(cursor.messages) + if cursor.description: + all_rows.extend(cursor.fetchall()) + + # Also collect messages from the final nextset() that returned False + # (trailing PRINT can attach to SQL_NO_DATA) + nextset_messages.extend(cursor.messages) + + # Verify the "after select" PRINT was captured by nextset(), not fetchall() + combined_text = " ".join(m[1] for m in nextset_messages) + assert "after select" in combined_text, "nextset() should capture the trailing PRINT message" + + # Verify the SELECT result was returned + assert len(all_rows) == 1 + assert all_rows[0][0] == 1 + + +def test_cursor_messages_nextset_three_prints(cursor): + """Test that three consecutive PRINT messages are all captured across nextset() calls.""" + cursor.execute(""" + PRINT 'msg1'; + PRINT 'msg2'; + PRINT 'msg3'; + """) + + # First PRINT captured by execute() + assert len(cursor.messages) == 1 + assert "msg1" in cursor.messages[0][1] + + # Second PRINT via nextset() + assert cursor.nextset() + assert len(cursor.messages) == 1 + assert "msg2" in cursor.messages[0][1] + + # Third PRINT via nextset() + assert cursor.nextset() + assert len(cursor.messages) == 1 + assert "msg3" in cursor.messages[0][1] + + # No more result sets + assert not cursor.nextset() + + +def test_cursor_messages_nextset_clears_previous(cursor): + """Test that nextset() clears messages from the previous result set.""" + cursor.execute(""" + PRINT 'first'; + PRINT 'second'; + """) + + assert len(cursor.messages) == 1 + assert "first" in cursor.messages[0][1] + + # After nextset(), messages should only contain the new message + assert cursor.nextset() + assert len(cursor.messages) == 1, "Previous messages should have been cleared" + assert "second" in cursor.messages[0][1] + assert not any("first" in m[1] for m in cursor.messages), "Old message should not persist" + + +def test_cursor_messages_nextset_trailing_print(cursor): + """Test that a trailing PRINT after the final SELECT is captured. + + The ODBC driver delivers the trailing PRINT as a separate result set + (SQL_SUCCESS_WITH_INFO), so nextset() returns True and captures the + message. A second nextset() then returns False (SQL_NO_DATA). + This is the most common customer pain point (GH-612). + """ + cursor.execute(""" + SELECT 1 AS val; + PRINT 'trailing'; + """) + + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 1 + + # The trailing PRINT is delivered as a separate result set + assert cursor.nextset() + assert len(cursor.messages) >= 1, "Trailing PRINT after final SELECT should be captured" + assert "trailing" in cursor.messages[0][1] + + # No more result sets + assert not cursor.nextset() + + def test_tables_setup(cursor, db_connection): """Create test objects for tables method testing""" try: