Skip to content

bring event functions at parity with procedures #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/app/main_async.py
Original file line number Diff line number Diff line change
@@ -72,9 +72,9 @@ async def echo(inv: Invocation) -> Result:


@app.subscribe("io.xconn.yo")
async def login(event: Event) -> None:
async def login(name: str, city: str) -> None:
print(app.session)
print(event.args)
print(name, city)


@app.register("io.xconn.dynamic")
4 changes: 2 additions & 2 deletions examples/app/main_sync.py
Original file line number Diff line number Diff line change
@@ -53,9 +53,9 @@ def echo(inv: Invocation) -> Result:


@app.subscribe("io.xconn.yo")
def login(event: Event) -> None:
def login(name: str, city: str) -> None:
print(app.session)
print(event.args)
print(name, city)


@app.register("io.xconn.not_allowed", allowed_roles=["test"])
38 changes: 25 additions & 13 deletions xconn/_client/async_.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
MAX_WAIT,
ProcedureMetadata,
assemble_call_details,
assemble_event_details,
)
from xconn._client.types import ClientConfig
from xconn.client import AsyncClient
@@ -146,22 +147,33 @@ async def subscribe_async(session: AsyncSession, topic: str, func: callable):
if not inspect.iscoroutinefunction(func):
raise RuntimeError(f"function {func.__name__} for topic '{topic}' must be a coroutine")

model, positional_args, options = _validate_topic_function(func, topic)
meta = _validate_topic_function(func, topic)

async def _handle_event(event: Event) -> None:
if model is not None:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, positional_args)
try:
await func(model(**kwargs))
except Exception as e:
print(e)
details = assemble_event_details(topic, meta, event)

return
if meta.dynamic_model:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args)
handle_model_validation(meta.request_model, **kwargs)

try:
await func(event)
except Exception as e:
print(e)
async with resolve_dependencies(meta) as deps:
await func(**kwargs, **deps, **details)

elif meta.request_model is not None:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args)
model = handle_model_validation(meta.request_model, **kwargs)

async with resolve_dependencies(meta) as deps:
input_data = {meta.positional_field_name: model}
await func(**input_data, **deps, **details)

elif meta.no_args:
async with resolve_dependencies(meta) as deps:
await func(**deps, **details)
else:
async with resolve_dependencies(meta) as deps:
input_data = {meta.positional_field_name: event}
await func(**input_data, **deps, **details)

await session.subscribe(topic, _handle_event, options=options)
await session.subscribe(topic, _handle_event, options=meta.subscribe_options)
print(f"Subscribed topic {topic}")
177 changes: 110 additions & 67 deletions xconn/_client/helpers.py
Original file line number Diff line number Diff line change
@@ -33,39 +33,46 @@
from xconn import Router, Server
from xconn._client.types import ClientConfig
from xconn.exception import ApplicationError
from xconn.types import Event, Invocation, Result, Depends, CallDetails, RegisterOptions
from xconn.types import Event, Invocation, Result, Depends, CallDetails, RegisterOptions, SubscribeOptions, EventDetails

MAX_WAIT = 300
INITIAL_WAIT = 1


@dataclass
class ProcedureMetadata:
class BaseMetadata:
request_model: Type[BaseModel] | None
response_model: Type[BaseModel] | None

request_args: list[str]
response_args: list[str]

request_kwargs: list[str]
response_kwargs: list[str]

no_args: bool
dynamic_model: bool

allowed_roles: list[str]

dependencies: dict[str, Callable]
ctx_dependencies: dict[str, ContextManager]
async_dependencies: dict[str, Awaitable]
async_ctx_dependencies: dict[str, AsyncContextManager]

call_details_field: str | None
details_field: str | None
positional_field_name: str | None


@dataclass
class ProcedureMetadata(BaseMetadata):
response_model: Type[BaseModel] | None
response_args: list[str]
response_kwargs: list[str]

allowed_roles: list[str]

register_options: dict[str, Any] | RegisterOptions | None


@dataclass
class EventMetadata(BaseMetadata):
subscribe_options: dict[str, Any] | SubscribeOptions | None


def create_model_from_func(func):
signature = inspect.signature(func)
type_hints = get_type_hints(func)
@@ -97,14 +104,29 @@ def is_subclass_of_any(type_, base_class: Any) -> bool:
return isinstance(type_, type) and issubclass(type_, base_class)


def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata:
def _do_it(
func: callable,
uri: str,
incoming_class: type(Invocation) | type(Event),
details_class: type(CallDetails) | type(EventDetails),
) -> BaseMetadata:
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if param.annotation is inspect._empty:
raise RuntimeError(f"Missing type hint for parameter: '{name}' in function '{func.__name__}'")

if issubclass(incoming_class, Event):
context = "topic"
elif issubclass(incoming_class, Invocation):
context = "procedure"
else:
raise RuntimeError(f"incoming_class can be with Invocation or Event got='{incoming_class.__name__}'")

