Skip to content

Commit 6012bff

Browse files
committed
feat(models): validate active TransitAgency
when active=True, validate that: - there are values for user-facing info fields like names, phone, etc. - templates exist
1 parent f57345e commit 6012bff

File tree

2 files changed

+121
-29
lines changed

2 files changed

+121
-29
lines changed

benefits/core/models.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from functools import cached_property
66
import importlib
77
import logging
8+
from pathlib import Path
89
import uuid
910

11+
from django import template
1012
from django.conf import settings
1113
from django.core.exceptions import ValidationError
1214
from django.contrib.auth.models import Group, User
@@ -24,6 +26,22 @@
2426
logger = logging.getLogger(__name__)
2527

2628

29+
def template_path(template_name: str) -> Path:
30+
"""Get a `pathlib.Path` for the named template, or None if it can't be found.
31+
32+
A `template_name` is the app-local name, e.g. `enrollment/success.html`.
33+
34+
Adapted from https://stackoverflow.com/a/75863472.
35+
"""
36+
for engine in template.engines.all():
37+
for loader in engine.engine.template_loaders:
38+
for origin in loader.get_template_sources(template_name):
39+
path = Path(origin.name)
40+
if path.exists() and path.is_file():
41+
return path
42+
return None
43+
44+
2745
class SecretNameField(models.SlugField):
2846
"""Field that stores the name of a secret held in a secret store.
2947
@@ -264,6 +282,30 @@ def transit_processor_client_secret(self):
264282
def enrollment_flows(self):
265283
return self.enrollmentflow_set
266284

285+
def clean(self):
286+
if self.active:
287+
errors = {}
288+
message = "This field is required for active transit agencies."
289+
needed = dict(
290+
short_name=self.short_name,
291+
long_name=self.long_name,
292+
phone=self.phone,
293+
info_url=self.info_url,
294+
)
295+
for k, v in needed.items():
296+
if not v:
297+
errors[k] = ValidationError(message)
298+
299+
if not template_path(self.index_template):
300+
errors["index_template"] = ValidationError(f"Template not found: {self.index_template}")
301+
if not template_path(self.eligibility_index_template):
302+
errors["eligibility_index_template"] = ValidationError(
303+
f"Template not found: {self.eligibility_index_template}"
304+
)
305+
306+
if errors:
307+
raise ValidationError(errors)
308+
267309
@staticmethod
268310
def by_id(id):
269311
"""Get a TransitAgency instance by its ID."""
@@ -493,6 +535,17 @@ def uses_claims_verification(self):
493535
"""True if this flow verifies via the claims provider and has a scope and claim. False otherwise."""
494536
return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim)
495537

538+
@property
539+
def claims_scheme(self):
540+
return self.claims_scheme_override or self.claims_provider.scheme
541+
542+
@property
543+
def claims_all_claims(self):
544+
claims = [self.claims_eligibility_claim]
545+
if self.claims_extra_claims is not None:
546+
claims.extend(self.claims_extra_claims.split())
547+
return claims
548+
496549
@property
497550
def eligibility_verifier(self):
498551
"""A str representing the entity that verifies eligibility for this flow.
@@ -520,23 +573,6 @@ def enrollment_success_template(self):
520573
else:
521574
return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html"
522575

523-
def eligibility_form_instance(self, *args, **kwargs):
524-
"""Return an instance of this flow's EligibilityForm, or None."""
525-
if not bool(self.eligibility_form_class):
526-
return None
527-
528-
# inspired by https://stackoverflow.com/a/30941292
529-
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
530-
FormClass = getattr(importlib.import_module(module_name), class_name)
531-
532-
return FormClass(*args, **kwargs)
533-
534-
@staticmethod
535-
def by_id(id):
536-
"""Get an EnrollmentFlow instance by its ID."""
537-
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
538-
return EnrollmentFlow.objects.get(id=id)
539-
540576
def clean(self):
541577
supports_expiration = self.supports_expiration
542578
expiration_days = self.expiration_days
@@ -556,18 +592,22 @@ def clean(self):
556592
if errors:
557593
raise ValidationError(errors)
558594

