diff --git a/.evergreen/remove-unimplemented-tests.sh b/.evergreen/remove-unimplemented-tests.sh index 794319679b..7982dd6b21 100755 --- a/.evergreen/remove-unimplemented-tests.sh +++ b/.evergreen/remove-unimplemented-tests.sh @@ -1,7 +1,6 @@ #!/bin/bash PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)") -rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894 rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873 rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918 rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918 diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 4a54f9eb3f..8da5ffcb47 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -36,7 +36,10 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.client_session import ( + AsyncClientSession, + _validate_session_write_concern, +) from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -271,6 +274,8 @@ async def write_command( if bwc.publish: bwc._start(cmd, request_id, docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.set_in_progress() reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 015947d7ef..dcef4eea02 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -35,7 +35,10 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.client_session import ( + AsyncClientSession, + _validate_session_write_concern, +) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.database import AsyncDatabase @@ -258,6 +261,8 @@ async def write_command( if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.set_in_progress() reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index f9d778f648..c5d5d7f298 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -433,6 +433,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[ self.attempt = 0 self.client = client self.has_completed_command = False + self.has_sent_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -443,6 +444,11 @@ def starting(self) -> bool: def set_starting(self) -> None: self.state = _TxnState.STARTING + def set_in_progress(self) -> None: + if self.state == _TxnState.STARTING: + self.has_sent_command = True + self.state = _TxnState.IN_PROGRESS + @property def pinned_conn(self) -> Optional[AsyncConnection]: if self.active() and self.conn_mgr: @@ -469,6 +475,7 @@ async def reset(self) -> None: self.recovery_token = None self.attempt = 0 self.has_completed_command = False + self.has_sent_command = False def __del__(self) -> None: if self.conn_mgr: @@ -1135,7 +1142,6 @@ def _apply_to( if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. - self._transaction.state = _TxnState.IN_PROGRESS command["startTransaction"] = True assert self._transaction.opts diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 412a13ec70..2e04fa4c24 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2870,8 +2870,8 @@ async def run(self) -> T: self._last_error = exc self._attempt_number += 1 - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never completed. if ( overloaded and self._session is not None @@ -2921,8 +2921,8 @@ async def run(self) -> T: self._last_error = exc if self._last_error is None: self._last_error = exc - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never completed. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction if not transaction.has_completed_command: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 475f4bfa99..d973e97601 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -395,6 +395,8 @@ async def command( unacknowledged = bool(write_concern and not write_concern.acknowledged) self._raise_if_not_writable(unacknowledged) try: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() return await command( self, dbname, diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index c93eeb413f..39d422d038 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -204,6 +204,8 @@ async def run_operation( if more_to_come: reply = await conn.receive_message(None) else: + if operation.session is not None and operation.session._starting_transaction: + operation.session._transaction.set_in_progress() await conn.send_message(data, max_doc_size) reply = await conn.receive_message(request_id) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 22d6a7a76a..f6e1d1abe4 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -67,7 +67,10 @@ _randint, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.client_session import ( + ClientSession, + _validate_session_write_concern, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -271,6 +274,8 @@ def write_command( if bwc.publish: bwc._start(cmd, request_id, docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.set_in_progress() reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 1134594ae9..400b1a2170 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -35,7 +35,10 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.client_session import ( + ClientSession, + _validate_session_write_concern, +) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database @@ -258,6 +261,8 @@ def write_command( if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.set_in_progress() reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 240372f0cf..f4df500549 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -431,6 +431,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]) self.attempt = 0 self.client = client self.has_completed_command = False + self.has_sent_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -441,6 +442,11 @@ def starting(self) -> bool: def set_starting(self) -> None: self.state = _TxnState.STARTING + def set_in_progress(self) -> None: + if self.state == _TxnState.STARTING: + self.has_sent_command = True + self.state = _TxnState.IN_PROGRESS + @property def pinned_conn(self) -> Optional[Connection]: if self.active() and self.conn_mgr: @@ -467,6 +473,7 @@ def reset(self) -> None: self.recovery_token = None self.attempt = 0 self.has_completed_command = False + self.has_sent_command = False def __del__(self) -> None: if self.conn_mgr: @@ -1131,7 +1138,6 @@ def _apply_to( if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. - self._transaction.state = _TxnState.IN_PROGRESS command["startTransaction"] = True assert self._transaction.opts diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2bd6f31b72..dc3c1434e9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2860,8 +2860,8 @@ def run(self) -> T: self._last_error = exc self._attempt_number += 1 - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never completed. if ( overloaded and self._session is not None @@ -2911,8 +2911,8 @@ def run(self) -> T: self._last_error = exc if self._last_error is None: self._last_error = exc - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never completed. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction if not transaction.has_completed_command: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 938eca42bd..add7a1b1f5 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -395,6 +395,8 @@ def command( unacknowledged = bool(write_concern and not write_concern.acknowledged) self._raise_if_not_writable(unacknowledged) try: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() return command( self, dbname, diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 73d782092e..5297a9e297 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -204,6 +204,8 @@ def run_operation( if more_to_come: reply = conn.receive_message(None) else: + if operation.session is not None and operation.session._starting_transaction: + operation.session._transaction.set_in_progress() conn.send_message(data, max_doc_size) reply = conn.receive_message(request_id) diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py index 8136641236..a9b7981c16 100644 --- a/test/asynchronous/test_unified_format.py +++ b/test/asynchronous/test_unified_format.py @@ -39,9 +39,6 @@ os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", - expected_failures=[ - "Client side error in command starting transaction", # PYTHON-1894 - ], ) ) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 0491e0024b..dbbc8b7f90 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -570,8 +570,6 @@ def maybe_skip_test(self, spec): class_name = self.__class__.__name__.lower() description = spec["description"].lower() - if "client side error in command starting transaction" in description: - self.skipTest("Implement PYTHON-1894") if "type=symbol" in description: self.skipTest("PyMongo does not support the symbol type") if "timeoutms applied to entire download" in description: diff --git a/test/test_unified_format.py b/test/test_unified_format.py index a55f810473..4e12102604 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -39,9 +39,6 @@ os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", - expected_failures=[ - "Client side error in command starting transaction", # PYTHON-1894 - ], ) ) diff --git a/test/unified_format.py b/test/unified_format.py index 74df1cabb0..1c7904f42d 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -569,8 +569,6 @@ def maybe_skip_test(self, spec): class_name = self.__class__.__name__.lower() description = spec["description"].lower() - if "client side error in command starting transaction" in description: - self.skipTest("Implement PYTHON-1894") if "type=symbol" in description: self.skipTest("PyMongo does not support the symbol type") if "timeoutms applied to entire download" in description: