diff --git a/aequilibrae/transit/transit_elements/pattern.py b/aequilibrae/transit/transit_elements/pattern.py index b5a3b0dd2..539d62a08 100644 --- a/aequilibrae/transit/transit_elements/pattern.py +++ b/aequilibrae/transit/transit_elements/pattern.py @@ -56,7 +56,7 @@ def __init__(self, route_id, gtfs_feed) -> None: self.network_candidates = [] self.full_path: List[int] = [] self.fpath_dir: List[int] = [] - self.pattern_mapping = [] + self.pattern_mapping = pd.DataFrame([]) self.stops = [] self.__map_matching_error = {} @@ -94,13 +94,14 @@ def save_to_database(self, conn: Connection, commit=True) -> None: total_capacity, geometry) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ST_Multi(GeomFromWKB(?, ?)));""" conn.execute(sql, data) - if self.pattern_mapping and self.shape: + if self.pattern_mapping.shape[0]: sqlgeo = """insert into pattern_mapping (pattern_id, seq, link, dir, geometry) values (?, ?, ?, ?, GeomFromWKB(?, ?));""" sql = "insert into pattern_mapping (pattern_id, seq, link, dir) values (?, ?, ?, ?);" if "wkb" in self.pattern_mapping.columns: - data = self.pattern_mapping[["pattern_id", "seq", "link_id", "dir", "wkb", "srid"]].to_records() + cols = ["pattern_id", "seq", "link_id", "dir", "wkb", "srid"] + data = self.pattern_mapping[cols].to_records(index=False) conn.executemany(sqlgeo, data) else: data = self.pattern_mapping[["pattern_id", "seq", "link_id", "dir"]].to_records()