-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_scorers.py
125 lines (98 loc) · 4.2 KB
/
test_scorers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pytest
from rxn_negative_learning.models.scorers.ideal_scorer import IdealScorer
from rxn_negative_learning.models.scorers.levenstein_ideal_scorer import LevenshteinIdealScorer
from rxn_negative_learning.models.scorers.svm_scorer import SVMScorer
from rxn_negative_learning.models.scorers.tanimoto_ideal_scorer import TanimotoIdealScorer
from rxn_negative_learning.utils.repo_utils import models_directory
def test_ideal_scorer():
pos_rxns = ["A.A>>B", "C.C>>D"]
neg_rxns = ["A.A>>C", "C.C>>B"]
ideal_scorer = IdealScorer(pos_rxns, neg_rxns)
result = ideal_scorer(["A.A>>B", "C.C>>D", "A.A>>C", "C.C>>B"])
assert result == [1, 1, 0, 0]
# With shift
ideal_scorer = IdealScorer(pos_rxns, neg_rxns, shift=0.5)
result = ideal_scorer(["A.A>>B", "C.C>>D", "A.A>>C", "C.C>>B"])
assert result == [0.5, 0.5, -0.5, -0.5]
def test_levenshtein_ideal_scorer():
pos_rxns = ["ola>>ciao"]
levenshtein_scorer = LevenshteinIdealScorer(pos_rxns)
result = levenshtein_scorer(
["ola>>ciao", "ola>>miao", "cola>>vola", "ola>>"],
)
assert result == [1, pytest.approx(0.5), 0, 0]
# With shift
levenshtein_scorer = LevenshteinIdealScorer(pos_rxns, shift=0.5)
result = levenshtein_scorer(
["ola>>ciao", "ola>>miao", "cola>>vola", "ola>>"],
)
assert result == [0.5, pytest.approx(0.5 - 0.5), -0.5, -0.5]
# With negative reactions grounding
neg_rxns = ["ola>>miao", "ola>>wow"]
levenshtein_scorer = LevenshteinIdealScorer(pos_rxns, neg_rxns)
result = levenshtein_scorer(
["ola>>ciao", "ola>>miao", "cola>>vola", "ola>>wow", "ola>>"],
)
assert result == [1, 0, 0, 0, 0]
def test_tanimoto_ideal_scorer():
pos_rxns = ["O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1cn[nH]c1"]
tanimoto_scorer = TanimotoIdealScorer(pos_rxns)
result = tanimoto_scorer([
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1cn[nH]c1",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1ccn[nH]1",
])
assert result == [1, 0, 0, pytest.approx(0.40540540540540543)]
# With shift
tanimoto_scorer = TanimotoIdealScorer(pos_rxns, shift=0.5)
result = tanimoto_scorer([
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1cn[nH]c1",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1ccn[nH]1",
])
assert result == [0.5, -0.5, -0.5, pytest.approx(0.40540540540540543 - 0.5)]
# With negative grounding
neg_rxns = ["O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1ccn[nH]1"]
tanimoto_scorer = TanimotoIdealScorer(pos_rxns, neg_rxns)
result = tanimoto_scorer([
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1cn[nH]c1",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>Brc1ccn[nH]1",
])
assert result == [1, 0, 0, 0]
def test_svm_scorer():
model_path = models_directory() / "old" / "svm_scorer"
print(model_path)
svm_scorer = SVMScorer(model_path)
result = svm_scorer([
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1[nH]cc(-c2ccncc2)cc1Br",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1ccc(-c2ccncc2Br)c[nH]1",
])
assert result == [1, 0, 0, 1]
# With shift
svm_scorer = SVMScorer(model_path, shift=0.5)
result = svm_scorer([
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1[nH]cc(-c2ccncc2)cc1Br",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1ccc(-c2ccncc2Br)c[nH]1",
])
assert result == [0.5, -0.5, -0.5, 0.5]
# With grounding
pos_rxns = ["O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1[nH]cc(-c2ccncc2)cc1Br"]
neg_rxns = ["O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1ccc(-c2ccncc2Br)c[nH]1"]
svm_scorer = SVMScorer(
model_path=model_path, positive_reactions=pos_rxns, negative_reactions=neg_rxns
)
result = svm_scorer([
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1[nH]cc(-c2ccncc2)cc1Br",
"O=C1CCC(=O)N1Br.c1cn[nH]c1>>",
"CO.CCCC>>[Na+]",
"O=C1CCC(=O)N1Br.O=c1ccc(-c2ccncc2)c[nH]1>>O=c1ccc(-c2ccncc2Br)c[nH]1",
])
assert result == [1, 0, 0, 0]