if not issubclass(details_class, CallDetails) and not issubclass(details_class, EventDetails):
raise RuntimeError(f"details_class can be either CallDetails or EventDetails got='{details_class.__name__}'")

hints = get_type_hints(func)
hints.pop("return") if "return" in hints else None
hints.pop("return", None)

request_model = None
request_args = []
@@ -147,39 +169,43 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata:
del hints[key]

# check if CallDetails are in the function
call_details_field = None
details_field = None
for name, type_ in hints.items():
if is_subclass_of_any(type_, CallDetails):
if call_details_field is not None:
raise RuntimeError(f"Duplicate call details in function '{func.__name__}'")
if is_subclass_of_any(type_, details_class):
if details_field is not None:
raise RuntimeError(f"Duplicate {details_class.__name__} in function '{func.__name__}'")

call_details_field = name
details_field = name

if call_details_field is not None:
del hints[call_details_field]
if details_field is not None:
del hints[details_field]

positional_field_name: str | None = None

has_invocation_in_sig = False
has_incoming_class_in_sig = False
for name, type_ in hints.items():
if issubclass(type_, BaseModel):
if len(hints) != 1:
raise RuntimeError(f"Cannot mix pydantic BaseModel with other types in signature of procedure '{uri}'")
raise RuntimeError(f"Cannot mix pydantic BaseModel with other types in signature of {context} '{uri}'")

request_model = type_
positional_field_name = name
break

if is_subclass_of_any(type_, Invocation):
if has_invocation_in_sig:
raise RuntimeError(f"Cannot use other types than 'Invocation' as arguments in procedure '{uri}'")
if is_subclass_of_any(type_, incoming_class):
if has_incoming_class_in_sig:
raise RuntimeError(
f"Cannot use other types than '{incoming_class.__name__}' as arguments in {context} '{uri}'"
)

has_invocation_in_sig = True
has_incoming_class_in_sig = True
positional_field_name = name

if Invocation in hints.values():
if len(hints) != 1:
raise RuntimeError(f"Cannot use other types than 'Invocation' as arguments in procedure '{uri}'")
raise RuntimeError(
f"Cannot use other types than '{incoming_class.__name__}' as arguments in {context} '{uri}'"
)
elif request_model is not None:
for key, value in request_model.model_fields.items():
if value.is_required:
@@ -199,6 +225,24 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata:

dynamic_model = True

return BaseMetadata(
request_model=request_model,
request_args=request_args,
request_kwargs=request_kwargs,
no_args=no_args,
dynamic_model=dynamic_model,
dependencies=dependencies,
ctx_dependencies=ctx_dependencies,
async_dependencies=async_dependencies,
async_ctx_dependencies=async_ctx_dependencies,
details_field=details_field,
positional_field_name=positional_field_name,
)


def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata:
meta = _do_it(func, uri, Invocation, CallDetails)

response_model = getattr(func, "__xconn_response_model__", None)
response_args = []
response_kwargs = []
@@ -213,21 +257,21 @@ def _validate_procedure_function(func: callable, uri: str) -> ProcedureMetadata:
register_options = getattr(func, "__xconn_register_options__", None)

return ProcedureMetadata(
request_model=request_model,
request_model=meta.request_model,
response_model=response_model,
request_args=request_args,
request_args=meta.request_args,
response_args=response_args,
request_kwargs=request_kwargs,
request_kwargs=meta.request_kwargs,
response_kwargs=response_kwargs,
no_args=no_args,
dynamic_model=dynamic_model,
no_args=meta.no_args,
dynamic_model=meta.dynamic_model,
allowed_roles=allowed_roles,
dependencies=dependencies,
ctx_dependencies=ctx_dependencies,
async_dependencies=async_dependencies,
async_ctx_dependencies=async_ctx_dependencies,
call_details_field=call_details_field,
positional_field_name=positional_field_name,
dependencies=meta.dependencies,
ctx_dependencies=meta.ctx_dependencies,
async_dependencies=meta.async_dependencies,
async_ctx_dependencies=meta.async_ctx_dependencies,
details_field=meta.details_field,
positional_field_name=meta.positional_field_name,
register_options=register_options,
)

