Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-6654] indexer recover from connection errors #835

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions share/search/daemon.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import contextlib
import collections
from collections.abc import Callable
import dataclasses
import logging
import queue
import random
import threading
import time

import amqp.exceptions
from django.conf import settings
import kombu
from kombu.mixins import ConsumerMixin
import sentry_sdk

Expand All @@ -27,6 +30,7 @@
MINIMUM_BACKOFF_FACTOR = 1.6 # unitless ratio
MAXIMUM_BACKOFF_FACTOR = 2.0 # unitless ratio
MAXIMUM_BACKOFF_TIMEOUT = 60 # seconds
CONNECTION_HEARTBEAT = 20 # seconds (see https://www.rabbitmq.com/docs/heartbeats#false-positives )


class TooFastSlowDown(Exception):
Expand All @@ -35,7 +39,10 @@ class TooFastSlowDown(Exception):

class IndexerDaemonControl:
def __init__(self, celery_app, *, daemonthread_context=None, stop_event=None):
self.celery_app = celery_app
self.kombu_connection = kombu.Connection(
celery_app.conf.broker_url, # use celery_app.conf for consistent config
heartbeat=CONNECTION_HEARTBEAT,
)
self.daemonthread_context = daemonthread_context
self._daemonthreads = []
# shared stop_event for all threads below
Expand All @@ -50,10 +57,16 @@ def start_daemonthreads_for_strategy(self, index_strategy):
)
# spin up daemonthreads, ready for messages
self._daemonthreads.extend(_daemon.start())
# assign a thread to pass messages to this daemon
threading.Thread(
target=CeleryMessageConsumer(self.celery_app, _daemon).run,
).start()
_consumer = KombuMessageConsumer(
kombu_connection=self.kombu_connection.clone(),
stop_event=self.stop_event,
index_strategy=index_strategy,
message_callback=_daemon.on_message,
)
# give the daemon a more robust callback for ack-ing
_daemon.ack_callback = _consumer.ensure_ack
# assign a thread for the consumer to receive and enqueue messages to this daemon
threading.Thread(target=_consumer.run).start()
return _daemon

def start_all_daemonthreads(self):
Expand All @@ -67,18 +80,16 @@ def stop_daemonthreads(self, *, wait=False):
_thread.join()


class CeleryMessageConsumer(ConsumerMixin):
class KombuMessageConsumer(ConsumerMixin):
PREFETCH_COUNT = 7500

# (from ConsumerMixin)
# should_stop: bool
should_stop: bool # (from ConsumerMixin)

def __init__(self, celery_app, indexer_daemon):
self.connection = celery_app.pool.acquire(block=True)
self.celery_app = celery_app
self.__stop_event = indexer_daemon.stop_event
self.__message_callback = indexer_daemon.on_message
self.__index_strategy = indexer_daemon.index_strategy
def __init__(self, *, kombu_connection, stop_event, message_callback, index_strategy):
self.connection = kombu_connection
self.__stop_event = stop_event
self.__message_callback = message_callback
self.__index_strategy = index_strategy

# overrides ConsumerMixin.run
def run(self):
Expand Down Expand Up @@ -112,9 +123,34 @@ def get_consumers(self, Consumer, channel):
def __repr__(self):
return '<{}({})>'.format(self.__class__.__name__, self.__index_strategy.name)

def consume(self, *args, **kwargs):
# wrap `consume` in `kombu.Connection.ensure`, following guidance from
# https://docs.celeryq.dev/projects/kombu/en/stable/userguide/failover.html#consumer
consume = self.connection.ensure(self.connection, super().consume)
return consume(*args, **kwargs)

def ensure_ack(self, daemon_message: messages.DaemonMessage):
# if the connection the message came thru is no longer usable,
# use `kombu.Connection.autoretry` to revive it for an ack
try:
daemon_message.ack()
except (ConnectionError, amqp.exceptions.ConnectionError):
@self.connection.autoretry
def _do_ack(*, channel):
try:
channel.basic_ack(daemon_message.kombu_message.delivery_tag)
finally:
channel.close()
_do_ack()


def _default_ack_callback(daemon_message: messages.DaemonMessage) -> None:
daemon_message.ack()


class IndexerDaemon:
MAX_LOCAL_QUEUE_SIZE = 5000
ack_callback: Callable[[messages.DaemonMessage], None]

def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None):
self.stop_event = (
Expand All @@ -126,6 +162,7 @@ def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None
self.__daemonthread_context = daemonthread_context or contextlib.nullcontext
self.__local_message_queues = {}
self.__started = False
self.ack_callback = _default_ack_callback

def start(self) -> list[threading.Thread]:
if self.__started:
Expand Down Expand Up @@ -154,6 +191,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread:
local_message_queue=_queue_from_rabbit_to_daemon,
log_prefix=f'{repr(self)} MessageHandlingLoop: ',
daemonthread_context=self.__daemonthread_context,
ack_callback=self.ack_callback,
)
return _handling_loop.start_thread()

Expand Down Expand Up @@ -186,7 +224,8 @@ class MessageHandlingLoop:
stop_event: threading.Event
local_message_queue: queue.Queue
log_prefix: str
daemonthread_context: contextlib.AbstractContextManager
daemonthread_context: Callable[[], contextlib.AbstractContextManager]
ack_callback: Callable[[messages.DaemonMessage], None]
_leftover_daemon_messages_by_target_id = None

def __post_init__(self):
Expand Down Expand Up @@ -270,7 +309,7 @@ def _handle_some_messages(self):
sentry_sdk.capture_message('error handling message', extras={'message_response': message_response})
target_id = message_response.index_message.target_id
for daemon_message in daemon_messages_by_target_id.pop(target_id, ()):
daemon_message.ack() # finally set it free
self.ack_callback(daemon_message)
if daemon_messages_by_target_id: # should be empty by now
logger.error('%sUnhandled messages?? %s', self.log_prefix, len(daemon_messages_by_target_id))
sentry_sdk.capture_message(
Expand Down
Loading