Skip to content

Commit 079bf57

Browse files
committed
Merge remote-tracking branch 'origin/main' into saumya/pool-c++
2 parents d373e9b + fc99220 commit 079bf57

6 files changed

Lines changed: 138 additions & 9 deletions

File tree

mssql_python/bcp_options.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from dataclasses import dataclass, field
2+
from typing import List, Optional, Literal
3+
4+
5+
@dataclass
6+
class ColumnFormat:
7+
"""
8+
Represents the format of a column in a bulk copy operation.
9+
Attributes:
10+
prefix_len (int): Option: (format_file) or (prefix_len, data_len).
11+
The length of the prefix for fixed-length data types. Must be non-negative.
12+
data_len (int): Option: (format_file) or (prefix_len, data_len).
13+
The length of the data. Must be non-negative.
14+
field_terminator (Optional[bytes]): Option: (-t). The field terminator string.
15+
e.g., b',' for comma-separated values.
16+
row_terminator (Optional[bytes]): Option: (-r). The row terminator string.
17+
e.g., b'\\n' for newline-terminated rows.
18+
server_col (int): Option: (format_file) or (server_col). The 1-based column number
19+
in the SQL Server table. Defaults to 1, representing the first column.
20+
Must be a positive integer.
21+
file_col (int): Option: (format_file) or (file_col). The 1-based column number
22+
in the data file. Defaults to 1, representing the first column.
23+
Must be a positive integer.
24+
"""
25+
26+
prefix_len: int
27+
data_len: int
28+
field_terminator: Optional[bytes] = None
29+
row_terminator: Optional[bytes] = None
30+
server_col: int = 1
31+
file_col: int = 1
32+
33+
def __post_init__(self):
34+
if self.prefix_len < 0:
35+
raise ValueError("prefix_len must be a non-negative integer.")
36+
if self.data_len < 0:
37+
raise ValueError("data_len must be a non-negative integer.")
38+
if self.server_col <= 0:
39+
raise ValueError("server_col must be a positive integer (1-based).")
40+
if self.file_col <= 0:
41+
raise ValueError("file_col must be a positive integer (1-based).")
42+
if self.field_terminator is not None and not isinstance(
43+
self.field_terminator, bytes
44+
):
45+
raise TypeError("field_terminator must be bytes or None.")
46+
if self.row_terminator is not None and not isinstance(
47+
self.row_terminator, bytes
48+
):
49+
raise TypeError("row_terminator must be bytes or None.")
50+
51+
52+
@dataclass
53+
class BCPOptions:
54+
"""
55+
Represents the options for a bulk copy operation.
56+
Attributes:
57+
direction (Literal[str]): 'in' or 'out'. Option: (-i or -o).
58+
data_file (str): The data file. Option: (positional argument).
59+
error_file (Optional[str]): The error file. Option: (-e).
60+
format_file (Optional[str]): The format file to use for 'in'/'out'. Option: (-f).
61+
batch_size (Optional[int]): The batch size. Option: (-b).
62+
max_errors (Optional[int]): The maximum number of errors allowed. Option: (-m).
63+
first_row (Optional[int]): The first row to process. Option: (-F).
64+
last_row (Optional[int]): The last row to process. Option: (-L).
65+
code_page (Optional[str]): The code page. Option: (-C).
66+
keep_identity (bool): Keep identity values. Option: (-E).
67+
keep_nulls (bool): Keep null values. Option: (-k).
68+
hints (Optional[str]): Additional hints. Option: (-h).
69+
bulk_mode (str): Bulk mode ('native', 'char', 'unicode'). Option: (-n, -c, -w).
70+
Defaults to "native".
71+
columns (List[ColumnFormat]): Column formats.
72+
"""
73+
74+
direction: Literal["in", "out"]
75+
data_file: str # data_file is mandatory for 'in' and 'out'
76+
error_file: Optional[str] = None
77+
format_file: Optional[str] = None
78+
# write_format_file is removed as 'format' direction is not actively supported
79+
batch_size: Optional[int] = None
80+
max_errors: Optional[int] = None
81+
first_row: Optional[int] = None
82+
last_row: Optional[int] = None
83+
code_page: Optional[str] = None
84+
keep_identity: bool = False
85+
keep_nulls: bool = False
86+
hints: Optional[str] = None
87+
bulk_mode: Literal["native", "char", "unicode"] = "native"
88+
columns: List[ColumnFormat] = field(default_factory=list)
89+
90+
def __post_init__(self):
91+
if self.direction not in ["in", "out"]:
92+
raise ValueError("direction must be 'in' or 'out'.")
93+
if not self.data_file:
94+
raise ValueError("data_file must be provided and non-empty for 'in' or 'out' directions.")
95+
if self.error_file is None or not self.error_file: # Making error_file mandatory for in/out
96+
raise ValueError("error_file must be provided and non-empty for 'in' or 'out' directions.")
97+
98+
if self.format_file is not None and not self.format_file:
99+
raise ValueError("format_file, if provided, must not be an empty string.")
100+
if self.batch_size is not None and self.batch_size <= 0:
101+
raise ValueError("batch_size must be a positive integer.")
102+
if self.max_errors is not None and self.max_errors < 0:
103+
raise ValueError("max_errors must be a non-negative integer.")
104+
if self.first_row is not None and self.first_row <= 0:
105+
raise ValueError("first_row must be a positive integer.")
106+
if self.last_row is not None and self.last_row <= 0:
107+
raise ValueError("last_row must be a positive integer.")
108+
if self.last_row is not None and self.first_row is None:
109+
raise ValueError("first_row must be specified if last_row is specified.")
110+
if (
111+
self.first_row is not None
112+
and self.last_row is not None
113+
and self.last_row < self.first_row
114+
):
115+
raise ValueError("last_row must be greater than or equal to first_row.")
116+
if self.code_page is not None and not self.code_page:
117+
raise ValueError("code_page, if provided, must not be an empty string.")
118+
if self.hints is not None and not self.hints:
119+
raise ValueError("hints, if provided, must not be an empty string.")
120+
if self.bulk_mode not in ["native", "char", "unicode"]:
121+
raise ValueError("bulk_mode must be 'native', 'char', or 'unicode'.")

