diff --git a/invenio_userprofiles/config.py b/invenio_userprofiles/config.py index fb3e078..7b1cbd6 100644 --- a/invenio_userprofiles/config.py +++ b/invenio_userprofiles/config.py @@ -8,6 +8,8 @@ """Default configuration.""" +from .forms import ProfileForm + USERPROFILES = True """Enable or disable module extensions.""" @@ -31,3 +33,6 @@ USERPROFILES_READ_ONLY = False """Make the user profiles read-only.""" + +USERPROFILES_FORM_CLASS = ProfileForm +"""Default user profiles form class.""" diff --git a/invenio_userprofiles/ext.py b/invenio_userprofiles/ext.py index d0114e9..06aee47 100644 --- a/invenio_userprofiles/ext.py +++ b/invenio_userprofiles/ext.py @@ -8,7 +8,7 @@ # under the terms of the MIT License; see LICENSE file for more details. """User profiles module for Invenio.""" - +from flask import current_app from flask_menu import current_menu from invenio_i18n import LazyString from invenio_i18n import lazy_gettext as _ @@ -95,10 +95,13 @@ def init_common(app): """Post initialization.""" if app.config["USERPROFILES_EXTEND_SECURITY_FORMS"]: security_ext = app.extensions["security"] + UserProfileForm = app.config.get("USERPROFILES_FORM_CLASS") security_ext.confirm_register_form = confirm_register_form_factory( - security_ext.confirm_register_form + security_ext.confirm_register_form, UserProfileForm + ) + security_ext.register_form = register_form_factory( + security_ext.register_form, UserProfileForm ) - security_ext.register_form = register_form_factory(security_ext.register_form) def init_menu(app): diff --git a/invenio_userprofiles/forms.py b/invenio_userprofiles/forms.py index baf49d7..9769eea 100644 --- a/invenio_userprofiles/forms.py +++ b/invenio_userprofiles/forms.py @@ -123,37 +123,42 @@ def populate_obj(self, user): super().populate_obj(user) -class EmailProfileForm(ProfileForm): - """Form to allow editing of email address.""" +def create_email_profile_form(UserProfileForm): + """Factory function to create EmailProfileForm inheriting from UserProfileForm.""" + + class EmailProfileForm(UserProfileForm): + """Form to allow editing of email address.""" + + email = StringField( + # NOTE: Form field label + _("Email address"), + filters=[ + lambda x: x.lower() if x is not None else x, + ], + validators=[ + email_required, + current_user_email, + email_validator, + unique_user_email, + ], + ) - email = StringField( - # NOTE: Form field label - _("Email address"), - filters=[ - lambda x: x.lower() if x is not None else x, - ], - validators=[ - email_required, - current_user_email, - email_validator, - unique_user_email, - ], - ) + email_repeat = StringField( + # NOTE: Form field label + _("Re-enter email address"), + # NOTE: Form field help text + description=_("Please re-enter your email address."), + filters=[ + lambda x: x.lower() if x else x, + ], + validators=[ + email_required, + # NOTE: Form validation error. + EqualTo("email", message=_("Email addresses do not match.")), + ], + ) - email_repeat = StringField( - # NOTE: Form field label - _("Re-enter email address"), - # NOTE: Form field help text - description=_("Please re-enter your email address."), - filters=[ - lambda x: x.lower() if x else x, - ], - validators=[ - email_required, - # NOTE: Form validation error. - EqualTo("email", message=_("Email addresses do not match.")), - ], - ) + return EmailProfileForm class VerificationForm(FlaskForm): @@ -163,10 +168,10 @@ class VerificationForm(FlaskForm): send_verification_email = SubmitField(_("Resend verification email")) -def register_form_factory(Form): +def register_form_factory(Form, UserProfileForm): """Factory for creating an extended user registration form.""" - class CsrfDisabledProfileForm(ProfileForm): + class CsrfDisabledProfileForm(UserProfileForm): """Subclass of ProfileForm to disable CSRF token in the inner form. This class will always be a inner form field of the parent class @@ -252,10 +257,10 @@ def populate_obj(self, user): super().populate_obj(user) -def confirm_register_form_factory(Form): +def confirm_register_form_factory(Form, UserProfileForm): """Factory for creating a confirm register form with UserProfile fields.""" - class CsrfDisabledProfileForm(ProfileForm): + class CsrfDisabledProfileForm(UserProfileForm): """Subclass of ProfileForm to disable CSRF token in the inner form. This class will always be an inner form field of the parent class diff --git a/invenio_userprofiles/views.py b/invenio_userprofiles/views.py index b5ce4ed..ba13d47 100644 --- a/invenio_userprofiles/views.py +++ b/invenio_userprofiles/views.py @@ -26,7 +26,7 @@ from invenio_db import db from invenio_i18n import lazy_gettext as _ -from .forms import EmailProfileForm, PreferencesForm, ProfileForm, VerificationForm +from .forms import PreferencesForm, VerificationForm, create_email_profile_form from .models import UserProfileProxy @@ -92,6 +92,8 @@ def profile(): def profile_form_factory(): """Create a profile form.""" + UserProfileForm = current_app.config.get("USERPROFILES_FORM_CLASS") + EmailProfileForm = create_email_profile_form(UserProfileForm) if current_app.config["USERPROFILES_EMAIL_ENABLED"]: return EmailProfileForm( formdata=None, @@ -99,7 +101,7 @@ def profile_form_factory(): prefix="profile", ) else: - return ProfileForm( + return UserProfileForm( formdata=None, obj=current_user, prefix="profile", diff --git a/tests/test_forms.py b/tests/test_forms.py index 5292ee9..29ca53f 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -9,6 +9,7 @@ """Tests for user profile forms.""" from invenio_userprofiles.forms import ( + ProfileForm, _update_with_csrf_disabled, confirm_register_form_factory, confirm_register_form_preferences_factory, @@ -56,8 +57,8 @@ def test_confirm_register_form_preferences_factory_no_csrf(app): """Test CSRF token is not in confirm form and not in inner forms.""" security = app.extensions["security"] - def factory_profile_preferences(Form): - ProfileForm = confirm_register_form_factory(Form) + def factory_profile_preferences(Form, UserProfileForm): + ProfileForm = confirm_register_form_factory(Form, UserProfileForm) return confirm_register_form_preferences_factory(ProfileForm) rf = _get_form(app, security.confirm_register_form, factory_profile_preferences) @@ -116,7 +117,7 @@ class AForm(parent_form): with app.test_request_context(): extra = _update_with_csrf_disabled() if force_disable_csrf else {} - RF = factory_method(AForm) + RF = factory_method(AForm, ProfileForm) rf = RF(**extra) rf.profile.username.data = "my username"