Skip to content

Commit 352936e

Browse files
committed
Add Lint corrections for GTFS RT validator
1 parent 1f2d510 commit 352936e

File tree

1 file changed

+98
-36
lines changed

1 file changed

+98
-36
lines changed

jobs/gtfs-rt-parser-v2/gtfs_rt_parser.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ def __init__(self, fs: gcsfs.GCSFileSystem, path: str, base64_validation_url: st
285285
def download(self, date: datetime.datetime) -> str:
286286
for day in reversed(list(date - date.subtract(days=7))):
287287
try:
288-
schedule_extract = ScheduleStorage().get_day(day).get_url_schedule(self.base64_validation_url)
288+
schedule_extract = (
289+
ScheduleStorage()
290+
.get_day(day)
291+
.get_url_schedule(self.base64_validation_url)
292+
)
289293
except KeyError:
290294
print(
291295
f"no schedule data found for {self.base64_validation_url} on day {day}"
@@ -312,10 +316,14 @@ def get_local_path(self) -> str:
312316
return os.path.join(self.path, self.extract.timestamped_filename)
313317

314318
def get_results_path(self) -> str:
315-
return os.path.join(self.path, f"{self.extract.timestamped_filename}.results.json")
319+
return os.path.join(
320+
self.path, f"{self.extract.timestamped_filename}.results.json"
321+
)
316322

317323
def hash(self) -> str:
318-
with open(os.path.join(self.path, self.extract.timestamped_filename), "rb") as f:
324+
with open(
325+
os.path.join(self.path, self.extract.timestamped_filename), "rb"
326+
) as f:
319327
file_hash = hashlib.md5()
320328
while chunk := f.read(8192):
321329
file_hash.update(chunk)
@@ -330,7 +338,9 @@ def has_results(self) -> bool:
330338

331339

332340
class AggregationExtracts:
333-
def __init__(self, fs: gcsfs.GCSFileSystem, path: str, aggregation: RTHourlyAggregation):
341+
def __init__(
342+
self, fs: gcsfs.GCSFileSystem, path: str, aggregation: RTHourlyAggregation
343+
):
334344
self.fs = fs
335345
self.path = path
336346
self.aggregation = aggregation
@@ -339,7 +349,9 @@ def get_path(self):
339349
return f"{self.path}/rt_{self.aggregation.name_hash}/"
340350

341351
def get_extracts(self) -> List[AggregationExtract]:
342-
return [AggregationExtract(self.get_path(), e) for e in self.aggregation.extracts]
352+
return [
353+
AggregationExtract(self.get_path(), e) for e in self.aggregation.extracts
354+
]
343355

344356
def get_local_paths(self) -> Dict[str, GTFSRTFeedExtract]:
345357
return {e.get_local_path(): e.extract for e in self.get_extracts()}
@@ -362,38 +374,50 @@ def get_hashes(self) -> Dict[str, List[GTFSRTFeedExtract]]:
362374

363375
def download(self):
364376
self.fs.get(
365-
rpath=[
366-
extract.path
367-
for extract in self.get_local_paths().values()
368-
],
377+
rpath=[extract.path for extract in self.get_local_paths().values()],
369378
lpath=list(self.get_local_paths().keys()),
370379
)
371380

372381
def download_most_recent_schedule(self) -> str:
373382
first_extract = self.aggregation.extracts[0]
374-
schedule = MostRecentSchedule(self.fs, self.path, first_extract.config.base64_validation_url)
383+
schedule = MostRecentSchedule(
384+
self.fs, self.path, first_extract.config.base64_validation_url
385+
)
375386
return schedule.download(first_extract.dt)
376387

377388

378389
class HourlyFeedQuery:
379-
def __init__(self, step: RTProcessingStep, feed_type: GTFSFeedType, files: List[GTFSRTFeedExtract], limit: int = 0, base64_url: Optional[str] = None):
390+
def __init__(
391+
self,
392+
step: RTProcessingStep,
393+
feed_type: GTFSFeedType,
394+
files: List[GTFSRTFeedExtract],
395+
limit: int = 0,
396+
base64_url: Optional[str] = None,
397+
):
380398
self.step = step
381399
self.feed_type = feed_type
382400
self.files = files
383401
self.limit = limit
384402
self.base64_url = base64_url
385403

386404
def set_limit(self, limit: int):
387-
return HourlyFeedQuery(self.step, self.feed_type, self.files, limit, self.base64_url)
405+
return HourlyFeedQuery(
406+
self.step, self.feed_type, self.files, limit, self.base64_url
407+
)
388408

389409
def where_base64url(self, base64_url: str):
390-
return HourlyFeedQuery(self.step, self.feed_type, self.files, self.limit, base64_url)
391-
392-
def get_aggregates(self) -> Dict[Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]]:
393-
aggregates: Dict[Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]] = defaultdict(
394-
list
410+
return HourlyFeedQuery(
411+
self.step, self.feed_type, self.files, self.limit, base64_url
395412
)
396413

414+
def get_aggregates(
415+
self,
416+
) -> Dict[Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]]:
417+
aggregates: Dict[
418+
Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]
419+
] = defaultdict(list)
420+
397421
for file in self.files:
398422
if self.base64_url is None or file.base64_url == self.base64_url:
399423
aggregates[(file.hour, file.base64_url)].append(file)
@@ -416,18 +440,29 @@ def total(self) -> int:
416440

