Skip to content

Commit 5181a0e

Browse files
committed
updating file
1 parent 2c9c3d7 commit 5181a0e

5 files changed

Lines changed: 85 additions & 50 deletions

File tree

mssql_python/pybind/connection/connection.cpp

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,54 +26,63 @@ Connection::Connection(const std::wstring& conn_str, bool autocommit)
2626
DriverLoader::getInstance().loadDriver();
2727
}
2828
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env);
29-
checkError(ret, "Failed to allocate environment handle");
29+
checkError(ret);
3030
_envHandle = std::make_shared<SqlHandle>(SQL_HANDLE_ENV, env);
3131

3232
LOG("Setting environment attributes");
3333
ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0);
34-
checkError(ret, "Failed to set environment attribute");
34+
checkError(ret);
3535
}
3636
allocateDbcHandle();
3737
}
3838

3939
Connection::~Connection() {
40-
disconnect(); // fallback if app forgets to disconnect
40+
disconnect(); // fallback if user forgets to disconnect
4141
}
4242

4343
// Allocates connection handle
4444
void Connection::allocateDbcHandle() {
4545
SQLHANDLE dbc = nullptr;
4646
LOG("Allocate SQL Connection Handle");
4747
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc);
48-
checkError(ret, "Failed to allocate connection handle");
48+
checkError(ret);
4949
_dbcHandle = std::make_shared<SqlHandle>(SQL_HANDLE_DBC, dbc);
5050
}
5151

52-
void Connection::connect() {
52+
SQLRETURN Connection::connect(const py::dict& attrs_before) {
5353
LOG("Connecting to database");
54+
// Apply access token before connect
55+
if (!attrs_before.is_none() && py::len(attrs_before) > 0) {
56+
LOG("Apply attributes before connect");
57+
applyAttrsBefore(attrs_before);
58+
if (_autocommit) {
59+
setAutocommit(_autocommit);
60+
}
61+
}
5462
SQLRETURN ret = SQLDriverConnect_ptr(
5563
_dbcHandle->get(), nullptr,
5664
(SQLWCHAR*)_connStr.c_str(), SQL_NTS,
5765
nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT);
58-
checkError(ret, "SQLDriverConnect failed");
59-
setAutocommit(_autocommit);
66+
checkError(ret);
6067
}
6168

6269
void Connection::disconnect() {
6370
if (_dbcHandle) {
6471
LOG("Disconnecting from database");
6572
SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get());
66-
checkError(ret, "Failed to disconnect from database");
73+
checkError(ret);
6774
_dbcHandle.reset(); // triggers SQLFreeHandle via destructor, if last owner
6875
}
6976
else {
7077
LOG("No connection handle to disconnect");
7178
}
7279
}
7380

74-
void Connection::checkError(SQLRETURN ret, const std::string& msg) const{
75-
if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) {
76-
throw std::runtime_error("[ODBC Error] " + msg);
81+
void Connection::checkError(SQLRETURN ret) const{
82+
if (!SQL_SUCCEEDED(ret)) {
83+
ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret);
84+
std::string errorMsg = std::string(err.ddbcErrorMsg.begin(), err.ddbcErrorMsg.end());
85+
ThrowStdException(errorMsg);
7786
}
7887
}
7988

@@ -83,7 +92,7 @@ void Connection::commit() {
8392
}
8493
LOG("Committing transaction");
8594
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT);
86-
checkError(ret, "Failed to commit transaction");
95+
checkError(ret);
8796
}
8897

8998
void Connection::rollback() {
@@ -92,7 +101,7 @@ void Connection::rollback() {
92101
}
93102
LOG("Rolling back transaction");
94103
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK);
95-
checkError(ret, "Failed to rollback transaction");
104+
checkError(ret);
96105
}
97106

98107
void Connection::setAutocommit(bool enable) {
@@ -102,7 +111,7 @@ void Connection::setAutocommit(bool enable) {
102111
SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF;
103112
LOG("Set SQL Connection Attribute");
104113
SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0);
105-
checkError(ret, "Failed to set autocommit attribute");
114+
checkError(ret);
106115
_autocommit = enable;
107116
}
108117

@@ -114,7 +123,7 @@ bool Connection::getAutocommit() const {
114123
SQLINTEGER value;
115124
SQLINTEGER string_length;
116125
SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length);
117-
checkError(ret, "Failed to get autocommit attribute");
126+
checkError(ret);
118127
return value == SQL_AUTOCOMMIT_ON;
119128
}
120129

@@ -125,32 +134,54 @@ SqlHandlePtr Connection::allocStatementHandle() {
125134
LOG("Allocating statement handle");
126135
SQLHANDLE stmt = nullptr;
127136
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt);
128-
checkError(ret, "Failed to allocate statement handle");
137+
checkError(ret);
129138
return std::make_shared<SqlHandle>(SQL_HANDLE_STMT, stmt);
130139
}
131140

