Skip to content

Commit

Permalink
Period, Periods, TransitGraph, and TransitGraphBuilder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake-Moss committed Dec 15, 2023
1 parent fced80b commit cb22e11
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aequilibrae/project/basic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def extent(self) -> Polygon:
Returns:
*model extent* (:obj:`Polygon`): Shapely polygon with the bounding box of the layer.
"""
self.__curr.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))')
poly = shapely.wkb.loads(self.__curr.fetchone()[0])
self._curr.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))')
poly = shapely.wkb.loads(self._curr.fetchone()[0])
return poly

@property
Expand All @@ -33,7 +33,7 @@ def fields(self) -> FieldEditor:
def refresh_connection(self):
"""Opens a new database connection to avoid thread conflict"""
self.conn = self.project.connect()
self.__curr = self.conn.cursor()
self._curr = self.conn.cursor()

def __copy__(self):
raise Exception(f"{self.__table_type__} object cannot be copied")
Expand Down
4 changes: 4 additions & 0 deletions aequilibrae/project/network/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def renumber(self, new_id: int):
"""

new_id = int(new_id)

if new_id == 1 or self.period_id == 1:
raise ValueError("You cannot renumber, or renumber another period to the default period.")

if new_id == self.period_id:
self._logger.warning("This is already the period number")
return
Expand Down
52 changes: 52 additions & 0 deletions tests/aequilibrae/paths/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import os
import tempfile
import numpy as np
import pandas as pd
from aequilibrae.paths import Graph
from os.path import join
from uuid import uuid4
from .parameters_test import centroids
from aequilibrae.project import Project
from ...data import siouxfalls_project
from aequilibrae.paths.results import PathResults
from aequilibrae.utils.create_example import create_example
from aequilibrae.transit import Transit


# Adds the folder with the data to the path and collects the paths to the files
# lib_path = os.path.abspath(os.path.join('..', '../tests'))
Expand Down Expand Up @@ -88,3 +92,51 @@ def test_exclude_links(self):
r1.prepare(self.graph)
r1.compute_path(20, 21)
self.assertEqual(list(r1.path), [63, 69])


class TestTransitGraph(TestCase):
def setUp(self) -> None:
os.environ["PATH"] = os.path.join(tempfile.gettempdir(), "temp_data") + ";" + os.environ["PATH"]
self.temp_proj_folder = os.path.join(tempfile.gettempdir(), uuid4().hex)

self.project = create_example(self.temp_proj_folder, "coquimbo")

#### Patch
os.remove(os.path.join(self.temp_proj_folder, "public_transport.sqlite"))
patches = [
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/periods.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/triggers/periods_triggers.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/transit_graph_configs.sql",
]
for patch in patches:
with open(patch) as f:
for statement in f.read().split("--#"):
self.project.conn.execute(statement)

self.project.conn.commit()
self.project.network.periods.refresh_fields()
self.data = Transit(self.project)
dest_path = join(self.temp_proj_folder, "gtfs_coquimbo.zip")
self.transit = self.data.new_gtfs_builder(agency="LISANCO", file_path=dest_path)
self.transit.load_date("2016-04-13")
self.transit.save_to_disk()
#### Patch end

self.graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method="nearest neighbour",
)

self.transit_graph = self.graph.to_transit_graph()

def tearDown(self) -> None:
self.project.close()

def test_transit_graph_config(self):
self.assertEqual(self.graph.config, self.transit_graph._config)

def test_transit_graph_od_node_mapping(self):
pd.testing.assert_frame_equal(self.graph.od_node_mapping, self.transit_graph.od_node_mapping)

143 changes: 143 additions & 0 deletions tests/aequilibrae/paths/test_transit_graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from unittest import TestCase
import os
import tempfile
import numpy as np
from aequilibrae.paths import TransitGraph
from os.path import join
from uuid import uuid4
from aequilibrae.project import Project
from ...data import siouxfalls_project
from aequilibrae.paths.results import PathResults
from aequilibrae.utils.create_example import create_example
from aequilibrae.transit import Transit


# Adds the folder with the data to the path and collects the paths to the files
# lib_path = os.path.abspath(os.path.join('..', '../tests'))
# sys.path.append(lib_path)
from ...data import path_test, test_graph, test_network
from shutil import copytree, rmtree


class TestTransitGraphBuilder(TestCase):
def setUp(self) -> None:
os.environ["PATH"] = os.path.join(tempfile.gettempdir(), "temp_data") + ";" + os.environ["PATH"]
self.temp_proj_folder = os.path.join(tempfile.gettempdir(), uuid4().hex)

self.project = create_example(self.temp_proj_folder, "coquimbo")

os.remove(os.path.join(self.temp_proj_folder, "public_transport.sqlite"))

#### Patch
patches = [
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/periods.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/triggers/periods_triggers.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/transit_graph_configs.sql",
]
for patch in patches:
with open(patch) as f:
for statement in f.read().split("--#"):
self.project.conn.execute(statement)

self.project.conn.commit()
self.project.network.periods.refresh_fields()
#### Patch end

