Skip to content

Commit

Permalink
cleans code and eliminates redundant data in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
pveigadecamargo committed May 28, 2024
1 parent f06e407 commit cd3c223
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 47 deletions.
1 change: 1 addition & 0 deletions aequilibrae/project/network/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, net):
self.__table_type__ = "links"
self.__fields = []
self.__items = {}
self.__data = None

if self.sql == "":
self.refresh_fields()
Expand Down
47 changes: 5 additions & 42 deletions aequilibrae/transit/lib_gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,10 @@
from .map_matching_graph import MMGraph
from ..utils.worker_thread import WorkerThread

spec = iutil.find_spec("PyQt5")
pyqt = spec is not None
if pyqt:
from PyQt5.QtCore import pyqtSignal as SignalImpl
else:

class SignalImpl:
def __init__(self, *args, **kwargs):
pass

def emit(*args, **kwargs):
pass


class GTFSRouteSystemBuilder(WorkerThread):
"""Container for GTFS feeds providing data retrieval for the importer"""

signal = SignalImpl(object)

def __init__(self, network, agency_identifier, file_path, day="", description="", capacities={}): # noqa: B006
"""Instantiates a transit class for the network
Expand All @@ -50,7 +35,8 @@ def __init__(self, network, agency_identifier, file_path, day="", description=""
WorkerThread.__init__(self, None)

self.__network = network
self.geotool = get_active_project(False)
self.project = get_active_project(False)
self.geo_links = self.project.network.links.data
self.archive_dir = None # type: str
self.day = day
self.logger = logger
Expand All @@ -70,7 +56,7 @@ def __init__(self, network, agency_identifier, file_path, day="", description=""
self.__do_execute_map_matching = False
self.__target_date__ = None
self.__outside_zones = 0
self.__has_taz = 1 if len(self.geotool.zoning.all_zones()) > 0 else 0
self.__has_taz = 1 if len(self.project.zoning.all_zones()) > 0 else 0

if file_path is not None:
self.logger.info(f"Creating GTFS feed object for {file_path}")
Expand Down Expand Up @@ -140,10 +126,7 @@ def map_match(self, route_types=[3]) -> None: # noqa: B006
if any(not isinstance(item, int) for item in route_types):
raise TypeError("All route types must be integers")

mt = f"Map-matching routes for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_patterns.keys()), "Map-matching", mt])
for i, pat in enumerate(self.select_patterns.values()):
self.signal.emit(["update", "secondary", i + 1, "Map-matching", mt])
if pat.route_type in route_types:
pat.map_match()
msg = pat.get_error("stop_from_pattern")
Expand Down Expand Up @@ -202,7 +185,6 @@ def load_date(self, service_date: str) -> None:
def doWork(self):
"""Alias for execute_import"""
self.execute_import()
self.finished()

def execute_import(self):
self.logger.debug("Starting execute_import")
Expand All @@ -215,35 +197,25 @@ def execute_import(self):

self.logger.info(f" Importing feed for agency {self.gtfs_data.agency.agency} on {self.day}")
self.__mt = f"Importing {self.gtfs_data.agency.agency} to supply"
self.signal.emit(["start", "master", 1, self.day, self.__mt])

self.save_to_disk()

def save_to_disk(self):
"""Saves all transit elements built in memory to disk"""

with closing(database_connection("transit")) as conn:
st = f"Importing routes for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_routes.keys()), st, self.__mt])
for counter, (_, pattern) in enumerate(self.select_patterns.items()):
pattern.save_to_database(conn, commit=False)
self.signal.emit(["update", "secondary", counter + 1, st, self.__mt])
conn.commit()

self.gtfs_data.agency.save_to_database(conn)

st = f"Importing trips for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_trips), st, self.__mt])
for counter, trip in enumerate(self.select_trips):
trip.save_to_database(conn, commit=False)
self.signal.emit(["update", "secondary", counter + 1, st, self.__mt])
conn.commit()

st = f"Importing links for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_links.keys()), st, self.__mt])
for counter, (_, link) in enumerate(self.select_links.items()):
link.save_to_database(conn, commit=False)
self.signal.emit(["update", "secondary", counter + 1, st, self.__mt])
conn.commit()

self.__outside_zones = 0
Expand All @@ -263,27 +235,21 @@ def save_to_disk(self):
for fare_rule in self.gtfs_data.fare_rules:
fare_rule.save_to_database(conn)

st = f"Importing stops for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_stops.keys()), st, self.__mt])
for counter, (_, stop) in enumerate(self.select_stops.items()):
if stop.zone in zone_ids:
stop.zone_id = zone_ids[stop.zone]
if self.__has_taz:
closest_zone = self.geotool.zoning.get_closest_zone(stop.geo)
if stop.geo.within(self.geotool.zoning.get(closest_zone).geometry):
closest_zone = self.project.zoning.get_closest_zone(stop.geo)
if stop.geo.within(self.project.zoning.get(closest_zone).geometry):
stop.taz = closest_zone
stop.save_to_database(conn, commit=False)
self.signal.emit(["update", "secondary", counter + 1, st, self.__mt])
conn.commit()

self.__outside_zones = None in [x.taz for x in self.select_stops.values()]
if self.__outside_zones:
msg = " Some stops are outside the zoning system. Check the result on a map and see the log for info"
self.logger.warning(msg)

def finished(self):
self.signal.emit(["finished_static_gtfs_procedure"])

def __build_data(self):
self.logger.debug("Starting __build_data")
self.__get_routes_by_date()
Expand All @@ -296,10 +262,7 @@ def __build_data(self):
self.builds_link_graphs_with_broken_stops()

c = Constants()
msg_txt = f"Building data for {self.gtfs_data.agency.agency}"
self.signal.emit(["start", "secondary", len(self.select_routes), msg_txt, self.__mt])
for counter, (route_id, route) in enumerate(self.select_routes.items()):
self.signal.emit(["update", "secondary", counter + 1, msg_txt, self.__mt])
new_trips = self._get_trips_by_date_and_route(route_id, self.day)

all_pats = [trip.pattern_hash for trip in new_trips]
Expand Down
2 changes: 1 addition & 1 deletion aequilibrae/transit/map_matching_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MMGraph(WorkerThread):

def __init__(self, lib_gtfs, mtmm):
WorkerThread.__init__(self, None)
self.geotool = lib_gtfs.geotool
self.geotool = lib_gtfs.project
self.stops = lib_gtfs.gtfs_data.stops
self.lib_gtfs = lib_gtfs
self.__mtmm = mtmm
Expand Down
7 changes: 3 additions & 4 deletions aequilibrae/transit/transit_elements/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class Pattern(BasicPTElement):
def __init__(self, route_id, gtfs_feed) -> None:
"""
:Arguments:
*pattern_id* (:obj:`str`): Pre-computed ID for this pattern
*route_id* (:obj:`str`): route ID for which this stop pattern belongs
*geotool* (:obj:`Geo`): Suite of geographic utilities.
*gtfs_feed* (:obj:`Geo`): Parent feed object
"""
self.pattern_hash = ""
self.pattern_id = -1
Expand All @@ -42,8 +42,7 @@ def __init__(self, route_id, gtfs_feed) -> None:
self.seated_capacity = None
self.total_capacity = None
self.__srid = get_srid()
self.__geotool = gtfs_feed.geotool
self.__geolinks = self.__geotool.network.links.data
self.__geolinks = gtfs_feed.geo_links
self.__logger = logger

self.__feed = gtfs_feed
Expand Down

0 comments on commit cd3c223

Please sign in to comment.