diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 412b79c6..b24df06d 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -110,6 +110,7 @@ include("graph.jl") include("datasets/graphs/planetoid.jl") include("datasets/graphs/traffic.jl") +include("datasets/graphs/graphsage.jl") # export read_planetoid_data include("datasets/graphs/cora.jl") export Cora @@ -137,6 +138,8 @@ include("datasets/graphs/pemsbay.jl") export PEMSBAY include("datasets/graphs/temporalbrains.jl") export TemporalBrains +include("datasets/graphs/ppi.jl") +export PPI # Meshes @@ -159,6 +162,7 @@ function __init__() __init__metrla() __init__pemsbay() __init__temporalbrains() + __init__ppi() # misc __init__iris() diff --git a/src/datasets/graphs/graphsage.jl b/src/datasets/graphs/graphsage.jl new file mode 100644 index 00000000..d7e73b68 --- /dev/null +++ b/src/datasets/graphs/graphsage.jl @@ -0,0 +1,61 @@ +""" +Read any of the datasets “Reddit” and “PPI” +from the “Inductive Representation Learning on Large Graphs” paper. + +Data collected from +http://snap.stanford.edu/graphsage/ +""" + +function read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path) + # Read the json files + graph = read_json(graph_json) + class_map = read_json(class_map_json) + id_map = read_json(id_map_json) + + # Metadata + directed = graph["directed"] + multigraph = graph["multigraph"] + links = graph["links"] + nodes = graph["nodes"] + num_edges = directed ? length(links) : length(links) * 2 + num_nodes = length(nodes) + @assert length(graph["graph"]) == 0 # should be zero + + # edges + s = get.(links, "source", nothing) .+ 1 + t = get.(links, "target", nothing) .+ 1 + if !directed + s, t = [s; t], [t; s] + end + + # labels + node_keys = get.(nodes, "id", nothing) + node_idx = [id_map[key] for key in node_keys] .+ 1 + + sort_order = sortperm(node_idx) + node_idx = node_idx[sort_order] + labels = [class_map[key] for key in node_keys][sort_order] + @assert length(node_idx) == length(labels) + # if `labels` is an array of array, this turns it into a matrix + # if `labels` is an array of numbers, this dose nothing + labels = stack(labels) + + # features + features = read_npz(feat_path)[node_idx, :] + features = permutedims(features, (2, 1)) + + # split + test_mask = get.(nodes, "test", nothing)[sort_order] + val_mask = get.(nodes, "val", nothing)[sort_order] + # A node should not be both test and validation + @assert sum(val_mask .& test_mask) == 0 + train_mask = nor.(test_mask, val_mask) + + metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph, + "num_edges" => num_edges, "num_nodes" => num_nodes) + g = Graph(; num_nodes, + edge_index = (s, t), + node_data = (; labels, features, train_mask, val_mask, test_mask)) + + return metadata, g +end diff --git a/src/datasets/graphs/ppi.jl b/src/datasets/graphs/ppi.jl new file mode 100644 index 00000000..953c8006 --- /dev/null +++ b/src/datasets/graphs/ppi.jl @@ -0,0 +1,60 @@ +function __init__ppi() + DEPNAME = "PPI" + LINK = "http://snap.stanford.edu/graphsage/ppi.zip" + DOCS = "http://snap.stanford.edu/graphsage/" + + register(DataDep(DEPNAME, + """ + Dataset: The $DEPNAME Dataset + Website: $DOCS + Download Link: $LINK + """, + LINK, + "53aeb76e54fd41b645e7edb48b62929240b89839495396b048086fd212503fbd", + post_fetch_method = unpack)) +end + +""" + PPI(; full=true, dir=nothing) + +The PPI dataset was introduced in Ref [1]. +Protein roles—in terms of their cellular functions from gene +ontology—in various protein-protein interaction (PPI) graphs, +with each graph corresponding to a different human tissue. +Positional gene sets are used, motif gene sets and immunological +signatures as features and gene ontology sets as labels (121 in total), +collected from the Molecular Signatures Database. +The average graph contains 2373 nodes, with an average degree of 28.8. + + +# References +[1]: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) +""" +struct PPI <: AbstractDataset + metadata::Dict{String, Any} + graphs::Vector{Graph} +end + +function PPI(; dir = nothing) + DATAFILES = [ + "ppi-G.json", "ppi-walks.txt", + "ppi-class_map.json", "ppi-feats.npy", "ppi-id_map.json", + ] + DATA = joinpath.("ppi", DATAFILES) + DEPNAME = "PPI" + + + graph_json = datafile(DEPNAME, DATA[1], dir) + + class_map_json = datafile(DEPNAME, DATA[3], dir) + id_map_json = datafile(DEPNAME, DATA[5], dir) + feat_path = datafile(DEPNAME, DATA[4], dir) + + metadata, g = read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path) + + return PPI(metadata, [g]) +end + +Base.length(d::PPI) = length(d.graphs) +Base.getindex(d::PPI, ::Colon) = d.graphs +Base.getindex(d::PPI, i) = getindex(d.graphs, i) diff --git a/src/datasets/graphs/reddit.jl b/src/datasets/graphs/reddit.jl index 22f48fdb..4f7f5b7b 100644 --- a/src/datasets/graphs/reddit.jl +++ b/src/datasets/graphs/reddit.jl @@ -56,52 +56,8 @@ function Reddit(; full = true, dir = nothing) id_map_json = datafile(DEPNAME, DATA[6], dir) feat_path = datafile(DEPNAME, DATA[5], dir) - # Read the json files - graph = read_json(graph_json) - class_map = read_json(class_map_json) - id_map = read_json(id_map_json) + metadata, g = read_graphsage_data(graph_json, class_map_json, id_map_json, feat_path) - # Metadata - directed = graph["directed"] - multigraph = graph["multigraph"] - links = graph["links"] - nodes = graph["nodes"] - num_edges = directed ? length(links) : length(links) * 2 - num_nodes = length(nodes) - @assert length(graph["graph"]) == 0 # should be zero - - # edges - s = get.(links, "source", nothing) .+ 1 - t = get.(links, "target", nothing) .+ 1 - if !directed - s, t = [s; t], [t; s] - end - - # labels - node_keys = get.(nodes, "id", nothing) - node_idx = [id_map[key] for key in node_keys] .+ 1 - - sort_order = sortperm(node_idx) - node_idx = node_idx[sort_order] - labels = [class_map[key] for key in node_keys][sort_order] - @assert length(node_idx) == length(labels) - - # features - features = read_npz(feat_path)[node_idx, :] - features = permutedims(features, (2, 1)) - - # split - test_mask = get.(nodes, "test", nothing)[sort_order] - val_mask = get.(nodes, "val", nothing)[sort_order] - # A node should not be both test and validation - @assert sum(val_mask .& test_mask) == 0 - train_mask = nor.(test_mask, val_mask) - - metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph, - "num_edges" => num_edges, "num_nodes" => num_nodes) - g = Graph(; num_nodes, - edge_index = (s, t), - node_data = (; labels, features, train_mask, val_mask, test_mask)) return Reddit(metadata, [g]) end