Skip to content

Commit

Permalink
Merge pull request #16 from BiomedSciAI/hpa_using_utils
Browse files Browse the repository at this point in the history
HPA using utils
  • Loading branch information
yoavkt authored Jun 26, 2024
2 parents cd9b6ea + d8a8328 commit 1b92f99
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 67 deletions.
103 changes: 39 additions & 64 deletions scripts/tasks_retrival/HPA_tasks_creation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import os

import click
import numpy as np
import pandas as pd
from yaml import safe_load

CELL_LINE = "Cell line expression cluster"
from gene_benchmark.tasks import dump_task_definitions
from scripts.tasks_retrival.task_retrieval import (
check_data_type,
create_single_label_task,
load_yaml_file,
print_numerical_task_report,
report_task_single_col,
tag_list_to_multi_label,
)

COLUMN_TO_CLEAR_SEMICOLON = "Cell line expression cluster"
DATA_URL = "https://v23.proteinatlas.org/download/proteinatlas.tsv.zip"


Expand All @@ -14,36 +20,21 @@ def import_data(url):
return data


def load_yaml_file(yaml_path):
with open(yaml_path) as f:
loaded_yaml = safe_load(f)
return loaded_yaml


def format_pathology_columns(data):
def format_hpa_columns(data, clear_semicolon=None):
pathology_columns = list(
filter(lambda x: "Pathology prognostics" in x, data.columns)
)
data[pathology_columns] = data[pathology_columns].replace(
r"\s*\(\d+\.?\d*e?-?\d*\)", "", regex=True
)
if not clear_semicolon is None and COLUMN_TO_CLEAR_SEMICOLON in data.columns:
data[clear_semicolon] = (
data[clear_semicolon].astype(str).apply(lambda x: x.replace(";", ""))
)
return data


def check_data_type(data_col):
if data_col.nunique() == 2:
return "binary"
elif (data_col.nunique() > 2) & (data_col.dtypes == object):
if data_col.astype(str).str.contains("[,;]").any():
return "multi_class"
return "categorical"
elif (data_col.nunique() > 2) & (
(data_col.dtypes == "int64") | (data_col.dtypes == "float64")
):
return "numerical"


def create_tasks(data, main_task_directory):
def create_tasks(data, main_task_directory, verbose=False):
for col in data:
current_col_data = data[col]
current_col_data = current_col_data.replace("", pd.NA)
Expand All @@ -52,40 +43,20 @@ def create_tasks(data, main_task_directory):
current_col_data[current_col_data.index.str.contains("ENSG")].index
)
data_type = check_data_type(current_col_data)
task_name = col.replace("/", "|")
if data_type == "multi_class":
entities, outcomes = create_multi_label_task(current_col_data)
entities, outcomes = tag_list_to_multi_label(current_col_data)
if verbose:
print(
f"Create task {task_name} at {main_task_directory} outcomes shaped {outcomes.shape}"
)
else:
entities, outcomes = create_single_label_task(current_col_data)

save_task_to_dir(main_task_directory, col, entities, outcomes)


def create_single_label_task(current_col_data):
entities = pd.Series(current_col_data.index, name="symbol")
outcomes = pd.Series(current_col_data.values, name="Outcomes")
return entities, outcomes


def create_multi_label_task(current_col_data):
split_values_df = current_col_data.apply(
lambda x: [item.strip() for item in x.split(",")]
)
vocab = list(set(np.concatenate(split_values_df.values)))
outcome_df = pd.DataFrame(0, index=split_values_df.index, columns=vocab)
for index in range(split_values_df.shape[0]):
outcome_df.iloc[index][split_values_df.iloc[index]] = 1
entities = pd.Series(outcome_df.index, name="symbol")
return entities, outcome_df


def save_task_to_dir(main_task_directory, task_name, entities, outcomes):
task_name = task_name.replace("/", "|")
task_dir = main_task_directory + f"/{task_name}"
os.makedirs(task_dir, exist_ok=True)
entities_path = task_dir + "/entities.csv"
outcomes_path = task_dir + "/outcomes.csv"
entities.to_csv(entities_path, index=False, header="symbol")
outcomes.to_csv(outcomes_path, index=False, header="Outcomes")
if data_type == "numerical" and verbose:
print_numerical_task_report(outcomes, main_task_directory, task_name)
elif verbose:
report_task_single_col(outcomes, main_task_directory, task_name)
dump_task_definitions(entities, outcomes, main_task_directory, task_name)


@click.command()
Expand All @@ -112,7 +83,15 @@ def save_task_to_dir(main_task_directory, task_name, entities, outcomes):
@click.option(
"--input-file", type=click.STRING, help="The path to the data file", default=None
)
def main(columns_to_use_yaml, main_task_directory, allow_downloads, input_file):
@click.option(
"--verbose",
"-v",
is_flag=True,
default=True,
)
def main(
columns_to_use_yaml, main_task_directory, allow_downloads, input_file, verbose
):
if allow_downloads:
data = import_data(DATA_URL)
else:
Expand All @@ -122,13 +101,9 @@ def main(columns_to_use_yaml, main_task_directory, allow_downloads, input_file):
data = data.set_index("Gene")
data = data[columns_to_use]

