Skip to content

Commit 104df44

Browse files
author
Simona Rabinovici-Cohen
committed
pre-commit fixes
1 parent be69eb1 commit 104df44

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

mammal/examples/molnet/molnet_infer.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import numpy as np
21
import click
2+
import numpy as np
33
import torch
44
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
55

6-
from mammal.model import Mammal
76
from mammal.keys import (
87
CLS_PRED,
98
ENCODER_INPUTS_ATTENTION_MASK,
109
ENCODER_INPUTS_STR,
1110
ENCODER_INPUTS_TOKENS,
1211
SCORES,
1312
)
13+
from mammal.model import Mammal
1414

1515
TASK_NAMES = ["BBBP", "TOXICITY", "FDA_APPR"]
1616

@@ -39,22 +39,22 @@ def load_model(task_name: str, device: str) -> dict:
3939
case "TOXICITY":
4040
path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox"
4141
case "FDA_APPR":
42-
path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
43-
case _:
44-
print(f"The {task_name=} is incorrect")
45-
42+
path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
43+
case _:
44+
print(f"The {task_name=} is incorrect")
45+
4646
# Load Model and set to evaluation mode
4747
model = Mammal.from_pretrained(path)
4848
model.eval()
4949
model.to(device=device)
50-
50+
5151
# Load Tokenizer
5252
tokenizer_op = ModularTokenizerOp.from_pretrained(path)
53-
53+
5454
task_dict = dict(
5555
task_name=task_name,
5656
model=model,
57-
tokenizer_op=tokenizer_op,
57+
tokenizer_op=tokenizer_op,
5858
)
5959
return task_dict
6060

@@ -65,7 +65,7 @@ def process_model_output(
6565
decoder_output_scores: np.ndarray,
6666
) -> dict:
6767
"""
68-
Extract predicted class and scores
68+
Extract predicted class and scores
6969
"""
7070
negative_token_id = tokenizer_op.get_token_id("<0>")
7171
positive_token_id = tokenizer_op.get_token_id("<1>")
@@ -79,10 +79,10 @@ def process_model_output(
7979
scores = decoder_output_scores[
8080
classification_position, positive_token_id
8181
]
82-
82+
8383
ans = dict(
8484
pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
85-
score=scores.item(),
85+
score=scores.item(),
8686
)
8787
return ans
8888

@@ -91,15 +91,15 @@ def task_infer(task_dict: dict, smiles_seq: str) -> dict:
9191
task_name = task_dict["task_name"]
9292
model = task_dict["model"]
9393
tokenizer_op = task_dict["tokenizer_op"]
94-
94+
9595
if task_name not in TASK_NAMES:
96-
print(f"The {task_name=} is incorrect. Valid names are {TASK_NAMES}")
97-
96+
print(f"The {task_name=} is incorrect. Valid names are {TASK_NAMES}")
97+
9898
# Create and load sample
9999
sample_dict = dict()
100100
# Formatting prompt to match pre-training syntax
101101
sample_dict[ENCODER_INPUTS_STR] = f"<@TOKENIZER-TYPE=SMILES><MOLECULAR_ENTITY><MOLECULAR_ENTITY_SMALL_MOLECULE><{task_name}><SENTINEL_ID_0><@TOKENIZER-TYPE=SMILES@MAX-LEN=2100><SEQUENCE_NATURAL_START>{smiles_seq}<SEQUENCE_NATURAL_END><EOS>"
102-
102+
103103
# Tokenize
104104
tokenizer_op(
105105
sample_dict=sample_dict,
@@ -116,14 +116,14 @@ def task_infer(task_dict: dict, smiles_seq: str) -> dict:
116116
output_scores=True,
117117
return_dict_in_generate=True,
118118
max_new_tokens=5,
119-
)
120-
119+
)
120+
121121
# Post-process the model's output
122122
result = process_model_output(
123123
tokenizer_op=tokenizer_op,
124124
decoder_output=batch_dict[CLS_PRED][0],
125125
decoder_output_scores=batch_dict[SCORES][0],
126-
)
126+
)
127127
return result
128128

129129

0 commit comments

Comments
 (0)