@@ -306,36 +350,23 @@ def _handle_result(


def _validate_topic_function(func: callable, uri: str):
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if param.annotation is inspect._empty:
raise RuntimeError(f"Missing type hint for parameter: '{name}' in function '{func.__name__}'")

hints = get_type_hints(func)
hints.pop("return") if "return" in hints else None

if Event in hints.values():
if len(hints) != 1:
raise RuntimeError(f"Cannot use other types than 'Event' as arguments in subscription '{uri}'")

pydantic_model = None
positional_args = []
for type_ in hints.values():
if issubclass(type_, BaseModel):
if len(hints) != 1:
raise RuntimeError(
f"Cannot mix pydantic dataclass with other types in signature of subscription '{uri}'"
)

pydantic_model = type_

for key, value in pydantic_model.model_fields.items():
if value.is_required:
positional_args.append(key)

meta = _do_it(func, uri, Event, EventDetails)
options = getattr(func, "__xconn_subscribe_options__", None)

return pydantic_model, positional_args, options
return EventMetadata(
request_model=meta.request_model,
request_args=meta.request_args,
request_kwargs=meta.request_kwargs,
no_args=meta.no_args,
dynamic_model=meta.dynamic_model,
dependencies=meta.dependencies,
ctx_dependencies=meta.ctx_dependencies,
async_dependencies=meta.async_dependencies,
async_ctx_dependencies=meta.async_ctx_dependencies,
details_field=meta.details_field,
positional_field_name=meta.positional_field_name,
subscribe_options=options,
)


def _sanitize_incoming_data(args: list, kwargs: dict, model_positional_args: list[str]):
@@ -466,11 +497,23 @@ def ensure_caller_allowed(call_details: dict[str, Any], allowed_roles: list[str]

def assemble_call_details(uri: str, meta: ProcedureMetadata, invocation: Invocation):
details = {}
if meta.call_details_field is not None:
if meta.details_field is not None:
if not invocation.details:
msg = f"Endpoint for procedure {uri} expects CallDetails but router did not send them"
raise ApplicationError("wamp.error.internal_error", msg)

details[meta.call_details_field] = CallDetails(invocation.details)
details[meta.details_field] = CallDetails(invocation.details)

return details


def assemble_event_details(uri: str, meta: EventMetadata, event: Event):
details = {}
if meta.details_field is not None:
if not event.details:
msg = f"Endpoint for topic {uri} expects EventDetails but router did not send them"
raise ApplicationError("wamp.error.internal_error", msg)

details[meta.details_field] = CallDetails(event.details)

return details
37 changes: 24 additions & 13 deletions xconn/_client/sync.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
MAX_WAIT,
ProcedureMetadata,
assemble_call_details,
assemble_event_details,
)
from xconn._client.types import ClientConfig
from xconn.client import Client
@@ -152,23 +153,33 @@ def subscribe_sync(session: Session, topic: str, func: callable):
if inspect.iscoroutinefunction(func):
raise RuntimeError(f"function {func.__name__} for topic '{topic}' must not be a coroutine")

model, positional_args, options = _validate_topic_function(func, topic)
meta = _validate_topic_function(func, topic)

def _handle_event(event: Event):
if model is not None:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, positional_args)
details = assemble_event_details(topic, meta, event)

try:
func(model(**kwargs))
except Exception as e:
print(e)
if meta.dynamic_model:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args)
handle_model_validation(meta.request_model, **kwargs)

return
with resolve_dependencies(meta) as deps:
func(**kwargs, **deps, **details)

try:
func(event)
except Exception as e:
print(e)
elif meta.request_model is not None:
kwargs = _sanitize_incoming_data(event.args, event.kwargs, meta.request_args)
model = handle_model_validation(meta.request_model, **kwargs)

with resolve_dependencies(meta) as deps:
input_data = {meta.positional_field_name: model}
func(**input_data, **deps, **details)

elif meta.no_args:
with resolve_dependencies(meta) as deps:
func(**deps, **details)
else:
with resolve_dependencies(meta) as deps:
input_data = {meta.positional_field_name: event}
func(**input_data, **deps, **details)

session.subscribe(topic, _handle_event, options=options)
session.subscribe(topic, _handle_event, options=meta.subscribe_options)
print(f"Subscribed topic {topic}")
5 changes: 4 additions & 1 deletion xconn/async_session.py
Original file line number Diff line number Diff line change
@@ -146,7 +146,10 @@ async def process_incoming_message(self, msg: messages.Message):
request.set_result(None)
elif isinstance(msg, messages.Event):
endpoint = self.subscriptions[msg.subscription_id]
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
try:
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
except Exception as e:
print(e)
elif isinstance(msg, messages.Error):
match msg.message_type:
case messages.Call.TYPE:
14 changes: 13 additions & 1 deletion xconn/types.py
Original file line number Diff line number Diff line change
@@ -514,10 +514,22 @@ def authrole(self) -> str | None:
return self.get("caller_authrole")


class PublicationDetails(_IncomingDetails):
class EventDetails(_IncomingDetails):
def __init__(self, details: dict | None = None):
super().__init__(details)

@property
def session_id(self) -> int | None:
return self.get("publisher")

@property
def authid(self) -> str | None:
return self.get("publisher_authid")

@property
def authrole(self) -> str | None:
return self.get("publisher_authrole")


class InvokeOptions(Enum):
SINGLE = "single"