417441

418442
class HourlyFeedFiles:
419-
def __init__(self, files: List[GTFSRTFeedExtract], files_missing_metadata: List[Blob], files_invalid_metadata: List[Blob]):
443+
def __init__(
444+
self,
445+
files: List[GTFSRTFeedExtract],
446+
files_missing_metadata: List[Blob],
447+
files_invalid_metadata: List[Blob],
448+
):
420449
self.files = files
421450
self.files_missing_metadata = files_missing_metadata
422451
self.files_invalid_metadata = files_invalid_metadata
423452

424453
def total(self) -> int:
425-
return len(self.files) + len(self.files_missing_metadata) + len(self.files_invalid_metadata)
454+
return (
455+
len(self.files)
456+
+ len(self.files_missing_metadata)
457+
+ len(self.files_invalid_metadata)
458+
)
426459

427460
def valid(self) -> bool:
428461
return not self.files or len(self.files) / self.total() > 0.99
429462

430-
def get_query(self, step: RTProcessingStep, feed_type: GTFSFeedType) -> HourlyFeedQuery:
463+
def get_query(
464+
self, step: RTProcessingStep, feed_type: GTFSFeedType
465+
) -> HourlyFeedQuery:
431466
return HourlyFeedQuery(step, feed_type, self.files)
432467

433468

@@ -451,12 +486,19 @@ def get_hour(self, hour: datetime.datetime) -> HourlyFeedFiles:
451486

452487

453488
class ValidationProcessor:
454-
def __init__(self, aggregation: RTHourlyAggregation, validator: RtValidator, verbose: bool = False):
489+
def __init__(
490+
self,
491+
aggregation: RTHourlyAggregation,
492+
validator: RtValidator,
493+
verbose: bool = False,
494+
):
455495
self.aggregation = aggregation
456496
self.validator = validator
457497
self.verbose = verbose
458498

459-
def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFileProcessingOutcome]:
499+
def process(
500+
self, tmp_dir: tempfile.TemporaryDirectory, scope
501+
) -> List[RTFileProcessingOutcome]:
460502
outcomes: List[RTFileProcessingOutcome] = []
461503
fs = get_fs()
462504

@@ -498,7 +540,9 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
498540
fingerprint: List[Any] = [
499541
type(e),
500542
# convert back to url manually, I don't want to mess around with the hourly class
501-
base64.urlsafe_b64decode(self.aggregation.base64_url.encode()).decode(),
543+
base64.urlsafe_b64decode(
544+
self.aggregation.base64_url.encode()
545+
).decode(),
502546
]
503547
fingerprint.append(e.returncode)
504548

@@ -509,9 +553,7 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
509553
scope.fingerprint = fingerprint
510554

511555
# get the end of stderr, just enough to fit in MAX_STRING_LENGTH defined above
512-
scope.set_context(
513-
"Process", {"stderr": stderr[-2000:]}
514-
)
556+
scope.set_context("Process", {"stderr": stderr[-2000:]})
515557

516558
sentry_sdk.capture_exception(e, scope=scope)
517559

