Skip to content

Commit

Permalink
feature: jina-code support
Browse files Browse the repository at this point in the history
  • Loading branch information
cdxker authored and skeptrunedev committed May 9, 2024
1 parent cd92638 commit 0db9eae
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 59 deletions.
6 changes: 6 additions & 0 deletions dashboard/src/types/apiTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ export const availableEmbeddingModels = [
url: "https://embedding.trieve.ai/bge-m3",
dimension: 1024,
},
{
id: "jina-embeddings-v2-base-code",
name: "jina-embeddings-v2-base-code (securely hosted by Trieve)",
url: "https://embedding.trieve.ai/jina-code",
dimension: 1024,
},
{
id: "text-embedding-3-small",
name: "text-embedding-3-small (hosted by OpenAI)",
Expand Down
5 changes: 2 additions & 3 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ pub fn main() -> std::io::Result<()> {

let sentry_url = std::env::var("SENTRY_URL");
let _guard = if let Ok(sentry_url) = sentry_url {
log::info!("Sentry monitoring enabled");

let guard = sentry::init((
sentry_url,
sentry::ClientOptions {
Expand All @@ -356,6 +354,7 @@ pub fn main() -> std::io::Result<()> {
.init();

std::env::set_var("RUST_BACKTRACE", "1");
log::info!("Sentry monitoring enabled");
Some(guard)
} else {
tracing_subscriber::Registry::default()
Expand Down Expand Up @@ -447,7 +446,7 @@ pub fn main() -> std::io::Result<()> {
.app_data(web::Data::new(pool.clone()))
.app_data(web::Data::new(oidc_client.clone()))
.app_data(web::Data::new(redis_pool.clone()))
.wrap(sentry_actix::Sentry::with_transaction())
.wrap(sentry_actix::Sentry::new())
.wrap(af_middleware::auth_middleware::AuthMiddlewareFactory)
.wrap(
IdentityMiddleware::builder()
Expand Down
91 changes: 35 additions & 56 deletions server/src/operators/model_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,17 @@ pub async fn create_embedding(
"https://embedding.trieve.ai" => std::env::var("EMBEDDING_SERVER_ORIGIN")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or(
get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
),
.unwrap_or("https://embedding.trieve.ai".to_string()),
"https://embedding.trieve.ai/bge-m3" => std::env::var("EMBEDDING_SERVER_ORIGIN_BGEM3")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or(
get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
)
.to_string(),
.unwrap_or("https://embedding.trieve.ai/bge-m3".to_string()),
"https://embedding.trieve.ai/jinaai-code" => {
std::env::var("EMBEDDING_SERVER_ORIGIN_JINA_CODE")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or("https://embedding.trieve.ai/jinaai-code".to_string())
}
_ => config_embedding_base_url,
};

Expand Down Expand Up @@ -167,14 +160,13 @@ pub async fn get_sparse_vector(
_ => unreachable!("Invalid embed_type passed"),
};

let server_origin = match std::env::var(origin_key).ok().filter(|s| !s.is_empty()) {
Some(origin) => origin,
None => get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
};
let server_origin = std::env::var(origin_key)
.ok()
.filter(|s| !s.is_empty())
.ok_or(ServiceError::BadRequest(format!(
"{} does not exist",
origin_key
)))?;

let embedding_server_call = format!("{}/embed_sparse", server_origin);

Expand Down Expand Up @@ -256,23 +248,16 @@ pub async fn create_embeddings(
"https://embedding.trieve.ai" => std::env::var("EMBEDDING_SERVER_ORIGIN")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or(
get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
),
.unwrap_or("https://embedding.trieve.ai".to_string()),
"https://embedding.trieve.ai/bge-m3" => std::env::var("EMBEDDING_SERVER_ORIGIN_BGEM3")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or(
get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
)
.unwrap_or("https://embedding.trieve.ai/bge-m3".to_string())
.to_string(),
"https://embedding.trieve.ai/jina-code" => std::env::var("EMBEDDING_SERVER_ORIGIN_BGEM3")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or("https://embedding.trieve.ai/jina-code".to_string())
.to_string(),
_ => config_embedding_base_url,
};
Expand Down Expand Up @@ -432,17 +417,16 @@ pub async fn get_sparse_vectors(
_ => unreachable!("Invalid embed_type passed"),
};

let server_origin = match std::env::var(origin_key).ok().filter(|s| !s.is_empty()) {
Some(origin) => origin,
None => get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
};
let embedding_server_call = format!("{}/embed_sparse", server_origin);

async move {
let server_origin = std::env::var(origin_key)
.ok()
.filter(|s| !s.is_empty())
.ok_or(ServiceError::BadRequest(format!(
"env flag {} is not set",
origin_key
)))?;
let embedding_server_call = format!("{}/embed_sparse", server_origin);

let sparse_embed_req = CustomSparseEmbedData {
inputs: thirty_messages.to_vec(),
encode_type: embed_type.to_string(),
Expand Down Expand Up @@ -548,17 +532,12 @@ pub async fn cross_encoder(
};
sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone())));

let server_origin: String = match std::env::var("RERANKER_SERVER_ORIGIN")
let server_origin: String = std::env::var("RERANKER_SERVER_ORIGIN")
.ok()
.filter(|s| !s.is_empty())
{
Some(origin) => origin,
None => get_env!(
"GPU_SERVER_ORIGIN",
"GPU_SERVER_ORIGIN should be set if this is called"
)
.to_string(),
};
.ok_or(ServiceError::BadRequest(
"env flag RERANKER_SERVER_ORIGIN is not set".to_string(),
))?;

let embedding_server_call = format!("{}/rerank", server_origin);

Expand Down

0 comments on commit 0db9eae

Please sign in to comment.