132-
SqlHandlePtr Connection::getSharedEnvHandle() {
133-
static std::once_flag flag;
134-
static SqlHandlePtr env_handle;
135141

136-
std::call_once(flag, []() {
137-
LOG("Allocating environment handle");
138-
SQLHANDLE env = nullptr;
139-
if (!SQLAllocHandle_ptr) {
140-
LOG("Function pointers not initialized, loading driver");
141-
DriverLoader::getInstance().loadDriver();
142-
}
143-
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env);
144-
if (!SQL_SUCCEEDED(ret)) {
145-
throw std::runtime_error("Failed to allocate environment handle");
142+
SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) {
143+
LOG("Setting SQL attribute");
144+
SQLPOINTER ptr = nullptr;
145+
SQLINTEGER length = 0;
146+
147+
if (py::isinstance<py::int_>(value)) {
148+
int intValue = value.cast<int>();
149+
ptr = reinterpret_cast<SQLPOINTER>(static_cast<uintptr_t>(intValue));
150+
length = SQL_IS_INTEGER;
151+
} else if (py::isinstance<py::bytes>(value) || py::isinstance<py::bytearray>(value)) {
152+
static std::vector<std::string> buffers;
153+
buffers.emplace_back(value.cast<std::string>());
154+
ptr = const_cast<char*>(buffers.back().c_str());
155+
length = static_cast<SQLINTEGER>(buffers.back().size());
156+
} else {
157+
LOG("Unsupported attribute value type");
158+
return SQL_ERROR;
159+
}
160+
161+
SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length);
162+
if (!SQL_SUCCEEDED(ret)) {
163+
LOG("Failed to set attribute");
164+
}
165+
else {
166+
LOG("Set attribute successfully");
167+
}
168+
return ret;
169+
}
170+
171+
void Connection::applyAttrsBefore(const py::dict& attrs) {
172+
for (const auto& item : attrs) {
173+
int key;
174+
try {
175+
key = py::cast<int>(item.first);
176+
} catch (...) {
177+
continue;
146178
}
147-
env_handle = std::make_shared<SqlHandle>(SQL_HANDLE_ENV, env);
148179

149-
LOG("Setting environment attributes");
150-
ret = SQLSetEnvAttr_ptr(env_handle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0);
151-
if (!SQL_SUCCEEDED(ret)) {
152-
throw std::runtime_error("Failed to set environment attribute");
180+
if (key == SQL_COPT_SS_ACCESS_TOKEN) {
181+
SQLRETURN ret = setAttribute(key, py::reinterpret_borrow<py::object>(item.second));
182+
if (!SQL_SUCCEEDED(ret)) {
183+
throw std::runtime_error("Failed to set access token before connect");
184+
}
153185
}
154-
});
155-
return env_handle;
186+
}
156187
}

mssql_python/pybind/connection/connection.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Connection {
1717
~Connection();
1818

1919
// Establish the connection using the stored connection string.
20-
void connect();
20+
SQLRETURN connect(const py::dict& attrs_before = py::dict());
2121

2222
// Disconnect and free the connection handle.
2323
void disconnect();
@@ -39,7 +39,9 @@ class Connection {
3939

4040
private:
4141
void allocateDbcHandle();
42-
void checkError(SQLRETURN ret, const std::string& msg) const;
42+
void checkError(SQLRETURN ret) const;
43+
SQLRETURN setAttribute(SQLINTEGER attribute, py::object value);
44+
void applyAttrsBefore(const py::dict& attrs_before);
4345

4446
std::wstring _connStr;
4547
bool _usePool = false;

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,6 @@ struct ColumnBuffers {
9595
indicators(numCols, std::vector<SQLLEN>(fetchSize)) {}
9696
};
9797

98-
// This struct is used to relay error info obtained from SQLDiagRec API to the Python module
99-
struct ErrorInfo {
100-
std::wstring sqlState;
101-
std::wstring ddbcErrorMsg;
102-
};
103-
10498
//-------------------------------------------------------------------------------------------------
10599
// Function pointer initialization
106100
//-------------------------------------------------------------------------------------------------
@@ -1927,8 +1921,8 @@ PYBIND11_MODULE(ddbc_bindings, m) {
19271921
.def("free", &SqlHandle::free, "Free the handle");
19281922
py::class_<Connection>(m, "Connection")
19291923
.def(py::init<const std::wstring&, bool>(), py::arg("conn_str"), py::arg("autocommit") = false)
1930-
.def("connect", &Connection::connect, py::arg("attrs_before") = py::dict(), "Establish a connection to the database")
1931-
.def("close", &Connection::close, "Close the connection")
1924+
.def("connect", &Connection::connect)
1925+
.def("close", &Connection::disconnect, "Close the connection")
19321926
.def("commit", &Connection::commit, "Commit the current transaction")
19331927
.def("rollback", &Connection::rollback, "Rollback the current transaction")
19341928
.def("set_autocommit", &Connection::setAutocommit)

mssql_python/pybind/ddbc_bindings.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ extern SQLGetDiagRecFunc SQLGetDiagRec_ptr;
121121
template <typename... Args>
122122
void LOG(const std::string& formatString, Args&&... args);
123123

124+
124125
// Throws a std::runtime_error with the given message
125126
void ThrowStdException(const std::string& message);
126127

@@ -169,3 +170,10 @@ class SqlHandle {
169170
SQLHANDLE _handle;
170171
};
171172
using SqlHandlePtr = std::shared_ptr<SqlHandle>;
173+
174+
// This struct is used to relay error info obtained from SQLDiagRec API to the Python module
175+
struct ErrorInfo {
176+
std::wstring sqlState;
177+
std::wstring ddbcErrorMsg;
178+
};
179+
ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode);

tests/test_005_exceptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_foreign_key_constraint_error(cursor, db_connection):
124124
drop_table_if_exists(cursor, "pytest_parent_table")
125125
db_connection.commit()
126126

127-
# def test_connection_error(db_connection):
128-
# with pytest.raises(OperationalError) as excinfo:
129-
# Connection("InvalidConnectionString")
130-
# assert "Client unable to establish connection" in str(excinfo.value)
127+
def test_connection_error(db_connection):
128+
with pytest.raises(RuntimeError) as excinfo:
129+
Connection("InvalidConnectionString")
130+
assert "Neither DSN nor SERVER keyword supplied" in str(excinfo.value)

0 commit comments

Comments
 (0)