Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expose event types, improve on/emit signature, allow parameterless listeners #800

Merged
merged 6 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/crawlee/_utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Callable, Literal

GroupName = Literal['Classes', 'Abstract classes', 'Data structures', 'Errors', 'Functions']
GroupName = Literal['Classes', 'Abstract classes', 'Data structures', 'Event payloads', 'Errors', 'Functions']


def docs_group(group_name: GroupName) -> Callable: # noqa: ARG001
Expand Down
23 changes: 22 additions & 1 deletion src/crawlee/events/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
from ._event_manager import EventManager
from ._local_event_manager import LocalEventManager
from ._types import (
Event,
EventAbortingData,
EventData,
EventExitData,
EventListener,
EventMigratingData,
EventPersistStateData,
EventSystemInfoData,
)

__all__ = ['EventManager', 'LocalEventManager']
__all__ = [
'Event',
'EventAbortingData',
'EventData',
'EventExitData',
'EventListener',
'EventManager',
'EventMigratingData',
'EventPersistStateData',
'EventSystemInfoData',
'LocalEventManager',
]
janbuchar marked this conversation as resolved.
Show resolved Hide resolved
62 changes: 52 additions & 10 deletions src/crawlee/events/_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,36 @@
from __future__ import annotations

import asyncio
import inspect
from collections import defaultdict
from collections.abc import Awaitable, Callable
from datetime import timedelta
from functools import wraps
from logging import getLogger
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union, cast, overload

from pyee.asyncio import AsyncIOEventEmitter

from crawlee._utils.context import ensure_context
from crawlee._utils.docs import docs_group
from crawlee._utils.recurring_task import RecurringTask
from crawlee._utils.wait import wait_for_all_tasks_for_finish
from crawlee.events._types import Event, EventPersistStateData
from crawlee.events._types import (
Event,
EventAbortingData,
EventExitData,
EventListener,
EventMigratingData,
EventPersistStateData,
EventSystemInfoData,
)

if TYPE_CHECKING:
from types import TracebackType

from typing_extensions import NotRequired

from crawlee.events._types import EventData, Listener, WrappedListener
from crawlee.events._types import EventData, WrappedListener

logger = getLogger(__name__)

