|
| 1 | +import inspect |
| 2 | +import types |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from enum import Enum |
| 5 | +from functools import wraps |
| 6 | +from typing import ( |
| 7 | + Any, |
| 8 | + Awaitable, |
| 9 | + Callable, |
| 10 | + Generic, |
| 11 | + Optional, |
| 12 | + Protocol, |
| 13 | + Type, |
| 14 | + TypeVar, |
| 15 | + Union, |
| 16 | + get_origin, |
| 17 | +) |
| 18 | + |
| 19 | +from typing_extensions import ParamSpec |
| 20 | + |
| 21 | +import nexusrpc.interface |
| 22 | + |
| 23 | +# TODO(dan) |
| 24 | +_ServiceImpl = Any |
| 25 | +_Params = ParamSpec("_Params") |
| 26 | +I = TypeVar("I", contravariant=True) |
| 27 | +O = TypeVar("O", covariant=True) |
| 28 | +S = TypeVar("S", bound=_ServiceImpl) |
| 29 | +ST = TypeVar("ST", bound=Type[_ServiceImpl]) |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class Link: |
| 34 | + """ |
| 35 | + Link contains a URL and a Type that can be used to decode the URL. |
| 36 | + Links can contain any arbitrary information as a percent-encoded URL. |
| 37 | + It can be used to pass information about the caller to the handler, or vice versa. |
| 38 | + """ |
| 39 | + |
| 40 | + # The URL must be percent-encoded. |
| 41 | + url: str |
| 42 | + # Can describe an actual data type for decoding the URL. Valid chars: alphanumeric, '_', '.', |
| 43 | + # '/' |
| 44 | + type: str |
| 45 | + |
| 46 | + |
| 47 | +@dataclass |
| 48 | +class StartOperationOptions: |
| 49 | + """Options passed by the Nexus caller when starting an operation.""" |
| 50 | + |
| 51 | + headers: dict[str, str] = field(default_factory=dict) |
| 52 | + |
| 53 | + # A callback URL is required to deliver the completion of an async operation. This URL should be |
| 54 | + # called by a handler upon completion if the started operation is async. |
| 55 | + callback_url: Optional[str] = None |
| 56 | + |
| 57 | + # Optional header fields set by the caller to be attached to the callback request when an |
| 58 | + # asynchronous operation completes. |
| 59 | + callback_header: dict[str, str] = field(default_factory=dict) |
| 60 | + |
| 61 | + # Request ID that may be used by the server handler to dedupe a start request. |
| 62 | + # By default a v4 UUID will be generated by the client. |
| 63 | + request_id: Optional[str] = None |
| 64 | + |
| 65 | + # Links contain arbitrary caller information. Handlers may use these links as |
| 66 | + # metadata on resources associated with an operation. |
| 67 | + links: list[Link] = field(default_factory=list) |
| 68 | + |
| 69 | + |
| 70 | +@dataclass |
| 71 | +class CancelOperationOptions: |
| 72 | + """Options passed by the Nexus caller when cancelling an operation.""" |
| 73 | + |
| 74 | + headers: dict[str, str] = field(default_factory=dict) |
| 75 | + |
| 76 | + |
| 77 | +@dataclass |
| 78 | +class FetchOperationInfoOptions: |
| 79 | + """Options passed by the Nexus caller when fetching information about an operation.""" |
| 80 | + |
| 81 | + headers: dict[str, str] = field(default_factory=dict) |
| 82 | + |
| 83 | + |
| 84 | +@dataclass |
| 85 | +class FetchOperationResultOptions: |
| 86 | + """Options passed by the Nexus caller when fetching the result of an operation.""" |
| 87 | + |
| 88 | + headers: dict[str, str] = field(default_factory=dict) |
| 89 | + |
| 90 | + |
| 91 | +class OperationState(Enum): |
| 92 | + """ |
| 93 | + Describes the current state of an operation. |
| 94 | + """ |
| 95 | + |
| 96 | + SUCCEEDED = "succeeded" |
| 97 | + FAILED = "failed" |
| 98 | + CANCELED = "canceled" |
| 99 | + RUNNING = "running" |
| 100 | + |
| 101 | + |
| 102 | +@dataclass |
| 103 | +class OperationInfo: |
| 104 | + """ |
| 105 | + Information about an operation. |
| 106 | + """ |
| 107 | + |
| 108 | + # Token identifying the operation (returned on operation start). |
| 109 | + token: str |
| 110 | + |
| 111 | + # The operation's status. |
| 112 | + status: OperationState |
| 113 | + |
| 114 | + |
| 115 | +@dataclass |
| 116 | +class AsyncOperationResult: |
| 117 | + token: str |
| 118 | + |
| 119 | + |
| 120 | +# TODO(dan): TBD what we'll call this. It's public so should not have an |
| 121 | +# underscore prefix. |
| 122 | +@dataclass |
| 123 | +class _NexusServiceDefinition: |
| 124 | + name: str |
| 125 | + operation_factories: dict[str, Callable[[_ServiceImpl], "Operation[Any, Any]"]] |
| 126 | + |
| 127 | + # TODO(dan): TBD what the name of this look method should be. |
| 128 | + @classmethod |
| 129 | + def from_implementation( |
| 130 | + cls, implementation: Type[_ServiceImpl] |
| 131 | + ) -> "_NexusServiceDefinition": |
| 132 | + """ |
| 133 | + Retrieve the service definition that was set by the decorator on a service implementation class. |
| 134 | + """ |
| 135 | + if defn := getattr(implementation, "__nexus_service__", None): |
| 136 | + return defn |
| 137 | + raise ValueError( |
| 138 | + f"Service implementation {implementation} does not have a service definition. " |
| 139 | + f"Use the @nexusrpc.handler.service decorator on your class to define a Nexus service implementation." |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +@dataclass |
| 144 | +class _NexusOperationDefinition: |
| 145 | + name: str |
| 146 | + |
| 147 | + |
| 148 | +@dataclass |
| 149 | +class service(Generic[ST]): |
| 150 | + """Decorator that marks a class as a Nexus service implementation. |
| 151 | +
|
| 152 | + Args: |
| 153 | + interface: The service interface that the service implements. |
| 154 | + name: The name of the service. If not provided, the class name will be used. |
| 155 | +
|
| 156 | + Example: |
| 157 | + ```python |
| 158 | + @nexusrpc.handler.service(MyServiceInterface, name="my-service") |
| 159 | + class MyService: |
| 160 | + ... |
| 161 | + ``` |
| 162 | + """ |
| 163 | + |
| 164 | + interface: Type[Any] |
| 165 | + name: Optional[str] = None |
| 166 | + |
| 167 | + def __call__(self, cls: ST) -> ST: |
| 168 | + cls.__nexus_service__ = _NexusServiceDefinition( |
| 169 | + # The name by which the service is addressed in Nexus requests is |
| 170 | + # the name of the interface, not that of the implementation class. |
| 171 | + name=self.interface.__name__, |
| 172 | + operation_factories=self._get_operation_factories(cls, self.interface), |
| 173 | + ) |
| 174 | + return cls |
| 175 | + |
| 176 | + @staticmethod |
| 177 | + def _get_operation_factories( |
| 178 | + service_cls: ST, interface: Type[Any] |
| 179 | + ) -> dict[str, Callable[[_ServiceImpl], "Operation[Any, Any]"]]: |
| 180 | + """ |
| 181 | + Get the operation factories from the class. |
| 182 | + """ |
| 183 | + # TODO(dan): gather all errors, check type parameters, type variance, etc |
| 184 | + interface_op_names = set() |
| 185 | + for name, _type in interface.__annotations__.items(): |
| 186 | + # TODO(dan): use get_args to check type parameters |
| 187 | + if get_origin(_type) == nexusrpc.interface.NexusOperation: |
| 188 | + interface_op_names.add(name) |
| 189 | + |
| 190 | + op_factories = { |
| 191 | + name: method |
| 192 | + for name, method in inspect.getmembers(service_cls, inspect.isfunction) |
| 193 | + if hasattr(method, "__nexus_operation__") and name in interface_op_names |
| 194 | + } |
| 195 | + if len(op_factories) < len(interface_op_names): |
| 196 | + raise ValueError( |
| 197 | + f"Service {service_cls} does not implement all operations in interface {interface}. " |
| 198 | + f"Missing operations: {interface_op_names - op_factories.keys()}" |
| 199 | + ) |
| 200 | + return op_factories |
| 201 | + |
| 202 | + |
| 203 | +class Operation(Protocol, Generic[I, O]): |
| 204 | + """ |
| 205 | + Interface that must be implemented by an operation in a Nexus service implementation. |
| 206 | + """ |
| 207 | + |
| 208 | + # start either returns the result synchronously, or returns an operation token. Which path is |
| 209 | + # taken may be decided at operation handling time. |
| 210 | + async def start( |
| 211 | + self, input: I, options: StartOperationOptions |
| 212 | + ) -> Union[O, AsyncOperationResult]: ... |
| 213 | + |
| 214 | + async def fetch_info( |
| 215 | + self, token: str, options: FetchOperationInfoOptions |
| 216 | + ) -> OperationInfo: ... |
| 217 | + |
| 218 | + async def fetch_result( |
| 219 | + self, token: str, options: FetchOperationResultOptions |
| 220 | + ) -> O: ... |
| 221 | + |
| 222 | + async def cancel(self, token: str, options: CancelOperationOptions) -> None: ... |
| 223 | + |
| 224 | + |
| 225 | +_OpFactory = Callable[[_ServiceImpl], Operation[Any, Any]] |
| 226 | +F = TypeVar("F", bound=_OpFactory) |
| 227 | + |
| 228 | + |
| 229 | +# TODO(dan): This is following workflow.defn but check that invalid decorator |
| 230 | +# usage is prevented by this implementation style. |
| 231 | +def operation( |
| 232 | + method: Optional[F] = None, |
| 233 | + *, |
| 234 | + name: Optional[str] = None, |
| 235 | +) -> Union[F, Callable[[F], F]]: |
| 236 | + """ |
| 237 | + Decorator that marks a method as an operation in a Nexus service implementation. |
| 238 | +
|
| 239 | + Args: |
| 240 | + method: The method to decorate. |
| 241 | + name: The name of the operation. If not provided, the method name will be used. |
| 242 | +
|
| 243 | + Examples: |
| 244 | + ``` |
| 245 | + @nexusrpc.handler.operation |
| 246 | + def my_operation(self) -> Operation[MyInput, MyOutput]: |
| 247 | + ... |
| 248 | + ``` |
| 249 | +
|
| 250 | + ``` |
| 251 | + @nexusrpc.handler.operation(name="my-operation") |
| 252 | + def my_operation(self) -> Operation[MyInput, MyOutput]: |
| 253 | + ... |
| 254 | + ``` |
| 255 | + """ |
| 256 | + |
| 257 | + def decorator(method: F) -> F: |
| 258 | + method.__nexus_operation__ = _NexusOperationDefinition( |
| 259 | + name=name or method.__name__ |
| 260 | + ) |
| 261 | + return method |
| 262 | + |
| 263 | + if method is None: |
| 264 | + return decorator |
| 265 | + |
| 266 | + return decorator(method) |
| 267 | + |
| 268 | + |
| 269 | +# abc? require start? |
| 270 | +class AbstractOperation(Operation[I, O]): |
| 271 | + async def start(self, input: I, options: StartOperationOptions) -> O: |
| 272 | + raise NotImplementedError |
| 273 | + |
| 274 | + async def fetch_info( |
| 275 | + self, token: str, options: FetchOperationInfoOptions |
| 276 | + ) -> OperationInfo: |
| 277 | + raise NotImplementedError |
| 278 | + |
| 279 | + async def fetch_result(self, token: str, options: FetchOperationResultOptions) -> O: |
| 280 | + raise NotImplementedError |
| 281 | + |
| 282 | + async def cancel(self, token: str, options: CancelOperationOptions) -> None: |
| 283 | + raise NotImplementedError |
| 284 | + |
| 285 | + |
| 286 | +# TODO(dan): support overriding op name |
| 287 | +def sync_operation( |
| 288 | + start_method: Callable[[S, I, StartOperationOptions], Awaitable[O]], |
| 289 | +) -> Callable[[S], AbstractOperation[I, O]]: |
| 290 | + def factory(service: S) -> AbstractOperation[I, O]: |
| 291 | + # A start method defined in this way was written by the user as a method |
| 292 | + # on a service. We convert it into a method on an operation, but the |
| 293 | + # method will not access the operation. |
| 294 | + @wraps(start_method) |
| 295 | + async def start(_, input: I, options: StartOperationOptions) -> O: |
| 296 | + return await start_method(service, input, options) |
| 297 | + |
| 298 | + op = AbstractOperation() |
| 299 | + op.start = types.MethodType(start, op) |
| 300 | + return op |
| 301 | + |
| 302 | + factory.__nexus_operation__ = _NexusOperationDefinition(name=start_method.__name__) |
| 303 | + |
| 304 | + return factory |
0 commit comments