From 318e30755f515259b9c1cc659785f2b75ab10315 Mon Sep 17 00:00:00 2001 From: Kegan Maher Date: Thu, 31 Oct 2024 19:11:27 +0000 Subject: [PATCH] feat(models): validate active TransitAgency when active=True, validate that: - there are values for user-facing info fields like names, phone, etc. - templates exist --- benefits/core/models.py | 96 ++++++++++++++++++++++---------- tests/pytest/core/test_models.py | 24 ++++++++ 2 files changed, 92 insertions(+), 28 deletions(-) diff --git a/benefits/core/models.py b/benefits/core/models.py index 0d20bf30ac..8a1ec5d6a4 100644 --- a/benefits/core/models.py +++ b/benefits/core/models.py @@ -5,8 +5,10 @@ from functools import cached_property import importlib import logging +from pathlib import Path import uuid +from django import template from django.conf import settings from django.core.exceptions import ValidationError from django.contrib.auth.models import Group, User @@ -24,6 +26,22 @@ logger = logging.getLogger(__name__) +def template_path(template_name: str) -> Path: + """Get a `pathlib.Path` for the named template, or None if it can't be found. + + A `template_name` is the app-local name, e.g. `enrollment/success.html`. + + Adapted from https://stackoverflow.com/a/75863472. + """ + for engine in template.engines.all(): + for loader in engine.engine.template_loaders: + for origin in loader.get_template_sources(template_name): + path = Path(origin.name) + if path.exists(): + return path + return None + + class SecretNameField(models.SlugField): """Field that stores the name of a secret held in a secret store. @@ -264,6 +282,30 @@ def transit_processor_client_secret(self): def enrollment_flows(self): return self.enrollmentflow_set + def clean(self): + if self.active: + errors = {} + message = "This field is required for active transit agencies." + needed = dict( + short_name=self.short_name, + long_name=self.long_name, + phone=self.phone, + info_url=self.info_url, + ) + for k, v in needed.items(): + if not v: + errors[k] = ValidationError(message) + + if not template_path(self.index_template): + errors["index_template"] = ValidationError(f"Template not found: {self.index_template}") + if not template_path(self.eligibility_index_template): + errors["eligibility_index_template"] = ValidationError( + f"Template not found: {self.eligibility_index_template}" + ) + + if errors: + raise ValidationError(errors) + @staticmethod def by_id(id): """Get a TransitAgency instance by its ID.""" @@ -493,6 +535,17 @@ def uses_claims_verification(self): """True if this flow verifies via the claims provider and has a scope and claim. False otherwise.""" return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim) + @property + def claims_scheme(self): + return self.claims_scheme_override or self.claims_provider.scheme + + @property + def claims_all_claims(self): + claims = [self.claims_eligibility_claim] + if self.claims_extra_claims is not None: + claims.extend(self.claims_extra_claims.split()) + return claims + @property def eligibility_verifier(self): """A str representing the entity that verifies eligibility for this flow. @@ -520,23 +573,6 @@ def enrollment_success_template(self): else: return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html" - def eligibility_form_instance(self, *args, **kwargs): - """Return an instance of this flow's EligibilityForm, or None.""" - if not bool(self.eligibility_form_class): - return None - - # inspired by https://stackoverflow.com/a/30941292 - module_name, class_name = self.eligibility_form_class.rsplit(".", 1) - FormClass = getattr(importlib.import_module(module_name), class_name) - - return FormClass(*args, **kwargs) - - @staticmethod - def by_id(id): - """Get an EnrollmentFlow instance by its ID.""" - logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") - return EnrollmentFlow.objects.get(id=id) - def clean(self): supports_expiration = self.supports_expiration expiration_days = self.expiration_days @@ -556,18 +592,22 @@ def clean(self): if errors: raise ValidationError(errors) - @property - def claims_scheme(self): - if not self.claims_scheme_override: - return self.claims_provider.scheme - return self.claims_scheme_override + def eligibility_form_instance(self, *args, **kwargs): + """Return an instance of this flow's EligibilityForm, or None.""" + if not bool(self.eligibility_form_class): + return None - @property - def claims_all_claims(self): - claims = [self.claims_eligibility_claim] - if self.claims_extra_claims is not None: - claims.extend(self.claims_extra_claims.split()) - return claims + # inspired by https://stackoverflow.com/a/30941292 + module_name, class_name = self.eligibility_form_class.rsplit(".", 1) + FormClass = getattr(importlib.import_module(module_name), class_name) + + return FormClass(*args, **kwargs) + + @staticmethod + def by_id(id): + """Get an EnrollmentFlow instance by its ID.""" + logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") + return EnrollmentFlow.objects.get(id=id) class EnrollmentEvent(models.Model): diff --git a/tests/pytest/core/test_models.py b/tests/pytest/core/test_models.py index e2be8c339d..0edb000f7f 100644 --- a/tests/pytest/core/test_models.py +++ b/tests/pytest/core/test_models.py @@ -518,6 +518,30 @@ def test_TransitAgency_for_user_in_group_not_linked_to_any_agency(): assert TransitAgency.for_user(user) is None +@pytest.mark.django_db +def test_TransitAgency_clean(model_TransitAgency_inactive): + model_TransitAgency_inactive.short_name = "" + model_TransitAgency_inactive.long_name = "" + model_TransitAgency_inactive.phone = "" + model_TransitAgency_inactive.info_url = "" + model_TransitAgency_inactive.index_template_override = "does/not/exist.html" + model_TransitAgency_inactive.eligibility_index_template_override = "does/not/exist.html" + # agency is inactive, OK to have incomplete fields + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError) as e: + model_TransitAgency_inactive.clean() + + assert "short_name" in e.value.error_dict + assert "long_name" in e.value.error_dict + assert "phone" in e.value.error_dict + assert "info_url" in e.value.error_dict + assert "index_template" in e.value.error_dict + assert "eligibility_index_template" in e.value.error_dict + + @pytest.mark.django_db def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow): ts = timezone.now()