diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index fd9d7b32..a95e5cbf 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -25,7 +25,10 @@ from mssql_python import get_settings if TYPE_CHECKING: + import pyarrow # type: ignore from mssql_python.connection import Connection +else: + pyarrow = None # Constants for string handling MAX_INLINE_CHAR: int = ( @@ -2198,6 +2201,89 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": + """ + Fetch a single pyarrow Record Batch of the specified size from the + query result set. + + Args: + batch_size: Maximum number of rows to fetch in the Record Batch. + + Returns: + A pyarrow RecordBatch object containing up to batch_size rows. + """ + self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() + + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_batch(). Please install pyarrow." + ) from e + + capsules = [] + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) + return batch + + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + """ + Fetch the entire result as a pyarrow Table. + + Args: + batch_size: Size of the Record Batches which make up the Table. + + Returns: + A pyarrow Table containing all remaining rows from the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e + + batches: list["pyarrow.RecordBatch"] = [] + while True: + batch = self.arrow_batch(batch_size) + if batch.num_rows < batch_size or batch_size <= 0: + if not batches or batch.num_rows > 0: + batches.append(batch) + break + batches.append(batch) + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) + + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": + """ + Fetch the result as a pyarrow RecordBatchReader, which yields Record + Batches of the specified size until the current result set is + exhausted. + + Args: + batch_size: Size of the Record Batches produced by the reader. + + Returns: + A pyarrow RecordBatchReader for the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_reader(). Please install pyarrow." + ) from e + + # Fetch schema without advancing cursor + schema_batch = self.arrow_batch(0) + schema = schema_batch.schema + + def batch_generator(): + while (batch := self.arrow_batch(batch_size)).num_rows > 0: + yield batch + + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 31cdc514..810c5b9c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,6 +157,82 @@ struct NumericData { } }; +struct ArrowArrayPrivateData { + std::unique_ptr valid; + + std::unique_ptr uint8Val; + std::unique_ptr int16Val; + std::unique_ptr int32Val; + std::unique_ptr int64Val; + std::unique_ptr float64Val; + std::unique_ptr bitVal; + std::unique_ptr varVal; + std::unique_ptr dateVal; + std::unique_ptr tsMicroVal; + std::unique_ptr timeSecondVal; + std::unique_ptr<__int128_t[]> decimalVal; + + std::vector varData; + + // first buffer will be the valid bitmap + // second buffer will be one of the value buffers above + // third buffer will be the varData buffer for variable length types + std::array buffers; + + // Points to one of the typed *Val buffers above. Since the buffer pointers + // don't change, this can be set once during batch initialization. + void* ptrValueBuffer; +}; + +struct ArrowSchemaPrivateData { + std::unique_ptr name; + std::unique_ptr format; +}; + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowSchemaPrivateData* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowArrayPrivateData* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -3926,6 +4002,1037 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +// GetDataVar - Progressively fetches variable-length column data using SQLGetData. +// +// Calls SQLGetData repeatedly, reallocating the buffer as needed, until all data is retrieved. +// Handles both fixed-size and unknown-size (SQL_NO_TOTAL) responses from the driver. +// +// @param hStmt: Statement handle +// @param colNumber: 1-based column index +// @param cType: SQL C data type (SQL_C_CHAR, SQL_C_WCHAR, or SQL_C_BINARY) +// @param dataVec: Reference to vector that will hold the fetched data (will be resized as needed) +// @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) +// +// @return SQLRETURN: SQL_SUCCESS on success, or error code on failure +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, + SQLUSMALLINT colNumber, + SQLSMALLINT cType, + std::vector& dataVec, + SQLLEN* indicator) { + size_t start = 0; + size_t end = 0; + + // Determine null terminator size based on data type + size_t sizeNullTerminator = 0; + switch (cType) { + case SQL_C_WCHAR: + case SQL_C_CHAR: + sizeNullTerminator = 1; + break; + case SQL_C_BINARY: + sizeNullTerminator = 0; + break; + default: + ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); + } + + // Ensure initial buffer has space for at least the null terminator + if (dataVec.size() < sizeNullTerminator) { + dataVec.resize(sizeNullTerminator); + } + + while (true) { + SQLLEN localInd = 0; + SQLRETURN ret = SQLGetData_ptr( + hStmt, + colNumber, + cType, + reinterpret_cast(dataVec.data() + start), + sizeof(T) * (dataVec.size() - start), // Available buffer size from start position + &localInd + ); + + // Handle NULL data + if (localInd == SQL_NULL_DATA) { + *indicator = SQL_NULL_DATA; + return SQL_SUCCESS; + } + + // Check for errors (excluding SQL_SUCCESS_WITH_INFO which means more data available) + if (ret == SQL_ERROR || ret == SQL_INVALID_HANDLE) { + return ret; + } + + // SQL_SUCCESS or SQL_NO_DATA means we got all the data + if (ret == SQL_SUCCESS || ret == SQL_NO_DATA) { + if (localInd >= 0) { + *indicator = static_cast(start) * sizeof(T) + localInd; + } else { + *indicator = localInd; // Preserve SQL_NO_TOTAL or other negative values + } + break; + } + + // SQL_SUCCESS_WITH_INFO means buffer was too small, need to continue fetching + if (ret == SQL_SUCCESS_WITH_INFO) { + // Determine how much more space we need + if (localInd < 0) { + // SQL_NO_TOTAL: driver doesn't know total size, double the buffer + end = dataVec.size() * 2; + } else { + // Driver returned total size: allocate exactly what we need + assert(localInd % sizeof(T) == 0); + end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; + } + + // The next read starts where the null terminator would have been placed + start = dataVec.size() - sizeNullTerminator; + + // Resize buffer for next iteration + dataVec.resize(end); + } else { + // Unexpected return code + return ret; + } + } + + return SQL_SUCCESS; +} + +int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + std::tm tm_date = {}; + tm_date.tm_year = year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = month - 1; // tm_mon is 0-11 + tm_date.tm_mday = day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t"); + ThrowStdException("Date conversion error"); + } + // Sanity check against timezone issues. Since we only provide the date, this has to be true + assert(time_since_epoch % 86400 == 0); + // Calculate days since epoch + return time_since_epoch / 86400; +} + +SQLRETURN FetchArrowBatch_wrap( + SqlHandlePtr StatementHandle, + py::list& capsules, + ssize_t arrowBatchSize +) { + ssize_t fetchSize = arrowBatchSize; + SQLRETURN ret; + SQLHSTMT hStmt = StatementHandle->get(); + // Retrieve column count + SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + if (numCols <= 0) { + ThrowStdException("No active result set. Cannot fetch Arrow batch."); + } + + // Retrieve column metadata + py::list columnNames; + ret = SQLDescribeCol_wrap(StatementHandle, columnNames); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to get column descriptions"); + return ret; + } + + bool hasLobColumns = false; + + std::vector dataTypes(numCols); + std::vector columnSizes(numCols); + std::vector columnNullable(numCols); + std::vector columnVarLen(numCols, false); + std::vector nullCounts(numCols, 0); + + std::vector> arrowArrayPrivateData(numCols); + std::vector> arrowSchemaPrivateData(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayPrivateData[i] = std::make_unique(); + auto& arrowColumnProducer = arrowArrayPrivateData[i]; + arrowSchemaPrivateData[i] = std::make_unique(); + + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + SQLSMALLINT nullable = colMeta["Nullable"].cast(); + + dataTypes[i] = dataType; + columnSizes[i] = columnSize; + columnNullable[i] = (nullable != SQL_NO_NULLS); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } + } + + std::string columnName = colMeta["ColumnName"].cast(); + size_t nameLen = columnName.length() + 1; + arrowSchemaPrivateData[i]->name = std::make_unique(nameLen); + std::memcpy(arrowSchemaPrivateData[i]->name.get(), columnName.c_str(), nameLen); + + std::string format = ""; + switch(dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + format = "u"; + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; + // start at offset 0 + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + format = "z"; + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; + // start at offset 0 + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); + break; + case SQL_TINYINT: + format = "C"; + arrowColumnProducer->uint8Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->uint8Val.get(); + break; + case SQL_SMALLINT: + format = "s"; + arrowColumnProducer->int16Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int16Val.get(); + break; + case SQL_INTEGER: + format = "i"; + arrowColumnProducer->int32Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int32Val.get(); + break; + case SQL_BIGINT: + format = "l"; + arrowColumnProducer->int64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int64Val.get(); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + format = "g"; + arrowColumnProducer->float64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->float64Val.get(); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + std::ostringstream formatStream; + formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + std::string formatStr = formatStream.str(); + size_t formatLen = formatStr.length() + 1; + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), formatStr.c_str(), formatLen); + format = arrowSchemaPrivateData[i]->format.get(); + arrowColumnProducer->decimalVal = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->decimalVal.get(); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + format = "tsu:"; + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); + break; + case SQL_SS_TIMESTAMPOFFSET: + format = "tsu:+00:00"; + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); + break; + case SQL_TYPE_DATE: + format = "tdD"; + arrowColumnProducer->dateVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->dateVal.get(); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + format = "tts"; + arrowColumnProducer->timeSecondVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->timeSecondVal.get(); + break; + case SQL_BIT: + format = "b"; + arrowColumnProducer->bitVal = std::make_unique((arrowBatchSize + 7) / 8); + std::memset(arrowColumnProducer->bitVal.get(), 0, (arrowBatchSize + 7) / 8); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->bitVal.get(); + break; + default: + std::wstring columnName = colMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (i + 1); + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + + // Store format string if not already stored. + // For non-decimal types, format is now a static string. + if (!arrowSchemaPrivateData[i]->format) { + size_t formatLen = format.length() + 1; + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), format.c_str(), formatLen); + } + + arrowColumnProducer->valid = std::make_unique((arrowBatchSize + 7) / 8); + // Initialize validity bitmap to all valid + std::memset(arrowColumnProducer->valid.get(), 0xFF, (arrowBatchSize + 7) / 8); + } + + if (fetchSize > 1) { + // An overly large fetch size doesn't seem to help performance + SQLSMALLINT searchStart = 64; + if (arrowBatchSize < 64) { + searchStart = static_cast(arrowBatchSize); + } + for (SQLSMALLINT maybeNewSize = searchStart; maybeNewSize >= 1; maybeNewSize -= 1) { + if (arrowBatchSize % maybeNewSize == 0) { + fetchSize = maybeNewSize; + break; + } + } + } + + // Initialize column buffers + ColumnBuffers buffers(numCols, fetchSize); + + if (!hasLobColumns && fetchSize > 0) { + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } + } + + SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + + size_t idxRowArrow = 0; + // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches + // start with a fresh batch + assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); + assert(fetchSize <= arrowBatchSize); + + while (idxRowArrow < arrowBatchSize) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) { + ret = SQL_SUCCESS; // Normal completion + break; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. + // It'll be populated by SQLFetch + assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); + for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { + for (SQLUSMALLINT idxCol = 0; idxCol < numCols; idxCol++) { + auto& arrowColumnProducer = arrowArrayPrivateData[idxCol]; + auto dataType = dataTypes[idxCol]; + auto columnSize = columnSizes[idxCol]; + + if (hasLobColumns) { + assert(idxRowSql == 0 && "GetData only works one row at a time"); + + switch(dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + ret = GetDataVar( + hStmt, + idxCol + 1, + SQL_C_BINARY, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + ret = GetDataVar( + hStmt, + idxCol + 1, + SQL_C_CHAR, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + ret = GetDataVar( + hStmt, + idxCol + 1, + SQL_C_WCHAR, + buffers.wcharBuffers[idxCol], + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching binary data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_INTEGER: { + buffers.intBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_SLONG, + buffers.intBuffers[idxCol].data(), + sizeof(SQLINTEGER), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SLONG data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_SMALLINT: { + buffers.smallIntBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_SSHORT, + buffers.smallIntBuffers[idxCol].data(), + sizeof(SQLSMALLINT), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SSHORT data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_TINYINT: { + buffers.charBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_TINYINT, + buffers.charBuffers[idxCol].data(), + sizeof(SQLCHAR), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TINYINT data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_BIT: { + buffers.charBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_BIT, + buffers.charBuffers[idxCol].data(), + sizeof(SQLCHAR), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching BIT data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_REAL: { + buffers.realBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_FLOAT, + buffers.realBuffers[idxCol].data(), + sizeof(SQLREAL), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching FLOAT data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching CHAR data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + buffers.doubleBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_DOUBLE, + buffers.doubleBuffers[idxCol].data(), + sizeof(SQLDOUBLE), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching DOUBLE data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + buffers.timestampBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[idxCol].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIMESTAMP data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_BIGINT: { + buffers.bigIntBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_SBIGINT, + buffers.bigIntBuffers[idxCol].data(), + sizeof(SQLBIGINT), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SBIGINT data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_TYPE_DATE: { + buffers.dateBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_TYPE_DATE, + buffers.dateBuffers[idxCol].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_DATE data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + buffers.timeBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_TYPE_TIME, + buffers.timeBuffers[idxCol].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIME data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_GUID: { + buffers.guidBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_GUID, + buffers.guidBuffers[idxCol].data(), + sizeof(SQLGUID), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching GUID data for column %d", idxCol + 1); + return ret; + } + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + buffers.datetimeoffsetBuffers[idxCol].resize(1); + ret = SQLGetData_ptr( + hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[idxCol].data(), + sizeof(DateTimeOffset), + buffers.indicators[idxCol].data() + ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", idxCol + 1); + return ret; + } + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << (idxCol + 1) + << ", Type - " << dataType; + LOG("SQLGetData: %s", errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + + SQLLEN dataLen = buffers.indicators[idxCol][idxRowSql]; + + if (dataLen == SQL_NULL_DATA) { + // Mark as null in validity bitmap + size_t bytePos = idxRowArrow / 8; + size_t bitPos = idxRowArrow % 8; + arrowColumnProducer->valid[bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + arrowColumnProducer->varVal[idxRowArrow + 1] = arrowColumnProducer->varVal[idxRowArrow]; + break; + default: + break; + } + + nullCounts[idxCol] += 1; + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", idxCol + 1, dataType, dataLen); + ThrowStdException("Unexpected negative data length."); + } + + switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; +#if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + while (target_vec->size() < start + utf8str.size()) { + target_vec->resize(target_vec->size() * 2); + } + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + utf8str.size(); +#endif + break; + } + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } + + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[idxCol][idxRowSql]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + arrowColumnProducer->varVal[idxRowArrow + 1] = start + 36; + break; + } + case SQL_TINYINT: + arrowColumnProducer->uint8Val[idxRowArrow] = buffers.charBuffers[idxCol][idxRowSql]; + break; + case SQL_SMALLINT: + arrowColumnProducer->int16Val[idxRowArrow] = buffers.smallIntBuffers[idxCol][idxRowSql]; + break; + case SQL_INTEGER: + arrowColumnProducer->int32Val[idxRowArrow] = buffers.intBuffers[idxCol][idxRowSql]; + break; + case SQL_BIGINT: + arrowColumnProducer->int64Val[idxRowArrow] = buffers.bigIntBuffers[idxCol][idxRowSql]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + arrowColumnProducer->float64Val[idxRowArrow] = buffers.doubleBuffers[idxCol][idxRowSql]; + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; + int sign = 1; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[idxCol][idx]; + if (digitChar == '-') { + sign = -1; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + } + arrowColumnProducer->decimalVal[idxRowArrow] = decimalValue * sign; + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TYPE_DATE: + arrowColumnProducer->dateVal[idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[idxCol][idxRowSql].year, + buffers.dateBuffers[idxCol][idxRowSql].month, + buffers.dateBuffers[idxCol][idxRowSql].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. + // To fully support SQL_SS_TIME2, the corresponding c-type should be used. + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[idxCol][idxRowSql]; + arrowColumnProducer->timeSecondVal[idxRowArrow] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); + break; + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[idxCol][idxRowSql] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = idxRowArrow / 8; + size_t bitIndex = idxRowArrow % 8; + + if (bitValue) { + // Set bit to 1 + arrowColumnProducer->bitVal[byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + arrowColumnProducer->bitVal[byteIndex] &= ~(1 << bitIndex); + } + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << (idxCol + 1) + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + idxRowArrow++; + } + } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + // Transfer ownership of buffers to batch ArrowSchema + // First, allocate memory for the necessary structures + auto arrowSchemaBatch = std::make_unique(); + + auto arrowSchemaBatchChildren = std::make_unique(numCols); + auto arrowSchemaBatchChildPointers = std::make_unique[]>(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowSchemaBatchChildPointers[i] = std::make_unique(); + } + + // Second, transfer ownership to arrowSchemaBatch + // No unhandled exceptions until the pycapsule owns the arrowSchemaBatch to avoid memory leaks + + for (SQLSMALLINT i = 0; i < numCols; i++) { + *arrowSchemaBatchChildPointers[i] = { + .format = arrowSchemaPrivateData[i]->format.get(), + .name = arrowSchemaPrivateData[i]->name.get(), + .metadata = nullptr, + .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), + .n_children = 0, + .children = nullptr, + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data != nullptr); + assert(schema->children == nullptr && schema->n_children == 0); + delete schema->private_data; // Frees format and name + schema->release = nullptr; + }, + .private_data = arrowSchemaPrivateData[i].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowSchemaBatchChildren[i] = arrowSchemaBatchChildPointers[i].release(); + } + + *arrowSchemaBatch = ArrowSchema{ + .format = "+s", + .name = "", + .metadata = nullptr, + .flags = 0, + .n_children = numCols, + .children = arrowSchemaBatchChildren.release(), + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + // format and name are string literals, no need to free + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data == nullptr); + assert(schema->children != nullptr); + assert(schema->n_children > 0); + for (int64_t i = 0; i < schema->n_children; ++i) { + if (schema->children[i]) { + if (schema->children[i]->release) { + schema->children[i]->release(schema->children[i]); + } + delete schema->children[i]; + } + } + delete[] schema->children; + schema->release = nullptr; + }, + .private_data = nullptr, + }; + + // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule + py::capsule arrowSchemaBatchCapsule; + try { + arrowSchemaBatchCapsule = py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { + auto arrowSchema = static_cast(ptr); + if (arrowSchema->release) { + arrowSchema->release(arrowSchema); + } + delete arrowSchema; + }); + } catch (...) { + arrowSchemaBatch->release(arrowSchemaBatch.get()); + throw; + } + arrowSchemaBatch.release(); + capsules.append(arrowSchemaBatchCapsule); + + // Transfer ownership of buffers to batch ArrowArray + // First, allocate memory for the necessary structures + auto arrowArrayBatch = std::make_unique(); + + auto arrowArrayBatchBuffers = std::make_unique(1); + arrowArrayBatchBuffers[0] = nullptr; + + auto arrowArrayBatchChildren = std::make_unique(numCols); + auto arrowArrayBatchChildPointers = std::make_unique[]>(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildPointers[i] = std::make_unique(); + } + + // Second, transfer ownership to arrowArrayBatch + // No unhandled exceptions until the pycapsule owns the arrowArrayBatch to avoid memory leaks + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + arrowArrayPrivateData[col]->buffers[0] = arrowArrayPrivateData[col]->valid.get(); + arrowArrayPrivateData[col]->buffers[1] = arrowArrayPrivateData[col]->ptrValueBuffer; + arrowArrayPrivateData[col]->buffers[2] = arrowArrayPrivateData[col]->varData.data(); + + *arrowArrayBatchChildPointers[col] = { + .length = static_cast(idxRowArrow), + .null_count = nullCounts[col], + .offset = 0, + .n_buffers = columnVarLen[col] ? 3 : 2, + .n_children = 0, + .buffers = (const void**)arrowArrayPrivateData[col]->buffers.data(), + .children = nullptr, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data != nullptr); + assert(array->release != nullptr); + assert(array->children == nullptr); + assert(array->n_children == 0); + delete array->private_data; // Frees all buffer entries + assert(array->buffers != nullptr); + array->release = nullptr; + }, + .private_data = arrowArrayPrivateData[col].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildren[i] = arrowArrayBatchChildPointers[i].release(); + } + + *arrowArrayBatch = ArrowArray{ + .length = static_cast(idxRowArrow), + .n_buffers = 1, + .n_children = numCols, + .buffers = arrowArrayBatchBuffers.release(), + .children = arrowArrayBatchChildren.release(), + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data == nullptr); + assert(array->release != nullptr); + assert(array->children != nullptr); + assert(array->n_children > 0); + for (int64_t i = 0; i < array->n_children; ++i) { + if (array->children[i]) { + if (array->children[i]->release) { + array->children[i]->release(array->children[i]); + } + delete array->children[i]; + } + } + delete[] array->children; + assert(array->buffers != nullptr); + assert(array->n_buffers == 1); + assert(array->buffers[0] == nullptr); + delete[] array->buffers; + array->release = nullptr; + }, + }; + + // Finally, transfer ownership of arrowArrayBatch and its pointer to pycapsule + py::capsule arrowArrayBatchCapsule; + try { + arrowArrayBatchCapsule = py::capsule(arrowArrayBatch.get(), "arrow_array", [](void* ptr) { + auto arrowArray = static_cast(ptr); + if (arrowArray->release) { + arrowArray->release(arrowArray); + } + delete arrowArray; + }); + } catch (...) { + arrowArrayBatch->release(arrowArrayBatch.get()); + throw; + } + arrowArrayBatch.release(); + capsules.append(arrowArrayBatchCapsule); + + return ret; +} + // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be @@ -4232,6 +5339,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, diff --git a/requirements.txt b/requirements.txt index 0951f7d0..4cd60771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pytest-cov coverage unittest-xml-reporting psutil +pyarrow # Build dependencies pybind11 diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index a37b2b6a..c82afaa4 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -18,6 +18,14 @@ import re from conftest import is_azure_sql_connection +try: + import pyarrow as pa + import pyarrow.parquet as pq + import io +except ImportError: + pa = None + pq = None + # Setup test table TEST_TABLE = """ @@ -14764,3 +14772,277 @@ def test_close(db_connection): pytest.fail(f"Cursor close test failed: {e}") finally: cursor = db_connection.cursor() + + +def get_arrow_test_data(include_lobs: bool, batch_length: int): + arrow_test_data = [ + (pa.uint8(), "tinyint", [1, 2, None, 4, 5, 0, 2**8 - 1]), + (pa.int16(), "smallint", [1, 2, None, 4, 5, -(2**15), 2**15 - 1]), + (pa.int32(), "int", [1, 2, None, 4, 5, 0, -(2**31), 2**31 - 1]), + (pa.int64(), "bigint", [1, 2, None, 4, 5, 0, -(2**63), 2**63 - 1]), + (pa.float64(), "float", [1.0, 2.5, None, 4.25, 5.125]), + ( + pa.decimal128(precision=10, scale=2), + "decimal(10, 2)", + [ + decimal.Decimal("1.23"), + None, + decimal.Decimal("0.25"), + decimal.Decimal("-99999999.99"), + decimal.Decimal("99999999.99"), + ], + ), + ( + pa.decimal128(precision=38, scale=10), + "decimal(38, 10)", + [ + decimal.Decimal("1.1234567890"), + None, + decimal.Decimal("0"), + decimal.Decimal("1.0000000001"), + decimal.Decimal("-9999999999999999999999999999.9999999999"), + decimal.Decimal("9999999999999999999999999999.9999999999"), + ], + ), + (pa.bool_(), "bit", [True, None, False]), + (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), + (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "uniqueidentifier", ["58185E0D-3A91-44D8-BC46-7107217E0A6D", None]), + (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), + ( + pa.time32("s"), + "time(0)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.time32("s"), + "time(7)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(0)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(3)", + [datetime(2025, 1, 1, 12, 0, 5, 123_000), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(6)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(7)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(2)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ] + + if include_lobs: + arrow_test_data += [ + (pa.string(), "nvarchar(max)", ["hey", None, "ho"]), + (pa.string(), "varchar(max)", ["hey", None, "ho"]), + (pa.binary(), "varbinary(max)", [b"hey", None, b"ho"]), + ] + + for ix in range(len(arrow_test_data)): + while True: + T, sql_type, vals = arrow_test_data[ix] + if len(vals) >= batch_length: + arrow_test_data[ix] = (T, sql_type, vals[:batch_length]) + break + arrow_test_data[ix] = (T, sql_type, vals + vals) + + return arrow_test_data + + +def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_length=500): + cols = [] + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + rows = [] + for value in values: + if type(value) is bool: + value = int(value) + if type(value) is bytes: + value = value.decode() + if value is None: + value = "null" + else: + value = f"'{value}'" + rows.append(f"col_{i_col} = cast({value} as {sql_type})") + cols.append(rows) + + selects = [] + for row in zip(*cols): + selects.append(f"select {', '.join(col for col in row)}") + full_query = "\nunion all\n".join(selects) + ret = cursor.execute(full_query).arrow_batch(fetch_length) + for i_col, col in enumerate(ret): + expected_data = arrow_test_data[i_col][2][:fetch_length] + for i_row, (v_expected, v_actual) in enumerate( + zip(expected_data, col.to_pylist(), strict=True) + ): + assert ( + v_expected == v_actual + ), f"Mismatch in column {i_col}, row {i_row}: expected {v_expected}, got {v_actual}" + # check that null counts match + expected_null_count = sum(1 for v in expected_data if v is None) + actual_null_count = col.null_count + assert expected_null_count == actual_null_count, (expected_null_count, actual_null_count) + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + field = ret.schema.field(i_col) + assert ( + field.name == f"col_{i_col}" + ), f"Column {i_col} name mismatch: expected col_{i_col}, got {field.name}" + assert field.type.equals( + pa_type + ), f"Column {i_col} type mismatch: expected {pa_type}, got {field.type}" + + # Validate that Parquet serialization/deserialization does not detect any issues + tbl = pa.Table.from_batches([ret]) + # for some reason parquet converts seconds to milliseconds in time32 + for i_col, col in enumerate(tbl.columns): + if col.type == pa.time32("s"): + tbl = tbl.set_column( + i_col, + tbl.schema.field(i_col).name, + col.cast(pa.time32("ms")), + ) + buffer = io.BytesIO() + pq.write_table(tbl, buffer) + buffer.seek(0) + read_table = pq.read_table(buffer) + assert read_table.equals(tbl) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_lob_wide(cursor: mssql_python.Cursor): + "Take the SQLGetData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_nolob_wide(cursor: mssql_python.Cursor): + "Test the SQLBindData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=False, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_single_column(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data]) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_fetch(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch of length 0." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data], fetch_length=0) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table_batchsize_negative(cursor: mssql_python.Cursor): + tbl = cursor.execute("select 1 a").arrow(batch_size=-42) + assert type(tbl) is pa.Table + assert tbl.num_rows == 0 + assert tbl.num_columns == 1 + assert cursor.fetchone()[0] == 1 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_result_set(cursor: mssql_python.Cursor): + "Test fetching from an empty result set." + cursor.execute("select 1 where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select 1, cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 2 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_no_result_set(cursor: mssql_python.Cursor): + "Test fetching when there is no result set." + cursor.execute("declare @a int") + with pytest.raises(Exception, match=".*No active result set.*"): + cursor.arrow_batch(10) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_datetimeoffset(cursor: mssql_python.Cursor): + "Datetimeoffset converts correctly to utc" + cursor.execute( + "declare @dt datetimeoffset(0) = '2345-02-03 12:34:56 +00:00';\n" + "select @dt, @dt at time zone 'Pacific Standard Time';\n" + ) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + for col in batch.columns: + assert pa.types.is_timestamp(col.type) + assert col.type.tz == "+00:00", col.type.tz + assert col.to_pylist() == [ + datetime(2345, 2, 3, 12, 34, 56, tzinfo=timezone.utc), + ] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_schema_nullable(cursor: mssql_python.Cursor): + "Test that the schema is nullable." + cursor.execute("select 1 a, null b") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + assert not batch.schema.field(0).nullable + assert batch.schema.field(1).nullable + assert batch.schema.field(0).name == "a" + assert batch.schema.field(1).name == "b" + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table(cursor: mssql_python.Cursor): + tbl = cursor.execute("select top 11 1 a from sys.objects").arrow(batch_size=5) + assert type(tbl) is pa.Table + assert tbl.num_rows == 11 + assert tbl.num_columns == 1 + assert [len(b) for b in tbl.to_batches()] == [5, 5, 1] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_reader(cursor: mssql_python.Cursor): + reader = cursor.execute("select top 11 1 a from sys.objects").arrow_reader(batch_size=4) + assert type(reader) is pa.RecordBatchReader + batches = list(reader) + assert [len(b) for b in batches] == [4, 4, 3] + assert sum(len(b) for b in batches) == 11 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_long_string(cursor: mssql_python.Cursor): + "Make sure resizing the data buffer works" + long_string = "A" * 100000 # 100k characters + cursor.execute("select cast(? as nvarchar(max))", (long_string,)) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 1 + assert batch.column(0).to_pylist() == [long_string]