Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the PPI dataset #228

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ include("graph.jl")

include("datasets/graphs/planetoid.jl")
include("datasets/graphs/traffic.jl")
include("datasets/graphs/graphsage.jl")
# export read_planetoid_data
include("datasets/graphs/cora.jl")
export Cora
Expand Down Expand Up @@ -137,6 +138,8 @@ include("datasets/graphs/pemsbay.jl")
export PEMSBAY
include("datasets/graphs/temporalbrains.jl")
export TemporalBrains
include("datasets/graphs/ppi.jl")
export PPI

# Meshes

Expand All @@ -159,6 +162,7 @@ function __init__()
__init__metrla()
__init__pemsbay()
__init__temporalbrains()
__init__ppi()

# misc
__init__iris()
Expand Down
61 changes: 61 additions & 0 deletions src/datasets/graphs/graphsage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Read any of the datasets “Reddit” and “PPI”
from the “Inductive Representation Learning on Large Graphs” paper.

Data collected from
http://snap.stanford.edu/graphsage/
"""

function read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path)
# Read the json files
graph = read_json(graph_json)
class_map = read_json(class_map_json)
id_map = read_json(id_map_json)

# Metadata
directed = graph["directed"]
multigraph = graph["multigraph"]
links = graph["links"]
nodes = graph["nodes"]
num_edges = directed ? length(links) : length(links) * 2
num_nodes = length(nodes)
@assert length(graph["graph"]) == 0 # should be zero

# edges
s = get.(links, "source", nothing) .+ 1
t = get.(links, "target", nothing) .+ 1
if !directed
s, t = [s; t], [t; s]
end

# labels
node_keys = get.(nodes, "id", nothing)
node_idx = [id_map[key] for key in node_keys] .+ 1

sort_order = sortperm(node_idx)
node_idx = node_idx[sort_order]
labels = [class_map[key] for key in node_keys][sort_order]
@assert length(node_idx) == length(labels)
# if `labels` is an array of array, this turns it into a matrix
# if `labels` is an array of numbers, this dose nothing
labels = stack(labels)

# features
features = read_npz(feat_path)[node_idx, :]
features = permutedims(features, (2, 1))

# split
test_mask = get.(nodes, "test", nothing)[sort_order]
val_mask = get.(nodes, "val", nothing)[sort_order]
# A node should not be both test and validation
@assert sum(val_mask .& test_mask) == 0
train_mask = nor.(test_mask, val_mask)

metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph,
"num_edges" => num_edges, "num_nodes" => num_nodes)
g = Graph(; num_nodes,
edge_index = (s, t),
node_data = (; labels, features, train_mask, val_mask, test_mask))

return metadata, g
end
60 changes: 60 additions & 0 deletions src/datasets/graphs/ppi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
function __init__ppi()
DEPNAME = "PPI"
LINK = "http://snap.stanford.edu/graphsage/ppi.zip"
DOCS = "http://snap.stanford.edu/graphsage/"

register(DataDep(DEPNAME,
"""
Dataset: The $DEPNAME Dataset
Website: $DOCS
Download Link: $LINK
""",
LINK,
"53aeb76e54fd41b645e7edb48b62929240b89839495396b048086fd212503fbd",
post_fetch_method = unpack))
end

"""
PPI(; full=true, dir=nothing)

The PPI dataset was introduced in Ref [1].
Protein roles—in terms of their cellular functions from gene
ontology—in various protein-protein interaction (PPI) graphs,
with each graph corresponding to a different human tissue.
Positional gene sets are used, motif gene sets and immunological
signatures as features and gene ontology sets as labels (121 in total),
collected from the Molecular Signatures Database.
The average graph contains 2373 nodes, with an average degree of 28.8.


# References
[1]: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216)
"""
struct PPI <: AbstractDataset
metadata::Dict{String, Any}
graphs::Vector{Graph}
end

function PPI(; dir = nothing)
DATAFILES = [
"ppi-G.json", "ppi-walks.txt",
"ppi-class_map.json", "ppi-feats.npy", "ppi-id_map.json",
]
DATA = joinpath.("ppi", DATAFILES)
DEPNAME = "PPI"


graph_json = datafile(DEPNAME, DATA[1], dir)

class_map_json = datafile(DEPNAME, DATA[3], dir)
id_map_json = datafile(DEPNAME, DATA[5], dir)
feat_path = datafile(DEPNAME, DATA[4], dir)

metadata, g = read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path)

return PPI(metadata, [g])
end

Base.length(d::PPI) = length(d.graphs)
Base.getindex(d::PPI, ::Colon) = d.graphs
Base.getindex(d::PPI, i) = getindex(d.graphs, i)
46 changes: 1 addition & 45 deletions src/datasets/graphs/reddit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,52 +56,8 @@ function Reddit(; full = true, dir = nothing)
id_map_json = datafile(DEPNAME, DATA[6], dir)
feat_path = datafile(DEPNAME, DATA[5], dir)

# Read the json files
graph = read_json(graph_json)
class_map = read_json(class_map_json)
id_map = read_json(id_map_json)
metadata, g = read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path)

# Metadata
directed = graph["directed"]
multigraph = graph["multigraph"]
links = graph["links"]
nodes = graph["nodes"]
num_edges = directed ? length(links) : length(links) * 2
num_nodes = length(nodes)
@assert length(graph["graph"]) == 0 # should be zero

# edges
s = get.(links, "source", nothing) .+ 1
t = get.(links, "target", nothing) .+ 1
if !directed
s, t = [s; t], [t; s]
end

# labels
node_keys = get.(nodes, "id", nothing)
node_idx = [id_map[key] for key in node_keys] .+ 1

sort_order = sortperm(node_idx)
node_idx = node_idx[sort_order]
labels = [class_map[key] for key in node_keys][sort_order]
@assert length(node_idx) == length(labels)

# features
features = read_npz(feat_path)[node_idx, :]
features = permutedims(features, (2, 1))

# split
test_mask = get.(nodes, "test", nothing)[sort_order]
val_mask = get.(nodes, "val", nothing)[sort_order]
# A node should not be both test and validation
@assert sum(val_mask .& test_mask) == 0
train_mask = nor.(test_mask, val_mask)

metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph,
"num_edges" => num_edges, "num_nodes" => num_nodes)
g = Graph(; num_nodes,
edge_index = (s, t),
node_data = (; labels, features, train_mask, val_mask, test_mask))
return Reddit(metadata, [g])
end

Expand Down
Loading