self.data = Transit(self.project)
dest_path = join(self.temp_proj_folder, "gtfs_coquimbo.zip")
self.transit = self.data.new_gtfs_builder(agency="LISANCO", file_path=dest_path)

self.transit.load_date("2016-04-13")
self.transit.save_to_disk()

def tearDown(self) -> None:
self.project.close()

def test_create_line_gemoetry(self):
self.project.network.build_graphs()
for connector_method in ["overlapping_regions", "nearest_neighbour"]:
for method in ["connector project match", "direct"]:
with self.subTest(connector_method=connector_method, method=method):
graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

self.assertNotIn("geometry", graph.edges.columns)

graph.create_line_geometry(method=method, graph="c")

self.assertIn("geometry", graph.edges.columns)
self.assertTrue(graph.edges.geometry.all())

def test_connector_methods(self):
connector_method = "nearest_neighbour"
graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

nearest_neighbour_connector_count = len(graph.edges[graph.edges.link_type == "access_connector"])
self.assertEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "egress_connector"]))
self.assertEqual(
nearest_neighbour_connector_count,
len(graph.vertices[graph.vertices.node_type == "stop"]),
)

connector_method = "overlapping_regions"
graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

self.assertLessEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "access_connector"]))
self.assertEqual(
len(graph.edges[graph.edges.link_type == "access_connector"]),
len(graph.edges[graph.edges.link_type == "egress_connector"]),
)

def test_connector_method_exception(self):
connector_method = "something not right"
with self.assertRaises(ValueError):
self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

def test_connector_method_without_missing(self):
connector_method = "nearest_neighbour"
graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

nearest_neighbour_connector_count = len(graph.edges[graph.edges.link_type == "access_connector"])
self.assertEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "egress_connector"]))
self.assertEqual(
nearest_neighbour_connector_count,
len(graph.vertices[graph.vertices.node_type == "stop"]),
)

connector_method = "overlapping_regions"
graph = self.data.create_graph(
with_outer_stop_transfers=False,
with_walking_edges=False,
blocking_centroid_flows=False,
connector_method=connector_method,
)

self.assertLessEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "access_connector"]))
self.assertEqual(
len(graph.edges[graph.edges.link_type == "access_connector"]),
len(graph.edges[graph.edges.link_type == "egress_connector"]),
)
113 changes: 113 additions & 0 deletions tests/aequilibrae/project/test_period.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from unittest import TestCase
from sqlite3 import IntegrityError
import os
from shutil import copytree, rmtree
from random import randint, random
import uuid
from tempfile import gettempdir
from shapely.geometry import Point
import shapely.wkb
from aequilibrae.project import Project
import pandas as pd

from ...data import siouxfalls_project


class TestPeriod(TestCase):
def setUp(self) -> None:
os.environ["PATH"] = os.path.join(gettempdir(), "temp_data") + ";" + os.environ["PATH"]

self.proj_dir = os.path.join(gettempdir(), uuid.uuid4().hex)
copytree(siouxfalls_project, self.proj_dir)

self.project = Project()
self.project.open(self.proj_dir)
self.network = self.project.network
self.curr = self.project.conn.cursor()

#### Patch
patches = [
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/periods.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/triggers/periods_triggers.sql",
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/transit_graph_configs.sql",
]
for patch in patches:
with open(patch) as f:
for statement in f.read().split("--#"):
self.project.conn.execute(statement)

self.project.conn.commit()
self.project.network.periods.refresh_fields()
#### Patch end

for num in range(2, 6):
self.project.network.periods.new_period(num, num, num, "test")

def tearDown(self) -> None:
self.curr.close()
self.project.close()
try:
rmtree(self.proj_dir)
except Exception as e:
print(f"Failed to remove at {e.args}")

def test_save_and_assignment(self):
periods = self.network.periods
nd = randint(2, 5)
period = periods.get(nd)

with self.assertRaises(AttributeError):
period.modes = "abc"

with self.assertRaises(AttributeError):
period.link_types = "default"

with self.assertRaises(AttributeError):
period.period_id = 2

period.period_description = "test"
self.assertEqual(period.period_description, "test")

period.save()

expected = pd.DataFrame(
{
"period_id": [1, nd],
"period_start": [0, nd],
"period_end": [86400, nd],
}
)
expected["period_description"] = "test"
expected.at[0, "period_description"] = "Default time period, whole day"

pd.testing.assert_frame_equal(periods.data, expected)

def test_data_fields(self):
periods = self.network.periods

period = periods.get(1)

fields = sorted(period.data_fields())
self.curr.execute("pragma table_info(periods)")
dt = self.curr.fetchall()

actual_fields = sorted([x[1] for x in dt if x[1] != "ogc_fid"])

self.assertEqual(fields, actual_fields, "Period has unexpected set of fields")

def test_renumber(self):
periods = self.network.periods

period = periods.get(1)

with self.assertRaises(ValueError):
period.renumber(1)

num = randint(25, 2000)
with self.assertRaises(ValueError):
period.renumber(num)

new_period = periods.new_period(num, 0, 0, "test")
new_period.renumber(num + 1)

self.assertEqual(new_period.period_id, num + 1)
Loading

0 comments on commit cb22e11

Please sign in to comment.