Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions api/app/handlers.py

This file was deleted.

14 changes: 9 additions & 5 deletions api/app/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
from collections import OrderedDict
from typing import Any, List, Optional, Type

from drf_yasg import openapi # type: ignore[import-untyped]
from drf_yasg.inspectors import PaginatorInspector # type: ignore[import-untyped]
Expand All @@ -17,7 +18,7 @@ class CustomPagination(PageNumberPagination):

class EdgeIdentityPaginationInspector(PaginatorInspector): # type: ignore[misc]
def get_paginator_parameters(
self, paginator: BasePagination
self, paginator: Type[BasePagination]
) -> list[openapi.Parameter]:
"""
:param BasePagination paginator: the paginator
Expand All @@ -40,13 +41,14 @@ def get_paginator_parameters(
),
]

def get_paginated_response(self, paginator, response_schema): # type: ignore[no-untyped-def]
def get_paginated_response(
self, paginator: Type[BasePagination], response_schema: openapi.Schema
) -> openapi.Schema:
"""
:param BasePagination paginator: the paginator
:param openapi.Schema response_schema: the response schema that must be paged.
:rtype: openapi.Schema
"""

return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties=OrderedDict(
Expand All @@ -69,7 +71,9 @@ class EdgeIdentityPagination(CustomPagination):
max_page_size = 100
page_size = 100

def paginate_queryset(self, dynamo_queryset, request, view=None): # type: ignore[no-untyped-def]
def paginate_queryset(
self, dynamo_queryset: Any, request: Any, view: Optional[Any] = None
) -> Optional[List[Any]]:
last_evaluated_key = dynamo_queryset.get("LastEvaluatedKey")
if last_evaluated_key:
self.last_evaluated_key = base64.b64encode(
Expand All @@ -81,7 +85,7 @@ def paginate_queryset(self, dynamo_queryset, request, view=None): # type: ignor
for identity_document in dynamo_queryset["Items"]
]

def get_paginated_response(self, data) -> Response: # type: ignore[no-untyped-def]
def get_paginated_response(self, data: Any) -> Response:
"""
Note: "If the size of the Query result set is larger than 1 MB, ScannedCount
and Count represent only a partial count of the total items"
Expand Down
36 changes: 25 additions & 11 deletions api/app/routers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import random
from enum import Enum
from typing import Any, Optional, Type

from django.conf import settings
from django.core.cache import cache
from django.db import connections
from django.db.models import Model
from django_stubs_ext.db.router import TypedDatabaseRouter
Copy link
Member

@khvn26 khvn26 Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add django_stubs_ext as an explicit runtime dependency.

Please run poetry add django-stubs-ext and commit the resulting changes to poetry.lock and pyproject.toml.

An alternative to this would be adding a conditional import:

Suggested change
from django_stubs_ext.db.router import TypedDatabaseRouter
if TYPE_CHECKING:
from django_stubs_ext.db.router import TypedDatabaseRouter
else:
TypedDatabaseRouter = object


from .exceptions import ImproperlyConfiguredError

Expand Down Expand Up @@ -46,8 +49,8 @@ def connection_check(database: str) -> bool:
return usable


class PrimaryReplicaRouter:
def db_for_read(self, model, **hints): # type: ignore[no-untyped-def]
class PrimaryReplicaRouter(TypedDatabaseRouter):
def db_for_read(self, model: Type[Model], **hints: Any) -> Optional[str]:
if settings.NUM_DB_REPLICAS == 0:
return "default"

Expand Down Expand Up @@ -75,10 +78,12 @@ def db_for_read(self, model, **hints): # type: ignore[no-untyped-def]
)
return "default"

def db_for_write(self, model, **hints): # type: ignore[no-untyped-def]
def db_for_write(self, model: Type[Model], **hints: Any) -> Optional[str]:
return "default"

