Skip to content

Commit

Permalink
feat(models): validate active TransitAgency
Browse files Browse the repository at this point in the history
when active=True, validate that:

- there are values for user-facing info fields like names, logos, phone, etc.
- templates exist
  • Loading branch information
thekaveman committed Nov 19, 2024
1 parent e9f0e5a commit d7da0da
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 28 deletions.
98 changes: 70 additions & 28 deletions benefits/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import importlib
import logging
import os
from pathlib import Path
import uuid

from django import template
from django.conf import settings
from django.core.exceptions import ValidationError
from django.contrib.auth.models import Group, User
Expand All @@ -25,6 +27,23 @@
logger = logging.getLogger(__name__)


def template_path(template_name: str) -> Path:
"""Get a `pathlib.Path` for the named template, or None if it can't be found.
A `template_name` is the app-local name, e.g. `enrollment/success.html`.
Adapted from https://stackoverflow.com/a/75863472.
"""
if template_name:
for engine in template.engines.all():
for loader in engine.engine.template_loaders:
for origin in loader.get_template_sources(template_name):
path = Path(origin.name)
if path.exists() and path.is_file():
return path
return None


class SecretNameField(models.SlugField):
"""Field that stores the name of a secret held in a secret store.
Expand Down Expand Up @@ -292,6 +311,31 @@ def transit_processor_client_secret(self):
def enrollment_flows(self):
return self.enrollmentflow_set

def clean(self):
if self.active:
errors = {}
message = "This field is required for active transit agencies."
needed = dict(
short_name=self.short_name,
long_name=self.long_name,
phone=self.phone,
info_url=self.info_url,
logo_large=self.logo_large,
logo_small=self.logo_small,
)
for k, v in needed.items():
if not v:
errors[k] = ValidationError(message)
if errors:
raise ValidationError(errors)

# since templates are calculated from the pattern or the override field
# we can't add a field-level validation error
# so just raise directly for a missing template
for t in [self.index_template, self.eligibility_index_template]:
if not template_path(t):
raise ValidationError(f"Template not found: {t}")

@staticmethod
def by_id(id):
"""Get a TransitAgency instance by its ID."""
Expand Down Expand Up @@ -520,6 +564,17 @@ def uses_claims_verification(self):
"""True if this flow verifies via the claims provider and has a scope and claim. False otherwise."""
return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim)

@property
def claims_scheme(self):
return self.claims_scheme_override or self.claims_provider.scheme

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims

@property
def eligibility_verifier(self):
"""A str representing the entity that verifies eligibility for this flow.
Expand Down Expand Up @@ -547,23 +602,6 @@ def enrollment_success_template(self):
else:
return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html"

def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)

def clean(self):
supports_expiration = self.supports_expiration
expiration_days = self.expiration_days
Expand All @@ -583,18 +621,22 @@ def clean(self):
if errors:
raise ValidationError(errors)

@property
def claims_scheme(self):
if not self.claims_scheme_override:
return self.claims_provider.scheme
return self.claims_scheme_override
def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims
# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)


class EnrollmentEvent(models.Model):
Expand Down
2 changes: 2 additions & 0 deletions tests/pytest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def model_TransitAgency(model_PemData, model_TransitProcessor):
eligibility_api_jws_signing_alg="alg",
index_template_override="core/agency-index.html",
eligibility_index_template_override="eligibility/index.html",
logo_large="agencies/cst-lg.png",
logo_small="agencies/cst-sm.png",
)

return agency
Expand Down
61 changes: 61 additions & 0 deletions tests/pytest/core/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
from pathlib import Path

from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.exceptions import ValidationError
Expand All @@ -7,6 +9,7 @@
import pytest

from benefits.core.models import (
template_path,
SecretNameField,
EnrollmentFlow,
TransitAgency,
Expand All @@ -24,6 +27,25 @@ def mock_requests_get_pem_data(mocker):
return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text"))


@pytest.mark.django_db
@pytest.mark.parametrize(
"input_template,expected_path",
[
("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"),
("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"),
("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"),
("", None),
("nope.html", None),
("core/not-there.html", None),
],
)
def test_template_path(input_template, expected_path):
if expected_path:
assert template_path(input_template) == Path(expected_path)
else:
assert template_path(input_template) is None


def test_SecretNameField_init():
field = SecretNameField()

Expand Down Expand Up @@ -538,6 +560,45 @@ def test_agency_logo_large(model_TransitAgency):
assert agency_logo_large(model_TransitAgency, "local_filename.png") == "agencies/test-lg.png"


@pytest.mark.django_db
def test_TransitAgency_clean(model_TransitAgency_inactive):
model_TransitAgency_inactive.short_name = ""
model_TransitAgency_inactive.long_name = ""
model_TransitAgency_inactive.phone = ""
model_TransitAgency_inactive.info_url = ""
model_TransitAgency_inactive.logo_large = ""
model_TransitAgency_inactive.logo_small = ""
# agency is inactive, OK to have incomplete fields
model_TransitAgency_inactive.clean()

# now mark it active and expect failure on clean()
model_TransitAgency_inactive.active = True
with pytest.raises(ValidationError) as e:
model_TransitAgency_inactive.clean()

errors = e.value.error_dict

assert "short_name" in errors
assert "long_name" in errors
assert "phone" in errors
assert "info_url" in errors
assert "logo_large" in errors
assert "logo_small" in errors


@pytest.mark.django_db
@pytest.mark.parametrize("template_attribute", ["index_template_override", "eligibility_index_template_override"])
def test_TransitAgency_clean_templates(model_TransitAgency_inactive, template_attribute):
setattr(model_TransitAgency_inactive, template_attribute, "does/not/exist.html")
# agency is inactive, OK to have missing template
model_TransitAgency_inactive.clean()

# now mark it active and expect failure on clean()
model_TransitAgency_inactive.active = True
with pytest.raises(ValidationError, match="Template not found: does/not/exist.html"):
model_TransitAgency_inactive.clean()


@pytest.mark.django_db
def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow):
ts = timezone.now()
Expand Down

0 comments on commit d7da0da

Please sign in to comment.