Skip to content

Commit 18655ae

Browse files
committed
fix: use threads for callbacks and avoid session concurrency
1 parent 44c5972 commit 18655ae

7 files changed

Lines changed: 192 additions & 15 deletions

File tree

src/pypsa_app/backend/api/routes/runs.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import re
5+
import threading
56
import urllib.parse
67
import uuid
78
from pathlib import PurePosixPath
@@ -33,8 +34,15 @@
3334
RunSummary,
3435
)
3536
from pypsa_app.backend.services.backend_registry import backend_registry
37+
from pypsa_app.backend.services.callback import (
38+
_build_payload,
39+
post_callback_sync,
40+
)
3641
from pypsa_app.backend.services.run import SnakedispatchClient, SnakedispatchError
37-
from pypsa_app.backend.services.sync import SYNCED_STATUSES, sync_run_from_job
42+
from pypsa_app.backend.services.sync import (
43+
SYNCED_STATUSES,
44+
sync_run_from_job,
45+
)
3846
from pypsa_app.backend.settings import settings
3947

4048
router = APIRouter()
@@ -125,7 +133,7 @@ def create_run(
125133

126134
payload = body.model_dump(
127135
exclude_none=True,
128-
exclude={"backend_id", "import_networks", "cache"},
136+
exclude={"backend_id", "import_networks", "cache", "callback_url"},
129137
)
130138
if body.cache:
131139
payload["cache_key"] = body.cache.key
@@ -142,6 +150,7 @@ def create_run(
142150
extra_files=body.extra_files,
143151
cache=body.cache.model_dump() if body.cache else None,
144152
import_networks=body.import_networks,
153+
callback_url=str(body.callback_url) if body.callback_url else None,
145154
status=RunStatus(result.get("status", "PENDING")),
146155
)
147156
db.add(run)
@@ -281,8 +290,18 @@ def get_run(
281290
if client:
282291
try:
283292
job = client.get_job(str(run_id))
284-
sync_run_from_job(run, job, db)
293+
needs_callback = sync_run_from_job(run, job, db)
285294
db.commit()
295+
if needs_callback and run.callback_url:
296+
# TODO: replace with proper async callback or
297+
# FastAPI BackgroundTasks.
298+
url = str(run.callback_url)
299+
payload = _build_payload(run)
300+
threading.Thread(
301+
target=post_callback_sync,
302+
args=(url, payload),
303+
daemon=True,
304+
).start()
286305
except SnakedispatchError:
287306
pass
288307

@@ -378,13 +397,31 @@ def cancel_run(
378397
sd_client = _get_client_for_run(run)
379398
try:
380399
result = sd_client.cancel_job(str(run_id))
381-
sync_run_from_job(run, result, db)
400+
needs_callback = sync_run_from_job(run, result, db)
382401
db.commit()
402+
if needs_callback and run.callback_url:
403+
# TODO: replace with proper async callback or
404+
# FastAPI BackgroundTasks.
405+
url = str(run.callback_url)
406+
payload = _build_payload(run)
407+
threading.Thread(
408+
target=post_callback_sync, args=(url, payload), daemon=True
409+
).start()
383410
except SnakedispatchError as e:
384411
if e.status_code in (404, 409):
385412
if run.status not in SYNCED_STATUSES:
386413
run.status = RunStatus.CANCELLED
387414
db.commit()
415+
if run.callback_url:
416+
# TODO: replace with proper async callback or
417+
# FastAPI BackgroundTasks.
418+
url = str(run.callback_url)
419+
payload = _build_payload(run)
420+
threading.Thread(
421+
target=post_callback_sync,
422+
args=(url, payload),
423+
daemon=True,
424+
).start()
388425
else:
389426
raise
390427

src/pypsa_app/backend/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ class Run(Base):
294294
extra_files: Mapped[Any | None] = mapped_column(JSON)
295295
cache: Mapped[Any | None] = mapped_column(JSON)
296296

297+
callback_url: Mapped[str | None] = mapped_column(String(512))
298+
297299
# Job metadata (synced from Snakedispatch)
298300
git_ref: Mapped[str | None] = mapped_column(String(255))
299301
git_sha: Mapped[str | None] = mapped_column(String(40))

src/pypsa_app/backend/schemas/run.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import uuid
44
from datetime import datetime
5+
from urllib.parse import urlparse
56

6-
from pydantic import BaseModel, ConfigDict, Field
7+
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, field_validator
78

89
from pypsa_app.backend.models import RunStatus
910
from pypsa_app.backend.schemas.auth import UserPublicResponse
1011
from pypsa_app.backend.schemas.backend import BackendPublicResponse
1112
from pypsa_app.backend.schemas.common import PaginationMeta
13+
from pypsa_app.backend.settings import settings
1214

1315

1416
class RunCache(BaseModel):
@@ -46,6 +48,22 @@ class RunCreate(BaseModel):
4648
cache: RunCache | None = None
4749
import_networks: list[str] | None = None
4850
backend_id: uuid.UUID | None = None
51+
callback_url: HttpUrl | None = None
52+
53+
@field_validator("callback_url")
54+
@classmethod
55+
def _validate_callback_domain(cls, v: HttpUrl | None) -> HttpUrl | None:
56+
if v is None:
57+
return v
58+
allowed = settings.resolved_callback_domains
59+
if not allowed:
60+
msg = "Callbacks are not enabled on this server"
61+
raise ValueError(msg)
62+
host = v.host or ""
63+
if not any(host == d or host.endswith(f".{d}") for d in allowed):
64+
msg = f"callback_url host '{host}' is not in the allowed domains"
65+
raise ValueError(msg)
66+
return v
4967

5068

5169
class RunSummary(BaseModel):
@@ -75,9 +93,18 @@ class RunResponse(RunSummary):
7593
extra_files: dict[str, str] | None = None
7694
cache: RunCache | None = None
7795
import_networks: list[str] | None = None
96+
callback_url: str | None = Field(None, validation_alias="callback_url")
7897
exit_code: int | None = None
7998
networks: list[RunNetworkSummary] = []
8099

100+
@field_validator("callback_url", mode="before")
101+
@classmethod
102+
def _redact_callback_url(cls, v: str | None) -> str | None:
103+
if not v:
104+
return None
105+
parsed = urlparse(v)
106+
return f"{parsed.scheme}://{parsed.hostname}/***"
107+
81108

82109
class RunListMeta(PaginationMeta):
83110
"""Extended pagination meta with run-specific filter options."""
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Callback helpers for notifying external systems of run status changes."""
2+
3+
import logging
4+
5+
import httpx
6+
7+
from pypsa_app.backend.models import Run
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def _build_payload(run: Run) -> dict:
13+
return {"run_id": str(run.job_id), "status": run.status.value}
14+
15+
16+
def post_callback_sync(url: str, payload: dict) -> None:
17+
"""POST to a callback URL (blocking)."""
18+
try:
19+
httpx.post(url, json=payload, timeout=5.0, follow_redirects=False)
20+
except Exception:
21+
logger.warning(
22+
"Callback failed for run %s to %s",
23+
payload.get("run_id"),
24+
url,
25+
exc_info=True,
26+
)
27+
28+
29+
def fire_callback_sync(run: Run) -> None:
30+
"""POST to the run's callback URL (blocking)."""
31+
if not run.callback_url:
32+
return
33+
post_callback_sync(str(run.callback_url), _build_payload(run))
34+
35+
36+
async def fire_callback_async(url: str, payload: dict) -> None:
37+
"""POST to a callback URL (async). Used by the background sync loop."""
38+
try:
39+
async with httpx.AsyncClient() as client:
40+
await client.post(url, json=payload, timeout=5.0, follow_redirects=False)
41+
except Exception:
42+
logger.warning(
43+
"Callback failed for run %s to %s",
44+
payload.get("run_id"),
45+
url,
46+
exc_info=True,
47+
)

src/pypsa_app/backend/services/sync.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
from pypsa_app.backend.database import SessionLocal
1313
from pypsa_app.backend.models import Run, RunStatus
1414
from pypsa_app.backend.services.backend_registry import backend_registry
15+
from pypsa_app.backend.services.callback import fire_callback_async
1516
from pypsa_app.backend.tasks import import_run_outputs_task
1617

1718
logger = logging.getLogger(__name__)
1819

20+
# Hold references to fire-and-forget callback tasks to prevent garbage collection.
21+
_background_tasks: set[asyncio.Task] = set()
22+
1923
# Statuses where the remote executor is done, no need to sync from Snakedispatch
2024
SYNCED_STATUSES = {
2125
RunStatus.UPLOADING,
@@ -41,8 +45,16 @@
4145
]
4246

4347

44-
def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
45-
"""Update a Run record from a Snakedispatch response dict."""
48+
_CALLBACK_STATUSES = SYNCED_STATUSES - {RunStatus.UPLOADING}
49+
50+
51+
def sync_run_from_job(run: Run, job: dict, db: Session) -> bool:
52+
"""Update a Run record from a Snakedispatch response dict.
53+
54+
Returns:
55+
True if a callback should be fired after the transaction commits.
56+
"""
57+
old_status = run.status
4658
changed = False
4759
for field in _SYNC_FIELDS:
4860
new_val = job.get(field)
@@ -65,7 +77,7 @@ def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
6577
run.status = RunStatus.UPLOADING
6678
db.flush()
6779
import_run_outputs_task.apply_async(args=(str(run.job_id),))
68-
return
80+
return False
6981
if completed_with_import_pending:
7082
run.status = RunStatus.COMPLETED
7183
changed = True
@@ -76,38 +88,65 @@ def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
7688
if changed:
7789
db.flush()
7890

91+
return run.status in _CALLBACK_STATUSES and old_status not in _CALLBACK_STATUSES
92+
93+
94+
def sync_non_terminal_runs() -> list[dict]:
95+
"""Poll all backends and update runs that haven't reached a terminal state.
7996
80-
def sync_non_terminal_runs() -> None:
81-
"""Poll all backends and update runs that haven't reached a terminal state."""
97+
Returns:
98+
List of callback dicts ``{"url": ..., "payload": ...}`` to be fired
99+
by the async caller after the DB session is closed.
100+
"""
101+
callbacks: list[dict] = []
82102
db = SessionLocal()
83103
try:
84104
non_terminal = db.query(Run).filter(Run.status.notin_(SYNCED_STATUSES)).all()
85105
if not non_terminal:
86-
return
106+
return callbacks
87107

88108
for backend_id, client in backend_registry.all_clients().items():
89109
backend_runs = [r for r in non_terminal if r.backend_id == backend_id]
90110
if not backend_runs:
91111
continue
92112
try:
93113
jobs_by_id = {j["job_id"]: j for j in client.list_jobs()}
114+
callback_runs: list[Run] = []
94115
for run in backend_runs:
95116
job = jobs_by_id.get(str(run.job_id))
96-
if job:
97-
sync_run_from_job(run, job, db)
117+
if job and sync_run_from_job(run, job, db):
118+
callback_runs.append(run)
98119
db.commit()
120+
callbacks.extend(
121+
{
122+
"url": str(run.callback_url),
123+
"payload": {
124+
"run_id": str(run.job_id),
125+
"status": run.status.value,
126+
},
127+
}
128+
for run in callback_runs
129+
if run.callback_url
130+
)
99131
except Exception:
100132
db.rollback()
101133
logger.warning("Sync failed for backend %s", backend_id, exc_info=True)
102134
finally:
103135
db.close()
136+
return callbacks
104137

105138

106139
async def run_sync_loop(interval: float = 10.0) -> None:
107140
"""Periodically sync non-terminal runs in a background thread."""
108141
while True:
109142
await asyncio.sleep(interval)
110143
try:
111-
await asyncio.to_thread(sync_non_terminal_runs)
144+
callbacks = await asyncio.to_thread(sync_non_terminal_runs)
145+
for cb in callbacks:
146+
task = asyncio.create_task(
147+
fire_callback_async(cb["url"], cb["payload"])
148+
)
149+
_background_tasks.add(task)
150+
task.add_done_callback(_background_tasks.discard)
112151
except Exception:
113152
logger.warning("Background run sync failed", exc_info=True)

src/pypsa_app/backend/settings.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,26 @@ def networks_path(self) -> Path:
9797
description="Interval in seconds between background Snakedispatch sync cycles",
9898
json_schema_extra={"category": "Runs"},
9999
)
100+
callback_url_allowed_domains: str = Field(
101+
default="",
102+
description=(
103+
"Comma-separated list of allowed domains for run callback URLs "
104+
"(e.g. hooks.myorg.dev,example.com). "
105+
"Callbacks are rejected unless the host matches. "
106+
"Empty disables callbacks entirely."
107+
),
108+
json_schema_extra={"category": "Runs"},
109+
)
110+
111+
@property
112+
def resolved_callback_domains(self) -> list[str]:
113+
"""Parse CALLBACK_URL_ALLOWED_DOMAINS into a list of domain strings."""
114+
if not self.callback_url_allowed_domains:
115+
return []
116+
return [
117+
d.strip() for d in self.callback_url_allowed_domains.split(",") if d.strip()
118+
]
119+
100120
snakedispatch_backends: str | None = Field(
101121
default=None,
102122
description=(

0 commit comments

Comments
 (0)