559-
@property
560-
def claims_scheme(self):
561-
if not self.claims_scheme_override:
562-
return self.claims_provider.scheme
563-
return self.claims_scheme_override
595+
def eligibility_form_instance(self, *args, **kwargs):
596+
"""Return an instance of this flow's EligibilityForm, or None."""
597+
if not bool(self.eligibility_form_class):
598+
return None
564599

565-
@property
566-
def claims_all_claims(self):
567-
claims = [self.claims_eligibility_claim]
568-
if self.claims_extra_claims is not None:
569-
claims.extend(self.claims_extra_claims.split())
570-
return claims
600+
# inspired by https://stackoverflow.com/a/30941292
601+
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
602+
FormClass = getattr(importlib.import_module(module_name), class_name)
603+
604+
return FormClass(*args, **kwargs)
605+
606+
@staticmethod
607+
def by_id(id):
608+
"""Get an EnrollmentFlow instance by its ID."""
609+
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
610+
return EnrollmentFlow.objects.get(id=id)
571611

572612

573613
class EnrollmentEvent(models.Model):

tests/pytest/core/test_models.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from datetime import timedelta
2+
from pathlib import Path
3+
24
from django.conf import settings
35
from django.contrib.auth.models import Group, User
46
from django.core.exceptions import ValidationError
57
from django.utils import timezone
68

79
import pytest
810

9-
from benefits.core.models import SecretNameField, EnrollmentFlow, TransitAgency, EnrollmentEvent, EnrollmentMethods
11+
from benefits.core.models import (
12+
template_path,
13+
SecretNameField,
14+
EnrollmentFlow,
15+
TransitAgency,
16+
EnrollmentEvent,
17+
EnrollmentMethods,
18+
)
1019
import benefits.secrets
1120

1221

@@ -16,6 +25,25 @@ def mock_requests_get_pem_data(mocker):
1625
return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text"))
1726

1827

28+
@pytest.mark.django_db
29+
@pytest.mark.parametrize(
30+
"input_template,expected_path",
31+
[
32+
("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"),
33+
("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"),
34+
("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"),
35+
("", None),
36+
("nope.html", None),
37+
("core/not-there.html", None),
38+
],
39+
)
40+
def test_template_path(input_template, expected_path):
41+
if expected_path:
42+
assert template_path(input_template) == Path(expected_path)
43+
else:
44+
assert template_path(input_template) is None
45+
46+
1947
def test_SecretNameField_init():
2048
field = SecretNameField()
2149

@@ -518,6 +546,30 @@ def test_TransitAgency_for_user_in_group_not_linked_to_any_agency():
518546
assert TransitAgency.for_user(user) is None
519547

520548

549+
@pytest.mark.django_db
550+
def test_TransitAgency_clean(model_TransitAgency_inactive):
551+
model_TransitAgency_inactive.short_name = ""
552+
model_TransitAgency_inactive.long_name = ""
553+
model_TransitAgency_inactive.phone = ""
554+
model_TransitAgency_inactive.info_url = ""
555+
model_TransitAgency_inactive.index_template_override = "does/not/exist.html"
556+
model_TransitAgency_inactive.eligibility_index_template_override = "does/not/exist.html"
557+
# agency is inactive, OK to have incomplete fields
558+
model_TransitAgency_inactive.clean()
559+
560+
# now mark it active and expect failure on clean()
561+
model_TransitAgency_inactive.active = True
562+
with pytest.raises(ValidationError) as e:
563+
model_TransitAgency_inactive.clean()
564+
565+
assert "short_name" in e.value.error_dict
566+
assert "long_name" in e.value.error_dict
567+
assert "phone" in e.value.error_dict
568+
assert "info_url" in e.value.error_dict
569+
assert "index_template" in e.value.error_dict
570+
assert "eligibility_index_template" in e.value.error_dict
571+
572+
521573
@pytest.mark.django_db
522574
def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow):
523575
ts = timezone.now()

0 commit comments

Comments
 (0)