@@ -581,10 +623,13 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
581623
typer.secho(
582624
f"writing {len(records_to_upload)} lines to {self.aggregation.path}",
583625
)
584-
with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=tmp_dir) as f:
626+
with tempfile.NamedTemporaryFile(
627+
mode="wb", delete=False, dir=tmp_dir
628+
) as f:
585629
gzipfile = gzip.GzipFile(mode="wb", fileobj=f)
586630
encoded = (
587-
r.json() if isinstance(r, BaseModel) else json.dumps(r) for r in records_to_upload
631+
r.json() if isinstance(r, BaseModel) else json.dumps(r)
632+
for r in records_to_upload
588633
)
589634
gzipfile.write("\n".join(encoded).encode("utf-8"))
590635
gzipfile.close()
@@ -604,14 +649,18 @@ def __init__(self, aggregation: RTHourlyAggregation, verbose: bool = False):
604649
self.aggregation = aggregation
605650
self.verbose = verbose
606651

607-
def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFileProcessingOutcome]:
652+
def process(
653+
self, tmp_dir: tempfile.TemporaryDirectory, scope
654+
) -> List[RTFileProcessingOutcome]:
608655
outcomes: List[RTFileProcessingOutcome] = []
609656
fs = get_fs()
610657
dst_path_rt = f"{tmp_dir}/rt_{self.aggregation.name_hash}/"
611658
fs.get(
612659
rpath=[
613660
extract.path
614-
for extract in self.aggregation.local_paths_to_extract(dst_path_rt).values()
661+
for extract in self.aggregation.local_paths_to_extract(
662+
dst_path_rt
663+
).values()
615664
],
616665
lpath=list(self.aggregation.local_paths_to_extract(dst_path_rt).keys()),
617666
)
@@ -738,15 +787,23 @@ def parse_and_validate(
738787
outcomes = []
739788
with tempfile.TemporaryDirectory() as tmp_dir:
740789
with sentry_sdk.push_scope() as scope:
741-
scope.set_tag("config_feed_type", aggregation.first_extract.config.feed_type)
790+
scope.set_tag(
791+
"config_feed_type", aggregation.first_extract.config.feed_type
792+
)
742793
scope.set_tag("config_name", aggregation.first_extract.config.name)
743794
scope.set_tag("config_url", aggregation.first_extract.config.url)
744795
scope.set_context("RT Hourly Aggregation", json.loads(aggregation.json()))
745796

746-
if aggregation.step != RTProcessingStep.validate and aggregation.step != RTProcessingStep.parse:
797+
if (
798+
aggregation.step != RTProcessingStep.validate
799+
and aggregation.step != RTProcessingStep.parse
800+
):
747801
raise RuntimeError("we should not be here")
748802

749-
if aggregation.step == RTProcessingStep.validate and not aggregation.extracts[0].config.schedule_url_for_validation:
803+
if (
804+
aggregation.step == RTProcessingStep.validate
805+
and not aggregation.extracts[0].config.schedule_url_for_validation
806+
):
750807
outcomes = [
751808
RTFileProcessingOutcome(
752809
step=aggregation.step,
@@ -758,7 +815,9 @@ def parse_and_validate(
758815
]
759816

760817
if aggregation.step == RTProcessingStep.validate:
761-
outcomes = ValidationProcessor(aggregation, validator, verbose).process(tmp_dir, scope)
818+
outcomes = ValidationProcessor(aggregation, validator, verbose).process(
819+
tmp_dir, scope
820+
)
762821

763822
if aggregation.step == RTProcessingStep.parse:
764823
outcomes = ParseProcessor(aggregation, verbose).process(tmp_dir, scope)
@@ -801,7 +860,9 @@ def main(
801860
f"too many files have missing/invalid metadata; {total - len(files)} of {total}" # noqa: E702
802861
)
803862
aggregated_feed = hourly_feed_files.get_query(step, feed_type)
804-
aggregations_to_process = aggregated_feed.where_base64url(base64url).set_limit(limit).get_aggregates()
863+
aggregations_to_process = (
864+
aggregated_feed.where_base64url(base64url).set_limit(limit).get_aggregates()
865+
)
805866

806867
typer.secho(
807868
f"found {len(hourly_feed_files.files)} {feed_type} files in {len(aggregated_feed.get_aggregates())} aggregations to process",
@@ -892,7 +953,8 @@ def main(
892953
)
893954

894955
assert (
895-
len(outcomes) == aggregated_feed.where_base64url(base64url).set_limit(limit).total()
956+
len(outcomes)
957+
== aggregated_feed.where_base64url(base64url).set_limit(limit).total()
896958
), f"we ended up with {len(outcomes)} outcomes from {aggregated_feed.where_base64url(base64url).set_limit(limit).total()}"
897959

898960
if exceptions:

0 commit comments

Comments
 (0)