diff --git a/analyze_results.py b/analyze_results.py new file mode 100644 index 000000000..d8084531b --- /dev/null +++ b/analyze_results.py @@ -0,0 +1,154 @@ +import os +import pickle +import wandb +import pandas as pd + +# Specify your project and entity (username or team name) +project_name = 'scale_mol_gnns_fingerprinting' +entity_name = 'ogb-lsc-comp' +pickle_path = 'ogb-results/sweep_results_dict.pickle' +csv_path = 'ogb-results/sweep_results_table.csv' + +DEFINITION_OF_BETTER = { + 'mae': min, + 'r2': max, + 'spearman': max, + 'auroc': max, + 'avpr': max +} + +BENCHMARKS = { + "ogbg-molbace": "test_auroc", + "ogbg-molbbbp": "test_auroc", + "ogbg-moltox21": "test_auroc", + "ogbg-molclintox": "test_auroc", + "ogbg-moltoxcast": "test_auroc" +} + +WANDB_STATES = { + 'running': False, + 'crashed': False, + 'killed': False, + 'failed': False, + 'finished': True, +} + +# if you want to order the columns in the table, specify the order here +MODELS = [ + 'ogb_20240125_003747_10M', + 'ogb_20240125_115530_10M' +] + +def find_best_score_for_sweep(sweep): + mean_test_scores, std_test_scores, run_indices = [], [], [] + metric, def_of_better = None, None + + for idx, run in enumerate(sweep.runs): + + if WANDB_STATES[run.state] is False: + continue # skip if crashed or unfinished + + if metric is None or def_of_better is None: # dataset cant be extracted from a sweep so get it from a run + metric = BENCHMARKS[run.config['dataset']] + def_of_better = DEFINITION_OF_BETTER[metric.split('_')[-1]] + + if "statistics" in run.summary_metrics.keys(): + run_statistics = run.summary_metrics['statistics'] + if f"{metric}" in run_statistics.keys(): + mean_test_scores += [run_statistics[metric]['mean']] + std_test_scores += [run_statistics[metric]['std']] + + # use appropriate reduction for the metric to get the best score in the sweep + best_mean_test_score = def_of_better(mean_test_scores) if len(mean_test_scores) else 'NaN' + + # Get the index of best_mean_test_score to find the std_test_score + if best_mean_test_score != 'NaN': + index_of_best_score = mean_test_scores.index(best_mean_test_score) + best_std_test_score = std_test_scores[index_of_best_score] + else: + best_std_test_score = 'NaN' + + return best_mean_test_score, best_std_test_score + +def load_results(file_path): + if os.path.exists(file_path): + with open(file_path, 'rb') as file: + return pickle.load(file) + return {} + +def save_results(results, file_path): + with open(file_path, 'wb') as file: + pickle.dump(results, file) + +def save_to_csv(results, csv_path=None): + # Prepare a list for DataFrame rows + data = [] + + # Iterate through each dataset in BENCHMARKS + for dataset in BENCHMARKS.keys(): + mean_row = {'Metric': 'Mean', 'Dataset': dataset} + std_row = {'Metric': 'Std', 'Dataset': dataset} + + # Iterate through results to fill the rows + for (model_name, result_dataset), values in results.items(): + if dataset == result_dataset: + mean_row[model_name] = values['mean'] + std_row[model_name] = values['std'] + + data.append(mean_row) + data.append(std_row) + + # Convert list to DataFrame + df = pd.DataFrame(data) + + # Set the 'Metric' and 'Dataset' columns as a multi-index + df.set_index(['Metric', 'Dataset'], inplace=True) + + # Handle unspecified order in MODELS or additional columns + ordered_columns = [model for model in MODELS if model in df.columns] + additional_columns = [model for model in df.columns if model not in MODELS] + final_columns_order = ordered_columns + additional_columns + + df = df[final_columns_order] + + if csv_path is not None: + df.to_csv(csv_path) + + return df + + + +if __name__ == "__main__": + + api = wandb.Api() + + project = api.project(name=project_name, entity=entity_name) + + results = load_results(pickle_path) + sweeps = project.sweeps() + + # filter + filtered_sweeps = [sweep for sweep in sweeps if "|" in sweep.name] + for idx, sweep in enumerate(filtered_sweeps): + model_name, dataset = sweep.name.split('|') + print(f"Sweep {idx + 1} / {len(filtered_sweeps)} - {model_name} - {dataset}") + + if model_name not in MODELS: + print(f"Model {model_name} not selected for analysis. Skipping...") + continue + + if (model_name, dataset) in results and model_name != 'SUPER': + print(f"Combination of ({model_name}, {dataset}) already exists in results. Skipping...") + continue + + _ = sweep.load(force=True) # this is needed otherwise sweep.runs is an empty list + if WANDB_STATES[sweep.state.lower()] is False and model_name != 'SUPER': + print(f"Sweep state - {sweep.state.lower()} - continuing to the next one") + continue + + mean_score, std_score = find_best_score_for_sweep(sweep) + results[(model_name, dataset)] = {"mean": mean_score, "std": std_score} + print(f"{mean_score=}, {std_score=}") + + save_results(results, pickle_path) + save_to_csv(results, csv_path) diff --git a/finetune_on_ogb_config.yaml b/finetune_on_ogb_config.yaml new file mode 100644 index 000000000..fe1208e54 --- /dev/null +++ b/finetune_on_ogb_config.yaml @@ -0,0 +1,46 @@ +command: + - python + - finetune_on_ogb_fingerprints.py + - --model-name=${envvar:SWEEP_MODEL_NAME} + - --dataset=${envvar:SWEEP_DATASET} + - --fingerprints-path=${envvar:SWEEP_FINGERPRINTS_PATH} + - --num-cross-validation-folds=${envvar:SWEEP_CROSS_VALIDATION_FOLDS} + - --epochs=20 + - ${args} +entity: lmueller +method: grid +metric: + goal: minimize + name: statistics.test_loss.mean +parameters: + combine-input: + values: + - concat + - none + depth: + values: + - 3 + - 4 + dropout-rate: + values: + - 0 + - 0.1 + hidden-dim: + values: + - 512 + - 1024 + - 2048 + lr: + values: + - 0.00075 + - 0.0003 + - 0.0001 + warmup-epochs: + values: + - 0 + - 5 + lr-schedule: + values: + - constant + - linear + - cosine \ No newline at end of file diff --git a/finetune_on_ogb_fingerprints.py b/finetune_on_ogb_fingerprints.py new file mode 100644 index 000000000..93503b291 --- /dev/null +++ b/finetune_on_ogb_fingerprints.py @@ -0,0 +1,444 @@ +import math +import wandb +import argparse +import json +from copy import deepcopy +import numpy as np + +import datamol as dm + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchsummary import summary +from torch.utils.data import DataLoader, Dataset +from sklearn.metrics import roc_auc_score, average_precision_score, r2_score, mean_absolute_error +from scipy.stats import spearmanr + +from ogb.graphproppred import PygGraphPropPredDataset, Evaluator +import pandas as pd + +SEEDS = [345374, 467039, 986009, 916060, 641316, 798438, 665204, 373079, 228395, 935414] + + +# model stuff +def train_one_epoch(model, dataloader, loss_fn, optimizer, task_type, epoch, fold): + model.train() + total_loss = 0 + for inputs, targets in dataloader: + + # Filter samples with NaNs out + nan_mask = ~torch.isnan(inputs).any(dim=1) + filtered_inputs = inputs[nan_mask] + filtered_targets = targets[nan_mask] + + if len(filtered_inputs) > 0: + optimizer.zero_grad() + outputs = model(filtered_inputs.float()) + + if task_type == "classification": + filtered_targets = filtered_targets.long() + elif task_type == "multi-class": + filtered_targets = filtered_targets.float() + else: + filtered_targets = filtered_targets.float() + + if task_type == "multi-class": + nan_mask = (filtered_targets == filtered_targets) + loss = loss_fn(outputs[nan_mask].squeeze(), filtered_targets[nan_mask].squeeze()) + else: + loss = loss_fn(outputs.squeeze(), filtered_targets) + loss.backward() + optimizer.step() + total_loss += loss.item() + + loss = total_loss / len(dataloader) + wandb.log({'epoch': epoch + fold, 'train_loss': loss}) + print(f"## Epoch {epoch+1} - Train Loss: {loss:.4f}") + return model + +def evaluate(model, dataloader, loss_fn, task_type, evaluation_type, epoch, fold, evaluator): + model.eval() + total_loss = 0 + all_outputs = [] # For regression, store raw outputs + all_probs = [] # For classification, store probabilities + all_targets = [] + + with torch.no_grad(): + for inputs, targets in dataloader: + outputs = model(inputs.float()) + + if task_type == "classification": + loss_targets = targets.long() + elif task_type == "multi-class": + loss_targets = targets.float() + else: + loss_targets = targets.float() + + if task_type == "multi-class": + nan_mask = (targets == targets) + loss = loss_fn(outputs[nan_mask].squeeze(), loss_targets[nan_mask].squeeze()) + else: + loss = loss_fn(outputs, loss_targets) + + total_loss += loss.item() + + if task_type == 'classification': + probs = torch.softmax(outputs, dim=1)[:, 1] + all_probs.extend(probs.tolist()) + all_targets.extend(targets.tolist()) + + elif task_type == "multi-class": + all_targets.append(targets.view(outputs.shape)) + all_outputs.append(outputs) + else: + # Ensure outputs are always in list format + outputs = outputs.squeeze() + if outputs.dim() == 0: # Check if outputs is a scalar + all_outputs.extend(outputs.unsqueeze(0).tolist()) # Append scalar directly + else: + all_outputs.extend(outputs.tolist()) # Extend list + + all_targets.extend(targets.tolist()) + + loss = total_loss / len(dataloader) + metrics = {f'{evaluation_type}_loss': loss} + + if task_type == 'classification': + # Filter out NaNs + clean_indices = [i for i, x in enumerate(all_probs) if not np.isnan(x)] + all_probs = [all_probs[i] for i in clean_indices] + all_targets = [all_targets[i] for i in clean_indices] + + auroc = roc_auc_score(all_targets, all_probs) + avpr = average_precision_score(all_targets, all_probs) # apparently same as AUPRC + metrics.update({ + f'{evaluation_type}_auroc': auroc, + f'{evaluation_type}_avpr': avpr, + }) + elif task_type == "multi-class": + auroc = evaluator.eval(dict( + y_true = torch.cat(all_targets), + y_pred = torch.cat(all_outputs) + ))["rocauc"] + metrics.update({ + f'{evaluation_type}_auroc': auroc, + }) + else: + # Filter out NaNs + clean_indices = [i for i, x in enumerate(all_outputs) if not np.isnan(x)] + all_outputs = [all_outputs[i] for i in clean_indices] + all_targets = [all_targets[i] for i in clean_indices] + + r2 = r2_score(all_targets, all_outputs) + mae = mean_absolute_error(all_targets, all_outputs) + spearman_corr, _ = spearmanr(all_targets, all_outputs) + metrics.update({ + f'{evaluation_type}_r2': r2, + f'{evaluation_type}_mae': mae, + f'{evaluation_type}_spearman': spearman_corr, + }) + + if evaluation_type == 'val': + wandb.log({**metrics, 'epoch': epoch + fold}) + else: + wandb.log({**metrics, 'fold': fold}) + + print(json.dumps(metrics, indent=5)) + + return metrics + +class Model(nn.Module): + def __init__(self, input_dim, depth=3, hidden_dim=512, activation_fn='relu', combine_input='concat', num_classes=None, num_tasks=0, dropout_rate=0.1, **kwargs): + super(Model, self).__init__() + + if depth < 2: + raise ValueError("Depth must be at least 2") + + if depth == 2 and combine_input == 'concat' and hidden_dim != input_dim: + raise ValueError("When depth is 2 and combine_input is 'concat', hidden_dim must match input_dim") + + self.depth = depth + self.hidden_dim = hidden_dim + self.combine_input = combine_input + self.dropout = nn.Dropout(dropout_rate) + self.layers = nn.ModuleList() + self.batch_norms = nn.ModuleList() # Batch normalization layers + + # Determine activation function + if activation_fn == 'relu': + self.activation_fn = F.relu + else: + raise NotImplementedError(f"Activation function {activation_fn} not implemented.") + + # Create layers and batch normalization layers + for i in range(depth): + if i == 0: # first layer + in_dim = input_dim + out_dim = hidden_dim + elif i == depth - 1: # last layer + in_dim = input_dim + hidden_dim if self.combine_input == 'concat' else hidden_dim + + if num_tasks == 0: + out_dim = num_classes if num_classes is not None else 1 + else: + out_dim = num_tasks + else: # in between layers + in_dim = hidden_dim + out_dim = hidden_dim + self.layers.append(nn.Linear(in_dim, out_dim)) + self.batch_norms += [nn.BatchNorm1d(hidden_dim)] if i != depth - 1 else [] + + def forward(self, x): + original_x = x + for i in range(self.depth): + x = self.layers[i](x) + if i < self.depth - 1: + x = self.batch_norms[i](x) + x = self.activation_fn(x) + x = self.dropout(x) + + if self.combine_input == 'concat' and i == self.depth - 2: + x = torch.cat((x, original_x), dim=1) + + if x.shape[1] == 1: # If final output dimension is 1, squeeze it for regression + x = x.squeeze(1) + + return x + + +# factories +def dataloader_factory(split_name, benchmark, split_idx, i2v, args, seed=42): + assert split_name == 'train+val' or split_name == 'test', "Wrong value for `split_name` argument passed to dataloader_factory" + + def match_and_replace_input_column(samples_df): + transformed_df = samples_df.copy() + transformed_df["smiles"] = transformed_df["smiles"].apply( + lambda s: i2v[dm.unique_id(s)].detach().numpy()) + return transformed_df + + class SingleInstancePredictionDataset(Dataset): + def __init__(self, samples_df, task_type): + self.samples = samples_df["smiles"].tolist() + + target_columns = [col for col in samples_df.columns.values if not col in ["smiles", "mol_id"]] + self.targets = zip(*[samples_df[col].tolist() for col in target_columns]) + if task_type == "multi-class": + self.targets = [[float(tgt) for tgt in target] for target in self.targets] + self.out_dim = len(target_columns) + elif task_type == "classification": + assert len(target_columns) == 1 + self.targets = samples_df[target_columns[0]].tolist() + self.targets = [float(target) for target in self.targets] + self.out_dim = 0 + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = torch.tensor(self.samples[idx]) + target = torch.tensor(self.targets[idx]) + return sample, target + + train_loader, val_loader, test_loader, input_dim, output_dim = None, None, None, None, None + + if split_name == 'train+val': + train_split, val_split = benchmark.iloc[split_idx["train"].tolist()], benchmark.iloc[split_idx["valid"].tolist()] + + train_samples = match_and_replace_input_column(train_split) + val_samples = match_and_replace_input_column(val_split) + + train_dataset = SingleInstancePredictionDataset(train_samples, args.task_type) + val_dataset = SingleInstancePredictionDataset(val_samples, args.task_type) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) + else: + test_split = benchmark.iloc[split_idx["test"].tolist()] + test_samples = match_and_replace_input_column(test_split) + + test_dataset = SingleInstancePredictionDataset(test_samples, args.task_type) + + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) + + input_dim = test_samples["smiles"].iloc[0].shape[0] + output_dim = test_dataset.out_dim + + return train_loader, val_loader, test_loader, input_dim, output_dim + +def model_factory(args): + model = Model(**vars(args)) + + if args.task_type == "classification": + loss_fn = nn.CrossEntropyLoss() + elif args.task_type == "multi-class": + loss_fn = F.binary_cross_entropy_with_logits + else: + loss_fn = nn.MSELoss() + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + summary(model, input_size=(args.input_dim,), batch_size=args.batch_size) + return model, loss_fn, optimizer, trainable_params + + + + + +# optimiser stuff +def l1_regularization(model, scale): + l1_loss = torch.tensor(0.0, requires_grad=True) + for param in model.parameters(): + l1_loss += torch.norm(param, 1) + return scale * l1_loss + +def adjust_learning_rate(optimizer, epoch, args): + if epoch < args.warmup_epochs: + # Linear warmup + lr = args.lr * (epoch + 1) / args.warmup_epochs + elif args.lr_schedule == 'constant': + lr = args.lr + elif args.lr_schedule == 'linear': + # Linear decay + lr = args.lr * (1 - (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)) + elif args.lr_schedule == 'cosine': + # Cosine decay + lr = args.lr * (1 + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) / 2 + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + current_lr = optimizer.param_groups[0]['lr'] + wandb.log({'epoch': epoch, 'lr_at_epoch': current_lr}) + + +def aggregate_dicts(dicts): + aggr_dict = {} + for d in dicts: + for key, value in d.items(): + # Ensure value is not a list to avoid nesting + if not isinstance(value, list): + if key in aggr_dict: + aggr_dict[key].append(value) + else: + aggr_dict[key] = [value] + else: + # Handle the case where the value is already a list + if key in aggr_dict: + aggr_dict[key].extend(value) + else: + aggr_dict[key] = value + return aggr_dict + +def calculate_statistics(aggr_dict): + result = {} + for key, values in aggr_dict.items(): + min_val = min(values) + max_val = max(values) + mean_val = sum(values) / len(values) if values else 0 + variance = sum((x - mean_val) ** 2 for x in values) / len(values) if len(values) > 1 else 0 + std_val = variance ** 0.5 + result[key] = {'min': min_val, 'max': max_val, 'mean': mean_val, 'std': std_val} + return result + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model-name', type=str, default='default-model', help='Name of model, is used to construct a name for the wandb run') + parser.add_argument('--fingerprints-path', type=str, default='ogb-results/ids_to_fingerprint.pt', help='Path to ids_to_fingerprint.pt') + parser.add_argument('--dataset', type=str, default='ogbg-molbace', help='Name of the benchmark from admet_group') + parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs') + parser.add_argument('--split', type=float, default=0.1, help='Ratio of validation set split') + parser.add_argument('--batch-size', type=int, default=32, help='Batch size for training and evaluation') + parser.add_argument('--num-cross-validation-folds', type=int, default=1, help='') + # Learning rate + parser.add_argument('--weight-decay', type=float, default=0.0001, help='Learning rate for training') + parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training') + parser.add_argument('--warmup-epochs', type=int, default=2, help='Number of warmup epochs') + parser.add_argument('--lr-schedule', type=str, default='constant', choices=['constant', 'linear', 'cosine'], help='Learning rate scheduling strategy') + # Model architecture + parser.add_argument('--depth', type=int, default=3, help='Depth of the model. Minimum 2. If 2, hidden_dim must equal the input dim.') + parser.add_argument('--hidden-dim', type=int, default=512, help='Dimension of hidden layers') + parser.add_argument('--activation-fn', type=str, default='relu', choices=['relu'], help='Activation function') + parser.add_argument('--combine-input', type=str, default='concat', choices=['concat', 'none'], help='Method to combine input') + parser.add_argument('--dropout-rate', type=float, default=0.1, help='Dropout rate') + # W&B + parser.add_argument('--wandb-off', action='store_false', help='') + parser.add_argument('--wandb-entity', type=str, default='ogb-lsc-comp', help='') + parser.add_argument('--wandb-project', type=str, default='scaling_mol_gnns', help='') + + args = parser.parse_args() + print(json.dumps(vars(args), indent=5)) + + + # Load the id to fingerprint mapping + i2v = torch.load(args.fingerprints_path) + + dataset = PygGraphPropPredDataset( + name=args.dataset, + root="ogb-data", + ) + + split_idx = dataset.get_idx_split() + benchmark = pd.read_csv(f"ogb-data/{args.dataset.replace('-', '_')}/mapping/mol.csv.gz") + + evaluator = Evaluator(args.dataset) + metric = evaluator.eval_metric + assert metric == "rocauc" + + # Determine task type and number of classes + if dataset.num_tasks == 1: + args.task_type, args.num_classes = "classification", dataset.num_classes + else: + args.task_type, args.num_classes = "multi-class", dataset.num_classes + + _, _, test_dl, args.input_dim, args.num_tasks = dataloader_factory("test", benchmark, split_idx, i2v, args) + + results = {} + + for seed, fold in zip(SEEDS, range(args.num_cross_validation_folds)): + # Construct dataloaders + train_dl, val_dl, _, _, _ = dataloader_factory("train+val", benchmark, split_idx, i2v, args, seed=seed) + + # Define a model + model, loss_fn, optimizer, args.trainable_params = model_factory(args) + + # Initialize wandb + run_name = f"{args.model_name}_{args.dataset}" + mode = 'disabled' if args.wandb_off is False else None + wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=run_name, mode=mode) + wandb.config.update(args) + + # Test random model + epoch = 0 + # evaluate(model, test_dl, loss_fn, args.task_type, evaluation_type='test', epoch=epoch) + + best_epoch = {'val_results': None, 'model': None} + # Training and validation loop + for epoch in range(args.epochs): + print(f"## Fold {fold+1}/{args.num_cross_validation_folds} | Epoch {epoch+1}/{args.epochs}") + adjust_learning_rate(optimizer, epoch, args) + model = train_one_epoch(model, train_dl, loss_fn, optimizer, args.task_type, epoch, fold) + val_results = evaluate(model, val_dl, loss_fn, args.task_type, evaluation_type='val', epoch=epoch, fold=fold, evaluator=evaluator) + + # keep best model and validation loss value + if best_epoch['model'] is None: + best_epoch['model'] = deepcopy(model) + best_epoch['val_results'] = deepcopy(val_results) + else: + best_epoch['model'] = best_epoch['model'] if best_epoch['val_results']['val_loss'] <= val_results['val_loss'] else deepcopy(model) + best_epoch['val_results'] = best_epoch['val_results'] if best_epoch['val_results']['val_loss'] <= val_results['val_loss'] else deepcopy(val_results) + + # Test trained model + eval_results = evaluate(best_epoch['model'], test_dl, loss_fn, args.task_type, evaluation_type='test', epoch=epoch, fold=fold, evaluator=evaluator) + results = aggregate_dicts(dicts=[results, best_epoch['val_results'], eval_results]) + + print(json.dumps(results, indent=5)) + print(json.dumps(calculate_statistics(results), indent=5)) + wandb.run.summary['statistics'] = calculate_statistics(results) + wandb.finish() + +if __name__ == "__main__": + main() diff --git a/graphium/cli/get_ogb_fingerprints.py b/graphium/cli/get_ogb_fingerprints.py new file mode 100644 index 000000000..e409f04f6 --- /dev/null +++ b/graphium/cli/get_ogb_fingerprints.py @@ -0,0 +1,170 @@ +import os + +import hydra +import torch +from lightning.pytorch.utilities.model_summary import ModelSummary +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from graphium.config._loader import ( + load_accelerator, + load_mup, + load_datamodule, + get_checkpoint_path, +) +from graphium.trainer.predictor import PredictorModule + +from tqdm import tqdm +from copy import deepcopy +import datamol as dm +import sys +from torch_geometric.data import Batch + +TESTING_ONLY_CONFIG_KEY = "testing_only" + +DATASETS = [ + "ogbg-molbace", + "ogbg-molbbbp", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-moltoxcast", +] +from ogb.graphproppred import PygGraphPropPredDataset +import pandas as pd + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + The main CLI endpoint for training, fine-tuning and evaluating Graphium models. + """ + return get_final_fingerprints(cfg) + + +def get_final_fingerprints(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + # Get ADMET SMILES strings + if not os.path.exists("saved_ogb_smiles.pt"): + ogb_mol_ids = set() + for dataset in DATASETS: + PygGraphPropPredDataset(root="ogb-data", name=dataset) + data = pd.read_csv(f"ogb-data/{dataset.replace('-', '_')}/mapping/mol.csv.gz") + ogb_mol_ids |= set(data["smiles"].apply(dm.unique_id)) + + smiles_to_process = [] + ogb_mol_ids_to_find = deepcopy(ogb_mol_ids) + + for dataset in tqdm(DATASETS, desc="Matching molecules to IDs", file=sys.stdout): + data = pd.read_csv(f"ogb-data/{dataset.replace('-', '_')}/mapping/mol.csv.gz") + mols = set(data["smiles"]) + for smiles in mols: + mol_id = dm.unique_id(smiles) + if mol_id in ogb_mol_ids_to_find: + smiles_to_process.append(smiles) + ogb_mol_ids_to_find.remove(mol_id) + + assert set(dm.unique_id(s) for s in smiles_to_process) == ogb_mol_ids + torch.save(smiles_to_process, "saved_ogb_smiles.pt") + else: + smiles_to_process = torch.load("saved_ogb_smiles.pt") + + unresolved_cfg = OmegaConf.to_container(cfg, resolve=False) + cfg = OmegaConf.to_container(cfg, resolve=True) + + ## == Instantiate all required objects from their respective configs == + # Accelerator + cfg, accelerator_type = load_accelerator(cfg) + assert accelerator_type == "cpu", "get_ogb_fingerprints script only runs on CPU for now" + + ## Data-module + datamodule = load_datamodule(cfg, accelerator_type) + + # Featurize SMILES strings + input_features_save_path = "ogb_input_features.pt" + idx_none_save_path = "ogb_idx_none.pt" + if not os.path.exists(input_features_save_path): + input_features, idx_none = datamodule._featurize_molecules(smiles_to_process) + + torch.save(input_features, input_features_save_path) + torch.save(idx_none, idx_none_save_path) + else: + input_features = torch.load(input_features_save_path) + + failures = 0 + + # Cast to FP32 + for input_feature in tqdm(input_features, desc="Casting to FP32"): + try: + if not isinstance(input_feature, str): + for k, v in input_feature.items(): + if isinstance(v, torch.Tensor): + if v.dtype == torch.half: + input_feature[k] = v.float() + elif v.dtype == torch.int32: + input_feature[k] = v.long() + else: + failures += 1 + except Exception as e: + print(f"{input_feature = }") + raise e + + print(f"{failures = }") + + + # Load pre-trained model + predictor = PredictorModule.load_pretrained_model( + name_or_path=get_checkpoint_path(cfg), device=accelerator_type + ) + predictor = load_mup(mup_base_path=cfg['architecture']['mup_base_path'], predictor=predictor) + + logger.info(predictor.model) + logger.info(ModelSummary(predictor, max_depth=4)) + + batch_size = 100 + + # Run the model to get fingerprints + results_folder = "ogb-results" + if not os.path.exists(results_folder): + os.makedirs(results_folder) + for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): + batch = Batch.from_data_list(input_features[index:(index + batch_size)]) + model_fp32 = predictor.model.float() + output, extras = model_fp32.forward(batch, extra_return_names=["pre_task_heads"]) + fingerprint = extras['pre_task_heads']['graph_feat'] + num_molecules = min(batch_size, fingerprint.shape[0]) + results = [fingerprint[i] for i in range(num_molecules)] + + torch.save(results, f'{results_folder}/res-{i:04}.pt') + + if index == 0: + print(fingerprint.shape) + + + # combine the results + all_results = [] + + for i, index in tqdm(enumerate(range(0, len(input_features), batch_size))): + + results = torch.load(f'{results_folder}/res-{i:04}.pt') + all_results.extend(results) + + del input_features + + # Save .pt files + suffix = '_' + unresolved_cfg['run_name_suffix'] if 'run_name_suffix' in unresolved_cfg.keys() else '' + + torch.save(all_results, f"{results_folder}/results{suffix}.pt") + + # Generate dictionary SMILES -> fingerprint vector + smiles_to_fingerprint = dict(zip(smiles_to_process, all_results)) + torch.save(smiles_to_fingerprint, f"{results_folder}/smiles_to_fingerprint{suffix}.pt") + + # Generate dictionary unique IDs -> fingerprint vector + ids = [dm.unique_id(smiles) for smiles in smiles_to_process] + ids_to_fingerprint = dict(zip(ids, all_results)) + torch.save(ids_to_fingerprint, f"{results_folder}/ids_to_fingerprint{suffix}.pt") + + +if __name__ == "__main__": + cli() diff --git a/sweep_checkpoint.py b/sweep_checkpoint.py new file mode 100644 index 000000000..da6ea3788 --- /dev/null +++ b/sweep_checkpoint.py @@ -0,0 +1,102 @@ +import os +import wandb +import yaml +import time +import argparse +import random + +# EDIT +#YAML_FILE_PATH = 'finetune_on_fingerprints_config.yaml' + +# CONST +TDC_BENCHMARKS = [ + 'Caco2_Wang', + 'Bioavailability_Ma', + 'Lipophilicity_AstraZeneca', + 'Solubility_AqSolDB', + 'HIA_Hou', + 'Pgp_Broccatelli', + 'BBB_Martins', + 'PPBR_AZ', + 'VDss_Lombardo', + 'CYP2C9_Veith', + 'CYP2D6_Veith', + 'CYP3A4_Veith', + 'CYP2C9_Substrate_CarbonMangels', + 'CYP2D6_Substrate_CarbonMangels', + 'CYP3A4_Substrate_CarbonMangels', + 'Half_Life_Obach', + 'Clearance_Hepatocyte_AZ', + 'Clearance_Microsome_AZ', + 'LD50_Zhu', + 'hERG', + 'AMES', + 'DILI' +] + +OGB_BENCHMARKS = [ + "ogbg-molbace", + "ogbg-molbbbp", + "ogbg-molclintox", + "ogbg-moltox21", + "ogbg-moltoxcast", +] + +def create_sweep_and_get_id(sweep_name, yaml_file_path): + with open(yaml_file_path, 'r') as file: sweep_config = yaml.safe_load(file) + sweep_config['name'] = sweep_name + return wandb.sweep(sweep=sweep_config, entity=os.getenv('WANDB_ENTITY'), project=os.getenv('WANDB_PROJECT')) + +def get_sweep_status(api): + return api.sweep(f"{os.getenv('WANDB_ENTITY')}/{os.getenv('WANDB_PROJECT')}/{sweep_id}").state + +def get_sweep_id_by_name(api, sweep_name): + project = api.project(name=os.getenv('WANDB_PROJECT'), entity=os.getenv('WANDB_ENTITY')) + sweeps = project.sweeps() + for sweep in sweeps: + if sweep.name == sweep_name: + return sweep.id + return None + +# Example: python sweep_checkpoint.py --fingerprints-path ogb-results/ids_to_fingerprint.pt--benchmark ogb --wandb-project biomol-ogb +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Sweep Configuration') + parser.add_argument('--model-name', type=str, default='10M', help='Name of the sweep model') + parser.add_argument('--fingerprints-path', type=str, default='ogb-results/ids_to_fingerprint.pt', help='Path to the fingerprints file') + parser.add_argument('--benchmark', type=str, default='ogb', help='Benchmark (tdc or ogb)') + parser.add_argument('--wandb-entity', type=str, help='W&B entity') + parser.add_argument('--wandb-project', type=str, help='W&B project') + args = parser.parse_args() + + os.environ['SWEEP_MODEL_NAME'] = args.model_name + os.environ['SWEEP_FINGERPRINTS_PATH'] = args.fingerprints_path + os.environ['WANDB_ENTITY'] = args.wandb_entity + os.environ['WANDB_PROJECT'] = args.wandb_project + os.environ['SWEEP_CROSS_VALIDATION_FOLDS'] = str(5) + + if args.benchmark == "tdc": + yaml_file_path = "finetune_on_fingerprints_config.yaml" + benchmarks = TDC_BENCHMARKS + else: + yaml_file_path = "finetune_on_ogb_config.yaml" + benchmarks = OGB_BENCHMARKS + + api = wandb.Api() + #random.shuffle(benchmarks) + for dataset in benchmarks: + os.environ['SWEEP_DATASET'] = dataset + sweep_name = f"{os.getenv('SWEEP_MODEL_NAME')}|{dataset}" + + sweep_id = get_sweep_id_by_name(api, sweep_name) + if sweep_id is not None: + status = get_sweep_status(api) + if status == 'FINISHED': + print(f"Sweep '{sweep_name}' is already finished, moving to the next benchmark.") + continue + else: + sweep_id = create_sweep_and_get_id(sweep_name, yaml_file_path) + print(f"Created sweep with ID {sweep_id} for dataset {dataset}") + + wandb.agent(sweep_id) + while get_sweep_status(api) != 'FINISHED': + time.sleep(100) # every 100 secs check if sweep finished