Skip to content

Commit

Permalink
cleanup: simplify middleware and remove redis caching
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh committed Apr 11, 2024
1 parent ad3c565 commit d82277b
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 591 deletions.
129 changes: 35 additions & 94 deletions server/src/af_middleware/auth_middleware.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::{
data::models::{DatasetAndOrgWithSubAndPlan, Pool, RedisPool, UserRole},
data::models::{Pool, UserRole},
errors::ServiceError,
handlers::auth_handler::{LoggedUser, OrganizationRole},
operators::{
dataset_operator::get_dataset_by_id_query,
dataset_operator::get_dataset_and_organization_from_dataset_id_query,
organization_operator::{
get_arbitrary_org_owner_from_dataset_id, get_arbitrary_org_owner_from_org_id,
get_organization_by_key_query,
},
user_operator::get_user_from_api_key_query,
},
Expand Down Expand Up @@ -47,102 +46,46 @@ where
let org_id_span = transaction.start_child("orgid", "Getting organization id");

let pool = req.app_data::<web::Data<Pool>>().unwrap().to_owned();
let redis_pool = req.app_data::<web::Data<RedisPool>>().unwrap().to_owned();

let org_id = match req.headers().get("TR-Organization") {
Some(org_header) => {
let orgid_result = org_header
let (http_req, pl) = req.parts_mut();
let user = get_user(http_req, pl).await;

let org_id = match req.headers().get("TR-Dataset") {
Some(dataset_header) => {
let dataset_id = dataset_header
.to_str()
.map_err(|_| {
Into::<Error>::into(ServiceError::BadRequest(
"Could not convert Organization to str".to_string(),
))
ServiceError::BadRequest("Dataset must be valid string".to_string())
})?
.parse::<uuid::Uuid>();

if let Some(dataset_header) = req.headers().get("TR-Dataset") {
let dataset_id = dataset_header
.to_str()
.map_err(|_| {
ServiceError::BadRequest("Dataset must be valid string".to_string())
})?
.parse::<uuid::Uuid>()
.map_err(|_| {
ServiceError::BadRequest("Dataset must be valid UUID".to_string())
})?;

let dataset =
get_dataset_by_id_query(dataset_id, redis_pool.clone(), pool.clone())
.await?;
let org_plan_sub = get_organization_by_key_query(
dataset.organization_id.into(),
redis_pool.clone(),
pool.clone(),
)
.await?;
.parse::<uuid::Uuid>()
.map_err(|_| {
ServiceError::BadRequest("Dataset must be valid UUID".to_string())
})?;

let dataset_org_plan_sub =
DatasetAndOrgWithSubAndPlan::from_components(dataset, org_plan_sub);
let dataset_org_plan_sub = get_dataset_and_organization_from_dataset_id_query(
dataset_id,
pool.clone(),
)
.await?;

req.extensions_mut().insert(dataset_org_plan_sub.clone());
}
req.extensions_mut().insert(dataset_org_plan_sub.clone());

match orgid_result {
Ok(org_id) => org_id,
Err(_) => {
let pool = req.app_data::<web::Data<Pool>>().unwrap().to_owned();
let organization = get_organization_by_key_query(
org_header
.to_str()
.map_err(|_| {
Into::<Error>::into(ServiceError::InternalServerError(
"Could not convert Organization to str".to_string(),
))
})?
.to_string()
.into(),
redis_pool,
pool,
)
.await
.map_err(|_| {
Into::<Error>::into(ServiceError::InternalServerError(
"Could not get org id".into(),
))
})?;
organization.id
}
}
dataset_org_plan_sub.organization.organization.id
}

None => match req.headers().get("TR-Dataset") {
Some(dataset_header) => {
let dataset_id = dataset_header
.to_str()
.map_err(|_| {
ServiceError::BadRequest("Dataset must be valid string".to_string())
})?
.parse::<uuid::Uuid>()
.map_err(|_| {
ServiceError::BadRequest("Dataset must be valid UUID".to_string())
})?;

let dataset =
get_dataset_by_id_query(dataset_id, redis_pool.clone(), pool.clone())
.await?;
let org_plan_sub = get_organization_by_key_query(
dataset.organization_id.into(),
redis_pool.clone(),
pool.clone(),
)
.await?;
let dataset_org_plan_sub =
DatasetAndOrgWithSubAndPlan::from_components(dataset, org_plan_sub);

req.extensions_mut().insert(dataset_org_plan_sub.clone());

dataset_org_plan_sub.organization.id
}
None => match req.headers().get("TR-Organization") {
Some(org_header) => org_header
.to_str()
.map_err(|_| {
Into::<Error>::into(ServiceError::BadRequest(
"Could not convert Organization to str".to_string(),
))
})?
.parse::<uuid::Uuid>()
.map_err(|_| {
Into::<Error>::into(ServiceError::BadRequest(
"Could not convert Organization to UUID".to_string(),
))
})?,
None => {
let (http_req, pl) = req.parts_mut();
let user = get_user(http_req, pl).await;
Expand All @@ -159,9 +102,6 @@ where
},
};

let (http_req, pl) = req.parts_mut();
let user = get_user(http_req, pl).await;

if let Some(user) = user {
req.extensions_mut().insert(user.clone());

Expand Down Expand Up @@ -238,6 +178,7 @@ async fn get_user(req: &HttpRequest, pl: &mut Payload) -> Option<LoggedUser> {
}
}

//TODO: Cache the api key in redis
if let Some(pool) = req.app_data::<web::Data<Pool>>() {
if let Ok(user) = get_user_from_api_key_query(authen_header, pool).await {
return Some(user);
Expand Down
38 changes: 11 additions & 27 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,13 +1578,7 @@ impl Organization {
}

pub fn from_org_with_plan_sub(org_plan_sub: OrganizationWithSubAndPlan) -> Self {
Organization {
id: org_plan_sub.id,
name: org_plan_sub.name,
created_at: org_plan_sub.created_at,
updated_at: org_plan_sub.updated_at,
registerable: org_plan_sub.registerable,
}
org_plan_sub.organization.clone()
}
}

Expand Down Expand Up @@ -1757,11 +1751,13 @@ impl StripeSubscription {

#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
#[schema(example = json!({
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"name": "Trieve",
"created_at": "2021-01-01T00:00:00",
"updated_at": "2021-01-01T00:00:00",
"registerable": true,
"organization": {
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"name": "Trieve",
"created_at": "2021-01-01T00:00:00",
"updated_at": "2021-01-01T00:00:00",
"registerable": true,
},
"plan": {
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"stripe_id": "plan_123",
Expand All @@ -1786,11 +1782,7 @@ impl StripeSubscription {
}
}))]
pub struct OrganizationWithSubAndPlan {
pub id: uuid::Uuid,
pub name: String,
pub created_at: chrono::NaiveDateTime,
pub updated_at: chrono::NaiveDateTime,
pub registerable: Option<bool>,
pub organization: Organization,
pub plan: Option<StripePlan>,
pub subscription: Option<StripeSubscription>,
}
Expand All @@ -1802,23 +1794,15 @@ impl OrganizationWithSubAndPlan {
subscription: Option<StripeSubscription>,
) -> Self {
OrganizationWithSubAndPlan {
id: organization.id,
name: organization.name,
registerable: organization.registerable,
created_at: organization.created_at,
updated_at: organization.updated_at,
organization: organization.clone(),
plan,
subscription,
}
}

pub fn with_defaults(&self) -> Self {
OrganizationWithSubAndPlan {
id: self.id,
name: self.name.clone(),
registerable: self.registerable,
created_at: self.created_at,
updated_at: self.updated_at,
organization: self.organization.clone(),
plan: Some(self.plan.clone().unwrap_or_default()),
subscription: self.subscription.clone(),
}
Expand Down
25 changes: 10 additions & 15 deletions server/src/handlers/auth_handler.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::data::models::{Organization, RedisPool, StripePlan, UserRole};
use crate::data::models::{Organization, StripePlan, UserRole};
use crate::get_env;
use crate::operators::invitation_operator::check_inv_valid;
use crate::operators::organization_operator::{
get_org_from_id_query, get_organization_by_key_query, get_user_org_count,
};
use crate::operators::organization_operator::{get_org_from_id_query, get_user_org_count};
use crate::operators::user_operator::{add_user_to_organization, create_user_query};
use crate::{
data::models::{Pool, SlimUser, User, UserOrganization},
Expand Down Expand Up @@ -162,36 +160,35 @@ pub async fn build_oidc_client() -> CoreClient {
)
}

#[tracing::instrument(skip(pool, redis_pool))]
#[tracing::instrument(skip(pool))]
pub async fn create_account(
email: String,
name: String,
user_id: uuid::Uuid,
organization_id: Option<uuid::Uuid>,
inv_code: Option<uuid::Uuid>,
redis_pool: web::Data<RedisPool>,
pool: web::Data<Pool>,
) -> Result<(User, Vec<UserOrganization>, Vec<Organization>), ServiceError> {
let (mut role, org) = match organization_id {
Some(organization_id) => (
UserRole::User,
get_org_from_id_query(organization_id, pool.clone()).await?,
get_org_from_id_query(organization_id, pool.clone())
.await?
.organization,
),
None => {
let org_name = email.split('@').collect::<Vec<&str>>()[0]
.to_string()
.replace(' ', "-");
(
UserRole::Owner,
create_organization_query(org_name.as_str(), redis_pool.clone(), pool.clone())
.await?,
create_organization_query(org_name.as_str(), pool.clone()).await?,
)
}
};
let org_id = org.id;

let org_plan_sub =
get_organization_by_key_query(org_id.into(), redis_pool.clone(), pool.clone()).await?;
let org_plan_sub = get_org_from_id_query(org_id.into(), pool.clone()).await?;
let user_org_count_pool = pool.clone();
let user_org_count = get_user_org_count(org_id, user_org_count_pool).await?;
if user_org_count
Expand Down Expand Up @@ -381,13 +378,12 @@ pub async fn login(
(status = 400, description = "Email or password empty or incorrect", body = ErrorResponseBody),
)
)]
#[tracing::instrument(skip(session, oidc_client, pool, redis_pool))]
#[tracing::instrument(skip(session, oidc_client, pool))]
pub async fn callback(
req: HttpRequest,
session: Session,
oidc_client: web::Data<CoreClient>,
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
query: web::Query<OpCallback>,
) -> Result<HttpResponse, Error> {
let state: OpenIdConnectState = session
Expand Down Expand Up @@ -474,7 +470,6 @@ pub async fn callback(
user_id,
login_state.organization_id,
login_state.inv_code,
redis_pool.clone(),
pool.clone(),
)
.await?
Expand All @@ -501,7 +496,7 @@ pub async fn callback(
invitation.organization_id,
invitation.role.into(),
);
add_user_to_organization(None, None, user_org, redis_pool, pool).await?;
add_user_to_organization(None, None, user_org, pool).await?;
}
}

Expand Down
Loading

0 comments on commit d82277b

Please sign in to comment.