@@ -46,19 +46,19 @@ def eval_lp_performance(self, dataset=List[Tuple[str, str, str]], filtered=True)
46
46
return evaluate_lp (model = self .model , triple_idx = idx_dataset , num_entities = len (self .entity_to_idx ),
47
47
er_vocab = None , re_vocab = None )
48
48
49
- def predict_missing_head_entity (self , relation : List [str ], tail_entity : List [str ]) -> Tuple :
49
+ def predict_missing_head_entity (self , relation : Union [ List [str ], str ], tail_entity : Union [ List [str ], str ]) -> Tuple :
50
50
"""
51
51
Given a relation and a tail entity, return top k ranked head entity.
52
52
53
53
argmax_{e \in E } f(e,r,t), where r \in R, t \in E.
54
54
55
55
Parameter
56
56
---------
57
- relation: List[str]
57
+ relation: Union[ List[str], str]
58
58
59
59
String representation of selected relations.
60
60
61
- tail_entity: List[str]
61
+ tail_entity: Union[ List[str], str]
62
62
63
63
String representation of selected entities.
64
64
@@ -74,14 +74,22 @@ def predict_missing_head_entity(self, relation: List[str], tail_entity: List[str
74
74
"""
75
75
76
76
head_entity = torch .arange (0 , len (self .entity_to_idx ))
77
- relation = torch .LongTensor ([self .relation_to_idx [i ] for i in relation ])
78
- tail_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in tail_entity ])
77
+ if isinstance (relation , list ):
78
+ relation = torch .LongTensor ([self .relation_to_idx [i ] for i in relation ])
79
+ else :
80
+ relation = torch .LongTensor ([self .relation_to_idx [relation ]])
81
+ if isinstance (tail_entity , list ):
82
+ tail_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in tail_entity ])
83
+ else :
84
+ tail_entity = torch .LongTensor ([self .entity_to_idx [tail_entity ]])
85
+
79
86
x = torch .stack ((head_entity ,
80
87
relation .repeat (self .num_entities , ),
81
88
tail_entity .repeat (self .num_entities , )), dim = 1 )
82
89
return self .model .forward (x )
83
90
84
- def predict_missing_relations (self , head_entity : List [str ], tail_entity : List [str ]) -> Tuple :
91
+ def predict_missing_relations (self , head_entity : Union [List [str ], str ],
92
+ tail_entity : Union [List [str ], str ]) -> Tuple :
85
93
"""
86
94
Given a head entity and a tail entity, return top k ranked relations.
87
95
@@ -109,19 +117,23 @@ def predict_missing_relations(self, head_entity: List[str], tail_entity: List[st
109
117
Highest K scores and entities
110
118
"""
111
119
112
- head_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in head_entity ])
113
120
relation = torch .arange (0 , len (self .relation_to_idx ))
114
- tail_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in tail_entity ])
115
121
122
+ if isinstance (head_entity , list ):
123
+ head_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in head_entity ])
124
+ else :
125
+ head_entity = torch .LongTensor ([self .entity_to_idx [head_entity ]])
126
+ if isinstance (tail_entity , list ):
127
+ tail_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in tail_entity ])
128
+ else :
129
+ tail_entity = torch .LongTensor ([self .entity_to_idx [tail_entity ]])
116
130
x = torch .stack ((head_entity .repeat (self .num_relations , ),
117
131
relation ,
118
132
tail_entity .repeat (self .num_relations , )), dim = 1 )
119
133
return self .model (x )
120
- # scores = self.model(x)
121
- # sort_scores, sort_idxs = torch.topk(scores, topk)
122
- # return sort_scores, [self.idx_to_relations[i] for i in sort_idxs.tolist()]
123
134
124
- def predict_missing_tail_entity (self , head_entity : List [str ], relation : List [str ]) -> torch .FloatTensor :
135
+ def predict_missing_tail_entity (self , head_entity : Union [List [str ], str ],
136
+ relation : Union [List [str ], str ]) -> torch .FloatTensor :
125
137
"""
126
138
Given a head entity and a relation, return top k ranked entities
127
139
@@ -143,21 +155,38 @@ def predict_missing_tail_entity(self, head_entity: List[str], relation: List[str
143
155
144
156
scores
145
157
"""
146
- x = torch .cat ((torch .LongTensor ([self .entity_to_idx [i ] for i in head_entity ]).unsqueeze (- 1 ),
147
- torch .LongTensor ([self .relation_to_idx [i ] for i in relation ]).unsqueeze (- 1 )), dim = 1 )
158
+ tail_entity = torch .arange (0 , len (self .entity_to_idx ))
159
+
160
+ if isinstance (head_entity , list ):
161
+ head_entity = torch .LongTensor ([self .entity_to_idx [i ] for i in head_entity ])
162
+ else :
163
+ head_entity = torch .LongTensor ([self .entity_to_idx [head_entity ]])
164
+ if isinstance (relation , list ):
165
+ relation = torch .LongTensor ([self .relation_to_idx [i ] for i in relation ])
166
+ else :
167
+ relation = torch .LongTensor ([self .relation_to_idx [relation ]])
168
+
169
+ x = torch .stack ((head_entity .repeat (self .num_entities , ),
170
+ relation .repeat (self .num_entities , ),
171
+ tail_entity ), dim = 1 )
148
172
return self .model .forward (x )
149
173
150
- def predict (self , * , h : List [str ] = None , r : List [str ] = None , t : List [str ] = None ):
174
+ def predict (self , * , h : Union [List [str ], str ] = None , r : Union [List [str ], str ] = None ,
175
+ t : Union [List [str ], str ] = None ) -> torch .FloatTensor :
176
+ """
177
+ Predict missing triples by means of
178
+ """
151
179
# (1) Sanity checking.
152
180
if h is not None :
153
- assert isinstance (h , list )
181
+ assert isinstance (h , list ) or isinstance ( h , str )
154
182
assert isinstance (h [0 ], str )
155
183
if r is not None :
156
- assert isinstance (r , list )
184
+ assert isinstance (r , list ) or isinstance ( r , str )
157
185
assert isinstance (r [0 ], str )
158
186
if t is not None :
159
- assert isinstance (t , list )
187
+ assert isinstance (t , list ) or isinstance ( t , str )
160
188
assert isinstance (t [0 ], str )
189
+
161
190
# (2) Predict missing head entity given a relation and a tail entity.
162
191
if h is None :
163
192
assert r is not None
@@ -177,7 +206,6 @@ def predict(self, *, h: List[str] = None, r: List[str] = None, t: List[str] = No
177
206
# h r ?
178
207
scores = self .predict_missing_tail_entity (h , r )
179
208
else :
180
- assert len (h ) == len (r ) == len (t )
181
209
scores = self .triple_score (h , r , t )
182
210
return torch .sigmoid (scores )
183
211
@@ -261,8 +289,8 @@ def predict_topk(self, *, h: List[str] = None, r: List[str] = None, t: List[str]
261
289
else :
262
290
raise AttributeError ('Use triple_score method' )
263
291
264
- def triple_score (self , h : List [str ] = None , r : List [str ] = None ,
265
- t : List [str ] = None , logits = False ) -> torch .FloatTensor :
292
+ def triple_score (self , h : Union [ List [str ], str ] = None , r : Union [ List [str ], str ] = None ,
293
+ t : Union [ List [str ], str ] = None , logits = False ) -> torch .FloatTensor :
266
294
"""
267
295
Predict triple score
268
296
@@ -289,9 +317,14 @@ def triple_score(self, h: List[str] = None, r: List[str] = None,
289
317
290
318
pytorch tensor of triple score
291
319
"""
292
- h = torch .LongTensor ([self .entity_to_idx [i ] for i in h ]).reshape (len (h ), 1 )
293
- r = torch .LongTensor ([self .relation_to_idx [i ] for i in r ]).reshape (len (r ), 1 )
294
- t = torch .LongTensor ([self .entity_to_idx [i ] for i in t ]).reshape (len (t ), 1 )
320
+ if isinstance (h , list ) and isinstance (r , list ) and isinstance (t , list ):
321
+ h = torch .LongTensor ([self .entity_to_idx [i ] for i in h ]).reshape (len (h ), 1 )
322
+ r = torch .LongTensor ([self .relation_to_idx [i ] for i in r ]).reshape (len (r ), 1 )
323
+ t = torch .LongTensor ([self .entity_to_idx [i ] for i in t ]).reshape (len (t ), 1 )
324
+ else :
325
+ h = torch .LongTensor ([self .entity_to_idx [h ]]).reshape (1 , 1 )
326
+ r = torch .LongTensor ([self .relation_to_idx [r ]]).reshape (1 , 1 )
327
+ t = torch .LongTensor ([self .entity_to_idx [t ]]).reshape (1 , 1 )
295
328
296
329
x = torch .hstack ((h , r , t ))
297
330
if self .apply_semantic_constraint :
@@ -343,7 +376,8 @@ def negnorm(self, tens_1: torch.Tensor, lambda_: float, neg_norm: str = 'standar
343
376
def __single_hop_query_answering (self , query : Tuple [str , Tuple [str , ...]]):
344
377
head , relation = query
345
378
assert len (relation ) == 1
346
- return self .predict (h = [head ], r = [relation [0 ]])
379
+ # scores for all entities
380
+ return self .predict (h = head , r = relation [0 ])
347
381
348
382
def __return_answers_and_scores (self , query_score_of_all_entities , k : int ):
349
383
query_score_of_all_entities = [(ei , s ) for ei , s in zip (self .entity_to_idx .keys (), query_score_of_all_entities )]
@@ -443,10 +477,10 @@ def answer_multi_hop_query(self, query_type: str = None, query: Tuple[Union[str,
443
477
tnorm = tnorm ,
444
478
k = k ):
445
479
top_k_scores1 .append (score_of_e_r1_a )
446
- # () Scores for all entities E
447
- atom2_scores .append (self .predict (h = [ top_k_entity ] , r = [ relation2 ] ))
480
+ # (. ) Scores for all entities E
481
+ atom2_scores .append (self .predict (h = top_k_entity , r = relation2 ))
448
482
# k by E tensor
449
- atom2_scores = torch .cat (atom2_scores , dim = 0 )
483
+ atom2_scores = torch .vstack (atom2_scores )
450
484
topk_scores1_expanded = torch .FloatTensor (top_k_scores1 ).view (- 1 , 1 ).repeat (1 , atom2_scores .shape [1 ])
451
485
query_scores , _ = torch .max (self .t_norm (topk_scores1_expanded , atom2_scores , tnorm ), dim = 0 )
452
486
if only_scores :
@@ -468,7 +502,7 @@ def answer_multi_hop_query(self, query_type: str = None, query: Tuple[Union[str,
468
502
# () Scores for all entities E
469
503
atom2_scores .append (self .predict (h = [top_k_entity ], r = [relation3 ]))
470
504
# k by E tensor
471
- atom2_scores = torch .cat (atom2_scores , dim = 0 )
505
+ atom2_scores = torch .vstack (atom2_scores )
472
506
topk_scores1_expanded = torch .FloatTensor (top_k_scores1 ).view (- 1 , 1 ).repeat (1 , atom2_scores .shape [1 ])
473
507
query_scores , _ = torch .max (self .t_norm (topk_scores1_expanded , atom2_scores , tnorm ), dim = 0 )
474
508
if only_scores :
0 commit comments