-
Notifications
You must be signed in to change notification settings - Fork 5
/
stateful_saga.py
120 lines (91 loc) · 4.85 KB
/
stateful_saga.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
__all__ = ['AbstractSagaStateRepository', 'StatefulSaga']
import abc
from celery import Celery, Task
from .utils import success_task_name, failure_task_name
from .base_saga import BaseSaga, BaseStep
from .async_saga import AsyncSaga, AsyncStep
class AbstractSagaStateRepository(abc.ABC):
@abc.abstractmethod
def get_saga_state_by_id(self, saga_id: int) -> object:
raise NotImplementedError
@abc.abstractmethod
def update_status(self, saga_id: int, status: str) -> object:
raise NotImplementedError
@abc.abstractmethod
def update(self, saga_id: int, **fields_to_update: str) -> object:
raise NotImplementedError
@abc.abstractmethod
def on_step_failure(self, saga_id: int, failed_step: BaseStep, initial_failure_payload: dict) -> object:
pass
class StatefulSaga(AsyncSaga, abc.ABC):
"""
Note this class assumes sqlalchemy-mixins library is used.
Use it rather as an example
"""
saga_state_repository: AbstractSagaStateRepository = None
_saga_state = None # cached SQLAlchemy instance
def __init__(self, saga_state_repository: AbstractSagaStateRepository, celery_app: Celery, saga_id: int):
self.saga_state_repository = saga_state_repository
super().__init__(celery_app, saga_id)
@property
def saga_state(self):
if not self._saga_state:
self._saga_state = self.saga_state_repository.get_saga_state_by_id(self.saga_id)
return self._saga_state
def run_step(self, step: BaseStep):
self.saga_state_repository.update_status(self.saga_id, status=f'{step.name}.running')
super().run_step(step)
def compensate_step(self, step: BaseStep, initial_failure_payload: dict):
self.saga_state_repository.update_status(self.saga_id, status=f'{step.name}.compensating')
super().compensate_step(step, initial_failure_payload)
self.saga_state_repository.update_status(self.saga_id, status=f'{step.name}.compensated')
def on_step_success(self, step: AsyncStep, *args, **kwargs):
self.saga_state_repository.update_status(self.saga_id, status=f'{step.name}.succeeded')
super().on_async_step_success(step, *args, **kwargs)
def on_step_failure(self, failed_step: AsyncStep, payload: dict):
self.saga_state_repository.update_status(self.saga_id, status=f'{failed_step.name}.failed')
super().on_async_step_failure(failed_step, payload)
def on_saga_success(self):
super().on_saga_success()
self.saga_state_repository.update_status(self.saga_id, 'succeeded')
def on_saga_failure(self, *args, **kwargs):
super().on_saga_failure(*args, **kwargs)
self.saga_state_repository.update_status(self.saga_id, 'failed')
def compensate(self, failed_step: BaseStep,
initial_failure_payload: dict = None):
self.saga_state_repository.on_step_failure(self.saga_id, failed_step, initial_failure_payload)
super().compensate(failed_step, initial_failure_payload)
@classmethod
def register_async_step_handlers(cls,
saga_state_repository: AbstractSagaStateRepository,
celery_app: Celery):
# noinspection PyTypeChecker
dummy_saga_instance = cls(None, None, None)
for step in dummy_saga_instance.async_steps:
cls.register_success_handler_for_step(saga_state_repository,
celery_app, step)
cls.register_failure_handler_for_step(saga_state_repository,
celery_app, step)
@classmethod
def register_success_handler_for_step(cls,
saga_state_repository: AbstractSagaStateRepository,
celery_app: Celery, step: AsyncStep):
def on_success_handler(celery_task: Task, saga_id: int, payload: dict):
saga = cls(saga_state_repository=saga_state_repository,
celery_app=celery_app, saga_id=saga_id)
step_ = saga.get_async_step_by_success_task_name(celery_task.name)
saga.on_async_step_success(step_, payload)
celery_app.task(
name=success_task_name(step.base_task_name),
bind=True
)(on_success_handler)
@classmethod
def register_failure_handler_for_step(cls, saga_state_repository: AbstractSagaStateRepository, celery_app: Celery, step: AsyncStep):
def on_failure_handler(celery_task: Task, saga_id: int, payload: dict):
saga = cls(saga_state_repository, celery_app, saga_id)
step_ = saga.get_async_step_by_failure_task_name(celery_task.name)
saga.on_async_step_failure(step_, payload)
celery_app.task(
name=failure_task_name(step.base_task_name),
bind=True
)(on_failure_handler)