Skip to content
Open
207 changes: 49 additions & 158 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from __future__ import annotations

import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
Expand All @@ -37,6 +35,10 @@
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.command_runner import (
run_acknowledged_command,
run_unacknowledged_command,
)
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
_COMMANDS,
Expand All @@ -53,18 +55,14 @@
from pymongo.errors import (
ConfigurationError,
InvalidOperation,
NotPrimaryError,
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
Expand Down Expand Up @@ -250,81 +248,26 @@ async def write_command(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
"""Run a batch write command, returning the response as a dict."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BulkWriteContext._start, _BulkWriteContext._succeed, _BulkWriteContext._fail, and _ClientBulkWriteContext._start are all dead code.

try:
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)

if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
return reply # type: ignore[return-value]
result_docs, _, _ = await run_acknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
use_conn_transport=True,
decrypt_reply=False,
)
return result_docs[0]

async def unack_write(
self,
Expand All @@ -336,83 +279,31 @@ async def unack_write(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_convert_write_result is also dead code now.

else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
return result # type: ignore[return-value]
"""Send an unacknowledged batch write command."""
# Historically the STARTED log omits the documents while the published
# CommandStartedEvent includes them, so log ``cmd`` but publish a copy
# carrying the ``docs`` field.
published = dict(cmd)
published[bwc.field] = docs
await run_unacknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
orig=published,
use_conn_transport=True,
max_doc_size=max_doc_size,
)
return None

async def _execute_batch_unack(
self,
Expand Down Expand Up @@ -484,7 +375,7 @@ async def _execute_command(
run = self.current_run

# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
# run_acknowledged_command/run_unacknowledged_command.
conn.validate_session(client, session)
last_run = False

Expand Down
Loading
Loading