1
- import numpy as np
2
1
import click
2
+ import numpy as np
3
3
import torch
4
4
from fuse .data .tokenizers .modular_tokenizer .op import ModularTokenizerOp
5
5
6
- from mammal .model import Mammal
7
6
from mammal .keys import (
8
7
CLS_PRED ,
9
8
ENCODER_INPUTS_ATTENTION_MASK ,
10
9
ENCODER_INPUTS_STR ,
11
10
ENCODER_INPUTS_TOKENS ,
12
11
SCORES ,
13
12
)
13
+ from mammal .model import Mammal
14
14
15
15
TASK_NAMES = ["BBBP" , "TOXICITY" , "FDA_APPR" ]
16
16
@@ -39,22 +39,22 @@ def load_model(task_name: str, device: str) -> dict:
39
39
case "TOXICITY" :
40
40
path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox"
41
41
case "FDA_APPR" :
42
- path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
43
- case _:
44
- print (f"The { task_name = } is incorrect" )
45
-
42
+ path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
43
+ case _:
44
+ print (f"The { task_name = } is incorrect" )
45
+
46
46
# Load Model and set to evaluation mode
47
47
model = Mammal .from_pretrained (path )
48
48
model .eval ()
49
49
model .to (device = device )
50
-
50
+
51
51
# Load Tokenizer
52
52
tokenizer_op = ModularTokenizerOp .from_pretrained (path )
53
-
53
+
54
54
task_dict = dict (
55
55
task_name = task_name ,
56
56
model = model ,
57
- tokenizer_op = tokenizer_op ,
57
+ tokenizer_op = tokenizer_op ,
58
58
)
59
59
return task_dict
60
60
@@ -65,7 +65,7 @@ def process_model_output(
65
65
decoder_output_scores : np .ndarray ,
66
66
) -> dict :
67
67
"""
68
- Extract predicted class and scores
68
+ Extract predicted class and scores
69
69
"""
70
70
negative_token_id = tokenizer_op .get_token_id ("<0>" )
71
71
positive_token_id = tokenizer_op .get_token_id ("<1>" )
@@ -79,10 +79,10 @@ def process_model_output(
79
79
scores = decoder_output_scores [
80
80
classification_position , positive_token_id
81
81
]
82
-
82
+
83
83
ans = dict (
84
84
pred = label_id_to_int .get (int (decoder_output [classification_position ]), - 1 ),
85
- score = scores .item (),
85
+ score = scores .item (),
86
86
)
87
87
return ans
88
88
@@ -91,15 +91,15 @@ def task_infer(task_dict: dict, smiles_seq: str) -> dict:
91
91
task_name = task_dict ["task_name" ]
92
92
model = task_dict ["model" ]
93
93
tokenizer_op = task_dict ["tokenizer_op" ]
94
-
94
+
95
95
if task_name not in TASK_NAMES :
96
- print (f"The { task_name = } is incorrect. Valid names are { TASK_NAMES } " )
97
-
96
+ print (f"The { task_name = } is incorrect. Valid names are { TASK_NAMES } " )
97
+
98
98
# Create and load sample
99
99
sample_dict = dict ()
100
100
# Formatting prompt to match pre-training syntax
101
101
sample_dict [ENCODER_INPUTS_STR ] = f"<@TOKENIZER-TYPE=SMILES><MOLECULAR_ENTITY><MOLECULAR_ENTITY_SMALL_MOLECULE><{ task_name } ><SENTINEL_ID_0><@TOKENIZER-TYPE=SMILES@MAX-LEN=2100><SEQUENCE_NATURAL_START>{ smiles_seq } <SEQUENCE_NATURAL_END><EOS>"
102
-
102
+
103
103
# Tokenize
104
104
tokenizer_op (
105
105
sample_dict = sample_dict ,
@@ -116,14 +116,14 @@ def task_infer(task_dict: dict, smiles_seq: str) -> dict:
116
116
output_scores = True ,
117
117
return_dict_in_generate = True ,
118
118
max_new_tokens = 5 ,
119
- )
120
-
119
+ )
120
+
121
121
# Post-process the model's output
122
122
result = process_model_output (
123
123
tokenizer_op = tokenizer_op ,
124
124
decoder_output = batch_dict [CLS_PRED ][0 ],
125
125
decoder_output_scores = batch_dict [SCORES ][0 ],
126
- )
126
+ )
127
127
return result
128
128
129
129
0 commit comments