Skip to content

Commit 54e57cc

Browse files
authored
Merge branch 'main' into subrata-ms/cp1252_encoding
2 parents 9b4a5f2 + b786900 commit 54e57cc

6 files changed

Lines changed: 1736 additions & 0 deletions

File tree

mssql_python/cursor.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
)
3737

3838
if TYPE_CHECKING:
39+
import pyarrow # type: ignore
3940
from mssql_python.connection import Connection
41+
else:
42+
pyarrow = None
4043

4144
# Constants for string handling
4245
MAX_INLINE_CHAR: int = (
@@ -775,6 +778,19 @@ def _check_closed(self) -> None:
775778
ddbc_error="",
776779
)
777780

781+
def _ensure_pyarrow(self) -> Any:
782+
"""
783+
Import and return pyarrow or raise ImportError accordingly.
784+
"""
785+
try:
786+
import pyarrow
787+
788+
return pyarrow
789+
except ImportError as e:
790+
raise ImportError(
791+
"pyarrow is required for Arrow fetch methods. Please install pyarrow."
792+
) from e
793+
778794
def setinputsizes(self, sizes: List[Union[int, tuple]]) -> None:
779795
"""
780796
Sets the type information to be used for parameters in execute and executemany.
@@ -2516,6 +2532,94 @@ def fetchall(self) -> List[Row]:
25162532
# On error, don't increment rownumber - rethrow the error
25172533
raise e
25182534

2535+
def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch":
2536+
"""
2537+
Fetch a single pyarrow Record Batch of the specified size from the
2538+
query result set.
2539+
2540+
Args:
2541+
batch_size: Maximum number of rows to fetch in the Record Batch.
2542+
2543+
Returns:
2544+
A pyarrow RecordBatch object containing up to batch_size rows.
2545+
"""
2546+
self._check_closed() # Check if the cursor is closed
2547+
pyarrow = self._ensure_pyarrow()
2548+
2549+
if not self._has_result_set and self.description:
2550+
self._reset_rownumber()
2551+
2552+
capsules = []
2553+
ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0))
2554+
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
2555+
2556+
batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules)
2557+
2558+
if self.hstmt:
2559+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
2560+
2561+
# Update rownumber for the number of rows actually fetched
2562+
num_fetched = batch.num_rows
2563+
if num_fetched > 0 and self._has_result_set:
2564+
self._next_row_index += num_fetched
2565+
self._rownumber = self._next_row_index - 1
2566+
2567+
# Centralize rowcount assignment after fetch
2568+
if num_fetched == 0 and self._next_row_index == 0:
2569+
self.rowcount = 0
2570+
else:
2571+
self.rowcount = self._next_row_index
2572+
2573+
return batch
2574+
2575+
def arrow(self, batch_size: int = 8192) -> "pyarrow.Table":
2576+
"""
2577+
Fetch the entire result as a pyarrow Table.
2578+
2579+
Args:
2580+
batch_size: Size of the Record Batches which make up the Table.
2581+
2582+
Returns:
2583+
A pyarrow Table containing all remaining rows from the result set.
2584+
"""
2585+
self._check_closed() # Check if the cursor is closed
2586+
pyarrow = self._ensure_pyarrow()
2587+
2588+
batches: list["pyarrow.RecordBatch"] = []
2589+
while True:
2590+
batch = self.arrow_batch(batch_size)
2591+
if batch.num_rows < batch_size or batch_size <= 0:
2592+
if not batches or batch.num_rows > 0:
2593+
batches.append(batch)
2594+
break
2595+
batches.append(batch)
2596+
return pyarrow.Table.from_batches(batches, schema=batches[0].schema)
2597+
2598+
def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
2599+
"""
2600+
Fetch the result as a pyarrow RecordBatchReader, which yields Record
2601+
Batches of the specified size until the current result set is
2602+
exhausted.
2603+
2604+
Args:
2605+
batch_size: Size of the Record Batches produced by the reader.
2606+
2607+
Returns:
2608+
A pyarrow RecordBatchReader for the result set.
2609+
"""
2610+
self._check_closed() # Check if the cursor is closed
2611+
pyarrow = self._ensure_pyarrow()
2612+
2613+
# Fetch schema without advancing cursor
2614+
schema_batch = self.arrow_batch(0)
2615+
schema = schema_batch.schema
2616+
2617+
def batch_generator():
2618+
while (batch := self.arrow_batch(batch_size)).num_rows > 0:
2619+
yield batch
2620+
2621+
return pyarrow.RecordBatchReader.from_batches(schema, batch_generator())
2622+
25192623
def nextset(self) -> Union[bool, None]:
25202624
"""
25212625
Skip to the next available result set.

mssql_python/mssql_python.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Type stubs for mssql_python package - based on actual public API
77
from typing import Any, Dict, List, Optional, Union, Tuple, Sequence, Callable, Iterator
88
import datetime
99
import logging
10+
import pyarrow
1011

1112
# GLOBALS - DB-API 2.0 Required Module Globals
1213
# https://www.python.org/dev/peps/pep-0249/#module-interface
@@ -199,6 +200,11 @@ class Cursor:
199200
def setinputsizes(self, sizes: List[Union[int, Tuple[Any, ...]]]) -> None: ...
200201
def setoutputsize(self, size: int, column: Optional[int] = None) -> None: ...
201202

203+
# Arrow Extension Methods (requires pyarrow)
204+
def arrow_batch(self, batch_size: int = 8192) -> pyarrow.RecordBatch: ...
205+
def arrow(self, batch_size: int = 8192) -> pyarrow.Table: ...
206+
def arrow_reader(self, batch_size: int = 8192) -> pyarrow.RecordBatchReader: ...
207+
202208
# DB-API 2.0 Connection Object
203209
# https://www.python.org/dev/peps/pep-0249/#connection-objects
204210
class Connection:

0 commit comments

Comments
 (0)