diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index aaf172b9..711ae2fb 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -191,7 +191,12 @@ void Connection::commit() { } updateLastUsed(); LOG("Committing transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); + SQLRETURN ret; + { + // Release the GIL during the blocking SQLEndTran network round-trip. + py::gil_scoped_release release; + ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_COMMIT); + } checkError(ret); } @@ -201,7 +206,12 @@ void Connection::rollback() { } updateLastUsed(); LOG("Rolling back transaction"); - SQLRETURN ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); + SQLRETURN ret; + { + // Release the GIL during the blocking SQLEndTran network round-trip. + py::gil_scoped_release release; + ret = SQLEndTran_ptr(SQL_HANDLE_DBC, _dbcHandle->get(), SQL_ROLLBACK); + } checkError(ret); } diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9d007653..f0a5de75 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1409,6 +1409,8 @@ SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataT ThrowStdException("SQLGetTypeInfo function not loaded"); } + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLGetTypeInfo_ptr(StatementHandle->get(), DataType); } @@ -1431,6 +1433,8 @@ SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, const py::object& cat std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector procedureBuf = WStringToSQLWCHAR(procedure); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLProcedures_ptr( StatementHandle->get(), catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : schemaBuf.data(), @@ -1438,6 +1442,7 @@ SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, const py::object& cat procedure.empty() ? 0 : SQL_NTS); #else // Windows implementation + py::gil_scoped_release release; return SQLProcedures_ptr( StatementHandle->get(), catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), @@ -1476,6 +1481,8 @@ SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, const py::object& pk std::vector fkSchemaBuf = WStringToSQLWCHAR(fkSchema); std::vector fkTableBuf = WStringToSQLWCHAR(fkTable); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLForeignKeys_ptr( StatementHandle->get(), pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), pkCatalog.empty() ? 0 : SQL_NTS, pkSchema.empty() ? nullptr : pkSchemaBuf.data(), @@ -1486,6 +1493,7 @@ SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, const py::object& pk fkTable.empty() ? 0 : SQL_NTS); #else // Windows implementation + py::gil_scoped_release release; return SQLForeignKeys_ptr( StatementHandle->get(), pkCatalog.empty() ? nullptr : (SQLWCHAR*)pkCatalog.c_str(), pkCatalog.empty() ? 0 : SQL_NTS, pkSchema.empty() ? nullptr : (SQLWCHAR*)pkSchema.c_str(), @@ -1513,6 +1521,8 @@ SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, const py::object& ca std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLPrimaryKeys_ptr( StatementHandle->get(), catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : schemaBuf.data(), @@ -1520,6 +1530,7 @@ SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, const py::object& ca table.empty() ? 0 : SQL_NTS); #else // Windows implementation + py::gil_scoped_release release; return SQLPrimaryKeys_ptr( StatementHandle->get(), catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), @@ -1545,6 +1556,8 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, const py::object& cat std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLStatistics_ptr( StatementHandle->get(), catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : schemaBuf.data(), @@ -1552,6 +1565,7 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, const py::object& cat table.empty() ? 0 : SQL_NTS, unique, reserved); #else // Windows implementation + py::gil_scoped_release release; return SQLStatistics_ptr( StatementHandle->get(), catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), @@ -1580,6 +1594,8 @@ SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, const py::object& catalo std::vector tableBuf = WStringToSQLWCHAR(tableStr); std::vector columnBuf = WStringToSQLWCHAR(columnStr); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLColumns_ptr( StatementHandle->get(), catalogStr.empty() ? nullptr : catalogBuf.data(), catalogStr.empty() ? 0 : SQL_NTS, schemaStr.empty() ? nullptr : schemaBuf.data(), @@ -1588,6 +1604,7 @@ SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, const py::object& catalo columnStr.empty() ? 0 : SQL_NTS); #else // Windows implementation + py::gil_scoped_release release; return SQLColumns_ptr( StatementHandle->get(), catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), catalogStr.empty() ? 0 : SQL_NTS, @@ -1728,7 +1745,14 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q #else queryPtr = const_cast(Query.c_str()); #endif - SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); + SQLRETURN ret; + { + // Release the GIL during the blocking ODBC call so that other Python + // threads (e.g. asyncio event loop, heartbeat threads) can run while + // SQL Server executes the query. See issue #540. + py::gil_scoped_release release; + ret = SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); + } if (!SQL_SUCCEEDED(ret)) { LOG("SQLExecDirect: Query execution failed - SQLRETURN=%d", ret); } @@ -1800,8 +1824,13 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, const std::wstring& catal } #endif - SQLRETURN ret = SQLTables_ptr(StatementHandle->get(), catalogPtr, catalogLen, schemaPtr, - schemaLen, tablePtr, tableLen, tableTypePtr, tableTypeLen); + SQLRETURN ret; + { + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; + ret = SQLTables_ptr(StatementHandle->get(), catalogPtr, catalogLen, schemaPtr, + schemaLen, tablePtr, tableLen, tableTypePtr, tableTypeLen); + } LOG("SQLTables: Catalog metadata query %s - SQLRETURN=%d", SQL_SUCCEEDED(ret) ? "succeeded" : "failed", ret); @@ -1858,7 +1887,11 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, // is the fastest way to submit a SQL statement for one-time execution // according to DDBC documentation - // https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlexecdirect-function?view=sql-server-ver16 - rc = SQLExecDirect_ptr(hStmt, queryPtr, SQL_NTS); + { + // Release the GIL during the blocking ODBC call + py::gil_scoped_release release; + rc = SQLExecDirect_ptr(hStmt, queryPtr, SQL_NTS); + } if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) { LOG("SQLExecute: Direct execution failed (non-parameterized query) " "- SQLRETURN=%d", @@ -1872,7 +1905,11 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, // element assert(isStmtPrepared.size() == 1); if (usePrepare) { - rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + { + // Release the GIL during the blocking SQLPrepare network call. + py::gil_scoped_release release; + rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + } if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPrepare failed - SQLRETURN=%d, " "statement_handle=%p", @@ -1904,12 +1941,27 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, return rc; } - rc = SQLExecute_ptr(hStmt); + { + // Release the GIL during the blocking SQLExecute network call. + py::gil_scoped_release release; + rc = SQLExecute_ptr(hStmt); + } if (rc == SQL_NEED_DATA) { LOG("SQLExecute: SQL_NEED_DATA received - Starting DAE " "(Data-At-Execution) loop for large parameter streaming"); SQLPOINTER paramToken = nullptr; - while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == SQL_NEED_DATA) { + // For DAE, release the GIL only around individual ODBC calls; + // Python type inspection of the parameter happens between calls + // and requires the GIL. + auto paramData = [&](SQLPOINTER* tok) { + py::gil_scoped_release release; + return SQLParamData_ptr(hStmt, tok); + }; + auto putData = [&](SQLPOINTER data, SQLLEN len) { + py::gil_scoped_release release; + return SQLPutData_ptr(hStmt, data, len); + }; + while ((rc = paramData(¶mToken)) == SQL_NEED_DATA) { // Finding the paramInfo that matches the returned token const ParamInfo* matchedInfo = nullptr; for (auto& info : paramInfos) { @@ -1923,7 +1975,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } const py::object& pyObj = matchedInfo->dataPtr; if (pyObj.is_none()) { - SQLPutData_ptr(hStmt, nullptr, 0); + putData(nullptr, 0); continue; } if (py::isinstance(pyObj)) { @@ -1949,8 +2001,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, ThrowStdException("Chunk size exceeds maximum " "allowed by SQLLEN"); } - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(lenBytes)); + rc = putData((SQLPOINTER)(dataPtr + offset), + static_cast(lenBytes)); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPutData failed for " "SQL_C_WCHAR chunk - offset=%zu", @@ -1984,8 +2036,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, while (offset < totalBytes) { size_t len = std::min(chunkBytes, totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(len)); + rc = putData((SQLPOINTER)(dataPtr + offset), + static_cast(len)); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPutData failed for " "SQL_C_CHAR chunk - offset=%zu", @@ -2006,8 +2058,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, const size_t chunkSize = DAE_CHUNK_SIZE; for (size_t offset = 0; offset < totalBytes; offset += chunkSize) { size_t len = std::min(chunkSize, totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(len)); + rc = putData((SQLPOINTER)(dataPtr + offset), + static_cast(len)); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPutData failed for " "binary/bytes chunk - offset=%zu", @@ -2732,7 +2784,12 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst queryPtr = const_cast(query.c_str()); LOG("SQLExecuteMany: Using wide string query directly"); #endif - RETCODE rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + RETCODE rc; + { + // Release the GIL during the blocking SQLPrepare network call. + py::gil_scoped_release release; + rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + } if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecuteMany: SQLPrepare failed - rc=%d", rc); return rc; @@ -2773,7 +2830,11 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst } LOG("SQLExecuteMany: PARAMSET_SIZE set to %zu", paramSetSize); - rc = SQLExecute_ptr(hStmt); + { + // Release the GIL during the blocking SQLExecute network call. + py::gil_scoped_release release; + rc = SQLExecute_ptr(hStmt); + } LOG("SQLExecuteMany: SQLExecute completed - rc=%d", rc); return rc; } else { @@ -2793,12 +2854,20 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst } LOG("SQLExecuteMany: Parameters bound for row %zu", rowIndex); - rc = SQLExecute_ptr(hStmt); + { + // Release the GIL during the blocking SQLExecute network call. + py::gil_scoped_release release; + rc = SQLExecute_ptr(hStmt); + } LOG("SQLExecuteMany: SQLExecute for row %zu - initial_rc=%d", rowIndex, rc); size_t dae_chunk_count = 0; while (rc == SQL_NEED_DATA) { SQLPOINTER token; - rc = SQLParamData_ptr(hStmt, &token); + { + // Release the GIL around the blocking SQLParamData call. + py::gil_scoped_release release; + rc = SQLParamData_ptr(hStmt, &token); + } LOG("SQLExecuteMany: SQLParamData called - chunk=%zu, rc=%d, " "token=%p", dae_chunk_count, rc, token); @@ -2821,7 +2890,10 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst LOG("SQLExecuteMany: Sending string DAE data - chunk=%zu, " "length=%lld", dae_chunk_count, static_cast(data_len)); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); + rc = [&] { + py::gil_scoped_release release; + return SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); + }(); if (!SQL_SUCCEEDED(rc) && rc != SQL_NEED_DATA) { LOG("SQLExecuteMany: SQLPutData(string) failed - " "chunk=%zu, rc=%d", @@ -2834,7 +2906,10 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst LOG("SQLExecuteMany: Sending bytes/bytearray DAE data - " "chunk=%zu, length=%lld", dae_chunk_count, static_cast(data_len)); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); + rc = [&] { + py::gil_scoped_release release; + return SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); + }(); if (!SQL_SUCCEEDED(rc) && rc != SQL_NEED_DATA) { LOG("SQLExecuteMany: SQLPutData(bytes) failed - " "chunk=%zu, rc=%d", @@ -2943,6 +3018,8 @@ SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT ident std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); + // Release the GIL during the blocking ODBC catalog call + py::gil_scoped_release release; return SQLSpecialColumns_ptr( StatementHandle->get(), identifierType, catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, schema.empty() ? nullptr : schemaBuf.data(), @@ -2950,6 +3027,7 @@ SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT ident table.empty() ? 0 : SQL_NTS, scope, nullable); #else // Windows implementation + py::gil_scoped_release release; return SQLSpecialColumns_ptr( StatementHandle->get(), identifierType, catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), catalog.empty() ? 0 : SQL_NTS, @@ -2967,6 +3045,8 @@ SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { DriverLoader::getInstance().loadDriver(); // Load the driver } + // Release the GIL during the blocking ODBC call + py::gil_scoped_release release; return SQLFetch_ptr(StatementHandle->get()); } @@ -2981,7 +3061,11 @@ py::object FetchLobColumnData(SQLHSTMT hStmt, SQLUSMALLINT colIndex, SQLSMALLINT ++loopCount; std::vector chunk(DAE_CHUNK_SIZE, 0); SQLLEN actualRead = 0; - ret = SQLGetData_ptr(hStmt, colIndex, cType, chunk.data(), DAE_CHUNK_SIZE, &actualRead); + { + // Release the GIL during blocking SQLGetData LOB streaming + py::gil_scoped_release release; + ret = SQLGetData_ptr(hStmt, colIndex, cType, chunk.data(), DAE_CHUNK_SIZE, &actualRead); + } if (ret == SQL_ERROR || !SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) { std::ostringstream oss; @@ -3741,7 +3825,12 @@ SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOri SQLFreeStmt_ptr(StatementHandle->get(), SQL_UNBIND); // Perform scroll operation - SQLRETURN ret = SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, FetchOffset); + SQLRETURN ret; + { + // Release the GIL during the blocking ODBC fetch + py::gil_scoped_release release; + ret = SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, FetchOffset); + } // If successful and caller wants data, retrieve it if (SQL_SUCCEEDED(ret) && row_data.size() == 0) { @@ -3929,7 +4018,12 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum const std::vector& lobColumns, const std::string& charEncoding = "utf-8") { LOG("FetchBatchData: Fetching data in batches"); - SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); + SQLRETURN ret; + { + // Release the GIL during the blocking ODBC fetch + py::gil_scoped_release release; + ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); + } if (ret == SQL_NO_DATA) { LOG("FetchBatchData: No data to fetch"); return ret; @@ -4404,7 +4498,11 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch "SQLGetData path", lobColumns.size()); while (numRowsFetched < (SQLULEN)fetchSize) { - ret = SQLFetch_ptr(hStmt); + { + // Release GIL during the blocking fetch + py::gil_scoped_release release; + ret = SQLFetch_ptr(hStmt); + } if (ret == SQL_NO_DATA) break; if (!SQL_SUCCEEDED(ret)) @@ -4789,7 +4887,11 @@ SQLRETURN FetchArrowBatch_wrap( // Adjust fetch size for final batch to avoid overfetching fetchStateGuard.setRowArraySize(spaceLeftInArrowBatch); } - ret = SQLFetch_ptr(hStmt); + { + // Release GIL during the blocking ODBC fetch + py::gil_scoped_release release; + ret = SQLFetch_ptr(hStmt); + } if (ret == SQL_NO_DATA) { ret = SQL_SUCCESS; // Normal completion break; @@ -5541,7 +5643,11 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, "SQLGetData path", lobColumns.size()); while (true) { - ret = SQLFetch_ptr(hStmt); + { + // Release GIL during the blocking fetch + py::gil_scoped_release release; + ret = SQLFetch_ptr(hStmt); + } if (ret == SQL_NO_DATA) break; if (!SQL_SUCCEEDED(ret)) @@ -5656,7 +5762,11 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row, SQLFreeStmt_ptr(hStmt, SQL_UNBIND); // Assume hStmt is already allocated and a query has been executed - ret = SQLFetch_ptr(hStmt); + { + // Release the GIL during the blocking ODBC fetch + py::gil_scoped_release release; + ret = SQLFetch_ptr(hStmt); + } if (SQL_SUCCEEDED(ret)) { // Retrieve column count SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); @@ -5680,6 +5790,8 @@ SQLRETURN SQLMoreResults_wrap(SqlHandlePtr StatementHandle) { DriverLoader::getInstance().loadDriver(); // Load the driver } + // Release the GIL during the blocking ODBC call + py::gil_scoped_release release; return SQLMoreResults_ptr(StatementHandle->get()); } diff --git a/tests/test_001_globals.py b/tests/test_001_globals.py index c18769f1..9e44e011 100644 --- a/tests/test_001_globals.py +++ b/tests/test_001_globals.py @@ -152,20 +152,35 @@ def worker(): mssql_python.lowercase = original_lowercase -def test_lowercase_concurrent_access_with_db(db_connection): +def test_lowercase_concurrent_access_with_db(conn_str, db_connection): """ Tests concurrent modification of the 'lowercase' setting while simultaneously creating cursors and executing queries. This simulates a real-world race condition. + + Each reader thread uses its own dedicated connection. Sharing a single ODBC + connection (HDBC) across threads is not supported by the driver without MARS + (the package advertises threadsafety=1 and cursor methods are documented as + not thread-safe), so the cross-thread parallelism we want to exercise here is + "one connection per thread" — not "many cursors on one connection". """ + if not conn_str: + pytest.skip("DB_CONNECTION_STRING not set") + original_lowercase = mssql_python.lowercase stop_event = threading.Event() errors = [] - # Create a temporary table for the test + # Use a global temp table so each thread's own connection can see it. + # A unique suffix avoids collisions with parallel test runs. + table_name = f"##pytest_thread_test_{random.randint(0, 2**31 - 1)}" + + # Create the global temp table on the shared db_connection. Because at least + # one connection (db_connection) keeps a reference to it for the whole test, + # the table persists for the duration of the reader threads' lifetimes. cursor = None try: cursor = db_connection.cursor() - cursor.execute("CREATE TABLE #pytest_thread_test (COLUMN_NAME INT)") + cursor.execute(f"CREATE TABLE {table_name} (COLUMN_NAME INT)") db_connection.commit() except Exception as e: pytest.fail(f"Failed to create test table: {e}") @@ -186,25 +201,40 @@ def writer(): break def reader(): - """Continuously creates cursors and checks for valid description casing.""" - while not stop_event.is_set(): - cursor = None + """Opens its own connection and continuously creates cursors and checks + for valid description casing.""" + local_conn = None + try: + local_conn = mssql_python.connect(conn_str) + except Exception as e: + errors.append(f"Reader thread connect error: {e}") + return + try: + while not stop_event.is_set(): + cursor = None + try: + cursor = local_conn.cursor() + cursor.execute(f"SELECT * FROM {table_name}") + + # The lock ensures the description is generated atomically. + # We just need to check if the result is one of the two valid states. + col_name = cursor.description[0][0] + + if col_name not in ("COLUMN_NAME", "column_name"): + errors.append( + f"Invalid column name '{col_name}' found. Race condition likely." + ) + except Exception as e: + errors.append(f"Reader thread error: {e}") + break + finally: + if cursor: + cursor.close() + finally: try: - cursor = db_connection.cursor() - cursor.execute("SELECT * FROM #pytest_thread_test") - - # The lock ensures the description is generated atomically. - # We just need to check if the result is one of the two valid states. - col_name = cursor.description[0][0] - - if col_name not in ("COLUMN_NAME", "column_name"): - errors.append(f"Invalid column name '{col_name}' found. Race condition likely.") - except Exception as e: - errors.append(f"Reader thread error: {e}") - break - finally: - if cursor: - cursor.close() + local_conn.close() + except Exception: + pass # Start threads writer_thread = threading.Thread(target=writer) @@ -227,7 +257,7 @@ def reader(): cursor = None try: cursor = db_connection.cursor() - cursor.execute("DROP TABLE #pytest_thread_test") + cursor.execute(f"DROP TABLE {table_name}") db_connection.commit() except Exception as e: # Log cleanup error but don't fail the test for it diff --git a/tests/test_022_concurrent_query_gil_release.py b/tests/test_022_concurrent_query_gil_release.py new file mode 100644 index 00000000..4bc09dc2 --- /dev/null +++ b/tests/test_022_concurrent_query_gil_release.py @@ -0,0 +1,210 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Functional tests for the statement-, fetch- and transaction-level GIL +release added in PR #541 (covering ``SQLExecute`` / ``SQLExecDirect`` / +``SQLFetch`` / ``SQLEndTran`` in ``mssql_python/pybind/ddbc_bindings.cpp`` +and ``mssql_python/pybind/connection/connection.cpp``). + +These are **not** performance/stress tests — they assert a binary +correctness property (the GIL must be released around blocking ODBC calls) +using a conservative threshold that doesn't depend on hardware speed: + +* with the GIL released, a Python heartbeat thread keeps ticking while + another thread sits in ``cursor.execute("WAITFOR DELAY '00:00:02'")`` + — without release the heartbeat is fully starved. +* same property holds across an explicit ``commit()`` (covers the + ``SQLEndTran`` GIL-release path). + +A wall-clock "N threads finish in ~one WAITFOR worth of time" assertion +was deliberately *not* added here — it depends on the SQL Server +scheduler/container CPU allocation and is too flaky for the functional +suite. That style of test lives in ``test_021_concurrent_connection_perf.py`` +under ``@pytest.mark.stress``. + +A 2-second server-side WAITFOR is short enough to keep these in the +default functional suite (~5s total) while still producing an unambiguous +signal that survives normal CI jitter. +""" + +import os +import time +import threading + +import pytest +import mssql_python +from mssql_python import connect + +WAITFOR_SECONDS = 2 +WAITFOR_SQL = f"WAITFOR DELAY '00:00:0{WAITFOR_SECONDS}'" + + +@pytest.fixture(scope="module") +def conn_str(): + """Get connection string from environment.""" + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + return conn_str + + +def _run_waitfor(conn_str: str) -> float: + """Open a fresh connection, run WAITFOR, return elapsed seconds.""" + conn = connect(conn_str) + try: + cursor = conn.cursor() + try: + start = time.perf_counter() + cursor.execute(WAITFOR_SQL) + return time.perf_counter() - start + finally: + cursor.close() + finally: + conn.close() + + +# ============================================================================ +# Heartbeat: a Python thread must keep running while another thread blocks +# inside a server-side WAITFOR. This is the canonical repro from PR #541. +# ============================================================================ + + +def test_query_does_not_block_other_python_threads(conn_str): + """ + While one thread executes a 2-second ``WAITFOR DELAY``, a second pure-Python + thread must continue to run. If the GIL were held across SQLExecDirect, the + heartbeat would not advance until the WAITFOR returned. + """ + mssql_python.pooling(enabled=False) + + heartbeat_interval = 0.05 # 50ms ticks + expected_min_ticks = int(WAITFOR_SECONDS / heartbeat_interval * 0.5) # 50% of theoretical max + + stop_event = threading.Event() + tick_count = [0] + query_error = [] + + def heartbeat(): + while not stop_event.is_set(): + tick_count[0] += 1 + time.sleep(heartbeat_interval) + + def run_query(): + try: + _run_waitfor(conn_str) + except Exception as exc: + query_error.append(str(exc)) + + hb = threading.Thread(target=heartbeat, daemon=True) + qt = threading.Thread(target=run_query, daemon=True) + + # Snapshot ticks just before/after the query so we measure ticks that + # happened *during* the blocking ODBC call, not before/after. + hb.start() + time.sleep(0.1) # let heartbeat warm up + ticks_before = tick_count[0] + + qt.start() + qt.join(timeout=WAITFOR_SECONDS + 30) + ticks_after = tick_count[0] + stop_event.set() + hb.join(timeout=5) + + assert not qt.is_alive(), "Query thread did not finish in time" + assert not query_error, f"Query thread error: {query_error}" + + ticks_during = ticks_after - ticks_before + print( + f"\n[HEARTBEAT] ticks during {WAITFOR_SECONDS}s WAITFOR: {ticks_during} " + f"(expected >= {expected_min_ticks})" + ) + + assert ticks_during >= expected_min_ticks, ( + f"Heartbeat thread was starved during cursor.execute(WAITFOR). " + f"Got {ticks_during} ticks, expected >= {expected_min_ticks}. " + f"This indicates the GIL was not released around the blocking ODBC call." + ) + + +# ============================================================================ +# Transaction: SQLEndTran (commit/rollback) is also wrapped in PR #541. Make +# sure a heartbeat can run while a long server-side WAITFOR holds an open +# transaction that is then committed. +# ============================================================================ + + +def test_commit_does_not_block_other_python_threads(conn_str): + """ + Smoke test for the SQLEndTran GIL-release added to ``Connection::commit`` + and ``Connection::rollback``. A heartbeat must keep ticking across an + explicit commit on a connection that just executed a (short) WAITFOR. + + SQLEndTran on a localhost connection is typically sub-millisecond, so we + can't reliably measure starvation from it alone. Instead we just assert + that the commit completes and the heartbeat made meaningful progress + over the whole transaction, including the WAITFOR. + """ + mssql_python.pooling(enabled=False) + + heartbeat_interval = 0.05 + stop_event = threading.Event() + tick_count = [0] + txn_error = [] + + def heartbeat(): + while not stop_event.is_set(): + tick_count[0] += 1 + time.sleep(heartbeat_interval) + + # Open the connection on the main thread *before* starting the heartbeat + # measurement window. The Python-side wrapper around connect() (connstr + # parsing, handle alloc, attr setup, etc.) legitimately holds the GIL, + # and including it in the window would give false starvation signals — + # especially on macOS CI where scheduler jitter is larger. We want to + # measure ticks across cursor.execute(WAITFOR) + commit only. + conn = connect(conn_str) + + def run_txn(): + try: + try: + cursor = conn.cursor() + try: + cursor.execute(WAITFOR_SQL) + finally: + cursor.close() + conn.commit() + finally: + conn.close() + except Exception as exc: + txn_error.append(str(exc)) + + hb = threading.Thread(target=heartbeat, daemon=True) + tt = threading.Thread(target=run_txn, daemon=True) + + hb.start() + time.sleep(0.1) + ticks_before = tick_count[0] + + tt.start() + tt.join(timeout=WAITFOR_SECONDS + 30) + ticks_after = tick_count[0] + stop_event.set() + hb.join(timeout=5) + + assert not tt.is_alive(), "Transaction thread did not finish in time" + assert not txn_error, f"Transaction thread error: {txn_error}" + + ticks_during = ticks_after - ticks_before + # 40% of theoretical max gives margin against macOS CI scheduler noise + # (sleep(0.05) overshoot + GIL re-acquisition latency) while still + # catching real GIL starvation, which would yield <= ~2 ticks. + expected_min_ticks = int(WAITFOR_SECONDS / heartbeat_interval * 0.4) + print( + f"\n[HEARTBEAT] ticks during WAITFOR+commit: {ticks_during} " + f"(expected >= {expected_min_ticks})" + ) + assert ticks_during >= expected_min_ticks, ( + f"Heartbeat thread was starved across cursor.execute+commit. " + f"Got {ticks_during} ticks, expected >= {expected_min_ticks}." + )