diff --git a/benefits/core/models.py b/benefits/core/models.py index 54493d86d1..7c69072490 100644 --- a/benefits/core/models.py +++ b/benefits/core/models.py @@ -6,8 +6,10 @@ import importlib import logging import os +from pathlib import Path import uuid +from django import template from django.conf import settings from django.core.exceptions import ValidationError from django.contrib.auth.models import Group, User @@ -25,6 +27,23 @@ logger = logging.getLogger(__name__) +def template_path(template_name: str) -> Path: + """Get a `pathlib.Path` for the named template, or None if it can't be found. + + A `template_name` is the app-local name, e.g. `enrollment/success.html`. + + Adapted from https://stackoverflow.com/a/75863472. + """ + if template_name: + for engine in template.engines.all(): + for loader in engine.engine.template_loaders: + for origin in loader.get_template_sources(template_name): + path = Path(origin.name) + if path.exists() and path.is_file(): + return path + return None + + class SecretNameField(models.SlugField): """Field that stores the name of a secret held in a secret store. @@ -292,6 +311,31 @@ def transit_processor_client_secret(self): def enrollment_flows(self): return self.enrollmentflow_set + def clean(self): + if self.active: + errors = {} + message = "This field is required for active transit agencies." + needed = dict( + short_name=self.short_name, + long_name=self.long_name, + phone=self.phone, + info_url=self.info_url, + logo_large=self.logo_large, + logo_small=self.logo_small, + ) + for k, v in needed.items(): + if not v: + errors[k] = ValidationError(message) + if errors: + raise ValidationError(errors) + + # since templates are calculated from the pattern or the override field + # we can't add a field-level validation error + # so just raise directly for a missing template + for t in [self.index_template, self.eligibility_index_template]: + if not template_path(t): + raise ValidationError(f"Template not found: {t}") + @staticmethod def by_id(id): """Get a TransitAgency instance by its ID.""" @@ -521,6 +565,17 @@ def uses_claims_verification(self): """True if this flow verifies via the claims provider and has a scope and claim. False otherwise.""" return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim) + @property + def claims_scheme(self): + return self.claims_scheme_override or self.claims_provider.scheme + + @property + def claims_all_claims(self): + claims = [self.claims_eligibility_claim] + if self.claims_extra_claims is not None: + claims.extend(self.claims_extra_claims.split()) + return claims + @property def eligibility_verifier(self): """A str representing the entity that verifies eligibility for this flow. @@ -548,23 +603,6 @@ def enrollment_success_template(self): else: return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html" - def eligibility_form_instance(self, *args, **kwargs): - """Return an instance of this flow's EligibilityForm, or None.""" - if not bool(self.eligibility_form_class): - return None - - # inspired by https://stackoverflow.com/a/30941292 - module_name, class_name = self.eligibility_form_class.rsplit(".", 1) - FormClass = getattr(importlib.import_module(module_name), class_name) - - return FormClass(*args, **kwargs) - - @staticmethod - def by_id(id): - """Get an EnrollmentFlow instance by its ID.""" - logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") - return EnrollmentFlow.objects.get(id=id) - def clean(self): supports_expiration = self.supports_expiration expiration_days = self.expiration_days @@ -584,18 +622,22 @@ def clean(self): if errors: raise ValidationError(errors) - @property - def claims_scheme(self): - if not self.claims_scheme_override: - return self.claims_provider.scheme - return self.claims_scheme_override + def eligibility_form_instance(self, *args, **kwargs): + """Return an instance of this flow's EligibilityForm, or None.""" + if not bool(self.eligibility_form_class): + return None - @property - def claims_all_claims(self): - claims = [self.claims_eligibility_claim] - if self.claims_extra_claims is not None: - claims.extend(self.claims_extra_claims.split()) - return claims + # inspired by https://stackoverflow.com/a/30941292 + module_name, class_name = self.eligibility_form_class.rsplit(".", 1) + FormClass = getattr(importlib.import_module(module_name), class_name) + + return FormClass(*args, **kwargs) + + @staticmethod + def by_id(id): + """Get an EnrollmentFlow instance by its ID.""" + logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") + return EnrollmentFlow.objects.get(id=id) class EnrollmentEvent(models.Model): diff --git a/tests/pytest/conftest.py b/tests/pytest/conftest.py index f60533b270..97d84c8c13 100644 --- a/tests/pytest/conftest.py +++ b/tests/pytest/conftest.py @@ -195,6 +195,8 @@ def model_TransitAgency(model_PemData, model_TransitProcessor): eligibility_api_jws_signing_alg="alg", index_template_override="core/agency-index.html", eligibility_index_template_override="eligibility/index.html", + logo_large="agencies/cst-lg.png", + logo_small="agencies/cst-sm.png", ) return agency diff --git a/tests/pytest/core/test_models.py b/tests/pytest/core/test_models.py index 9f7a3e8766..8916109339 100644 --- a/tests/pytest/core/test_models.py +++ b/tests/pytest/core/test_models.py @@ -1,4 +1,6 @@ from datetime import timedelta +from pathlib import Path + from django.conf import settings from django.contrib.auth.models import Group, User from django.core.exceptions import ValidationError @@ -7,6 +9,7 @@ import pytest from benefits.core.models import ( + template_path, SecretNameField, EnrollmentFlow, TransitAgency, @@ -24,6 +27,25 @@ def mock_requests_get_pem_data(mocker): return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text")) +@pytest.mark.django_db +@pytest.mark.parametrize( + "input_template,expected_path", + [ + ("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"), + ("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"), + ("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"), + ("", None), + ("nope.html", None), + ("core/not-there.html", None), + ], +) +def test_template_path(input_template, expected_path): + if expected_path: + assert template_path(input_template) == Path(expected_path) + else: + assert template_path(input_template) is None + + def test_SecretNameField_init(): field = SecretNameField() @@ -538,6 +560,45 @@ def test_agency_logo_large(model_TransitAgency): assert agency_logo_large(model_TransitAgency, "local_filename.png") == "agencies/test-lg.png" +@pytest.mark.django_db +def test_TransitAgency_clean(model_TransitAgency_inactive): + model_TransitAgency_inactive.short_name = "" + model_TransitAgency_inactive.long_name = "" + model_TransitAgency_inactive.phone = "" + model_TransitAgency_inactive.info_url = "" + model_TransitAgency_inactive.logo_large = "" + model_TransitAgency_inactive.logo_small = "" + # agency is inactive, OK to have incomplete fields + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError) as e: + model_TransitAgency_inactive.clean() + + errors = e.value.error_dict + + assert "short_name" in errors + assert "long_name" in errors + assert "phone" in errors + assert "info_url" in errors + assert "logo_large" in errors + assert "logo_small" in errors + + +@pytest.mark.django_db +@pytest.mark.parametrize("template_attribute", ["index_template_override", "eligibility_index_template_override"]) +def test_TransitAgency_clean_templates(model_TransitAgency_inactive, template_attribute): + setattr(model_TransitAgency_inactive, template_attribute, "does/not/exist.html") + # agency is inactive, OK to have missing template + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError, match="Template not found: does/not/exist.html"): + model_TransitAgency_inactive.clean() + + @pytest.mark.django_db def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow): ts = timezone.now()