def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def]
def allow_relation(
self, obj1: Type[Model], obj2: Type[Model], **hints: Any
) -> Optional[bool]:
"""
Relations between objects are allowed if both objects are
in the primary/replica pool.
Expand All @@ -95,10 +100,12 @@ def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def]
return True
return None

def allow_migrate(self, db, app_label, model_name=None, **hints): # type: ignore[no-untyped-def]
def allow_migrate(
self, db: str, app_label: str, model_name: str | None = None, **hints: Any
) -> Optional[bool]:
return db == "default"

def _get_replica(self, replicas: list[str]) -> None | str: # type: ignore[return]
def _get_replica(self, replicas: list[str]) -> None | str:
while replicas:
if settings.REPLICA_READ_STRATEGY == ReplicaReadStrategy.DISTRIBUTED:
database = random.choice(replicas)
Expand All @@ -119,27 +126,32 @@ def _get_replica(self, replicas: list[str]) -> None | str: # type: ignore[retur
if connection_check(database):
return database

# If no replicas are available, return None
return None


class AnalyticsRouter:
class AnalyticsRouter(TypedDatabaseRouter):
route_app_labels = ["app_analytics"]

def db_for_read(self, model, **hints): # type: ignore[no-untyped-def]
def db_for_read(self, model: Type[Model], **hints: Any) -> Optional[str]:
"""
Attempts to read analytics models go to 'analytics' database.
"""
if model._meta.app_label in self.route_app_labels:
return "analytics"
return None

def db_for_write(self, model, **hints): # type: ignore[no-untyped-def]
def db_for_write(self, model: Type[Model], **hints: Any) -> Optional[str]:
"""
Attempts to write analytics models go to 'analytics' database.
"""
if model._meta.app_label in self.route_app_labels:
return "analytics"
return None

def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def]
def allow_relation(
self, obj1: Type[Model], obj2: Type[Model], **hints: Any
) -> Optional[bool]:
"""
Relations between objects are allowed if both objects are
in the analytics database.
Expand All @@ -151,7 +163,9 @@ def allow_relation(self, obj1, obj2, **hints): # type: ignore[no-untyped-def]
return True
return None

def allow_migrate(self, db, app_label, model_name=None, **hints): # type: ignore[no-untyped-def]
def allow_migrate(
self, db: str, app_label: str, model_name: str | None = None, **hints: Any
) -> Optional[bool]:
"""
Make sure the analytics app only appears in the 'analytics' database
"""
Expand Down
3 changes: 2 additions & 1 deletion api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ tzdata = "^2024.1"
djangorestframework-simplejwt = "^5.3.1"
structlog = "^24.4.0"
prometheus-client = "^0.21.1"
django-stubs-ext = "^5.1.3"

[tool.poetry.group.auth-controller]
optional = true
Expand Down
14 changes: 7 additions & 7 deletions api/tests/unit/app/test_unit_app_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_replica_router_db_for_read_with_one_offline_replica(
router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser) # type: ignore[no-untyped-call]
result = router.db_for_read(FFAdminUser)

# Then
# Read strategy DISTRIBUTED is random, so just this is a check
# against loading the primary or one of the cross region replicas
assert result.startswith("replica_")
assert result is not None and result.startswith("replica_")

# Check that the number of replica call counts is as expected.
conn_call_count = 2
Expand Down Expand Up @@ -85,12 +85,12 @@ def test_replica_router_db_for_read_with_local_offline_replicas(
router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser) # type: ignore[no-untyped-call]
result = router.db_for_read(FFAdminUser)

# Then
# Read strategy DISTRIBUTED is random, so just this is a check
# against loading the primary or one of the cross region replicas
assert result.startswith("cross_region_replica_")
assert result is not None and result.startswith("cross_region_replica_")

# Check that the number of replica call counts is as expected.
conn_call_count = 6
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_replica_router_db_for_read_with_all_offline_replicas(
router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser) # type: ignore[no-untyped-call]
result = router.db_for_read(FFAdminUser)

# Then
# Fallback to primary database if all replicas are offline.
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_replica_router_db_with_sequential_read(
router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser) # type: ignore[no-untyped-call]
result = router.db_for_read(FFAdminUser)

# Then
# Fallback from first replica to second one.
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_replica_router_db_no_replicas(
router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser) # type: ignore[no-untyped-call]
result = router.db_for_read(FFAdminUser)

# Then
# Should always use primary database.
Expand Down