Skip to content

Commit

Permalink
Handle literals
Browse files Browse the repository at this point in the history
  • Loading branch information
sshivam95 committed Nov 29, 2024
1 parent 85e3f79 commit 4d64660
Showing 1 changed file with 46 additions and 9 deletions.
55 changes: 46 additions & 9 deletions dicee/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pickle
from copy import deepcopy
from .static_funcs import save_pickle, load_pickle
import re


class QueryGenerator:
Expand Down Expand Up @@ -73,23 +74,59 @@ def construct_graph(self, paths: List[str]) -> Tuple[Dict, Dict]:
"""
Construct graph from triples
Returns dicts with incoming and outgoing edges
"""
"""
# Mapping from tail entity and a relation to heads.
tail_relation_to_heads = defaultdict(lambda: defaultdict(set))
# Mapping from head and relation to tails.
head_relation_to_tails = defaultdict(lambda: defaultdict(set))

for path in paths:
import shlex

with open(path, "r") as f:
for line in f:
try:
h, r, t = map(str, line.strip().split("\t"))
except:
h, r, t, _ = map(str, line.strip().split(" "))
if t.startswith('"'):
continue # Skip literals
tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h])
head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t])
line = line.strip()
# Skip empty lines or comments
if not line or line.startswith('#'):
continue

# Use shlex.split to correctly parse the line into tokens
tokens = shlex.split(line)

# Check that the line ends with a period '.'
if tokens[-1] != '.':
continue # Skip malformed lines

# Remove the period
tokens = tokens[:-1]

# Check that we have exactly 3 tokens: h, r, t
if len(tokens) != 3:
continue # Skip malformed lines

h, r, t = tokens

# Check if t is a literal (starts and ends with double quotes)
if t.startswith('"') and t.endswith('"'):
continue # Skip literals

# Map to IDs
h_id = self.ent2id.get(h)
r_id = self.rel2id.get(r)
t_id = self.ent2id.get(t)

# Skip if any ID is not found
if h_id is None or r_id is None or t_id is None:
continue

# Update the dictionaries
tail_relation_to_heads.setdefault(t_id, {}).setdefault(r_id, set()).add(h_id)
head_relation_to_tails.setdefault(h_id, {}).setdefault(r_id, set()).add(t_id)


# h, r, t = map(str, line.strip().split("\t"))
# tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h])
# head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t])

self.ent_in = tail_relation_to_heads
self.ent_out = head_relation_to_tails
Expand Down

0 comments on commit 4d64660

Please sign in to comment.