Skip to content

Commit

Permalink
Cannot compute complex score correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
yjcyxky committed Jul 2, 2024
1 parent 10bcc67 commit 7fb14f9
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 91 deletions.
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,20 @@ changelog:

deploy: deploy-biomedgps

deploy-biomedgps: build-biomedgps-studio
build-biomedgps-cross-compile:
@docker run --rm -it -v "$(CURDIR)":/home/rust/src messense/rust-musl-cross:x86_64-musl cargo build --release
@rsync -avP target/x86_64-unknown-linux-musl/release/biomedgps target/x86_64-unknown-linux-musl/release/biomedgps-cli [email protected]:/data/biomedgps/bin
@rsync -avP --delete assets/index.html [email protected]:/var/www/html/biomedgps/index.html
@rsync -avP --delete assets [email protected]:/var/www/html/biomedgps/

deploy-rapex: build-rapex-studio
deploy-biomedgps: build-biomedgps-studio build-biomedgps-cross-compile
@printf "\nDone!\n"

build-rapex-cross-compile:
@docker run --rm -it -v "$(CURDIR)":/home/rust/src messense/rust-musl-cross:x86_64-musl cargo build --release
@rsync -avP target/x86_64-unknown-linux-musl/release/biomedgps target/x86_64-unknown-linux-musl/release/biomedgps-cli [email protected]:/data/rapex/bin
@rsync -avP --delete assets/index.html [email protected]:/var/www/html/rapex/index.html
@rsync -avP --delete assets [email protected]:/var/www/html/rapex/
@rsync -avP --delete assets [email protected]:/var/www/html/rapex/

deploy-rapex: build-rapex-studio build-rapex-cross-compile
@printf "\nDone!\n"
19 changes: 4 additions & 15 deletions src/bin/biomedgps-cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub struct CacheTableArguments {
batch_size: usize,
}

/// Import embedding files into a database. The embedding files are generated by KGE models. If you have multiple models for different cases or datasets, you need to import them all and with different parameters, such as table_name, model_name, model_type, dataset, description etc. More details about these parameters can be found in their descriptions.
/// Import embedding files into a database. The embedding files are generated by KGE models. If you have multiple models for different cases or datasets, you need to import them all and with different parameters, such as table_name, model_type, dataset, description etc. More details about these parameters can be found in their descriptions.
#[derive(StructOpt, PartialEq, Debug)]
#[structopt(setting=structopt::clap::AppSettings::ColoredHelp, name="BioMedGPS - importkge", author="Jingcheng Yang <[email protected]>")]
pub struct ImportKGEArguments {
Expand Down Expand Up @@ -228,24 +228,15 @@ pub struct ImportKGEArguments {
#[structopt(name = "metadata_file", short = "f", long = "metadata-file")]
metadata_file: String,

/// [Optional] The table name you want to name. e.g. biomedgps, mecfs, etc. This feature is used to distinguish different dataset combinations matched with your model. If not set, we will use the biomedgps as default. But in this case, the dimension of the embedding should be 400.
/// [Optional] The dataset name you want to name. e.g. biomedgps, mecfs, etc. This feature is used to distinguish different dataset combinations matched with your model. If not set, we will use the biomedgps as default. But in this case, the dimension of the embedding should be 400.
#[structopt(
name = "table_name",
name = "dataset_name",
short = "t",
long = "table-name",
long = "dataset-name",
default_value = DEFAULT_MODEL_NAME
)]
table_name: String,

/// [Optional] The model name you want to name. e.g. mecfs_transe, mecfs_distmult, etc. You need to specify the model name when you import the embedding files. This feature is used to distinguish different models. Users can choose the model for their own purpose.
#[structopt(
name = "model_name",
short = "m",
long = "model-name",
default_value = DEFAULT_MODEL_NAME
)]
model_name: String,

/// [Required] The model type of generated embedding files. e.g. TransE_l1, TransE_l2, DistMult, ComplEx, etc. This feature is used to distinguish different models. Users can choose the model for their own purpose.
#[structopt(name = "model_type", short = "M", long = "model-type", default_value = "TransE_l2", possible_values = &["TransE_l1", "TransE_l2", "TransH", "TransR", "TransD", "RotatE", "DistMult", "ComplEx"])]
model_type: String,
Expand Down Expand Up @@ -699,7 +690,6 @@ async fn main() {
}

