Skip to content

Commit

Permalink
Merge pull request #298 from CybercentreCanada/persistent-service-update
Browse files Browse the repository at this point in the history
Persistent service update
  • Loading branch information
cccs-douglass committed Sep 16, 2021
2 parents af5f2d2 + 8ff7f2b commit bd00c2f
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 70 deletions.
23 changes: 12 additions & 11 deletions assemblyline_core/dispatching/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import uuid
import os
import threading
Expand Down Expand Up @@ -57,9 +58,9 @@ class Action(enum.IntEnum):
class DispatchAction:
kind: Action
sid: str = dataclasses.field(compare=False)
sha: str = dataclasses.field(compare=False, default=None)
service_name: str = dataclasses.field(compare=False, default=None)
worker_id: str = dataclasses.field(compare=False, default=None)
sha: Optional[str] = dataclasses.field(compare=False, default=None)
service_name: Optional[str] = dataclasses.field(compare=False, default=None)
worker_id: Optional[str] = dataclasses.field(compare=False, default=None)
data: Any = dataclasses.field(compare=False, default=None)


Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, submission, completed_queue):
self.dropped_files = set()

self.service_results: Dict[Tuple[str, str], ResultSummary] = {}
self.service_errors: Dict[Tuple[str, str], dict] = {}
self.service_errors: Dict[Tuple[str, str], str] = {}
self.service_attempts: Dict[Tuple[str, str], int] = defaultdict(int)
self.queue_keys: Dict[Tuple[str, str], bytes] = {}
self.running_services: Set[Tuple[str, str]] = set()
Expand Down Expand Up @@ -198,8 +199,8 @@ def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None,
self.apm_client = elasticapm.Client(server_url=self.config.core.metrics.apm_server.server_url,
service_name="dispatcher")

self._service_timeouts = TimeoutTable()
self._submission_timeouts = TimeoutTable()
self._service_timeouts: TimeoutTable[Tuple[str, str, str], str] = TimeoutTable()
self._submission_timeouts: TimeoutTable[str, None] = TimeoutTable()

# Setup queues for work to be divided into
self.process_queues: List[PriorityQueue[DispatchAction]] = [PriorityQueue() for _ in range(RESULT_THREADS)]
Expand Down Expand Up @@ -404,7 +405,7 @@ def dispatch_file(self, task: SubmissionTask, sha256: str) -> bool:

# Go through each round of the schedule removing complete/failed services
# Break when we find a stage that still needs processing
outstanding = {}
outstanding: dict[str, Service] = {}
started_stages = []
with elasticapm.capture_span('check_result_table'):
while schedule and not outstanding:
Expand Down Expand Up @@ -1156,8 +1157,8 @@ def handle_commands(self):

command = DispatcherCommandMessage(message)
if command.kind == CREATE_WATCH:
payload: CreateWatch = command.payload()
self.setup_watch_queue(payload.submission, payload.queue_name)
watch_payload: CreateWatch = command.payload()
self.setup_watch_queue(watch_payload.submission, watch_payload.queue_name)
elif command.kind == LIST_OUTSTANDING:
payload: ListOutstanding = command.payload()
self.list_outstanding(payload.submission, payload.response_queue)
Expand Down Expand Up @@ -1192,7 +1193,7 @@ def setup_watch_queue(self, sid, queue_name):
@elasticapm.capture_span(span_type='dispatcher')
def list_outstanding(self, sid: str, queue_name: str):
response_queue = NamedQueue(queue_name, host=self.redis)
outstanding = defaultdict(int)
outstanding: defaultdict[str, int] = defaultdict(int)
task = self.tasks.get(sid)
if task:
for sha, service_name in list(task.queue_keys.keys()):
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def timeout_backstop(self):

def recover_submission(self, sid: str, message: str) -> bool:
# Make sure we can load the submission body
submission: Submission = self.datastore.submission.get_if_exists(sid)
submission: Optional[Submission] = self.datastore.submission.get_if_exists(sid)
if not submission:
return False
if submission.state != 'submitted':
Expand Down
15 changes: 8 additions & 7 deletions assemblyline_core/dispatching/schedules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, cast
from __future__ import annotations
from typing import Dict, cast

import logging
import os
Expand Down Expand Up @@ -26,7 +27,7 @@ def __init__(self, datastore: AssemblylineDatastore, config: Config, redis):
self.services = cast(Dict[str, Service], CachedObject(self._get_services))
self.service_stage = get_service_stage_hash(redis)