data = format_pathology_columns(data)
if CELL_LINE in data.columns:
data[CELL_LINE] = (
data[CELL_LINE].astype(str).apply(lambda x: x.replace(";", ""))
)
data = format_hpa_columns(data, clear_semicolon=COLUMN_TO_CLEAR_SEMICOLON)

create_tasks(data, main_task_directory)
create_tasks(data, main_task_directory, verbose)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def main(
)

if verbose:
print_numerical_task_report(
downloaded_dataframe, main_task_directory, task_name
)
print_numerical_task_report(outcomes, main_task_directory, task_name)


if __name__ == "__main__":
Expand Down
117 changes: 117 additions & 0 deletions scripts/tasks_retrival/task_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pickle
from io import BytesIO
from itertools import chain
from pathlib import Path
from urllib.parse import urlparse

import mygene
import numpy as np
import pandas as pd
import requests
from yaml import safe_load


def verify_source_of_data(
Expand Down Expand Up @@ -225,3 +228,117 @@ def list_form_to_onehot_form(
path_genes = list_df.loc[pathway_idx, participant_col_name].split(delimiter)
onehot_df.loc[path_genes, pathway_idx] = True
return onehot_df


def check_data_type(
data_col: pd.Series,
binary_name: str = "binary",
category_name: str = "categorical",
numerical_name: str = "numerical",
multi_class_name: str = "multi_class",
multi_regex: str = "[,;]",
) -> str:
"""
Determines the column data type.
Args:
----
data_col (pd.Series): The series that is evaluated
binary_name (str, optional): The string name to be used if the data is binary. Defaults to "binary".
category_name (str, optional): The string name to be used if the data is categorical (multi-class). Defaults to 'categorical'.
numerical_name (str, optional): The string name to be used if the data is numerical. Defaults to "numerical".
multi_class_name (str, optional): The string name to be used if the data is multi class. Defaults to "multi_class".
multi_regex (str, optional): If the values of the series contains the regex then it's a multi label . Defaults to "[,;]".
Returns:
-------
str: the string type of the column
"""
if data_col.nunique() == 2:
return binary_name
elif (data_col.nunique() > 2) & (data_col.dtypes == object):
if data_col.astype(str).str.contains(multi_regex).any():
return multi_class_name
return category_name
elif (data_col.nunique() > 2) & (
(data_col.dtypes == "int64") | (data_col.dtypes == "float64")
):
return numerical_name


def load_yaml_file(yaml_path: str):
"""
loads a yaml file into an object.
Args:
----
yaml_path (str): the path to a yaml file
Returns:
-------
object: the loaded yaml file
"""
with open(yaml_path) as f:
loaded_yaml = safe_load(f)
return loaded_yaml


def create_single_label_task(
current_col_data: pd.Series,
entities_name: str = "symbol",
outcomes_name: str = "Outcomes",
) -> tuple[pd.Series, pd.Series]:
"""
take a series and creates task from the index and values.
Args:
----
current_col_data (pd.Series): the series to turn into a task
entities_name (str, optional): the name to be used for the entities. Defaults to "symbol".
outcomes_name (str, optional): the name to be used for the outcomes. Defaults to "Outcomes".
Returns:
-------
tuple[pd.Series,pd.Series]: entities and outcomes series
"""
entities = pd.Series(current_col_data.index, name=entities_name)
outcomes = pd.Series(current_col_data.values, name=outcomes_name)
return entities, outcomes


def tag_list_to_multi_label(
current_col_data: pd.Series, entities_name: str = "symbol", delimiter: str = ","
) -> tuple[pd.Series, pd.DataFrame]:
"""
Takes a table with entities in rows and a tag cloud of attributes in values
and converts into a multi label task.
Args:
----
current_col_data (pd.Series): Series with entities as indexes and the attribute list as the values
entities_name (str, optional): Type of the entities. Defaults to 'symbol'.
delimiter (str, optional): The delimiter in the attributes cloud . Defaults to ','.
Returns:
-------
tuple[pd.Series,pd.DataFrame]: A tuple with the entities and a dta frame where each column
is a attribute and the values represent the assignment to each attribute
"""
split_values_df = current_col_data.apply(
lambda x: [item.strip() for item in x.split(delimiter)]
)
vocab = list(set(np.concatenate(split_values_df.values)))
outcome_df = pd.DataFrame(0, index=split_values_df.index, columns=vocab)
for index in split_values_df.index:
true_cat = split_values_df[index]
if not isinstance(true_cat, list):
true_cat = list(set(chain(*true_cat.values)))
else:
true_cat = list(set(true_cat))
outcome_df.loc[index, true_cat] = 1
entities = pd.Series(outcome_df.index, name=entities_name)
return entities, outcome_df

0 comments on commit 1b92f99

Please sign in to comment.