|
36 | 36 | ) |
37 | 37 |
|
38 | 38 | if TYPE_CHECKING: |
| 39 | + import pyarrow # type: ignore |
39 | 40 | from mssql_python.connection import Connection |
| 41 | +else: |
| 42 | + pyarrow = None |
40 | 43 |
|
41 | 44 | # Constants for string handling |
42 | 45 | MAX_INLINE_CHAR: int = ( |
@@ -775,6 +778,19 @@ def _check_closed(self) -> None: |
775 | 778 | ddbc_error="", |
776 | 779 | ) |
777 | 780 |
|
| 781 | + def _ensure_pyarrow(self) -> Any: |
| 782 | + """ |
| 783 | + Import and return pyarrow or raise ImportError accordingly. |
| 784 | + """ |
| 785 | + try: |
| 786 | + import pyarrow |
| 787 | + |
| 788 | + return pyarrow |
| 789 | + except ImportError as e: |
| 790 | + raise ImportError( |
| 791 | + "pyarrow is required for Arrow fetch methods. Please install pyarrow." |
| 792 | + ) from e |
| 793 | + |
778 | 794 | def setinputsizes(self, sizes: List[Union[int, tuple]]) -> None: |
779 | 795 | """ |
780 | 796 | Sets the type information to be used for parameters in execute and executemany. |
@@ -2516,6 +2532,94 @@ def fetchall(self) -> List[Row]: |
2516 | 2532 | # On error, don't increment rownumber - rethrow the error |
2517 | 2533 | raise e |
2518 | 2534 |
|
| 2535 | + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": |
| 2536 | + """ |
| 2537 | + Fetch a single pyarrow Record Batch of the specified size from the |
| 2538 | + query result set. |
| 2539 | +
|
| 2540 | + Args: |
| 2541 | + batch_size: Maximum number of rows to fetch in the Record Batch. |
| 2542 | +
|
| 2543 | + Returns: |
| 2544 | + A pyarrow RecordBatch object containing up to batch_size rows. |
| 2545 | + """ |
| 2546 | + self._check_closed() # Check if the cursor is closed |
| 2547 | + pyarrow = self._ensure_pyarrow() |
| 2548 | + |
| 2549 | + if not self._has_result_set and self.description: |
| 2550 | + self._reset_rownumber() |
| 2551 | + |
| 2552 | + capsules = [] |
| 2553 | + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) |
| 2554 | + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) |
| 2555 | + |
| 2556 | + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) |
| 2557 | + |
| 2558 | + if self.hstmt: |
| 2559 | + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) |
| 2560 | + |
| 2561 | + # Update rownumber for the number of rows actually fetched |
| 2562 | + num_fetched = batch.num_rows |
| 2563 | + if num_fetched > 0 and self._has_result_set: |
| 2564 | + self._next_row_index += num_fetched |
| 2565 | + self._rownumber = self._next_row_index - 1 |
| 2566 | + |
| 2567 | + # Centralize rowcount assignment after fetch |
| 2568 | + if num_fetched == 0 and self._next_row_index == 0: |
| 2569 | + self.rowcount = 0 |
| 2570 | + else: |
| 2571 | + self.rowcount = self._next_row_index |
| 2572 | + |
| 2573 | + return batch |
| 2574 | + |
| 2575 | + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": |
| 2576 | + """ |
| 2577 | + Fetch the entire result as a pyarrow Table. |
| 2578 | +
|
| 2579 | + Args: |
| 2580 | + batch_size: Size of the Record Batches which make up the Table. |
| 2581 | +
|
| 2582 | + Returns: |
| 2583 | + A pyarrow Table containing all remaining rows from the result set. |
| 2584 | + """ |
| 2585 | + self._check_closed() # Check if the cursor is closed |
| 2586 | + pyarrow = self._ensure_pyarrow() |
| 2587 | + |
| 2588 | + batches: list["pyarrow.RecordBatch"] = [] |
| 2589 | + while True: |
| 2590 | + batch = self.arrow_batch(batch_size) |
| 2591 | + if batch.num_rows < batch_size or batch_size <= 0: |
| 2592 | + if not batches or batch.num_rows > 0: |
| 2593 | + batches.append(batch) |
| 2594 | + break |
| 2595 | + batches.append(batch) |
| 2596 | + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) |
| 2597 | + |
| 2598 | + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": |
| 2599 | + """ |
| 2600 | + Fetch the result as a pyarrow RecordBatchReader, which yields Record |
| 2601 | + Batches of the specified size until the current result set is |
| 2602 | + exhausted. |
| 2603 | +
|
| 2604 | + Args: |
| 2605 | + batch_size: Size of the Record Batches produced by the reader. |
| 2606 | +
|
| 2607 | + Returns: |
| 2608 | + A pyarrow RecordBatchReader for the result set. |
| 2609 | + """ |
| 2610 | + self._check_closed() # Check if the cursor is closed |
| 2611 | + pyarrow = self._ensure_pyarrow() |
| 2612 | + |
| 2613 | + # Fetch schema without advancing cursor |
| 2614 | + schema_batch = self.arrow_batch(0) |
| 2615 | + schema = schema_batch.schema |
| 2616 | + |
| 2617 | + def batch_generator(): |
| 2618 | + while (batch := self.arrow_batch(batch_size)).num_rows > 0: |
| 2619 | + yield batch |
| 2620 | + |
| 2621 | + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) |
| 2622 | + |
2519 | 2623 | def nextset(self) -> Union[bool, None]: |
2520 | 2624 | """ |
2521 | 2625 | Skip to the next available result set. |
|
0 commit comments