Replies: 3 comments 2 replies
-
I myself as a user of the reflex ecosystem haven't dived deep into this yet. But you would actually need to know who made the request to your endpoints so that you can get their very specific state. State isn't global state per se, it's state bound to reflex browser sessions. Said differently, there are multiple state "scopes", one per user. What you could be doing as a quick and potentially dirty work around is to copy the states that you want access from your endpoints (it's not a "middleware", as you call it!) into global variables in your module. That could be a dictionary mapping from the user ID to the specific state. And then in your endpoints, pass the user ID as a request parameter to the route. And use that user ID to access that global state. You see the strategy? |
Beta Was this translation helpful? Give feedback.
-
Understood:) Thank You! |
Beta Was this translation helpful? Give feedback.
-
The key here is you need to get the state instance associated with the user session that is making the request. If you're triggering the request from other reflex code, then you can slurp the token from The The following code implements the FastAPI dependency from collections.abc import AsyncGenerator
import datetime
from typing import Annotated, Any, Literal
import uuid
from fastapi import Depends, FastAPI, HTTPException, Request
import reflex as rx
from reflex.event import EventChainVar, EventHandler, EventSpec, passthrough_event_spec
class PromiseVar(rx.Var):
"""A Var representing a JavaScript Promise."""
def _chain(
self, js_expr_chain: Literal["then", "catch"], callback: rx.Var
) -> "PromiseVar":
"""Chain a callback to the promise.
Args:
js_expr_chain: The JavaScript expression to chain (then or catch).
callback: The callback to chain.
Returns:
A new PromiseVar with the callback chained.
"""
callback_var = rx.Var.create(callback)
return self._replace(
_js_expr=f"{self!s}.{js_expr_chain}({callback!s})",
merge_var_data=callback_var._get_all_var_data(),
)
def then(self, callback: rx.Var) -> "PromiseVar":
"""Chain a callback to the promise.
Args:
callback: The callback to chain.
Returns:
A new PromiseVar with the callback chained.
"""
return self._chain("then", callback)
def catch(self, callback: rx.Var) -> "PromiseVar":
"""Chain a callback to handle errors.
Args:
callback: The callback to handle errors.
Returns:
A new PromiseVar with the error handler chained.
"""
return self._chain("catch", callback)
class ArgsFunctionOperationPromise(rx.vars.function.ArgsFunctionOperation):
"""A function operation that returns a PromiseVar when called.
Used for chaining promises from python code.
"""
def __call__(self, *args, **kwargs) -> PromiseVar:
"""Call the function with the given arguments.
Args:
*args: The arguments to pass to the function.
**kwargs: The keyword arguments to pass to the function.
Returns:
A PromiseVar representing the result of the function call.
"""
call_result = super().__call__(*args, **kwargs)
return PromiseVar(
_js_expr=call_result._js_expr,
_var_data=call_result._get_all_var_data(),
)
def get_backend_url(relative_url: str | rx.Var[str]) -> rx.Var[str]:
"""Get the full backend URL for a given relative URL.
Use with `fetch` to access backend API endpoints.
Args:
relative_url: The relative URL to convert.
Returns:
A Var representing the full backend URL.
"""
return rx.vars.function.ArgsFunctionOperation.create(
args_names=("url",),
return_expr=rx.Var(
r"`${getBackendURL(env.UPLOAD).origin}/${url.replace(/^\/+/, '')}`"
),
_var_data=rx.vars.VarData(
imports={
"$/utils/state": ["getBackendURL"],
"$/env.json": rx.ImportVar(tag="env", is_default=True),
}
),
)(relative_url).to(str)
def fetch(
url: str | rx.Var[str],
options: dict[str, Any] | rx.Var[dict[str, Any]] | None = None,
) -> PromiseVar:
"""Fetch a URL with the given options.
Args:
url: The URL to fetch.
options: The options to use for the fetch.
Returns:
A PromiseVar representing the eventual result of the fetch.
"""
return ArgsFunctionOperationPromise.create(
args_names=("url", "options"),
return_expr=rx.Var(
"fetch(url, {...options, headers: {...options.headers, 'X-Reflex-Client-Token': getToken()}})"
),
_var_data=rx.vars.VarData(
imports={
"$/utils/state": [
"getToken",
],
}
),
)(url, options or {})
class State(rx.State):
"""The main app state."""
value: str = "initial"
last_request_time: datetime.datetime | None
class FetchResultState(rx.State):
"""State to hold the result of the fetch operation."""
response: dict[str, str]
headers: dict[str, str]
# Set up the custom FastAPI endpoints
fastapi_app = FastAPI(title="My API")
async def reflex_state(request: Request) -> AsyncGenerator[State]:
"""FastAPI dependency to fetch the main app state for an API request."""
token = request.headers.get("X-Reflex-Client-Token")
if not token:
raise HTTPException(
status_code=401,
detail="X-Reflex-Client-Token header is required",
)
# Get the instance of the state associated with the token.
async with app.modify_state(token) as root_state:
# Fetch and yield the app's main state from the root state.
yield await root_state.get_state(State)
# NOTE: Any changes made to the state by the route endpoint. will be saved,
# but the delta will only be sent to the frontend IF THE REQUEST IS HANDLED
# BY THE SAME INSTANCE THAT HAS THE WEBSOCKET! Big caveat that makes this
# approach mostly useful for only reading the state.
# Add routes to the FastAPI app
@fastapi_app.get("/api/value")
async def get_current_value(
state: Annotated[State, Depends(reflex_state)],
) -> dict[str, str]:
# Update the state with the current request time.
state.last_request_time = datetime.datetime.now()
# Read the value from the instance of the State, returning JSON.
return {"value": state.value, "txid": str(uuid.uuid4())}
def fetch_with_json(
endpoint: str,
callbacks: list[EventHandler | EventSpec],
) -> EventChainVar:
"""Perform a fetch request, if it succeeds, run the callbacks with the response decoded as JSON."""
return rx.vars.function.ArgsFunctionOperation.create(
args_names=(),
return_expr=fetch(get_backend_url(endpoint)).then(
PromiseVar("(resp) => resp.json()").then(
rx.Var.create(
rx.EventChain.create(
callbacks,
args_spec=passthrough_event_spec(dict[str, str]),
),
),
),
),
).to(rx.EventChain)
def index() -> rx.Component:
return rx.container(
rx.color_mode.button(position="top-right"),
rx.vstack(
rx.heading("Custom API Request Demo", size="9"),
rx.form(
rx.input(
placeholder="Type something...",
value=State.value,
on_change=State.set_value,
),
rx.button("Make API Request"),
on_submit=fetch_with_json(
endpoint="/api/value",
callbacks=[
FetchResultState.set_response,
FetchResultState.set_headers(
rx.Var("Object.fromEntries(resp.headers.entries())")
),
],
),
),
rx.divider(),
rx.hstack(
rx.heading(
"API Response",
size="6",
),
rx.spacer(),
rx.cond(
State.last_request_time,
rx.moment(State.last_request_time, format="YYYY-MM-DD HH:mm:ss"),
),
align="center"
),
rx.code_block(
FetchResultState.response.to_string(),
language="json",
),
rx.divider(),
rx.heading(
"Response Headers",
size="6",
),
rx.code_block(
FetchResultState.headers.to_string(),
language="json",
),
),
rx.logo(),
)
app = rx.App(api_transformer=fastapi_app)
app.add_page(index) If you're making the request from outside of the reflex app, then you need to somehow make the session's client_token available to the service making the request. You can get the client token from the base state |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
i have added a fastapi middleware to get some values:
When i call my api, the value is always inital. Even it prints out "second" so this function is called.
What am i missing? Is the "State" the right state?
Beta Was this translation helpful? Give feedback.
All reactions