Expand Down Expand Up @@ -71,7 +81,7 @@ def __init__(

# Store the mapping between events, listeners and their wrappers in the following way:
# event -> listener -> [wrapped_listener_1, wrapped_listener_2, ...]
self._listeners_to_wrappers: dict[Event, dict[Listener, list[WrappedListener]]] = defaultdict(
self._listeners_to_wrappers: dict[Event, dict[EventListener[Any], list[WrappedListener]]] = defaultdict(
lambda: defaultdict(list),
)

Expand Down Expand Up @@ -125,22 +135,41 @@ async def __aexit__(
await self._emit_persist_state_event_rec_task.stop()
self._active = False

def on(self, *, event: Event, listener: Listener) -> None:
@overload
def on(self, *, event: Literal[Event.PERSIST_STATE], listener: EventListener[EventPersistStateData]) -> None: ...
@overload
def on(self, *, event: Literal[Event.SYSTEM_INFO], listener: EventListener[EventSystemInfoData]) -> None: ...
@overload
def on(self, *, event: Literal[Event.MIGRATING], listener: EventListener[EventMigratingData]) -> None: ...
@overload
def on(self, *, event: Literal[Event.ABORTING], listener: EventListener[EventAbortingData]) -> None: ...
@overload
def on(self, *, event: Literal[Event.EXIT], listener: EventListener[EventExitData]) -> None: ...
@overload
def on(self, *, event: Event, listener: EventListener[None]) -> None: ...
Pijukatel marked this conversation as resolved.
Show resolved Hide resolved

def on(self, *, event: Event, listener: EventListener[Any]) -> None:
Pijukatel marked this conversation as resolved.
Show resolved Hide resolved
"""Add an event listener to the event manager.

Args:
event: The Actor event for which to listen to.
event: The event for which to listen to.
listener: The function (sync or async) which is to be called when the event is emitted.
"""
signature = inspect.signature(listener)

@wraps(listener)
@wraps(cast(Callable[..., Union[None, Awaitable[None]]], listener))
async def listener_wrapper(event_data: EventData) -> None:
try:
bound_args = signature.bind(event_data)
except TypeError: # Parameterless listener
bound_args = signature.bind()

# If the listener is a coroutine function, just call it, otherwise, run it in a separate thread
# to avoid blocking the event loop
coro = (
listener(event_data)
listener(*bound_args.args, **bound_args.kwargs)
if asyncio.iscoroutinefunction(listener)
else asyncio.to_thread(listener, event_data)
else asyncio.to_thread(cast(Callable[..., None], listener), *bound_args.args, **bound_args.kwargs)
)
# Note: use `asyncio.iscoroutinefunction` rather then `inspect.iscoroutinefunction` since it works with
# unittests.mock.AsyncMock. See https://github.com/python/cpython/issues/84753.
Expand All @@ -165,7 +194,7 @@ async def listener_wrapper(event_data: EventData) -> None:
self._listeners_to_wrappers[event][listener].append(listener_wrapper)
self._event_emitter.add_listener(event.value, listener_wrapper)

def off(self, *, event: Event, listener: Listener | None = None) -> None:
def off(self, *, event: Event, listener: EventListener[Any] | None = None) -> None:
"""Remove a listener, or all listeners, from an Actor event.

Args:
Expand All @@ -181,6 +210,19 @@ def off(self, *, event: Event, listener: Listener | None = None) -> None:
self._listeners_to_wrappers[event] = defaultdict(list)
self._event_emitter.remove_all_listeners(event.value)

@overload
def emit(self, *, event: Literal[Event.PERSIST_STATE], event_data: EventPersistStateData) -> None: ...
@overload
def emit(self, *, event: Literal[Event.SYSTEM_INFO], event_data: EventSystemInfoData) -> None: ...
@overload
def emit(self, *, event: Literal[Event.MIGRATING], event_data: EventMigratingData) -> None: ...
@overload
def emit(self, *, event: Literal[Event.ABORTING], event_data: EventAbortingData) -> None: ...
@overload
def emit(self, *, event: Literal[Event.EXIT], event_data: EventExitData) -> None: ...
@overload
def emit(self, *, event: Event, event_data: Any) -> None: ...

@ensure_context
def emit(self, *, event: Event, event_data: EventData) -> None:
"""Emit an event.
Expand Down
28 changes: 23 additions & 5 deletions src/crawlee/events/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Annotated, Any, Union
from typing import Annotated, Any, TypeVar, Union

from pydantic import BaseModel, ConfigDict, Field

from crawlee._utils.docs import docs_group
from crawlee._utils.system import CpuInfo, MemoryUsageInfo


class Event(str, Enum):
"""Enum of all possible events that can be emitted."""
"""Names of all possible events that can be emitted using an EventManager."""
vdusek marked this conversation as resolved.
Show resolved Hide resolved

# Core events
PERSIST_STATE = 'persistState'
Expand All @@ -30,6 +31,7 @@ class Event(str, Enum):
PAGE_CLOSED = 'pageClosed'


@docs_group('Event payloads')
class EventPersistStateData(BaseModel):
"""Data for the persist state event."""

Expand All @@ -38,6 +40,7 @@ class EventPersistStateData(BaseModel):
is_migrating: Annotated[bool, Field(alias='isMigrating')]


@docs_group('Event payloads')
class EventSystemInfoData(BaseModel):
"""Data for the system info event."""

Expand All @@ -50,26 +53,41 @@ class EventSystemInfoData(BaseModel):
]


@docs_group('Event payloads')
class EventMigratingData(BaseModel):
"""Data for the migrating event."""

model_config = ConfigDict(populate_by_name=True)


@docs_group('Event payloads')
class EventAbortingData(BaseModel):
"""Data for the aborting event."""

model_config = ConfigDict(populate_by_name=True)


@docs_group('Event payloads')
class EventExitData(BaseModel):
"""Data for the exit event."""

model_config = ConfigDict(populate_by_name=True)


EventData = Union[EventPersistStateData, EventSystemInfoData, EventMigratingData, EventAbortingData, EventExitData]
SyncListener = Callable[..., None]
AsyncListener = Callable[..., Coroutine[Any, Any, None]]
Listener = Union[SyncListener, AsyncListener]
"""A helper type for all possible event payloads"""

WrappedListener = Callable[..., Coroutine[Any, Any, None]]

TEvent = TypeVar('TEvent')
EventListener = Union[
Callable[
[TEvent],
Union[None, Coroutine[Any, Any, None]],
],
Callable[
[],
Union[None, Coroutine[Any, Any, None]],
],
]
"""An event listener function - it can be both sync and async and may accept zero or one argument."""
39 changes: 35 additions & 4 deletions tests/unit/events/test_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import asyncio
import logging
from datetime import timedelta
from functools import update_wrapper
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock

import pytest

from crawlee.events import EventManager
from crawlee.events._types import Event, EventSystemInfoData
from crawlee.events import Event, EventManager, EventSystemInfoData

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand All @@ -28,15 +28,21 @@ def event_system_info_data() -> EventSystemInfoData:

@pytest.fixture
def async_listener() -> AsyncMock:
async def async_listener(payload: Any) -> None:
pass

al = AsyncMock()
al.__name__ = 'async_listener' # To avoid issues with the function name
update_wrapper(al, async_listener)
return al


@pytest.fixture
def sync_listener() -> MagicMock:
def sync_listener(payload: Any) -> None:
pass

sl = MagicMock()
sl.__name__ = 'sync_listener' # To avoid issues with the function name
update_wrapper(sl, sync_listener)
return sl


Expand Down Expand Up @@ -83,11 +89,36 @@ async def test_emit_event_with_no_listeners(

# Attempt to emit an event for which no listeners are registered, it should not fail
event_manager.emit(event=Event.SYSTEM_INFO, event_data=event_system_info_data)
await asyncio.sleep(0.1) # Allow some time for the event to be processed

# Ensure the listener for the other event was not called
assert async_listener.call_count == 0


async def test_emit_invokes_parameterless_listener(
event_manager: EventManager,
event_system_info_data: EventSystemInfoData,
) -> None:
sync_mock = MagicMock()

def sync_listener() -> None:
sync_mock()

async_mock = MagicMock()

async def async_listener() -> None:
async_mock()

event_manager.on(event=Event.SYSTEM_INFO, listener=sync_listener)
event_manager.on(event=Event.SYSTEM_INFO, listener=async_listener)

event_manager.emit(event=Event.SYSTEM_INFO, event_data=event_system_info_data)
await asyncio.sleep(0.1) # Allow some time for the event to be processed
janbuchar marked this conversation as resolved.
Show resolved Hide resolved

assert sync_mock.call_count == 1
assert async_mock.call_count == 1


async def test_remove_nonexistent_listener_does_not_fail(
async_listener: AsyncMock,
event_manager: EventManager,
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/events/test_local_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio
from datetime import timedelta
from functools import update_wrapper
from typing import Any
from unittest.mock import AsyncMock

import pytest
Expand All @@ -12,8 +14,11 @@

@pytest.fixture
def listener() -> AsyncMock:
async def async_listener(payload: Any) -> None:
pass

al = AsyncMock()
al.__name__ = 'listener' # To avoid issues with the function name
update_wrapper(al, async_listener)
return al


Expand Down
Loading