diff --git a/ingredients_tasks/celary.py b/ingredients_tasks/celary.py index fdd81f3..cb1fbcd 100644 --- a/ingredients_tasks/celary.py +++ b/ingredients_tasks/celary.py @@ -34,12 +34,12 @@ def connect(self): 'confirm_publish': True }, task_acks_late=True, - task_reject_on_worker_last=True, + task_reject_on_worker_lost=True, task_ignore_result=True, task_store_errors_even_if_ignored=False, task_soft_time_limit=300, # 5 minutes task_time_limit=600, # 10 minutes - worker_prefetch_multiplier=1, + worker_prefetch_multiplier=1, # One worker process can only do one type of task at a time include=include, task_queues=task_queues, task_routes=task_routes @@ -53,12 +53,12 @@ def populate_tasks(self, *args): task_queues = set() task_routes = {} - from ingredients_tasks.tasks.tasks import ImageTask, InstanceTask, NetworkTask + from ingredients_tasks.tasks.tasks import NetworkTask, ImageTask for task_module in args: include.append(task_module.__name__) for name, method in inspect.getmembers(task_module): - if method in [ImageTask, InstanceTask, NetworkTask]: + if method in [NetworkTask, ImageTask]: continue if hasattr(method, 'apply_async'): task_queues.add(Queue(name, exchange=Exchange(task_module.__name__), routing_key=name)) diff --git a/ingredients_tasks/tasks/image.py b/ingredients_tasks/tasks/image.py index 0179126..4108300 100644 --- a/ingredients_tasks/tasks/image.py +++ b/ingredients_tasks/tasks/image.py @@ -3,42 +3,49 @@ from ingredients_db.models.images import ImageState from ingredients_tasks.tasks.tasks import ImageTask +from ingredients_tasks.vmware import VMWareClient logger = get_task_logger(__name__) @celery.shared_task(base=ImageTask, bind=True, max_retires=2, default_retry_delay=5) def create_image(self, **kwargs): - vmware_image = self.vmware_session.get_image(self.image.file_name) + image = self.request.image + with VMWareClient.client_session() as vmware: + vmware_image = vmware.get_image(image.file_name) - if vmware_image is None: - raise ValueError("Could not find image file") + if vmware_image is None: + raise ValueError("Could not find image file") - self.image.state = ImageState.CREATED + image.state = ImageState.CREATED @celery.shared_task(base=ImageTask, bind=True, max_retires=2, default_retry_delay=5) def delete_image(self, **kwargs): - vmware_image = self.vmware_session.get_image(self.image.file_name) + image = self.request.image + with VMWareClient.client_session() as vmware: + vmware_image = vmware.get_image(image.file_name) - if vmware_image is not None: - self.vmware_session.delete_image(vmware_image) - else: - logger.warning("Tried to delete image %s but couldn't find its backing file" % self.image.id) + if vmware_image is not None: + vmware.delete_image(vmware_image) + else: + logger.warning("Tried to delete image %s but couldn't find its backing file" % str(image.id)) - self.image.state = ImageState.DELETED + image.state = ImageState.DELETED - self.db_session.delete(self.image) + self.request.session.delete(image) @celery.shared_task(base=ImageTask, bind=True, max_retires=2, default_retry_delay=5) def convert_vm(self, **kwargs): - vmware_vm = self.vmware_session.get_vm(self.image.file_name) + image = self.request.image + with VMWareClient.client_session() as vmware: + vmware_vm = vmware.get_vm(image.file_name) - if vmware_vm is None: - raise LookupError( - 'Could not find backing vm for image %s when trying to convert to template.' % str(self.instance.id)) + if vmware_vm is None: + raise LookupError( + 'Could not find backing vm for image %s when trying to convert to template.' % str(image.id)) - self.vmware_session.template_vm(vmware_vm) + vmware.template_vm(vmware_vm) - self.image.state = ImageState.CREATED + image.state = ImageState.CREATED diff --git a/ingredients_tasks/tasks/instance.py b/ingredients_tasks/tasks/instance.py index a476385..006c062 100644 --- a/ingredients_tasks/tasks/instance.py +++ b/ingredients_tasks/tasks/instance.py @@ -6,128 +6,141 @@ from ingredients_db.models.instance import InstanceState from ingredients_db.models.network import Network from ingredients_db.models.network_port import NetworkPort +from ingredients_tasks.omapi import OmapiClient from ingredients_tasks.tasks.tasks import InstanceTask +from ingredients_tasks.vmware import VMWareClient logger = get_task_logger(__name__) @celery.shared_task(base=InstanceTask, bind=True, max_retires=2, default_retry_delay=5) def create_instance(self, **kwargs): - if self.instance.image_id is None: + instance = self.request.instance + if instance.image_id is None: raise ValueError("Image turned NULL before the instance could be created") try: - image = self.db_session.query(Image).filter(Image.id == self.instance.image_id).one() + image = self.request.session.query(Image).filter(Image.id == instance.image_id).one() except NoResultFound: raise LookupError("Image got deleted before the instance could be created") - vmware_image = self.vmware_session.get_image(image.file_name) + with VMWareClient.client_session() as vmware: + vmware_image = vmware.get_image(image.file_name) - if vmware_image is None: - raise LookupError("Could not find image file to clone") + if vmware_image is None: + raise LookupError("Could not find image file to clone") - old_vmware_vm = self.vmware_session.get_vm(str(self.instance.id)) - if old_vmware_vm is not None: - # A backing vm with the same id exists (how?! the task should have failed) so we probably should delete it - logger.info( - 'A backing vm with the id of %s already exists so it is going to be deleted.' % str(self.instance.id)) - self.vmware_session.delete_vm(old_vmware_vm) + old_vmware_vm = vmware.get_vm(str(instance.id)) + if old_vmware_vm is not None: + # A backing vm with the same id exists (how?! the task should have failed) so we probably should delete it + logger.info( + 'A backing vm with the id of %s already exists so it is going to be deleted.' % str(instance.id)) + vmware.delete_vm(old_vmware_vm) - # We need a nested transaction because we need to lock the network so we can calculate the next free ip - # Without a nested transaction the lock will last for the total time of the task which could be several minutes - # this will block the api from creating new network_ports. With nested we only block for the time needed to - # calculate the next available ip address which is at most O(n) time with n being the number of - # ip addresses in the cidr - with self.database.session() as nested_session: - network_port = nested_session.query(NetworkPort).filter( - NetworkPort.id == self.instance.network_port_id).first() + # We need a nested transaction because we need to lock the network so we can calculate the next free ip + # Without a nested transaction the lock will last for the total time of the task which could be several minutes + # this will block the api from creating new network_ports. With nested we only block for the time needed to + # calculate the next available ip address which is at most O(n) time with n being the number of + # ip addresses in the cidr + with self.database.session() as nested_session: + network_port = nested_session.query(NetworkPort).filter(NetworkPort.id == instance.network_port_id).first() - network = nested_session.query(Network).filter(Network.id == network_port.network_id).with_for_update().first() + network = nested_session.query(Network).filter( + Network.id == network_port.network_id).with_for_update().first() - logger.info('Allocating IP address for instance %s' % str(self.instance.id)) - if network_port.ip_address is not None: - # An ip address was already allocated (how?! the task should have failed) so let's reset it - network_port.ip_address = None + logger.info('Allocating IP address for instance %s' % str(instance.id)) + if network_port.ip_address is not None: + # An ip address was already allocated (how?! the task should have failed) so let's reset it + network_port.ip_address = None - ip_address = network.next_free_address(nested_session) - if ip_address is None: - raise IndexError("Could not allocate a free ip address. Is the pool full?") - network_port.ip_address = ip_address - logger.info('Allocated IP address %s for instance %s' % (str(ip_address), str(self.instance.id))) + ip_address = network.next_free_address(nested_session) + if ip_address is None: + raise IndexError("Could not allocate a free ip address. Is the pool full?") + network_port.ip_address = ip_address + logger.info('Allocated IP address %s for instance %s' % (str(ip_address), str(instance.id))) - port_group = self.vmware_session.get_port_group(network.port_group) - if port_group is None: - raise LookupError("Cloud not find port group to connect to") - nested_session.commit() + port_group = vmware.get_port_group(network.port_group) + if port_group is None: + raise LookupError("Cloud not find port group to connect to") + nested_session.commit() - logger.info('Creating backing vm for instance %s' % str(self.instance.id)) - vmware_vm = self.vmware_session.create_vm(vm_name=str(self.instance.id), image=vmware_image, port_group=port_group) + logger.info('Creating backing vm for instance %s' % str(instance.id)) + vmware_vm = vmware.create_vm(vm_name=str(instance.id), image=vmware_image, port_group=port_group) - nic_mac = self.vmware_session.find_vm_mac(vmware_vm) - if nic_mac is None: - raise LookupError("Could not find mac address of nic") + nic_mac = vmware.find_vm_mac(vmware_vm) + if nic_mac is None: + raise LookupError("Could not find mac address of nic") - logger.info('Telling DHCP about our IP for instance %s' % str(self.instance.id)) - self.omapi_session.add_host(str(ip_address), nic_mac) + logger.info('Telling DHCP about our IP for instance %s' % str(instance.id)) + with OmapiClient.client_session() as omapi: + omapi.add_host(str(ip_address), nic_mac) - logger.info('Powering on backing vm for instance %s' % str(self.instance.id)) - self.vmware_session.power_on_vm(vmware_vm) + logger.info('Powering on backing vm for instance %s' % str(instance.id)) + vmware.power_on_vm(vmware_vm) - self.instance.state = InstanceState.ACTIVE + instance.state = InstanceState.ACTIVE @celery.shared_task(base=InstanceTask, bind=True, max_retires=2, default_retry_delay=5) def delete_instance(self, delete_backing: bool, **kwargs): + instance = self.request.instance if delete_backing: - vmware_vm = self.vmware_session.get_vm(str(self.instance.id)) + with VMWareClient.client_session() as vmware: + vmware_vm = vmware.get_vm(str(instance.id)) - if vmware_vm is None: - logger.warning('Could not find backing vm for instance %s when trying to delete.' % str(self.instance.id)) - else: - logger.info('Deleting backing vm for instance %s' % str(self.instance.id)) - self.vmware_session.power_off_vm(vmware_vm) - self.vmware_session.delete_vm(vmware_vm) + if vmware_vm is None: + logger.warning( + 'Could not find backing vm for instance %s when trying to delete.' % str(instance.id)) + else: + logger.info('Deleting backing vm for instance %s' % str(instance.id)) + vmware.power_off_vm(vmware_vm) + vmware.delete_vm(vmware_vm) - network_port = self.db_session.query(NetworkPort).filter( - NetworkPort.id == self.instance.network_port_id).first() + network_port = self.request.session.query(NetworkPort).filter(NetworkPort.id == instance.network_port_id).first() - self.instance.state = InstanceState.DELETED - self.db_session.delete(self.instance) - self.db_session.delete(network_port) + instance.state = InstanceState.DELETED + self.request.session.delete(instance) + self.request.session.delete(network_port) @celery.shared_task(base=InstanceTask, bind=True, max_retires=2, default_retry_delay=5) def stop_instance(self, hard=False, timeout=60, **kwargs): - vmware_vm = self.vmware_session.get_vm(str(self.instance.id)) + instance = self.request.instance + with VMWareClient.client_session() as vmware: + vmware_vm = vmware.get_vm(str(instance.id)) - if vmware_vm is None: - raise LookupError('Could not find backing vm for instance %s when trying to stop.' % str(self.instance.id)) + if vmware_vm is None: + raise LookupError('Could not find backing vm for instance %s when trying to stop.' % str(instance.id)) - self.vmware_session.power_off_vm(vmware_vm, hard=hard, timeout=timeout) + vmware.power_off_vm(vmware_vm, hard=hard, timeout=timeout) - self.instance.state = InstanceState.STOPPED + instance.state = InstanceState.STOPPED @celery.shared_task(base=InstanceTask, bind=True, max_retires=2, default_retry_delay=5) def start_instance(self, **kwargs): - vmware_vm = self.vmware_session.get_vm(str(self.instance.id)) + instance = self.request.instance + with VMWareClient.client_session() as vmware: + vmware_vm = vmware.get_vm(str(instance.id)) - if vmware_vm is None: - raise LookupError('Could not find backing vm for instance %s when trying to start.' % str(self.instance.id)) + if vmware_vm is None: + raise LookupError('Could not find backing vm for instance %s when trying to start.' % str(instance.id)) - self.vmware_session.power_on_vm(vmware_vm) + vmware.power_on_vm(vmware_vm) - self.instance.state = InstanceState.ACTIVE + instance.state = InstanceState.ACTIVE @celery.shared_task(base=InstanceTask, bind=True, max_retires=2, default_retry_delay=5) def restart_instance(self, hard=False, timeout=60, **kwargs): - vmware_vm = self.vmware_session.get_vm(str(self.instance.id)) + instance = self.request.instance + with VMWareClient.client_session() as vmware: + vmware_vm = vmware.get_vm(str(instance.id)) - if vmware_vm is None: - raise LookupError('Could not find backing vm for instance %s when trying to restart.' % str(self.instance.id)) + if vmware_vm is None: + raise LookupError('Could not find backing vm for instance %s when trying to restart.' % str(instance.id)) - self.vmware_session.power_off_vm(vmware_vm, hard=hard, timeout=timeout) - self.vmware_session.power_on_vm(vmware_vm) + vmware.power_off_vm(vmware_vm, hard=hard, timeout=timeout) + vmware.power_on_vm(vmware_vm) - self.instance.state = InstanceState.ACTIVE + instance.state = InstanceState.ACTIVE diff --git a/ingredients_tasks/tasks/network.py b/ingredients_tasks/tasks/network.py index 4919436..95ba732 100644 --- a/ingredients_tasks/tasks/network.py +++ b/ingredients_tasks/tasks/network.py @@ -2,13 +2,16 @@ from ingredients_db.models.network import NetworkState from ingredients_tasks.tasks.tasks import NetworkTask +from ingredients_tasks.vmware import VMWareClient @celery.shared_task(base=NetworkTask, bind=True, max_retires=2, default_retry_delay=5) def create_network(self, **kwargs): - port_group = self.vmware_session.get_port_group(self.network.port_group) + network = self.request.network + with VMWareClient.client_session() as vmware: + port_group = vmware.get_port_group(network.port_group) - if port_group is None: - raise ValueError("Could not find port group") + if port_group is None: + raise ValueError("Could not find port group") - self.network.state = NetworkState.CREATED + network.state = NetworkState.CREATED diff --git a/ingredients_tasks/tasks/tasks.py b/ingredients_tasks/tasks/tasks.py index 50318da..c156655 100644 --- a/ingredients_tasks/tasks/tasks.py +++ b/ingredients_tasks/tasks/tasks.py @@ -5,26 +5,18 @@ from celery.utils.log import get_task_logger from simple_settings import settings from sqlalchemy.exc import OperationalError, IntegrityError, DataError, ProgrammingError -from sqlalchemy.orm.exc import NoResultFound from sqlalchemy_utils.types.arrow import arrow from ingredients_db.database import Database -from ingredients_db.models.images import Image, ImageState +from ingredients_db.models.images import ImageState, Image from ingredients_db.models.instance import InstanceState, Instance from ingredients_db.models.network import Network, NetworkState from ingredients_db.models.task import TaskState, Task -from ingredients_tasks.omapi import OmapiClient -from ingredients_tasks.vmware import VMWareClient logger = get_task_logger(__name__) -class BaseMixin(object): - def __init__(self): - pass - - -class DatabaseMixin(BaseMixin): +class DBTask(celery.Task): def __init__(self): super().__init__() # TODO: find another place to put this db object (needs to not load the settings object) @@ -33,168 +25,99 @@ def __init__(self): self.database = Database(settings.DATABASE_HOST, settings.DATABASE_PORT, settings.DATABASE_USERNAME, settings.DATABASE_PASSWORD, settings.DATABASE_DB, settings.DATABASE_POOL_SIZE) self.database.connect() - self.db_session_manager = None - self.db_session = None - def setup_db_session(self): - self.db_session_manager = self.database.session() - self.db_session = self.db_session_manager.__enter__() + def __call__(self, *args, **kwargs): + + def commit_database(task, session): + task.stopped_at = arrow.now() - def after_return(self, status, retval, task_id, args, kwargs, einfo): - if self.db_session is not None: - try: - self.db_session.flush() + try: # Try to commit, if error log and force quit + session.flush() except (IntegrityError, DataError, ProgrammingError): - logger.exception("Error flushing transaction to database. This is probably due to a bug somewhere") + logger.exception( + "Error flushing transaction to database. This is probably due to a bug somewhere") os.killpg(os.getpgrp(), 9) - self.db_session.commit() - self.db_session_manager.__exit__(None, None, None) - - def on_failure(self, exc, task_id, args, kwargs, einfo): - if isinstance(exc, OperationalError): - # Rerun the task again in 60 seconds - self.retry(countdown=60, max_retries=sys.maxsize, throw=False) - - -class VMWareMixin(BaseMixin): - def __init__(self): - super().__init__() - self.vmware_session_manager = None - self.vmware_session = None - - def setup_vmware_session(self): - self.vmware_session_manager = VMWareClient.client_session() - self.vmware_session = self.vmware_session_manager.__enter__() - - def after_return(self, status, retval, task_id, args, kwargs, einfo): - self.vmware_session_manager.__exit__(None, None, None) - - def on_failure(self, exc, task_id, args, kwargs, einfo): - # We shouldn't retry vmware errors, just let it fail. + return + session.commit() + + try: # Try to do db stuff and catch OperationalError + with self.database.session() as session: + self.request.session = session + task = session.query(Task).filter(Task.id == self.request.id).first() + if task is None: # We might be faster than the db so retry + raise self.retry() + if task.stopped_at is not None: # Task has already ran + raise ValueError("Task has already stopped, cannot do it again.") + + self.on_database(session) # Load more stuff into the request + + try: + super().__call__(*args, **kwargs) + task.state = TaskState.COMPLETED + except Exception as exc: # There was an error during the call + try: # Set task to error and reraise + task.state = TaskState.ERROR + task.error_message = str(exc.msg) if hasattr(exc, 'msg') else str( + exc) # VMWare errors are stupid + + self.on_task_failure() + + raise exc + finally: # Commit the error + commit_database(task, session) + finally: # Set task to stopped and commit + commit_database(task, session) + + except OperationalError: # There some some sort of connection error so keep retrying + raise self.retry(countdown=60, max_retries=sys.maxsize) + + def on_database(self, session): pass - -class OmapiMixin(BaseMixin): - def __init__(self): - super().__init__() - self.omapi_session_manager = None - self.omapi_session = None - - def setup_omapi_session(self): - self.omapi_session_manager = OmapiClient.client_session() - self.omapi_session = self.omapi_session_manager.__enter__() - - def after_return(self, status, retval, task_id, args, kwargs, einfo): - self.omapi_session_manager.__exit__(None, None, None) - - def on_failure(self, exc, task_id, args, kwargs, einfo): - # We shouldn't retry omapi errors, just let it fail. + def on_task_failure(self): pass -class TaskStateMixin(DatabaseMixin): - def __init__(self): - super().__init__() - self.task = None - - def setup_task(self, task_id, kwargs): - self.setup_db_session() - try: - self.task = self.db_session.query(Task).filter(Task.id == task_id).one() - if self.task.stopped_at is not None: - self.task = None # Force a failure because the task has already stopped - raise ValueError("Task has already stopped, cannot do it again.") - except NoResultFound as exc: # We might be faster than the db so retry - raise self.retry() - - def after_return(self, status, retval, task_id, args, kwargs, einfo): - if self.task is not None: - self.task.stopped_at = arrow.now() - super().after_return(status, retval, task_id, args, kwargs, einfo) - - def on_success(self, retval, task_id, args, kwargs): - self.task.state = TaskState.COMPLETED - - def on_failure(self, exc, task_id, args, kwargs, einfo): - super().on_failure(exc, task_id, args, kwargs, einfo) - if self.task is not None: - if hasattr(exc, 'msg'): # VMWare errors are stupid - self.task.error_message = exc.msg - else: - self.task.error_message = str(exc) - self.task.state = TaskState.ERROR - - -class ImageTask(TaskStateMixin, VMWareMixin, celery.Task): - def __init__(self): - super().__init__() - self.image = None - - def __call__(self, *args, **kwargs): - self.setup_task(self.request.id, kwargs) - try: - self.image = self.db_session.query(Image).filter(Image.id == kwargs['image_id']).one() - except NoResultFound as exc: # We might be faster than the db so retry +class NetworkTask(DBTask): + def on_database(self, session): + self.request.network = None + network = session.query(Network).filter(Network.id == self.request.kwargs['network_id']).first() + if network is None: # We might be faster than the db so retry raise self.retry() - self.setup_vmware_session() - super().__call__(*args, **kwargs) - - def after_return(self, status, retval, task_id, args, kwargs, einfo): - super().after_return(status, retval, task_id, args, kwargs, einfo) - def on_failure(self, exc, task_id, args, kwargs, einfo): - super().on_failure(exc, task_id, args, kwargs, einfo) - if self.image is not None: - self.image.state = ImageState.ERROR + self.request.network = network + def on_task_failure(self): + if self.request.network is not None: + self.request.network = NetworkState.ERROR -class InstanceTask(TaskStateMixin, VMWareMixin, OmapiMixin, celery.Task): - def __init__(self): - super().__init__() - self.instance = None - def __call__(self, *args, **kwargs): - self.setup_task(self.request.id, kwargs) - try: - self.instance = self.db_session.query(Instance).filter( - Instance.id == kwargs['instance_id']).one() - except NoResultFound as exc: # We might be faster than the db so retry +class ImageTask(DBTask): + def on_database(self, session): + self.request.image = None + image = session.query(Image).filter(Image.id == self.request.kwargs['image_id']).first() + if image is None: # We might be faster than the db so retry raise self.retry() - self.setup_vmware_session() - self.setup_omapi_session() - super().__call__(*args, **kwargs) - def after_return(self, status, retval, task_id, args, kwargs, einfo): - super().after_return(status, retval, task_id, args, kwargs, einfo) + self.request.image = image - def on_failure(self, exc, task_id, args, kwargs, einfo): - super().on_failure(exc, task_id, args, kwargs, einfo) - if self.instance is not None: - self.instance.state = InstanceState.ERROR + def on_task_failure(self): + if self.request.image is not None: + self.request.image = ImageState.ERROR -class NetworkTask(TaskStateMixin, VMWareMixin, celery.Task): - def __init__(self): - super().__init__() - self.network = None - - def __call__(self, *args, **kwargs): - self.setup_task(self.request.id, kwargs) - try: - self.network = self.db_session.query(Network).filter( - Network.id == kwargs['network_id']).one() - except NoResultFound as exc: # We might be faster than the db so retry +class InstanceTask(DBTask): + def on_database(self, session): + self.request.instance = None + instance = session.query(Instance).filter(Instance.id == self.request.kwargs['instance_id']).first() + if instance is None: # We might be faster than the db so retry raise self.retry() - self.setup_vmware_session() - super().__call__(*args, **kwargs) - def after_return(self, status, retval, task_id, args, kwargs, einfo): - super().after_return(status, retval, task_id, args, kwargs, einfo) + self.request.instance = instance - def on_failure(self, exc, task_id, args, kwargs, einfo): - super().on_failure(exc, task_id, args, kwargs, einfo) - if self.network is not None: - self.network.state = NetworkState.ERROR + def on_task_failure(self): + if self.request.instance is not None: + self.request.instance = InstanceState.ERROR def create_task(session, entity, celery_task, signature=False, **kwargs):