Skip to content
86 changes: 86 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from mssql_python import get_settings

if TYPE_CHECKING:
import pyarrow # type: ignore
from mssql_python.connection import Connection
else:
pyarrow = None

# Constants for string handling
MAX_INLINE_CHAR: int = (
Expand Down Expand Up @@ -2198,6 +2201,89 @@ def fetchall(self) -> List[Row]:
# On error, don't increment rownumber - rethrow the error
raise e

def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch":
"""
Fetch a single pyarrow Record Batch of the specified size from the
query result set.

Args:
batch_size: Maximum number of rows to fetch in the Record Batch.

Returns:
A pyarrow RecordBatch object containing up to batch_size rows.
"""
self._check_closed() # Check if the cursor is closed
if not self._has_result_set and self.description:
self._reset_rownumber()

try:
import pyarrow
except ImportError as e:
raise ImportError(
"pyarrow is required for arrow_batch(). Please install pyarrow."
) from e

capsules = []
ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0))
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules)
return batch

def arrow(self, batch_size: int = 8192) -> "pyarrow.Table":
"""
Fetch the entire result as a pyarrow Table.

Args:
batch_size: Size of the Record Batches which make up the Table.

Returns:
A pyarrow Table containing all remaining rows from the result set.
"""
try:
import pyarrow
except ImportError as e:
raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e

batches: list["pyarrow.RecordBatch"] = []
while True:
batch = self.arrow_batch(batch_size)
if batch.num_rows < batch_size or batch_size <= 0:
if not batches or batch.num_rows > 0:
batches.append(batch)
break
batches.append(batch)
return pyarrow.Table.from_batches(batches, schema=batches[0].schema)

def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
"""
Fetch the result as a pyarrow RecordBatchReader, which yields Record
Batches of the specified size until the current result set is
exhausted.

Args:
batch_size: Size of the Record Batches produced by the reader.

Returns:
A pyarrow RecordBatchReader for the result set.
"""
try:
import pyarrow
except ImportError as e:
raise ImportError(
"pyarrow is required for arrow_reader(). Please install pyarrow."
) from e

# Fetch schema without advancing cursor
schema_batch = self.arrow_batch(0)
schema = schema_batch.schema

def batch_generator():
while (batch := self.arrow_batch(batch_size)).num_rows > 0:
yield batch

return pyarrow.RecordBatchReader.from_batches(schema, batch_generator())

def nextset(self) -> Union[bool, None]:
"""
Skip to the next available result set.
Expand Down
Loading