diff --git a/config/system_config_demo.yml b/config/system_config_demo.yml index 1b20af6c..71b2c0a4 100644 --- a/config/system_config_demo.yml +++ b/config/system_config_demo.yml @@ -125,6 +125,7 @@ std_datasets: path: workloads/IMDB_100GB/adhoc_test/ bootstrap_vdbe_path: config/vdbe_demo/imdb_extended_vdbes.json +# bootstrap_vdbe_path: config/vdbe_demo/imdb_editable_vdbes.json aurora_max_query_factor: 4.0 aurora_max_query_factor_replace: 10000.0 diff --git a/src/brad/connection/factory.py b/src/brad/connection/factory.py index 5bf95cf6..90a621c4 100644 --- a/src/brad/connection/factory.py +++ b/src/brad/connection/factory.py @@ -28,6 +28,10 @@ async def connect_to( if config.stub_mode_path() is not None: return cls.connect_to_stub(config) + # HACK: Schema aliasing for convenience. + if schema_name is not None and schema_name == "imdb_editable_100g": + schema_name = "imdb_extended_100g" + connection_details = config.get_connection_details(engine) if engine == Engine.Redshift: cluster = directory.redshift_cluster() @@ -153,6 +157,10 @@ async def connect_to_sidecar( if config.stub_mode_path() is not None: return cls.connect_to_stub(config) + # HACK: Schema aliasing for convenience. + if schema_name is not None and schema_name == "imdb_editable_100g": + schema_name = "imdb_extended_100g" + connection_details = config.get_sidecar_db_details() if ( _USE_PSYCOPG_KEY in connection_details diff --git a/src/brad/daemon/daemon.py b/src/brad/daemon/daemon.py index e1f5fb5f..8b43223a 100644 --- a/src/brad/daemon/daemon.py +++ b/src/brad/daemon/daemon.py @@ -27,6 +27,8 @@ InternalCommandResponse, NewBlueprint, NewBlueprintAck, + ReconcileVirtualInfrastructure, + ReconcileVirtualInfrastructureAck, ) from brad.daemon.monitor import Monitor from brad.daemon.system_event_logger import SystemEventLogger @@ -37,7 +39,8 @@ from brad.data_stats.postgres_estimator import PostgresEstimator from brad.data_stats.stub_estimator import StubEstimator from brad.data_sync.execution.executor import DataSyncExecutor -from brad.front_end.start_front_end import start_front_end +from brad.front_end.start_front_end import start_front_end, start_vdbe_front_end +from brad.front_end.vdbe.vdbe_front_end import BradVdbeFrontEnd from brad.planner.abstract import BlueprintPlanner from brad.planner.compare.provider import ( BlueprintComparatorProvider, @@ -69,8 +72,10 @@ from brad.routing.tree_based.forest_policy import ForestPolicy from brad.row_list import RowList from brad.utils.time_periods import period_start, universal_now +from brad.utils.mailbox import Mailbox from brad.ui.manager import UiManager from brad.vdbe.manager import VdbeManager +from brad.vdbe.models import VirtualInfrastructure logger = logging.getLogger(__name__) @@ -134,9 +139,11 @@ def __init__( self._vdbe_manager: Optional[VdbeManager] = VdbeManager.load_from( load_vdbe_path, starting_port=9876, + apply_infra=self._apply_virtual_infra, ) else: self._vdbe_manager = None + self._vdbe_process: Optional[_VdbeFrontEndProcess] = None # This is used to hold references to internal command tasks we create. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task @@ -343,6 +350,31 @@ async def _run_setup(self) -> None: for fe in self._front_ends: fe.process.start() + if self._vdbe_manager is not None: + v_input_queue = self._process_manager.Queue() + v_output_queue = self._process_manager.Queue() + process = mp.Process( + target=start_vdbe_front_end, + args=( + self._config, + self._schema_name, + self._path_to_system_config, + self._debug_mode, + self._blueprint_mgr.get_directory(), + self._vdbe_manager.infra(), + v_input_queue, + v_output_queue, + ), + ) + self._vdbe_process = _VdbeFrontEndProcess( + process, v_input_queue, v_output_queue + ) + reader_task = asyncio.create_task( + self._read_vdbe_messages(self._vdbe_process) + ) + self._vdbe_process.message_reader_task = reader_task + self._vdbe_process.process.start() + if ( self._config.routing_policy == RoutingPolicy.ForestTableSelectivity or self._config.routing_policy == RoutingPolicy.Default @@ -375,6 +407,15 @@ async def _run_teardown(self) -> None: if fe.message_reader_task is not None: fe.output_queue.put(Sentinel(fe_index)) + if self._vdbe_process is not None: + logger.info("Telling the VDBE front end to shut down...") + self._vdbe_process.input_queue.put( + ShutdownFrontEnd(BradVdbeFrontEnd.NUMERIC_IDENTIFIER) + ) + self._vdbe_process.output_queue.put( + Sentinel(BradVdbeFrontEnd.NUMERIC_IDENTIFIER) + ) + if self._timed_sync_task is not None: self._timed_sync_task.cancel() self._timed_sync_task = None @@ -394,6 +435,11 @@ async def _run_teardown(self) -> None: fe.process.join() self._front_ends.clear() + if self._vdbe_process is not None: + logger.info("Waiting for the VDBE front end to shut down...") + self._vdbe_process.process.join() + self._vdbe_process = None + async def _read_front_end_messages(self, front_end: "_FrontEndProcess") -> None: """ Waits for messages from the specified front end process and processes them. @@ -465,6 +511,83 @@ async def _read_front_end_messages(self, front_end: "_FrontEndProcess") -> None: front_end.fe_index, ) + async def _read_vdbe_messages(self, vdbe_process: "_VdbeFrontEndProcess") -> None: + loop = asyncio.get_running_loop() + while True: + try: + message = await loop.run_in_executor( + None, vdbe_process.output_queue.get + ) + if message.fe_index != BradVdbeFrontEnd.NUMERIC_IDENTIFIER: + logger.warning( + "Received message with invalid front end index. Expected %d. Received %d.", + BradVdbeFrontEnd.NUMERIC_IDENTIFIER, + message.fe_index, + ) + continue + + if isinstance(message, NewBlueprintAck): + if self._transition_orchestrator is None: + logger.error( + "Received blueprint ack message but no transition is in progress. Version: %d, Front end: %d", + message.version, + message.fe_index, + ) + continue + + # Sanity check. + next_version = self._transition_orchestrator.next_version() + if next_version != message.version: + logger.error( + "Received a blueprint ack for a mismatched version. Received %d, Expected %d", + message.version, + next_version, + ) + continue + + logger.info( + "Received blueprint ack. Version: %d, Front end: %d", + message.version, + message.fe_index, + ) + + self._transition_orchestrator.decrement_waiting_for_front_ends() + if self._transition_orchestrator.waiting_for_front_ends() == 0: + # Schedule the second half of the transition. + self._transition_task = asyncio.create_task( + self._run_transition_part_two() + ) + + elif isinstance(message, ReconcileVirtualInfrastructureAck): + logger.info( + "Received reconcile ack from VDBE front end. Added %d, Removed %d", + message.num_added, + message.num_removed, + ) + vdbe_process.mailbox.on_new_message((None,)) + + else: + logger.debug( + "Received unexpected message from front end %d: %s", + BradVdbeFrontEnd.NUMERIC_IDENTIFIER, + str(message), + ) + + except Exception as ex: + if not isinstance(ex, asyncio.CancelledError): + logger.exception( + "Unexpected error when handling front end message. Front end: %d", + BradVdbeFrontEnd.NUMERIC_IDENTIFIER, + ) + + async def _apply_virtual_infra(self, virtual_infra: VirtualInfrastructure) -> None: + """ + Used by the VDBE manager to apply a change to the virtual infrastructure + on the VDBE front end. This returns after the change is applied. + """ + assert self._vdbe_process is not None + await self._vdbe_process.mailbox.send_recv(virtual_infra) + async def _handle_new_blueprint( self, blueprint: Blueprint, score: Score, trigger: Optional[Trigger] ) -> None: @@ -857,9 +980,6 @@ def update_monitor_sources(): "Notifying %d front ends about the new blueprint.", len(self._front_ends), ) - self._transition_orchestrator.set_waiting_for_front_ends( - len(self._front_ends) - ) for fe in self._front_ends: fe.input_queue.put( NewBlueprint( @@ -869,6 +989,18 @@ def update_monitor_sources(): ) ) + total_wait = len(self._front_ends) + if self._vdbe_process is not None: + self._vdbe_process.input_queue.put( + NewBlueprint( + BradVdbeFrontEnd.NUMERIC_IDENTIFIER, + tm.next_version, + self._blueprint_mgr.get_directory(), + ) + ) + total_wait += 1 + + self._transition_orchestrator.set_waiting_for_front_ends(total_wait) self._transition_task = None # We finish the transition after all front ends acknowledge that they @@ -999,3 +1131,29 @@ def __init__( self.input_queue = input_queue self.output_queue = output_queue self.message_reader_task: Optional[asyncio.Task] = None + + +class _VdbeFrontEndProcess: + """ + Used to manage state associated with the VDBE front end process. + """ + + def __init__( + self, + process: mp.Process, + input_queue: queue.Queue, + output_queue: queue.Queue, + ) -> None: + self.process = process + self.input_queue = input_queue + self.output_queue = output_queue + self.message_reader_task: Optional[asyncio.Task] = None + self.mailbox: Mailbox[VirtualInfrastructure, Tuple] = Mailbox( + do_send_msg=self._send_message + ) + + async def _send_message(self, infra: VirtualInfrastructure) -> None: + logger.debug("Sending reconcile VDBE IPC message") + msg = ReconcileVirtualInfrastructure(BradVdbeFrontEnd.NUMERIC_IDENTIFIER, infra) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.input_queue.put, msg) diff --git a/src/brad/daemon/messages.py b/src/brad/daemon/messages.py index b902490f..91401fde 100644 --- a/src/brad/daemon/messages.py +++ b/src/brad/daemon/messages.py @@ -1,8 +1,10 @@ +from typing import Tuple, List from ddsketch import DDSketch from ddsketch.pb.proto import DDSketchProto, pb as ddspb from brad.provisioning.directory import Directory from brad.row_list import RowList +from brad.vdbe.models import VirtualInfrastructure class IpcMessage: @@ -88,6 +90,40 @@ def query_latency_sketch(self) -> DDSketch: return DDSketchProto.from_proto(pb_sketch) +class VdbeMetricsReport(IpcMessage): + """ + Sent from the VDBE front end to the daemon to report BRAD's client-side metrics. + """ + + @classmethod + def from_data( + cls, + fe_index: int, + latency_sketches: List[Tuple[int, DDSketch]], + ) -> "VdbeMetricsReport": + serialized_sketches = [ + (vdbe_id, DDSketchProto.to_proto(sketch).SerializeToString()) + for vdbe_id, sketch in latency_sketches + ] + return cls(fe_index, latency_sketches=serialized_sketches) + + def __init__( + self, + fe_index: int, + latency_sketches: List[Tuple[int, bytes]], + ) -> None: + super().__init__(fe_index) + self.serialized_latency_sketches = latency_sketches + + def query_latency_sketches(self) -> List[Tuple[int, DDSketch]]: + results = [] + for vdbe_id, serialized_sketch in self.serialized_latency_sketches: + pb_sketch = ddspb.DDSketch() + pb_sketch.ParseFromString(serialized_sketch) + results.append((vdbe_id, DDSketchProto.from_proto(pb_sketch))) + return results + + class InternalCommandRequest(IpcMessage): """ Sent from the front end to the daemon to handle an internal command. @@ -108,6 +144,27 @@ def __init__(self, fe_index: int, response: RowList) -> None: self.response = response +class ReconcileVirtualInfrastructure(IpcMessage): + """ + Sent from the daemon to the VDBE front end to update its virtual infrastructure. + """ + + def __init__(self, fe_index: int, virtual_infra: VirtualInfrastructure) -> None: + super().__init__(fe_index) + self.virtual_infra = virtual_infra + + +class ReconcileVirtualInfrastructureAck(IpcMessage): + """ + Sent from the VDBE front end back to the daemon to acknowledge the virtual infrastructure update. + """ + + def __init__(self, fe_index: int, num_added: int, num_removed: int) -> None: + super().__init__(fe_index) + self.num_added = num_added + self.num_removed = num_removed + + class ShutdownFrontEnd(IpcMessage): """ Sent from the daemon to the front end indicating that it should shut down. diff --git a/src/brad/exec/cli.py b/src/brad/exec/cli.py index 7d7ad7e9..c2ec48b2 100644 --- a/src/brad/exec/cli.py +++ b/src/brad/exec/cli.py @@ -2,7 +2,7 @@ import pathlib import readline import time -from typing import List +from typing import List, Tuple from tabulate import tabulate import brad @@ -15,29 +15,31 @@ def register_command(subparsers): "cli", help="Start a BRAD client session.", ) - parser.add_argument( - "--host", - type=str, - default="localhost", - help="The host where the BRAD server is running.", - ) - parser.add_argument( - "--port", - type=int, - default=6583, - help="The port on which BRAD is listening for connections.", - ) parser.add_argument( "-c", "--command", type=str, help="Run a single SQL query (or internal command) and exit.", ) + parser.add_argument( + "endpoint", + nargs="?", + help="The BRAD endpoint to connect to. Defaults to localhost:6583.", + default="localhost:6583", + ) parser.set_defaults(func=main) +def parse_endpoint(endpoint: str) -> Tuple[str, int]: + parts = endpoint.split(":") + if len(parts) != 2: + raise ValueError("Invalid endpoint format.") + return parts[0], int(parts[1]) + + def run_command(args) -> None: - with BradGrpcClient(args.host, args.port) as client: + host, port = parse_endpoint(args.endpoint) + with BradGrpcClient(host, port) as client: run_query(client, args.command) @@ -132,11 +134,12 @@ def main(args) -> None: run_command(args) return + host, port = parse_endpoint(args.endpoint) print("BRAD Interactive Shell v{}".format(brad.__version__)) print() - print("Connecting to BRAD at {}:{}...".format(args.host, args.port)) + print("Connecting to BRAD VDBE at {}:{}...".format(host, port)) - with BradGrpcClient(args.host, args.port) as client: + with BradGrpcClient(host, port) as client: print("Connected!") print() print("Terminate all SQL queries with a semicolon (;). Hit Ctrl-D to exit.") diff --git a/src/brad/front_end/session.py b/src/brad/front_end/session.py index 416e2515..90d582a3 100644 --- a/src/brad/front_end/session.py +++ b/src/brad/front_end/session.py @@ -77,7 +77,11 @@ async def close(self): class SessionManager: def __init__( - self, config: ConfigFile, blueprint_mgr: "BlueprintManager", schema_name: str + self, + config: ConfigFile, + blueprint_mgr: "BlueprintManager", + schema_name: str, + for_vdbes: bool = False, ) -> None: self._config = config self._blueprint_mgr = blueprint_mgr @@ -89,6 +93,7 @@ def __init__( # project. For now we assume that we always operate against one schema # and that it is provided up front when starting BRAD. self._schema_name = schema_name + self._for_vdbes = for_vdbes async def create_new_session(self) -> Tuple[SessionId, Session]: logger.debug("Creating a new session...") @@ -114,7 +119,7 @@ async def create_new_session(self) -> Tuple[SessionId, Session]: # Create an estimator if needed. The estimator should be # session-specific since it currently depends on a DB connection. routing_policy_override = self._config.routing_policy - if ( + if not self._for_vdbes and ( routing_policy_override == RoutingPolicy.ForestTableSelectivity or routing_policy_override == RoutingPolicy.Default ): diff --git a/src/brad/front_end/start_front_end.py b/src/brad/front_end/start_front_end.py index d66c4958..5709fd21 100644 --- a/src/brad/front_end/start_front_end.py +++ b/src/brad/front_end/start_front_end.py @@ -5,8 +5,10 @@ from brad.config.file import ConfigFile from brad.front_end.front_end import BradFrontEnd +from brad.front_end.vdbe.vdbe_front_end import BradVdbeFrontEnd from brad.provisioning.directory import Directory from brad.utils import set_up_logging +from brad.vdbe.models import VirtualInfrastructure logger = logging.getLogger(__name__) @@ -59,6 +61,57 @@ def start_front_end( logger.info("BRAD front end %d has shut down.", fe_index) +def start_vdbe_front_end( + config: ConfigFile, + schema_name: str, + path_to_system_config: str, + debug_mode: bool, + directory: Directory, + initial_infra: VirtualInfrastructure, + input_queue: mp.Queue, + output_queue: mp.Queue, +) -> None: + """ + Schedule this method to run in a child process to launch a BRAD front + end server. + """ + set_up_logging( + filename=config.front_end_log_file(BradVdbeFrontEnd.NUMERIC_IDENTIFIER), + debug_mode=debug_mode, + ) + + event_loop = asyncio.new_event_loop() + event_loop.set_debug(enabled=debug_mode) + asyncio.set_event_loop(event_loop) + + # Signal handlers are inherited from the parent server process. We want + # to ignore these signals since we receive a shutdown signal from the + # daemon directly. + for sig in [signal.SIGTERM, signal.SIGINT]: + event_loop.add_signal_handler(sig, _noop) + # This is useful for debugging purposes. + event_loop.add_signal_handler(signal.SIGUSR1, _drop_into_pdb) + event_loop.set_exception_handler(_handle_exception) + + try: + front_end = BradVdbeFrontEnd( + config, + schema_name, + path_to_system_config, + debug_mode, + directory, + initial_infra, + input_queue, + output_queue, + ) + event_loop.create_task(front_end.serve_forever()) + logger.info("BRAD VDBE front end is starting...") + event_loop.run_forever() + finally: + event_loop.close() + logger.info("BRAD VDBE front end has shut down.") + + def _handle_exception(event_loop, context): message = context.get("exception", context["message"]) logging.error("Encountered uncaught exception: %s", message) diff --git a/src/brad/front_end/vdbe/__init__.py b/src/brad/front_end/vdbe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/brad/front_end/vdbe/vdbe_endpoint_manager.py b/src/brad/front_end/vdbe/vdbe_endpoint_manager.py new file mode 100644 index 00000000..16c4814e --- /dev/null +++ b/src/brad/front_end/vdbe/vdbe_endpoint_manager.py @@ -0,0 +1,158 @@ +import grpc +import json +import logging +from typing import Callable, Optional, Tuple, Dict, AsyncIterable, Any, Set, Awaitable + +import brad.proto_gen.brad_pb2_grpc as brad_grpc +from brad.connection.schema import Schema +from brad.front_end.brad_interface import BradInterface +from brad.front_end.grpc import BradGrpc +from brad.front_end.session import SessionManager, SessionId +from brad.row_list import RowList +from brad.utils.json_decimal_encoder import DecimalEncoder +from brad.vdbe.manager import VdbeFrontEndManager + +logger = logging.getLogger(__name__) + + +# (query_string, vdbe_id, session_id, debug_info) -> (rows, schema) +QueryHandler = Callable[ + [str, int, SessionId, Dict[str, Any]], Awaitable[Tuple[RowList, Optional[Schema]]] +] + + +class VdbeEndpointManager: + """ + Used to start/stop VDBE endpoints. Right now, we only support the BRAD gRPC + interface. + """ + + def __init__( + self, + *, + vdbe_mgr: VdbeFrontEndManager, + session_mgr: SessionManager, + handler: QueryHandler, + ) -> None: + self._vdbe_mgr = vdbe_mgr + self._session_mgr = session_mgr + self._handler = handler + self._endpoints: Dict[int, Tuple[int, grpc.aio.Server, VdbeGrpcInterface]] = {} + + async def initialize(self) -> None: + for engine in self._vdbe_mgr.engines(): + endpoint = engine.endpoint + if endpoint is None: + logger.warning( + "Engine %s (ID: %d) has no endpoint. Skipping adding VDBE endpoint.", + engine.internal_id, + engine.name, + ) + continue + port = int(endpoint.split(":")[1]) + await self.add_vdbe_endpoint(port, engine.internal_id) + + async def shutdown(self) -> None: + known_ids = list(self._endpoints.keys()) + for vdbe_id in known_ids: + await self.remove_vdbe_endpoint(vdbe_id) + + async def add_vdbe_endpoint(self, port: int, vdbe_id: int) -> None: + query_service = VdbeGrpcInterface( + vdbe_id=vdbe_id, handler=self._handler, session_mgr=self._session_mgr + ) + grpc_server = grpc.aio.server() + brad_grpc.add_BradServicer_to_server(BradGrpc(query_service), grpc_server) + grpc_server.add_insecure_port(f"0.0.0.0:{port}") + await grpc_server.start() + logger.info( + "Added VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port + ) + self._endpoints[vdbe_id] = (port, grpc_server, query_service) + + async def remove_vdbe_endpoint(self, vdbe_id: int) -> None: + try: + port, grpc_server, query_service = self._endpoints[vdbe_id] + await query_service.end_all_sessions() + # See `brad.front_end.BradFrontEnd.serve_forever`. + grpc_server.__del__() + del self._endpoints[vdbe_id] + logger.info("Removed VDBE endpoint for ID %d (was port %d).", vdbe_id, port) + + except KeyError: + logger.error( + "Tried to remove VDBE endpoint for ID %d, but it was not found.", + vdbe_id, + ) + + async def reconcile(self) -> Tuple[int, int]: + to_add = [] + to_remove = [] + seen_ids = set() + + for engine in self._vdbe_mgr.engines(): + if engine.internal_id not in self._endpoints: + to_add.append(engine) + seen_ids.add(engine.internal_id) + + for vdbe_id in self._endpoints.keys(): + if vdbe_id not in seen_ids: + to_remove.append(vdbe_id) + + for vdbe in to_add: + if vdbe.endpoint is None: + logger.warning( + "VDBE %s (ID: %d) has no endpoint. Skipping adding VDBE endpoint.", + vdbe.name, + vdbe.internal_id, + ) + continue + port = int(vdbe.endpoint.split(":")[1]) + await self.add_vdbe_endpoint(port, vdbe.internal_id) + + for vdbe_id in to_remove: + await self.remove_vdbe_endpoint(vdbe_id) + + return len(to_add), len(to_remove) + + +class VdbeGrpcInterface(BradInterface): + def __init__( + self, *, vdbe_id: int, handler: QueryHandler, session_mgr: SessionManager + ) -> None: + self._vdbe_id = vdbe_id + self._session_mgr = session_mgr + self._handler = handler + self._our_sessions: Set[SessionId] = set() + + async def start_session(self) -> SessionId: + session_id, _ = await self._session_mgr.create_new_session() + self._our_sessions.add(session_id) + return session_id + + def run_query( + self, session_id: SessionId, query: str, debug_info: Dict[str, Any] + ) -> AsyncIterable[bytes]: + # Purposefully not implemented - this is a legacy interface. + raise NotImplementedError + + async def run_query_json( + self, session_id: SessionId, query: str, debug_info: Dict[str, Any] + ) -> str: + """ + Returns query results encoded as a JSON string. + + This method may throw an error to indicate a problem with the query. + """ + results, _ = await self._handler(query, self._vdbe_id, session_id, debug_info) + return json.dumps(results, cls=DecimalEncoder, default=str) + + async def end_session(self, session_id: SessionId) -> None: + await self._session_mgr.end_session(session_id) + self._our_sessions.remove(session_id) + + async def end_all_sessions(self) -> None: + our_sessions = self._our_sessions.copy() + self._our_sessions.clear() + for session_id in our_sessions: + await self._session_mgr.end_session(session_id) diff --git a/src/brad/front_end/vdbe/vdbe_front_end.py b/src/brad/front_end/vdbe/vdbe_front_end.py new file mode 100644 index 00000000..c6fd66b6 --- /dev/null +++ b/src/brad/front_end/vdbe/vdbe_front_end.py @@ -0,0 +1,574 @@ +import asyncio +import logging +import ssl +import multiprocessing as mp +import redshift_connector.error as redshift_errors +import psycopg +import struct +from typing import Optional, Dict, Any, Tuple +from datetime import timedelta +from ddsketch import DDSketch +import pyodbc + +from brad.asset_manager import AssetManager +from brad.blueprint.manager import BlueprintManager +from brad.config.engine import Engine +from brad.config.file import ConfigFile +from brad.connection.connection import ConnectionFailed +from brad.connection.schema import Schema +from brad.daemon.monitor import Monitor +from brad.daemon.messages import ( + ShutdownFrontEnd, + Sentinel, + VdbeMetricsReport, + NewBlueprint, + NewBlueprintAck, + ReconcileVirtualInfrastructure, + ReconcileVirtualInfrastructureAck, +) +from brad.front_end.errors import QueryError +from brad.front_end.session import SessionManager, SessionId +from brad.front_end.watchdog import Watchdog +from brad.provisioning.directory import Directory +from brad.row_list import RowList +from brad.utils import log_verbose, create_custom_logger +from brad.utils.rand_exponential_backoff import RandomizedExponentialBackoff +from brad.utils.run_time_reservoir import RunTimeReservoir +from brad.utils.time_periods import universal_now +from brad.vdbe.manager import VdbeFrontEndManager +from brad.vdbe.models import VirtualInfrastructure +from brad.front_end.vdbe.vdbe_endpoint_manager import VdbeEndpointManager + +logger = logging.getLogger(__name__) + +LINESEP = "\n".encode() + + +class BradVdbeFrontEnd: + NUMERIC_IDENTIFIER = 10101 + + def __init__( + self, + config: ConfigFile, + schema_name: str, + path_to_system_config: str, + debug_mode: bool, + initial_directory: Directory, + initial_infra: VirtualInfrastructure, + input_queue: mp.Queue, + output_queue: mp.Queue, + ): + self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None + + self._config = config + self._schema_name = schema_name + self._debug_mode = debug_mode + + # Used for IPC with the daemon. Eventually we will use RPC to + # communicate with the daemon. But there's currently no need for + # something fancy here. + # Used for messages sent from the daemon to this front end server. + self._input_queue = input_queue + # Used for messages sent from this front end server to the daemon. + self._output_queue = output_queue + + self._assets = AssetManager(self._config) + self._blueprint_mgr = BlueprintManager( + self._config, + self._assets, + self._schema_name, + # This is provided by the daemon. We want to avoid hitting the AWS + # cluster metadata APIs when starting up the front end(s). + initial_directory=initial_directory, + ) + self._path_to_system_config = path_to_system_config + self._monitor: Optional[Monitor] = None + + # Used to track query performance. + self._query_run_times = RunTimeReservoir[float]( + self._config.front_end_query_latency_buffer_size + ) + + self._sessions = SessionManager( + self._config, self._blueprint_mgr, self._schema_name, for_vdbes=True + ) + self._daemon_messages_task: Optional[asyncio.Task[None]] = None + + self._reset_latency_sketches() + self._brad_metrics_reporting_task: Optional[asyncio.Task[None]] = None + + # Used to re-establish engine connections. + self._reestablish_connections_task: Optional[asyncio.Task[None]] = None + + # Used for logging transient errors that are too verbose. + main_log_file = config.front_end_log_file(self.NUMERIC_IDENTIFIER) + if main_log_file is not None: + verbose_log_file = ( + main_log_file.parent + / f"brad_front_end_verbose_{self.NUMERIC_IDENTIFIER}.log" + ) + self._verbose_logger: Optional[logging.Logger] = create_custom_logger( + "fe_verbose", str(verbose_log_file) + ) + self._verbose_logger.info("Verbose logger enabled.") + else: + self._verbose_logger = None + + # Used for debug purposes. + # We print the system state if the front end becomes unresponsive for >= 5 mins. + self._watchdog = Watchdog( + check_period=timedelta(minutes=1), take_action_after=timedelta(minutes=5) + ) + self._ping_watchdog_task: Optional[asyncio.Task[None]] = None + + self._is_stub_mode = self._config.stub_mode_path() is not None + + self._vdbe_mgr = VdbeFrontEndManager(initial_infra) + self._endpoint_mgr = VdbeEndpointManager( + vdbe_mgr=self._vdbe_mgr, + session_mgr=self._sessions, + handler=self._run_query_impl, + ) + self._shutdown_event = asyncio.Event() + + async def serve_forever(self): + await self._run_setup() + try: + # The server is shut down when we receive a shutdown message. + await self._shutdown_event.wait() + finally: + logger.info("BRAD VDBE front end is shutting down...") + await self._run_teardown() + logger.info("BRAD VDBE front end _run_teardown() complete.") + + async def _run_setup(self) -> None: + self._main_thread_loop = asyncio.get_running_loop() + + # The directory will have been populated by the daemon. + await self._blueprint_mgr.load(skip_directory_refresh=True) + logger.info("Using blueprint: %s", self._blueprint_mgr.get_blueprint()) + + if self._monitor is not None: + self._monitor.set_up_metrics_sources() + await self._monitor.fetch_latest() + + # Start the metrics reporting task. + self._brad_metrics_reporting_task = asyncio.create_task( + self._report_metrics_to_daemon() + ) + + # Used to handle messages from the daemon. + self._daemon_messages_task = asyncio.create_task(self._read_daemon_messages()) + + self._watchdog.start(self._main_thread_loop) + self._ping_watchdog_task = asyncio.create_task(self._ping_watchdog()) + + # Start all VDBE endpoints. + await self._endpoint_mgr.initialize() + + async def _run_teardown(self): + # Stop all VDBE endpoints (this will also end the sessions). + await self._endpoint_mgr.shutdown() + + # Important for unblocking our message reader thread. + self._input_queue.put(Sentinel(self.NUMERIC_IDENTIFIER)) + + if self._daemon_messages_task is not None: + self._daemon_messages_task.cancel() + self._daemon_messages_task = None + + if self._brad_metrics_reporting_task is not None: + self._brad_metrics_reporting_task.cancel() + self._brad_metrics_reporting_task = None + + self._watchdog.stop() + if self._ping_watchdog_task is not None: + self._ping_watchdog_task.cancel() + self._ping_watchdog_task = None + + async def start_session(self) -> SessionId: + rand_backoff = None + while True: + try: + session_id, _ = await self._sessions.create_new_session() + if self._verbose_logger is not None: + self._verbose_logger.info( + "New session started %d", session_id.value() + ) + return session_id + except ConnectionFailed: + if rand_backoff is None: + rand_backoff = RandomizedExponentialBackoff( + max_retries=20, base_delay_s=0.5, max_delay_s=10.0 + ) + time_to_wait = rand_backoff.wait_time_s() + if time_to_wait is None: + logger.exception( + "Failed to start a new session due to a repeated " + "connection failure (10 retries)." + ) + raise + await asyncio.sleep(time_to_wait) + # Defensively refresh the blueprint and directory before + # retrying. Maybe we are getting outdated endpoint information + # from AWS. + await self._blueprint_mgr.load() + + async def end_session(self, session_id: SessionId) -> None: + await self._sessions.end_session(session_id) + if self._verbose_logger is not None: + self._verbose_logger.info("Session ended %d", session_id.value()) + + def set_shutdown(self) -> None: + self._shutdown_event.set() + + async def _run_query_impl( + self, + query: str, + vdbe_id: int, + session_id: SessionId, + debug_info: Dict[str, Any], + retrieve_schema: bool = False, + ) -> Tuple[RowList, Optional[Schema]]: + session = self._sessions.get_session(session_id) + if session is None: + raise QueryError( + "Invalid session id {}".format(str(session_id)), is_transient=False + ) + + vdbe = self._vdbe_mgr.engine_by_id(vdbe_id) + if vdbe is None: + raise QueryError( + "Invalid VDBE id {}".format(str(vdbe_id)), is_transient=False + ) + + try: + # Remove any trailing or leading whitespace. Remove the trailing + # semicolon if it exists. + # NOTE: BRAD does not yet support having multiple + # semicolon-separated queries in one request. + query = self._clean_query_str(query) + + # TODO: Validate table accesses. + engine_to_use = vdbe.mapped_to + + log_verbose( + logger, + "[S%d] Routing '%s' to %s", + session_id.value(), + query, + engine_to_use, + ) + debug_info["executor"] = engine_to_use + + # 3. Actually execute the query. + try: + connection = session.engines.get_reader_connection(engine_to_use) + cursor = connection.cursor_sync() + # HACK: To work around dialect differences between + # Athena/Aurora/Redshift for now. This should be replaced by + # a more robust translation layer. + if engine_to_use == Engine.Athena and "ascii" in query: + translated_query = query.replace("ascii", "codepoint") + else: + translated_query = query + start = universal_now() + await cursor.execute(translated_query) + end = universal_now() + except ( + pyodbc.ProgrammingError, + pyodbc.Error, + pyodbc.OperationalError, + redshift_errors.InterfaceError, + ssl.SSLEOFError, # Occurs during Redshift restarts. + IndexError, # Occurs during Redshift restarts. + struct.error, # Occurs during Redshift restarts. + psycopg.Error, + psycopg.OperationalError, + psycopg.ProgrammingError, + ) as ex: + is_transient_error = False + if connection.is_connection_lost_error(ex): + connection.mark_connection_lost() + self._schedule_reestablish_connections() + is_transient_error = True + # N.B. We still pass the error to the client. The client + # should retry the query (later on we can add more graceful + # handling here). + + # Error when executing the query. + raise QueryError.from_exception(ex, is_transient_error) + + # Decide whether to log the query. + run_time_s = end - start + run_time_s_float = run_time_s.total_seconds() + # TODO: Should be per VDBE. + self._query_latency_sketch.add(run_time_s_float) + + # Extract and return the results, if any. + try: + result_row_limit = self._config.result_row_limit() + if result_row_limit is not None: + results = [] + for _ in range(result_row_limit): + row = cursor.fetchone_sync() + if row is None: + break + results.append(tuple(row)) + log_verbose( + logger, + "Responded with %d rows (limited to %d rows).", + len(results), + ) + else: + # Using `fetchall_sync()` is lower overhead than the async interface. + results = [tuple(row) for row in cursor.fetchall_sync()] + log_verbose(logger, "Responded with %d rows.", len(results)) + return ( + results, + (cursor.result_schema(results) if retrieve_schema else None), + ) + except (pyodbc.ProgrammingError, psycopg.ProgrammingError): + log_verbose(logger, "No rows produced.") + return ([], Schema.empty() if retrieve_schema else None) + except ( + pyodbc.Error, + pyodbc.OperationalError, + psycopg.Error, + psycopg.OperationalError, + ) as ex: + is_transient_error = False + if connection.is_connection_lost_error(ex): + connection.mark_connection_lost() + self._schedule_reestablish_connections() + is_transient_error = True + + raise QueryError.from_exception(ex, is_transient_error) + + except QueryError as ex: + # This is an expected exception. We catch and re-raise it here to + # avoid triggering the handler below. + logger.debug("Query error: %s", repr(ex)) + if self._verbose_logger is not None: + if ex.is_transient(): + self._verbose_logger.exception("Transient error") + else: + self._verbose_logger.exception("Non-transient error") + raise + except Exception as ex: + logger.exception("Encountered unexpected exception when handling request.") + raise QueryError.from_exception(ex) + + async def _read_daemon_messages(self) -> None: + assert self._input_queue is not None + loop = asyncio.get_running_loop() + while True: + try: + message = await loop.run_in_executor(None, self._input_queue.get) + if message.fe_index != self.NUMERIC_IDENTIFIER: + logger.warning( + "Received message with invalid front end index. Expected %d. Received %d.", + self.NUMERIC_IDENTIFIER, + message.fe_index, + ) + continue + + if isinstance(message, ShutdownFrontEnd): + logger.debug("The BRAD VDBE front end is initiating a shut down...") + loop.create_task(_orchestrate_shutdown(self)) + break + + elif isinstance(message, NewBlueprint): + logger.info( + "Received notification to update to blueprint version %d", + message.version, + ) + # This refreshes any cached state that depends on the old blueprint. + await self._run_blueprint_update( + message.version, message.updated_directory + ) + # Tell the daemon that we have updated. + self._output_queue.put( + NewBlueprintAck(self.NUMERIC_IDENTIFIER, message.version), + block=False, + ) + logger.info( + "Acknowledged update to blueprint version %d", message.version + ) + + elif isinstance(message, ReconcileVirtualInfrastructure): + self._vdbe_mgr.update_infra(message.virtual_infra) + num_added, num_removed = await self._endpoint_mgr.reconcile() + self._output_queue.put( + ReconcileVirtualInfrastructureAck( + self.NUMERIC_IDENTIFIER, num_added, num_removed + ), + block=False, + ) + logger.info( + "Acknowledged virtual infrastructure update. Added %d and removed %d.", + num_added, + num_removed, + ) + + else: + logger.info("Received message from the daemon: %s", message) + except Exception as ex: + if not isinstance(ex, asyncio.CancelledError): + logger.exception( + "Unexpected error when handling message from the daemon." + ) + + async def _report_metrics_to_daemon(self) -> None: + try: + # We want to stagger the reports across the front ends to avoid + # overwhelming the daemon. + await asyncio.sleep(0.1 * 10) + + while True: + # Ideally we adjust for delays here too. + await asyncio.sleep( + self._config.front_end_metrics_reporting_period_seconds + ) + + # If the input queue is full, we just drop this message. + metrics_report = VdbeMetricsReport.from_data( + self.NUMERIC_IDENTIFIER, + [(0, self._query_latency_sketch)], + ) + self._output_queue.put_nowait(metrics_report) + + query_p90 = self._query_latency_sketch.get_quantile_value(0.9) + if query_p90 is not None: + logger.debug("Query latency p90 (s): %.4f", query_p90) + + self._reset_latency_sketches() + + except Exception as ex: + if not isinstance(ex, asyncio.CancelledError): + # This should be a fatal error. + logger.exception("Unexpected error in the metrics reporting task.") + + def _clean_query_str(self, raw_sql: str) -> str: + sql = raw_sql.strip() + if sql.endswith(";"): + sql = sql[:-1] + return sql.strip() + + async def _ping_watchdog(self) -> None: + try: + while True: + await asyncio.sleep(60.0) # TODO: Hardcoded + self._watchdog.ping() + except Exception as ex: + if not isinstance(ex, asyncio.CancelledError): + logger.exception("Watchdog ping task encountered exception.") + + async def _run_blueprint_update( + self, version: int, updated_directory: Directory + ) -> None: + await self._blueprint_mgr.load(skip_directory_refresh=True) + self._blueprint_mgr.get_directory().update_to_directory(updated_directory) + active_version = self._blueprint_mgr.get_active_blueprint_version() + if version != active_version: + logger.error( + "Retrieved active blueprint version (%d) is not the same as the notified version (%d).", + active_version, + version, + ) + return + + directory = self._blueprint_mgr.get_directory() + logger.info("Loaded new directory: %s", directory) + if self._monitor is not None: + self._monitor.update_metrics_sources() + await self._sessions.add_and_refresh_connections() + # NOTE: This will cause any pending queries on the to-be-removed + # connections to be cancelled. We consider this behavior to be + # acceptable. + await self._sessions.remove_connections() + logger.info("Completed transition to blueprint version %d", version) + + def _schedule_reestablish_connections(self) -> None: + if self._reestablish_connections_task is not None: + return + self._reestablish_connections_task = asyncio.create_task( + self._do_reestablish_connections() + ) + + async def _do_reestablish_connections(self) -> None: + try: + # FIXME: This approach is not ideal because we introduce concurrent + # access to the session manager. + rand_backoff = None + + while True: + if self._verbose_logger is not None: + self._verbose_logger.info( + "Attempting to re-establish lost connections." + ) + + report = await self._sessions.reestablish_connections() + + if self._verbose_logger is not None: + self._verbose_logger.info("%s", str(report)) + + if report.all_succeeded(): + logger.debug("Re-established connections successfully.") + if self._verbose_logger is not None: + self._verbose_logger.debug( + "Re-established connections successfully." + ) + self._reestablish_connections_task = None + break + + if rand_backoff is None: + rand_backoff = RandomizedExponentialBackoff( + max_retries=100, + base_delay_s=1.0, + max_delay_s=timedelta(minutes=1).total_seconds(), + ) + + wait_time = rand_backoff.wait_time_s() + if wait_time is None: + logger.warning( + "Abandoning connection re-establishment due to too many failures" + ) + # N.B. We purposefully do not clear the + # `_reestablish_connections_task` variable. + break + else: + await asyncio.sleep(wait_time) + + # N.B. We should not refresh the blueprint/directory here + # because it can lead to AWS throttling. The directory only + # changes during a blueprint transition; the daemon always + # provides the latest directory to the front end on a + # transition. + + except: # pylint: disable=bare-except + logger.exception("Unexpected failure when reestablishing connections.") + self._reestablish_connections_task = None + + def _reset_latency_sketches(self) -> None: + # TODO: Store per VDBE. + sketch_rel_accuracy = 0.01 + self._query_latency_sketch = DDSketch(relative_accuracy=sketch_rel_accuracy) + + +async def _orchestrate_shutdown(fe: BradVdbeFrontEnd) -> None: + fe.set_shutdown() + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + # We need to do this a second time because stopping the grpc server(s) + # creates additional tasks that need to be awaited. Unfortunately, their API + # does not return a future so we cannot wait for them in the correct + # shutdown spot. + remaining = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + await asyncio.gather(*remaining, return_exceptions=True) + + loop = asyncio.get_event_loop() + loop.stop() diff --git a/src/brad/ui/manager_impl.py b/src/brad/ui/manager_impl.py index 9e068640..4e08181b 100644 --- a/src/brad/ui/manager_impl.py +++ b/src/brad/ui/manager_impl.py @@ -219,7 +219,7 @@ async def get_predicted_changes(args: PredictedChangesArgs) -> DisplayableBluepr @app.post("/api/1/vdbe") -def create_vdbe(engine: CreateVirtualEngineArgs) -> VirtualEngine: +async def create_vdbe(engine: CreateVirtualEngineArgs) -> VirtualEngine: assert manager is not None assert manager.vdbe_mgr is not None @@ -231,11 +231,11 @@ def create_vdbe(engine: CreateVirtualEngineArgs) -> VirtualEngine: if engine.p90_latency_slo_ms <= 0: raise HTTPException(400, "p90_latency_slo_ms must be positive.") - return manager.vdbe_mgr.add_engine(engine) + return await manager.vdbe_mgr.add_engine(engine) @app.put("/api/1/vdbe") -def update_vdbe(engine: VirtualEngine) -> VirtualEngine: +async def update_vdbe(engine: VirtualEngine) -> VirtualEngine: assert manager is not None assert manager.vdbe_mgr is not None @@ -248,18 +248,18 @@ def update_vdbe(engine: VirtualEngine) -> VirtualEngine: raise HTTPException(400, "p90_latency_slo_ms must be positive.") try: - return manager.vdbe_mgr.update_engine(engine) + return await manager.vdbe_mgr.update_engine(engine) except ValueError as ex: raise HTTPException(400, str(ex)) from ex @app.delete("/api/1/vdbe/{engine_id}") -def delete_vdbe(engine_id: int) -> None: +async def delete_vdbe(engine_id: int) -> None: assert manager is not None assert manager.vdbe_mgr is not None try: - manager.vdbe_mgr.delete_engine(engine_id) + await manager.vdbe_mgr.delete_engine(engine_id) except ValueError as ex: raise HTTPException(400, str(ex)) from ex diff --git a/src/brad/vdbe/manager.py b/src/brad/vdbe/manager.py index 33f5170a..4b316137 100644 --- a/src/brad/vdbe/manager.py +++ b/src/brad/vdbe/manager.py @@ -1,5 +1,5 @@ import pathlib -from typing import List, Optional +from typing import List, Optional, Callable, Awaitable from brad.vdbe.models import ( VirtualInfrastructure, VirtualEngine, @@ -14,18 +14,26 @@ class VdbeManager: @classmethod def load_from( - cls, serialized_infra_json: pathlib.Path, starting_port: int + cls, + serialized_infra_json: pathlib.Path, + starting_port: int, + apply_infra: Callable[[VirtualInfrastructure], Awaitable[None]], ) -> "VdbeManager": with open(serialized_infra_json, "r", encoding="utf-8") as f: infra = VirtualInfrastructure.model_validate_json(f.read()) hostname = _get_hostname() - return cls(infra, hostname, starting_port) + return cls(infra, hostname, starting_port, apply_infra) def __init__( - self, infra: VirtualInfrastructure, hostname: Optional[str], starting_port: int + self, + infra: VirtualInfrastructure, + hostname: Optional[str], + starting_port: int, + apply_infra: Callable[[VirtualInfrastructure], Awaitable[None]], ) -> None: self._infra = infra self._hostname = hostname + self._apply_infra = apply_infra self._next_port = starting_port self._next_id = 1 for engine in self._infra.engines: @@ -43,7 +51,7 @@ def infra(self) -> VirtualInfrastructure: def engines(self) -> List[VirtualEngine]: return self._infra.engines - def add_engine(self, create: CreateVirtualEngineArgs) -> VirtualEngine: + async def add_engine(self, create: CreateVirtualEngineArgs) -> VirtualEngine: engine = VirtualEngine( internal_id=self._next_id, name=create.name, @@ -60,22 +68,25 @@ def add_engine(self, create: CreateVirtualEngineArgs) -> VirtualEngine: engine.endpoint = f"{self._hostname}:{self._assign_port()}" self._infra.engines.append(engine) + await self._apply_infra(self._infra) return engine - def update_engine(self, engine: VirtualEngine) -> VirtualEngine: + async def update_engine(self, engine: VirtualEngine) -> VirtualEngine: if engine.endpoint is None and self._hostname is not None: engine.endpoint = f"{self._hostname}:{self._assign_port()}" for i in range(len(self._infra.engines)): if self._infra.engines[i].internal_id == engine.internal_id: self._infra.engines[i] = engine + await self._apply_infra(self._infra) return engine raise ValueError(f"Engine with id {engine.internal_id} not found") - def delete_engine(self, engine_id: int) -> None: + async def delete_engine(self, engine_id: int) -> None: for engine in self._infra.engines: if engine.internal_id == engine_id: self._infra.engines.remove(engine) + await self._apply_infra(self._infra) return raise ValueError(f"Engine with id {engine_id} not found") @@ -85,6 +96,27 @@ def _assign_port(self) -> int: return port +class VdbeFrontEndManager: + """ + Used on the front end. Provides a read-only view of the current VDBE state. + """ + + def __init__(self, initial_infra: VirtualInfrastructure) -> None: + self._infra = initial_infra + + def engines(self) -> List[VirtualEngine]: + return self._infra.engines + + def engine_by_id(self, engine_id: int) -> Optional[VirtualEngine]: + for engine in self._infra.engines: + if engine.internal_id == engine_id: + return engine + return None + + def update_infra(self, new_infra: VirtualInfrastructure) -> None: + self._infra = new_infra + + def _get_hostname() -> str: import socket