mssql_python/connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const
1212
from mssql_python.helpers import add_driver_to_connection_str, check_error
1313
from mssql_python import ddbc_bindings
14+
# from mssql_python.pooling import PoolingManager
1415

1516
logger = get_logger()
1617

@@ -57,8 +58,9 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
5758
connection_str, **kwargs
5859
)
5960
self._attrs_before = attrs_before or {}
60-
self._conn = ddbc_bindings.Connection(self.connection_str, autocommit)
61-
self._conn.connect(self._attrs_before)
61+
# self._pooling = PoolingManager.is_enabled()
62+
self._pooling = False
63+
self._conn = ddbc_bindings.Connection(self.connection_str, self._pooling, self._attrs_before)
6264
self.setautocommit(autocommit)
6365

6466
def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str:

mssql_python/pybind/connection/connection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,4 @@ SqlHandlePtr ConnectionHandle::allocStatementHandle() {
285285
ThrowStdException("Connection object is not initialized");
286286
}
287287
return _conn->allocStatementHandle();
288-
}
288+
}

mssql_python/pybind/connection/connection.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,13 @@ class Connection {
3333

3434
// Check whether autocommit is enabled.
3535
bool getAutocommit() const;
36-
3736
bool isAlive() const;
38-
3937
bool reset();
40-
4138
void updateLastUsed();
42-
4339
std::chrono::steady_clock::time_point lastUsed() const;
4440

4541
// Allocate a new statement handle on this connection.
46-
SqlHandlePtr allocStatementHandle();
42+
SqlHandlePtr allocStatementHandle();
4743

4844
private:
4945
void allocateDbcHandle();

mssql_python/pybind/connection/connection_pool.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be
5+
// taken up in future.
6+
17
#include "connection_pool.h"
28
#include <iostream>
39
#include <exception>
@@ -36,6 +42,7 @@ std::shared_ptr<Connection> ConnectionPool::acquire(const std::wstring& connStr,
3642
if (_current_size < _max_size) {
3743
auto conn = std::make_shared<Connection>(connStr, true);
3844
conn->connect(attrs_before);
45+
++_current_size;
3946
return conn;
4047
} else {
4148
LOG("Cannot acquire connection: pool size limit reached");

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1877,8 +1877,11 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) {
18771877
return rowCount;
18781878
}
18791879

1880+
static std::once_flag pooling_init_flag;
18801881
void enable_pooling(int maxSize, int idleTimeout) {
1881-
ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout);
1882+
std::call_once(pooling_init_flag, [&]() {
1883+
ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout);
1884+
});
18821885
}
18831886

18841887
// Architecture-specific defines

0 commit comments

Comments
 (0)