Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated pipeline for evaluating pre-computed fingerprints #2

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
884f98e
download LargeMix datasets in a single command
blazejba Dec 11, 2023
8a35f54
add a script for finetuning on fingerprints on cpu
blazejba Dec 21, 2023
1c80129
fix target types for regression
blazejba Dec 21, 2023
63e54ad
add lr scheduling
blazejba Dec 22, 2023
8ab2887
sweeping script ready
blazejba Dec 22, 2023
e521b52
Delete sweep_finetunning_on_fingerprints.sh
blazejba Dec 22, 2023
e4fc5fa
various fixes
blazejba Dec 22, 2023
43495d8
Merge branch 'bb/finetuning-on-fingerprints' of github.com:graphcore-…
blazejba Dec 22, 2023
a501ad2
full TDC eval for a checkpoint implemented
blazejba Dec 22, 2023
b5d0fb6
add config + small fixes
blazejba Jan 3, 2024
20485bd
allow disabling wandb + drop last batch if not full
blazejba Jan 3, 2024
65cc2da
log number of trainable params to wandb
blazejba Jan 3, 2024
dee7d2e
add spearmanr metric to regression tasks
blazejba Jan 3, 2024
cbf847f
add fingerprint extraction
blazejba Jan 3, 2024
71dfc69
small bug fix + allow to add a suffix to the fingerprint filenames
blazejba Jan 4, 2024
a1e8e45
refactoring
blazejba Jan 4, 2024
5722964
auto extraction of scores from the fingerprinting sweeps
blazejba Jan 4, 2024
30b0a0f
fix a bug with regression eval + add worker to unfinished sweeps inst…
blazejba Jan 4, 2024
c3733f2
use args instead of consts for the sweeping script
blazejba Jan 4, 2024
e0fc687
fix small bugs
blazejba Jan 4, 2024
16c0b71
analyze_results now produces a table in a csv file that can be copied…
blazejba Jan 4, 2024
656410b
simplify the script by removing repeating stuff
blazejba Jan 5, 2024
9ba63c9
fix a bug in the sweeper
blazejba Jan 5, 2024
c4d2109
dump a csv with extracted scores from the sweeps
blazejba Jan 6, 2024
4237502
add weight decay + filter out nans
blazejba Jan 6, 2024
d5163cb
combine all results correctly
blazejba Jan 6, 2024
7a2fcdc
randomly choose with dataset to sweep to avoid workers collision
blazejba Jan 6, 2024
98fb812
filter out samples with NaNs in training
blazejba Jan 8, 2024
f9451ba
order the csv table and include empty rows to match the excel template
blazejba Jan 8, 2024
d2e14d0
a bug with casting targets
blazejba Jan 8, 2024
c41a606
run finetuning with 5-fold scaffold split and report min,max,mean and…
blazejba Jan 8, 2024
6e54218
analzye 5-fold cross validation results
blazejba Jan 9, 2024
a8f5876
script for analyzing the best hparams
blazejba Jan 11, 2024
282a944
fixing the metric for the metabolism datasets
blazejba Jan 11, 2024
0352b96
load mup in validation/fingerprint extraction
blazejba Jan 16, 2024
8880d0d
extract node and edge level features
blazejba Jan 17, 2024
41afd4e
yolo.py extends fingerprint training with node and edge level features
blazejba Jan 17, 2024
69a0a2c
data analysis
blazejba Jan 25, 2024
8b5c2a6
smaller sweeps + improved analytics
blazejba Jan 28, 2024
69f1fc1
add "fair" test score selection
blazejba Jan 30, 2024
074e009
ensemble evaluation
blazejba Jan 30, 2024
051e5a1
allow extracting node features as fingerprints
blazejba Jan 30, 2024
c5cbf5f
change program name
blazejba Jan 30, 2024
a06c802
remove Gradient link from README
kerstink-GC Mar 12, 2024
3fabe6f
update
blazejba Mar 26, 2024
b2cbe3d
cleaning
blazejba Apr 8, 2024
76aa26d
move files
blazejba Apr 8, 2024
abc4a4b
Merge branch 'bb/finetuning-on-fingerprints' of github.com:graphcore-…
blazejba Apr 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

---

