diff --git a/assemblyline_core/scaler/controllers/docker_ctl.py b/assemblyline_core/scaler/controllers/docker_ctl.py index e3594c8e..b6892376 100644 --- a/assemblyline_core/scaler/controllers/docker_ctl.py +++ b/assemblyline_core/scaler/controllers/docker_ctl.py @@ -243,7 +243,7 @@ def memory_info(self): self.log.debug(f'Total Memory available {mem}/{self._info["MemTotal"]/mega}') return mem, total_mem - def get_target(self, service_name): + def get_target(self, service_name: str) -> int: """Get how many instances of a service we expect to be running. Since we start our containers with 'restart always' we just need to count how many diff --git a/assemblyline_core/scaler/controllers/interface.py b/assemblyline_core/scaler/controllers/interface.py index 5f22c6c5..b3222dcb 100644 --- a/assemblyline_core/scaler/controllers/interface.py +++ b/assemblyline_core/scaler/controllers/interface.py @@ -24,11 +24,11 @@ def cpu_info(self): """Return free and total memory in the system.""" raise NotImplementedError() - def free_cpu(self): + def free_cpu(self) -> float: """Number of cores available for reservation.""" return self.cpu_info()[0] - def free_memory(self): + def free_memory(self) -> float: """Megabytes of RAM that has not been reserved.""" return self.memory_info()[0] diff --git a/assemblyline_core/scaler/controllers/kubernetes_ctl.py b/assemblyline_core/scaler/controllers/kubernetes_ctl.py index e3c06dbe..06541db1 100644 --- a/assemblyline_core/scaler/controllers/kubernetes_ctl.py +++ b/assemblyline_core/scaler/controllers/kubernetes_ctl.py @@ -7,7 +7,7 @@ import os import threading import weakref -from typing import Dict, List, Optional, Tuple +from typing import Optional, Tuple import urllib3 import kubernetes @@ -48,7 +48,7 @@ def get_return_type(self, func): return None -def median(values: List[float]) -> float: +def median(values: list[float]) -> float: if len(values) == 0: return 0 return values[len(values)//2] @@ -149,15 +149,15 @@ def __init__(self, logger, namespace, prefix, priority, cpu_reservation, labels= self.cpu_reservation: float = max(0.0, min(cpu_reservation, 1.0)) self.logger = logger self.log_level: str = log_level - self._labels: Dict[str, str] = labels or {} + self._labels: dict[str, str] = labels or {} self.apps_api = client.AppsV1Api() self.api = client.CoreV1Api() self.net_api = client.NetworkingV1Api() self.namespace: str = namespace - self.config_volumes: Dict[str, V1Volume] = {} - self.config_mounts: Dict[str, V1VolumeMount] = {} - self.core_config_volumes: Dict[str, V1Volume] = {} - self.core_config_mounts: Dict[str, V1VolumeMount] = {} + self.config_volumes: dict[str, V1Volume] = {} + self.config_mounts: dict[str, V1VolumeMount] = {} + self.core_config_volumes: dict[str, V1Volume] = {} + self.core_config_mounts: dict[str, V1VolumeMount] = {} self._external_profiles = weakref.WeakValueDictionary() self._service_limited_env: dict[str, dict[str, str]] = defaultdict(dict) @@ -191,7 +191,7 @@ def __init__(self, logger, namespace, prefix, priority, cpu_reservation, labels= pod_background = threading.Thread(target=self._loop_forever(self._monitor_pods), daemon=True) pod_background.start() - self._deployment_targets: Dict[str, int] = {} + self._deployment_targets: dict[str, int] = {} deployment_background = threading.Thread(target=self._loop_forever(self._monitor_deployments), daemon=True) deployment_background.start() @@ -434,7 +434,7 @@ def memory_info(self): return self._node_pool_max_ram - self._pod_used_ram, self._node_pool_max_ram @staticmethod - def _create_metadata(deployment_name: str, labels: Dict[str, str]): + def _create_metadata(deployment_name: str, labels: dict[str, str]): return V1ObjectMeta(name=deployment_name, labels=labels) def _create_volumes(self, core_mounts=False): @@ -585,7 +585,7 @@ def get_target(self, service_name: str) -> int: """Get the target for running instances of a service.""" return self._deployment_targets.get(service_name, 0) - def get_targets(self) -> Dict[str, int]: + def get_targets(self) -> dict[str, int]: """Get the target for running instances of all services.""" return self._deployment_targets @@ -674,8 +674,20 @@ def start_stateful_container(self, service_name: str, container_name: str, )) mounts.append(V1VolumeMount(mount_path=volume_spec.mount_path, name=mount_name)) + # Read the key being used for the deployment instance or generate a new one + try: + instance_key = uuid.uuid4().hex + old_deployment = self.apps_api.read_namespaced_deployment(deployment_name, self.namespace) + for container in old_deployment.spec.template.spec.containers: + for env in container.env: + if env.name == 'AL_INSTANCE_KEY': + instance_key = env.value + break + except ApiException as error: + if error.status != 404: + raise + # Setup the deployment itself - instance_key = uuid.uuid4().hex labels['container'] = container_name spec.container.environment.append({'name': 'AL_INSTANCE_KEY', 'value': instance_key}) self._create_deployment(service_name, deployment_name, spec.container, diff --git a/assemblyline_core/scaler/scaler_server.py b/assemblyline_core/scaler/scaler_server.py index cfaa2de8..92234423 100644 --- a/assemblyline_core/scaler/scaler_server.py +++ b/assemblyline_core/scaler/scaler_server.py @@ -1,11 +1,12 @@ """ An auto-scaling service specific to Assemblyline services. """ +from __future__ import annotations import functools import threading from collections import defaultdict from string import Template -from typing import Dict, List, Optional +from typing import Optional, Any import os import math import time @@ -20,9 +21,11 @@ from assemblyline.remote.datatypes.queues.priority import PriorityQueue, length as pq_length from assemblyline.remote.datatypes.exporting_counter import export_metrics_once from assemblyline.remote.datatypes.hash import ExpiringHash +from assemblyline.remote.datatypes.events import EventWatcher from assemblyline.odm.models.service import Service, DockerConfig from assemblyline.odm.messages.scaler_heartbeat import Metrics from assemblyline.odm.messages.scaler_status_heartbeat import Status +from assemblyline.odm.messages.changes import ServiceChange, Operation from assemblyline.common.forge import get_service_queue from assemblyline.common.constants import SCALER_TIMEOUT_QUEUE, SERVICE_STATE_HASH, ServiceStatus from assemblyline_core.scaler.controllers import KubernetesController @@ -36,7 +39,7 @@ # How often (in seconds) to download new service data, try to scale managed services, # and download more metrics data respectively -SERVICE_SYNC_INTERVAL = 30 +SERVICE_SYNC_INTERVAL = 60 * 30 # Every half hour SCALE_INTERVAL = 5 METRIC_SYNC_INTERVAL = 0.5 CONTAINER_EVENTS_LOG_INTERVAL = 2 @@ -62,7 +65,7 @@ @contextmanager -def apm_span(client, span_name: str): +def apm_span(client: Optional[elasticapm.Client], span_name: str): try: if client: client.begin_transaction(APM_SPAN_TYPE) @@ -82,9 +85,9 @@ class Pool: the context ends. """ - def __init__(self, size=10): + def __init__(self, size: int = 10): self.pool = concurrent.futures.ThreadPoolExecutor(size) - self.futures = [] + self.futures: list[concurrent.futures.Future[Any]] = [] def __enter__(self): return self @@ -106,8 +109,8 @@ class ServiceProfile: This includes how the service should be run, and conditions related to the scaling of the service. """ - def __init__(self, name, container_config: DockerConfig, config_hash=0, min_instances=0, max_instances=None, - growth: float = 600, shrink: Optional[float] = None, backlog=500, queue=None, shutdown_seconds=30): + def __init__(self, name: str, container_config: DockerConfig, config_hash:int=0, min_instances:int=0, max_instances:int=None, + growth: float = 600, shrink: Optional[float] = None, backlog:int=500, queue=None, shutdown_seconds:int=30): """ :param name: Name of the service to manage :param container_config: Instructions on how to start this service @@ -127,8 +130,9 @@ def __init__(self, name, container_config: DockerConfig, config_hash=0, min_inst self.config_hash = config_hash # How many instances we want, and can have - self.min_instances = self._min_instances = max(0, int(min_instances)) - self._max_instances = max(0, int(max_instances)) if max_instances else float('inf') + self.min_instances: int = max(0, int(min_instances)) + self._min_instances: int = self.min_instances + self._max_instances: float = max(0, int(max_instances)) if max_instances else float('inf') self.desired_instances: int = 0 self.target_instances: int = 0 self.running_instances: int = 0 @@ -160,12 +164,12 @@ def instance_limit(self): return self._max_instances @property - def max_instances(self): + def max_instances(self) -> int: # Adjust the max_instances based on the number that is already requested # this keeps the scaler from running way ahead with its demands when resource caps are reached return min(self._max_instances, self.target_instances + 2) - def update(self, delta, instances, backlog, duty_cycle): + def update(self, delta: float, instances: int, backlog: int, duty_cycle: float): self.last_update = time.time() self.running_instances = instances self.queue_length = backlog @@ -231,8 +235,10 @@ def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count_lock = threading.Lock() - self.error_count: Dict[str, List[float]] = {} + self.error_count: dict[str, list[float]] = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30*60) + self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) + self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) labels = { 'app': 'assemblyline', @@ -268,7 +274,7 @@ def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): self.controller.global_mounts.append((CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) # Information about services - self.profiles: Dict[str, ServiceProfile] = {} + self.profiles: dict[str, ServiceProfile] = {} self.profiles_lock = threading.RLock() # Prepare a single threaded scheduler @@ -311,6 +317,7 @@ def add_service(self, profile: ServiceProfile): self.controller.add_profile(profile, scale=profile.desired_instances) def try_run(self): + self.service_change_watcher.start() self.maintain_threads({ 'Log Container Events': self.log_container_events, 'Process Timeouts': self.process_timeouts, @@ -322,102 +329,28 @@ def try_run(self): def stop(self): super().stop() + self.service_change_watcher.stop() self.controller.stop() + def _handle_service_change_event(self, data: ServiceChange): + if data.operation == Operation.Removed: + self.log.info(f'Service appears to be deleted, removing {data.name}') + stage = self.get_service_stage(data.name) + self.stop_service(data.name, stage) + else: + self._sync_service(self.datastore.get_service_with_delta(data.name)) + def sync_services(self): while self.running: with apm_span(self.apm_client, 'sync_services'): - default_settings = self.config.core.scaler.service_defaults - image_variables = defaultdict(str) - image_variables.update(self.config.services.image_variables) with self.profiles_lock: current_services = set(self.profiles.keys()) - discovered_services = [] + discovered_services: list[str] = [] # Get all the service data for service in self.datastore.list_all_services(full=True): - service: Service = service - name = service.name - stage = self.get_service_stage(service.name) - discovered_services.append(name) - - # noinspection PyBroadException - try: - if service.enabled and stage == ServiceStage.Off: - # Enable this service's dependencies - self.controller.prepare_network(service.name, service.docker_config.allow_internet_access) - for _n, dependency in service.dependencies.items(): - dependency.container.image = Template(dependency.container.image) \ - .safe_substitute(image_variables) - self.controller.start_stateful_container( - service_name=service.name, - container_name=_n, - spec=dependency, - labels={'dependency_for': service.name}, - ) - - # Move to the next service stage - if service.update_config and service.update_config.wait_for_update: - self._service_stage_hash.set(name, ServiceStage.Update) - else: - self._service_stage_hash.set(name, ServiceStage.Running) - - if not service.enabled: - self.stop_service(service.name, stage) - continue - - # Check that all enabled services are enabled - if service.enabled and stage == ServiceStage.Running: - # Compute a hash of service properties not include in the docker config, that - # should still result in a service being restarted when changed - config_hash = hash(str(sorted(service.config.items()))) - config_hash = hash((config_hash, str(service.submission_params))) - - # Build the docker config for the service, we are going to either create it or - # update it so we need to know what the current configuration is either way - docker_config = service.docker_config - docker_config.image = Template(docker_config.image).safe_substitute(image_variables) - set_keys = set(var.name for var in docker_config.environment) - for var in default_settings.environment: - if var.name not in set_keys: - docker_config.environment.append(var) - - # Add the service to the list of services being scaled - with self.profiles_lock: - if name not in self.profiles: - self.log.info(f'Adding {service.name} to scaling') - self.add_service(ServiceProfile( - name=name, - min_instances=default_settings.min_instances, - growth=default_settings.growth, - shrink=default_settings.shrink, - config_hash=config_hash, - backlog=default_settings.backlog, - max_instances=service.licence_count, - container_config=docker_config, - queue=get_service_queue(name, self.redis), - # Give service an extra 30 seconds to upload results - shutdown_seconds=service.timeout + 30, - )) - - # Update RAM, CPU, licence requirements for running services - else: - profile = self.profiles[name] - if service.licence_count == 0: - profile._max_instances = float('inf') - else: - profile._max_instances = service.licence_count - - if profile.container_config != docker_config or profile.config_hash != config_hash: - self.log.info(f"Updating deployment information for {name}") - profile.container_config = docker_config - profile.config_hash = config_hash - self.controller.restart(profile) - self.log.info(f"Deployment information for {name} replaced") - - except Exception: - self.log.exception(f"Error applying service settings from: {service.name}") - self.handle_service_error(service.name) + self._sync_service(service) + discovered_services.append(service.name) # Find any services we have running, that are no longer in the database and remove them for stray_service in current_services - set(discovered_services): @@ -427,8 +360,95 @@ def sync_services(self): self.sleep(SERVICE_SYNC_INTERVAL) + def _sync_service(self, service: Service): + name = service.name + stage = self.get_service_stage(service.name) + default_settings = self.config.core.scaler.service_defaults + image_variables: defaultdict[str, str] = defaultdict(str) + image_variables.update(self.config.services.image_variables) + + def prepare_container(docker_config: DockerConfig) -> DockerConfig: + docker_config.image = Template(docker_config.image).safe_substitute(image_variables) + set_keys = set(var.name for var in docker_config.environment) + for var in default_settings.environment: + if var.name not in set_keys: + docker_config.environment.append(var) + return docker_config + + # noinspection PyBroadException + try: + if service.enabled and (stage == ServiceStage.Off or name not in self.profiles): + # Enable this service's dependencies + self.controller.prepare_network(service.name, service.docker_config.allow_internet_access) + for _n, dependency in service.dependencies.items(): + dependency.container = prepare_container(dependency.container) + self.controller.start_stateful_container( + service_name=service.name, + container_name=_n, + spec=dependency, + labels={'dependency_for': service.name} + ) + + # Move to the next service stage + if service.update_config and service.update_config.wait_for_update: + self._service_stage_hash.set(name, ServiceStage.Update) + else: + self._service_stage_hash.set(name, ServiceStage.Running) + + if not service.enabled: + self.stop_service(service.name, stage) + return + + # Check that all enabled services are enabled + if service.enabled and stage == ServiceStage.Running: + # Compute a hash of service properties not include in the docker config, that + # should still result in a service being restarted when changed + config_hash = hash(str(sorted(service.config.items()))) + config_hash = hash((config_hash, str(service.submission_params))) + + # Build the docker config for the service, we are going to either create it or + # update it so we need to know what the current configuration is either way + docker_config = prepare_container(service.docker_config) + + # Add the service to the list of services being scaled + with self.profiles_lock: + if name not in self.profiles: + self.log.info(f'Adding {service.name} to scaling') + self.add_service(ServiceProfile( + name=name, + min_instances=default_settings.min_instances, + growth=default_settings.growth, + shrink=default_settings.shrink, + config_hash=config_hash, + backlog=default_settings.backlog, + max_instances=service.licence_count, + container_config=docker_config, + queue=get_service_queue(name, self.redis), + # Give service an extra 30 seconds to upload results + shutdown_seconds=service.timeout + 30 + )) + + # Update RAM, CPU, licence requirements for running services + else: + profile = self.profiles[name] + if service.licence_count == 0: + profile._max_instances = float('inf') + else: + profile._max_instances = service.licence_count + + if profile.container_config != docker_config or profile.config_hash != config_hash: + self.log.info(f"Updating deployment information for {name}") + profile.container_config = docker_config + profile.config_hash = config_hash + self.controller.restart(profile) + self.log.info(f"Deployment information for {name} replaced") + + except Exception: + self.log.exception(f"Error applying service settings from: {service.name}") + self.handle_service_error(service.name) + @elasticapm.capture_span(span_type=APM_SPAN_TYPE) - def stop_service(self, name, current_stage): + def stop_service(self, name: str, current_stage: ServiceStage): if current_stage != ServiceStage.Off: # Disable this service's dependencies self.controller.stop_containers(labels={ @@ -453,7 +473,7 @@ def update_scaling(self): # Figure out what services are expected to be running and how many with elasticapm.capture_span('read_profiles'): with self.profiles_lock: - all_profiles: Dict[str, ServiceProfile] = copy.deepcopy(self.profiles) + all_profiles: dict[str, ServiceProfile] = copy.deepcopy(self.profiles) raw_targets = self.controller.get_targets() targets = {_p.name: raw_targets.get(_p.name, 0) for _p in all_profiles.values()} @@ -496,7 +516,7 @@ def update_scaling(self): free_memory = self.controller.free_memory() # - def trim(prof: List[ServiceProfile]): + def trim(prof: list[ServiceProfile]): prof = [_p for _p in prof if _p.desired_instances > targets[_p.name]] drop = [_p for _p in prof if _p.cpu > free_cpu or _p.ram > free_memory] if drop: @@ -505,7 +525,7 @@ def trim(prof: List[ServiceProfile]): prof = [_p for _p in prof if _p.cpu <= free_cpu and _p.ram <= free_memory] return prof - remaining_profiles: List[ServiceProfile] = trim(list(all_profiles.values())) + remaining_profiles: list[ServiceProfile] = trim(list(all_profiles.values())) # The target values up until now should be in sync with the container orchestrator # create a copy, so we can track which ones change in the following loop old_targets = dict(targets) @@ -533,7 +553,7 @@ def trim(prof: List[ServiceProfile]): pool.call(self.controller.set_target, name, value) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) - def handle_service_error(self, service_name): + def handle_service_error(self, service_name: str): """Handle an error occurring in the *analysis* service. Errors for core systems should simply be logged, and a best effort to continue made. diff --git a/assemblyline_core/server_base.py b/assemblyline_core/server_base.py index 90d93318..cb9af0c6 100644 --- a/assemblyline_core/server_base.py +++ b/assemblyline_core/server_base.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from assemblyline.datastore.helper import AssemblylineDatastore + from assemblyline.odm.models.config import Config SHUTDOWN_SECONDS_LIMIT = 10 @@ -40,7 +41,7 @@ def __init__(self, component_name: str, logger: logging.Logger = None, shutdown_timeout: float = SHUTDOWN_SECONDS_LIMIT, config=None): super().__init__(name=component_name) al_log.init_logging(component_name) - self.config = config or forge.get_config() + self.config: Config = config or forge.get_config() self.running = None self.stopping = threading.Event() @@ -232,6 +233,12 @@ def stop(self): super().stop() self.main_loop_exit.wait(30) + + def sleep(self, timeout: float): + self.stopping.wait(timeout) + return self.running + + def log_crashes(self, fn): @functools.wraps(fn) def with_logs(*args, **kwargs):