diff --git a/gwr_feed.py b/gwr_feed.py index ee95b29..ed66212 100644 --- a/gwr_feed.py +++ b/gwr_feed.py @@ -163,7 +163,7 @@ def get_item_listing(query): fare_price, f"({fare_type_name})"] # 'availablespaces' appears to be defaulted to 9 so we will ignore that - if remaining_seats and remaining_seats != 9: + if query.seats_left and remaining_seats and remaining_seats != 9: fare_text.insert(2, f"({remaining_seats} left)") result_dict[departure_dt] = ' '.join(fare_text) diff --git a/gwr_feed_data.py b/gwr_feed_data.py index 6d505bb..5d75bf5 100644 --- a/gwr_feed_data.py +++ b/gwr_feed_data.py @@ -60,6 +60,8 @@ class _BaseQuery: timestamp: datetime = None weeks_ahead_str: str = '0' weeks_ahead: int = 0 + seats_left_str: str = 'false' + seats_left: bool = False def init_station_ids(self, feed_config): self.from_id = get_station_id(self.from_code, feed_config) @@ -86,6 +88,11 @@ def validate_departure_time(self): if not all(time_rules): self.status.errors.append('Invalid departure time') + def init_seats_left(self): + if self.seats_left_str: + self.seats_left = bool( + self.seats_left_str.lower() in ('true', 'y', 'yes')) + def validate_departure_date(self): if self.date_str: date_rules = [self.date_str.isnumeric(), len(self.date_str) == 8] @@ -109,6 +116,12 @@ def validate_weeks_ahead(self): if not self.weeks_ahead_str.isnumeric(): self.status.errors.append('Invalid week count') + def validate_seats_left(self): + if self.seats_left_str: + if not self.seats_left_str.isalpha(): + self.status.errors.append( + 'seats_left should be either true or false') + @dataclass() class GwrQuery(_BaseQuery): @@ -118,6 +131,7 @@ def __post_init__(self): self.validate_departure_time() self.validate_departure_date() self.validate_weeks_ahead() + self.validate_seats_left() self.status.refresh() if self.status.ok: @@ -125,4 +139,5 @@ def __post_init__(self): self.init_journey() self.init_timestamp() self.init_weeks_ahead() + self.init_seats_left() self.status.refresh() diff --git a/server.py b/server.py index e8ad887..e42efce 100644 --- a/server.py +++ b/server.py @@ -49,7 +49,8 @@ def process_listing(): 'to_code': rq.args.get('to') or GwrQuery.to_code, 'time_str': rq.args.get('at') or GwrQuery.time_str, 'date_str': rq.args.get('on') or GwrQuery.date_str, - 'weeks_ahead_str': rq.args.get('weeks') or GwrQuery.weeks_ahead_str + 'weeks_ahead_str': rq.args.get('weeks') or GwrQuery.weeks_ahead_str, + 'seats_left_str': rq.args.get('seats_left') or GwrQuery.seats_left_str } # access_token expires after 45 mins, get a new token for each query