Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/OATML-Markslab/ProteinGym i…
Browse files Browse the repository at this point in the history
…nto main

Merging small filepath updates with updates for MIF scoring and ProtSSN
  • Loading branch information
danieldritter committed Mar 19, 2024
2 parents afd8157 + a653cea commit 495cc30
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 232 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,6 @@ If you use ProteinGym in your work, please cite the following paper:
```

## Links
Website: https://www.proteingym.org/
- Website: https://www.proteingym.org/
- NeurIPS proceedings: [link to abstract](https://papers.nips.cc/paper_files/paper/2023/hash/cac723e5ff29f65e3fcbb0739ae91bee-Abstract-Datasets_and_Benchmarks.html)
- Preprint: [link to abstract](https://www.biorxiv.org/content/10.1101/2023.12.07.570727v1)
1 change: 1 addition & 0 deletions proteingym/baselines/carp_mif/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import carp_mif_utils
38 changes: 38 additions & 0 deletions proteingym/baselines/carp_mif/carp_mif_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from sequence_models.collaters import SimpleCollater, StructureCollater, BGCCollater
from sequence_models.pretrained import load_carp,load_gnn,MIF
from sequence_models.constants import PROTEIN_ALPHABET

CARP_URL = 'https://zenodo.org/record/6564798/files/'
MIF_URL = 'https://zenodo.org/record/6573779/files/'
BIG_URL = 'https://zenodo.org/record/6857704/files/'

def load_model_and_alphabet(model_name, model_dir=None):
if not model_name.endswith(".pt"):
if 'big' in model_name:
url = BIG_URL + '%s.pt?download=1' %model_name
elif 'carp' in model_name:
url = CARP_URL + '%s.pt?download=1' %model_name
elif 'mif' in model_name:
url = MIF_URL + '%s.pt?download=1' %model_name
model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu", model_dir=model_dir)
else:
model_data = torch.load(model_name, map_location="cpu")
if 'big' in model_data['model']:
pfam_to_domain = model_data['pfam_to_domain']
tokens = model_data['tokens']
collater = BGCCollater(tokens, pfam_to_domain)
else:
collater = SimpleCollater(PROTEIN_ALPHABET, pad=True)
if 'carp' in model_data['model']:
model = load_carp(model_data)
elif model_data['model'] in ['mif', 'mif-st']:
gnn = load_gnn(model_data)
cnn = None
if model_data['model'] == 'mif-st':
url = CARP_URL + '%s.pt?download=1' % 'carp_640M'
cnn_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
cnn = load_carp(cnn_data)
collater = StructureCollater(collater, n_connections=30)
model = MIF(gnn, cnn=cnn)
return model, collater
40 changes: 27 additions & 13 deletions proteingym/baselines/carp_mif/compute_fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import torch
from torch.nn import CrossEntropyLoss

from sequence_models.pretrained import load_model_and_alphabet
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
from sequence_models.pdb_utils import parse_PDB, process_coords

from proteingym.baselines.carp_mif.carp_mif_utils import load_model_and_alphabet

def label_row(rows, sequence, token_probs, alphabet, offset_idx=1):
rows = rows.split(":")
score = 0
for row in rows:
wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]

assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

wt_encoded, mt_encoded = alphabet.index(wt), alphabet.index(mt)
Expand Down Expand Up @@ -46,7 +48,7 @@ def process_batch_mif(prot,pdb_file,tokenizer,device='cuda:0'):
edge_mask = edge_mask.to(device)
return input_ids,nodes,edges,connections,edge_mask

def calc_fitness(model, DMS_data, tokenizer, device='cuda:0', model_context_len=1024, mode="masked_marginals", alphabet=PROTEIN_ALPHABET, mutation_col='mutant', target_seq=None, pdb_file=None, model_name=None):
def calc_fitness(model, DMS_data, tokenizer, device='cuda:0', model_context_len=1024, mode="masked_marginals", alphabet=PROTEIN_ALPHABET, mutation_col='mutant', target_seq=None, pdb_file=None, model_name=None, offset_idx=1):
if mode=="pseudo_likelihood":
prots=np.array(DMS_data['mutated_sequence'])
loss_fn = CrossEntropyLoss()
Expand Down Expand Up @@ -85,6 +87,7 @@ def calc_fitness(model, DMS_data, tokenizer, device='cuda:0', model_context_len=
target_seq,
token_probs,
PROTEIN_ALPHABET,
offset_idx
),
axis=1,
)
Expand Down Expand Up @@ -130,24 +133,35 @@ def main():
mapping_protein_seq_DMS = pd.read_csv(args.DMS_reference_file_path)
list_DMS = mapping_protein_seq_DMS["DMS_id"]
DMS_id=list_DMS[args.DMS_index]
if not os.path.exists(args.output_scores_folder): os.mkdir(args.output_scores_folder)
args.output_scores_folder = args.output_scores_folder + os.sep + args.model_name
if not os.path.exists(args.output_scores_folder): os.mkdir(args.output_scores_folder)
scoring_filename = args.output_scores_folder+os.sep+DMS_id+'.csv'
print("Computing scores for: {} with model: {}".format(DMS_id, args.model_name))

DMS_file_name = mapping_protein_seq_DMS["DMS_filename"][mapping_protein_seq_DMS["DMS_id"]==DMS_id].values[0]
target_seq = mapping_protein_seq_DMS["target_seq"][mapping_protein_seq_DMS["DMS_id"]==DMS_id].values[0].upper()
pdb_file = args.structure_data_folder + os.sep + mapping_protein_seq_DMS["pdb_file"][mapping_protein_seq_DMS["DMS_id"]==DMS_id].values[0]


DMS_data = pd.read_csv(args.DMS_data_folder + os.sep + DMS_file_name, low_memory=False)
DMS_data['mutated_sequence'] = DMS_data['mutant'].apply(lambda x: get_mutated_sequence(target_seq, x)) if not args.indel_mode else DMS_data['mutant']

model_scores = calc_fitness(model=model, DMS_data=DMS_data, tokenizer=tokenizer, mode=args.fitness_computation_mode, target_seq=target_seq, pdb_file=pdb_file, model_name=args.model_name)

if 'mif' in args.model_name:
pdb_filenames = mapping_protein_seq_DMS["pdb_file"][mapping_protein_seq_DMS["DMS_id"]==DMS_id].values[0].split('|') #if sequence is large (eg., BRCA2_HUMAN) the structure is split in several chunks
pdb_ranges = mapping_protein_seq_DMS["pdb_range"][mapping_protein_seq_DMS["DMS_id"]==DMS_id].values[0].split('|')
model_scores=[]
for pdb_index, pdb_filename in enumerate(pdb_filenames):
pdb_file = args.structure_data_folder + os.sep + pdb_filename
pdb_range = [int(x) for x in pdb_ranges[pdb_index].split("-")]
target_seq_split = target_seq[pdb_range[0]-1:pdb_range[1]] #pdb_range is 1-indexed
DMS_data["mutated_position"] = DMS_data['mutant'].apply(lambda x: int(x.split(':')[0][1:-1])) #if multiple mutant, will extract position of first mutant
filtered_DMS_data = DMS_data[(DMS_data["mutated_position"] >= pdb_range[0]) & (DMS_data["mutated_position"] <= pdb_range[1])]
model_scores.append(calc_fitness(model=model, DMS_data=filtered_DMS_data, tokenizer=tokenizer, mode=args.fitness_computation_mode, target_seq=target_seq_split, pdb_file=pdb_file, model_name=args.model_name, offset_idx=pdb_range[0]))
model_scores = np.concatenate(model_scores)
else:
model_scores = calc_fitness(model=model, DMS_data=DMS_data, tokenizer=tokenizer, mode=args.fitness_computation_mode, target_seq=target_seq, pdb_file=None, model_name=args.model_name)

DMS_data[args.model_name+'_score']=model_scores

if not os.path.exists(args.output_scores_folder): os.mkdir(args.output_scores_folder)
args.output_scores_folder = args.output_scores_folder + os.sep + args.model_name
if not os.path.exists(args.output_scores_folder): os.mkdir(args.output_scores_folder)
scoring_filename = args.output_scores_folder+os.sep+DMS_id+'.csv'
DMS_data[['mutant',args.model_name+'_score','DMS_score']].to_csv(scoring_filename, index=False)

spearman, _ = spearmanr(DMS_data[args.model_name+'_score'], DMS_data['DMS_score'])

if not os.path.exists(args.performance_file) or os.stat(args.performance_file).st_size==0:
Expand All @@ -157,4 +171,4 @@ def main():
performance_file.write(",".join([DMS_id,str(spearman)])+"\n")

if __name__ == '__main__':
main()
main()
3 changes: 3 additions & 0 deletions proteingym/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
unusual_AA ="OU" #Pyrrolysine O and selenocysteine U
indeterminate_AA = "BJXZ" #B = Asparagine or Aspartic acid; J = leucine or isoleucine; X = Any/Unknown ; Z = Glutamine or glutamic acid

def standardize(x, epsilon = 1e-10):
return (x - x.mean()) / (x.std() + epsilon)

def nanmean(v, *args, inplace=False, **kwargs):
if not inplace:
v = v.clone()
Expand Down
Loading

0 comments on commit 495cc30

Please sign in to comment.