Skip to content

Commit

Permalink
global: support Jinja templating for job args
Browse files Browse the repository at this point in the history
* Closes #36.
  • Loading branch information
slint committed Jun 13, 2024
1 parent 47a91bd commit 91ec6cc
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 26 deletions.
22 changes: 21 additions & 1 deletion invenio_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from sqlalchemy_utils.types import ChoiceType, JSONType, UUIDType
from werkzeug.utils import cached_property

from .utils import eval_tpl_str, walk_values

JSON = (
db.JSON()
.with_variant(postgresql.JSONB(none_as_null=True), "postgresql")
Expand All @@ -44,12 +46,17 @@ class Job(db.Model, Timestamp):
default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
schedule = db.Column(JSON, nullable=True)

# TODO: See if we move this to an API class
@property
def last_run(self):
"""Last run of the job."""
return self.runs.order_by(Run.created.desc()).first()

def last_run_with_status(self, status):
"""Last run of the job with a specific status."""
return (
self.runs.filter(Run.status == status).order_by(Run.created.desc()).first()
)

@property
def parsed_schedule(self):
"""Return schedule parsed as crontab or timedelta."""
Expand Down Expand Up @@ -109,6 +116,19 @@ def started_by(self):
args = db.Column(JSON, default=lambda: dict(), nullable=True)
queue = db.Column(db.String(64), nullable=False)

@classmethod
def create(cls, job, **kwargs):
"""Create a new run."""
if "args" not in kwargs:
args = deepcopy(job.default_args)
ctx = {"job": job, "run": kwargs}
walk_values(args, lambda val: eval_tpl_str(val, ctx))
kwargs["args"] = args
if "queue" not in kwargs:
kwargs["queue"] = job.default_queue

return cls(job=job, **kwargs)


class Task:
"""Celery Task model."""
Expand Down
32 changes: 12 additions & 20 deletions invenio_jobs/services/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

import traceback
import uuid
from typing import Any

from celery.beat import ScheduleEntry, Scheduler, logger
from invenio_db import db
from sqlalchemy import and_

from invenio_jobs.models import Job, Run, Task
from invenio_jobs.models import Job, Run
from invenio_jobs.tasks import execute_run


Expand Down Expand Up @@ -49,27 +47,23 @@ class RunScheduler(Scheduler):
Entry = JobEntry
entries = {}

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the database scheduler."""
super().__init__(*args, **kwargs)

@property
def schedule(self):
"""Get currently scheduled entries."""
return self.entries

# Celery internal override
#
# Celery overrides
#
def setup_schedule(self):
"""Setup schedule."""
self.sync()

# Celery internal override
def reserve(self, entry):
"""Update entry to next run execution time."""
new_entry = self.schedule[entry.job.id] = next(entry)
return new_entry

# Celery internal override
def apply_entry(self, entry, producer=None):
"""Create and apply a JobEntry."""
with self.app.flask_app.app_context():
Expand All @@ -93,26 +87,24 @@ def apply_entry(self, entry, producer=None):
else:
logger.debug("%s sent.", entry.task)

# Celery internal override
def sync(self):
"""Sync Jobs from db to the scheduler."""
# TODO Should we also have a cleaup task for runs? "stale" run (status running, starttime > hour, Run pending for > 1 hr)
with self.app.flask_app.app_context():
jobs = Job.query.filter(
and_(Job.active == True, Job.schedule != None)
).all()
Job.active.is_(True),
Job.schedule.isnot(None),
)
self.entries = {} # because some jobs might be deactivated
for job in jobs:
self.entries[job.id] = JobEntry.from_job(job)

#
# Helpers
#
def create_run(self, entry):
"""Create run from a JobEntry."""
job = Job.query.filter_by(id=entry.job.id).one()
run = Run(
job=job,
args=job.default_args,
queue=job.default_queue,
task_id=uuid.uuid4(),
)
job = Job.query.get(entry.job.id)
run = Run.create(job=job, task_id=uuid.uuid4())
db.session.commit()
return run
7 changes: 3 additions & 4 deletions invenio_jobs/services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,10 @@ def create(self, identity, job_id, data, uow=None):
raise_errors=True,
)

valid_data.setdefault("queue", job.default_queue)
run = Run(
id=str(uuid.uuid4()),
run = Run.create(
job=job,
task_id=uuid.uuid4(),
id=str(uuid.uuid4()),
task_id=str(uuid.uuid4()),
started_by_id=identity.id,
status=RunStatusEnum.QUEUED,
**valid_data,
Expand Down
1 change: 1 addition & 0 deletions invenio_jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Tasks."""

from datetime import datetime, timezone

from celery import shared_task
Expand Down
43 changes: 43 additions & 0 deletions invenio_jobs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-Jobs is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Utilities."""

import ast

from jinja2.sandbox import SandboxedEnvironment

jinja_env = SandboxedEnvironment()


def eval_tpl_str(val, ctx):
"""Evaluate a Jinja template string."""
tpl = jinja_env.from_string(val)
res = tpl.render(**ctx)

try:
res = ast.literal_eval(res)
except Exception:
pass

return res


def walk_values(obj, transform_fn):
"""Recursively apply a function in-place to the value of dictionary or list."""
if isinstance(obj, dict):
items = obj.items()
elif isinstance(obj, list):
items = enumerate(obj)
else:
return transform_fn(obj)

for key, val in items:
if isinstance(val, (dict, list)):
walk_values(val, transform_fn)
else:
obj[key] = transform_fn(val)
125 changes: 124 additions & 1 deletion tests/resources/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

"""Resource tests."""

import pdb
from unittest.mock import patch

from invenio_jobs.tasks import execute_run
Expand Down Expand Up @@ -429,3 +428,127 @@ def test_jobs_delete(db, client, jobs):
assert res.json["hits"]["total"] == 2
hits = res.json["hits"]["hits"]
assert all(j["id"] != jobs.simple.id for j in hits)


@patch.object(execute_run, "apply_async")
def test_job_template_args(mock_apply_async, app, db, client, user):
client = user.login(client)
job_payload = {
"title": "Job with template args",
"task": "tasks.mock_task",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ job.last_run.created.isoformat() if job.last_run else None }}",
},
}

# Create a job
res = client.post("/jobs", json=job_payload)
assert res.status_code == 201
job_id = res.json["id"]
expected_job = {
"id": job_id,
"title": "Job with template args",
"description": None,
"active": True,
"task": "tasks.mock_task",
"default_queue": "celery",
"default_args": {
"arg1": "{{ 1 + 1 }}",
"arg2": "{{ job.title | upper }}",
"kwarg1": "{{ job.last_run.created.isoformat() if job.last_run else None }}",
},
"schedule": None,
"last_run": None,
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}",
"runs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs",
},
}
assert res.json == expected_job

# Create/trigger a run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": None,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run
last_run_created = res.json["created"].replace("+00:00", "")

# Trigger another run to test the kwarg1 template depending on the last run
res = client.post(f"/jobs/{job_id}/runs")
assert res.status_code == 201
run_id = res.json["id"]
expected_run = {
"id": run_id,
"job_id": job_id,
"task_id": res.json["task_id"],
"started_by_id": int(user.id),
"started_by": {
"id": str(user.id),
"username": user.username,
"profile": user._user_profile,
"links": {
# "self": f"https://127.0.0.1:5000/api/users/{user.id}",
},
"identities": {},
"is_current_user": True,
"type": "user",
},
"started_at": res.json["started_at"],
"finished_at": res.json["finished_at"],
"status": "QUEUED",
"message": None,
"title": None,
"args": {
"arg1": 2,
"arg2": "JOB WITH TEMPLATE ARGS",
"kwarg1": last_run_created,
},
"queue": "celery",
"created": res.json["created"],
"updated": res.json["updated"],
"links": {
"self": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}",
"logs": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/logs",
"stop": f"https://127.0.0.1:5000/api/jobs/{job_id}/runs/{run_id}/actions/stop",
},
}
assert res.json == expected_run

0 comments on commit 91ec6cc

Please sign in to comment.