Skip to content

Commit

Permalink
feat: replace FromContext with FromRequestExtensions
Browse files Browse the repository at this point in the history
  • Loading branch information
andogq committed Jun 18, 2024
1 parent eb9d09b commit 7274cb0
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .changes/from-request-extensions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"qubit": minor:feat
"qubit-macros": minor:feat
---

**BREAKING** replace `FromContext` with `FromRequestExtensions` to build ctx from request information (via tower middleware)
2 changes: 1 addition & 1 deletion crates/qubit-macros/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ impl From<Handler> for TokenStream {
#visibility struct #name;
impl<#inner_ctx_ty> ::qubit::Handler<#inner_ctx_ty> for #name
where #inner_ctx_ty: 'static + ::std::marker::Send + ::std::marker::Sync + ::std::clone::Clone,
#ctx_ty: ::qubit::FromContext<#inner_ctx_ty>,
#ctx_ty: ::qubit::FromRequestExtensions<#inner_ctx_ty>,
{
fn get_type() -> ::qubit::HandlerType {
::qubit::HandlerType {
Expand Down
9 changes: 6 additions & 3 deletions examples/authentication/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::{
};
use cookie::Cookie;
use hyper::{header::SET_COOKIE, StatusCode};
use qubit::{handler, ErrorCode, FromContext, Router, RpcError};
use qubit::{handler, ErrorCode, Extensions, FromRequestExtensions, Router, RpcError};
use serde::Deserialize;
use tokio::net::TcpListener;

Expand Down Expand Up @@ -59,10 +59,13 @@ struct AuthCtx {
user: String,
}

impl FromContext<ReqCtx> for AuthCtx {
impl FromRequestExtensions<ReqCtx> for AuthCtx {
/// Implementation to generate the [`AuthCtx`] from the [`ReqCtx`]. Is falliable, so requests
/// can be blocked at this point.
async fn from_app_ctx(ctx: ReqCtx) -> Result<Self, qubit::RpcError> {
async fn from_request_extensions(
ctx: ReqCtx,
_extensions: Extensions,
) -> Result<Self, qubit::RpcError> {
// Enforce that the auth cookie is present
let Some(cookie) = ctx.auth_cookie else {
// Return an error to cancel the request if it's not
Expand Down
14 changes: 10 additions & 4 deletions examples/chaos/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ mod user {
user: u32,
}

impl FromContext<AppCtx> for UserCtx {
async fn from_app_ctx(ctx: AppCtx) -> Result<Self, RpcError> {
impl FromRequestExtensions<AppCtx> for UserCtx {
async fn from_request_extensions(
ctx: AppCtx,
_extensions: Extensions,
) -> Result<Self, RpcError> {
Ok(UserCtx {
app_ctx: ctx,
user: 0,
Expand Down Expand Up @@ -128,8 +131,11 @@ struct CountCtx {
count: Arc<AtomicUsize>,
}

impl FromContext<AppCtx> for CountCtx {
async fn from_app_ctx(ctx: AppCtx) -> Result<Self, RpcError> {
impl FromRequestExtensions<AppCtx> for CountCtx {
async fn from_request_extensions(
ctx: AppCtx,
_extensions: Extensions,
) -> Result<Self, RpcError> {
Ok(Self {
count: ctx.count.clone(),
})
Expand Down
39 changes: 21 additions & 18 deletions src/builder/rpc_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use jsonrpsee::{
use serde::Serialize;
use serde_json::json;

use crate::FromContext;
use crate::FromRequestExtensions;

/// Builder to construct the RPC module. Handlers can be registered using the [`RpcBuilder::query`]
/// and [`RpcBuilder::subscription`] methods. It tracks an internally mutable [`RpcModule`] and
Expand Down Expand Up @@ -43,31 +43,32 @@ where

/// Register a new query handler with the provided name.
///
/// The `handler` can take its own `Ctx`, so long as it implements [`FromContext`]. It must
/// return a future which outputs a serializable value.
/// The `handler` can take its own `Ctx`, so long as it implements [`FromRequestExtensions`]. It
/// must return a future which outputs a serializable value.
pub fn query<T, C, F, Fut>(mut self, name: &'static str, handler: F) -> Self
where
T: Serialize + Clone + 'static,
C: FromContext<Ctx>,
C: FromRequestExtensions<Ctx>,
F: Fn(C, Params<'static>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = T> + Send + 'static,
{
self.module
.register_async_method(self.namespace_str(name), move |params, ctx, _extensions| {
.register_async_method(self.namespace_str(name), move |params, ctx, extensions| {
// NOTE: Handler has to be cloned in since `register_async_method` takes `Fn`, not
// `FnOnce`. Not sure if it's better to be an `Rc`/leaked/???
let handler = handler.clone();

async move {
// Build the context
let ctx = match C::from_app_ctx(ctx.deref().clone()).await {
Ok(ctx) => ctx,
Err(e) => {
// Handle any error building the context by turning it into a response
// payload.
return ResponsePayload::Error(e.into());
}
};
let ctx =
match C::from_request_extensions(ctx.deref().clone(), extensions).await {
Ok(ctx) => ctx,
Err(e) => {
// Handle any error building the context by turning it into a response
// payload.
return ResponsePayload::Error(e.into());
}
};

// Run the actual handler
ResponsePayload::success(handler(ctx, params).await)
Expand All @@ -80,8 +81,8 @@ where

/// Register a new subscription handler with the provided name.
///
/// The `handler` can take its own `Ctx`, so long as it implements [`FromContext`]. It must
/// return a future that outputs a stream of serializable values.
/// The `handler` can take its own `Ctx`, so long as it implements [`FromRequestExtensions`]. It
/// must return a future that outputs a stream of serializable values.
pub fn subscription<T, C, F, Fut, S>(
mut self,
name: &'static str,
Expand All @@ -91,7 +92,7 @@ where
) -> Self
where
T: Serialize + Send + Clone + 'static,
C: FromContext<Ctx>,
C: FromRequestExtensions<Ctx>,
F: Fn(C, Params<'static>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = S> + Send + 'static,
S: Stream<Item = T> + Send + 'static,
Expand All @@ -101,7 +102,7 @@ where
self.namespace_str(name),
self.namespace_str(notification_name),
self.namespace_str(unsubscribe_name),
move |params, subscription, ctx, _extensions| {
move |params, subscription, ctx, extensions| {
// NOTE: Same deal here with cloning the handler as in the query registration.
let handler = handler.clone();

Expand Down Expand Up @@ -135,7 +136,9 @@ where
// Build the context
// NOTE: It won't be held across await so that `C` doesn't have to be
// `Send`
let ctx = match C::from_app_ctx(ctx.deref().clone()).await {
let ctx = match C::from_request_extensions(ctx.deref().clone(), extensions)
.await
{
Ok(ctx) => ctx,
Err(e) => {
// Handle any error building the context by turning it into a
Expand Down
20 changes: 10 additions & 10 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ mod error;
mod router;

pub use error::*;
pub use http::Extensions;
pub use router::{Router, ServerHandle};

/// Router context variation that can derived from `Ctx`.
#[trait_variant::make(FromContext: Send)]
pub trait LocalFromContext<Ctx>
/// Context can be built from request information by implementing the following trait. The
/// extensions are passed in from the request (see [`Extensions`]), which can be added using tower
/// middleware.
#[trait_variant::make(Send)]
pub trait FromRequestExtensions<Ctx>
where
Self: Sized,
{
/// Create a new instance from the provided context.
///
/// This is falliable, so any errors must produce a [`RpcError`], which will be returned to the
/// client.
async fn from_app_ctx(ctx: Ctx) -> Result<Self, RpcError>;
/// Using the provided context and extensions, build a new extension.
async fn from_request_extensions(ctx: Ctx, extensions: Extensions) -> Result<Self, RpcError>;
}

impl<Ctx: Send> FromContext<Ctx> for Ctx {
async fn from_app_ctx(ctx: Ctx) -> Result<Self, RpcError> {
impl<Ctx: Send> FromRequestExtensions<Ctx> for Ctx {
async fn from_request_extensions(ctx: Ctx, _extensions: Extensions) -> Result<Self, RpcError> {
Ok(ctx)
}
}

0 comments on commit 7274cb0

Please sign in to comment.