let table_name = arguments.table_name;
let model_name = arguments.model_name;
let model_type = arguments.model_type;
let description = arguments.description;
let datasets: Vec<&str> = arguments.dataset.iter().map(|s| s.as_str()).collect();
Expand Down Expand Up @@ -728,7 +718,6 @@ async fn main() {
import_kge(
&database_url,
&table_name,
&model_name,
&model_type,
&datasets,
description.as_deref(),
Expand Down
15 changes: 11 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,6 @@ pub async fn connect_db(database_url: &str, max_connections: u32) -> sqlx::PgPoo
pub async fn import_kge(
database_url: &str,
table_name: &str,
model_name: &str,
model_type: &str,
datasets: &Vec<&str>,
description: Option<&str>,
Expand Down Expand Up @@ -1373,7 +1372,6 @@ pub async fn import_kge(
match EmbeddingMetadata::init_embedding_table(
&pool,
table_name,
model_name,
model_type,
&description,
datasets,
Expand All @@ -1388,8 +1386,17 @@ pub async fn import_kge(
}
Err(e) => {
if drop {
info!("The embedding tables already exist, drop their records and reimport the embeddings.");
true
info!("Existing a same model in the database, we will drop it and re-import the model.");
match EmbeddingMetadata::delete(&pool, table_name).await {
Ok(_) => {
info!("Drop the existing model successfully, please rerun the command without the --drop option.");
std::process::exit(0);
}
Err(e) => {
error!("Failed to drop the existing model: {}", e);
std::process::exit(1);
}
}
} else {
error!("Failed to init the embedding tables: {}", e);
std::process::exit(1);
Expand Down
19 changes: 7 additions & 12 deletions src/model/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use lazy_static::lazy_static;
use log::{debug, error};
use neo4rs::{Node as NeoNode, Relation as NeoRelation};
use poem_openapi::Object;
use polars::lazy::dsl::Operator;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -667,15 +666,12 @@ impl TargetNode {
}
};

// TODO: We need to allow the user to set the score function, gamma and exp_enabled
let gamma = 12.0;
let sql_str = match Graph::format_score_sql(
&entity_id,
&entity_type,
relation_type,
&embedding_metadata,
topk,
gamma,
topk
) {
Ok(sql_str) => sql_str,
Err(err) => {
Expand Down Expand Up @@ -1085,17 +1081,17 @@ impl Graph {
/// let embedding_metadata = EmbeddingMetadata {
/// id: 1,
/// metadata: None,
/// model_name: "biomedgps_transe_l2".to_string(),
/// model_name: "biomedgps-TransE_l2".to_string(),
/// model_type: "TransE_l2".to_string(),
/// dimension: 400,
/// table_name: "biomedgps".to_string(),
/// created_at: DateTime::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc),
/// datasets: vec!("STRING".to_string()),
/// metadata: Some(r#"{"gamma": 12.0}"#.to_string()),
/// description: "The entity embedding trained by the TransE_l2 model".to_string(),
/// };
/// let topk = 10;
/// let gamma = 12.0;
/// let sql_str = Graph::format_score_sql(source_id, source_type, relation_type, &embedding_metadata, topk, gamma).unwrap();
/// let sql_str = Graph::format_score_sql(source_id, source_type, relation_type, &embedding_metadata, topk).unwrap();
/// let expected_sql_str = "
/// SELECT
/// COALESCE(ee2.entity_type, '') || '::' || COALESCE(ee2.entity_id, '') AS node_id,
Expand Down Expand Up @@ -1134,8 +1130,7 @@ impl Graph {
source_type: &str,
relation_type: &str,
embedding_metadata: &EmbeddingMetadata,
topk: u64,
gamma: f64,
topk: u64
) -> Result<String, ValidationError> {
let source_id = source_id.split(",").collect::<Vec<&str>>().join("', '");
let source_type_vec = source_type
Expand Down Expand Up @@ -1277,7 +1272,7 @@ impl Graph {
vector_to_float4(ee1.embedding, {dimension}, false),
vector_to_float4(rte.embedding, {dimension}, false),
vector_to_float4(ee2.embedding, {dimension}, false),
{gamma},
{gamma_str}
true,
{reverse}
) AS score
Expand Down Expand Up @@ -1307,7 +1302,7 @@ impl Graph {
relation_type_embedding_table =
get_relation_emb_table_name(&embedding_metadata.table_name),
score_function_name = score_function_name,
gamma = gamma
gamma_str = embedding_metadata.get_gamma_string(),
)
};

Expand Down
26 changes: 11 additions & 15 deletions src/model/init_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ pub fn init_score_sql(
first_second_relation_type: &str,
second_third_relation_type: &str,
table_prefix: Option<&str>,
gamma: f64,
embedding_metadata: &EmbeddingMetadata,
) -> String {
let table_prefix = table_prefix.unwrap_or(DEFAULT_MODEL_NAME);
Expand Down Expand Up @@ -159,15 +158,15 @@ pub fn init_score_sql(
vector_to_float4(tt.first_embedding, {dimension}, false),
vector_to_float4(tt.first_second_embedding, {dimension}, false),
vector_to_float4(tt.second_embedding, {dimension}, false),
{gamma},
{gamma_str}
true,
false
),
{score_function_name}(
vector_to_float4(tt.second_embedding, {dimension}, false),
vector_to_float4(tt.second_third_embedding, {dimension}, false),
vector_to_float4(tt.third_embedding, {dimension}, false),
{gamma},
{gamma_str}
true,
false
)
Expand All @@ -185,6 +184,7 @@ pub fn init_score_sql(
entity_emb_table = get_entity_emb_table_name(table_prefix),
score_function_name = score_function_name,
dimension = embedding_metadata.dimension,
gamma_str = embedding_metadata.get_gamma_string(),
)
}

Expand Down Expand Up @@ -278,15 +278,13 @@ pub async fn create_score_table(
}
};

let gamma = 12.0;
let init_sql = init_score_sql(
first_entity_type,
second_entity_type,
third_entity_type,
first_second_relation_type,
second_third_relation_type,
table_prefix,
gamma,
&embedding_metadata,
);

Expand Down Expand Up @@ -361,7 +359,6 @@ pub fn get_kg_score_table_name(table_prefix: &str) -> String {

pub fn init_kg_score_sql(
table_prefix: Option<&str>,
gamma: f64,
embedding_metadata: &EmbeddingMetadata,
) -> String {
let table_prefix = table_prefix.unwrap_or(DEFAULT_MODEL_NAME);
Expand Down Expand Up @@ -402,7 +399,7 @@ pub fn init_kg_score_sql(
vector_to_float4(tt.source_embedding, {dimension}, false),
vector_to_float4(tt.relation_type_embedding, {dimension}, false),
vector_to_float4(tt.target_embedding, {dimension}, false),
{gamma},
{gamma_str}
true,
false
)::FLOAT8 AS score
Expand All @@ -414,7 +411,7 @@ pub fn init_kg_score_sql(
entity_emb_table = get_entity_emb_table_name(table_prefix),
score_function_name = score_function_name,
dimension = embedding_metadata.dimension,
gamma = gamma
gamma_str = embedding_metadata.get_gamma_string(),
)
}

Expand Down Expand Up @@ -461,8 +458,7 @@ pub async fn create_kg_score_table(
}
};

let gamma = 12.0;
let init_sql = init_kg_score_sql(table_prefix, gamma, &embedding_metadata);
let init_sql = init_kg_score_sql(table_prefix, &embedding_metadata);

debug!("init_sql: {}", init_sql);
let mut tx = pool.begin().await.unwrap();
Expand Down Expand Up @@ -827,15 +823,15 @@ mod tests {
let third_entity_type = "Symptom";
let first_second_relation_type = "treats";
let second_third_relation_type = "causes";
let gamma = 12.0;
let embedding_metadata = EmbeddingMetadata {
id: 1,
metadata: None,
model_name: "biomedgps_transe_l2".to_string(),
model_name: "biomedgps-TransE_l2".to_string(),
model_type: "TransE_l2".to_string(),
dimension: 400,
table_name: "biomedgps".to_string(),
created_at: DateTime::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc),
gamma: Some(12.0),
datasets: vec!["STRING".to_string()],
description: "The entity embedding trained by the TransE_l2 model".to_string(),
};
Expand All @@ -847,7 +843,6 @@ mod tests {
first_second_relation_type,
second_third_relation_type,
Some(table_prefix),
gamma,
&embedding_metadata,
);
println!("sql: {}", sql);
Expand Down Expand Up @@ -889,15 +884,16 @@ mod tests {
let embedding_metadata = EmbeddingMetadata {
id: 1,
metadata: None,
model_name: "biomedgps_transe_l2".to_string(),
model_name: "biomedgps-TransE_l2".to_string(),
model_type: "TransE_l2".to_string(),
dimension: 400,
gamma: Some(12.0),
table_name: "biomedgps".to_string(),
created_at: DateTime::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc),
datasets: vec!["STRING".to_string()],
description: "The entity embedding trained by the TransE_l2 model".to_string(),
};
let sql = init_kg_score_sql(Some(table_prefix), gamma, &embedding_metadata);
let sql = init_kg_score_sql(Some(table_prefix), &embedding_metadata);
println!("sql: {}", sql);
assert!(sql.contains("biomedgps_relation_with_score"));
}
Expand Down
Loading

0 comments on commit 7fb14f9

Please sign in to comment.