Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion durabletask/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from durabletask.entities.entity_lock import EntityLock
from durabletask.entities.entity_context import EntityContext
from durabletask.entities.entity_metadata import EntityMetadata
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException

__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata",
"EntityOperationFailedException"]

PACKAGE_NAME = "durabletask.entities"
12 changes: 9 additions & 3 deletions durabletask/entities/entity_instance_id.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
class EntityInstanceId:
def __init__(self, entity: str, key: str):
self.entity = entity
if not entity or not key:
raise ValueError("Entity name and key cannot be empty.")
if "@" in key:
raise ValueError("Entity key cannot contain '@' symbol.")
Comment thread
halspang marked this conversation as resolved.
Outdated
self.entity = entity.lower()
Comment thread
andystaples marked this conversation as resolved.
self.key = key

def __str__(self) -> str:
Expand Down Expand Up @@ -35,8 +39,10 @@ def parse(entity_id: str) -> "EntityInstanceId":
ValueError
If the input string is not in the correct format.
"""
if not entity_id.startswith("@"):
raise ValueError("Entity ID must start with '@'.")
try:
_, entity, key = entity_id.split("@", 2)
return EntityInstanceId(entity=entity, key=key)
except ValueError as ex:
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
raise ValueError(f"Invalid entity ID format: {entity_id}") from ex
return EntityInstanceId(entity=entity, key=key)
15 changes: 15 additions & 0 deletions durabletask/entities/entity_operation_failed_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails
from durabletask.entities.entity_instance_id import EntityInstanceId


class EntityOperationFailedException(Exception):
"""Exception raised when an operation on an Entity Function fails."""

def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None:
super().__init__()
self.entity_instance_id = entity_instance_id
self.operation_name = operation_name
self.failure_details = failure_details

def __str__(self) -> str:
return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}"
12 changes: 12 additions & 0 deletions durabletask/internal/json_encode_output_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any


class JsonEncodeOutputException(Exception):
"""Custom exception type used to indicate that an orchestration result could not be JSON-encoded."""

def __init__(self, problem_object: Any):
super().__init__()
self.problem_object = problem_object

def __str__(self) -> str:
return f"The orchestration result could not be encoded. Object details: {self.problem_object}"
Comment thread
halspang marked this conversation as resolved.
54 changes: 37 additions & 17 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import grpc
from google.protobuf import empty_pb2

from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
from durabletask.internal import helpers
from durabletask.internal.entity_state_shim import StateShim
from durabletask.internal.helpers import new_timestamp
from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
import durabletask.internal.helpers as ph
Expand Down Expand Up @@ -141,14 +143,12 @@ class _Registry:
orchestrators: dict[str, task.Orchestrator]
activities: dict[str, task.Activity]
entities: dict[str, task.Entity]
entity_instances: dict[str, DurableEntity]
Comment thread
halspang marked this conversation as resolved.
versioning: Optional[VersioningOptions] = None

def __init__(self):
self.orchestrators = {}
self.activities = {}
self.entities = {}
self.entity_instances = {}

def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
if fn is None:
Expand Down Expand Up @@ -201,6 +201,7 @@ def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
def add_named_entity(self, name: str, fn: task.Entity) -> None:
Comment thread
halspang marked this conversation as resolved.
if not name:
raise ValueError("A non-empty entity name is required.")
name = name.lower()
Comment thread
andystaples marked this conversation as resolved.
if name in self.entities:
raise ValueError(f"A '{name}' entity already exists.")

Expand Down Expand Up @@ -829,7 +830,7 @@ def __init__(self, instance_id: str, registry: _Registry):
self._pending_actions: dict[int, pb.OrchestratorAction] = {}
self._pending_tasks: dict[int, task.CompletableTask] = {}
# Maps entity ID to task ID
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {}
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
# Maps criticalSectionId to task ID
self._entity_lock_id_map: dict[str, int] = {}
Expand Down Expand Up @@ -902,7 +903,10 @@ def set_complete(
self._result = result
result_json: Optional[str] = None
if result is not None:
result_json = result if is_result_encoded else shared.to_json(result)
try:
result_json = result if is_result_encoded else shared.to_json(result)
except (ValueError, TypeError):
result_json = shared.to_json(str(JsonEncodeOutputException(result)))
action = ph.new_complete_orchestration_action(
self.next_sequence_number(), status, result_json
)
Expand Down Expand Up @@ -1606,7 +1610,7 @@ def process_event(
raise TypeError("Unexpected sub-orchestration task type")
elif event.HasField("eventRaised"):
if event.eventRaised.name in ctx._entity_task_id_map:
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
Expand Down Expand Up @@ -1680,9 +1684,10 @@ def process_event(
)
try:
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
operation = event.entityOperationCalled.operation
except ValueError:
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id)
elif event.HasField("entityOperationSignaled"):
# This history event confirms that the entity signal was successfully scheduled.
# Remove the entityOperationSignaled event from the pending action list so we don't schedule it
Expand Down Expand Up @@ -1743,7 +1748,7 @@ def process_event(
ctx.resume()
elif event.HasField("entityOperationCompleted"):
request_id = event.entityOperationCompleted.requestId
entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
if not entity_id:
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
if not task_id:
Expand All @@ -1762,10 +1767,29 @@ def process_event(
entity_task.complete(result)
ctx.resume()
elif event.HasField("entityOperationFailed"):
if not ctx.is_replaying:
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
pass
request_id = event.entityOperationFailed.requestId
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
if not entity_id:
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
if operation is None:
raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'")
if not task_id:
raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
entity_task = ctx._pending_tasks.pop(task_id, None)
if not entity_task:
if not ctx.is_replaying:
self._logger.warning(
f"{ctx.instance_id}: Ignoring unexpected entityOperationFailed event with request ID = {request_id}."
)
return
failure = EntityOperationFailedException(
entity_id,
operation,
event.entityOperationFailed.failureDetails
)
ctx._entity_context.recover_lock_after_call(entity_id)
entity_task.fail(str(failure), failure)
ctx.resume()
elif event.HasField("orchestratorCompleted"):
# Added in Functions only (for some reason) and does not affect orchestrator flow
pass
Expand All @@ -1777,7 +1801,7 @@ def process_event(
if action and action.HasField("sendEntityMessage"):
if action.sendEntityMessage.HasField("entityOperationCalled"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId)
elif action.sendEntityMessage.HasField("entityLockRequested"):
entity_id, event_id = self._parse_entity_event_sent_input(event)
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
Expand Down Expand Up @@ -1937,11 +1961,7 @@ def execute(
ctx = EntityContext(orchestration_id, operation, state, entity_id)

if isinstance(fn, type) and issubclass(fn, DurableEntity):
if self._registry.entity_instances.get(str(entity_id), None):
entity_instance = self._registry.entity_instances[str(entity_id)]
else:
entity_instance = fn()
self._registry.entity_instances[str(entity_id)] = entity_instance
entity_instance = fn()
if not hasattr(entity_instance, operation):
raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
method = getattr(entity_instance, operation)
Expand Down
Empty file.
Loading
Loading