diff --git a/examples/app/main_async.py b/examples/app/main_async.py index 39d1a4d..2fc6e4b 100644 --- a/examples/app/main_async.py +++ b/examples/app/main_async.py @@ -72,9 +72,9 @@ async def echo(inv: Invocation) -> Result: @app.subscribe("io.xconn.yo") -async def login(event: Event) -> None: +async def login(name: str, city: str) -> None: print(app.session) - print(event.args) + print(name, city) @app.register("io.xconn.dynamic") diff --git a/examples/app/main_sync.py b/examples/app/main_sync.py index cbfa78d..96f698f 100644 --- a/examples/app/main_sync.py +++ b/examples/app/main_sync.py @@ -53,9 +53,9 @@ def echo(inv: Invocation) -> Result: @app.subscribe("io.xconn.yo") -def login(event: Event) -> None: +def login(name: str, city: str) -> None: print(app.session) - print(event.args) + print(name, city) @app.register("io.xconn.not_allowed", allowed_roles=["test"]) diff --git a/xconn/_client/async_.py b/xconn/_client/async_.py index d8e3241..e409ecb 100644 --- a/xconn/_client/async_.py +++ b/xconn/_client/async_.py @@ -19,6 +19,7 @@ MAX_WAIT, ProcedureMetadata, assemble_call_details, + assemble_event_details, ) from xconn._client.types import ClientConfig from xconn.client import AsyncClient @@ -146,22 +147,33 @@ async def subscribe_async(session: AsyncSession, topic: str, func: callable): if not inspect.iscoroutinefunction(func): raise RuntimeError(f"function {func.__name__} for topic '{topic}' must be a coroutine") - model, positional_args, options = _validate_topic_function(func, topic) + meta = _validate_topic_function(func, topic) async def _handle_event(event: Event) -> None: - if model is not None: - kwargs = _sanitize_incoming_data(event.args, event.kwargs, positional_args) - try: - await func(model(**kwargs)) - except Exception as e: - print(e) + details = assemble_event_details(topic, meta, event) - return + if meta.dynamic_model: + kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args) + handle_model_validation(meta.request_model, **kwargs) - try: - await func(event) - except Exception as e: - print(e) + async with resolve_dependencies(meta) as deps: + await func(**kwargs, **deps, **details) + + elif meta.request_model is not None: + kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args) + model = handle_model_validation(meta.request_model, **kwargs) + + async with resolve_dependencies(meta) as deps: + input_data = {meta.positional_field_name: model} + await func(**input_data, **deps, **details) + + elif meta.no_args: + async with resolve_dependencies(meta) as deps: + await func(**deps, **details) + else: + async with resolve_dependencies(meta) as deps: + input_data = {meta.positional_field_name: event} + await func(**input_data, **deps, **details) - await session.subscribe(topic, _handle_event, options=options) + await session.subscribe(topic, _handle_event, options=meta.subscribe_options) print(f"Subscribed topic {topic}") diff --git a/xconn/_client/helpers.py b/xconn/_client/helpers.py index d74106b..31a8391 100644 --- a/xconn/_client/helpers.py +++ b/xconn/_client/helpers.py @@ -33,39 +33,46 @@ from xconn import Router, Server from xconn._client.types import ClientConfig from xconn.exception import ApplicationError -from xconn.types import Event, Invocation, Result, Depends, CallDetails, RegisterOptions +from xconn.types import Event, Invocation, Result, Depends, CallDetails, RegisterOptions, SubscribeOptions, EventDetails MAX_WAIT = 300 INITIAL_WAIT = 1 @dataclass -class ProcedureMetadata: +class BaseMetadata: request_model: Type[BaseModel] | None - response_model: Type[BaseModel] | None - request_args: list[str] - response_args: list[str] - request_kwargs: list[str] - response_kwargs: list[str] no_args: bool dynamic_model: bool - allowed_roles: list[str] - dependencies: dict[str, Callable] ctx_dependencies: dict[str, ContextManager] async_dependencies: dict[str, Awaitable] async_ctx_dependencies: dict[str, AsyncContextManager] - call_details_field: str | None + details_field: str | None positional_field_name: str | None + +@dataclass +class ProcedureMetadata(BaseMetadata): + response_model: Type[BaseModel] | None + response_args: list[str] + response_kwargs: list[str] + + allowed_roles: list[str] + register_options: dict[str, Any] | RegisterOptions | None +@dataclass +class EventMetadata(BaseMetadata): + subscribe_options: dict[str, Any] | SubscribeOptions | None + + def create_model_from_func(func): signature = inspect.signature(func) type_hints = get_type_hints(func) @@ -97,14 +104,29 @@ def is_subclass_of_any(type_, base_class: Any) -> bool: return isinstance(type_, type) and issubclass(type_, base_class) -def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata: +def _do_it( + func: callable, + uri: str, + incoming_class: type(Invocation) | type(Event), + details_class: type(CallDetails) | type(EventDetails), +) -> BaseMetadata: sig = inspect.signature(func) for name, param in sig.parameters.items(): if param.annotation is inspect._empty: raise RuntimeError(f"Missing type hint for parameter: '{name}' in function '{func.__name__}'") + if issubclass(incoming_class, Event): + context = "topic" + elif issubclass(incoming_class, Invocation): + context = "procedure" + else: + raise RuntimeError(f"incoming_class can be with Invocation or Event got='{incoming_class.__name__}'") + + if not issubclass(details_class, CallDetails) and not issubclass(details_class, EventDetails): + raise RuntimeError(f"details_class can be either CallDetails or EventDetails got='{details_class.__name__}'") + hints = get_type_hints(func) - hints.pop("return") if "return" in hints else None + hints.pop("return", None) request_model = None request_args = [] @@ -147,39 +169,43 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata: del hints[key] # check if CallDetails are in the function - call_details_field = None + details_field = None for name, type_ in hints.items(): - if is_subclass_of_any(type_, CallDetails): - if call_details_field is not None: - raise RuntimeError(f"Duplicate call details in function '{func.__name__}'") + if is_subclass_of_any(type_, details_class): + if details_field is not None: + raise RuntimeError(f"Duplicate {details_class.__name__} in function '{func.__name__}'") - call_details_field = name + details_field = name - if call_details_field is not None: - del hints[call_details_field] + if details_field is not None: + del hints[details_field] positional_field_name: str | None = None - has_invocation_in_sig = False + has_incoming_class_in_sig = False for name, type_ in hints.items(): if issubclass(type_, BaseModel): if len(hints) != 1: - raise RuntimeError(f"Cannot mix pydantic BaseModel with other types in signature of procedure '{uri}'") + raise RuntimeError(f"Cannot mix pydantic BaseModel with other types in signature of {context} '{uri}'") request_model = type_ positional_field_name = name break - if is_subclass_of_any(type_, Invocation): - if has_invocation_in_sig: - raise RuntimeError(f"Cannot use other types than 'Invocation' as arguments in procedure '{uri}'") + if is_subclass_of_any(type_, incoming_class): + if has_incoming_class_in_sig: + raise RuntimeError( + f"Cannot use other types than '{incoming_class.__name__}' as arguments in {context} '{uri}'" + ) - has_invocation_in_sig = True + has_incoming_class_in_sig = True positional_field_name = name if Invocation in hints.values(): if len(hints) != 1: - raise RuntimeError(f"Cannot use other types than 'Invocation' as arguments in procedure '{uri}'") + raise RuntimeError( + f"Cannot use other types than '{incoming_class.__name__}' as arguments in {context} '{uri}'" + ) elif request_model is not None: for key, value in request_model.model_fields.items(): if value.is_required: @@ -199,6 +225,24 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata: dynamic_model = True + return BaseMetadata( + request_model=request_model, + request_args=request_args, + request_kwargs=request_kwargs, + no_args=no_args, + dynamic_model=dynamic_model, + dependencies=dependencies, + ctx_dependencies=ctx_dependencies, + async_dependencies=async_dependencies, + async_ctx_dependencies=async_ctx_dependencies, + details_field=details_field, + positional_field_name=positional_field_name, + ) + + +def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata: + meta = _do_it(func, uri, Invocation, CallDetails) + response_model = getattr(func, "__xconn_response_model__", None) response_args = [] response_kwargs = [] @@ -213,21 +257,21 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata: register_options = getattr(func, "__xconn_register_options__", None) return ProcedureMetadata( - request_model=request_model, + request_model=meta.request_model, response_model=response_model, - request_args=request_args, + request_args=meta.request_args, response_args=response_args, - request_kwargs=request_kwargs, + request_kwargs=meta.request_kwargs, response_kwargs=response_kwargs, - no_args=no_args, - dynamic_model=dynamic_model, + no_args=meta.no_args, + dynamic_model=meta.dynamic_model, allowed_roles=allowed_roles, - dependencies=dependencies, - ctx_dependencies=ctx_dependencies, - async_dependencies=async_dependencies, - async_ctx_dependencies=async_ctx_dependencies, - call_details_field=call_details_field, - positional_field_name=positional_field_name, + dependencies=meta.dependencies, + ctx_dependencies=meta.ctx_dependencies, + async_dependencies=meta.async_dependencies, + async_ctx_dependencies=meta.async_ctx_dependencies, + details_field=meta.details_field, + positional_field_name=meta.positional_field_name, register_options=register_options, ) @@ -306,36 +350,23 @@ def _handle_result( def _validate_topic_function(func: callable, uri: str): - sig = inspect.signature(func) - for name, param in sig.parameters.items(): - if param.annotation is inspect._empty: - raise RuntimeError(f"Missing type hint for parameter: '{name}' in function '{func.__name__}'") - - hints = get_type_hints(func) - hints.pop("return") if "return" in hints else None - - if Event in hints.values(): - if len(hints) != 1: - raise RuntimeError(f"Cannot use other types than 'Event' as arguments in subscription '{uri}'") - - pydantic_model = None - positional_args = [] - for type_ in hints.values(): - if issubclass(type_, BaseModel): - if len(hints) != 1: - raise RuntimeError( - f"Cannot mix pydantic dataclass with other types in signature of subscription '{uri}'" - ) - - pydantic_model = type_ - - for key, value in pydantic_model.model_fields.items(): - if value.is_required: - positional_args.append(key) - + meta = _do_it(func, uri, Event, EventDetails) options = getattr(func, "__xconn_subscribe_options__", None) - return pydantic_model, positional_args, options + return EventMetadata( + request_model=meta.request_model, + request_args=meta.request_args, + request_kwargs=meta.request_kwargs, + no_args=meta.no_args, + dynamic_model=meta.dynamic_model, + dependencies=meta.dependencies, + ctx_dependencies=meta.ctx_dependencies, + async_dependencies=meta.async_dependencies, + async_ctx_dependencies=meta.async_ctx_dependencies, + details_field=meta.details_field, + positional_field_name=meta.positional_field_name, + subscribe_options=options, + ) def _sanitize_incoming_data(args: list, kwargs: dict, model_positional_args: list[str]): @@ -466,11 +497,23 @@ def ensure_caller_allowed(call_details: dict[str, Any], allowed_roles: list[str] def assemble_call_details(uri: str, meta: ProcedureMetadata, invocation: Invocation): details = {} - if meta.call_details_field is not None: + if meta.details_field is not None: if not invocation.details: msg = f"Endpoint for procedure {uri} expects CallDetails but router did not send them" raise ApplicationError("wamp.error.internal_error", msg) - details[meta.call_details_field] = CallDetails(invocation.details) + details[meta.details_field] = CallDetails(invocation.details) + + return details + + +def assemble_event_details(uri: str, meta: EventMetadata, event: Event): + details = {} + if meta.details_field is not None: + if not event.details: + msg = f"Endpoint for topic {uri} expects EventDetails but router did not send them" + raise ApplicationError("wamp.error.internal_error", msg) + + details[meta.details_field] = CallDetails(event.details) return details diff --git a/xconn/_client/sync.py b/xconn/_client/sync.py index ad63789..fd8f37a 100644 --- a/xconn/_client/sync.py +++ b/xconn/_client/sync.py @@ -22,6 +22,7 @@ MAX_WAIT, ProcedureMetadata, assemble_call_details, + assemble_event_details, ) from xconn._client.types import ClientConfig from xconn.client import Client @@ -152,23 +153,33 @@ def subscribe_sync(session: Session, topic: str, func: callable): if inspect.iscoroutinefunction(func): raise RuntimeError(f"function {func.__name__} for topic '{topic}' must not be a coroutine") - model, positional_args, options = _validate_topic_function(func, topic) + meta = _validate_topic_function(func, topic) def _handle_event(event: Event): - if model is not None: - kwargs = _sanitize_incoming_data(event.args, event.kwargs, positional_args) + details = assemble_event_details(topic, meta, event) - try: - func(model(**kwargs)) - except Exception as e: - print(e) + if meta.dynamic_model: + kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args) + handle_model_validation(meta.request_model, **kwargs) - return + with resolve_dependencies(meta) as deps: + func(**kwargs, **deps, **details) - try: - func(event) - except Exception as e: - print(e) + elif meta.request_model is not None: + kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args) + model = handle_model_validation(meta.request_model, **kwargs) + + with resolve_dependencies(meta) as deps: + input_data = {meta.positional_field_name: model} + func(**input_data, **deps, **details) + + elif meta.no_args: + with resolve_dependencies(meta) as deps: + func(**deps, **details) + else: + with resolve_dependencies(meta) as deps: + input_data = {meta.positional_field_name: event} + func(**input_data, **deps, **details) - session.subscribe(topic, _handle_event, options=options) + session.subscribe(topic, _handle_event, options=meta.subscribe_options) print(f"Subscribed topic {topic}") diff --git a/xconn/async_session.py b/xconn/async_session.py index 8d5a3b6..3614e76 100644 --- a/xconn/async_session.py +++ b/xconn/async_session.py @@ -146,7 +146,10 @@ async def process_incoming_message(self, msg: messages.Message): request.set_result(None) elif isinstance(msg, messages.Event): endpoint = self.subscriptions[msg.subscription_id] - await endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + try: + await endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + except Exception as e: + print(e) elif isinstance(msg, messages.Error): match msg.message_type: case messages.Call.TYPE: diff --git a/xconn/types.py b/xconn/types.py index e60a28d..48a40f8 100644 --- a/xconn/types.py +++ b/xconn/types.py @@ -514,10 +514,22 @@ def authrole(self) -> str | None: return self.get("caller_authrole") -class PublicationDetails(_IncomingDetails): +class EventDetails(_IncomingDetails): def __init__(self, details: dict | None = None): super().__init__(details) + @property + def session_id(self) -> int | None: + return self.get("publisher") + + @property + def authid(self) -> str | None: + return self.get("publisher_authid") + + @property + def authrole(self) -> str | None: + return self.get("publisher_authrole") + class InvokeOptions(Enum): SINGLE = "single"