def build_schedule(self, submission: Submission, file_type: str) -> List[Dict[str, Service]]:
def build_schedule(self, submission: Submission, file_type: str) -> list[dict[str, Service]]:
all_services = dict(self.services)

# Load the selected and excluded services by category
Expand All @@ -38,7 +39,7 @@ def build_schedule(self, submission: Submission, file_type: str) -> List[Dict[st
selected = self.expand_categories(submission.params.services.selected)

# Add all selected, accepted, and not rejected services to the schedule
schedule: List[Dict[str, Service]] = [{} for _ in self.config.services.stages]
schedule: list[dict[str, Service]] = [{} for _ in self.config.services.stages]
services = list(set(selected) - set(excluded) - set(runtime_excluded))
selected = []
skipped = []
Expand All @@ -61,7 +62,7 @@ def build_schedule(self, submission: Submission, file_type: str) -> List[Dict[st

return schedule

def expand_categories(self, services: List[str]) -> List[str]:
def expand_categories(self, services: list[str]) -> list[str]:
"""Expands the names of service categories found in the list of services.
Args:
Expand All @@ -74,7 +75,7 @@ def expand_categories(self, services: List[str]) -> List[str]:
categories = self.categories()

found_services = []
seen_categories = set()
seen_categories: set[str] = set()
while services:
name = services.pop()

Expand All @@ -94,8 +95,8 @@ def expand_categories(self, services: List[str]) -> List[str]:
# Use set to remove duplicates, set is more efficient in batches
return list(set(found_services))

def categories(self) -> Dict[str, List[str]]:
all_categories = {}
def categories(self) -> Dict[str, list[str]]:
all_categories: dict[str, list[str]] = {}
for service in self.services.values():
try:
all_categories[service.category].append(service.name)
Expand Down
13 changes: 7 additions & 6 deletions assemblyline_core/dispatching/timeout.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
"""
A data structure encapsulating the timeout logic for the dispatcher.
"""
from __future__ import annotations
import queue
import time
from queue import PriorityQueue
from dataclasses import dataclass, field
from typing import TypeVar, Dict
from typing import TypeVar, Generic, Hashable

KeyType = TypeVar('KeyType')
KeyType = TypeVar('KeyType', bound=Hashable)
DataType = TypeVar('DataType')


@dataclass(order=True)
class TimeoutItem:
class TimeoutItem(Generic[KeyType, DataType]):
expiry: float
key: KeyType = field(compare=False)
data: DataType = field(compare=False)


class TimeoutTable:
class TimeoutTable(Generic[KeyType, DataType]):
def __init__(self):
self.timeout_queue: PriorityQueue[TimeoutItem] = PriorityQueue()
self.event_data: Dict[KeyType, TimeoutItem] = {}
self.event_data: dict[KeyType, TimeoutItem] = {}

def set(self, key: KeyType, timeout: float, data: DataType):
# If a timeout is set repeatedly with the same key, only the last one will count
Expand All @@ -37,7 +38,7 @@ def clear(self, key: KeyType):
def __contains__(self, item):
return item in self.event_data

def timeouts(self) -> Dict[KeyType, DataType]:
def timeouts(self) -> dict[KeyType, DataType]:
found = {}
try:
now = time.time()
Expand Down
14 changes: 7 additions & 7 deletions assemblyline_core/ingester/ingester.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
_retry_delay = 60 * 4 # Wait 4 minutes to retry
_max_time = 2 * 24 * 60 * 60 # Wait 2 days for responses.
HOUR_IN_SECONDS = 60 * 60
INGEST_THREADS = environ.get('INGESTER_INGEST_THREADS', 1)
SUBMIT_THREADS = environ.get('INGESTER_SUBMIT_THREADS', 4)
INGEST_THREADS = int(environ.get('INGESTER_INGEST_THREADS', 1))
SUBMIT_THREADS = int(environ.get('INGESTER_SUBMIT_THREADS', 4))


def must_drop(length: int, maximum: int) -> bool:
Expand Down Expand Up @@ -79,11 +79,11 @@ def must_drop(length: int, maximum: int) -> bool:
def determine_resubmit_selected(selected: List[str], resubmit_to: List[str]) -> Optional[List[str]]:
resubmit_selected = None

selected = set(selected)
resubmit_to = set(resubmit_to)
_selected = set(selected)
_resubmit_to = set(resubmit_to)

if not selected.issuperset(resubmit_to):
resubmit_selected = sorted(selected.union(resubmit_to))
if not _selected.issuperset(_resubmit_to):
resubmit_selected = sorted(_selected.union(_resubmit_to))

return resubmit_selected

Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(self, datastore=None, logger=None, classification=None, redis=None,
self.retry_queue = PriorityQueue('m-retry', self.redis_persist)

# Internal, timeout watch queue
self.timeout_queue = PriorityQueue('m-timeout', self.redis)
self.timeout_queue: PriorityQueue[str] = PriorityQueue('m-timeout', self.redis)

# Internal, queue for processing duplicates
# When a duplicate file is detected (same cache key => same file, and same
Expand Down
26 changes: 13 additions & 13 deletions assemblyline_core/scaler/controllers/kubernetes_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from typing import Optional, Tuple

import urllib3
import kubernetes
from kubernetes import client, config
from kubernetes.client import ExtensionsV1beta1Deployment, ExtensionsV1beta1DeploymentSpec, V1PodTemplateSpec, \

from kubernetes import client, config, watch
from kubernetes.client import V1Deployment, V1DeploymentSpec, V1PodTemplateSpec, \
V1PodSpec, V1ObjectMeta, V1Volume, V1Container, V1VolumeMount, V1EnvVar, V1ConfigMapVolumeSource, \
V1PersistentVolumeClaimVolumeSource, V1LabelSelector, V1ResourceRequirements, V1PersistentVolumeClaim, \
V1PersistentVolumeClaimSpec, V1NetworkPolicy, V1NetworkPolicySpec, V1NetworkPolicyEgressRule, V1NetworkPolicyPeer, \
Expand Down Expand Up @@ -43,7 +43,7 @@
}


class TypelessWatch(kubernetes.watch.Watch):
class TypelessWatch(watch.Watch):
"""A kubernetes watch object that doesn't marshal the response."""

def get_return_type(self, func):
Expand Down Expand Up @@ -476,16 +476,16 @@ def _create_containers(self, service_name: str, deployment_name: str, container_
)]

def _create_deployment(self, service_name: str, deployment_name: str, docker_config: DockerConfig,
shutdown_seconds: int, scale: int, labels:dict[str,str]=None,
shutdown_seconds: int, scale: int, labels:dict[str,str]=None,
volumes:list[V1Volume]=None, mounts:list[V1VolumeMount]=None,
core_mounts:bool=False, change_key:str=''):
# Build a cache key to check for changes, just trying to only patch what changed
# Build a cache key to check for changes, just trying to only patch what changed
# will still potentially result in a lot of restarts due to different kubernetes
# systems returning differently formatted data
change_key = (
deployment_name + change_key + str(docker_config) + str(shutdown_seconds) +
deployment_name + change_key + str(docker_config) + str(shutdown_seconds) +
str(sorted((labels or {}).items())) + str(volumes) + str(mounts) + str(core_mounts)
)
)

# Check if a deployment already exists, and if it does check if it has the same change key set
replace = None
Expand All @@ -498,7 +498,7 @@ def _create_deployment(self, service_name: str, deployment_name: str, docker_con
except ApiException as error:
if error.status != 404:
raise

# If we have been given a username or password for the registry, we have to
# update it, if we haven't been, make sure its been cleaned up in the system
# so we don't leave passwords lying around
Expand Down Expand Up @@ -543,7 +543,7 @@ def _create_deployment(self, service_name: str, deployment_name: str, docker_con
all_labels['section'] = 'core'
all_labels.update(labels or {})

# Build set of volumes, first the global mounts, then the core specific ones,
# Build set of volumes, first the global mounts, then the core specific ones,
# then the ones specific to this container only
all_volumes: list[V1Volume] = []
all_mounts: list[V1VolumeMount] = []
Expand Down Expand Up @@ -574,14 +574,14 @@ def _create_deployment(self, service_name: str, deployment_name: str, docker_con
spec=pod,
)

spec = ExtensionsV1beta1DeploymentSpec(
spec = V1DeploymentSpec(
replicas=int(scale),
revision_history_limit=0,
selector=V1LabelSelector(match_labels=all_labels),
template=template,
)

deployment = ExtensionsV1beta1Deployment(
deployment = V1Deployment(
kind="Deployment",
metadata=metadata,
spec=spec,
Expand Down Expand Up @@ -645,7 +645,7 @@ def stop_container(self, service_name, container_id):

def restart(self, service):
self._create_deployment(service.name, self._deployment_name(service.name), service.container_config,
service.shutdown_seconds, self.get_target(service.name),
service.shutdown_seconds, self.get_target(service.name),
change_key=service.config_blob)

def get_running_container_names(self):
Expand Down
Loading

0 comments on commit bd00c2f

Please sign in to comment.