diff --git a/src/crawlee/_utils/docs.py b/src/crawlee/_utils/docs.py index f09587c926..f9de21dfc8 100644 --- a/src/crawlee/_utils/docs.py +++ b/src/crawlee/_utils/docs.py @@ -2,7 +2,7 @@ from typing import Callable, Literal -GroupName = Literal['Classes', 'Abstract classes', 'Data structures', 'Errors', 'Functions'] +GroupName = Literal['Classes', 'Abstract classes', 'Data structures', 'Event payloads', 'Errors', 'Functions'] def docs_group(group_name: GroupName) -> Callable: # noqa: ARG001 diff --git a/src/crawlee/events/__init__.py b/src/crawlee/events/__init__.py index f0d986db93..1c2cda0173 100644 --- a/src/crawlee/events/__init__.py +++ b/src/crawlee/events/__init__.py @@ -1,4 +1,25 @@ from ._event_manager import EventManager from ._local_event_manager import LocalEventManager +from ._types import ( + Event, + EventAbortingData, + EventData, + EventExitData, + EventListener, + EventMigratingData, + EventPersistStateData, + EventSystemInfoData, +) -__all__ = ['EventManager', 'LocalEventManager'] +__all__ = [ + 'Event', + 'EventAbortingData', + 'EventData', + 'EventExitData', + 'EventListener', + 'EventManager', + 'EventMigratingData', + 'EventPersistStateData', + 'EventSystemInfoData', + 'LocalEventManager', +] diff --git a/src/crawlee/events/_event_manager.py b/src/crawlee/events/_event_manager.py index ad2f3e82ab..9beec81843 100644 --- a/src/crawlee/events/_event_manager.py +++ b/src/crawlee/events/_event_manager.py @@ -3,11 +3,13 @@ from __future__ import annotations import asyncio +import inspect from collections import defaultdict +from collections.abc import Awaitable, Callable from datetime import timedelta from functools import wraps from logging import getLogger -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union, cast, overload from pyee.asyncio import AsyncIOEventEmitter @@ -15,14 +17,22 @@ from crawlee._utils.docs import docs_group from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events._types import Event, EventPersistStateData +from crawlee.events._types import ( + Event, + EventAbortingData, + EventExitData, + EventListener, + EventMigratingData, + EventPersistStateData, + EventSystemInfoData, +) if TYPE_CHECKING: from types import TracebackType from typing_extensions import NotRequired - from crawlee.events._types import EventData, Listener, WrappedListener + from crawlee.events._types import EventData, WrappedListener logger = getLogger(__name__) @@ -71,7 +81,7 @@ def __init__( # Store the mapping between events, listeners and their wrappers in the following way: # event -> listener -> [wrapped_listener_1, wrapped_listener_2, ...] - self._listeners_to_wrappers: dict[Event, dict[Listener, list[WrappedListener]]] = defaultdict( + self._listeners_to_wrappers: dict[Event, dict[EventListener[Any], list[WrappedListener]]] = defaultdict( lambda: defaultdict(list), ) @@ -125,22 +135,41 @@ async def __aexit__( await self._emit_persist_state_event_rec_task.stop() self._active = False - def on(self, *, event: Event, listener: Listener) -> None: + @overload + def on(self, *, event: Literal[Event.PERSIST_STATE], listener: EventListener[EventPersistStateData]) -> None: ... + @overload + def on(self, *, event: Literal[Event.SYSTEM_INFO], listener: EventListener[EventSystemInfoData]) -> None: ... + @overload + def on(self, *, event: Literal[Event.MIGRATING], listener: EventListener[EventMigratingData]) -> None: ... + @overload + def on(self, *, event: Literal[Event.ABORTING], listener: EventListener[EventAbortingData]) -> None: ... + @overload + def on(self, *, event: Literal[Event.EXIT], listener: EventListener[EventExitData]) -> None: ... + @overload + def on(self, *, event: Event, listener: EventListener[None]) -> None: ... + + def on(self, *, event: Event, listener: EventListener[Any]) -> None: """Add an event listener to the event manager. Args: - event: The Actor event for which to listen to. + event: The event for which to listen to. listener: The function (sync or async) which is to be called when the event is emitted. """ + signature = inspect.signature(listener) - @wraps(listener) + @wraps(cast(Callable[..., Union[None, Awaitable[None]]], listener)) async def listener_wrapper(event_data: EventData) -> None: + try: + bound_args = signature.bind(event_data) + except TypeError: # Parameterless listener + bound_args = signature.bind() + # If the listener is a coroutine function, just call it, otherwise, run it in a separate thread # to avoid blocking the event loop coro = ( - listener(event_data) + listener(*bound_args.args, **bound_args.kwargs) if asyncio.iscoroutinefunction(listener) - else asyncio.to_thread(listener, event_data) + else asyncio.to_thread(cast(Callable[..., None], listener), *bound_args.args, **bound_args.kwargs) ) # Note: use `asyncio.iscoroutinefunction` rather then `inspect.iscoroutinefunction` since it works with # unittests.mock.AsyncMock. See https://github.com/python/cpython/issues/84753. @@ -165,7 +194,7 @@ async def listener_wrapper(event_data: EventData) -> None: self._listeners_to_wrappers[event][listener].append(listener_wrapper) self._event_emitter.add_listener(event.value, listener_wrapper) - def off(self, *, event: Event, listener: Listener | None = None) -> None: + def off(self, *, event: Event, listener: EventListener[Any] | None = None) -> None: """Remove a listener, or all listeners, from an Actor event. Args: @@ -181,6 +210,19 @@ def off(self, *, event: Event, listener: Listener | None = None) -> None: self._listeners_to_wrappers[event] = defaultdict(list) self._event_emitter.remove_all_listeners(event.value) + @overload + def emit(self, *, event: Literal[Event.PERSIST_STATE], event_data: EventPersistStateData) -> None: ... + @overload + def emit(self, *, event: Literal[Event.SYSTEM_INFO], event_data: EventSystemInfoData) -> None: ... + @overload + def emit(self, *, event: Literal[Event.MIGRATING], event_data: EventMigratingData) -> None: ... + @overload + def emit(self, *, event: Literal[Event.ABORTING], event_data: EventAbortingData) -> None: ... + @overload + def emit(self, *, event: Literal[Event.EXIT], event_data: EventExitData) -> None: ... + @overload + def emit(self, *, event: Event, event_data: Any) -> None: ... + @ensure_context def emit(self, *, event: Event, event_data: EventData) -> None: """Emit an event. diff --git a/src/crawlee/events/_types.py b/src/crawlee/events/_types.py index 7aedfd270b..22c571c58d 100644 --- a/src/crawlee/events/_types.py +++ b/src/crawlee/events/_types.py @@ -2,15 +2,16 @@ from collections.abc import Callable, Coroutine from enum import Enum -from typing import Annotated, Any, Union +from typing import Annotated, Any, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field +from crawlee._utils.docs import docs_group from crawlee._utils.system import CpuInfo, MemoryUsageInfo class Event(str, Enum): - """Enum of all possible events that can be emitted.""" + """Names of all possible events that can be emitted using an `EventManager`.""" # Core events PERSIST_STATE = 'persistState' @@ -30,6 +31,7 @@ class Event(str, Enum): PAGE_CLOSED = 'pageClosed' +@docs_group('Event payloads') class EventPersistStateData(BaseModel): """Data for the persist state event.""" @@ -38,6 +40,7 @@ class EventPersistStateData(BaseModel): is_migrating: Annotated[bool, Field(alias='isMigrating')] +@docs_group('Event payloads') class EventSystemInfoData(BaseModel): """Data for the system info event.""" @@ -50,18 +53,21 @@ class EventSystemInfoData(BaseModel): ] +@docs_group('Event payloads') class EventMigratingData(BaseModel): """Data for the migrating event.""" model_config = ConfigDict(populate_by_name=True) +@docs_group('Event payloads') class EventAbortingData(BaseModel): """Data for the aborting event.""" model_config = ConfigDict(populate_by_name=True) +@docs_group('Event payloads') class EventExitData(BaseModel): """Data for the exit event.""" @@ -69,7 +75,19 @@ class EventExitData(BaseModel): EventData = Union[EventPersistStateData, EventSystemInfoData, EventMigratingData, EventAbortingData, EventExitData] -SyncListener = Callable[..., None] -AsyncListener = Callable[..., Coroutine[Any, Any, None]] -Listener = Union[SyncListener, AsyncListener] +"""A helper type for all possible event payloads""" + WrappedListener = Callable[..., Coroutine[Any, Any, None]] + +TEvent = TypeVar('TEvent') +EventListener = Union[ + Callable[ + [TEvent], + Union[None, Coroutine[Any, Any, None]], + ], + Callable[ + [], + Union[None, Coroutine[Any, Any, None]], + ], +] +"""An event listener function - it can be both sync and async and may accept zero or one argument.""" diff --git a/tests/unit/events/test_event_manager.py b/tests/unit/events/test_event_manager.py index 516501811e..853c773f5d 100644 --- a/tests/unit/events/test_event_manager.py +++ b/tests/unit/events/test_event_manager.py @@ -3,13 +3,13 @@ import asyncio import logging from datetime import timedelta +from functools import update_wrapper from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock import pytest -from crawlee.events import EventManager -from crawlee.events._types import Event, EventSystemInfoData +from crawlee.events import Event, EventManager, EventSystemInfoData if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -28,15 +28,21 @@ def event_system_info_data() -> EventSystemInfoData: @pytest.fixture def async_listener() -> AsyncMock: + async def async_listener(payload: Any) -> None: + pass + al = AsyncMock() - al.__name__ = 'async_listener' # To avoid issues with the function name + update_wrapper(al, async_listener) return al @pytest.fixture def sync_listener() -> MagicMock: + def sync_listener(payload: Any) -> None: + pass + sl = MagicMock() - sl.__name__ = 'sync_listener' # To avoid issues with the function name + update_wrapper(sl, sync_listener) return sl @@ -83,11 +89,36 @@ async def test_emit_event_with_no_listeners( # Attempt to emit an event for which no listeners are registered, it should not fail event_manager.emit(event=Event.SYSTEM_INFO, event_data=event_system_info_data) + await asyncio.sleep(0.1) # Allow some time for the event to be processed # Ensure the listener for the other event was not called assert async_listener.call_count == 0 +async def test_emit_invokes_parameterless_listener( + event_manager: EventManager, + event_system_info_data: EventSystemInfoData, +) -> None: + sync_mock = MagicMock() + + def sync_listener() -> None: + sync_mock() + + async_mock = MagicMock() + + async def async_listener() -> None: + async_mock() + + event_manager.on(event=Event.SYSTEM_INFO, listener=sync_listener) + event_manager.on(event=Event.SYSTEM_INFO, listener=async_listener) + + event_manager.emit(event=Event.SYSTEM_INFO, event_data=event_system_info_data) + await asyncio.sleep(0.1) # Allow some time for the event to be processed + + assert sync_mock.call_count == 1 + assert async_mock.call_count == 1 + + async def test_remove_nonexistent_listener_does_not_fail( async_listener: AsyncMock, event_manager: EventManager, diff --git a/tests/unit/events/test_local_event_manager.py b/tests/unit/events/test_local_event_manager.py index 65a22d43cb..481c7da16b 100644 --- a/tests/unit/events/test_local_event_manager.py +++ b/tests/unit/events/test_local_event_manager.py @@ -2,6 +2,8 @@ import asyncio from datetime import timedelta +from functools import update_wrapper +from typing import Any from unittest.mock import AsyncMock import pytest @@ -12,8 +14,11 @@ @pytest.fixture def listener() -> AsyncMock: + async def async_listener(payload: Any) -> None: + pass + al = AsyncMock() - al.__name__ = 'listener' # To avoid issues with the function name + update_wrapper(al, async_listener) return al