Skip to content

Commit

Permalink
feat(models): validate active TransitAgency
Browse files Browse the repository at this point in the history
when active=True, validate that:

- there are values for user-facing info fields like names, phone, etc.
- templates exist
  • Loading branch information
thekaveman committed Nov 5, 2024
1 parent 4ec2c38 commit 318e307
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 28 deletions.
96 changes: 68 additions & 28 deletions benefits/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from functools import cached_property
import importlib
import logging
from pathlib import Path
import uuid

from django import template
from django.conf import settings
from django.core.exceptions import ValidationError
from django.contrib.auth.models import Group, User
Expand All @@ -24,6 +26,22 @@
logger = logging.getLogger(__name__)


def template_path(template_name: str) -> Path:
"""Get a `pathlib.Path` for the named template, or None if it can't be found.
A `template_name` is the app-local name, e.g. `enrollment/success.html`.
Adapted from https://stackoverflow.com/a/75863472.
"""
for engine in template.engines.all():
for loader in engine.engine.template_loaders:
for origin in loader.get_template_sources(template_name):
path = Path(origin.name)
if path.exists():
return path
return None


class SecretNameField(models.SlugField):
"""Field that stores the name of a secret held in a secret store.
Expand Down Expand Up @@ -264,6 +282,30 @@ def transit_processor_client_secret(self):
def enrollment_flows(self):
return self.enrollmentflow_set

def clean(self):
if self.active:
errors = {}
message = "This field is required for active transit agencies."
needed = dict(
short_name=self.short_name,
long_name=self.long_name,
phone=self.phone,
info_url=self.info_url,
)
for k, v in needed.items():
if not v:
errors[k] = ValidationError(message)

if not template_path(self.index_template):
errors["index_template"] = ValidationError(f"Template not found: {self.index_template}")
if not template_path(self.eligibility_index_template):
errors["eligibility_index_template"] = ValidationError(
f"Template not found: {self.eligibility_index_template}"
)

if errors:
raise ValidationError(errors)

@staticmethod
def by_id(id):
"""Get a TransitAgency instance by its ID."""
Expand Down Expand Up @@ -493,6 +535,17 @@ def uses_claims_verification(self):
"""True if this flow verifies via the claims provider and has a scope and claim. False otherwise."""
return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim)

@property
def claims_scheme(self):
return self.claims_scheme_override or self.claims_provider.scheme

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims

@property
def eligibility_verifier(self):
"""A str representing the entity that verifies eligibility for this flow.
Expand Down Expand Up @@ -520,23 +573,6 @@ def enrollment_success_template(self):
else:
return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html"

def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)

def clean(self):
supports_expiration = self.supports_expiration
expiration_days = self.expiration_days
Expand All @@ -556,18 +592,22 @@ def clean(self):
if errors:
raise ValidationError(errors)

@property
def claims_scheme(self):
if not self.claims_scheme_override:
return self.claims_provider.scheme
return self.claims_scheme_override
def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims
# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)


class EnrollmentEvent(models.Model):
Expand Down
24 changes: 24 additions & 0 deletions tests/pytest/core/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,30 @@ def test_TransitAgency_for_user_in_group_not_linked_to_any_agency():
assert TransitAgency.for_user(user) is None


@pytest.mark.django_db
def test_TransitAgency_clean(model_TransitAgency_inactive):
model_TransitAgency_inactive.short_name = ""
model_TransitAgency_inactive.long_name = ""
model_TransitAgency_inactive.phone = ""
model_TransitAgency_inactive.info_url = ""
model_TransitAgency_inactive.index_template_override = "does/not/exist.html"
model_TransitAgency_inactive.eligibility_index_template_override = "does/not/exist.html"
# agency is inactive, OK to have incomplete fields
model_TransitAgency_inactive.clean()

# now mark it active and expect failure on clean()
model_TransitAgency_inactive.active = True
with pytest.raises(ValidationError) as e:
model_TransitAgency_inactive.clean()

assert "short_name" in e.value.error_dict
assert "long_name" in e.value.error_dict
assert "phone" in e.value.error_dict
assert "info_url" in e.value.error_dict
assert "index_template" in e.value.error_dict
assert "eligibility_index_template" in e.value.error_dict


@pytest.mark.django_db
def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow):
ts = timezone.now()
Expand Down

0 comments on commit 318e307

Please sign in to comment.