Skip to content

Commit

Permalink
wip: bulk create chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
cdxker committed Mar 21, 2024
1 parent 492b835 commit 5d6269f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 74 deletions.
182 changes: 113 additions & 69 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter::zip;

use super::auth_handler::{AdminOnly, LoggedUser};
use crate::data::models::{
ChatMessageProxy, ChunkMetadata, ChunkMetadataWithFileData, DatasetAndOrgWithSubAndPlan,
Expand Down Expand Up @@ -45,7 +47,7 @@ use utoipa::ToSchema;
"weight": 0.5,
"split_avg": false
}))]
pub struct CreateChunkData {
pub struct ChunkData {
/// HTML content of the chunk. This can also be plaintext. The innerText of the HTML will be used to create the embedding vector. The point of using HTML is for convienience, as some users have applications where users submit HTML content.
pub chunk_html: Option<String>,
/// Link to the chunk. This can also be any string. Frequently, this is a link to the source of the chunk. The link value will not affect the embedding creation.
Expand Down Expand Up @@ -76,7 +78,7 @@ pub struct CreateChunkData {

#[derive(Serialize, Deserialize, Clone, ToSchema)]
#[schema(example = json!({
"chunk_metadata": {
"chunk_metadata": [{
"content": "Some content",
"link": "https://example.com",
"tag_set": ["tag1", "tag2"],
Expand All @@ -86,23 +88,45 @@ pub struct CreateChunkData {
"tracking_id": "tracking_id",
"time_stamp": "2021-01-01T00:00:00",
"weight": 0.5
},
}],
"pos_in_queue": 1
}))]
pub struct ReturnQueuedChunk {
pub chunk_metadata: ChunkMetadata,
pub chunk_metadatas: Vec<ChunkMetadata>,
pub pos_in_queue: i32,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct UploadIngestionMessage {
pub chunk_metadata: ChunkMetadata,
pub chunk: CreateChunkData,
pub chunk: ChunkData,
pub dataset_id: uuid::Uuid,
pub dataset_config: ServerDatasetConfiguration,
pub upsert_by_tracking_id: bool,
}

#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
#[serde(untagged)]
#[schema(example = json!({
"chunk_html": "<p>Some HTML content</p>",
"link": "https://example.com",
"tag_set": ["tag1", "tag2"],
"file_id": "d290f1ee-6c54-4b01-90e6-d701748f0851",
"metadata": {"key1": "value1", "key2": "value2"},
"chunk_vector": [0.1, 0.2, 0.3],
"tracking_id": "tracking_id",
"upsert_by_tracking_id": true,
"group_ids": ["d290f1ee-6c54-4b01-90e6-d701748f0851"],
"group_tracking_ids": ["group_tracking_id"],
"time_stamp": "2021-01-01T00:00:00",
"weight": 0.5,
"split_avg": false
}))]
pub enum CreateChunkData {
Single(ChunkData),
Batch(Vec<ChunkData>),
}

