Skip to content

Commit 19bd82a

Browse files
bewithgauravSumit Sarabhai
authored andcommitted
Merged PR 5306: Cursor rowcount, description + Tests
#### AI description (iteration 1) #### PR Classification New feature and tests #### PR Summary This pull request implements new features for the `Cursor` class and adds corresponding tests. - Implemented `Cursor.rowcount` to track the number of affected rows in `mssql_python/cursor.py`. - Added tests for `Cursor.rowcount` in `tests/test_cursor.py`. - Updated `main.py` to include examples of using `Cursor.rowcount`. - Modified `.coveragerc` to exclude test files from coverage. <!-- GitOpsUserAgent=GitOps.Apps.Server.pullrequestcopilot --> Related work items: #32727, #32728, #33076
1 parent de214ca commit 19bd82a

6 files changed

Lines changed: 173 additions & 8 deletions

File tree

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[run]
22
omit =
33
mssql_python/testing_ddbc_bindings.py
4+
tests/*
45

56
[report]
67
# Add any report-specific settings here, if needed

mssql_python/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,4 @@ class ConstantsODBC(Enum):
104104
SQL_PARAM_OUTPUT = 2
105105
SQL_PARAM_INPUT_OUTPUT = 3
106106
SQL_C_WCHAR = -8
107+
SQL_NULLABLE = 1

mssql_python/cursor.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,62 @@ def _create_parameter_types_list(self, parameter, ParamInfo, parameters_list, i)
409409
paraminfo.decimalDigits = decimal_digits
410410
return paraminfo
411411

412+
def _initialize_description(self):
413+
"""
414+
Initialize the description attribute using SQLDescribeCol.
415+
"""
416+
col_metadata = []
417+
ret = ddbc_bindings.DDBCSQLDescribeCol(self.hstmt.value, col_metadata)
418+
check_error(odbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret)
419+
420+
self.description = [
421+
(
422+
col["ColumnName"],
423+
self._map_data_type(col["DataType"]),
424+
None,
425+
col["ColumnSize"],
426+
col["ColumnSize"],
427+
col["DecimalDigits"],
428+
col["Nullable"] == odbc_sql_const.SQL_NULLABLE.value,
429+
)
430+
for col in col_metadata
431+
]
432+
433+
def _map_data_type(self, sql_type):
434+
"""
435+
Map SQL data type to Python data type.
436+
437+
Args:
438+
sql_type: SQL data type.
439+
440+
Returns:
441+
Corresponding Python data type.
442+
"""
443+
sql_to_python_type = {
444+
odbc_sql_const.SQL_INTEGER.value: int,
445+
odbc_sql_const.SQL_VARCHAR.value: str,
446+
odbc_sql_const.SQL_WVARCHAR.value: str,
447+
odbc_sql_const.SQL_CHAR.value: str,
448+
odbc_sql_const.SQL_WCHAR.value: str,
449+
odbc_sql_const.SQL_FLOAT.value: float,
450+
odbc_sql_const.SQL_DOUBLE.value: float,
451+
odbc_sql_const.SQL_DECIMAL.value: decimal.Decimal,
452+
odbc_sql_const.SQL_NUMERIC.value: decimal.Decimal,
453+
odbc_sql_const.SQL_DATE.value: datetime.date,
454+
odbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime,
455+
odbc_sql_const.SQL_TIME.value: datetime.time,
456+
odbc_sql_const.SQL_BIT.value: bool,
457+
odbc_sql_const.SQL_TINYINT.value: int,
458+
odbc_sql_const.SQL_SMALLINT.value: int,
459+
odbc_sql_const.SQL_BIGINT.value: int,
460+
odbc_sql_const.SQL_BINARY.value: bytes,
461+
odbc_sql_const.SQL_VARBINARY.value: bytes,
462+
odbc_sql_const.SQL_LONGVARBINARY.value: bytes,
463+
odbc_sql_const.SQL_GUID.value: uuid.UUID,
464+
# Add more mappings as needed
465+
}
466+
return sql_to_python_type.get(sql_type, str)
467+
412468
def execute(self, operation: str, *parameters, use_prepare: bool = True, reset_cursor: bool = True):
413469
"""
414470
Prepare and execute a database operation (query or command).
@@ -452,6 +508,12 @@ def execute(self, operation: str, *parameters, use_prepare: bool = True, reset_c
452508
self.is_stmt_prepared, use_prepare)
453509
check_error(odbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt.value, ret)
454510
self.last_executed_stmt = operation
511+
512+
# Update rowcount after execution
513+
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt.value)
514+
515+
# Initialize description after execution
516+
self._initialize_description()
455517
except Exception as e:
456518
if ENABLE_LOGGING:
457519
logging.error("An error occurred while executing query: %s", e)
@@ -475,6 +537,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
475537
self._reset_cursor()
476538

477539
first_execution = True
540+
total_rowcount = 0
478541
for parameters in seq_of_parameters:
479542
# Execute the operation with the current set of parameters without
480543
# Converting the parameters to a list
@@ -491,6 +554,13 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
491554
prepare_stmt = False
492555
# Execute statement with one parameter set
493556
self.execute(operation, parameters, use_prepare=prepare_stmt, reset_cursor=False)
557+
if self.rowcount != -1:
558+
# Rowcount would get updated inside execute method, add it to the current rowcount
559+
total_rowcount += self.rowcount
560+
else:
561+
total_rowcount = -1
562+
# Update the rowcount after all executions
563+
self.rowcount = total_rowcount
494564
except Exception as e:
495565
if ENABLE_LOGGING:
496566
logging.error("An error occurred while executing multiple queries: %s", e)

mssql_python/msvcp140.dll

562 KB
Binary file not shown.

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;
196196
// Diagnostic record function pointer
197197
SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr;
198198

199+
// Function pointer typedef for SQLRowCount
200+
typedef SQLRETURN (*SQLRowCountFunc)(SQLHSTMT, SQLLEN*);
201+
202+
// Function pointer for SQLRowCount
203+
SQLRowCountFunc SQLRowCount_ptr = nullptr;
204+
199205
namespace {
200206
// Helper to load the driver
201207
bool LoadDriver() {
@@ -247,6 +253,9 @@ bool LoadDriver() {
247253
// Diagnostic record function Loading
248254
SQLGetDiagRec_ptr = (SQLGetDiagRecFunc)GetProcAddress(hModule, "SQLGetDiagRecW");
249255

256+
// Load SQLRowCount function
257+
SQLRowCount_ptr = (SQLRowCountFunc)GetProcAddress(hModule, "SQLRowCount");
258+
250259
#ifdef _DEBUG
251260
std::cout << "Driver loaded successfully." << std::endl;
252261
#endif
@@ -256,7 +265,8 @@ bool LoadDriver() {
256265
SQLBindParameter_ptr && SQLExecute_ptr && SQLFetch_ptr && SQLFetchScroll_ptr &&
257266
SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr &&
258267
SQLMoreResults_ptr && SQLColAttribute_ptr && SQLColAttribute_ptr && SQLEndTran_ptr &&
259-
SQLFreeHandle_ptr && SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
268+
SQLFreeHandle_ptr && SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr &&
269+
SQLRowCount_ptr;
260270
}
261271

262272
// TODO: Add more nuanced exception classes
@@ -1598,6 +1608,23 @@ SQLRETURN SQLDisconnect_wrap(intptr_t ConnectionHandle) {
15981608
return SQLDisconnect_ptr(reinterpret_cast<SQLHDBC>(ConnectionHandle));
15991609
}
16001610

1611+
// Wrap SQLRowCount
1612+
SQLLEN SQLRowCount_wrap(intptr_t StatementHandle) {
1613+
if (!SQLRowCount_ptr && !LoadDriver()) {
1614+
return -1;
1615+
}
1616+
1617+
SQLLEN rowCount;
1618+
SQLRETURN ret = SQLRowCount_ptr(reinterpret_cast<SQLHSTMT>(StatementHandle), &rowCount);
1619+
if (!SQL_SUCCEEDED(ret)) {
1620+
std::cerr << "SQLRowCount failed with error code: " << ret << std::endl;
1621+
return -1;
1622+
}
1623+
1624+
std::cout << "SQLRowCount returned: " << rowCount << std::endl; // Debug print
1625+
return rowCount;
1626+
}
1627+
16011628
// Bind the functions to the module
16021629
PYBIND11_MODULE(ddbc_bindings, m) {
16031630
m.doc() = "msodbcsql driver api bindings for Python"; // optional module docstring
@@ -1629,7 +1656,7 @@ PYBIND11_MODULE(ddbc_bindings, m) {
16291656
m.def("DDBCSQLGetConnectionAttr", &SQLGetConnectionAttr_wrap,
16301657
"Get an attribute that governs aspects of connections");
16311658
m.def("DDBCSQLDriverConnect", &SQLDriverConnect_wrap,
1632-
"Connect to a data source with a connection string");
1659+
"Connect to a data source with a connection string");
16331660
m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly");
16341661
m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements");
16351662
m.def("DDBCSQLFetch", &SQLFetch_wrap, "Fetch the next row from the result set");
@@ -1647,4 +1674,5 @@ PYBIND11_MODULE(ddbc_bindings, m) {
16471674
m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle");
16481675
m.def("DDBCSQLDisconnect", &SQLDisconnect_wrap, "Disconnect from a data source");
16491676
m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors");
1677+
m.def("DDBCSQLRowCount", &SQLRowCount_wrap, "Get the number of rows affected by the last statement");
16501678
}

tests/test_cursor.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_cursor(cursor):
6868
def test_insert_id_column(cursor, db_connection):
6969
"""Test inserting data into the id column"""
7070
try:
71+
drop_table_if_exists(cursor, "single_column")
7172
cursor.execute("CREATE TABLE single_column (id INTEGER PRIMARY KEY)")
7273
db_connection.commit()
7374
cursor.execute("INSERT INTO single_column (id) VALUES (?)", [1])
@@ -328,12 +329,64 @@ def test_parametrized_insert(cursor, db_connection, data):
328329
except Exception as e:
329330
pytest.fail(f"Parameterized data insertion failed: {e}")
330331

331-
# def test_rowcount(cursor, db_connection):
332-
# """Test rowcount"""
333-
# cursor.execute("SELECT * FROM all_data_types")
334-
# assert cursor.rowcount == -1, "Affected Rowcount should be -1"
335-
# cursor.execute("UPDATE all_data_types SET wvarchar_column = 'updated' where id > 2")
336-
# assert cursor.rowcount == 2, "Affected Rowcount should be 2"
332+
def test_rowcount(cursor, db_connection):
333+
"""Test rowcount after insert operations"""
334+
try:
335+
cursor.execute("CREATE TABLE test_rowcount (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(100))")
336+
db_connection.commit()
337+
338+
cursor.execute("INSERT INTO test_rowcount (name) VALUES ('JohnDoe1');")
339+
assert cursor.rowcount == 1, "Rowcount should be 1 after first insert"
340+
341+
cursor.execute("INSERT INTO test_rowcount (name) VALUES ('JohnDoe2');")
342+
assert cursor.rowcount == 1, "Rowcount should be 1 after second insert"
343+
344+
cursor.execute("INSERT INTO test_rowcount (name) VALUES ('JohnDoe3');")
345+
assert cursor.rowcount == 1, "Rowcount should be 1 after third insert"
346+
347+
cursor.execute("""
348+
INSERT INTO test_rowcount (name)
349+
VALUES
350+
('JohnDoe4'),
351+
('JohnDoe5'),
352+
('JohnDoe6');
353+
""")
354+
assert cursor.rowcount == 3, "Rowcount should be 3 after inserting multiple rows"
355+
356+
cursor.execute("SELECT * FROM test_rowcount;")
357+
assert cursor.rowcount == -1, "Rowcount should be -1 after a SELECT statement"
358+
359+
db_connection.commit()
360+
except Exception as e:
361+
pytest.fail(f"Rowcount test failed: {e}")
362+
finally:
363+
cursor.execute("DROP TABLE test_rowcount")
364+
db_connection.commit()
365+
366+
def test_rowcount_executemany(cursor, db_connection):
367+
"""Test rowcount after executemany operations"""
368+
try:
369+
cursor.execute("CREATE TABLE test_rowcount (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(100))")
370+
db_connection.commit()
371+
372+
data = [
373+
('JohnDoe1',),
374+
('JohnDoe2',),
375+
('JohnDoe3',)
376+
]
377+
378+
cursor.executemany("INSERT INTO test_rowcount (name) VALUES (?)", data)
379+
assert cursor.rowcount == 3, "Rowcount should be 3 after executemany insert"
380+
381+
cursor.execute("SELECT * FROM test_rowcount;")
382+
assert cursor.rowcount == -1, "Rowcount should be -1 after a SELECT statement"
383+
384+
db_connection.commit()
385+
except Exception as e:
386+
pytest.fail(f"Rowcount executemany test failed: {e}")
387+
finally:
388+
cursor.execute("DROP TABLE test_rowcount")
389+
db_connection.commit()
337390

338391
def test_fetchone(cursor):
339392
"""Test fetching a single row"""
@@ -585,6 +638,18 @@ def test_drop_tables_for_join(cursor, db_connection):
585638
except Exception as e:
586639
pytest.fail(f"Failed to drop tables for join operations: {e}")
587640

641+
def test_cursor_description(cursor):
642+
"""Test cursor description"""
643+
cursor.execute("SELECT database_id, name FROM sys.databases;")
644+
description = cursor.description
645+
expected_description = [
646+
('database_id', int, None, 10, 10, 0, False),
647+
('name', str, None, 128, 128, 0, False)
648+
]
649+
assert len(description) == len(expected_description), "Description length mismatch"
650+
for desc, expected in zip(description, expected_description):
651+
assert desc == expected, f"Description mismatch: {desc} != {expected}"
652+
588653
def test_close(cursor):
589654
"""Test closing the cursor"""
590655
cursor.close()

0 commit comments

Comments
 (0)