Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrite committed Nov 23, 2024
1 parent 4a98da5 commit 15cb151
Showing 1 changed file with 22 additions and 28 deletions.
50 changes: 22 additions & 28 deletions jobs/gtfs-rt-parser-v2/gtfs_rt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@
from google.transit import gtfs_realtime_pb2 # type: ignore
from pydantic import BaseModel, Field, validator

RT_VALIDATOR_JAR_LOCATION_ENV_KEY = "GTFS_RT_VALIDATOR_JAR"
JAR_DEFAULT = typer.Option(
os.environ.get(RT_VALIDATOR_JAR_LOCATION_ENV_KEY),
help="Path to the GTFS RT Validator JAR",
)

JAR_DEFAULT = os.environ["GTFS_RT_VALIDATOR_JAR"]
RT_PARSED_BUCKET = os.environ["CALITP_BUCKET__GTFS_RT_PARSED"]
RT_VALIDATION_BUCKET = os.environ["CALITP_BUCKET__GTFS_RT_VALIDATION"]
GTFS_RT_VALIDATOR_VERSION = os.environ["GTFS_RT_VALIDATOR_VERSION"]
Expand Down Expand Up @@ -218,7 +213,7 @@ def dt(self) -> pendulum.Date:


class RtValidator:
def __init__(self, jar_path: Path):
def __init__(self, jar_path: str = JAR_DEFAULT):
self.jar_path = jar_path

def execute(self, gtfs_file: str, rt_path: str):
Expand Down Expand Up @@ -291,12 +286,13 @@ def download(self, date: datetime.datetime) -> str:
try:
gtfs_zip = "/".join([self.path, schedule_extract.filename])
self.fs.get(schedule_extract.path, gtfs_zip)
return gtfs_zip
break
except FileNotFoundError:
print(
f"no schedule file found for {self.base64_validation_url} on day {day}"
)
continue
return gtfs_zip


class AggregationExtract:
Expand All @@ -312,7 +308,7 @@ def get_results_path(self) -> str:
self.path, f"{self.extract.timestamped_filename}.results.json"
)

def hash(self) -> str:
def hash(self) -> bytes:
with open(
os.path.join(self.path, self.extract.timestamped_filename), "rb"
) as f:
Expand Down Expand Up @@ -351,15 +347,15 @@ def get_local_paths(self) -> Dict[str, GTFSRTFeedExtract]:
def get_results_paths(self) -> Dict[str, GTFSRTFeedExtract]:
return {e.get_results_path(): e.extract for e in self.get_extracts()}

def get_hashed_results(self) -> Dict[str, List[Dict[str, str]]]:
hashed: Dict[str, str] = {}
def get_hashed_results(self):
hashed = {}
for e in self.get_extracts():
if e.has_results():
hashed[e.hash()] = e.get_results()
return hashed

def get_hashes(self) -> Dict[str, List[GTFSRTFeedExtract]]:
hashed: Dict[str, List[GTFSRTFeedExtract]] = defaultdict(list)
def get_hashes(self) -> Dict[bytes, List[GTFSRTFeedExtract]]:
hashed: Dict[bytes, List[GTFSRTFeedExtract]] = defaultdict(list)
for e in self.get_extracts():
hashed[e.hash()].append(e.extract)
return hashed
Expand Down Expand Up @@ -398,14 +394,14 @@ def set_limit(self, limit: int):
self.step, self.feed_type, self.files, limit, self.base64_url
)

def where_base64url(self, base64_url: str):
def where_base64url(self, base64_url: Optional[str]):
return HourlyFeedQuery(
self.step, self.feed_type, self.files, self.limit, base64_url
)

def get_aggregates(
self,
) -> Dict[Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]]:
) -> List[RTHourlyAggregation]:
aggregates: Dict[
Tuple[pendulum.DateTime, str], List[GTFSRTFeedExtract]
] = defaultdict(list)
Expand Down Expand Up @@ -481,15 +477,16 @@ class ValidationProcessor:
def __init__(
self,
aggregation: RTHourlyAggregation,
validator: RtValidator,
verbose: bool = False,
):
self.aggregation = aggregation
self.validator = validator
self.verbose = verbose

def validator(self):
return RtValidator()

def process(
self, tmp_dir: tempfile.TemporaryDirectory, scope
self, tmp_dir: str, scope
) -> List[RTFileProcessingOutcome]:
outcomes: List[RTFileProcessingOutcome] = []
fs = get_fs()
Expand Down Expand Up @@ -534,7 +531,7 @@ def process(

if not outcomes:
try:
self.validator.execute(gtfs_zip, aggregation_extracts.get_path())
self.validator().execute(gtfs_zip, aggregation_extracts.get_path())

# these are the only two types of errors we expect; let any others bubble up
except subprocess.CalledProcessError as e:
Expand Down Expand Up @@ -612,12 +609,12 @@ def process(
]
)

for e in extracts:
for extract in extracts:
outcomes.append(
RTFileProcessingOutcome(
step=self.aggregation.step,
success=True,
extract=e,
extract=extract,
aggregation=self.aggregation,
)
)
Expand Down Expand Up @@ -654,7 +651,7 @@ def __init__(self, aggregation: RTHourlyAggregation, verbose: bool = False):
self.verbose = verbose

def process(
self, tmp_dir: tempfile.TemporaryDirectory, scope
self, tmp_dir: str, scope
) -> List[RTFileProcessingOutcome]:
outcomes: List[RTFileProcessingOutcome] = []
fs = get_fs()
Expand Down Expand Up @@ -784,10 +781,8 @@ def process(
# exceptions in backoff's context, which ruins things
def parse_and_validate(
aggregation: RTHourlyAggregation,
jar_path: Path,
verbose: bool = False,
) -> List[RTFileProcessingOutcome]:
validator = RtValidator(jar_path)
with tempfile.TemporaryDirectory() as tmp_dir:
with sentry_sdk.push_scope() as scope:
scope.set_tag(
Expand All @@ -804,7 +799,7 @@ def parse_and_validate(
raise RuntimeError("we should not be here")

if aggregation.step == RTProcessingStep.validate:
return ValidationProcessor(aggregation, validator, verbose).process(
return ValidationProcessor(aggregation, verbose).process(
tmp_dir, scope
)

Expand Down Expand Up @@ -835,16 +830,16 @@ def main(
hour: datetime.datetime,
limit: int = 0,
threads: int = 4,
jar_path: Path = JAR_DEFAULT,
verbose: bool = False,
base64url: Optional[str] = None,
):
hourly_feed_files = FeedStorage(feed_type).get_hour(hour)
if not hourly_feed_files.valid():
typer.secho(f"missing: {hourly_feed_files.files_missing_metadata}")
typer.secho(f"invalid: {hourly_feed_files.files_invalid_metadata}")
error_count = hourly_feed_files.total() - len(hourly_feed_files.files)
raise RuntimeError(
f"too many files have missing/invalid metadata; {hourly_feed_files.total - len(hourly_feed_files.files)} of {hourly_feed_files.total}" # noqa: E702
f"too many files have missing/invalid metadata; {error_count} of {hourly_feed_files.total()}" # noqa: E702
)
aggregated_feed = hourly_feed_files.get_query(step, feed_type)
aggregations_to_process = (
Expand Down Expand Up @@ -894,7 +889,6 @@ def main(
pool.submit(
parse_and_validate,
aggregation=aggregation,
jar_path=jar_path,
verbose=verbose,
): aggregation
for aggregation in aggregations_to_process
Expand Down

0 comments on commit 15cb151

Please sign in to comment.