/// Create Chunk
///
/// Create a new chunk. If the chunk has the same tracking_id as an existing chunk, the request will fail. Once a chunk is created, it can be searched for using the search endpoint.
Expand Down Expand Up @@ -131,98 +155,117 @@ pub async fn create_chunk(
dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan,
redis_pool: web::Data<RedisPool>,
) -> Result<HttpResponse, actix_web::Error> {
let chunks = match chunk.into_inner() {
CreateChunkData::Single(chunk) => vec![chunk],
CreateChunkData::Batch(chunks) => chunks,
};

let count_dataset_id = dataset_org_plan_sub.dataset.id;

let mut timer = Timer::new();
let chunk_count = get_row_count_for_dataset_id_query(count_dataset_id, pool.clone())
.await
.map_err(|err| ServiceError::BadRequest(err.message.into()))?;
timer.add("get daataset count");

if chunk_count
if chunk_count + chunks.len()
>= dataset_org_plan_sub
.organization
.plan
.unwrap_or_default()
.chunk_count
.chunk_count as usize
{
return Ok(HttpResponse::UpgradeRequired()
.json(json!({"message": "Must upgrade your plan to add more chunks"})));
}

let chunk_tracking_id = chunk
.tracking_id
.clone()
.filter(|chunk_tracking| !chunk_tracking.is_empty());

let content = convert_html_to_text(chunk.chunk_html.as_ref().unwrap_or(&"".to_string()));
let server_dataset_configuration = ServerDatasetConfiguration::from_json(
dataset_org_plan_sub.dataset.server_configuration.clone(),
);

let chunk_tag_set = chunk.tag_set.clone().map(|tag_set| tag_set.join(","));

let chunk_metadata = ChunkMetadata::from_details(
content,
&chunk.chunk_html,
&chunk.link,
&chunk_tag_set,
None,
chunk.metadata.clone(),
chunk_tracking_id,
chunk
.time_stamp
.clone()
.map(|ts| -> Result<NaiveDateTime, ServiceError> {
Ok(ts
.parse::<DateTimeUtc>()
.map_err(|_| ServiceError::BadRequest("Invalid timestamp format".to_string()))?
.0
.with_timezone(&chrono::Local)
.naive_local())
})
.transpose()?,
dataset_org_plan_sub.dataset.id,
chunk.weight.unwrap_or(0.0),
);
let ingestion_messages = chunks
.iter()
.map(|chunk| async {
let content =
convert_html_to_text(chunk.chunk_html.as_ref().unwrap_or(&"".to_string()));
let chunk_tag_set = chunk.tag_set.clone().map(|tag_set| tag_set.join(","));

let group_ids_from_group_tracking_ids =
if let Some(group_tracking_ids) = chunk.group_tracking_ids.clone() {
get_groups_from_tracking_ids_query(group_tracking_ids, count_dataset_id, pool)
.await
.map_err(|err| ServiceError::BadRequest(err.message.into()))?
let chunk_tracking_id = chunk
.tracking_id
.clone()
.filter(|chunk_tracking| !chunk_tracking.is_empty());

let chunk_metadata = ChunkMetadata::from_details(
content,
&chunk.chunk_html,
&chunk.link,
&chunk_tag_set,
None,
chunk.metadata.clone(),
chunk_tracking_id,
chunk
.time_stamp
.clone()
.map(|ts| -> Result<NaiveDateTime, ServiceError> {
Ok(ts
.parse::<DateTimeUtc>()
.map_err(|_| {
ServiceError::BadRequest("Invalid timestamp format".to_string())
})?
.0
.with_timezone(&chrono::Local)
.naive_local())
})
.transpose()?,
dataset_org_plan_sub.dataset.id,
chunk.weight.unwrap_or(0.0),
);


let group_ids_from_group_tracking_ids = if let Some(group_tracking_ids) = chunk.group_tracking_ids.clone() {
get_groups_from_tracking_ids_query(group_tracking_ids, count_dataset_id, pool)
.await
.map_err(|err| ServiceError::BadRequest(err.message.into()))?
.into_iter()
.map(|group| group.id)
.collect::<Vec<uuid::Uuid>>()
} else {
vec![]
};

let initial_group_ids = chunk.group_ids.clone().unwrap_or_default();
let mut chunk_only_group_ids = chunk.clone();
let deduped_group_ids = group_ids_from_group_tracking_ids
.into_iter()
.map(|group| group.id)
.collect::<Vec<uuid::Uuid>>()
} else {
vec![]
};

let initial_group_ids = chunk.group_ids.clone().unwrap_or_default();
let mut chunk_only_group_ids = chunk.clone();
let deduped_group_ids = group_ids_from_group_tracking_ids
.into_iter()
.chain(initial_group_ids.into_iter())
.unique()
.collect::<Vec<uuid::Uuid>>();
chunk_only_group_ids.group_ids = Some(deduped_group_ids.clone());
chunk_only_group_ids.group_tracking_ids = None;
.chain(initial_group_ids.into_iter())
.unique()
.collect::<Vec<uuid::Uuid>>();

let server_dataset_configuration = ServerDatasetConfiguration::from_json(
dataset_org_plan_sub.dataset.server_configuration.clone(),
);
chunk_only_group_ids.group_ids = Some(deduped_group_ids);
chunk_only_group_ids.group_tracking_ids = None;

let ingestion_message = UploadIngestionMessage {
chunk_metadata: chunk_metadata.clone(),
chunk: chunk_only_group_ids.clone(),
dataset_id: count_dataset_id,
dataset_config: server_dataset_configuration.clone(),
upsert_by_tracking_id: chunk.upsert_by_tracking_id.unwrap_or(false),
};

Ok(UploadIngestionMessage {
chunk_metadata: chunk_metadata.clone(),
chunk: chunk_only_group_ids.clone(),
dataset_id: count_dataset_id,
dataset_config: server_dataset_configuration.clone(),
upsert_by_tracking_id: chunk.upsert_by_tracking_id.unwrap_or(false),
})

})
.collect();

let mut redis_conn = redis_pool
.get()
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?;
timer.add("Got redis conn");

deadpool_redis::redis::cmd("lpush")
.arg("ingestion")
.arg(serde_json::to_string(&ingestion_message)?)
.arg(serde_json::to_string(&ingestion_messages)?)
.query_async(&mut redis_conn)
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?;
Expand All @@ -232,9 +275,10 @@ pub async fn create_chunk(
.query_async(&mut redis_conn)
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?;
timer.add("Push to reedis");

Ok(HttpResponse::Ok().json(ReturnQueuedChunk {
chunk_metadata: chunk_metadata.clone(),
chunk_metadatas,
pos_in_queue,
}))
}
Expand Down
2 changes: 1 addition & 1 deletion server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ pub async fn main() -> std::io::Result<()> {
let redis_pool = deadpool_redis::Config::from_url(redis_url)
.create_pool(Some(deadpool_redis::Runtime::Tokio1))
.unwrap();
redis_pool.resize(30);
redis_pool.resize(200);

let oidc_client = build_oidc_client().await;

Expand Down
4 changes: 2 additions & 2 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ pub fn find_relevant_sentence(
pub async fn get_row_count_for_dataset_id_query(
dataset_id: uuid::Uuid,
pool: web::Data<Pool>,
) -> Result<i32, DefaultError> {
) -> Result<usize, DefaultError> {
use crate::data::schema::dataset_usage_counts::dsl as dataset_usage_counts_columns;

let mut conn = pool.get().await.expect("Failed to get connection to db");
Expand All @@ -987,5 +987,5 @@ pub async fn get_row_count_for_dataset_id_query(
message: "Failed to get chunk count for dataset",
})?;

Ok(chunk_metadata_count)
Ok(chunk_metadata_count as usize)
}
5 changes: 3 additions & 2 deletions server/src/operators/file_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::data::models::{
ChunkMetadata, Dataset, DatasetAndOrgWithSubAndPlan, EventType, ServerDatasetConfiguration,
};
use crate::handlers::auth_handler::AdminOnly;
use crate::handlers::chunk_handler::ChunkData;
use crate::operators::chunk_operator::delete_chunk_metadata_query;
use crate::{data::models::ChunkGroup, handlers::chunk_handler::ReturnQueuedChunk};
use crate::{data::models::Event, get_env};
Expand Down Expand Up @@ -318,7 +319,7 @@ pub async fn create_chunks_with_handler(
})?;

for chunk_html in chunk_htmls {
let create_chunk_data = CreateChunkData {
let create_chunk_data = ChunkData {
chunk_html: Some(chunk_html.clone()),
link: link.clone(),
tag_set: split_tag_set.clone(),
Expand All @@ -333,7 +334,7 @@ pub async fn create_chunks_with_handler(
weight: None,
split_avg: None,
};
let web_json_create_chunk_data = web::Json(create_chunk_data);
let web_json_create_chunk_data = web::Json(CreateChunkData::Single(create_chunk_data));

match create_chunk(
web_json_create_chunk_data,
Expand Down

0 comments on commit 5d6269f

Please sign in to comment.