Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow partial reads and align with GTFS specifications for stop_lat, stop_long, and transfer_type #68

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
62 changes: 35 additions & 27 deletions pygtfs/gtfs_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _validate_date(*field_names):
@validates(*field_names)
def make_date(self, key, value):
return datetime.datetime.strptime(value, '%Y%m%d').date()

return make_date


Expand All @@ -37,6 +38,7 @@ def time_delta(self, key, value):
(hours, minutes, seconds) = map(int, value.split(":"))
return datetime.timedelta(hours=hours, minutes=minutes,
seconds=seconds)

return time_delta


Expand All @@ -47,6 +49,7 @@ def int_bool(self, key, value):
raise PygtfsValidationError("{0} must be 0 or 1, "
"was {1}".format(key, value))
return value == "1"

return int_bool


Expand All @@ -64,18 +67,22 @@ def in_range(self, key, value):
raise PygtfsValidationError(
"{0} must be in range {1}, was {2}".format(key, int_choice, value))
return int_value

return in_range


def _validate_float_range(float_min, float_max, *field_names):
def _validate_float_range(float_min, float_max, nullable, *field_names):
@validates(*field_names)
def in_range(self, key, value):
if nullable and value is None:
return None
float_value = float(value)
if not (float_min <= float_value <= float_max):
raise PygtfsValidationError(
"{0} must be in range [{1}, {2}],"
" was {2}".format(key, float_min, float_max, value))
return float_value

return in_range


Expand All @@ -85,6 +92,7 @@ def is_float_none(self, key, value):
if value is None or value == "":
return None
return float(value)

return is_float_none


Expand Down Expand Up @@ -145,8 +153,9 @@ class Stop(Base):
stop_code = Column(Unicode, nullable=True, index=True)
stop_name = Column(Unicode)
stop_desc = Column(Unicode, nullable=True)
stop_lat = Column(Float)
stop_lon = Column(Float)
nullable_lat_long = True
stop_lat = Column(Float, nullable=nullable_lat_long)
stop_lon = Column(Float, nullable=nullable_lat_long)
zone_id = Column(Unicode, nullable=True)
stop_url = Column(Unicode, nullable=True)
location_type = Column(Integer, nullable=True)
Expand All @@ -162,8 +171,8 @@ class Stop(Base):
_validate_location = _validate_int_choice([None, 0, 1, 2, 3, 4], 'location_type')
_validate_wheelchair = _validate_int_choice([None, 0, 1, 2],
'wheelchair_boarding')
_validate_lon_lat = _validate_float_range(-180, 180, 'stop_lon',
'stop_lat')
_validate_lon_lat = _validate_float_range(-180, 180, nullable_lat_long,
'stop_lon', 'stop_lat')

def __repr__(self):
return '<Stop %s: %s>' % (self.stop_id, self.stop_name)
Expand All @@ -189,8 +198,8 @@ class Route(Base):
)

agency = relationship(Agency, backref="routes",
primaryjoin=and_(Agency.agency_id==foreign(agency_id),
Agency.feed_id==feed_id))
primaryjoin=and_(Agency.agency_id == foreign(agency_id),
Agency.feed_id == feed_id))

# https://developers.google.com/transit/gtfs/reference/extended-route-types
valid_extended_route_types = [
Expand Down Expand Up @@ -227,6 +236,7 @@ class ShapePoint(Base):
_plural_name_ = 'shapes'
feed_id = Column(Integer, ForeignKey('_feed.feed_id'), primary_key=True)
shape_id = Column(Unicode, primary_key=True)
nullable_lat_long = False
shape_pt_lat = Column(Float)
shape_pt_lon = Column(Float)
shape_pt_sequence = Column(Integer, primary_key=True)
Expand All @@ -236,7 +246,7 @@ class ShapePoint(Base):
Index('idx_shape_for_trips', feed_id, shape_id),
)

_validate_lon_lat = _validate_float_range(-180, 180,
_validate_lon_lat = _validate_float_range(-180, 180, nullable_lat_long,
'shape_pt_lon', 'shape_pt_lat')
_validate_shape_dist_traveled = _validate_float_none('shape_dist_traveled')

Expand Down Expand Up @@ -316,18 +326,17 @@ class Trip(Base):
)

route = relationship(Route, backref="trips",
primaryjoin=and_(Route.route_id==foreign(route_id),
Route.feed_id==feed_id))
primaryjoin=and_(Route.route_id == foreign(route_id),
Route.feed_id == feed_id))

shape_points = relationship(ShapePoint, backref="trips",
secondary="_trip_shapes")
secondary="_trip_shapes")

# TODO: The service_id references to calendar or to calendar_dates.
# Need to implement this requirement, but not using a simple foreign key.
service = relationship(Service, backref='trips',
primaryjoin=and_(foreign(service_id) == Service.service_id,
feed_id == Service.feed_id))

primaryjoin=and_(foreign(service_id) == Service.service_id,
feed_id == Service.feed_id))

_validate_direction_id = _validate_int_choice([None, 0, 1], 'direction_id')
_validate_wheelchair = _validate_int_choice([None, 0, 1, 2],
Expand Down Expand Up @@ -374,11 +383,11 @@ class StopTime(Base):
)

