Skip to content

Commit 2c9c3d7

Browse files
committed
Merge branch 'saumya/conn_implementation' into saumya/integratec++class
2 parents e302c0c + 6a077c6 commit 2c9c3d7

4 files changed

Lines changed: 87 additions & 140 deletions

File tree

mssql_python/pybind/connection/connection.cpp

Lines changed: 66 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -10,176 +10,125 @@
1010

1111
#define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token
1212

13+
SqlHandlePtr Connection::_envHandle = nullptr;
1314
//-------------------------------------------------------------------------------------------------
1415
// Implements the Connection class declared in connection.h.
1516
// This class wraps low-level ODBC operations like connect/disconnect,
1617
// transaction control, and autocommit configuration.
1718
//-------------------------------------------------------------------------------------------------
1819
Connection::Connection(const std::wstring& conn_str, bool autocommit)
19-
: _conn_str(conn_str) , _autocommit(autocommit) {}
20-
21-
Connection::~Connection() {}
22-
23-
SQLRETURN Connection::connect(const py::dict& attrs_before) {
24-
allocDbcHandle();
25-
// Apply access token before connect
26-
if (!attrs_before.is_none() && py::len(attrs_before) > 0) {
27-
LOG("Apply attributes before connect");
28-
applyAttrsBefore(attrs_before);
29-
if (_autocommit) {
30-
setAutocommit(_autocommit);
20+
: _connStr(conn_str) , _autocommit(autocommit) {
21+
if (!_envHandle) {
22+
LOG("Allocating environment handle");
23+
SQLHANDLE env = nullptr;
24+
if (!SQLAllocHandle_ptr) {
25+
LOG("Function pointers not initialized, loading driver");
26+
DriverLoader::getInstance().loadDriver();
3127
}
28+
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env);
29+
checkError(ret, "Failed to allocate environment handle");
30+
_envHandle = std::make_shared<SqlHandle>(SQL_HANDLE_ENV, env);
31+
32+
LOG("Setting environment attributes");
33+
ret = SQLSetEnvAttr_ptr(_envHandle->get(), SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3_80, 0);
34+
checkError(ret, "Failed to set environment attribute");
3235
}
33-
return connectToDb();
36+
allocateDbcHandle();
3437
}
3538

36-
// Allocates DBC handle
37-
void Connection::allocDbcHandle() {
39+
Connection::~Connection() {
40+
disconnect(); // fallback if app forgets to disconnect
41+
}
42+
43+
// Allocates connection handle
44+
void Connection::allocateDbcHandle() {
3845
SQLHANDLE dbc = nullptr;
3946
LOG("Allocate SQL Connection Handle");
40-
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, getSharedEnvHandle()->get(), &dbc);
41-
if (!SQL_SUCCEEDED(ret)) {
42-
throw std::runtime_error("Failed to allocate connection handle");
43-
}
44-
_dbc_handle = std::make_shared<SqlHandle>(SQL_HANDLE_DBC, dbc);
47+
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_DBC, _envHandle->get(), &dbc);
48+
checkError(ret, "Failed to allocate connection handle");
49+
_dbcHandle = std::make_shared<SqlHandle>(SQL_HANDLE_DBC, dbc);
4550
}
4651

47-
// Connects to the database
48-
SQLRETURN Connection::connectToDb() {
52+
void Connection::connect() {
4953
LOG("Connecting to database");
50-
SQLRETURN ret = SQLDriverConnect_ptr(_dbc_handle->get(), nullptr,
51-
(SQLWCHAR*)_conn_str.c_str(), SQL_NTS,
52-
nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT);
53-
if (!SQL_SUCCEEDED(ret)) {
54-
ThrowStdException("Client unable to establish connection");
55-
}
56-
LOG("Connected to database successfully");
57-
return ret;
54+
SQLRETURN ret = SQLDriverConnect_ptr(
55+
_dbcHandle->get(), nullptr,
56+
(SQLWCHAR*)_connStr.c_str(), SQL_NTS,
57+
nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT);
58+
checkError(ret, "SQLDriverConnect failed");
59+
setAutocommit(_autocommit);
5860
}
5961

60-
SQLRETURN Connection::close() {
61-
if (!_dbc_handle) {
62-
LOG("No connection handle to close");
63-
return SQL_SUCCESS;
62+
void Connection::disconnect() {
63+
if (_dbcHandle) {
64+
LOG("Disconnecting from database");
65+
SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get());
66+
checkError(ret, "Failed to disconnect from database");
67+
_dbcHandle.reset(); // triggers SQLFreeHandle via destructor, if last owner
6468
}
65-
LOG("Disconnect from MSSQL");
66-
if (!SQLDisconnect_ptr) {
67-
LOG("Function pointer not initialized. Loading the driver.");
68-
DriverLoader::getInstance().loadDriver();
69+
else {
70+
LOG("No connection handle to disconnect");
6971
}
72+
}
7073

71-
SQLRETURN ret = SQLDisconnect_ptr(_dbc_handle->get());
72-
_dbc_handle->free();
73-
return ret;
74+
void Connection::checkError(SQLRETURN ret, const std::string& msg) const{
75+
if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) {
76+
throw std::runtime_error("[ODBC Error] " + msg);
77+
}
7478
}
7579

76-
SQLRETURN Connection::commit() {
77-
if (!_dbc_handle) {
80+
void Connection::commit() {
81+
if (!_dbcHandle) {
7882
throw std::runtime_error("Connection handle not allocated");
7983
}
8084
LOG("Committing transaction");
81-
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbc_handle->get(), SQL_COMMIT);
82-
if (!SQL_SUCCEEDED(ret)) {
83-
throw std::runtime_error("Failed to commit transaction");
84-
}
85-
return ret;
85+
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT);
86+
checkError(ret, "Failed to commit transaction");
8687
}
8788

88-
SQLRETURN Connection::rollback() {
89-
if (!_dbc_handle) {
89+
void Connection::rollback() {
90+
if (!_dbcHandle) {
9091
throw std::runtime_error("Connection handle not allocated");
9192
}
9293
LOG("Rolling back transaction");
93-
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbc_handle->get(), SQL_ROLLBACK);
94-
if (!SQL_SUCCEEDED(ret)) {
95-
throw std::runtime_error("Failed to rollback transaction");
96-
}
97-
return ret;
94+
SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK);
95+
checkError(ret, "Failed to rollback transaction");
9896
}
9997

100-
SQLRETURN Connection::setAutocommit(bool enable) {
101-
SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF;
102-
LOG("Set SQL Connection Attribute - Autocommit");
103-
SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0);
104-
if (!SQL_SUCCEEDED(ret)) {
105-
throw std::runtime_error("Failed to set autocommit mode.");
98+
void Connection::setAutocommit(bool enable) {
99+
if (!_dbcHandle) {
100+
throw std::runtime_error("Connection handle not allocated");
106101
}
102+
SQLINTEGER value = enable ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF;
103+
LOG("Set SQL Connection Attribute");
104+
SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)value, 0);
105+
checkError(ret, "Failed to set autocommit attribute");
107106
_autocommit = enable;
108-
return ret;
109107
}
110108

111109
bool Connection::getAutocommit() const {
112-
if (!_dbc_handle) {
110+
if (!_dbcHandle) {
113111
throw std::runtime_error("Connection handle not allocated");
114112
}
115113
LOG("Get SQL Connection Attribute");
116114
SQLINTEGER value;
117115
SQLINTEGER string_length;
118-
SQLGetConnectAttr_ptr(_dbc_handle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length);
116+
SQLRETURN ret = SQLGetConnectAttr_ptr(_dbcHandle->get(), SQL_ATTR_AUTOCOMMIT, &value, sizeof(value), &string_length);
117+
checkError(ret, "Failed to get autocommit attribute");
119118
return value == SQL_AUTOCOMMIT_ON;
120119
}
121120

122121
SqlHandlePtr Connection::allocStatementHandle() {
123-
if (!_dbc_handle) {
122+
if (!_dbcHandle) {
124123
throw std::runtime_error("Connection handle not allocated");
125124
}
126125
LOG("Allocating statement handle");
127126
SQLHANDLE stmt = nullptr;
128-
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbc_handle->get(), &stmt);
129-
if (!SQL_SUCCEEDED(ret)) {
130-
throw std::runtime_error("Failed to allocate statement handle");
131-
}
127+
SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt);
128+
checkError(ret, "Failed to allocate statement handle");
132129
return std::make_shared<SqlHandle>(SQL_HANDLE_STMT, stmt);
133130
}
134131

135-
SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) {
136-
LOG("Setting SQL attribute");
137-
138-
SQLPOINTER ptr = nullptr;
139-
SQLINTEGER length = 0;
140-
141-
if (py::isinstance<py::int_>(value)) {
142-
int intValue = value.cast<int>();
143-
ptr = reinterpret_cast<SQLPOINTER>(static_cast<uintptr_t>(intValue));
144-
length = SQL_IS_INTEGER;
145-
} else if (py::isinstance<py::bytes>(value) || py::isinstance<py::bytearray>(value)) {
146-
static std::vector<std::string> buffers;
147-
buffers.emplace_back(value.cast<std::string>());
148-
ptr = const_cast<char*>(buffers.back().c_str());
149-
length = static_cast<SQLINTEGER>(buffers.back().size());
150-
} else {
151-
LOG("Unsupported attribute value type");
152-
return SQL_ERROR;
153-
}
154-
155-
SQLRETURN ret = SQLSetConnectAttr_ptr(_dbc_handle->get(), attribute, ptr, length);
156-
if (!SQL_SUCCEEDED(ret)) {
157-
LOG("Failed to set attribute");
158-
}
159-
else {
160-
LOG("Set attribute successfully");
161-
}
162-
return ret;
163-
}
164-
165-
void Connection::applyAttrsBefore(const py::dict& attrs) {
166-
for (const auto& item : attrs) {
167-
int key;
168-
try {
169-
key = py::cast<int>(item.first);
170-
} catch (...) {
171-
continue;
172-
}
173-
174-
if (key == SQL_COPT_SS_ACCESS_TOKEN) {
175-
SQLRETURN ret = setAttribute(key, py::reinterpret_borrow<py::object>(item.second));
176-
if (!SQL_SUCCEEDED(ret)) {
177-
throw std::runtime_error("Failed to set access token before connect");
178-
}
179-
}
180-
}
181-
}
182-
183132
SqlHandlePtr Connection::getSharedEnvHandle() {
184133
static std::once_flag flag;
185134
static SqlHandlePtr env_handle;

mssql_python/pybind/connection/connection.h

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be
55
// taken up in future.
66

7-
#ifndef CONNECTION_H
8-
#define CONNECTION_H
9-
7+
#pragma once
108
#include "ddbc_bindings.h"
119

1210
// Represents a single ODBC database connection.
@@ -19,19 +17,19 @@ class Connection {
1917
~Connection();
2018

2119
// Establish the connection using the stored connection string.
22-
SQLRETURN connect(const py::dict& attrs_before = py::dict());
20+
void connect();
2321

24-
// Close the connection and free resources.
25-
SQLRETURN close();
22+
// Disconnect and free the connection handle.
23+
void disconnect();
2624

2725
// Commit the current transaction.
28-
SQLRETURN commit();
26+
void commit();
2927

3028
// Rollback the current transaction.
31-
SQLRETURN rollback();
29+
void rollback();
3230

3331
// Enable or disable autocommit mode.
34-
SQLRETURN setAutocommit(bool value);
32+
void setAutocommit(bool value);
3533

3634
// Check whether autocommit is enabled.
3735
bool getAutocommit() const;
@@ -40,16 +38,13 @@ class Connection {
4038
SqlHandlePtr allocStatementHandle();
4139

4240
private:
43-
void allocDbcHandle();
44-
SQLRETURN connectToDb();
45-
46-
std::wstring _conn_str;
47-
SqlHandlePtr _dbc_handle;
48-
bool _autocommit = false;
49-
50-
static SqlHandlePtr getSharedEnvHandle();
51-
SQLRETURN setAttribute(SQLINTEGER attribute, pybind11::object value);
52-
void applyAttrsBefore(const pybind11::dict& attrs);
53-
};
41+
void allocateDbcHandle();
42+
void checkError(SQLRETURN ret, const std::string& msg) const;
5443

55-
#endif // CONNECTION_H
44+
std::wstring _connStr;
45+
bool _usePool = false;
46+
bool _autocommit = true;
47+
SqlHandlePtr _dbcHandle;
48+
49+
static SqlHandlePtr _envHandle;
50+
};

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,10 +619,10 @@ DriverLoader& DriverLoader::getInstance() {
619619
}
620620

621621
void DriverLoader::loadDriver() {
622-
if (!m_driverLoaded) {
622+
std::call_once(m_onceFlag, [this]() {
623623
LoadDriverOrThrowException();
624624
m_driverLoaded = true;
625-
}
625+
});
626626
}
627627

628628
// SqlHandle definition

mssql_python/pybind/ddbc_bindings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <sql.h>
1414
#include <sqlext.h>
1515
#include <memory>
16+
#include <mutex>
1617

1718
#include <pybind11/chrono.h>
1819
#include <pybind11/complex.h>
@@ -145,7 +146,9 @@ class DriverLoader {
145146
DriverLoader();
146147
DriverLoader(const DriverLoader&) = delete;
147148
DriverLoader& operator=(const DriverLoader&) = delete;
149+
148150
bool m_driverLoaded;
151+
std::once_flag m_onceFlag;
149152
};
150153

151154
//-------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)