Skip to content

Commit

Permalink
feature: redis pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
cdxker committed Mar 20, 2024
1 parent e7b667e commit ad96641
Show file tree
Hide file tree
Showing 24 changed files with 302 additions and 262 deletions.
47 changes: 46 additions & 1 deletion server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ tokio-postgres = "0.7.10"
postgres-openssl = "0.5.0"
openssl = "0.10.64"
utoipa-swagger-ui = { version = "6.0.0", features = ["actix-web"] }
deadpool-redis = { version = "0.14.0", features = ["rt_tokio_1"] }


[build-dependencies]
Expand Down
20 changes: 13 additions & 7 deletions server/src/af_middleware/auth_middleware.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
data::models::{DatasetAndOrgWithSubAndPlan, Pool, UserRole},
data::models::{DatasetAndOrgWithSubAndPlan, Pool, RedisPool, UserRole},
errors::ServiceError,
handlers::auth_handler::{LoggedUser, OrganizationRole},
operators::{
Expand Down Expand Up @@ -43,6 +43,9 @@ where
let transaction = sentry::start_transaction(tx_ctx);
let org_id_span = transaction.start_child("orgid", "Getting organization id");

let pool = req.app_data::<web::Data<Pool>>().unwrap().to_owned();
let redis_pool = req.app_data::<web::Data<RedisPool>>().unwrap().to_owned();

let org_id = match req.headers().get("TR-Organization") {
Some(org_header) => {
let orgid_result = org_header
Expand All @@ -55,8 +58,6 @@ where
.parse::<uuid::Uuid>();

if let Some(dataset_header) = req.headers().get("TR-Dataset") {
let pool = req.app_data::<web::Data<Pool>>().unwrap().to_owned();

let dataset_id = dataset_header
.to_str()
.map_err(|_| {
Expand All @@ -67,9 +68,12 @@ where
ServiceError::BadRequest("Dataset must be valid UUID".to_string())
})?;

let dataset = get_dataset_by_id_query(dataset_id, pool.clone()).await?;
let dataset =
get_dataset_by_id_query(dataset_id, redis_pool.clone(), pool.clone())
.await?;
let org_plan_sub = get_organization_by_key_query(
dataset.organization_id.into(),
redis_pool.clone(),
pool.clone(),
)
.await
Expand All @@ -95,6 +99,7 @@ where
})?
.to_string()
.into(),
redis_pool,
pool,
)
.await
Expand All @@ -110,8 +115,6 @@ where

None => match req.headers().get("TR-Dataset") {
Some(dataset_header) => {
let pool = req.app_data::<web::Data<Pool>>().unwrap().to_owned();

let dataset_id = dataset_header
.to_str()
.map_err(|_| {
Expand All @@ -122,9 +125,12 @@ where
ServiceError::BadRequest("Dataset must be valid UUID".to_string())
})?;

let dataset = get_dataset_by_id_query(dataset_id, pool.clone()).await?;
let dataset =
get_dataset_by_id_query(dataset_id, redis_pool.clone(), pool.clone())
.await?;
let org_plan_sub = get_organization_by_key_query(
dataset.organization_id.into(),
redis_pool.clone(),
pool.clone(),
)
.await
Expand Down
42 changes: 25 additions & 17 deletions server/src/bin/ingestion-microservice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use diesel_async::pooled_connection::{AsyncDieselConnectionManager, ManagerConfig};
use redis::AsyncCommands;
use sentry::{Hub, SentryFutureExt};
use tracing_subscriber::{prelude::*, EnvFilter, Layer};
use trieve_server::data::models::{self, Event, ServerDatasetConfiguration};
Expand Down Expand Up @@ -71,6 +70,7 @@ fn main() {
};

let database_url = get_env!("DATABASE_URL", "DATABASE_URL is not set");
let redis_url = get_env!("REDIS_URL", "REDIS_URL is not set");

let mut config = ManagerConfig::default();
config.custom_setup = Box::new(establish_connection);
Expand All @@ -87,26 +87,25 @@ fn main() {

let web_pool = actix_web::web::Data::new(pool.clone());

let redis_pool = deadpool_redis::Config::from_url(redis_url)
.create_pool(Some(deadpool_redis::Runtime::Tokio1))
.unwrap();
redis_pool.resize(30);
let web_redis_pool = actix_web::web::Data::new(redis_pool);

tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(
async move {
let redis_url = get_env!("REDIS_URL", "REDIS_URL is not set");
let redis_client = redis::Client::open(redis_url).unwrap();
let redis_connection = redis_client
.get_multiplexed_tokio_connection()
.await
.unwrap();

let threads: Vec<_> = (0..thread_num)
.map(|i| {
let web_pool = web_pool.clone();
let redis_connection = redis_connection.clone();
tokio::spawn(async move {
ingestion_service(i, redis_connection, web_pool).await
})
let web_redis_pool = web_redis_pool.clone();
tokio::spawn(
async move { ingestion_service(i, web_redis_pool, web_pool).await },
)
})
.collect();

Expand All @@ -116,21 +115,30 @@ fn main() {
);
}

#[tracing::instrument(skip(web_pool, redis_connection))]
#[tracing::instrument(skip(web_pool, redis_pool))]
async fn ingestion_service(
thread: usize,
mut redis_connection: redis::aio::MultiplexedConnection,
redis_pool: actix_web::web::Data<models::RedisPool>,
web_pool: actix_web::web::Data<models::Pool>,
) {
log::info!("Starting ingestion service thread");
loop {
let payload_result = redis_connection
.brpop::<&str, Vec<String>>("ingestion", 0.0)
.await;
let mut redis_connection = redis_pool
.get()
.await
.expect("Failed to fetch from redis pool");

let payload_result: Result<Vec<String>, deadpool_redis::redis::RedisError> =
deadpool_redis::redis::cmd("brpop")
.arg("ingestion")
.arg(0.0)
.query_async(&mut redis_connection)
.await;

let payload = if let Ok(payload) = payload_result {
payload
} else {
log::error!("Unable to process {:?}", payload_result);
continue;
};

Expand Down
1 change: 1 addition & 0 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use utoipa::ToSchema;

// type alias to use in multiple places
pub type Pool = diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>;
pub type RedisPool = deadpool_redis::Pool;

#[derive(Debug, Serialize, Deserialize, Queryable, Insertable, Selectable, Clone, ToSchema)]
#[schema(example = json!({
Expand Down
20 changes: 12 additions & 8 deletions server/src/handlers/auth_handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::data::models::{Organization, StripePlan, UserRole};
use crate::data::models::{Organization, RedisPool, StripePlan, UserRole};
use crate::get_env;
use crate::operators::invitation_operator::check_inv_valid;
use crate::operators::organization_operator::{
Expand Down Expand Up @@ -162,13 +162,14 @@ pub async fn build_oidc_client() -> CoreClient {
)
}

#[tracing::instrument(skip(pool))]
#[tracing::instrument(skip(pool, redis_pool))]
pub async fn create_account(
email: String,
name: String,
user_id: uuid::Uuid,
organization_id: Option<uuid::Uuid>,
inv_code: Option<uuid::Uuid>,
redis_pool: web::Data<RedisPool>,
pool: web::Data<Pool>,
) -> Result<(User, Vec<UserOrganization>, Vec<Organization>), ServiceError> {
let (mut role, org) = match organization_id {
Expand All @@ -184,7 +185,7 @@ pub async fn create_account(
.replace(' ', "-");
(
UserRole::Owner,
create_organization_query(org_name.as_str(), pool.clone())
create_organization_query(org_name.as_str(), redis_pool.clone(), pool.clone())
.await
.map_err(|error| {
ServiceError::InternalServerError(error.message.to_string())
Expand All @@ -194,9 +195,10 @@ pub async fn create_account(
};
let org_id = org.id;

let org_plan_sub = get_organization_by_key_query(org_id.into(), pool.clone())
.await
.map_err(|error| ServiceError::InternalServerError(error.message.to_string()))?;
let org_plan_sub =
get_organization_by_key_query(org_id.into(), redis_pool.clone(), pool.clone())
.await
.map_err(|error| ServiceError::InternalServerError(error.message.to_string()))?;
let user_org_count_pool = pool.clone();
let user_org_count = get_user_org_count(org_id, user_org_count_pool)
.await
Expand Down Expand Up @@ -390,12 +392,13 @@ pub async fn login(
(status = 400, description = "Email or password empty or incorrect", body = ErrorResponseBody),
)
)]
#[tracing::instrument(skip(session, oidc_client, pool))]
#[tracing::instrument(skip(session, oidc_client, pool, redis_pool))]
pub async fn callback(
req: HttpRequest,
session: Session,
oidc_client: web::Data<CoreClient>,
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
query: web::Query<OpCallback>,
) -> Result<HttpResponse, Error> {
let state: OpenIdConnectState = session
Expand Down Expand Up @@ -482,6 +485,7 @@ pub async fn callback(
user_id,
login_state.organization_id,
login_state.inv_code,
redis_pool.clone(),
pool.clone(),
)
.await?
Expand All @@ -508,7 +512,7 @@ pub async fn callback(
invitation.organization_id,
invitation.role.into(),
);
add_user_to_organization(None, None, user_org, pool).await?;
add_user_to_organization(None, None, user_org, redis_pool, pool).await?;
}
}

Expand Down
Loading

0 comments on commit ad96641

Please sign in to comment.