diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 5f9c40b5da..0bc3271131 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1070,24 +1070,6 @@ } } ], - "./monitoring/mock_uss/dynamic_configuration/routes.py": [ - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 5, - "lineCount": 3 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 5, - "lineCount": 3 - } - } - ], "./monitoring/mock_uss/f3548v21/flight_planning.py": [ { "code": "reportOperatorIssue", @@ -1552,24 +1534,6 @@ } } ], - "./monitoring/mock_uss/riddp/routes_behavior.py": [ - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 43, - "lineCount": 1 - } - } - ], "./monitoring/mock_uss/riddp/routes_observation.py": [ { "code": "reportOptionalMemberAccess", @@ -1802,14 +1766,6 @@ "endColumn": 87, "lineCount": 1 } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 32, - "lineCount": 1 - } } ], "./monitoring/mock_uss/ridsp/behavior.py": [ @@ -1846,24 +1802,6 @@ } } ], - "./monitoring/mock_uss/ridsp/routes_behavior.py": [ - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 43, - "lineCount": 1 - } - } - ], "./monitoring/mock_uss/ridsp/routes_injection.py": [ { "code": "reportOperatorIssue", @@ -1889,14 +1827,6 @@ "lineCount": 1 } }, - { - "code": "reportReturnType", - "range": { - "startColumn": 15, - "endColumn": 38, - "lineCount": 1 - } - }, { "code": "reportOptionalMemberAccess", "range": { @@ -1928,62 +1858,6 @@ "endColumn": 120, "lineCount": 1 } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 19, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 5, - "lineCount": 3 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 12, - "endColumn": 86, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 12, - "endColumn": 77, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 16, - "endColumn": 82, - "lineCount": 1 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 12, - "endColumn": 13, - "lineCount": 3 - } - }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 27, - "lineCount": 1 - } } ], "./monitoring/mock_uss/ridsp/routes_ridsp_v19.py": [ diff --git a/monitoring/mock_uss/database.py b/monitoring/mock_uss/database.py index 575e27360f..6df90b1eda 100644 --- a/monitoring/mock_uss/database.py +++ b/monitoring/mock_uss/database.py @@ -47,7 +47,7 @@ class Database(ImplicitDict): """Timestamp of most recent time periodic task loop iterated""" -db = SynchronizedValue( +db = SynchronizedValue[Database]( Database( one_time_tasks=[], task_errors=[], diff --git a/monitoring/mock_uss/dynamic_configuration/configuration.py b/monitoring/mock_uss/dynamic_configuration/configuration.py index e444c7fe76..80df1fa898 100644 --- a/monitoring/mock_uss/dynamic_configuration/configuration.py +++ b/monitoring/mock_uss/dynamic_configuration/configuration.py @@ -14,7 +14,7 @@ class DynamicConfiguration(ImplicitDict): locale: LocalityCode -db = SynchronizedValue( +db = SynchronizedValue[DynamicConfiguration]( DynamicConfiguration(locale=LocalityCode(webapp.config[KEY_BEHAVIOR_LOCALITY])), decoder=lambda b: ImplicitDict.parse( json.loads(b.decode("utf-8")), DynamicConfiguration @@ -24,6 +24,6 @@ class DynamicConfiguration(ImplicitDict): def get_locality() -> Locality: - with db as tx: - code = tx.locale + with db.transact() as tx: + code = tx.value.locale return Locality.from_locale(code) diff --git a/monitoring/mock_uss/dynamic_configuration/routes.py b/monitoring/mock_uss/dynamic_configuration/routes.py index 80e92e756e..496b836d39 100644 --- a/monitoring/mock_uss/dynamic_configuration/routes.py +++ b/monitoring/mock_uss/dynamic_configuration/routes.py @@ -12,7 +12,7 @@ @webapp.route("/configuration/locality", methods=["GET"]) -def locality_get() -> tuple[str, int]: +def locality_get() -> flask.Response: return flask.jsonify( GetLocalityResponse(locality_code=get_locality().locality_code()) ) @@ -20,7 +20,7 @@ def locality_get() -> tuple[str, int]: @webapp.route("/configuration/locality", methods=["PUT"]) @requires_scope(MOCK_USS_CONFIG_SCOPE) # TODO: use separate public key for this -def locality_set() -> tuple[str, int]: +def locality_set() -> tuple[str, int] | flask.Response: """Set the locality of the mock_uss.""" try: json = flask.request.json @@ -38,8 +38,8 @@ def locality_set() -> tuple[str, int]: msg = f"Invalid locality_code: {str(e)}" return msg, 400 - with db as tx: - tx.locale = req.locality_code + with db.transact() as tx: + tx.value.locale = req.locality_code return flask.jsonify( GetLocalityResponse(locality_code=get_locality().locality_code()) diff --git a/monitoring/mock_uss/geoawareness/check.py b/monitoring/mock_uss/geoawareness/check.py index c278d3bd52..97ca2d4e2b 100644 --- a/monitoring/mock_uss/geoawareness/check.py +++ b/monitoring/mock_uss/geoawareness/check.py @@ -7,7 +7,8 @@ GeozoneSourceResponseResult, ) -from monitoring.mock_uss.geoawareness.database import Database, SourceRecord, db +from monitoring.mock_uss.geoawareness import database +from monitoring.mock_uss.geoawareness.database import SourceRecord, db from monitoring.mock_uss.geoawareness.ed269 import evaluate_source logger = logging.getLogger(__name__) @@ -27,7 +28,7 @@ def combine_results( def check_geozones(req: GeozonesCheckRequest) -> list[GeozonesCheckResultGeozone]: - sources: dict[str, SourceRecord] = Database.get_sources(db) + sources: dict[str, SourceRecord] = database.get_sources(db) results: list[GeozonesCheckResultGeozone] = [ GeozonesCheckResultGeozone.Absent @@ -44,7 +45,11 @@ def check_geozones(req: GeozonesCheckRequest) -> list[GeozonesCheckResultGeozone ) continue - fmt = source.definition.https_source.format + fmt = ( + source.definition.https_source.format + if source.definition.https_source + else None + ) if fmt == GeozoneHttpsSourceFormat.ED_269: logger.debug(f" {j + 1}. ED269 source {source_id} ready.") result = combine_results( diff --git a/monitoring/mock_uss/geoawareness/database.py b/monitoring/mock_uss/geoawareness/database.py index 8c4b233579..7a4d139a32 100644 --- a/monitoring/mock_uss/geoawareness/database.py +++ b/monitoring/mock_uss/geoawareness/database.py @@ -26,60 +26,63 @@ class Database(ImplicitDict): sources: dict[str, SourceRecord] = {} - @staticmethod - def get_source(db: SynchronizedValue, id: str) -> SourceRecord: - return db.value.sources.get(id, None) - - @staticmethod - def get_sources(db: SynchronizedValue) -> SourceRecord: - return db.value.sources - - @staticmethod - def insert_source( - db: SynchronizedValue, - id: str, - definition: CreateGeozoneSourceRequest, - state: GeozoneSourceResponseResult, - message: str | None = None, - ) -> SourceRecord: - with db as tx: - if id in tx.sources.keys(): - raise ExistingRecordException() - tx.sources[id] = SourceRecord( - definition=definition, state=state, message=message - ) - result = tx.sources[id] - return result - - @staticmethod - def update_source_state( - db: SynchronizedValue, - id: str, - state: GeozoneSourceResponseResult, - message: str | None = None, - ): - with db as tx: - tx.sources[id]["state"] = state - tx.sources[id]["message"] = message - result = tx.sources[id] - return result - - @staticmethod - def update_source_geozone_ed269( - db: SynchronizedValue, id: str, geozone: ED269Schema - ): - with db as tx: - tx.sources[id]["geozone_ed269"] = geozone - result = tx.sources[id] - return result - - @staticmethod - def delete_source(db: SynchronizedValue, id: str): - with db as tx: - return tx.sources.pop(id, None) - - -db = SynchronizedValue( + +def get_source(geo_db: SynchronizedValue[Database], source_id: str) -> SourceRecord: + result = geo_db.value.sources.get(source_id, None) + if result is None: + raise KeyError(f"No source exists with id '{source_id}'") + return result + + +def get_sources(geo_db: SynchronizedValue[Database]) -> dict[str, SourceRecord]: + return geo_db.value.sources + + +def insert_source( + geo_db: SynchronizedValue[Database], + source_id: str, + definition: CreateGeozoneSourceRequest, + state: GeozoneSourceResponseResult, + message: str | None = None, +) -> SourceRecord: + with geo_db.transact() as tx: + if source_id in tx.value.sources.keys(): + raise ExistingRecordException() + tx.value.sources[source_id] = SourceRecord( + definition=definition, state=state, message=message + ) + result = tx.value.sources[source_id] + return result + + +def update_source_state( + geo_db: SynchronizedValue[Database], + source_id: str, + state: GeozoneSourceResponseResult, + message: str | None = None, +): + with geo_db.transact() as tx: + tx.value.sources[source_id]["state"] = state + tx.value.sources[source_id]["message"] = message + result = tx.value.sources[source_id] + return result + + +def update_source_geozone_ed269( + geo_db: SynchronizedValue[Database], source_id: str, geozone: ED269Schema +): + with geo_db.transact() as tx: + tx.value.sources[source_id]["geozone_ed269"] = geozone + result = tx.value.sources[source_id] + return result + + +def delete_source(geo_db: SynchronizedValue[Database], source_id: str): + with geo_db.transact() as tx: + return tx.value.sources.pop(source_id, None) + + +db = SynchronizedValue[Database]( Database(), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database), ) diff --git a/monitoring/mock_uss/geoawareness/geozone_sources.py b/monitoring/mock_uss/geoawareness/geozone_sources.py index e363f23e8f..dd7f673269 100644 --- a/monitoring/mock_uss/geoawareness/geozone_sources.py +++ b/monitoring/mock_uss/geoawareness/geozone_sources.py @@ -7,8 +7,8 @@ GeozoneSourceResponseResult, ) +from monitoring.mock_uss.geoawareness import database from monitoring.mock_uss.geoawareness.database import ( - Database, ExistingRecordException, db, ) @@ -17,7 +17,7 @@ def get_geozone_source(geozone_source_id: str): """This handler returns the state of a geozone source""" - source = Database.get_source(db, geozone_source_id) + source = database.get_source(db, geozone_source_id) if source is None: return f"source {geozone_source_id} not found or deleted", 404 return ( @@ -30,7 +30,7 @@ def create_geozone_source(id, source_definition: CreateGeozoneSourceRequest): """This handler creates and activates a geozone source""" try: - source = Database.insert_source( + source = database.insert_source( db, id, source_definition, GeozoneSourceResponseResult.Activating ) except ExistingRecordException: @@ -41,12 +41,12 @@ def create_geozone_source(id, source_definition: CreateGeozoneSourceRequest): raw_data = requests.get(source.definition.https_source.url).json() if source.definition.https_source.format == GeozoneHttpsSourceFormat.ED_269: geozones = ED269Schema.from_dict(raw_data) - Database.update_source_geozone_ed269(db, id, geozones) - source = Database.update_source_state( + database.update_source_geozone_ed269(db, id, geozones) + source = database.update_source_state( db, id, GeozoneSourceResponseResult.Ready ) except ValueError as e: - source = Database.update_source_state( + source = database.update_source_state( db, id, GeozoneSourceResponseResult.Error, @@ -54,7 +54,7 @@ def create_geozone_source(id, source_definition: CreateGeozoneSourceRequest): ) else: - source = Database.update_source_state( + source = database.update_source_state( db, id, GeozoneSourceResponseResult.Error, @@ -70,7 +70,7 @@ def create_geozone_source(id, source_definition: CreateGeozoneSourceRequest): def delete_geozone_source(geozone_source_id): """This handler deactivates and deletes a geozone source""" - deleted_id = Database.delete_source(db, geozone_source_id) + deleted_id = database.delete_source(db, geozone_source_id) if deleted_id is None: return f"source {geozone_source_id} not found", 404 diff --git a/monitoring/mock_uss/msgsigning/database.py b/monitoring/mock_uss/msgsigning/database.py index b03b4df005..64661f2aa0 100644 --- a/monitoring/mock_uss/msgsigning/database.py +++ b/monitoring/mock_uss/msgsigning/database.py @@ -14,7 +14,7 @@ class Database(ImplicitDict): private_key_name: str = "messagesigning/mock_faa_priv.pem" -db = SynchronizedValue( +db = SynchronizedValue[Database]( Database(), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database), ) diff --git a/monitoring/mock_uss/riddp/database.py b/monitoring/mock_uss/riddp/database.py index 0ab4adcc2f..d62da41fca 100644 --- a/monitoring/mock_uss/riddp/database.py +++ b/monitoring/mock_uss/riddp/database.py @@ -55,7 +55,7 @@ class Database(ImplicitDict): subscriptions: list[ObservationSubscription] -db = SynchronizedValue( +db = SynchronizedValue[Database]( Database(flights={}, subscriptions=[]), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database), ) diff --git a/monitoring/mock_uss/riddp/routes_behavior.py b/monitoring/mock_uss/riddp/routes_behavior.py index df363f055c..a299ef79b7 100644 --- a/monitoring/mock_uss/riddp/routes_behavior.py +++ b/monitoring/mock_uss/riddp/routes_behavior.py @@ -8,7 +8,7 @@ @webapp.route("/riddp/behavior", methods=["PUT"]) -def riddp_set_dp_behavior() -> tuple[str, int]: +def riddp_set_dp_behavior() -> tuple[str, int] | flask.Response: """Set the behavior of the mock Display Provider.""" try: json = flask.request.json @@ -19,13 +19,13 @@ def riddp_set_dp_behavior() -> tuple[str, int]: msg = f"Change behavior for Display Provider unable to parse JSON: {e}" return msg, 400 - with db as tx: - tx.behavior = dp_behavior + with db.transact() as tx: + tx.value.behavior = dp_behavior return flask.jsonify(dp_behavior) @webapp.route("/riddp/behavior", methods=["GET"]) -def riddp_get_dp_behavior() -> tuple[str, int]: +def riddp_get_dp_behavior() -> flask.Response: """Get the behavior of the mock Display Provider.""" return flask.jsonify(db.value.behavior) diff --git a/monitoring/mock_uss/riddp/routes_observation.py b/monitoring/mock_uss/riddp/routes_observation.py index 5511b6046d..0a12a4cd9e 100644 --- a/monitoring/mock_uss/riddp/routes_observation.py +++ b/monitoring/mock_uss/riddp/routes_observation.py @@ -135,16 +135,19 @@ def riddp_display_data() -> tuple[flask.Response, int]: 413, ) - with db as tx: + with db.transact() as tx: # Find an existing subscription to serve this request subscription: ObservationSubscription | None = None t_max = ( arrow.utcnow() + timedelta(seconds=1) ).datetime # Don't rely on subscriptions very near their expiration - tx.subscriptions = [ - s for s in tx.subscriptions if s.upsert_result.subscription.time_end > t_max + tx.value.subscriptions = [ + s + for s in tx.value.subscriptions + if s.upsert_result.subscription + and s.upsert_result.subscription.time_end > t_max ] - for existing_subscription in tx.subscriptions: + for existing_subscription in tx.value.subscriptions: assert isinstance(existing_subscription, ObservationSubscription) sub_rect = existing_subscription.bounds.to_latlngrect() if sub_rect.contains(view): @@ -184,7 +187,7 @@ def riddp_display_data() -> tuple[flask.Response, int]: subscription = ObservationSubscription( bounds=sub_bounds, upsert_result=upsert_result, updates=[] ) - tx.subscriptions.append(subscription) + tx.value.subscriptions.append(subscription) # Fetch flights from each unique flights URL validated_flights: list[Flight] = [] @@ -222,9 +225,9 @@ def riddp_display_data() -> tuple[flask.Response, int]: flight_info[flight.id] = database.FlightInfo(flights_url=flights_url) # Update links between flight IDs and flight URLs - with db as tx: + with db.transact() as tx: for k, v in flight_info.items(): - tx.flights[k] = v + tx.value.flights[k] = v # Make and return response flights = [_make_flight_observation(f, view) for f in validated_flights] @@ -243,7 +246,7 @@ def riddp_display_data() -> tuple[flask.Response, int]: @webapp.route("/riddp/observation/display_data/", methods=["GET"]) @requires_scope(Scope.Read) -def riddp_flight_details(flight_id: str) -> tuple[str, int]: +def riddp_flight_details(flight_id: str) -> tuple[str, int] | flask.Response: """Implements get flight details endpoint per automated testing API.""" tx = db.value flight_info = tx.flights.get(flight_id) diff --git a/monitoring/mock_uss/riddp/routes_riddp_v19.py b/monitoring/mock_uss/riddp/routes_riddp_v19.py index 4586d549a1..24f8897c94 100644 --- a/monitoring/mock_uss/riddp/routes_riddp_v19.py +++ b/monitoring/mock_uss/riddp/routes_riddp_v19.py @@ -37,10 +37,12 @@ def riddp_notify_isa_v19(id: str): subscription_ids = [s.subscription_id for s in put_params.subscriptions] if subscription_ids: - with db as tx: + with db.transact() as tx: updated = False - for subscription in tx.subscriptions: + for subscription in tx.value.subscriptions: + if not subscription.upsert_result.subscription: + continue if subscription.upsert_result.subscription.id in subscription_ids: query = describe_flask_query(flask.request, flask.jsonify(None), 0) subscription.updates.append(UpdatedISA(v19_query=query)) diff --git a/monitoring/mock_uss/riddp/routes_riddp_v22a.py b/monitoring/mock_uss/riddp/routes_riddp_v22a.py index 0f6a47c938..7d643c0dae 100644 --- a/monitoring/mock_uss/riddp/routes_riddp_v22a.py +++ b/monitoring/mock_uss/riddp/routes_riddp_v22a.py @@ -37,10 +37,12 @@ def riddp_notify_isa_v22a(id: str): subscription_ids = [s.subscription_id for s in put_params.subscriptions] if subscription_ids: - with db as tx: + with db.transact() as tx: updated = False - for subscription in tx.subscriptions: + for subscription in tx.value.subscriptions: + if not subscription.upsert_result.subscription: + continue if subscription.upsert_result.subscription.id in subscription_ids: query = describe_flask_query(flask.request, flask.jsonify(None), 0) subscription.updates.append(UpdatedISA(v22a_query=query)) diff --git a/monitoring/mock_uss/ridsp/database.py b/monitoring/mock_uss/ridsp/database.py index 736b47c54c..54497f1464 100644 --- a/monitoring/mock_uss/ridsp/database.py +++ b/monitoring/mock_uss/ridsp/database.py @@ -34,7 +34,7 @@ class Database(ImplicitDict): notifications: ServiceProviderUserNotifications = ServiceProviderUserNotifications() -db = SynchronizedValue( +db = SynchronizedValue[Database]( Database(), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database), ) diff --git a/monitoring/mock_uss/ridsp/routes_behavior.py b/monitoring/mock_uss/ridsp/routes_behavior.py index 5a45f65bba..f3f131150f 100644 --- a/monitoring/mock_uss/ridsp/routes_behavior.py +++ b/monitoring/mock_uss/ridsp/routes_behavior.py @@ -8,7 +8,7 @@ @webapp.route("/ridsp/behavior", methods=["PUT"]) -def ridsp_set_dp_behavior() -> tuple[str, int]: +def ridsp_set_dp_behavior() -> tuple[str, int] | flask.Response: """Set the behavior of the mock Display Provider.""" try: json = flask.request.json @@ -19,13 +19,13 @@ def ridsp_set_dp_behavior() -> tuple[str, int]: msg = f"Change behavior for Service Provider unable to parse JSON: {e}" return msg, 400 - with db as tx: - tx.behavior = dp_behavior + with db.transact() as tx: + tx.value.behavior = dp_behavior return flask.jsonify(dp_behavior) @webapp.route("/ridsp/behavior", methods=["GET"]) -def ridsp_get_dp_behavior() -> tuple[str, int]: +def ridsp_get_dp_behavior() -> flask.Response: """Get the behavior of the mock Display Provider.""" return flask.jsonify(db.value.behavior) diff --git a/monitoring/mock_uss/ridsp/routes_injection.py b/monitoring/mock_uss/ridsp/routes_injection.py index 1542c3ad13..418bc28679 100644 --- a/monitoring/mock_uss/ridsp/routes_injection.py +++ b/monitoring/mock_uss/ridsp/routes_injection.py @@ -41,7 +41,7 @@ class ErrorResponse(ImplicitDict): @webapp.route("/ridsp/injection/tests/", methods=["PUT"]) @requires_scope(injection_api.SCOPE_RID_QUALIFIER_INJECT) @idempotent_request() -def ridsp_create_test(test_id: str) -> tuple[str, int]: +def ridsp_create_test(test_id: str) -> tuple[str | flask.Response, int]: """Implements test creation in RID automated testing injection API.""" logger.info(f"Create test {test_id}") rid_version = webapp.config[KEY_RID_VERSION] @@ -109,13 +109,13 @@ def ridsp_create_test(test_id: str) -> tuple[str, int]: response["query"] = notification.query return flask.jsonify(response), 412 - with db as tx: - tx.tests[test_id] = record - tx.notifications.create_notifications_if_needed(record) + with db.transact() as tx: + tx.value.tests[test_id] = record + tx.value.notifications.create_notifications_if_needed(record) return flask.jsonify( ChangeTestResponse(version=record.version, injected_flights=record.flights) - ) + ), 200 @webapp.route("/ridsp/injection/tests//", methods=["DELETE"]) @@ -165,8 +165,8 @@ def ridsp_delete_test(test_id: str, version: str) -> tuple[str | flask.Response, ) result["query"] = notification.query - with db as tx: - del tx.tests[test_id] + with db.transact() as tx: + del tx.value.tests[test_id] return flask.jsonify(result), 200 @@ -175,7 +175,7 @@ def ridsp_delete_test(test_id: str, version: str) -> tuple[str | flask.Response, methods=["GET"], ) @requires_scope(injection_api.SCOPE_RID_QUALIFIER_INJECT) -def ridsp_get_user_notifications() -> tuple[str, int]: +def ridsp_get_user_notifications() -> tuple[str | flask.Response, int]: """Returns the list of user notifications observed by the virtual user""" if "after" not in flask.request.args: @@ -222,4 +222,4 @@ def ridsp_get_user_notifications() -> tuple[str, int]: r = QueryUserNotificationsResponse(user_notifications=final_list) - return flask.jsonify(r) + return flask.jsonify(r), 200 diff --git a/monitoring/mock_uss/server.py b/monitoring/mock_uss/server.py index 1800d1eb2d..fe2fc7171e 100644 --- a/monitoring/mock_uss/server.py +++ b/monitoring/mock_uss/server.py @@ -14,7 +14,7 @@ from loguru import logger from ..monitorlib.errors import stacktrace_string -from .database import Database, PeriodicTaskStatus, TaskError, db +from .database import PeriodicTaskStatus, TaskError, db MAX_PERIODIC_LATENCY = timedelta(seconds=5) @@ -87,10 +87,10 @@ def shutdown_task_decorator(func): def _run_one_time_tasks(self, trigger: TaskTrigger): tasks: dict[str, OneTimeServerTask] = {} - with db as tx: + with db.transact() as tx: for task_name, task in self._one_time_tasks.items(): - if task.trigger == trigger and task_name not in tx.one_time_tasks: - tx.one_time_tasks.append(task_name) + if task.trigger == trigger and task_name not in tx.value.one_time_tasks: + tx.value.one_time_tasks.append(task_name) tasks[task_name] = task if not tasks: logger.info(f"No {trigger} tasks to initiate from process ID {os.getpid()}") @@ -105,8 +105,8 @@ def _run_one_time_tasks(self, trigger: TaskTrigger): try: setup_task.run() except Exception as e: - with db as tx: - tx.task_errors.append(TaskError.from_exception(trigger, e)) + with db.transact() as tx: + tx.value.task_errors.append(TaskError.from_exception(trigger, e)) if trigger == TaskTrigger.Shutdown: logger.error( f"{type(e).__name__} error in '{task_name}' on process ID {os.getpid()} while shutting down mock_uss: {str(e)}\n{stacktrace_string(e)}" @@ -148,11 +148,10 @@ def set_task_period(self, task_name: str, period: timedelta | None): raise ValueError( f"Periodic task '{task_name}' is not declared, so its period cannot be set" ) - with db as tx: - assert isinstance(tx, Database) - if task_name not in tx.periodic_tasks: - tx.periodic_tasks[task_name] = PeriodicTaskStatus() - tx.periodic_tasks[task_name].period = ( + with db.transact() as tx: + if task_name not in tx.value.periodic_tasks: + tx.value.periodic_tasks[task_name] = PeriodicTaskStatus() + tx.value.periodic_tasks[task_name].period = ( StringBasedTimeDelta(period) if period is not None else None ) @@ -172,21 +171,20 @@ def _periodic_tasks_daemon_loop(self): # Determine what to do on this loop (execute task or wait) task_to_execute = None next_check = None - with db as tx: - assert isinstance(tx, Database) - tx.most_recent_periodic_check = StringBasedDateTime( + with db.transact() as tx: + tx.value.most_recent_periodic_check = StringBasedDateTime( datetime.now(UTC) ) # Cancel the loop if we're stopping - if tx.stopping: + if tx.value.stopping: break # Find the earliest scheduled task earliest_task: tuple[str, datetime, PeriodicTaskStatus] | None = ( None ) - for task_name, task in tx.periodic_tasks.items(): + for task_name, task in tx.value.periodic_tasks.items(): if task.executing: # Don't consider executing tasks that are already executing continue @@ -216,7 +214,7 @@ def _periodic_tasks_daemon_loop(self): task_name, t_execute, task = earliest_task if t_execute <= arrow.utcnow().datetime: # We should execute this task immediately - tx.periodic_tasks[task_name] = PeriodicTaskStatus( + tx.value.periodic_tasks[task_name] = PeriodicTaskStatus( last_execution_time=StringBasedDateTime( arrow.utcnow().datetime ), @@ -227,7 +225,7 @@ def _periodic_tasks_daemon_loop(self): else: # We need to wait some time before executing this task next_check = t_execute - # + # if task_to_execute: # Execute the selected task right now @@ -235,8 +233,8 @@ def _periodic_tasks_daemon_loop(self): f"Executing '{task_to_execute}' periodic task from process {os.getpid()}" ) self._periodic_tasks[task_to_execute].run() - with db as tx: - periodic_task = tx.periodic_tasks[task_to_execute] + with db.transact() as tx: + periodic_task = tx.value.periodic_tasks[task_to_execute] periodic_task.executing = False if ( "period" in periodic_task @@ -260,8 +258,10 @@ def _periodic_tasks_daemon_loop(self): logger.error( f"Shutting down mock_uss due to {type(e).__name__} error while executing '{task_to_execute}' periodic task: {str(e)}\n{stacktrace_string(e)}" ) - with db as tx: - tx.task_errors.append(TaskError.from_exception(TaskTrigger.Setup, e)) + with db.transact() as tx: + tx.value.task_errors.append( + TaskError.from_exception(TaskTrigger.Setup, e) + ) self.stop() finally: logger.info(f"Periodic task daemon for process {os.getpid()} exited") @@ -271,10 +271,10 @@ def is_stopping(self) -> bool: def stop(self): send_signal = False - with db as tx: - if not tx.stopping: + with db.transact() as tx: + if not tx.value.stopping: send_signal = True - tx.stopping = True + tx.value.stopping = True if send_signal: logger.info( f"Initiating shutdown of MockUSS process {self._pid} from process {os.getpid()}" diff --git a/monitoring/mock_uss/tracer/database.py b/monitoring/mock_uss/tracer/database.py index 208286256e..07b364961b 100644 --- a/monitoring/mock_uss/tracer/database.py +++ b/monitoring/mock_uss/tracer/database.py @@ -29,7 +29,7 @@ class Database(ImplicitDict): """Interval at which polling of observation areas should occur""" -db = SynchronizedValue( +db = SynchronizedValue[Database]( Database(observation_areas={}), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), Database), ) diff --git a/monitoring/mock_uss/tracer/routes/observation_areas.py b/monitoring/mock_uss/tracer/routes/observation_areas.py index b05ec1b444..5a1d179409 100644 --- a/monitoring/mock_uss/tracer/routes/observation_areas.py +++ b/monitoring/mock_uss/tracer/routes/observation_areas.py @@ -38,9 +38,11 @@ @webapp.route("/tracer/observation_areas", methods=["GET"]) @ui_auth.login_required() def tracer_list_observation_areas() -> flask.Response: - with db as tx: + with db.transact() as tx: result = ListObservationAreasResponse( - areas=[redact_observation_area(a) for a in tx.observation_areas.values()] + areas=[ + redact_observation_area(a) for a in tx.value.observation_areas.values() + ] ) return flask.jsonify(result) @@ -62,24 +64,24 @@ def tracer_upsert_observation_area( msg = f"Upsert observation area for tracer unable to parse JSON: {e}" return msg, 400 - with db as tx: + with db.transact() as tx: # Determine if this observation area triggers the need to start polling - if tx.observation_areas: - poll_interval = tx.polling_interval.timedelta - for a in tx.observation_areas.values(): + if tx.value.observation_areas: + poll_interval = tx.value.polling_interval.timedelta + for a in tx.value.observation_areas.values(): if a.polls: poll_interval = None break else: poll_interval = ( - tx.polling_interval.timedelta if request.area.polls else None + tx.value.polling_interval.timedelta if request.area.polls else None ) - if area_id in tx.observation_areas: + if area_id in tx.value.observation_areas: # Request is to mutate an existing observation area, so we'll first just delete the existing area - delete_observation_area(tx.observation_areas[area_id]) + delete_observation_area(tx.value.observation_areas[area_id]) created = create_observation_area(area_id, request.area) - tx.observation_areas[area_id] = created + tx.value.observation_areas[area_id] = created if poll_interval is not None: webapp.set_task_period(TASK_POLL_OBSERVATION_AREAS, poll_interval) @@ -91,13 +93,13 @@ def tracer_upsert_observation_area( def tracer_delete_observation_area( area_id: str, ) -> tuple[str, int] | flask.Response: - with db as tx: - if area_id not in tx.observation_areas: + with db.transact() as tx: + if area_id not in tx.value.observation_areas: return "Specified observation area not in system", 404 - area = tx.observation_areas.pop(area_id) + area = tx.value.observation_areas.pop(area_id) area = delete_observation_area(area) remaining_polling_areas = sum( - 1 if a.polls else 0 for a in tx.observation_areas.values() + 1 if a.polls else 0 for a in tx.value.observation_areas.values() ) if not remaining_polling_areas: @@ -180,11 +182,13 @@ def tracer_import_observation_areas() -> tuple[str, int] | flask.Response: "Import of F3548 subscriptions into observation areas is not yet implemented" ) - with db as tx: + with db.transact() as tx: new_obs_areas = [] f3411_subscription_ids = { - a.f3411.subscription_id for a in tx.observation_areas.values() if a.f3411 + a.f3411.subscription_id + for a in tx.value.observation_areas.values() + if a.f3411 } new_obs_areas.extend( a @@ -193,7 +197,9 @@ def tracer_import_observation_areas() -> tuple[str, int] | flask.Response: ) f3548_subscription_ids = { - a.f3548.subscription_id for a in tx.observation_areas.values() if a.f3548 + a.f3548.subscription_id + for a in tx.value.observation_areas.values() + if a.f3548 } new_obs_areas.extend( a @@ -202,7 +208,7 @@ def tracer_import_observation_areas() -> tuple[str, int] | flask.Response: ) for obs_area in new_obs_areas: - tx.observation_areas[obs_area.id] = obs_area + tx.value.observation_areas[obs_area.id] = obs_area return flask.jsonify( ListObservationAreasResponse( @@ -217,9 +223,9 @@ def _shutdown(): f"Cleaning up observation areas from PID {os.getpid()} at {datetime.now(UTC)}..." ) - with db as tx: - observation_areas: list[ObservationArea] = [v for _, v in tx.observation_areas] - tx.observation_areas.clear() + with db.transact() as tx: + observation_areas = list(tx.value.observation_areas.values()) + tx.value.observation_areas.clear() for area in observation_areas: delete_observation_area(area) diff --git a/monitoring/mock_uss/tracer/routes/ui.py b/monitoring/mock_uss/tracer/routes/ui.py index d2706d484f..061550c453 100644 --- a/monitoring/mock_uss/tracer/routes/ui.py +++ b/monitoring/mock_uss/tracer/routes/ui.py @@ -180,11 +180,10 @@ def tracer_kml_historical(): def _get_validated_obs_area(observation_area_id: str) -> ObservationArea: - with db as tx: - if observation_area_id not in tx.observation_areas: - flask.abort(404, "Specified observation area not found") - area: ObservationArea = tx.observation_areas[observation_area_id] - return area + tx = db.value + if observation_area_id not in tx.observation_areas: + flask.abort(404, "Specified observation area not found") + return tx.observation_areas[observation_area_id] @webapp.route("/tracer/observation_areas//ui", methods=["GET"]) diff --git a/monitoring/mock_uss/tracer/tracer_poll.py b/monitoring/mock_uss/tracer/tracer_poll.py index def9c84f64..886ac30aca 100755 --- a/monitoring/mock_uss/tracer/tracer_poll.py +++ b/monitoring/mock_uss/tracer/tracer_poll.py @@ -46,7 +46,7 @@ class PollingStatus(ImplicitDict): started: bool = False -polling_status = SynchronizedValue( +polling_status = SynchronizedValue[PollingStatus]( PollingStatus(), capacity_bytes=1000, decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), PollingStatus), @@ -60,7 +60,7 @@ class PollingValues(ImplicitDict): last_constraints_result: FetchedEntities | None = None -polling_values = SynchronizedValue( +polling_values = SynchronizedValue[PollingValues]( PollingValues(), decoder=lambda b: ImplicitDict.parse(json.loads(b.decode("utf-8")), PollingValues), ) @@ -73,10 +73,10 @@ def print_no_newline(s): def _log_poll_start(logger): init = False - with polling_status as tx: - if not tx.started: + with polling_status.transact() as tx: + if not tx.value.started: init = True - tx.started = True + tx.value.started = True if init: config = { KEY_TRACER_OUTPUT_FOLDER: webapp.config[KEY_TRACER_OUTPUT_FOLDER], @@ -129,18 +129,17 @@ def poll_isas(area: ObservationArea, logger: tracerlog.Logger) -> None: log_new = False last_result = None - with polling_values as tx: - assert isinstance(tx, PollingValues) - if tx.last_isa_result is None or result.has_different_content_than( - tx.last_isa_result + with polling_values.transact() as tx: + if tx.value.last_isa_result is None or result.has_different_content_than( + tx.value.last_isa_result ): - last_result = tx.last_isa_result + last_result = tx.value.last_isa_result log_new = True - tx.need_line_break = False - tx.last_isa_result = result + tx.value.need_line_break = False + tx.value.last_isa_result = result else: - tx.need_line_break = True - need_line_break = tx.need_line_break + tx.value.need_line_break = True + need_line_break = tx.value.need_line_break log_entry = PollISAs(poll=result, recorded_at=StringBasedDateTime(arrow.utcnow())) if log_new: @@ -174,17 +173,17 @@ def poll_ops( log_new = False last_result = None - with polling_values as tx: - if tx.last_ops_result is None or result.has_different_content_than( - tx.last_ops_result + with polling_values.transact() as tx: + if tx.value.last_ops_result is None or result.has_different_content_than( + tx.value.last_ops_result ): - last_result = tx.last_ops_result + last_result = tx.value.last_ops_result log_new = True - tx.need_line_break = False - tx.last_ops_result = result + tx.value.need_line_break = False + tx.value.last_ops_result = result else: - tx.need_line_break = True - need_line_break = tx.need_line_break + tx.value.need_line_break = True + need_line_break = tx.value.need_line_break log_entry = PollOperationalIntents( poll=result, recorded_at=StringBasedDateTime(arrow.utcnow()) @@ -220,15 +219,15 @@ def poll_constraints( log_new = False last_result = None - with polling_values as tx: - if result.has_different_content_than(tx.last_constraints_result): - last_result = tx.last_constraints_result + with polling_values.transact() as tx: + if result.has_different_content_than(tx.value.last_constraints_result): + last_result = tx.value.last_constraints_result log_new = True - tx.need_line_break = False - tx.last_constraints_result = result + tx.value.need_line_break = False + tx.value.last_constraints_result = result else: - tx.need_line_break = True - need_line_break = tx.need_line_break + tx.value.need_line_break = True + need_line_break = tx.value.need_line_break log_entry = PollConstraints( poll=result, recorded_at=StringBasedDateTime(arrow.utcnow()) diff --git a/monitoring/monitorlib/idempotency.py b/monitoring/monitorlib/idempotency.py index e5ed3b1b1e..ebff7b6088 100644 --- a/monitoring/monitorlib/idempotency.py +++ b/monitoring/monitorlib/idempotency.py @@ -51,7 +51,7 @@ def _set_responses(responses: dict[str, Response]) -> bytes: return s.encode("utf-8") -_fulfilled_requests = SynchronizedValue[dict]( +_fulfilled_requests = SynchronizedValue[dict[str, Response]]( {}, decoder=_get_responses, encoder=_set_responses, @@ -59,7 +59,7 @@ def _set_responses(responses: dict[str, Response]) -> bytes: ) -def get_hashed_request_id() -> str | None: +def get_hashed_request_id() -> str: """Retrieves an identifier for the request by hashing key characteristics of the request.""" characteristics = flask.request.method + flask.request.url if flask.request.json: @@ -71,7 +71,7 @@ def get_hashed_request_id() -> str | None: ).decode("utf-8") -def idempotent_request(get_request_id: Callable[[], str | None] | None = None): +def idempotent_request(get_request_id: Callable[[], str] | None = None): """Decorator for idempotent Flask view handlers. When subsequent requests are received with the same request identifier, this decorator will use a recent cached @@ -104,20 +104,20 @@ def wrapper(*args, **kwargs): request_id, ) response = cached_requests[request_id] - if response["body"] is not None: - return response["body"], response["code"] + if response.body is not None: + return response.body, response.code else: - return flask.jsonify(response["json"]), response["code"] + return flask.jsonify(response.json), response.code result = fn(*args, **kwargs) to_return = result - response = { - "timestamp": arrow.utcnow().isoformat(), - "code": 200, - "body": None, - "json": None, - } + response = Response( + timestamp=arrow.utcnow().isoformat(), + code=200, + body=None, + json=None, + ) keep_code = False if isinstance(result, tuple): if len(result) == 2: @@ -125,7 +125,7 @@ def wrapper(*args, **kwargs): raise NotImplementedError( f"Unable to cache Flask view handler result where the second 2-tuple element is a '{type(result[1]).__name__}'" ) - response["code"] = result[1] + response.code = result[1] keep_code = True result = result[0] else: @@ -134,15 +134,15 @@ def wrapper(*args, **kwargs): ) if isinstance(result, str): - response["body"] = result - response["json"] = None + response.body = result + response.json = None elif isinstance(result, flask.Response): try: - response["json"] = result.get_json() + response.json = result.get_json() except ValueError: - response["body"] = result.get_data(as_text=True) + response.body = result.get_data(as_text=True) if not keep_code: - response["code"] = result.status_code + response.code = result.status_code else: raise NotImplementedError( f"Unable to cache Flask view handler result of type '{type(result).__name__}'" diff --git a/monitoring/monitorlib/multiprocessing.py b/monitoring/monitorlib/multiprocessing.py index ee177f1fe0..d74efc6413 100644 --- a/monitoring/monitorlib/multiprocessing.py +++ b/monitoring/monitorlib/multiprocessing.py @@ -5,8 +5,6 @@ from multiprocessing.synchronize import RLock as RLockT from typing import Generic, TypeVar -import deprecation - TValue = TypeVar("TValue") @@ -175,20 +173,3 @@ def value(self) -> TValue: def transact(self) -> Transaction[TValue]: return Transaction[TValue](self._lock, self._get_value, self._set_value) - - @deprecation.deprecated(details="Use `value` of transact() method instead.") - def __enter__(self): - if self._transaction: - raise RuntimeError( - "SynchronizedValue transaction started when another transaction was in progress" - ) - self._transaction = self.transact() - self._transaction.__enter__() - return self._transaction.value - - @deprecation.deprecated(details="Use `value` of transact() method instead.") - def __exit__(self, exc_type, exc_val, exc_tb): - if not self._transaction: - return - self._transaction.__exit__(exc_type, exc_val, exc_tb) - self._transaction = None