From d247c60498c220106e0fee43605b1b4dd499c552 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 11:23:27 -0400 Subject: [PATCH 01/10] initial idea --- .evergreen/remove-unimplemented-tests.sh | 1 - pymongo/asynchronous/client_session.py | 9 ++++++++- pymongo/asynchronous/pool.py | 14 ++++++++++++-- pymongo/synchronous/client_session.py | 9 ++++++++- pymongo/synchronous/pool.py | 14 ++++++++++++-- test/asynchronous/test_unified_format.py | 3 --- test/asynchronous/unified_format.py | 2 -- test/test_unified_format.py | 3 --- test/unified_format.py | 2 -- 9 files changed, 40 insertions(+), 17 deletions(-) diff --git a/.evergreen/remove-unimplemented-tests.sh b/.evergreen/remove-unimplemented-tests.sh index 794319679b..7982dd6b21 100755 --- a/.evergreen/remove-unimplemented-tests.sh +++ b/.evergreen/remove-unimplemented-tests.sh @@ -1,7 +1,6 @@ #!/bin/bash PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)") -rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894 rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873 rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918 rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918 diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index f9d778f648..8278c5da86 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -1059,6 +1059,14 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() + @property + def _advance_transaction_state_on_response(self) -> None: + """Advance STARTING -> IN_PROGRESS after the first command has reached + the server response stage. Client-side errors must not advance transaction state. + """ + if self._transaction.state == _TxnState.STARTING: + self._transaction.state = _TxnState.IN_PROGRESS + @property def _pinned_address(self) -> Optional[_Address]: """The mongos address this transaction was created on.""" @@ -1135,7 +1143,6 @@ def _apply_to( if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. - self._transaction.state = _TxnState.IN_PROGRESS command["startTransaction"] = True assert self._transaction.opts diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 7bc2a9f207..20b1fd2807 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -400,7 +400,7 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - return await command( + result = await command( self, dbname, spec, @@ -424,7 +424,17 @@ async def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError): + if session and session.in_transaction: + session._advance_transaction_state_on_response() + return result + except (OperationFailure, NotPrimaryError) as exc: + if ( + session + and session.in_transaction + and session._starting_transaction + and not exc.has_error_label("RetryableWriteError") + ): + session._advance_transaction_state_on_response() raise # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 240372f0cf..a6050f1b5f 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -1055,6 +1055,14 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() + @property + def _advance_transaction_state_on_response(self) -> None: + """Advance STARTING -> IN_PROGRESS after the first command has reached + the server response stage. Client-side errors must not advance transaction state. + """ + if self._transaction.state == _TxnState.STARTING: + self._transaction.state = _TxnState.IN_PROGRESS + @property def _pinned_address(self) -> Optional[_Address]: """The mongos address this transaction was created on.""" @@ -1131,7 +1139,6 @@ def _apply_to( if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. - self._transaction.state = _TxnState.IN_PROGRESS command["startTransaction"] = True assert self._transaction.opts diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 970989c594..be49d68504 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -400,7 +400,7 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - return command( + result = command( self, dbname, spec, @@ -424,7 +424,17 @@ def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError): + if session and session.in_transaction: + session._advance_transaction_state_on_response() + return result + except (OperationFailure, NotPrimaryError) as exc: + if ( + session + and session.in_transaction + and session._starting_transaction + and not exc.has_error_label("RetryableWriteError") + ): + session._advance_transaction_state_on_response() raise # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py index 8136641236..a9b7981c16 100644 --- a/test/asynchronous/test_unified_format.py +++ b/test/asynchronous/test_unified_format.py @@ -39,9 +39,6 @@ os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", - expected_failures=[ - "Client side error in command starting transaction", # PYTHON-1894 - ], ) ) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index f63f716726..0c606e9693 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -569,8 +569,6 @@ def maybe_skip_test(self, spec): class_name = self.__class__.__name__.lower() description = spec["description"].lower() - if "client side error in command starting transaction" in description: - self.skipTest("Implement PYTHON-1894") if "type=symbol" in description: self.skipTest("PyMongo does not support the symbol type") if "timeoutms applied to entire download" in description: diff --git a/test/test_unified_format.py b/test/test_unified_format.py index a55f810473..4e12102604 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -39,9 +39,6 @@ os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", - expected_failures=[ - "Client side error in command starting transaction", # PYTHON-1894 - ], ) ) diff --git a/test/unified_format.py b/test/unified_format.py index a7018c01d8..cca8993a0d 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -568,8 +568,6 @@ def maybe_skip_test(self, spec): class_name = self.__class__.__name__.lower() description = spec["description"].lower() - if "client side error in command starting transaction" in description: - self.skipTest("Implement PYTHON-1894") if "type=symbol" in description: self.skipTest("PyMongo does not support the symbol type") if "timeoutms applied to entire download" in description: From 69edc65320afb5f9ce40332eb610794e721bb640 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 11:47:05 -0400 Subject: [PATCH 02/10] fix types --- pymongo/asynchronous/client_session.py | 1 - pymongo/asynchronous/pool.py | 50 ++++++++++++++------------ pymongo/synchronous/client_session.py | 1 - pymongo/synchronous/pool.py | 50 ++++++++++++++------------ 4 files changed, 54 insertions(+), 48 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 8278c5da86..60b8a04aa4 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -1059,7 +1059,6 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() - @property def _advance_transaction_state_on_response(self) -> None: """Advance STARTING -> IN_PROGRESS after the first command has reached the server response stage. Client-side errors must not advance transaction state. diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 20b1fd2807..0c65db9f31 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -34,6 +34,7 @@ Optional, Sequence, Union, + cast, ) from bson import DEFAULT_CODEC_OPTIONS @@ -400,29 +401,32 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = await command( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, # type: ignore[arg-type] - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, + result = await cast( + dict[str, Any], + command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, # type: ignore[arg-type] + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ), ) if session and session.in_transaction: session._advance_transaction_state_on_response() diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index a6050f1b5f..d3b776b78e 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -1055,7 +1055,6 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() - @property def _advance_transaction_state_on_response(self) -> None: """Advance STARTING -> IN_PROGRESS after the first command has reached the server response stage. Client-side errors must not advance transaction state. diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index be49d68504..0a198a0857 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -34,6 +34,7 @@ Optional, Sequence, Union, + cast, ) from bson import DEFAULT_CODEC_OPTIONS @@ -400,29 +401,32 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = command( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, # type: ignore[arg-type] - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, + result = cast( + dict[str, Any], + command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, # type: ignore[arg-type] + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ), ) if session and session.in_transaction: session._advance_transaction_state_on_response() From b6173dfc8c4750ba6e270271e6dae1ca7938d43e Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 11:50:58 -0400 Subject: [PATCH 03/10] typo --- pymongo/asynchronous/pool.py | 4 ++-- pymongo/synchronous/pool.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 0c65db9f31..ed645b519c 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -401,9 +401,9 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = await cast( + result = result = cast( dict[str, Any], - command( + await command( self, dbname, spec, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 0a198a0857..0182237351 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -401,7 +401,7 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = cast( + result = result = cast( dict[str, Any], command( self, From 7d4ef9ca69f8fdeb5c417fa0ac29193680b903a0 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 11:52:26 -0400 Subject: [PATCH 04/10] typo2 --- pymongo/asynchronous/pool.py | 2 +- pymongo/synchronous/pool.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index ed645b519c..0573a8b286 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -401,7 +401,7 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = result = cast( + result = cast( dict[str, Any], await command( self, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 0182237351..0a198a0857 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -401,7 +401,7 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = result = cast( + result = cast( dict[str, Any], command( self, From 72d7e61ea7804a95a65e5ac25de1a9bf236bb93d Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 12:42:43 -0400 Subject: [PATCH 05/10] call helper in server.run_operation() --- pymongo/asynchronous/pool.py | 53 +++++++++++++++------------------- pymongo/asynchronous/server.py | 2 ++ pymongo/synchronous/pool.py | 53 +++++++++++++++------------------- pymongo/synchronous/server.py | 2 ++ 4 files changed, 52 insertions(+), 58 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 0573a8b286..d9cbd94769 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -34,7 +34,6 @@ Optional, Sequence, Union, - cast, ) from bson import DEFAULT_CODEC_OPTIONS @@ -401,36 +400,32 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = cast( - dict[str, Any], - await command( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, # type: ignore[arg-type] - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, - ), - ) if session and session.in_transaction: session._advance_transaction_state_on_response() - return result + return await command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, # type: ignore[arg-type] + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ) except (OperationFailure, NotPrimaryError) as exc: if ( session diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index f212306174..a3b02a1783 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -204,6 +204,8 @@ async def run_operation( if more_to_come: reply = await conn.receive_message(None) else: + if operation.session is not None and operation.session._starting_transaction: + operation.session._advance_transaction_state_on_send() await conn.send_message(data, max_doc_size) reply = await conn.receive_message(request_id) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 0a198a0857..9cc6b0972c 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -34,7 +34,6 @@ Optional, Sequence, Union, - cast, ) from bson import DEFAULT_CODEC_OPTIONS @@ -401,36 +400,32 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - result = cast( - dict[str, Any], - command( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, # type: ignore[arg-type] - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, - ), - ) if session and session.in_transaction: session._advance_transaction_state_on_response() - return result + return command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, # type: ignore[arg-type] + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ) except (OperationFailure, NotPrimaryError) as exc: if ( session diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index f57420918b..51d99fcc50 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -204,6 +204,8 @@ def run_operation( if more_to_come: reply = conn.receive_message(None) else: + if operation.session is not None and operation.session._starting_transaction: + operation.session._advance_transaction_state_on_send() conn.send_message(data, max_doc_size) reply = conn.receive_message(request_id) From 13dcdca61ea0c5c6848a8d719c050b250932e93c Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 13:24:25 -0400 Subject: [PATCH 06/10] inline --- pymongo/asynchronous/server.py | 5 ++++- pymongo/synchronous/server.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index a3b02a1783..90390f7200 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -27,6 +27,7 @@ ) from bson import _decode_all_selective +from pymongo.asynchronous.client_session import _TxnState from pymongo.asynchronous.helpers import _handle_reauth from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers_shared import _check_command_response @@ -204,8 +205,10 @@ async def run_operation( if more_to_come: reply = await conn.receive_message(None) else: + # Mark the transaction as in progress once the first transactional message is about to be sent, + # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: - operation.session._advance_transaction_state_on_send() + operation.session._transaction.state = _TxnState.IN_PROGRESS await conn.send_message(data, max_doc_size) reply = await conn.receive_message(request_id) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 51d99fcc50..610d7d60ca 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -38,6 +38,7 @@ ) from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.client_session import _TxnState from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: @@ -204,8 +205,10 @@ def run_operation( if more_to_come: reply = conn.receive_message(None) else: + # Mark the transaction as in progress once the first transactional message is about to be sent, + # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: - operation.session._advance_transaction_state_on_send() + operation.session._transaction.state = _TxnState.IN_PROGRESS conn.send_message(data, max_doc_size) reply = conn.receive_message(request_id) From 3b56b6b64d2e972b8167bad276ebd7d928945443 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 14:57:21 -0400 Subject: [PATCH 07/10] Track first transactional send --- pymongo/asynchronous/client_session.py | 9 ++------- pymongo/asynchronous/mongo_client.py | 12 ++++++------ pymongo/asynchronous/pool.py | 16 +++++----------- pymongo/asynchronous/server.py | 1 + pymongo/synchronous/client_session.py | 9 ++------- pymongo/synchronous/mongo_client.py | 12 ++++++------ pymongo/synchronous/pool.py | 16 +++++----------- pymongo/synchronous/server.py | 1 + 8 files changed, 28 insertions(+), 48 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 60b8a04aa4..0fb9336354 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -433,6 +433,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[ self.attempt = 0 self.client = client self.has_completed_command = False + self.has_sent_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -469,6 +470,7 @@ async def reset(self) -> None: self.recovery_token = None self.attempt = 0 self.has_completed_command = False + self.has_sent_command = False def __del__(self) -> None: if self.conn_mgr: @@ -1059,13 +1061,6 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() - def _advance_transaction_state_on_response(self) -> None: - """Advance STARTING -> IN_PROGRESS after the first command has reached - the server response stage. Client-side errors must not advance transaction state. - """ - if self._transaction.state == _TxnState.STARTING: - self._transaction.state = _TxnState.IN_PROGRESS - @property def _pinned_address(self) -> Optional[_Address]: """The mongos address this transaction was created on.""" diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 412a13ec70..b250142db9 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2870,15 +2870,15 @@ async def run(self) -> T: self._last_error = exc self._attempt_number += 1 - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never sent. if ( overloaded and self._session is not None and self._session.in_transaction ): transaction = self._session._transaction - if not transaction.has_completed_command: + if not transaction.has_sent_command: transaction.set_starting() transaction.attempt = 0 else: @@ -2921,11 +2921,11 @@ async def run(self) -> T: self._last_error = exc if self._last_error is None: self._last_error = exc - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never sent. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction - if not transaction.has_completed_command: + if not transaction.has_sent_command: transaction.set_starting() transaction.attempt = 0 diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index d9cbd94769..083a530dbf 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -38,7 +38,7 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared -from pymongo.asynchronous.client_session import _validate_session_write_concern +from pymongo.asynchronous.client_session import _TxnState, _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.network import command from pymongo.common import ( @@ -400,8 +400,9 @@ async def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - if session and session.in_transaction: - session._advance_transaction_state_on_response() + if session is not None and session._starting_transaction: + session._transaction.has_sent_command = True + session._transaction.state = _TxnState.IN_PROGRESS return await command( self, dbname, @@ -426,14 +427,7 @@ async def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError) as exc: - if ( - session - and session.in_transaction - and session._starting_transaction - and not exc.has_error_label("RetryableWriteError") - ): - session._advance_transaction_state_on_response() + except (OperationFailure, NotPrimaryError): raise # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 90390f7200..c271a5d959 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -208,6 +208,7 @@ async def run_operation( # Mark the transaction as in progress once the first transactional message is about to be sent, # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: + operation.session._transaction.has_sent_command = True operation.session._transaction.state = _TxnState.IN_PROGRESS await conn.send_message(data, max_doc_size) reply = await conn.receive_message(request_id) diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index d3b776b78e..9a3a7e7f27 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -431,6 +431,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]) self.attempt = 0 self.client = client self.has_completed_command = False + self.has_sent_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -467,6 +468,7 @@ def reset(self) -> None: self.recovery_token = None self.attempt = 0 self.has_completed_command = False + self.has_sent_command = False def __del__(self) -> None: if self.conn_mgr: @@ -1055,13 +1057,6 @@ def _starting_transaction(self) -> bool: """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() - def _advance_transaction_state_on_response(self) -> None: - """Advance STARTING -> IN_PROGRESS after the first command has reached - the server response stage. Client-side errors must not advance transaction state. - """ - if self._transaction.state == _TxnState.STARTING: - self._transaction.state = _TxnState.IN_PROGRESS - @property def _pinned_address(self) -> Optional[_Address]: """The mongos address this transaction was created on.""" diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2bd6f31b72..e400844dc8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2860,15 +2860,15 @@ def run(self) -> T: self._last_error = exc self._attempt_number += 1 - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never sent. if ( overloaded and self._session is not None and self._session.in_transaction ): transaction = self._session._transaction - if not transaction.has_completed_command: + if not transaction.has_sent_command: transaction.set_starting() transaction.attempt = 0 else: @@ -2911,11 +2911,11 @@ def run(self) -> T: self._last_error = exc if self._last_error is None: self._last_error = exc - # Revert back to starting state if we're in a transaction but haven't completed the first - # command. + # Revert back to starting state only if the first + # transactional command was never sent. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction - if not transaction.has_completed_command: + if not transaction.has_sent_command: transaction.set_starting() transaction.attempt = 0 diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 9cc6b0972c..1f129253dd 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -87,7 +87,7 @@ from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.client_session import _TxnState, _validate_session_write_concern from pymongo.synchronous.helpers import _handle_reauth from pymongo.synchronous.network import command @@ -400,8 +400,9 @@ def command( if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - if session and session.in_transaction: - session._advance_transaction_state_on_response() + if session is not None and session._starting_transaction: + session._transaction.has_sent_command = True + session._transaction.state = _TxnState.IN_PROGRESS return command( self, dbname, @@ -426,14 +427,7 @@ def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError) as exc: - if ( - session - and session.in_transaction - and session._starting_transaction - and not exc.has_error_label("RetryableWriteError") - ): - session._advance_transaction_state_on_response() + except (OperationFailure, NotPrimaryError): raise # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 610d7d60ca..1ad60eba4a 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -208,6 +208,7 @@ def run_operation( # Mark the transaction as in progress once the first transactional message is about to be sent, # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: + operation.session._transaction.has_sent_command = True operation.session._transaction.state = _TxnState.IN_PROGRESS conn.send_message(data, max_doc_size) reply = conn.receive_message(request_id) From c60b9488d3677b04f88f5832840cd7390b4d1777 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 15:25:27 -0400 Subject: [PATCH 08/10] fix bulk tests --- pymongo/asynchronous/bulk.py | 11 ++++++++++- pymongo/synchronous/bulk.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 4a54f9eb3f..ad13b5210f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -36,7 +36,11 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.client_session import ( + AsyncClientSession, + _TxnState, + _validate_session_write_concern, +) from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -271,6 +275,11 @@ async def write_command( if bwc.publish: bwc._start(cmd, request_id, docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + # Mark the transaction as in progress once the first + # transactional bulk message is about to go on the wire. + bwc.session._transaction.has_sent_command = True + bwc.session._transaction.state = _TxnState.IN_PROGRESS reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 22d6a7a76a..60ba44fd95 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -67,7 +67,11 @@ _randint, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.client_session import ( + ClientSession, + _TxnState, + _validate_session_write_concern, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -271,6 +275,11 @@ def write_command( if bwc.publish: bwc._start(cmd, request_id, docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + # Mark the transaction as in progress once the first + # transactional bulk message is about to go on the wire. + bwc.session._transaction.has_sent_command = True + bwc.session._transaction.state = _TxnState.IN_PROGRESS reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): From 069e875c88fff0b75ac70a84e79bd52e4f5d81bc Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 8 Jun 2026 15:56:45 -0400 Subject: [PATCH 09/10] fix failure --- pymongo/asynchronous/client_bulk.py | 9 ++++++++- pymongo/asynchronous/mongo_client.py | 8 ++++---- pymongo/synchronous/client_bulk.py | 9 ++++++++- pymongo/synchronous/mongo_client.py | 8 ++++---- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 015947d7ef..2188964200 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -35,7 +35,11 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.client_session import ( + AsyncClientSession, + _TxnState, + _validate_session_write_concern, +) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.database import AsyncDatabase @@ -258,6 +262,9 @@ async def write_command( if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.has_sent_command = True + bwc.session._transaction.state = _TxnState.IN_PROGRESS reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b250142db9..2e04fa4c24 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2871,14 +2871,14 @@ async def run(self) -> T: self._attempt_number += 1 # Revert back to starting state only if the first - # transactional command was never sent. + # transactional command was never completed. if ( overloaded and self._session is not None and self._session.in_transaction ): transaction = self._session._transaction - if not transaction.has_sent_command: + if not transaction.has_completed_command: transaction.set_starting() transaction.attempt = 0 else: @@ -2922,10 +2922,10 @@ async def run(self) -> T: if self._last_error is None: self._last_error = exc # Revert back to starting state only if the first - # transactional command was never sent. + # transactional command was never completed. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction - if not transaction.has_sent_command: + if not transaction.has_completed_command: transaction.set_starting() transaction.attempt = 0 diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 1134594ae9..bd4a193669 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -35,7 +35,11 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.client_session import ( + ClientSession, + _TxnState, + _validate_session_write_concern, +) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database @@ -258,6 +262,9 @@ def write_command( if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: + if bwc.session is not None and bwc.session._starting_transaction: + bwc.session._transaction.has_sent_command = True + bwc.session._transaction.state = _TxnState.IN_PROGRESS reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index e400844dc8..dc3c1434e9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2861,14 +2861,14 @@ def run(self) -> T: self._attempt_number += 1 # Revert back to starting state only if the first - # transactional command was never sent. + # transactional command was never completed. if ( overloaded and self._session is not None and self._session.in_transaction ): transaction = self._session._transaction - if not transaction.has_sent_command: + if not transaction.has_completed_command: transaction.set_starting() transaction.attempt = 0 else: @@ -2912,10 +2912,10 @@ def run(self) -> T: if self._last_error is None: self._last_error = exc # Revert back to starting state only if the first - # transactional command was never sent. + # transactional command was never completed. if overloaded and self._session is not None and self._session.in_transaction: transaction = self._session._transaction - if not transaction.has_sent_command: + if not transaction.has_completed_command: transaction.set_starting() transaction.attempt = 0 From 78628a634e551ceed206377ea45171df44ceba8a Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Tue, 9 Jun 2026 09:30:52 -0400 Subject: [PATCH 10/10] refactor --- pymongo/asynchronous/bulk.py | 6 +----- pymongo/asynchronous/client_bulk.py | 4 +--- pymongo/asynchronous/client_session.py | 5 +++++ pymongo/asynchronous/pool.py | 5 ++--- pymongo/asynchronous/server.py | 6 +----- pymongo/synchronous/bulk.py | 6 +----- pymongo/synchronous/client_bulk.py | 4 +--- pymongo/synchronous/client_session.py | 5 +++++ pymongo/synchronous/pool.py | 5 ++--- pymongo/synchronous/server.py | 6 +----- 10 files changed, 20 insertions(+), 32 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ad13b5210f..8da5ffcb47 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -38,7 +38,6 @@ from pymongo import _csot, common from pymongo.asynchronous.client_session import ( AsyncClientSession, - _TxnState, _validate_session_write_concern, ) from pymongo.asynchronous.helpers import _handle_reauth @@ -276,10 +275,7 @@ async def write_command( bwc._start(cmd, request_id, docs) try: if bwc.session is not None and bwc.session._starting_transaction: - # Mark the transaction as in progress once the first - # transactional bulk message is about to go on the wire. - bwc.session._transaction.has_sent_command = True - bwc.session._transaction.state = _TxnState.IN_PROGRESS + bwc.session._transaction.set_in_progress() reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 2188964200..dcef4eea02 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -37,7 +37,6 @@ from pymongo import _csot, common from pymongo.asynchronous.client_session import ( AsyncClientSession, - _TxnState, _validate_session_write_concern, ) from pymongo.asynchronous.collection import AsyncCollection @@ -263,8 +262,7 @@ async def write_command( bwc._start(cmd, request_id, op_docs, ns_docs) try: if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.has_sent_command = True - bwc.session._transaction.state = _TxnState.IN_PROGRESS + bwc.session._transaction.set_in_progress() reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 0fb9336354..c5d5d7f298 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -444,6 +444,11 @@ def starting(self) -> bool: def set_starting(self) -> None: self.state = _TxnState.STARTING + def set_in_progress(self) -> None: + if self.state == _TxnState.STARTING: + self.has_sent_command = True + self.state = _TxnState.IN_PROGRESS + @property def pinned_conn(self) -> Optional[AsyncConnection]: if self.active() and self.conn_mgr: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 083a530dbf..0a08ddc789 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -38,7 +38,7 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared -from pymongo.asynchronous.client_session import _TxnState, _validate_session_write_concern +from pymongo.asynchronous.client_session import _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.network import command from pymongo.common import ( @@ -401,8 +401,7 @@ async def command( self._raise_if_not_writable(unacknowledged) try: if session is not None and session._starting_transaction: - session._transaction.has_sent_command = True - session._transaction.state = _TxnState.IN_PROGRESS + session._transaction.set_in_progress() return await command( self, dbname, diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index c271a5d959..0bbbba252b 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -27,7 +27,6 @@ ) from bson import _decode_all_selective -from pymongo.asynchronous.client_session import _TxnState from pymongo.asynchronous.helpers import _handle_reauth from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers_shared import _check_command_response @@ -205,11 +204,8 @@ async def run_operation( if more_to_come: reply = await conn.receive_message(None) else: - # Mark the transaction as in progress once the first transactional message is about to be sent, - # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.has_sent_command = True - operation.session._transaction.state = _TxnState.IN_PROGRESS + operation.session._transaction.set_in_progress() await conn.send_message(data, max_doc_size) reply = await conn.receive_message(request_id) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 60ba44fd95..f6e1d1abe4 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -69,7 +69,6 @@ from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ( ClientSession, - _TxnState, _validate_session_write_concern, ) from pymongo.synchronous.helpers import _handle_reauth @@ -276,10 +275,7 @@ def write_command( bwc._start(cmd, request_id, docs) try: if bwc.session is not None and bwc.session._starting_transaction: - # Mark the transaction as in progress once the first - # transactional bulk message is about to go on the wire. - bwc.session._transaction.has_sent_command = True - bwc.session._transaction.state = _TxnState.IN_PROGRESS + bwc.session._transaction.set_in_progress() reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index bd4a193669..400b1a2170 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -37,7 +37,6 @@ from pymongo import _csot, common from pymongo.synchronous.client_session import ( ClientSession, - _TxnState, _validate_session_write_concern, ) from pymongo.synchronous.collection import Collection @@ -263,8 +262,7 @@ def write_command( bwc._start(cmd, request_id, op_docs, ns_docs) try: if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.has_sent_command = True - bwc.session._transaction.state = _TxnState.IN_PROGRESS + bwc.session._transaction.set_in_progress() reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 9a3a7e7f27..f4df500549 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -442,6 +442,11 @@ def starting(self) -> bool: def set_starting(self) -> None: self.state = _TxnState.STARTING + def set_in_progress(self) -> None: + if self.state == _TxnState.STARTING: + self.has_sent_command = True + self.state = _TxnState.IN_PROGRESS + @property def pinned_conn(self) -> Optional[Connection]: if self.active() and self.conn_mgr: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1f129253dd..fe7577f5f3 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -87,7 +87,7 @@ from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.synchronous.client_session import _TxnState, _validate_session_write_concern +from pymongo.synchronous.client_session import _validate_session_write_concern from pymongo.synchronous.helpers import _handle_reauth from pymongo.synchronous.network import command @@ -401,8 +401,7 @@ def command( self._raise_if_not_writable(unacknowledged) try: if session is not None and session._starting_transaction: - session._transaction.has_sent_command = True - session._transaction.state = _TxnState.IN_PROGRESS + session._transaction.set_in_progress() return command( self, dbname, diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 1ad60eba4a..e2e5e8503c 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -38,7 +38,6 @@ ) from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response -from pymongo.synchronous.client_session import _TxnState from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: @@ -205,11 +204,8 @@ def run_operation( if more_to_come: reply = conn.receive_message(None) else: - # Mark the transaction as in progress once the first transactional message is about to be sent, - # so local validation errors keep the session in STARTING, but post-send failures do not. if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.has_sent_command = True - operation.session._transaction.state = _TxnState.IN_PROGRESS + operation.session._transaction.set_in_progress() conn.send_message(data, max_doc_size) reply = conn.receive_message(request_id)