Skip to content

Commit 9331183

Browse files
committed
Split graphml_class/citation.py into graphml_class/citation/networkx.py and graphml_class/citation/dgl.py
1 parent 7919eb1 commit 9331183

File tree

3 files changed

+203
-194
lines changed

3 files changed

+203
-194
lines changed

graphml_class/citation/dgl.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import gzip
2+
import os
3+
import tarfile
4+
from typing import Dict, List
5+
6+
import dgl
7+
import numpy as np
8+
import torch
9+
from dgl.data import DGLDataset
10+
from dgl.data.utils import download as dgl_download
11+
from dgl.data.utils import load_graphs, save_graphs
12+
from dgl.sampling import global_uniform_negative_sampling
13+
from sentence_transformers import SentenceTransformer
14+
15+
16+
class CitationGraphDataset(DGLDataset):
17+
"""CitationGraphDataset DGLDataset sub-class for loading our citation network dataset.
18+
19+
Parameters
20+
----------
21+
url : str
22+
URL of the base path at SNAP from which to download the raw dataset(s)
23+
raw_dir : str
24+
Specifying the directory that will store the
25+
downloaded data or the directory that
26+
already stores the input data.
27+
Default: ~/.dgl/
28+
save_dir : str
29+
Directory to save the processed dataset.
30+
Default: the value of `raw_dir`
31+
force_reload : bool
32+
Whether to reload the dataset. Default: False
33+
verbose : bool
34+
Whether to print out progress information
35+
"""
36+
37+
name = "cit-HepTh"
38+
39+
def __init__(
40+
self,
41+
url="https://snap.stanford.edu/data/",
42+
raw_dir="data",
43+
save_dir="data",
44+
force_reload=False,
45+
verbose=False,
46+
):
47+
self._save_path = os.path.join(save_dir, f"{CitationGraphDataset.name}.bin")
48+
self._paraphrase_model = SentenceTransformer(
49+
"sentence-transformers/paraphrase-MiniLM-L6-v2"
50+
)
51+
super(CitationGraphDataset, self).__init__(
52+
name=CitationGraphDataset.name,
53+
url=url,
54+
raw_dir=raw_dir,
55+
save_dir=save_dir,
56+
force_reload=force_reload,
57+
verbose=verbose,
58+
)
59+
60+
def download(self):
61+
"""download Download all three files: edges, abstracts, and publishing dates."""
62+
file_names: List[str] = [
63+
"cit-HepTh.txt.gz",
64+
"cit-HepTh-abstracts.tar.gz",
65+
"cit-HepTh-dates.txt.gz",
66+
]
67+
for file_name in file_names:
68+
file_path = os.path.join(self.raw_dir, file_name)
69+
dgl_download(self.url + file_name, path=file_path)
70+
71+
def featurize(self, file_to_net: Dict[int, int]) -> torch.Tensor:
72+
"""featurize Sentence encode abstracts into 384-dimension embeddings via paraphrase-MiniLM-L6-v2.
73+
74+
Returns
75+
-------
76+
torch.Tensor
77+
(node_count,384) tensor of abstract embeddings
78+
"""
79+
self.model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
80+
81+
abstracts: Dict[int, str] = {}
82+
83+
# Decompress the gzip content, then work through the abstract files in the tarball
84+
abstract_path = os.path.join(self.raw_dir, "cit-HepTh-abstracts.tar.gz")
85+
with gzip.GzipFile(filename=abstract_path) as f:
86+
with tarfile.open(fileobj=f, mode="r|") as tar:
87+
for member in tar:
88+
paper_id = int(os.path.basename(member.name).split(".")[0])
89+
90+
abstract_file = tar.extractfile(member)
91+
if abstract_file and paper_id in file_to_net:
92+
content = abstract_file.read().decode("utf-8")
93+
abstracts[file_to_net[paper_id]] = content
94+
95+
# Embed all the abstracts at once with a paraphrase model
96+
node_ids = list(abstracts.keys())
97+
contents = list(abstracts.values())
98+
99+
if contents and isinstance(contents, str):
100+
contents = [contents]
101+
all_embeddings = self._paraphrase_model.encode(contents, convert_to_tensor=False)
102+
103+
# Determine the max node ID to ensure the embeddings tensor covers all nodes
104+
max_node_id = max(abstracts.keys())
105+
106+
# Pre-allocate a tensor filled with zeros for all node embeddings
107+
embeddings = np.zeros((max_node_id + 1, 384))
108+
109+
# Fill in the tensor with the embeddings at the corresponding node IDs
110+
for idx, node_id in enumerate(node_ids):
111+
embeddings[node_id] = all_embeddings[idx]
112+
113+
return torch.tensor(embeddings)
114+
115+
def process(self):
116+
"""process Build graph and node features from sbert on raw data."""
117+
118+
#
119+
# Build the graph U/V edge Tensors
120+
#
121+
u, v = [], []
122+
current_idx = 0
123+
edge_path = os.path.join(self.raw_dir, "cit-HepTh.txt.gz")
124+
file_to_net: Dict[int, int] = {}
125+
with gzip.GzipFile(filename=edge_path) as f:
126+
for line_number, line in enumerate(f):
127+
line = line.decode("utf-8")
128+
129+
# Ignore comment lines that start with '#'
130+
if not line.startswith("#"):
131+
# Source (citing), desstination (cited) papers
132+
citing_key, cited_key = line.strip().split("\t")
133+
134+
# The edge list makes the paper ID an int, stripping 0001001 to 1001, for example
135+
citing_key, cited_key = int(citing_key), int(cited_key)
136+
137+
# If the either of the paper IDs don't exist, make one
138+
for key in [citing_key, cited_key]:
139+
if key not in file_to_net:
140+
# Build up an index that maps back and forth
141+
file_to_net[key] = current_idx
142+
143+
# Bump the current ID
144+
current_idx += 1
145+
146+
u.append(file_to_net[citing_key])
147+
v.append(file_to_net[cited_key])
148+
149+
# Build our DGLGraph from the edge list :)
150+
self.graph = dgl.graph((torch.tensor(u), torch.tensor(v)))
151+
self.graph.ndata["feat"] = self.featurize(file_to_net)
152+
153+
#
154+
# Build label set for link prediction: balanced positive / negative, count = number of edges
155+
#
156+
157+
# The real edges are the positive labels
158+
pos_src, pos_dst = self.graph.edges()
159+
160+
# Negative sample labels sized by the number of actual, positive edges
161+
neg_src, neg_dst = global_uniform_negative_sampling(
162+
self.graph, self.graph.number_of_edges()
163+
)
164+
165+
# Combine positive and negative samples
166+
self.src_nodes = torch.cat([pos_src, neg_src])
167+
self.dst_nodes = torch.cat([pos_dst, neg_dst])
168+
169+
# Generate labels: 1 for positive and 0 for negative samples
170+
self.graph.ndata["label"] = torch.cat(
171+
[torch.ones_like(pos_src), torch.zeros_like(neg_src)]
172+
).float()
173+
174+
def __getitem__(self, idx):
175+
assert idx == 0, "This dataset has only one graph"
176+
return self.graph
177+
178+
def __len__(self):
179+
"""__len__ we are one graph long"""
180+
return 1
181+
182+
def save(self):
183+
"""save Save our one graph to directory `self.save_path`"""
184+
save_graphs(self._save_path, [self.graph])
185+
186+
def load(self):
187+
"""load Load processed data from directory `self.save_path`"""
188+
self.graph = load_graphs(self._save_path)[0]
189+
190+
def has_cache(self):
191+
"""has_cache Check whether there are processed data in `self.save_path`"""
192+
return os.path.exists(self._save_path)

0 commit comments

Comments
 (0)