diff --git a/pygtfs/gtfs_entities.py b/pygtfs/gtfs_entities.py index 8244006..d8bd42b 100644 --- a/pygtfs/gtfs_entities.py +++ b/pygtfs/gtfs_entities.py @@ -26,6 +26,7 @@ def _validate_date(*field_names): @validates(*field_names) def make_date(self, key, value): return datetime.datetime.strptime(value, '%Y%m%d').date() + return make_date @@ -37,6 +38,7 @@ def time_delta(self, key, value): (hours, minutes, seconds) = map(int, value.split(":")) return datetime.timedelta(hours=hours, minutes=minutes, seconds=seconds) + return time_delta @@ -47,6 +49,7 @@ def int_bool(self, key, value): raise PygtfsValidationError("{0} must be 0 or 1, " "was {1}".format(key, value)) return value == "1" + return int_bool @@ -64,18 +67,22 @@ def in_range(self, key, value): raise PygtfsValidationError( "{0} must be in range {1}, was {2}".format(key, int_choice, value)) return int_value + return in_range -def _validate_float_range(float_min, float_max, *field_names): +def _validate_float_range(float_min, float_max, nullable, *field_names): @validates(*field_names) def in_range(self, key, value): + if nullable and value is None: + return None float_value = float(value) if not (float_min <= float_value <= float_max): raise PygtfsValidationError( "{0} must be in range [{1}, {2}]," " was {2}".format(key, float_min, float_max, value)) return float_value + return in_range @@ -85,6 +92,7 @@ def is_float_none(self, key, value): if value is None or value == "": return None return float(value) + return is_float_none @@ -145,8 +153,9 @@ class Stop(Base): stop_code = Column(Unicode, nullable=True, index=True) stop_name = Column(Unicode) stop_desc = Column(Unicode, nullable=True) - stop_lat = Column(Float) - stop_lon = Column(Float) + nullable_lat_long = True + stop_lat = Column(Float, nullable=nullable_lat_long) + stop_lon = Column(Float, nullable=nullable_lat_long) zone_id = Column(Unicode, nullable=True) stop_url = Column(Unicode, nullable=True) location_type = Column(Integer, nullable=True) @@ -162,8 +171,8 @@ class Stop(Base): _validate_location = _validate_int_choice([None, 0, 1, 2, 3, 4], 'location_type') _validate_wheelchair = _validate_int_choice([None, 0, 1, 2], 'wheelchair_boarding') - _validate_lon_lat = _validate_float_range(-180, 180, 'stop_lon', - 'stop_lat') + _validate_lon_lat = _validate_float_range(-180, 180, nullable_lat_long, + 'stop_lon', 'stop_lat') def __repr__(self): return '' % (self.stop_id, self.stop_name) @@ -189,8 +198,8 @@ class Route(Base): ) agency = relationship(Agency, backref="routes", - primaryjoin=and_(Agency.agency_id==foreign(agency_id), - Agency.feed_id==feed_id)) + primaryjoin=and_(Agency.agency_id == foreign(agency_id), + Agency.feed_id == feed_id)) # https://developers.google.com/transit/gtfs/reference/extended-route-types valid_extended_route_types = [ @@ -227,6 +236,7 @@ class ShapePoint(Base): _plural_name_ = 'shapes' feed_id = Column(Integer, ForeignKey('_feed.feed_id'), primary_key=True) shape_id = Column(Unicode, primary_key=True) + nullable_lat_long = False shape_pt_lat = Column(Float) shape_pt_lon = Column(Float) shape_pt_sequence = Column(Integer, primary_key=True) @@ -236,7 +246,7 @@ class ShapePoint(Base): Index('idx_shape_for_trips', feed_id, shape_id), ) - _validate_lon_lat = _validate_float_range(-180, 180, + _validate_lon_lat = _validate_float_range(-180, 180, nullable_lat_long, 'shape_pt_lon', 'shape_pt_lat') _validate_shape_dist_traveled = _validate_float_none('shape_dist_traveled') @@ -316,18 +326,17 @@ class Trip(Base): ) route = relationship(Route, backref="trips", - primaryjoin=and_(Route.route_id==foreign(route_id), - Route.feed_id==feed_id)) + primaryjoin=and_(Route.route_id == foreign(route_id), + Route.feed_id == feed_id)) shape_points = relationship(ShapePoint, backref="trips", - secondary="_trip_shapes") + secondary="_trip_shapes") # TODO: The service_id references to calendar or to calendar_dates. # Need to implement this requirement, but not using a simple foreign key. service = relationship(Service, backref='trips', - primaryjoin=and_(foreign(service_id) == Service.service_id, - feed_id == Service.feed_id)) - + primaryjoin=and_(foreign(service_id) == Service.service_id, + feed_id == Service.feed_id)) _validate_direction_id = _validate_int_choice([None, 0, 1], 'direction_id') _validate_wheelchair = _validate_int_choice([None, 0, 1, 2], @@ -374,11 +383,11 @@ class StopTime(Base): ) stop = relationship(Stop, backref='stop_times', - primaryjoin=and_(Stop.stop_id==foreign(stop_id), - Stop.feed_id==feed_id)) + primaryjoin=and_(Stop.stop_id == foreign(stop_id), + Stop.feed_id == feed_id)) trip = relationship(Trip, backref="stop_times", - primaryjoin=and_(Trip.trip_id==foreign(trip_id), - Trip.feed_id==feed_id)) + primaryjoin=and_(Trip.trip_id == foreign(trip_id), + Trip.feed_id == feed_id)) _validate_pickup_drop_off = _validate_int_choice([None, 0, 1, 2, 3], 'pickup_type', @@ -432,8 +441,8 @@ class FareRule(Base): ) route = relationship(Route, backref="fare_rules", - primaryjoin=and_(Route.route_id==foreign(route_id), - Route.feed_id==feed_id)) + primaryjoin=and_(Route.route_id == foreign(route_id), + Route.feed_id == feed_id)) def __repr__(self): return '' % (self.fare_id, @@ -458,8 +467,8 @@ class Frequency(Base): ) trip = relationship(Trip, backref="frequencies", - primaryjoin=and_(Trip.trip_id==foreign(trip_id), - Trip.feed_id==feed_id)) + primaryjoin=and_(Trip.trip_id == foreign(trip_id), + Trip.feed_id == feed_id)) _validate_exact_times = _validate_int_choice([None, 0, 1], 'exact_times') _validate_deltas = _validate_time_delta('start_time', 'end_time') @@ -507,7 +516,7 @@ class Transfer(Base): primaryjoin=and_(Trip.trip_id == foreign(to_trip_id), Trip.feed_id == feed_id)) - _validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3], + _validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3, 4, 5], 'transfer_type') def __repr__(self): @@ -543,10 +552,10 @@ def __repr__(self): Column('trans_id', Unicode), Column('lang', Unicode), ForeignKeyConstraint(['stop_feed_id', 'stop_id'], [Stop.feed_id, Stop.stop_id]), - ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'], [Translation.feed_id, Translation.trans_id, Translation.lang]), + ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'], + [Translation.feed_id, Translation.trans_id, Translation.lang]), ) - _trip_shapes = Table( '_trip_shapes', Base.metadata, Column('trip_feed_id', Integer), @@ -556,10 +565,9 @@ def __repr__(self): Column('shape_pt_sequence', Integer), ForeignKeyConstraint(['trip_feed_id', 'trip_id'], [Trip.feed_id, Trip.trip_id]), ForeignKeyConstraint(['shape_feed_id', 'shape_id', 'shape_pt_sequence'], - [ShapePoint.feed_id, ShapePoint.shape_id, ShapePoint.shape_pt_sequence]), + [ShapePoint.feed_id, ShapePoint.shape_id, ShapePoint.shape_pt_sequence]), ) - # a feed can skip Service (calendar) if it has ServiceException(calendar_dates) gtfs_required = {Agency, Stop, Route, Trip, StopTime} gtfs_calendar = {Service, ServiceException} diff --git a/pygtfs/loader.py b/pygtfs/loader.py index 9905ff2..1972bc8 100644 --- a/pygtfs/loader.py +++ b/pygtfs/loader.py @@ -1,17 +1,16 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) -from datetime import date import sys +from datetime import date import six -from sqlalchemy import and_ -from sqlalchemy.sql.expression import select, join -from .gtfs_entities import (Feed, Service, ServiceException, gtfs_required, +from . import feed +from .exceptions import PygtfsException +from .gtfs_entities import (Feed, gtfs_required, Translation, Stop, Trip, ShapePoint, _stop_translations, _trip_shapes, gtfs_calendar, gtfs_all) -from . import feed def list_feeds(schedule): @@ -21,7 +20,6 @@ def list_feeds(schedule): def delete_feed(schedule, feed_filename, interactive=False): - feed_name = feed.derive_feed_name(feed_filename) feeds_with_name = schedule.session.query(Feed).filter(Feed.feed_name == feed_name).all() delete_all = not interactive @@ -46,8 +44,7 @@ def overwrite_feed(schedule, feed_filename, *args, **kwargs): def append_feed(schedule, feed_filename, strip_fields=True, - chunk_size=5000, agency_id_override=None): - + chunk_size=5000, agency_id_override=None, ignore_failures=False): fd = feed.Feed(feed_filename, strip_fields) gtfs_tables = {} @@ -77,7 +74,8 @@ def append_feed(schedule, feed_filename, strip_fields=True, continue gtfs_table = gtfs_tables[gtfs_class] - + skipped_records = 0 + read_records = 0 for i, record in enumerate(gtfs_table): if not record: # Empty row. @@ -85,16 +83,19 @@ def append_feed(schedule, feed_filename, strip_fields=True, try: instance = gtfs_class(feed_id=feed_id, **record._asdict()) + schedule.session.add(instance) + read_records += 1 except: - print("Failure while writing {0}".format(record)) - raise - schedule.session.add(instance) + skipped_records += 1 + print("Failure while writing {}".format(record)) + if not ignore_failures: + raise if i % chunk_size == 0 and i > 0: schedule.session.flush() sys.stdout.write('.') sys.stdout.flush() - print('%d record%s read for %s.' % ((i+1), '' if i == 0 else 's', - gtfs_class)) + print('{0} records read for {1}'.format(read_records, gtfs_class)) + print('{0} records skipped for {1}'.format(skipped_records, gtfs_class)) schedule.session.flush() schedule.session.commit() # load many to many relationships