stop = relationship(Stop, backref='stop_times',
primaryjoin=and_(Stop.stop_id==foreign(stop_id),
Stop.feed_id==feed_id))
primaryjoin=and_(Stop.stop_id == foreign(stop_id),
Stop.feed_id == feed_id))
trip = relationship(Trip, backref="stop_times",
primaryjoin=and_(Trip.trip_id==foreign(trip_id),
Trip.feed_id==feed_id))
primaryjoin=and_(Trip.trip_id == foreign(trip_id),
Trip.feed_id == feed_id))

_validate_pickup_drop_off = _validate_int_choice([None, 0, 1, 2, 3],
'pickup_type',
Expand Down Expand Up @@ -432,8 +441,8 @@ class FareRule(Base):
)

route = relationship(Route, backref="fare_rules",
primaryjoin=and_(Route.route_id==foreign(route_id),
Route.feed_id==feed_id))
primaryjoin=and_(Route.route_id == foreign(route_id),
Route.feed_id == feed_id))

def __repr__(self):
return '<FareRule %s: %s %s %s %s>' % (self.fare_id,
Expand All @@ -458,8 +467,8 @@ class Frequency(Base):
)

trip = relationship(Trip, backref="frequencies",
primaryjoin=and_(Trip.trip_id==foreign(trip_id),
Trip.feed_id==feed_id))
primaryjoin=and_(Trip.trip_id == foreign(trip_id),
Trip.feed_id == feed_id))

_validate_exact_times = _validate_int_choice([None, 0, 1], 'exact_times')
_validate_deltas = _validate_time_delta('start_time', 'end_time')
Expand Down Expand Up @@ -507,7 +516,7 @@ class Transfer(Base):
primaryjoin=and_(Trip.trip_id == foreign(to_trip_id),
Trip.feed_id == feed_id))

_validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3],
_validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3, 4, 5],
'transfer_type')

def __repr__(self):
Expand Down Expand Up @@ -543,10 +552,10 @@ def __repr__(self):
Column('trans_id', Unicode),
Column('lang', Unicode),
ForeignKeyConstraint(['stop_feed_id', 'stop_id'], [Stop.feed_id, Stop.stop_id]),
ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'], [Translation.feed_id, Translation.trans_id, Translation.lang]),
ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'],
[Translation.feed_id, Translation.trans_id, Translation.lang]),
)


_trip_shapes = Table(
'_trip_shapes', Base.metadata,
Column('trip_feed_id', Integer),
Expand All @@ -556,10 +565,9 @@ def __repr__(self):
Column('shape_pt_sequence', Integer),
ForeignKeyConstraint(['trip_feed_id', 'trip_id'], [Trip.feed_id, Trip.trip_id]),
ForeignKeyConstraint(['shape_feed_id', 'shape_id', 'shape_pt_sequence'],
[ShapePoint.feed_id, ShapePoint.shape_id, ShapePoint.shape_pt_sequence]),
[ShapePoint.feed_id, ShapePoint.shape_id, ShapePoint.shape_pt_sequence]),
)


# a feed can skip Service (calendar) if it has ServiceException(calendar_dates)
gtfs_required = {Agency, Stop, Route, Trip, StopTime}
gtfs_calendar = {Service, ServiceException}
Expand Down
29 changes: 15 additions & 14 deletions pygtfs/loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import (division, absolute_import, print_function,
unicode_literals)

from datetime import date
import sys
from datetime import date

import six
from sqlalchemy import and_
from sqlalchemy.sql.expression import select, join

from .gtfs_entities import (Feed, Service, ServiceException, gtfs_required,
from . import feed
from .exceptions import PygtfsException
from .gtfs_entities import (Feed, gtfs_required,
Translation, Stop, Trip, ShapePoint, _stop_translations,
_trip_shapes, gtfs_calendar, gtfs_all)
from . import feed


def list_feeds(schedule):
Expand All @@ -21,7 +20,6 @@ def list_feeds(schedule):


def delete_feed(schedule, feed_filename, interactive=False):

feed_name = feed.derive_feed_name(feed_filename)
feeds_with_name = schedule.session.query(Feed).filter(Feed.feed_name == feed_name).all()
delete_all = not interactive
Expand All @@ -46,8 +44,7 @@ def overwrite_feed(schedule, feed_filename, *args, **kwargs):


def append_feed(schedule, feed_filename, strip_fields=True,
chunk_size=5000, agency_id_override=None):

chunk_size=5000, agency_id_override=None, ignore_failures=False):
fd = feed.Feed(feed_filename, strip_fields)

gtfs_tables = {}
Expand Down Expand Up @@ -77,24 +74,28 @@ def append_feed(schedule, feed_filename, strip_fields=True,
continue
gtfs_table = gtfs_tables[gtfs_class]


skipped_records = 0
read_records = 0
for i, record in enumerate(gtfs_table):
if not record:
# Empty row.
continue

try:
instance = gtfs_class(feed_id=feed_id, **record._asdict())
schedule.session.add(instance)
read_records += 1
except:
print("Failure while writing {0}".format(record))
raise
schedule.session.add(instance)
skipped_records += 1
print("Failure while writing {}".format(record))
if not ignore_failures:
raise
if i % chunk_size == 0 and i > 0:
schedule.session.flush()
sys.stdout.write('.')
sys.stdout.flush()
print('%d record%s read for %s.' % ((i+1), '' if i == 0 else 's',
gtfs_class))
print('{0} records read for {1}'.format(read_records, gtfs_class))
InterferencePattern marked this conversation as resolved.
Show resolved Hide resolved
print('{0} records skipped for {1}'.format(skipped_records, gtfs_class))
schedule.session.flush()
schedule.session.commit()
# load many to many relationships
Expand Down