diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index 8fd3ee6e..da03ec10 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -1,8 +1,11 @@ +import functools import json +from typing import Any import warnings import geopandas as gp import networkx +from networkx.classes.function import frozen from networkx.readwrite import json_graph from shapely.ops import unary_union from shapely.prepared import prep @@ -18,10 +21,14 @@ class Graph(networkx.Graph): to save and load graphs as JSON files. """ - def __repr__(self): return "".format(len(self.nodes), len(self.edges)) + @classmethod + def from_networkx(cls, graph: networkx.Graph): + g = cls(graph) + return g + @classmethod def from_json(cls, json_file): """Load a graph from a JSON file in the NetworkX json_graph format. @@ -31,7 +38,7 @@ def from_json(cls, json_file): with open(json_file) as f: data = json.load(f) g = json_graph.adjacency_graph(data) - graph = cls(g) + graph = cls.from_networkx(g) graph.issue_warnings() return graph @@ -165,6 +172,22 @@ def from_geodataframe( graph.add_data(df, columns=cols_to_add) return graph + def lookup(self, node, field): + """ + Lookup a node/field attribute. + :param node: Node to look up. + :param field: Field to look up. + """ + return self.nodes[node][field] + + @property + def node_indices(self): + return set(self.nodes) + + @property + def edge_indices(self): + return set(self.edges) + def add_data(self, df, columns=None): """Add columns of a DataFrame to a graph as node attributes using by matching the DataFrame's index to node ids. @@ -310,3 +333,61 @@ def convert_geometries_to_geojson(data): # This is what :func:`geopandas.GeoSeries.to_json` uses under # the hood. node[key] = node[key].__geo_interface__ + + +class FrozenGraph: + """ Represents an immutable graph to be partitioned. It is based off :class:`Graph`. + + This speeds up chain runs and prevents having to deal with cache invalidation issues. + This class behaves slightly differently than :class:`Graph` or :class:`networkx.Graph`. + """ + __slots__ = [ + "graph", + "size" + ] + + def __init__(self, graph: Graph): + self.graph = networkx.classes.function.freeze(graph) + self.graph.join = frozen + self.graph.add_data = frozen + self.graph.add_data = frozen + + self.size = len(self.graph) + + def __len__(self): + return self.size + + def __getattribute__(self, __name: str) -> Any: + try: + return object.__getattribute__(self, __name) + except AttributeError: + return object.__getattribute__(self.graph, __name) + + def __getitem__(self, __name: str) -> Any: + return self.graph[__name] + + def __iter__(self): + yield from self.node_indices + + @functools.lru_cache(16384) + def neighbors(self, n): + return tuple(self.graph.neighbors(n)) + + @functools.cached_property + def node_indices(self): + return self.graph.node_indices + + @functools.cached_property + def edge_indices(self): + return self.graph.edge_indices + + @functools.lru_cache(16384) + def degree(self, n): + return self.graph.degree(n) + + @functools.lru_cache(65536) + def lookup(self, node, field): + return self.graph.nodes[node][field] + + def subgraph(self, nodes): + return FrozenGraph(self.graph.subgraph(nodes)) diff --git a/gerrychain/grid.py b/gerrychain/grid.py index aac5d3b8..cb23fbe2 100644 --- a/gerrychain/grid.py +++ b/gerrychain/grid.py @@ -3,6 +3,7 @@ import networkx from gerrychain.partition import Partition +from gerrychain.graph import Graph from gerrychain.updaters import ( Tally, boundary_nodes, @@ -63,7 +64,7 @@ def __init__( """ if dimensions: self.dimensions = dimensions - graph = create_grid_graph(dimensions, with_diagonals) + graph = Graph.from_networkx(create_grid_graph(dimensions, with_diagonals)) if not assignment: thresholds = tuple(math.floor(n / 2) for n in self.dimensions) diff --git a/gerrychain/partition/partition.py b/gerrychain/partition/partition.py index caf8309e..b9204c1f 100644 --- a/gerrychain/partition/partition.py +++ b/gerrychain/partition/partition.py @@ -1,5 +1,8 @@ import json import geopandas +import networkx + +from gerrychain.graph.graph import FrozenGraph, Graph from ..updaters import compute_edge_flows, flows_from_changes, cut_edges from .assignment import get_assignment from .subgraphs import SubgraphView @@ -50,7 +53,16 @@ def __init__( self.subgraphs = SubgraphView(self.graph, self.parts) def _first_time(self, graph, assignment, updaters, use_cut_edges): - self.graph = graph + if isinstance(graph, Graph): + self.graph = FrozenGraph(graph) + elif isinstance(graph, networkx.Graph): + graph = Graph.from_networkx(graph) + self.graph = FrozenGraph(graph) + elif isinstance(graph, FrozenGraph): + self.graph = graph + else: + raise TypeError("Unsupported Graph object") + self.assignment = get_assignment(assignment, graph) if set(self.assignment) != set(graph): diff --git a/gerrychain/proposals/spectral_proposals.py b/gerrychain/proposals/spectral_proposals.py index 5e361557..09048955 100644 --- a/gerrychain/proposals/spectral_proposals.py +++ b/gerrychain/proposals/spectral_proposals.py @@ -15,7 +15,7 @@ def spectral_cut(graph, part_labels, weight_type, lap_type): n = len(nlist) if weight_type == "random": - for edge in graph.edges(): + for edge in graph.edge_indices: graph.edges[edge]["weight"] = random.random() if lap_type == "normalized": diff --git a/gerrychain/proposals/tree_proposals.py b/gerrychain/proposals/tree_proposals.py index 5e39324e..ab2f09af 100644 --- a/gerrychain/proposals/tree_proposals.py +++ b/gerrychain/proposals/tree_proposals.py @@ -46,7 +46,7 @@ def recom( ) flips = recursive_tree_part( - subgraph, + subgraph.graph, parts_to_merge, pop_col=pop_col, pop_target=pop_target, diff --git a/gerrychain/tree.py b/gerrychain/tree.py index 11962f90..88dd023f 100644 --- a/gerrychain/tree.py +++ b/gerrychain/tree.py @@ -15,7 +15,7 @@ def successors(h, root): def random_spanning_tree(graph): """ Builds a spanning tree chosen by Kruskal's method using random weights. - :param graph: Networkx Graph + :param graph: FrozenGraph Important Note: The key is specifically labelled "random_weight" instead of the previously @@ -24,7 +24,7 @@ def random_spanning_tree(graph): This meant that the laplacian would change for the graph step to step, something that we do not intend!! """ - for edge in graph.edges: + for edge in graph.edge_indices: graph.edges[edge]["random_weight"] = random.random() spanning_tree = tree.maximum_spanning_tree( @@ -39,14 +39,14 @@ def uniform_spanning_tree(graph, choice=random.choice): :param graph: Networkx Graph :param choice: :func:`random.choice` """ - root = choice(list(graph.nodes)) + root = choice(graph.node_indices) tree_nodes = set([root]) next_node = {root: None} - for node in graph.nodes: + for node in graph.node_indices: u = node while u not in tree_nodes: - next_node[u] = choice(list(nx.neighbors(graph, u))) + next_node[u] = choice(graph.neighbors(u)) u = next_node[u] u = node @@ -65,12 +65,12 @@ def uniform_spanning_tree(graph, choice=random.choice): class PopulatedGraph: def __init__(self, graph, populations, ideal_pop, epsilon): self.graph = graph - self.subsets = {node: {node} for node in graph} + self.subsets = {node: {node} for node in graph.node_indices} self.population = populations.copy() self.tot_pop = sum(self.population.values()) self.ideal_pop = ideal_pop self.epsilon = epsilon - self._degrees = {node: graph.degree(node) for node in graph} + self._degrees = {node: graph.degree(node) for node in graph.node_indices} def __iter__(self): return iter(self.graph) @@ -194,7 +194,7 @@ def bipartition_tree( tree is not provided :param choice: :func:`random.choice`. Can be substituted for testing. """ - populations = {node: graph.nodes[node][pop_col] for node in graph} + populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} possible_cuts = [] if spanning_tree is None: @@ -224,7 +224,7 @@ def _bipartition_tree_random_all( choice=random.choice, ): """Randomly bipartitions a graph and returns all cuts.""" - populations = {node: graph.nodes[node][pop_col] for node in graph} + populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} possible_cuts = [] if spanning_tree is None: @@ -303,11 +303,12 @@ def recursive_tree_part( :param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts of the partition can be :param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use. + :param method: The partition method to use. :return: New assignments for the nodes of ``graph``. :rtype: dict """ flips = {} - remaining_nodes = set(graph.nodes) + remaining_nodes = graph.node_indices # We keep a running tally of deviation from ``epsilon`` at each partition # and use it to tighten the population constraints on a per-partition # basis such that every partition, including the last partition, has a @@ -376,7 +377,7 @@ def get_seed_chunks( new_epsilon = epsilon chunk_pop = 0 - for node in graph.nodes: + for node in graph.node_indices: chunk_pop += graph.nodes[node][pop_col] while True: diff --git a/gerrychain/updaters/county_splits.py b/gerrychain/updaters/county_splits.py index b2215826..63d5984e 100644 --- a/gerrychain/updaters/county_splits.py +++ b/gerrychain/updaters/county_splits.py @@ -35,8 +35,8 @@ def compute_county_splits(partition, county_field, partition_field): if not partition.parent: county_dict = dict() - for node in partition.graph: - county = partition.graph.nodes[node][county_field] + for node in partition.graph.node_indices: + county = partition.graph.lookup(node, county_field) if county in county_dict: split, nodes, seen = county_dict[county] else: diff --git a/gerrychain/updaters/tally.py b/gerrychain/updaters/tally.py index c521ba80..ccba1344 100644 --- a/gerrychain/updaters/tally.py +++ b/gerrychain/updaters/tally.py @@ -9,6 +9,11 @@ class DataTally: """An updater for tallying numerical data that is not necessarily stored as node attributes """ + __slots__ = [ + "data", + "alias", + "_call" + ] def __init__(self, data, alias): """ @@ -54,6 +59,11 @@ def __call__(self, partition, previous=None): class Tally: """An updater for keeping a tally of one or more node attributes. """ + __slots__ = [ + "fields", + "alias", + "dtype" + ] def __init__(self, fields, alias=None, dtype=int): """ @@ -116,12 +126,12 @@ def _update_tally(self, partition): return new_tally def _get_tally_from_node(self, partition, node): - return sum(partition.graph.nodes[node][field] for field in self.fields) + return sum(partition.graph.lookup(node, field) for field in self.fields) def compute_out_flow(graph, fields, flow): - return sum(graph.nodes[node][field] for node in flow["out"] for field in fields) + return sum(graph.lookup(node, field) for node in flow["out"] for field in fields) def compute_in_flow(graph, fields, flow): - return sum(graph.nodes[node][field] for node in flow["in"] for field in fields) + return sum(graph.lookup(node, field) for node in flow["in"] for field in fields) diff --git a/setup.py b/setup.py index f530b660..0004e757 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,6 @@ install_requires=requirements, keywords="GerryChain", classifiers=[ - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/tests/conftest.py b/tests/conftest.py index 5170c05a..96d72dc0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,7 +56,7 @@ def graph(three_by_three_grid): @pytest.fixture def example_partition(): - graph = networkx.complete_graph(3) + graph = Graph.from_networkx(networkx.complete_graph(3)) assignment = {0: 1, 1: 1, 2: 2} partition = Partition(graph, assignment, {"cut_edges": cut_edges}) return partition diff --git a/tests/constraints/test_validity.py b/tests/constraints/test_validity.py index 00653151..54e61ab2 100644 --- a/tests/constraints/test_validity.py +++ b/tests/constraints/test_validity.py @@ -11,6 +11,7 @@ single_flip_contiguous) from gerrychain.partition import Partition from gerrychain.partition.partition import get_assignment +from gerrychain.graph import Graph @pytest.fixture @@ -18,7 +19,7 @@ def contiguous_partition_with_flips(): graph = nx.Graph() graph.add_nodes_from(range(4)) graph.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0)]) - partition = Partition(graph, {0: 0, 1: 1, 2: 1, 3: 0}) + partition = Partition(Graph.from_networkx(graph), {0: 0, 1: 1, 2: 1, 3: 0}) # This flip will maintain contiguity. return partition, {0: 1} @@ -29,7 +30,7 @@ def discontiguous_partition_with_flips(): graph = nx.Graph() graph.add_nodes_from(range(4)) graph.add_edges_from([(0, 1), (1, 2), (2, 3)]) - partition = Partition(graph, {0: 0, 1: 1, 2: 1, 3: 0}) + partition = Partition(Graph.from_networkx(graph), {0: 0, 1: 1, 2: 1, 3: 0}) # This flip will maintain discontiguity. return partition, {1: 0} diff --git a/tests/partition/test_partition.py b/tests/partition/test_partition.py index 79889492..1a20d6f7 100644 --- a/tests/partition/test_partition.py +++ b/tests/partition/test_partition.py @@ -7,6 +7,7 @@ from gerrychain.partition import GeographicPartition, Partition from gerrychain.proposals import propose_random_flip +from gerrychain.graph import Graph from gerrychain.updaters import cut_edges @@ -17,14 +18,14 @@ def test_Partition_can_be_flipped(example_partition): def test_Partition_misnamed_vertices_raises_keyerror(): - graph = networkx.complete_graph(3) + graph = Graph.from_networkx(networkx.complete_graph(3)) assignment = {"0": 1, "1": 1, "2": 2} with pytest.raises(KeyError): Partition(graph, assignment, {"cut_edges": cut_edges}) def test_Partition_unlabelled_vertices_raises_keyerror(): - graph = networkx.complete_graph(3) + graph = Graph.from_networkx(networkx.complete_graph(3)) assignment = {0: 1, 2: 2} with pytest.raises(KeyError): Partition(graph, assignment, {"cut_edges": cut_edges}) @@ -44,7 +45,7 @@ def test_propose_random_flip_proposes_a_partition(example_partition): @pytest.fixture def example_geographic_partition(): - graph = networkx.complete_graph(3) + graph = Graph.from_networkx(networkx.complete_graph(3)) assignment = {0: 1, 1: 1, 2: 2} for node in graph.nodes: graph.nodes[node]["boundary_node"] = False diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 26334c65..c539d8a3 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -107,4 +107,4 @@ def test_pa_freeze(): result += str(len(partition.cut_edges)) result += str(count) + "\n" - assert hashlib.sha256(result.encode()).hexdigest() == "3bef9ac8c0bfa025fb75e32aea3847757a8fba56b2b2be6f9b3b952088ae3b3c" + assert hashlib.sha256(result.encode()).hexdigest() == "309316e6ca5685c8b3601268b1814a966771e00715a6c69973a8ede810f4c8cf" diff --git a/tests/test_tree.py b/tests/test_tree.py index a518d4a9..d752ee45 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -5,6 +5,7 @@ from gerrychain import MarkovChain from gerrychain.constraints import contiguous, within_percent_of_ideal_population +from gerrychain.graph import Graph from gerrychain.partition import Partition from gerrychain.proposals import recom from gerrychain.tree import ( @@ -22,7 +23,7 @@ def graph_with_pop(three_by_three_grid): for node in three_by_three_grid: three_by_three_grid.nodes[node]["pop"] = 1 - return three_by_three_grid + return Graph.from_networkx(three_by_three_grid) @pytest.fixture @@ -41,7 +42,7 @@ def twelve_by_twelve_with_pop(): grid = networkx.relabel_nodes(xy_grid, nodes) for node in grid: grid.nodes[node]["pop"] = 1 - return grid + return Graph.from_networkx(grid) def test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop): @@ -92,9 +93,9 @@ def test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop): def test_bipartition_tree_returns_a_tree(graph_with_pop): ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 - tree = networkx.Graph( + tree = Graph.from_networkx(networkx.Graph( [(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)] - ) + )) for node in tree: tree.nodes[node]["pop"] = graph_with_pop.nodes[node]["pop"] @@ -123,9 +124,9 @@ def test_recom_works_as_a_proposal(partition_with_pop): def test_find_balanced_cuts_contraction(): - tree = networkx.Graph( + tree = Graph.from_networkx(networkx.Graph( [(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)] - ) + )) # 0 - 1 - 2 # || diff --git a/tests/updaters/test_updaters.py b/tests/updaters/test_updaters.py index d179896b..20e2136a 100644 --- a/tests/updaters/test_updaters.py +++ b/tests/updaters/test_updaters.py @@ -5,6 +5,7 @@ from gerrychain import MarkovChain from gerrychain.constraints import Validator, no_vanishing_districts +from gerrychain.graph import Graph from gerrychain.partition import Partition from gerrychain.proposals import propose_random_flip from gerrychain.random import random @@ -62,7 +63,7 @@ def test_Partition_can_update_stats(): updaters = {"total_stat": Tally("stat", alias="total_stat")} - partition = Partition(graph, assignment, updaters) + partition = Partition(Graph.from_networkx(graph), assignment, updaters) assert partition["total_stat"][2] == 3 flip = {1: 2}