From cffab7fd686b58ecf8d14a2392363f1b171649e1 Mon Sep 17 00:00:00 2001 From: Karolina Przerwa Date: Tue, 27 Aug 2024 11:28:45 +0200 Subject: [PATCH] run: serialize arguments for celery task --- invenio_jobs/models.py | 13 +++++++++---- invenio_jobs/services/schema.py | 12 ++++++++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/invenio_jobs/models.py b/invenio_jobs/models.py index 6e40c3b..78fb56e 100644 --- a/invenio_jobs/models.py +++ b/invenio_jobs/models.py @@ -49,7 +49,6 @@ class Job(db.Model, Timestamp): active = db.Column(db.Boolean, default=True, nullable=False) title = db.Column(db.String(255), nullable=False) description = db.Column(db.Text) - # default_args = db.Column(JSON, default=lambda: dict(), nullable=True) task = db.Column(db.String(255)) default_queue = db.Column(db.String(64)) schedule = db.Column(JSON, nullable=True) @@ -146,13 +145,15 @@ def create(cls, job, **kwargs): """Create a new run.""" if "args" not in kwargs: kwargs["args"] = cls.generate_args(job) + else: + task_arguments = deepcopy(kwargs["args"].get("args", {})) + kwargs["args"] = cls.generate_args(job, task_arguments=task_arguments) if "queue" not in kwargs: kwargs["queue"] = job.default_queue - return cls(job=job, **kwargs) @classmethod - def generate_args(cls, job): + def generate_args(cls, job, task_arguments=None): """Generate new run args. We allow a templating mechanism to generate the args for the run. It's important @@ -160,7 +161,10 @@ def generate_args(cls, job): classes or Python objects or functions. Otherwise, we risk that users could execute arbitrary code, or perform harmful DB operations (e.g. delete rows). """ - args = deepcopy(job.default_args) + if task_arguments: + args = Task.get(job.task).build_task_arguments(job_obj=job, **task_arguments) + else: + args = deepcopy(job.default_args) args = json.dumps(args, indent=4, sort_keys=True, default=str) args = json.loads(args) return args @@ -208,3 +212,4 @@ def all(cls): def get(cls, id_): """Get registered task by id.""" return cls(current_jobs.registry.get(id_)) + diff --git a/invenio_jobs/services/schema.py b/invenio_jobs/services/schema.py index 60d02a2..a151322 100644 --- a/invenio_jobs/services/schema.py +++ b/invenio_jobs/services/schema.py @@ -8,12 +8,14 @@ """Service schemas.""" import inspect +import json from copy import deepcopy from datetime import timezone from invenio_i18n import lazy_gettext as _ from invenio_users_resources.services import schemas as user_schemas -from marshmallow import EXCLUDE, Schema, fields, post_load, pre_dump, types, validate +from marshmallow import EXCLUDE, Schema, fields, post_load, pre_dump, types, validate, \ + pre_load from marshmallow_oneofschema import OneOfSchema from marshmallow_utils.fields import SanitizedUnicode, TZDateTime from marshmallow_utils.permissions import FieldPermissionsMixin @@ -243,10 +245,16 @@ class Meta: dump_default=lambda: current_jobs.default_queue, ) + @pre_load + def wrap_args(self, obj, many, **kwargs): + """Workaround for nested args.""" + obj["args"] = {"args": obj["args"]} + return obj + @post_load def pick_args(self, obj, many, **kwargs): """Choose custom or default args.""" custom_args = obj.pop("custom_args") if custom_args: - obj["args"] = custom_args + obj["args"] = json.loads(custom_args) return obj