[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS)
[![PyPI](https://img.shields.io/pypi/v/graphium)](https://pypi.org/project/graphium/)
[![Conda](https://img.shields.io/conda/v/conda-forge/graphium?label=conda&color=success)](https://anaconda.org/conda-forge/graphium)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/graphium)](https://pypi.org/project/graphium/)
Expand Down Expand Up @@ -34,10 +33,6 @@ A deep learning library focused on graph representation learning for real-world

Visit https://graphium-docs.datamol.io/.

[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://ipu.dev/sdGggS)

You can try running Graphium on Graphcore IPUs for free on Gradient by clicking on the button above.

## Installation for developers

### For CPU and GPU developers
Expand Down
182 changes: 182 additions & 0 deletions analyze_best_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import wandb
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt


# Specify your project and entity (username or team name)
project_name = 'scale_mol_gnns_fingerprinting'
entity_name = 'ogb-lsc-comp'
pickle_path = 'results/sweep_results_dict.pickle'
csv_path = 'results/sweep_results_table.csv'

DEFINITION_OF_BETTER = {
'mae': min,
'r2': max,
'spearman': max,
'auroc': max,
'avpr': max
}

BENCHMARKS = {
'Caco2_Wang': 'test_mae',
'Bioavailability_Ma': 'test_auroc',
'Lipophilicity_AstraZeneca': 'test_mae',
'Solubility_AqSolDB': 'test_mae',
'HIA_Hou': 'test_auroc',
'Pgp_Broccatelli': 'test_auroc',
'BBB_Martins': 'test_auroc',
'PPBR_AZ': 'test_mae',
'VDss_Lombardo': 'test_spearman',
'CYP2C9_Veith': 'test_auroc',
'CYP2D6_Veith': 'test_auroc',
'CYP3A4_Veith': 'test_auroc',
'CYP2C9_Substrate_CarbonMangels': 'test_auroc',
'CYP2D6_Substrate_CarbonMangels': 'test_auroc',
'CYP3A4_Substrate_CarbonMangels': 'test_auroc',
'Half_Life_Obach': 'test_spearman',
'Clearance_Hepatocyte_AZ': 'test_spearman',
'Clearance_Microsome_AZ': 'test_spearman',
'LD50_Zhu': 'test_mae',
'hERG': 'test_auroc',
'AMES': 'test_auroc',
'DILI': 'test_auroc'
}

WANDB_STATES = {
'running': False,
'crashed': False,
'finished': True
}

import matplotlib.pyplot as plt
import os


def plot_hparam_distribution(top1_hparams, topn_hparams, save_dir='hparam_plots'):
if not os.path.exists(save_dir):
os.makedirs(save_dir)

hparams_keys = top1_hparams.keys()

for key in hparams_keys:
top1_values = top1_hparams[key]
topn_values = topn_hparams[key]

# Creating subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
fig.suptitle(f'Distribution of {key}')

# Top 1 Plot
axes[0].bar(top1_values.keys(), top1_values.values())
axes[0].set_title('Top 1')
axes[0].set_xlabel(key)
axes[0].set_ylabel('Frequency')
axes[0].tick_params(axis='x', rotation=45)

# Top N Plot
axes[1].bar(topn_values.keys(), topn_values.values())
axes[1].set_title(f'Top {n}')
axes[1].set_xlabel(key)
axes[1].set_ylabel('Frequency')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])

# Save the figure
fig.savefig(os.path.join(save_dir, f'{key}_distribution.png'))
plt.close(fig)


def initialize_hparams(sweeps):
""" Initialize hparams dictionaries with all possible parameter options set to zero. """
all_params = {}
for sweep in sweeps:
params = sweep.config['parameters']
for key, param in params.items():
all_values = param.get('values', [])
if key in all_params:
all_params[key].update(all_values)
else:
all_params[key] = set(all_values)

initialized_hparams = {key: Counter({val: 0 for val in values}) for key, values in all_params.items()}
return initialized_hparams

def get_sweep_parameters(sweep):
""" Extracts parameters being swept from the sweep configuration. """
return set(sweep.config['parameters'].keys())

def find_top_n_runs_for_sweep(sweep, n=5):
swept_params = get_sweep_parameters(sweep)
runs_data = []

for run in sweep.runs:
if not WANDB_STATES[run.state]:
continue

metric = BENCHMARKS[run.config['dataset']]
def_of_better = DEFINITION_OF_BETTER[metric.split('_')[-1]]

run_statistics = run.summary_metrics['statistics']
if metric in run_statistics.keys():
mean_score = run_statistics[metric]['mean']
# Filter run configuration to include only swept parameters
filtered_config = {k: v for k, v in run.config.items() if k in swept_params}
runs_data.append((mean_score, filtered_config))

# Sort and pick top N
runs_data.sort(key=lambda x: x[0], reverse=def_of_better is max)
return runs_data[:n]

def update_hparams(hparams, runs):
for _, config in runs:
for key, value in config.items():
if key in hparams:
hparams[key][value] += 1
else:
hparams[key] = Counter({value: 1})

keywords = ['40M-MPNN-easy-th', '11M-easy-th'] # Replace with your actual keywords

def any_keywords_present(sweep_name, keywords):
return any(keyword in sweep_name for keyword in keywords)

if __name__ == "__main__":
api = wandb.Api()
project = api.project(name=project_name, entity=entity_name)

n = 5
sweeps = project.sweeps()

# Initialize hparams dictionaries
topn_hparams = initialize_hparams(sweeps)
top1_hparams = initialize_hparams(sweeps)

# filter
filtered_sweeps = [sweep for sweep in sweeps if "|" in sweep.name and any_keywords_present(sweep.name, keywords)]
for idx, sweep in enumerate(filtered_sweeps):
model_name, dataset = sweep.name.split('|')
print(f"Sweep {idx + 1} / {len(filtered_sweeps)} - {model_name} - {dataset}")


_ = sweep.load(force=True)
if not WANDB_STATES[sweep.state.lower()]:
print(f"Sweep state - {sweep.state.lower()} - continuing to the next one")
continue

top_n_runs = find_top_n_runs_for_sweep(sweep, n=n)
update_hparams(topn_hparams, top_n_runs)

if top_n_runs:
top_1_run = [top_n_runs[0]] # Taking the top 1 run
update_hparams(top1_hparams, top_1_run)

import json
print("top1")
print(json.dumps(top1_hparams, indent=5))
print(f"top{n}")
print(json.dumps(topn_hparams, indent=5))

plot_hparam_distribution(top1_hparams, topn_hparams)
Loading