diff --git a/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_columns_a1f73ada66c5.py b/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_columns_a1f73ada66c5.py new file mode 100644 index 00000000..50b21028 --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_05_21_1508_add_task_creator_columns_a1f73ada66c5.py @@ -0,0 +1,68 @@ +"""add_task_creator_columns + +Revision ID: a1f73ada66c5 +Revises: 6c942325c828 +Create Date: 2026-05-21 15:08:51.441535 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1f73ada66c5" +down_revision: str | None = "6c942325c828" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column("tasks", sa.Column("creator_user_id", sa.String(), nullable=True)) + op.add_column( + "tasks", sa.Column("creator_service_account_id", sa.String(), nullable=True) + ) + with op.get_context().autocommit_block(): + # Partial indexes — the columns are NULL for all pre-migration rows and + # remain majority-NULL for legacy traffic indefinitely. A WHERE clause + # keeps the indexes scoped to populated rows so we don't pay storage or + # write amplification for NULL entries. + op.execute( + "CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_tasks_creator_user_id " + "ON tasks (creator_user_id) WHERE creator_user_id IS NOT NULL" + ) + op.execute( + "CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_tasks_creator_service_account_id " + "ON tasks (creator_service_account_id) " + "WHERE creator_service_account_id IS NOT NULL" + ) + # Add the CHECK as NOT VALID so the brief ACCESS EXCLUSIVE lock doesn't have to + # wait on an existence scan, then VALIDATE under SHARE UPDATE EXCLUSIVE which + # doesn't block concurrent reads/writes. `tasks` is a high-write table; even the + # short ACCESS EXCLUSIVE held during a vanilla CHECK addition queues behind + # in-flight transactions and blocks readers until it releases. + # + # Each ALTER must commit before the next one runs — otherwise Alembic's + # default single-transaction wrapper holds the ACCESS EXCLUSIVE from + # `NOT VALID` straight through the `VALIDATE` scan, collapsing the + # two-statement split into one long blocking window. Use autocommit_block + # to release locks between statements. + with op.get_context().autocommit_block(): + op.execute( + "ALTER TABLE tasks ADD CONSTRAINT ck_tasks_at_most_one_creator " + "CHECK ((creator_user_id IS NULL) OR (creator_service_account_id IS NULL)) " + "NOT VALID" + ) + op.execute("ALTER TABLE tasks VALIDATE CONSTRAINT ck_tasks_at_most_one_creator") + + +def downgrade() -> None: + op.drop_constraint("ck_tasks_at_most_one_creator", "tasks", type_="check") + with op.get_context().autocommit_block(): + op.execute( + "DROP INDEX CONCURRENTLY IF EXISTS ix_tasks_creator_service_account_id" + ) + op.execute("DROP INDEX CONCURRENTLY IF EXISTS ix_tasks_creator_user_id") + op.drop_column("tasks", "creator_service_account_id") + op.drop_column("tasks", "creator_user_id") diff --git a/agentex/database/migrations/migration_history.txt b/agentex/database/migrations/migration_history.txt index b18d86d8..5f495069 100644 --- a/agentex/database/migrations/migration_history.txt +++ b/agentex/database/migrations/migration_history.txt @@ -1,4 +1,5 @@ -a9959ebcbe98 -> 6c942325c828 (head), adding task cleaned at +6c942325c828 -> a1f73ada66c5 (head), add_task_creator_columns +a9959ebcbe98 -> 6c942325c828, adding task cleaned at e9c4ff9e6542 -> a9959ebcbe98, finalize_spans_task_id 9ff3ee32c81b -> e9c4ff9e6542, add_tasks_metadata_gin_index 57c5ed4f59ae -> 9ff3ee32c81b, uppercase deployment status enum labels diff --git a/agentex/src/adapters/authorization/adapter_agentex_authz_proxy.py b/agentex/src/adapters/authorization/adapter_agentex_authz_proxy.py index 0e44479d..3ea465a8 100644 --- a/agentex/src/adapters/authorization/adapter_agentex_authz_proxy.py +++ b/agentex/src/adapters/authorization/adapter_agentex_authz_proxy.py @@ -85,6 +85,35 @@ async def list_resources( ) return response["items"] + async def register_resource( + self, + principal: AgentexAuthPrincipalContext, + resource: AgentexResource, + parent: AgentexResource | None = None, + ) -> None: + payload: dict = { + "principal": principal, + "resource": resource.model_dump(), + } + if parent is not None: + payload["parent"] = parent.model_dump() + await HttpRequestHandler.post_with_error_handling( + self.agentex_auth_url, "/v1/authz/register", json=payload + ) + + async def deregister_resource( + self, + principal: AgentexAuthPrincipalContext, + resource: AgentexResource, + ) -> None: + payload = { + "principal": principal, + "resource": resource.model_dump(), + } + await HttpRequestHandler.post_with_error_handling( + self.agentex_auth_url, "/v1/authz/deregister", json=payload + ) + DAgentexAuthorization = Annotated[ AgentexAuthorizationProxy, Depends(AgentexAuthorizationProxy) diff --git a/agentex/src/adapters/authorization/port.py b/agentex/src/adapters/authorization/port.py index 80c05200..44b1d776 100644 --- a/agentex/src/adapters/authorization/port.py +++ b/agentex/src/adapters/authorization/port.py @@ -49,3 +49,31 @@ async def list_resources( filter_operation: AuthorizedOperationType = AuthorizedOperationType.read, ) -> Iterable[str]: """List resource_ids for a given principal""" + + @abstractmethod + async def register_resource( + self, + principal: PrincipalT, + resource: AgentexResource, + parent: AgentexResource | None = None, + ) -> None: + """Register a newly-created resource in the authorization graph. + + Atomically writes the relation tuples the schema requires (tenant + + owner, plus an optional typed parent like ``task.parent_agent``). + Distinct from ``grant`` because ``grant`` writes a single role + relation, which is insufficient for schemas that gate access on + ``tenant->membership``. + """ + + @abstractmethod + async def deregister_resource( + self, + principal: PrincipalT, + resource: AgentexResource, + ) -> None: + """Deregister a resource being deleted from the authorization graph. + + Removes every relation tuple written for the resource — keeps the + graph in sync with the application database on row delete. + """ diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index 42a66c1a..f2f62442 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -2,6 +2,7 @@ JSON, BigInteger, Boolean, + CheckConstraint, Column, DateTime, ForeignKey, @@ -75,6 +76,12 @@ class TaskORM(BaseORM): cleaned_at = Column(DateTime(timezone=True), nullable=True) params = Column(JSONB, nullable=True) task_metadata = Column(JSONB, nullable=True) + # NB: the runtime DB indexes are partial (`WHERE … IS NOT NULL`) — see + # migration a1f73ada66c5. SQLAlchemy's declarative `index=True` can only + # express a full index, so the ORM and migration intentionally differ on + # the WHERE clause; the migration's index wins. + creator_user_id = Column(String, nullable=True, index=True) + creator_service_account_id = Column(String, nullable=True, index=True) # Many-to-Many relationship with agents agents = relationship("AgentORM", secondary="task_agents", back_populates="tasks") @@ -82,6 +89,10 @@ class TaskORM(BaseORM): __table_args__ = ( # Index for filtering tasks by status (used in list queries) Index("ix_tasks_status", "status"), + CheckConstraint( + "creator_user_id IS NULL OR creator_service_account_id IS NULL", + name="ck_tasks_at_most_one_creator", + ), ) diff --git a/agentex/src/config/environment_variables.py b/agentex/src/config/environment_variables.py index 2d41740a..861ae9f6 100644 --- a/agentex/src/config/environment_variables.py +++ b/agentex/src/config/environment_variables.py @@ -58,6 +58,7 @@ class EnvVarKeys(str, Enum): AGENTEX_SERVER_TASK_QUEUE = "AGENTEX_SERVER_TASK_QUEUE" ENABLE_HEALTH_CHECK_WORKFLOW = "ENABLE_HEALTH_CHECK_WORKFLOW" WEBHOOK_REQUEST_TIMEOUT = "WEBHOOK_REQUEST_TIMEOUT" + FGAC_TASKS_DUAL_WRITE = "FGAC_TASKS_DUAL_WRITE" class Environment(str, Enum): @@ -114,6 +115,10 @@ class EnvironmentVariables(BaseModel): AGENTEX_SERVER_TASK_QUEUE: str | None = None ENABLE_HEALTH_CHECK_WORKFLOW: bool = False WEBHOOK_REQUEST_TIMEOUT: float = 15.0 # Webhook request timeout in seconds + # AGX1-274: gate the task FGAC dual-write call sites. Off by default so + # rollout is operator-controlled per environment. Mirrors KB's + # ``FGAC_KNOWLEDGE_BASES_DUAL_WRITE`` shape. + FGAC_TASKS_DUAL_WRITE: bool = False @classmethod def refresh(cls, force_refresh: bool = False) -> EnvironmentVariables | None: @@ -203,6 +208,10 @@ def refresh(cls, force_refresh: bool = False) -> EnvironmentVariables | None: WEBHOOK_REQUEST_TIMEOUT=float( os.environ.get(EnvVarKeys.WEBHOOK_REQUEST_TIMEOUT, "15.0") ), + FGAC_TASKS_DUAL_WRITE=( + os.environ.get(EnvVarKeys.FGAC_TASKS_DUAL_WRITE, "false").lower() + == "true" + ), ) refreshed_environment_variables = environment_variables return refreshed_environment_variables diff --git a/agentex/src/domain/entities/tasks.py b/agentex/src/domain/entities/tasks.py index 6a1ffce7..7b40d07f 100644 --- a/agentex/src/domain/entities/tasks.py +++ b/agentex/src/domain/entities/tasks.py @@ -62,6 +62,14 @@ class TaskEntity(BaseModel): None, title="Task metadata", ) + creator_user_id: str | None = Field( + None, + title="Identity ID of the user who created this task", + ) + creator_service_account_id: str | None = Field( + None, + title="Service identity ID of the service account that created this task", + ) # allow extra fields for agents relationships model_config = ConfigDict(extra="allow") diff --git a/agentex/src/domain/services/authorization_service.py b/agentex/src/domain/services/authorization_service.py index 8400603b..c1b798fb 100644 --- a/agentex/src/domain/services/authorization_service.py +++ b/agentex/src/domain/services/authorization_service.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Annotated +from typing import Annotated, Any from fastapi import Depends, Request @@ -17,6 +17,11 @@ logger = make_logger(__name__) +# Sentinel for "caller did not pass an explicit principal_context" — None is a +# valid value (the bypass path logs ``for principal None``) so it can't be the +# default. Using a named object reads better at the call site than ``...``. +_UNSET: Any = object() + class AuthorizationService: def __init__( @@ -40,7 +45,11 @@ def is_enabled(self) -> bool: return self.enabled async def grant( - self, resource: AgentexResource, *, commit: bool = True, principal_context=... + self, + resource: AgentexResource, + *, + commit: bool = True, + principal_context: Any = _UNSET, ) -> None: if self._bypass(): logger.info( @@ -57,7 +66,7 @@ async def grant( ) result = await self.gateway.grant( principal_context - if principal_context is not ... + if principal_context is not _UNSET else self.principal_context, resource, AuthorizedOperationType.create, @@ -65,7 +74,11 @@ async def grant( return result async def revoke( - self, resource: AgentexResource, *, commit: bool = True, principal_context=... + self, + resource: AgentexResource, + *, + commit: bool = True, + principal_context: Any = _UNSET, ) -> None: if self._bypass(): logger.info("Authorization bypassed for revoke operation") @@ -81,7 +94,7 @@ async def revoke( result = await self.gateway.revoke( principal_context - if principal_context is not ... + if principal_context is not _UNSET else self.principal_context, resource, AuthorizedOperationType.delete, @@ -96,7 +109,7 @@ async def check( resource: AgentexResource, operation: AuthorizedOperationType, *, - principal_context=..., + principal_context: Any = _UNSET, ) -> bool: if self._bypass(): logger.info("Authorization bypassed for check operation") @@ -105,7 +118,7 @@ async def check( # Determine which principal context to use effective_principal = ( principal_context - if principal_context is not ... + if principal_context is not _UNSET else self.principal_context ) @@ -157,12 +170,75 @@ async def check( ) return result + async def register_resource( + self, + resource: AgentexResource, + *, + parent: AgentexResource | None = None, + principal_context: Any = _UNSET, + ) -> None: + """Register a freshly-created resource with the authorization graph. + + Used immediately after persisting a new row to write the tenant + + owner (and optionally typed parent) relation tuples atomically. + Distinct from ``grant`` because ``grant`` only writes a single + role relation, which is insufficient for schemas (e.g. ``task``) + that require a ``tenant->membership`` gate. + """ + if self._bypass(): + logger.info( + f"Authorization bypassed for register_resource on {resource.type}:{resource.selector}" + ) + return None + + logger.info( + "[authorization_service] Registering resource %s:%s for principal %s (parent=%s)", + resource.type, + resource.selector, + self.principal_context, + parent, + ) + await self.gateway.register_resource( + principal_context + if principal_context is not _UNSET + else self.principal_context, + resource, + parent, + ) + + async def deregister_resource( + self, + resource: AgentexResource, + *, + principal_context: Any = _UNSET, + ) -> None: + """Remove every relation tuple written for the resource — used when + deleting the underlying database row.""" + if self._bypass(): + logger.info( + f"Authorization bypassed for deregister_resource on {resource.type}:{resource.selector}" + ) + return None + + logger.info( + "[authorization_service] Deregistering resource %s:%s for principal %s", + resource.type, + resource.selector, + self.principal_context, + ) + await self.gateway.deregister_resource( + principal_context + if principal_context is not _UNSET + else self.principal_context, + resource, + ) + async def list_resources( self, filter_resource: AgentexResourceType, filter_operation: AuthorizedOperationType = AuthorizedOperationType.read, *, - principal_context=..., + principal_context: Any = _UNSET, ) -> Iterable[str] | None: """List resource identifiers for which the current principal has *filter_operation* permission.""" @@ -178,7 +254,7 @@ async def list_resources( ) result = await self.gateway.list_resources( principal_context - if principal_context is not ... + if principal_context is not _UNSET else self.principal_context, filter_resource, filter_operation, diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 013c6903..9bcbccdc 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -1,9 +1,21 @@ -from collections.abc import AsyncIterator +import asyncio +import os +import random +from collections.abc import AsyncIterator, Awaitable, Callable from typing import Annotated, Any +from datadog import statsd from fastapi import Depends +from src.adapters.authentication.exceptions import ( + AuthenticationGatewayError, + AuthenticationServiceUnavailableError, +) +from src.adapters.crud_store.exceptions import ItemDoesNotExist from src.adapters.streams.adapter_redis import DRedisStreamRepository +from src.api.schemas.authorization_types import AgentexResource +from src.config.dependencies import DEnvironmentVariable +from src.config.environment_variables import EnvVarKeys from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.events import EventEntity from src.domain.entities.task_message_updates import TaskMessageUpdateEntity @@ -14,12 +26,47 @@ from src.domain.repositories.task_repository import DTaskRepository from src.domain.repositories.task_state_repository import DTaskStateRepository from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.authorization_service import DAuthorizationService from src.utils.ids import orm_id from src.utils.logging import make_logger from src.utils.stream_topics import get_task_event_stream_topic logger = make_logger(__name__) +# Retry shape mirrored from agents_acp_use_case.grant_with_retry — keeps the +# task FGAC dual-write resilient to transient agentex-auth flakiness without +# pulling in a heavier retry framework. +_REGISTER_MAX_RETRIES = 3 +_REGISTER_BASE_BACKOFF_SECONDS = 0.2 + +# Dual-write rollout observability. Counters fire only when the Datadog +# agent host is configured (mirrors src/utils/db_metrics.py). Use these to +# drive the FGAC_TASKS_DUAL_WRITE rollout dashboard — without them the only +# signal is logger.error lines. +_STATSD_ENABLED = bool(os.environ.get("DD_AGENT_HOST")) + + +def _emit_dual_write_metric( + metric: str, *, op: str, exception_class: str | None = None +) -> None: + if not _STATSD_ENABLED: + return + tags = [f"op:{op}"] + if exception_class is not None: + tags.append(f"exception_class:{exception_class}") + statsd.increment(metric, tags=tags) + + +def _principal_field(principal_context: Any, key: str) -> str | None: + """Read an attribute from the principal context, which may be either a + Pydantic-style object or a plain dict. The authn proxy returns a dict + (``response.json()`` shape), so ``getattr`` alone silently yields None.""" + if principal_context is None: + return None + if isinstance(principal_context, dict): + return principal_context.get(key) + return getattr(principal_context, key, None) + class AgentTaskService: """ @@ -33,12 +80,22 @@ def __init__( task_repository: DTaskRepository, event_repository: DEventRepository, stream_repository: DRedisStreamRepository, + authorization_service: DAuthorizationService, + dual_write_enabled: DEnvironmentVariable( + EnvVarKeys.FGAC_TASKS_DUAL_WRITE + ) = False, ): self.acp_client = acp_client self.task_state_repository = task_state_repository self.task_repository = task_repository self.event_repository = event_repository self.stream_repository = stream_repository + self.authorization_service = authorization_service + # Read once from EnvironmentVariables at DI-resolve time (process + # singleton) — operator rollout assumes a redeploy cycles pods, so a + # mid-process flag flip is intentionally invisible. Tests construct + # the service directly and pass the desired bool. + self.dual_write_enabled = dual_write_enabled async def create_task( self, @@ -59,6 +116,11 @@ async def create_task( Returns: Task containing the created task info """ + principal_context = self.authorization_service.principal_context + creator_user_id = _principal_field(principal_context, "user_id") + creator_service_account_id = _principal_field( + principal_context, "service_account_id" + ) task_entity = await self.task_repository.create( agent_id=agent.id, @@ -69,10 +131,107 @@ async def create_task( status_reason="Task created, forwarding to ACP server", params=task_params, task_metadata=task_metadata, + creator_user_id=creator_user_id, + creator_service_account_id=creator_service_account_id, ), ) + + # AGX1-274: dual-write the task into the authorization graph so route + # enforcement (AGX1-275) can resolve tenant + owner + parent_agent + # tuples for this task. Gated behind ``FGAC_TASKS_DUAL_WRITE`` for + # operator-controlled rollout. Retries cover network-layer transients + # and named 502/503 responses from agentex-auth; bare 500s (e.g. an + # unhandled upstream exception surfacing as FastAPI's default) are + # NOT retried and propagate on the first attempt. The + # ``task_fgac_dual_write.failure{op=register}`` metric is the rollout + # signal — when it fires, the AGX1-291 operator runbook re-runs + # register from a script against the orphan rows identified via the + # ``creator_user_id`` / ``creator_service_account_id`` audit columns. + # Unlike ``agents_acp_use_case.grant_with_retry`` there is no outer + # ``fail_task`` fallback here — the Postgres row is the durable + # record and orphan tuples in the auth graph are preferable to + # losing the task. + if self.dual_write_enabled: + await self._dual_write_with_retry( + op_name="register", + do_call=lambda: self.authorization_service.register_resource( + AgentexResource.task(task_entity.id), + parent=AgentexResource.agent(agent.id), + ), + task_id=task_entity.id, + ) + return task_entity + async def _dual_write_with_retry( + self, + *, + op_name: str, + do_call: Callable[[], Awaitable[None]], + task_id: str, + attempts: int = 0, + ) -> None: + # Caller-retry edge case: when this exhausts retries on `register`, the + # Postgres row is already committed by `create_task`. A retrying HTTP + # client will call `create_task` again, persisting a *second* row + # before `register` succeeds — both rows then show up in the + # `creator_user_id`/`creator_service_account_id` audit scan, + # indistinguishable from a single orphan. The AGX1-291 reconciliation + # must dedup by `(account_id, name)` and discard the older-by- + # `created_at` row when a duplicate pair is found. + _emit_dual_write_metric("task_fgac_dual_write.attempt", op=op_name) + try: + await do_call() + except ( + AuthenticationServiceUnavailableError, + AuthenticationGatewayError, + ) as exc: + if attempts < _REGISTER_MAX_RETRIES: + delay = _REGISTER_BASE_BACKOFF_SECONDS * (2**attempts) + random.uniform( + 0, 0.1 + ) + logger.error( + "task FGAC %s transient failure for task %s: %s. " + "Retrying in %.2fs (attempt %d/%d).", + op_name, + task_id, + exc, + delay, + attempts + 1, + _REGISTER_MAX_RETRIES, + ) + _emit_dual_write_metric("task_fgac_dual_write.retry", op=op_name) + await asyncio.sleep(delay) + return await self._dual_write_with_retry( + op_name=op_name, + do_call=do_call, + task_id=task_id, + attempts=attempts + 1, + ) + logger.error( + "task FGAC %s exhausted retries for task %s: %s", + op_name, + task_id, + exc, + ) + _emit_dual_write_metric( + "task_fgac_dual_write.failure", + op=op_name, + exception_class=type(exc).__name__, + ) + raise + except Exception as exc: + # Non-transient errors are not retried — emit the failure metric so + # the rollout dashboard still sees them, then propagate. + _emit_dual_write_metric( + "task_fgac_dual_write.failure", + op=op_name, + exception_class=type(exc).__name__, + ) + raise + else: + _emit_dual_write_metric("task_fgac_dual_write.success", op=op_name) + async def create_task_and_forward_to_acp( self, agent: AgentEntity, @@ -91,7 +250,9 @@ async def create_task_and_forward_to_acp( Task containing the created task info """ task_entity = await self.create_task( - agent=agent, task_name=task_name, task_params=task_params + agent=agent, + task_name=task_name, + task_params=task_params, ) if agent.acp_type == ACPType.SYNC: @@ -214,8 +375,42 @@ async def delete_task(self, id: str | None = None, name: str | None = None) -> N """ Delete a task from the repository. """ + # Resolve the task id before the Postgres delete so we can deregister + # by id afterwards. Looking up by name post-delete would race. If the + # name doesn't resolve, swallow ItemDoesNotExist here and let the + # subsequent delete() raise its own native error — flipping the flag + # must not change the error contract callers see for missing tasks. + task_id_for_deregister: str | None = id + if ( + task_id_for_deregister is None + and name is not None + and self.dual_write_enabled + ): + try: + task = await self.task_repository.get(name=name) + task_id_for_deregister = task.id + except ItemDoesNotExist: + task_id_for_deregister = None + await self.task_repository.delete(id=id, name=name) + # AGX1-274: deregister the task from the authorization graph after the + # row is gone, mirroring the agent delete ordering at + # ``delete_agent_by_id``. Transient failures share the dual-write + # retry budget with register; a retry-exhausted deregister still + # leaves orphan SpiceDB tuples for a deleted task — the + # ``task_fgac_dual_write.failure{op=deregister}`` metric is the rollout + # signal, and the AGX1-291 operator runbook covers cleanup. + if self.dual_write_enabled and task_id_for_deregister is not None: + deregister_id = task_id_for_deregister + await self._dual_write_with_retry( + op_name="deregister", + do_call=lambda: self.authorization_service.deregister_resource( + AgentexResource.task(deregister_id), + ), + task_id=deregister_id, + ) + async def list_tasks( self, *, diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index c30c06c8..04ba18a5 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -3,7 +3,7 @@ Provides factory functions and specific fixtures for creating services with test repositories. """ -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest @@ -12,6 +12,24 @@ # ============================================================================= +def make_noop_authorization_service() -> Mock: + """Shared noop AuthorizationService mock for tests that don't exercise authz. + + ``principal_context`` is ``None`` (so creator-audit columns resolve to NULL), + and ``grant``/``revoke``/``register_resource``/``deregister_resource`` are + async no-ops returning ``None`` — matching the real service signature. + Use this anywhere a test just needs to construct ``AgentTaskService`` + without caring about authorization behavior. + """ + svc = Mock() + svc.principal_context = None + svc.grant = AsyncMock(return_value=None) + svc.revoke = AsyncMock(return_value=None) + svc.register_resource = AsyncMock(return_value=None) + svc.deregister_resource = AsyncMock(return_value=None) + return svc + + def create_task_message_service(task_message_repository): """Factory function to create TaskMessageService with given repository""" from src.domain.services.task_message_service import TaskMessageService @@ -52,16 +70,23 @@ def create_task_service( event_repository, agent_acp_service, redis_stream_repository, + authorization_service=None, + dual_write_enabled: bool = False, ): - """Factory function to create AgentTaskService with given repositories and services""" + """Factory function to create AgentTaskService with given repositories and services.""" from src.domain.services.task_service import AgentTaskService + if authorization_service is None: + authorization_service = make_noop_authorization_service() + return AgentTaskService( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=authorization_service, + dual_write_enabled=dual_write_enabled, ) diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index f1396d57..34479547 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -22,6 +22,8 @@ from src.config.dependencies import GlobalDependencies from src.config.environment_variables import EnvironmentVariables +from tests.fixtures.services import make_noop_authorization_service + @pytest.fixture(scope="session") def event_loop(): @@ -448,6 +450,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) return TasksUseCase(task_service=task_service) diff --git a/agentex/tests/integration/test_task_stream.py b/agentex/tests/integration/test_task_stream.py index 289010ee..f2a0d762 100644 --- a/agentex/tests/integration/test_task_stream.py +++ b/agentex/tests/integration/test_task_stream.py @@ -7,6 +7,8 @@ from src.domain.use_cases.tasks_use_case import TasksUseCase from src.utils.ids import orm_id +from tests.fixtures.services import make_noop_authorization_service + @pytest.mark.asyncio @pytest.mark.integration @@ -76,6 +78,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) return TasksUseCase(task_service=task_service) @@ -103,6 +106,7 @@ async def send_message(self, *args, **kwargs): task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], stream_repository=isolated_repositories["redis_stream_repository"], + authorization_service=make_noop_authorization_service(), ) environment_variables = EnvironmentVariables.refresh() @@ -194,17 +198,17 @@ async def collect_stream_events(): pass # Then - Verify the stream event was received - assert ( - len(stream_events) >= 1 - ), f"Expected at least 1 stream event, got {len(stream_events)}" + assert len(stream_events) >= 1, ( + f"Expected at least 1 stream event, got {len(stream_events)}" + ) # Find the task_updated event task_updated_events = [ e for e in stream_events if e.get("type") == "task_updated" ] - assert ( - len(task_updated_events) >= 1 - ), f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}" + assert len(task_updated_events) >= 1, ( + f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}" + ) task_updated_event = task_updated_events[0] @@ -389,9 +393,9 @@ async def collect_stream_events(): task_updated_events = [ e for e in stream_events if e.get("type") == "task_updated" ] - assert ( - len(task_updated_events) >= 3 - ), f"Expected at least 3 task_updated events, got {len(task_updated_events)}" + assert len(task_updated_events) >= 3, ( + f"Expected at least 3 task_updated events, got {len(task_updated_events)}" + ) # Verify each event has the correct metadata for its update versions = [ @@ -599,8 +603,8 @@ async def collect_stream_data(): pass # Then - Verify we received at least 2 pings - assert ( - ping_count >= 2 - ), f"Expected at least 2 ping messages during idle period, got {ping_count}" + assert ping_count >= 2, ( + f"Expected at least 2 ping messages during idle period, got {ping_count}" + ) print(f"✅ Stream sent {ping_count} keepalive pings during idle period") diff --git a/agentex/tests/integration/use_cases/__init__.py b/agentex/tests/integration/use_cases/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/integration/use_cases/test_task_audit_columns.py b/agentex/tests/integration/use_cases/test_task_audit_columns.py new file mode 100644 index 00000000..069974cb --- /dev/null +++ b/agentex/tests/integration/use_cases/test_task_audit_columns.py @@ -0,0 +1,179 @@ +"""Tests for task creator-audit column population in ``task_service.create_task``. + +Asserts that: +- A user principal populates ``creator_user_id`` and leaves ``creator_service_account_id`` NULL. +- A service-account principal populates ``creator_service_account_id`` and leaves ``creator_user_id`` NULL. +- A principal with no resolvable creator leaves both columns NULL; the row is still inserted. +- The ``ck_tasks_at_most_one_creator`` CHECK constraint rejects rows where both columns are set. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Literal +from unittest.mock import AsyncMock, Mock +from uuid import uuid4 + +import pytest +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.entities.tasks import TaskEntity, TaskStatus +from src.domain.exceptions import ServiceError +from src.domain.services.task_service import AgentTaskService + +# In production the authn proxy returns ``response.json()`` (a plain dict), +# so the test must lock in the dict shape as well as the SimpleNamespace shape. +PrincipalShape = Literal["namespace", "dict"] +PRINCIPAL_SHAPES: list[PrincipalShape] = ["namespace", "dict"] + + +def _principal( + shape: PrincipalShape = "namespace", + *, + user_id: str | None = None, + service_account_id: str | None = None, + account_id: str | None = "acct-1", +) -> Any: + if shape == "dict": + return { + "user_id": user_id, + "service_account_id": service_account_id, + "account_id": account_id, + } + return SimpleNamespace( + user_id=user_id, + service_account_id=service_account_id, + account_id=account_id, + ) + + +def _build_service( + *, + task_repository, + principal: Any, +) -> AgentTaskService: + authorization_service = Mock() + authorization_service.principal_context = principal + authorization_service.grant = AsyncMock(return_value={}) + authorization_service.revoke = AsyncMock(return_value=None) + # Audit-column tests run with FGAC_TASKS_DUAL_WRITE off (default), so + # these are wired as no-ops to satisfy the AgentTaskService surface + # without exercising the dual-write path. + authorization_service.register_resource = AsyncMock(return_value=None) + authorization_service.deregister_resource = AsyncMock(return_value=None) + + return AgentTaskService( + acp_client=Mock(), + task_state_repository=Mock(), + task_repository=task_repository, + event_repository=Mock(), + stream_repository=Mock(), + authorization_service=authorization_service, + dual_write_enabled=False, + ) + + +async def _persist_agent(agent_repository) -> AgentEntity: + agent = AgentEntity( + id=str(uuid4()), + name=f"audit-agent-{uuid4().hex[:8]}", + description="audit-column test agent", + status=AgentStatus.READY, + acp_type=ACPType.SYNC, + acp_url="http://test-acp", + ) + return await agent_repository.create(agent) + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestTaskAuditColumns: + # Uses ``isolated_repositories`` (per-test schema, commit-safe) rather than + # the unit-style ``task_repository`` fixture (shared session wrapped in + # rollback). Repository ``create`` methods commit internally, which would + # close the wrapping transaction and break teardown. + + @pytest.mark.parametrize("shape", PRINCIPAL_SHAPES) + async def test_user_principal_populates_user_id( + self, + isolated_repositories, + shape: PrincipalShape, + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service = _build_service( + task_repository=task_repo, + principal=_principal(shape, user_id="user-A"), + ) + + task = await service.create_task( + agent=agent, task_name=f"audit-user-{uuid4().hex[:8]}" + ) + + assert task.creator_user_id == "user-A" + assert task.creator_service_account_id is None + + @pytest.mark.parametrize("shape", PRINCIPAL_SHAPES) + async def test_service_account_principal_populates_service_account_id( + self, + isolated_repositories, + shape: PrincipalShape, + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service = _build_service( + task_repository=task_repo, + principal=_principal(shape, service_account_id="svc-1"), + ) + + task = await service.create_task( + agent=agent, task_name=f"audit-svc-{uuid4().hex[:8]}" + ) + + assert task.creator_user_id is None + assert task.creator_service_account_id == "svc-1" + + @pytest.mark.parametrize("shape", PRINCIPAL_SHAPES) + async def test_no_resolvable_creator_leaves_both_columns_null( + self, + isolated_repositories, + shape: PrincipalShape, + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service = _build_service( + task_repository=task_repo, + principal=_principal(shape), + ) + + task = await service.create_task( + agent=agent, task_name=f"audit-noop-{uuid4().hex[:8]}" + ) + + assert task.creator_user_id is None + assert task.creator_service_account_id is None + # Row was still inserted — fetch back to confirm. + fetched = await task_repo.get(id=task.id) + assert fetched is not None + + async def test_check_constraint_rejects_both_creators(self, isolated_repositories): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + # The repository wraps IntegrityError in ServiceError; the underlying + # CheckViolationError + constraint name still appears in the message. + with pytest.raises(ServiceError) as excinfo: + await task_repo.create( + agent_id=agent.id, + task=TaskEntity( + id=str(uuid4()), + name=f"audit-both-{uuid4().hex[:8]}", + status=TaskStatus.RUNNING, + status_reason="should be rejected", + creator_user_id="user-A", + creator_service_account_id="svc-1", + ), + ) + assert "ck_tasks_at_most_one_creator" in str(excinfo.value) diff --git a/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py b/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py new file mode 100644 index 00000000..4302d0f4 --- /dev/null +++ b/agentex/tests/integration/use_cases/test_task_fgac_dual_write.py @@ -0,0 +1,284 @@ +"""Tests for the AGX1-274 task FGAC dual-write call sites. + +Asserts that when ``FGAC_TASKS_DUAL_WRITE`` is on: + +- ``task_service.create_task`` calls ``register_resource`` with the new + task and the agent as ``parent``. +- ``task_service.delete_task`` calls ``deregister_resource`` with the task + after the Postgres row is gone. +- A transient ``AuthenticationServiceUnavailableError`` is retried with + backoff and ultimately succeeds on a later attempt. + +And when the flag is off: + +- Neither ``register_resource`` nor ``deregister_resource`` is invoked. + +These run as integration tests because ``create_task`` / ``delete_task`` +touch the real task repository through the ``isolated_repositories`` +fixture; only the authorization service is stubbed so the call shapes are +asserted directly. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, Mock +from uuid import uuid4 + +import pytest +from src.adapters.authentication.exceptions import ( + AuthenticationServiceUnavailableError, +) +from src.api.schemas.authorization_types import AgentexResource, AgentexResourceType +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.services.task_service import AgentTaskService + + +def _principal() -> Any: + return SimpleNamespace( + user_id="user-A", + service_account_id=None, + account_id="acct-1", + ) + + +def _build_service( + *, + task_repository, + dual_write_principal: Any, + dual_write_enabled: bool, +) -> tuple[AgentTaskService, Mock]: + authorization_service = Mock() + authorization_service.principal_context = dual_write_principal + authorization_service.grant = AsyncMock(return_value={}) + authorization_service.revoke = AsyncMock(return_value=None) + authorization_service.register_resource = AsyncMock(return_value=None) + authorization_service.deregister_resource = AsyncMock(return_value=None) + + service = AgentTaskService( + acp_client=Mock(), + task_state_repository=Mock(), + task_repository=task_repository, + event_repository=Mock(), + stream_repository=Mock(), + authorization_service=authorization_service, + dual_write_enabled=dual_write_enabled, + ) + return service, authorization_service + + +async def _persist_agent(agent_repository) -> AgentEntity: + agent = AgentEntity( + id=str(uuid4()), + name=f"dual-write-agent-{uuid4().hex[:8]}", + description="dual-write test agent", + status=AgentStatus.READY, + acp_type=ACPType.SYNC, + acp_url="http://test-acp", + ) + return await agent_repository.create(agent) + + +async def _clear_task_agent_links(task_repository, task_id: str) -> None: + """Delete the task_agents / agent_task_tracker join rows for a task. + + ``create_task`` writes both join rows, and ``task_repository.delete`` + issues a raw ``DELETE FROM tasks`` that the ``task_agents_task_id_fkey`` + FK rejects while those rows exist — this is the established, intentional + contract (see ``test_task_repository.test_delete_task`` and + ``test_task_service.test_delete_task_with_cleanup``). Tests that exercise + the hard-delete deregister path must clear the join rows first, exactly as + a real cascading delete would, otherwise the delete FK-fails before the + deregister dual-write is ever reached. + """ + from sqlalchemy import delete as sql_delete + from src.adapters.orm import AgentTaskTrackerORM, TaskAgentORM + + async with task_repository.start_async_db_session(True) as session: + await session.execute( + sql_delete(AgentTaskTrackerORM).where( + AgentTaskTrackerORM.task_id == task_id + ) + ) + await session.execute( + sql_delete(TaskAgentORM).where(TaskAgentORM.task_id == task_id) + ) + await session.commit() + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestTaskDualWrite: + async def test_create_task_registers_with_agent_as_parent( + self, isolated_repositories + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + + task = await service.create_task( + agent=agent, task_name=f"dw-create-{uuid4().hex[:8]}" + ) + + authorization_service.register_resource.assert_awaited_once() + call = authorization_service.register_resource.call_args + registered_resource: AgentexResource = call.args[0] + assert registered_resource.type == AgentexResourceType.task + assert registered_resource.selector == task.id + + parent: AgentexResource | None = call.kwargs.get("parent") + assert parent is not None + assert parent.type == AgentexResourceType.agent + assert parent.selector == agent.id + + async def test_delete_task_deregisters(self, isolated_repositories): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + task = await service.create_task( + agent=agent, task_name=f"dw-delete-{uuid4().hex[:8]}" + ) + authorization_service.deregister_resource.reset_mock() + await _clear_task_agent_links(task_repo, task.id) + + await service.delete_task(id=task.id) + + authorization_service.deregister_resource.assert_awaited_once() + deregistered: AgentexResource = ( + authorization_service.deregister_resource.call_args.args[0] + ) + assert deregistered.type == AgentexResourceType.task + assert deregistered.selector == task.id + + async def test_flag_off_skips_register_and_deregister(self, isolated_repositories): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=False, + ) + + task = await service.create_task( + agent=agent, task_name=f"dw-off-{uuid4().hex[:8]}" + ) + await _clear_task_agent_links(task_repo, task.id) + await service.delete_task(id=task.id) + + authorization_service.register_resource.assert_not_awaited() + authorization_service.deregister_resource.assert_not_awaited() + + async def test_transient_unavailable_retries_then_succeeds( + self, isolated_repositories + ): + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + # First two calls fail with transient AuthenticationServiceUnavailable, + # then the third succeeds. The retry shape (3 attempts) is mirrored + # from agents_acp_use_case.grant_with_retry. + authorization_service.register_resource.side_effect = [ + AuthenticationServiceUnavailableError("transient"), + AuthenticationServiceUnavailableError("transient"), + None, + ] + + await service.create_task(agent=agent, task_name=f"dw-retry-{uuid4().hex[:8]}") + + assert authorization_service.register_resource.await_count == 3 + + async def test_transient_unavailable_on_deregister_retries_then_succeeds( + self, isolated_repositories + ): + # Mirrors test_transient_unavailable_retries_then_succeeds for the + # deregister path: both call sites share _dual_write_with_retry. + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + task = await service.create_task( + agent=agent, task_name=f"dw-dereg-retry-{uuid4().hex[:8]}" + ) + await _clear_task_agent_links(task_repo, task.id) + authorization_service.deregister_resource.reset_mock() + authorization_service.deregister_resource.side_effect = [ + AuthenticationServiceUnavailableError("transient"), + AuthenticationServiceUnavailableError("transient"), + None, + ] + + await service.delete_task(id=task.id) + + assert authorization_service.deregister_resource.await_count == 3 + + async def test_delete_task_by_missing_name_with_flag_on_does_not_change_error_contract( + self, isolated_repositories + ): + # Flag-OFF behavior: task_repository.delete on a missing name is a + # no-op (no exception). Flag-ON must preserve that contract — the + # pre-lookup catches ItemDoesNotExist and lets the underlying delete + # produce its native (non-)error. + task_repo = isolated_repositories["task_repository"] + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + + await service.delete_task(name=f"missing-{uuid4().hex[:8]}") + + authorization_service.deregister_resource.assert_not_awaited() + + async def test_register_failure_after_retries_propagates_and_task_row_persists( + self, isolated_repositories + ): + # All four register attempts fail with transient unavailables, so + # the retry budget exhausts. The exception must propagate, the + # Postgres row must remain (it is the durable record; orphan auth + # tuples are cleaned up out of band via the AGX1-291 operator + # runbook using the creator-audit columns), and deregister_resource + # must NOT be called as part of any compensating delete. + task_repo = isolated_repositories["task_repository"] + agent_repo = isolated_repositories["agent_repository"] + agent = await _persist_agent(agent_repo) + service, authorization_service = _build_service( + task_repository=task_repo, + dual_write_principal=_principal(), + dual_write_enabled=True, + ) + authorization_service.register_resource.side_effect = [ + AuthenticationServiceUnavailableError("transient"), + AuthenticationServiceUnavailableError("transient"), + AuthenticationServiceUnavailableError("transient"), + AuthenticationServiceUnavailableError("transient"), + ] + task_name = f"dw-exhaust-{uuid4().hex[:8]}" + + with pytest.raises(AuthenticationServiceUnavailableError): + await service.create_task(agent=agent, task_name=task_name) + + assert authorization_service.register_resource.await_count == 4 + persisted = await task_repo.get(name=task_name) + assert persisted is not None + authorization_service.deregister_resource.assert_not_awaited() diff --git a/agentex/tests/unit/services/test_task_service.py b/agentex/tests/unit/services/test_task_service.py index eb096eb1..2e1f19c3 100644 --- a/agentex/tests/unit/services/test_task_service.py +++ b/agentex/tests/unit/services/test_task_service.py @@ -19,6 +19,8 @@ from src.domain.repositories.task_state_repository import TaskStateRepository from src.domain.services.task_service import AgentTaskService +from tests.fixtures.services import make_noop_authorization_service + async def create_or_get_agent(agent_repository, agent): """Helper to create agent or get existing one if name already exists""" @@ -84,6 +86,7 @@ def task_service( task_state_repository=task_state_repository, event_repository=event_repository, stream_repository=redis_stream_repository, + authorization_service=make_noop_authorization_service(), ) diff --git a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py index 60914adc..e6d13949 100644 --- a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py +++ b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py @@ -23,6 +23,7 @@ from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase +from tests.fixtures.services import make_noop_authorization_service @pytest.mark.unit @@ -35,9 +36,9 @@ async def test_both_agentic_and_async_have_same_allowed_methods(self): agentic_methods = set(ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC]) async_methods = set(ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.ASYNC]) - assert ( - agentic_methods == async_methods - ), "AGENTIC and ASYNC should have identical allowed RPC methods" + assert agentic_methods == async_methods, ( + "AGENTIC and ASYNC should have identical allowed RPC methods" + ) # Verify they include the expected methods expected_methods = { @@ -95,6 +96,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create AGENTIC agent @@ -148,6 +150,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create SYNC agent @@ -195,6 +198,7 @@ async def test_async_agent_forwards_task_to_acp(self): task_repository=task_repo, event_repository=event_repo, stream_repository=stream_repo, + authorization_service=make_noop_authorization_service(), ) # Create ASYNC agent @@ -355,6 +359,6 @@ async def test_agentic_and_async_agents_both_use_not_sync_logic(self): # Both AGENTIC and ASYNC should pass the same conditional checks agentic_is_not_sync = agentic_agent.acp_type != ACPType.SYNC async_is_not_sync = async_agent.acp_type != ACPType.SYNC - assert ( - agentic_is_not_sync == async_is_not_sync - ), "AGENTIC and ASYNC should have identical behavior in != SYNC checks" + assert agentic_is_not_sync == async_is_not_sync, ( + "AGENTIC and ASYNC should have identical behavior in != SYNC checks" + ) diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index b48751a4..abb7af0b 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -36,6 +36,7 @@ from src.domain.services.task_message_service import TaskMessageService from src.domain.services.task_service import AgentTaskService from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase +from tests.fixtures.services import make_noop_authorization_service # UTC timezone constant UTC = ZoneInfo("UTC") @@ -135,6 +136,7 @@ def task_service( event_repository=event_repository, acp_client=agent_acp_service, stream_repository=redis_stream_repository, + authorization_service=make_noop_authorization_service(), ) @@ -552,9 +554,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -563,33 +565,33 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its final accumulated content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the final accumulated content includes both deltas response_message = agent_messages[0] # First agent response - assert ( - "Hello" in response_message.content.content - ), f"Expected 'Hello' in final content, got '{response_message.content.content}'" - assert ( - "world!" in response_message.content.content - ), f"Expected 'world!' in final content, got '{response_message.content.content}'" + assert "Hello" in response_message.content.content, ( + f"Expected 'Hello' in final content, got '{response_message.content.content}'" + ) + assert "world!" in response_message.content.content, ( + f"Expected 'world!' in final content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_full_message( self, @@ -662,9 +664,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -673,30 +675,30 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the FULL message content is correctly stored response_message = agent_messages[0] # First agent response - assert ( - response_message.content.content == "Complete message in one chunk" - ), f"Expected complete message content, got '{response_message.content.content}'" + assert response_message.content.content == "Complete message in one chunk", ( + f"Expected complete message content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_multiple_indexes( self, @@ -820,9 +822,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -831,30 +833,30 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 3 new messages (input + 2 response messages for different indexes) - assert ( - len(new_messages) >= 3 - ), f"Expected at least 3 new messages (input + 2 responses), got {len(new_messages)}" + assert len(new_messages) >= 3, ( + f"Expected at least 3 new messages (input + 2 responses), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response messages in database" + assert "agent" in content_authors, ( + "Should have agent response messages in database" + ) # Find the agent response messages - should have multiple for different indexes agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 2 - ), f"Should have at least 2 agent response messages for different indexes, got {len(agent_messages)}" + assert len(agent_messages) >= 2, ( + f"Should have at least 2 agent response messages for different indexes, got {len(agent_messages)}" + ) # Verify the content includes expected text from both indexes agent_content = " ".join([msg.content.content for msg in agent_messages]) - assert ( - "First" in agent_content or "Second" in agent_content - ), f"Expected content from multiple indexes, got '{agent_content}'" + assert "First" in agent_content or "Second" in agent_content, ( + f"Expected content from multiple indexes, got '{agent_content}'" + ) async def test_handle_task_create_error( self, @@ -1160,9 +1162,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -1171,33 +1173,33 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 2 new messages (input + response) - assert ( - len(new_messages) >= 2 - ), f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + assert len(new_messages) >= 2, ( + f"Expected at least 2 new messages (input + response), got {len(new_messages)}" + ) # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify accumulated content was flushed agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the deltas were properly accumulated and flushed to database response_message = agent_messages[0] # First agent response - assert ( - "Incomplete" in response_message.content.content - ), f"Expected 'Incomplete' in flushed content, got '{response_message.content.content}'" - assert ( - "message" in response_message.content.content - ), f"Expected 'message' in flushed content, got '{response_message.content.content}'" + assert "Incomplete" in response_message.content.content, ( + f"Expected 'Incomplete' in flushed content, got '{response_message.content.content}'" + ) + assert "message" in response_message.content.content, ( + f"Expected 'message' in flushed content, got '{response_message.content.content}'" + ) async def test_handle_message_send_stream_complex_mixed_content_types( self, @@ -1432,9 +1434,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -1443,18 +1445,18 @@ def create_mock_stream(*args, **kwargs): ] # Only the newly created messages # Should have at least 3 new messages (one for each index) plus deltas potentially stored - assert ( - len(new_messages) >= 3 - ), f"Expected at least 3 new messages, got {len(new_messages)}" + assert len(new_messages) >= 3, ( + f"Expected at least 3 new messages, got {len(new_messages)}" + ) # Verify we have messages with different content types content_types_found = {msg.content.type.value for msg in new_messages} expected_types = {"tool_request", "tool_response", "text"} # At least some of the expected types should be present (depends on how deltas vs full messages are stored) - assert ( - len(content_types_found.intersection(expected_types)) > 0 - ), f"Expected some of {expected_types}, got {content_types_found}" + assert len(content_types_found.intersection(expected_types)) > 0, ( + f"Expected some of {expected_types}, got {content_types_found}" + ) # Verify index distribution - should have messages for different indexes indexes_found = {getattr(update, "index", None) for update in updates} @@ -1478,18 +1480,18 @@ def create_mock_stream(*args, **kwargs): u for u in updates if isinstance(u, StreamTaskMessageDoneEntity) ] - assert ( - len(start_updates) == 3 - ), f"Expected 3 START updates, got {len(start_updates)}" - assert ( - len(delta_updates) == 6 - ), f"Expected 6 DELTA updates, got {len(delta_updates)}" - assert ( - len(full_updates) == 1 - ), f"Expected 1 FULL update, got {len(full_updates)}" - assert ( - len(done_updates) == 2 - ), f"Expected 2 DONE updates, got {len(done_updates)} (index 0 completed with FULL message)" + assert len(start_updates) == 3, ( + f"Expected 3 START updates, got {len(start_updates)}" + ) + assert len(delta_updates) == 6, ( + f"Expected 6 DELTA updates, got {len(delta_updates)}" + ) + assert len(full_updates) == 1, ( + f"Expected 1 FULL update, got {len(full_updates)}" + ) + assert len(done_updates) == 2, ( + f"Expected 2 DONE updates, got {len(done_updates)} (index 0 completed with FULL message)" + ) # Verify content types in START messages start_content_types = {update.content.type.value for update in start_updates} @@ -1641,27 +1643,27 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Get all events from database to verify content all_events = await event_repository.list() new_events = all_events[initial_event_count:] # Only the newly created events # Should have exactly 1 new event - assert ( - len(new_events) == 1 - ), f"Expected exactly 1 new event, got {len(new_events)}" + assert len(new_events) == 1, ( + f"Expected exactly 1 new event, got {len(new_events)}" + ) # Verify the event was properly stored created_event = new_events[0] - assert ( - created_event.task_id == created_task.id - ), f"Expected task_id {created_task.id}, got {created_event.task_id}" - assert ( - created_event.content == sample_text_content - ), "Expected event content to match input" + assert created_event.task_id == created_task.id, ( + f"Expected task_id {created_task.id}, got {created_event.task_id}" + ) + assert created_event.content == sample_text_content, ( + "Expected event content to match input" + ) async def test_handle_event_send_with_task_name( self, @@ -1713,27 +1715,27 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Get all events from database to verify content all_events = await event_repository.list() new_events = all_events[initial_event_count:] # Only the newly created events # Should have exactly 1 new event - assert ( - len(new_events) == 1 - ), f"Expected exactly 1 new event, got {len(new_events)}" + assert len(new_events) == 1, ( + f"Expected exactly 1 new event, got {len(new_events)}" + ) # Verify the event was properly stored created_event = new_events[0] - assert ( - created_event.task_id == created_task.id - ), f"Expected task_id {created_task.id}, got {created_event.task_id}" - assert ( - created_event.content == sample_text_content - ), "Expected event content to match input" + assert created_event.task_id == created_task.id, ( + f"Expected task_id {created_task.id}, got {created_event.task_id}" + ) + assert created_event.content == sample_text_content, ( + "Expected event content to match input" + ) async def test_handle_event_send_with_request_headers( self, @@ -1811,9 +1813,9 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Verify HTTP call was made (mock_async_call will assert headers) mock_http_gateway.async_call.assert_called_once() @@ -1871,9 +1873,9 @@ async def mock_async_call(*args, **kwargs): # Verify database interactions - should have created an event final_event_count = len(await event_repository.list()) - assert ( - final_event_count > initial_event_count - ), "New event should have been created in database" + assert final_event_count > initial_event_count, ( + "New event should have been created in database" + ) # Verify HTTP call was made (mock_async_call will assert no headers) mock_http_gateway.async_call.assert_called_once() @@ -2056,9 +2058,9 @@ def create_mock_stream(*args, **kwargs): # Verify database interactions - should have created messages final_message_count = len(await task_message_repository.list()) - assert ( - final_message_count > initial_message_count - ), "New messages should have been created in database" + assert final_message_count > initial_message_count, ( + "New messages should have been created in database" + ) # Get all messages from database to verify content all_messages = await task_message_repository.list() @@ -2067,26 +2069,26 @@ def create_mock_stream(*args, **kwargs): # Verify we have both user input and agent response messages content_authors = {msg.content.author.value for msg in new_messages} assert "user" in content_authors, "Should have user input message in database" - assert ( - "agent" in content_authors - ), "Should have agent response message in database" + assert "agent" in content_authors, ( + "Should have agent response message in database" + ) # Find the agent response message and verify its final accumulated content agent_messages = [ msg for msg in new_messages if msg.content.author == MessageAuthor.AGENT ] - assert ( - len(agent_messages) >= 1 - ), "Should have at least one agent response message" + assert len(agent_messages) >= 1, ( + "Should have at least one agent response message" + ) # Verify the final accumulated content includes both deltas response_message = agent_messages[0] - assert ( - "Stream response" in response_message.content.content - ), f"Expected 'Stream response' in final content, got '{response_message.content.content}'" - assert ( - "to named task" in response_message.content.content - ), f"Expected 'to named task' in final content, got '{response_message.content.content}'" + assert "Stream response" in response_message.content.content, ( + f"Expected 'Stream response' in final content, got '{response_message.content.content}'" + ) + assert "to named task" in response_message.content.content, ( + f"Expected 'to named task' in final content, got '{response_message.content.content}'" + ) async def test_handle_message_send_sync_with_task_params( self,