@@ -285,7 +285,11 @@ def __init__(self, fs: gcsfs.GCSFileSystem, path: str, base64_validation_url: st
285
285
def download (self , date : datetime .datetime ) -> str :
286
286
for day in reversed (list (date - date .subtract (days = 7 ))):
287
287
try :
288
- schedule_extract = ScheduleStorage ().get_day (day ).get_url_schedule (self .base64_validation_url )
288
+ schedule_extract = (
289
+ ScheduleStorage ()
290
+ .get_day (day )
291
+ .get_url_schedule (self .base64_validation_url )
292
+ )
289
293
except KeyError :
290
294
print (
291
295
f"no schedule data found for { self .base64_validation_url } on day { day } "
@@ -312,10 +316,14 @@ def get_local_path(self) -> str:
312
316
return os .path .join (self .path , self .extract .timestamped_filename )
313
317
314
318
def get_results_path (self ) -> str :
315
- return os .path .join (self .path , f"{ self .extract .timestamped_filename } .results.json" )
319
+ return os .path .join (
320
+ self .path , f"{ self .extract .timestamped_filename } .results.json"
321
+ )
316
322
317
323
def hash (self ) -> str :
318
- with open (os .path .join (self .path , self .extract .timestamped_filename ), "rb" ) as f :
324
+ with open (
325
+ os .path .join (self .path , self .extract .timestamped_filename ), "rb"
326
+ ) as f :
319
327
file_hash = hashlib .md5 ()
320
328
while chunk := f .read (8192 ):
321
329
file_hash .update (chunk )
@@ -330,7 +338,9 @@ def has_results(self) -> bool:
330
338
331
339
332
340
class AggregationExtracts :
333
- def __init__ (self , fs : gcsfs .GCSFileSystem , path : str , aggregation : RTHourlyAggregation ):
341
+ def __init__ (
342
+ self , fs : gcsfs .GCSFileSystem , path : str , aggregation : RTHourlyAggregation
343
+ ):
334
344
self .fs = fs
335
345
self .path = path
336
346
self .aggregation = aggregation
@@ -339,7 +349,9 @@ def get_path(self):
339
349
return f"{ self .path } /rt_{ self .aggregation .name_hash } /"
340
350
341
351
def get_extracts (self ) -> List [AggregationExtract ]:
342
- return [AggregationExtract (self .get_path (), e ) for e in self .aggregation .extracts ]
352
+ return [
353
+ AggregationExtract (self .get_path (), e ) for e in self .aggregation .extracts
354
+ ]
343
355
344
356
def get_local_paths (self ) -> Dict [str , GTFSRTFeedExtract ]:
345
357
return {e .get_local_path (): e .extract for e in self .get_extracts ()}
@@ -362,38 +374,50 @@ def get_hashes(self) -> Dict[str, List[GTFSRTFeedExtract]]:
362
374
363
375
def download (self ):
364
376
self .fs .get (
365
- rpath = [
366
- extract .path
367
- for extract in self .get_local_paths ().values ()
368
- ],
377
+ rpath = [extract .path for extract in self .get_local_paths ().values ()],
369
378
lpath = list (self .get_local_paths ().keys ()),
370
379
)
371
380
372
381
def download_most_recent_schedule (self ) -> str :
373
382
first_extract = self .aggregation .extracts [0 ]
374
- schedule = MostRecentSchedule (self .fs , self .path , first_extract .config .base64_validation_url )
383
+ schedule = MostRecentSchedule (
384
+ self .fs , self .path , first_extract .config .base64_validation_url
385
+ )
375
386
return schedule .download (first_extract .dt )
376
387
377
388
378
389
class HourlyFeedQuery :
379
- def __init__ (self , step : RTProcessingStep , feed_type : GTFSFeedType , files : List [GTFSRTFeedExtract ], limit : int = 0 , base64_url : Optional [str ] = None ):
390
+ def __init__ (
391
+ self ,
392
+ step : RTProcessingStep ,
393
+ feed_type : GTFSFeedType ,
394
+ files : List [GTFSRTFeedExtract ],
395
+ limit : int = 0 ,
396
+ base64_url : Optional [str ] = None ,
397
+ ):
380
398
self .step = step
381
399
self .feed_type = feed_type
382
400
self .files = files
383
401
self .limit = limit
384
402
self .base64_url = base64_url
385
403
386
404
def set_limit (self , limit : int ):
387
- return HourlyFeedQuery (self .step , self .feed_type , self .files , limit , self .base64_url )
405
+ return HourlyFeedQuery (
406
+ self .step , self .feed_type , self .files , limit , self .base64_url
407
+ )
388
408
389
409
def where_base64url (self , base64_url : str ):
390
- return HourlyFeedQuery (self .step , self .feed_type , self .files , self .limit , base64_url )
391
-
392
- def get_aggregates (self ) -> Dict [Tuple [pendulum .DateTime , str ], List [GTFSRTFeedExtract ]]:
393
- aggregates : Dict [Tuple [pendulum .DateTime , str ], List [GTFSRTFeedExtract ]] = defaultdict (
394
- list
410
+ return HourlyFeedQuery (
411
+ self .step , self .feed_type , self .files , self .limit , base64_url
395
412
)
396
413
414
+ def get_aggregates (
415
+ self ,
416
+ ) -> Dict [Tuple [pendulum .DateTime , str ], List [GTFSRTFeedExtract ]]:
417
+ aggregates : Dict [
418
+ Tuple [pendulum .DateTime , str ], List [GTFSRTFeedExtract ]
419
+ ] = defaultdict (list )
420
+
397
421
for file in self .files :
398
422
if self .base64_url is None or file .base64_url == self .base64_url :
399
423
aggregates [(file .hour , file .base64_url )].append (file )
@@ -416,18 +440,29 @@ def total(self) -> int:
416
440
417
441
418
442
class HourlyFeedFiles :
419
- def __init__ (self , files : List [GTFSRTFeedExtract ], files_missing_metadata : List [Blob ], files_invalid_metadata : List [Blob ]):
443
+ def __init__ (
444
+ self ,
445
+ files : List [GTFSRTFeedExtract ],
446
+ files_missing_metadata : List [Blob ],
447
+ files_invalid_metadata : List [Blob ],
448
+ ):
420
449
self .files = files
421
450
self .files_missing_metadata = files_missing_metadata
422
451
self .files_invalid_metadata = files_invalid_metadata
423
452
424
453
def total (self ) -> int :
425
- return len (self .files ) + len (self .files_missing_metadata ) + len (self .files_invalid_metadata )
454
+ return (
455
+ len (self .files )
456
+ + len (self .files_missing_metadata )
457
+ + len (self .files_invalid_metadata )
458
+ )
426
459
427
460
def valid (self ) -> bool :
428
461
return not self .files or len (self .files ) / self .total () > 0.99
429
462
430
- def get_query (self , step : RTProcessingStep , feed_type : GTFSFeedType ) -> HourlyFeedQuery :
463
+ def get_query (
464
+ self , step : RTProcessingStep , feed_type : GTFSFeedType
465
+ ) -> HourlyFeedQuery :
431
466
return HourlyFeedQuery (step , feed_type , self .files )
432
467
433
468
@@ -451,12 +486,19 @@ def get_hour(self, hour: datetime.datetime) -> HourlyFeedFiles:
451
486
452
487
453
488
class ValidationProcessor :
454
- def __init__ (self , aggregation : RTHourlyAggregation , validator : RtValidator , verbose : bool = False ):
489
+ def __init__ (
490
+ self ,
491
+ aggregation : RTHourlyAggregation ,
492
+ validator : RtValidator ,
493
+ verbose : bool = False ,
494
+ ):
455
495
self .aggregation = aggregation
456
496
self .validator = validator
457
497
self .verbose = verbose
458
498
459
- def process (self , tmp_dir : tempfile .TemporaryDirectory , scope ) -> List [RTFileProcessingOutcome ]:
499
+ def process (
500
+ self , tmp_dir : tempfile .TemporaryDirectory , scope
501
+ ) -> List [RTFileProcessingOutcome ]:
460
502
outcomes : List [RTFileProcessingOutcome ] = []
461
503
fs = get_fs ()
462
504
@@ -498,7 +540,9 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
498
540
fingerprint : List [Any ] = [
499
541
type (e ),
500
542
# convert back to url manually, I don't want to mess around with the hourly class
501
- base64 .urlsafe_b64decode (self .aggregation .base64_url .encode ()).decode (),
543
+ base64 .urlsafe_b64decode (
544
+ self .aggregation .base64_url .encode ()
545
+ ).decode (),
502
546
]
503
547
fingerprint .append (e .returncode )
504
548
@@ -509,9 +553,7 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
509
553
scope .fingerprint = fingerprint
510
554
511
555
# get the end of stderr, just enough to fit in MAX_STRING_LENGTH defined above
512
- scope .set_context (
513
- "Process" , {"stderr" : stderr [- 2000 :]}
514
- )
556
+ scope .set_context ("Process" , {"stderr" : stderr [- 2000 :]})
515
557
516
558
sentry_sdk .capture_exception (e , scope = scope )
517
559
@@ -581,10 +623,13 @@ def process(self, tmp_dir: tempfile.TemporaryDirectory, scope) -> List[RTFilePro
581
623
typer .secho (
582
624
f"writing { len (records_to_upload )} lines to { self .aggregation .path } " ,
583
625
)
584
- with tempfile .NamedTemporaryFile (mode = "wb" , delete = False , dir = tmp_dir ) as f :
626
+ with tempfile .NamedTemporaryFile (
627
+ mode = "wb" , delete = False , dir = tmp_dir
628
+ ) as f :
585
629
gzipfile = gzip .GzipFile (mode = "wb" , fileobj = f )
586
630
encoded = (
587
- r .json () if isinstance (r , BaseModel ) else json .dumps (r ) for r in records_to_upload
631
+ r .json () if isinstance (r , BaseModel ) else json .dumps (r )
632
+ for r in records_to_upload
588
633
)
589
634
gzipfile .write ("\n " .join (encoded ).encode ("utf-8" ))
590
635
gzipfile .close ()
@@ -604,14 +649,18 @@ def __init__(self, aggregation: RTHourlyAggregation, verbose: bool = False):
604
649
self .aggregation = aggregation
605
650
self .verbose = verbose
606
651
607
- def process (self , tmp_dir : tempfile .TemporaryDirectory , scope ) -> List [RTFileProcessingOutcome ]:
652
+ def process (
653
+ self , tmp_dir : tempfile .TemporaryDirectory , scope
654
+ ) -> List [RTFileProcessingOutcome ]:
608
655
outcomes : List [RTFileProcessingOutcome ] = []
609
656
fs = get_fs ()
610
657
dst_path_rt = f"{ tmp_dir } /rt_{ self .aggregation .name_hash } /"
611
658
fs .get (
612
659
rpath = [
613
660
extract .path
614
- for extract in self .aggregation .local_paths_to_extract (dst_path_rt ).values ()
661
+ for extract in self .aggregation .local_paths_to_extract (
662
+ dst_path_rt
663
+ ).values ()
615
664
],
616
665
lpath = list (self .aggregation .local_paths_to_extract (dst_path_rt ).keys ()),
617
666
)
@@ -738,15 +787,23 @@ def parse_and_validate(
738
787
outcomes = []
739
788
with tempfile .TemporaryDirectory () as tmp_dir :
740
789
with sentry_sdk .push_scope () as scope :
741
- scope .set_tag ("config_feed_type" , aggregation .first_extract .config .feed_type )
790
+ scope .set_tag (
791
+ "config_feed_type" , aggregation .first_extract .config .feed_type
792
+ )
742
793
scope .set_tag ("config_name" , aggregation .first_extract .config .name )
743
794
scope .set_tag ("config_url" , aggregation .first_extract .config .url )
744
795
scope .set_context ("RT Hourly Aggregation" , json .loads (aggregation .json ()))
745
796
746
- if aggregation .step != RTProcessingStep .validate and aggregation .step != RTProcessingStep .parse :
797
+ if (
798
+ aggregation .step != RTProcessingStep .validate
799
+ and aggregation .step != RTProcessingStep .parse
800
+ ):
747
801
raise RuntimeError ("we should not be here" )
748
802
749
- if aggregation .step == RTProcessingStep .validate and not aggregation .extracts [0 ].config .schedule_url_for_validation :
803
+ if (
804
+ aggregation .step == RTProcessingStep .validate
805
+ and not aggregation .extracts [0 ].config .schedule_url_for_validation
806
+ ):
750
807
outcomes = [
751
808
RTFileProcessingOutcome (
752
809
step = aggregation .step ,
@@ -758,7 +815,9 @@ def parse_and_validate(
758
815
]
759
816
760
817
if aggregation .step == RTProcessingStep .validate :
761
- outcomes = ValidationProcessor (aggregation , validator , verbose ).process (tmp_dir , scope )
818
+ outcomes = ValidationProcessor (aggregation , validator , verbose ).process (
819
+ tmp_dir , scope
820
+ )
762
821
763
822
if aggregation .step == RTProcessingStep .parse :
764
823
outcomes = ParseProcessor (aggregation , verbose ).process (tmp_dir , scope )
@@ -801,7 +860,9 @@ def main(
801
860
f"too many files have missing/invalid metadata; { total - len (files )} of { total } " # noqa: E702
802
861
)
803
862
aggregated_feed = hourly_feed_files .get_query (step , feed_type )
804
- aggregations_to_process = aggregated_feed .where_base64url (base64url ).set_limit (limit ).get_aggregates ()
863
+ aggregations_to_process = (
864
+ aggregated_feed .where_base64url (base64url ).set_limit (limit ).get_aggregates ()
865
+ )
805
866
806
867
typer .secho (
807
868
f"found { len (hourly_feed_files .files )} { feed_type } files in { len (aggregated_feed .get_aggregates ())} aggregations to process" ,
@@ -892,7 +953,8 @@ def main(
892
953
)
893
954
894
955
assert (
895
- len (outcomes ) == aggregated_feed .where_base64url (base64url ).set_limit (limit ).total ()
956
+ len (outcomes )
957
+ == aggregated_feed .where_base64url (base64url ).set_limit (limit ).total ()
896
958
), f"we ended up with { len (outcomes )} outcomes from { aggregated_feed .where_base64url (base64url ).set_limit (limit ).total ()} "
897
959
898
960
if exceptions :
0 commit comments