-
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Period, Periods, TransitGraph, and TransitGraphBuilder tests
- Loading branch information
Showing
6 changed files
with
422 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from unittest import TestCase | ||
import os | ||
import tempfile | ||
import numpy as np | ||
from aequilibrae.paths import TransitGraph | ||
from os.path import join | ||
from uuid import uuid4 | ||
from aequilibrae.project import Project | ||
from ...data import siouxfalls_project | ||
from aequilibrae.paths.results import PathResults | ||
from aequilibrae.utils.create_example import create_example | ||
from aequilibrae.transit import Transit | ||
|
||
|
||
# Adds the folder with the data to the path and collects the paths to the files | ||
# lib_path = os.path.abspath(os.path.join('..', '../tests')) | ||
# sys.path.append(lib_path) | ||
from ...data import path_test, test_graph, test_network | ||
from shutil import copytree, rmtree | ||
|
||
|
||
class TestTransitGraphBuilder(TestCase): | ||
def setUp(self) -> None: | ||
os.environ["PATH"] = os.path.join(tempfile.gettempdir(), "temp_data") + ";" + os.environ["PATH"] | ||
self.temp_proj_folder = os.path.join(tempfile.gettempdir(), uuid4().hex) | ||
|
||
self.project = create_example(self.temp_proj_folder, "coquimbo") | ||
|
||
os.remove(os.path.join(self.temp_proj_folder, "public_transport.sqlite")) | ||
|
||
#### Patch | ||
patches = [ | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/periods.sql", | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/triggers/periods_triggers.sql", | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/transit_graph_configs.sql", | ||
] | ||
for patch in patches: | ||
with open(patch) as f: | ||
for statement in f.read().split("--#"): | ||
self.project.conn.execute(statement) | ||
|
||
self.project.conn.commit() | ||
self.project.network.periods.refresh_fields() | ||
#### Patch end | ||
|
||
self.data = Transit(self.project) | ||
dest_path = join(self.temp_proj_folder, "gtfs_coquimbo.zip") | ||
self.transit = self.data.new_gtfs_builder(agency="LISANCO", file_path=dest_path) | ||
|
||
self.transit.load_date("2016-04-13") | ||
self.transit.save_to_disk() | ||
|
||
def tearDown(self) -> None: | ||
self.project.close() | ||
|
||
def test_create_line_gemoetry(self): | ||
self.project.network.build_graphs() | ||
for connector_method in ["overlapping_regions", "nearest_neighbour"]: | ||
for method in ["connector project match", "direct"]: | ||
with self.subTest(connector_method=connector_method, method=method): | ||
graph = self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
self.assertNotIn("geometry", graph.edges.columns) | ||
|
||
graph.create_line_geometry(method=method, graph="c") | ||
|
||
self.assertIn("geometry", graph.edges.columns) | ||
self.assertTrue(graph.edges.geometry.all()) | ||
|
||
def test_connector_methods(self): | ||
connector_method = "nearest_neighbour" | ||
graph = self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
nearest_neighbour_connector_count = len(graph.edges[graph.edges.link_type == "access_connector"]) | ||
self.assertEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "egress_connector"])) | ||
self.assertEqual( | ||
nearest_neighbour_connector_count, | ||
len(graph.vertices[graph.vertices.node_type == "stop"]), | ||
) | ||
|
||
connector_method = "overlapping_regions" | ||
graph = self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
self.assertLessEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "access_connector"])) | ||
self.assertEqual( | ||
len(graph.edges[graph.edges.link_type == "access_connector"]), | ||
len(graph.edges[graph.edges.link_type == "egress_connector"]), | ||
) | ||
|
||
def test_connector_method_exception(self): | ||
connector_method = "something not right" | ||
with self.assertRaises(ValueError): | ||
self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
def test_connector_method_without_missing(self): | ||
connector_method = "nearest_neighbour" | ||
graph = self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
nearest_neighbour_connector_count = len(graph.edges[graph.edges.link_type == "access_connector"]) | ||
self.assertEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "egress_connector"])) | ||
self.assertEqual( | ||
nearest_neighbour_connector_count, | ||
len(graph.vertices[graph.vertices.node_type == "stop"]), | ||
) | ||
|
||
connector_method = "overlapping_regions" | ||
graph = self.data.create_graph( | ||
with_outer_stop_transfers=False, | ||
with_walking_edges=False, | ||
blocking_centroid_flows=False, | ||
connector_method=connector_method, | ||
) | ||
|
||
self.assertLessEqual(nearest_neighbour_connector_count, len(graph.edges[graph.edges.link_type == "access_connector"])) | ||
self.assertEqual( | ||
len(graph.edges[graph.edges.link_type == "access_connector"]), | ||
len(graph.edges[graph.edges.link_type == "egress_connector"]), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from unittest import TestCase | ||
from sqlite3 import IntegrityError | ||
import os | ||
from shutil import copytree, rmtree | ||
from random import randint, random | ||
import uuid | ||
from tempfile import gettempdir | ||
from shapely.geometry import Point | ||
import shapely.wkb | ||
from aequilibrae.project import Project | ||
import pandas as pd | ||
|
||
from ...data import siouxfalls_project | ||
|
||
|
||
class TestPeriod(TestCase): | ||
def setUp(self) -> None: | ||
os.environ["PATH"] = os.path.join(gettempdir(), "temp_data") + ";" + os.environ["PATH"] | ||
|
||
self.proj_dir = os.path.join(gettempdir(), uuid.uuid4().hex) | ||
copytree(siouxfalls_project, self.proj_dir) | ||
|
||
self.project = Project() | ||
self.project.open(self.proj_dir) | ||
self.network = self.project.network | ||
self.curr = self.project.conn.cursor() | ||
|
||
#### Patch | ||
patches = [ | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/periods.sql", | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/triggers/periods_triggers.sql", | ||
"/home/jake/Software/aequilibrae/aequilibrae/project/database_specification/network/tables/transit_graph_configs.sql", | ||
] | ||
for patch in patches: | ||
with open(patch) as f: | ||
for statement in f.read().split("--#"): | ||
self.project.conn.execute(statement) | ||
|
||
self.project.conn.commit() | ||
self.project.network.periods.refresh_fields() | ||
#### Patch end | ||
|
||
for num in range(2, 6): | ||
self.project.network.periods.new_period(num, num, num, "test") | ||
|
||
def tearDown(self) -> None: | ||
self.curr.close() | ||
self.project.close() | ||
try: | ||
rmtree(self.proj_dir) | ||
except Exception as e: | ||
print(f"Failed to remove at {e.args}") | ||
|
||
def test_save_and_assignment(self): | ||
periods = self.network.periods | ||
nd = randint(2, 5) | ||
period = periods.get(nd) | ||
|
||
with self.assertRaises(AttributeError): | ||
period.modes = "abc" | ||
|
||
with self.assertRaises(AttributeError): | ||
period.link_types = "default" | ||
|
||
with self.assertRaises(AttributeError): | ||
period.period_id = 2 | ||
|
||
period.period_description = "test" | ||
self.assertEqual(period.period_description, "test") | ||
|
||
period.save() | ||
|
||
expected = pd.DataFrame( | ||
{ | ||
"period_id": [1, nd], | ||
"period_start": [0, nd], | ||
"period_end": [86400, nd], | ||
} | ||
) | ||
expected["period_description"] = "test" | ||
expected.at[0, "period_description"] = "Default time period, whole day" | ||
|
||
pd.testing.assert_frame_equal(periods.data, expected) | ||
|
||
def test_data_fields(self): | ||
periods = self.network.periods | ||
|
||
period = periods.get(1) | ||
|
||
fields = sorted(period.data_fields()) | ||
self.curr.execute("pragma table_info(periods)") | ||
dt = self.curr.fetchall() | ||
|
||
actual_fields = sorted([x[1] for x in dt if x[1] != "ogc_fid"]) | ||
|
||
self.assertEqual(fields, actual_fields, "Period has unexpected set of fields") | ||
|
||
def test_renumber(self): | ||
periods = self.network.periods | ||
|
||
period = periods.get(1) | ||
|
||
with self.assertRaises(ValueError): | ||
period.renumber(1) | ||
|
||
num = randint(25, 2000) | ||
with self.assertRaises(ValueError): | ||
period.renumber(num) | ||
|
||
new_period = periods.new_period(num, 0, 0, "test") | ||
new_period.renumber(num + 1) | ||
|
||
self.assertEqual(new_period.period_id, num + 1) |
Oops, something went wrong.