Skip to content

Commit

Permalink
sso: send_update_ping if groups list changes for a user (#208)
Browse files Browse the repository at this point in the history
This happens if, for instance, the groups are modified from the
Group admin page, and not from the User admin page.
  • Loading branch information
lukegb committed Aug 12, 2018
1 parent ca0637d commit 8bcd444
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 8 deletions.
32 changes: 31 additions & 1 deletion spongeauth/sso/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.conf import settings
from django.db.models.signals import post_save
from django.db.models.signals import post_save, m2m_changed
from django.dispatch import receiver

from accounts.models import User, Avatar
Expand All @@ -18,6 +18,36 @@ def on_user_save(sender, instance=None, **kwargs):
send_update_ping(instance)


@receiver(m2m_changed, sender=User.groups.through)
def on_group_change(sender, instance=None, pk_set=None, action=None, reverse=None, **kwargs):
if action not in ('post_add', 'post_remove'):
return
if not _can_ping():
return # do nothing, again
if reverse:
instances = User.objects.filter(pk__in=pk_set)
else:
instances = [instance]
for instance in instances:
send_update_ping(instance)


@receiver(m2m_changed, sender=User.groups.through)
def on_group_clear(sender, instance=None, pk_set=None, action=None, reverse=None, **kwargs):
if action != 'pre_clear':
return
if not _can_ping():
return # do nothing, again
if reverse:
instances = list(instance.user_set.all())
groups = [instance.id]
else:
instances = [instance]
groups = list(instance.groups.values_list('id', flat=True))
for instance in instances:
send_update_ping(instance, exclude_groups=groups)


@receiver(post_save, sender=Avatar)
def on_avatar_save(sender, instance=None, **kwargs):
if not _can_ping():
Expand Down
52 changes: 51 additions & 1 deletion spongeauth/sso/tests/test_ping_on_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import faker
import pytest

from accounts.tests.factories import UserFactory, AvatarFactory
from accounts.tests.factories import UserFactory, GroupFactory, AvatarFactory
import sso.models

TEST_SSO_ENDPOINTS = {
Expand Down Expand Up @@ -39,6 +39,56 @@ def test_pings_on_user_save(fake_send_update_ping, settings):
fake_send_update_ping.assert_called_once_with(user)


@unittest.mock.patch('sso.models.send_update_ping')
@pytest.mark.django_db
def test_pings_on_group_save_forward(fake_send_update_ping, settings):
user = UserFactory.create()
group = GroupFactory.create()
settings.SSO_ENDPOINTS = TEST_SSO_ENDPOINTS
fake_send_update_ping.assert_not_called()

user.groups.add(group)
fake_send_update_ping.assert_called_once_with(user)


@unittest.mock.patch('sso.models.send_update_ping')
@pytest.mark.django_db
def test_pings_on_group_save(fake_send_update_ping, settings):
user = UserFactory.create()
group = GroupFactory.create()
settings.SSO_ENDPOINTS = TEST_SSO_ENDPOINTS
fake_send_update_ping.assert_not_called()

group.user_set.add(user)
fake_send_update_ping.assert_called_once_with(user)


@unittest.mock.patch('sso.models.send_update_ping')
@pytest.mark.django_db
def test_pings_on_group_clear_forward(fake_send_update_ping, settings):
user = UserFactory.create()
group = GroupFactory.create()
user.groups.set([group])
settings.SSO_ENDPOINTS = TEST_SSO_ENDPOINTS
fake_send_update_ping.assert_not_called()

user.groups.clear()
assert list(fake_send_update_ping.call_args[1]['exclude_groups']) == [group.id]


@unittest.mock.patch('sso.models.send_update_ping')
@pytest.mark.django_db
def test_pings_on_group_clear(fake_send_update_ping, settings):
user = UserFactory.create()
group = GroupFactory.create()
group.user_set.set([user])
settings.SSO_ENDPOINTS = TEST_SSO_ENDPOINTS
fake_send_update_ping.assert_not_called()

group.user_set.clear()
fake_send_update_ping.assert_called_once_with(user, exclude_groups=[group.id])


@unittest.mock.patch('sso.models.send_update_ping')
@pytest.mark.django_db
def test_no_pings_on_avatar_save_not_current(fake_send_update_ping, settings):
Expand Down
75 changes: 74 additions & 1 deletion spongeauth/sso/tests/test_send_update_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from accounts.tests.factories import UserFactory
from accounts.tests.factories import UserFactory, GroupFactory
from .. import discourse_sso
from ..utils import send_update_ping

Expand Down Expand Up @@ -65,3 +65,76 @@ def test_send_update_ping(settings):
'add_groups': 'aardvark,banana,carrot',
'remove_groups': 'gingerbread,horseradish,indigo',
})


@pytest.mark.django_db
def test_send_update_ping_better(settings):
with unittest.mock.patch.object(discourse_sso, 'DiscourseSigner') as \
fake_discourse_signer_cls:
fake_send_post = unittest.mock.MagicMock()
fake_discourse_signer = fake_discourse_signer_cls.return_value
fake_discourse_signer.sign.return_value = (
'payload', 'signature')

excluded_group = GroupFactory.create(
internal_only=False, internal_name='1-excluded')
in_group = GroupFactory.create(
internal_only=False, internal_name='2-in')
not_in_group = GroupFactory.create(
internal_only=False, internal_name='3-not-in')
in_internal_group = GroupFactory.create(
internal_only=True, internal_name='4-internal-in')
not_in_internal_group = GroupFactory.create(
internal_only=True, internal_name='5-internal-not-in')

user = UserFactory.create(
email='[email protected]',
username='foo_',
mc_username='meep',
gh_username='meeep',
irc_nick='XxXmeepXxX')
user.groups.set([excluded_group, in_group, in_internal_group])

settings.SSO_ENDPOINTS = TEST_SSO_ENDPOINTS
send_update_ping(user, send_post=fake_send_post,
exclude_groups=[excluded_group.id])

fake_send_post.assert_called_once_with(
'http://discourse.example.com/admin/users/sync_sso',
data={
'sso': 'payload',
'sig': 'signature',
'api_key': 'discourse-api-key',
'api_username': 'system'})
fake_discourse_signer.sign.assert_called_once_with({
'nonce': str(user.id),
'email': '[email protected]',
'require_activation': 'false',
'external_id': user.id,
'username': 'foo_',
'name': 'foo_',
'custom.user_field_1': 'meep',
'custom.user_field_2': 'XxXmeepXxX',
'custom.user_field_3': 'meeep',
'moderator': False,
'admin': False,
'add_groups': '2-in',
'remove_groups': '1-excluded,3-not-in',
})

send_update_ping(user, send_post=fake_send_post)
fake_discourse_signer.sign.assert_called_with({
'nonce': str(user.id),
'email': '[email protected]',
'require_activation': 'false',
'external_id': user.id,
'username': 'foo_',
'name': 'foo_',
'custom.user_field_1': 'meep',
'custom.user_field_2': 'XxXmeepXxX',
'custom.user_field_3': 'meeep',
'moderator': False,
'admin': False,
'add_groups': '1-excluded,2-in',
'remove_groups': '3-not-in',
})
16 changes: 11 additions & 5 deletions spongeauth/sso/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.conf import settings
from django.db.models import Q

import requests

Expand All @@ -10,16 +11,18 @@ def _cast_bool(b):
return str(bool(b)).lower()


def make_payload(user, nonce, request=None, group=None):
def make_payload(user, nonce, request=None, group=None, exclude_groups=None):
group = group or Group
exclude_groups = set(exclude_groups or [])
avatar_url = user.avatar.get_absolute_url()
if request is not None:
avatar_url = request.build_absolute_uri(avatar_url)
relevant_groups = group.objects.filter(internal_only=False).order_by(
'internal_name')
add_groups = relevant_groups.filter(user=user).values_list(
filter_q = Q(user=user) & ~Q(pk__in=exclude_groups)
add_groups = relevant_groups.filter(filter_q).values_list(
'internal_name', flat=True)
remove_groups = relevant_groups.exclude(user=user).values_list(
remove_groups = relevant_groups.exclude(filter_q).values_list(
'internal_name', flat=True)
payload = {
'nonce': nonce,
Expand All @@ -39,10 +42,13 @@ def make_payload(user, nonce, request=None, group=None):
return payload


def send_update_ping(user, send_post=None, group=None):
def send_update_ping(user, send_post=None, group=None, exclude_groups=None):
send_post = send_post or requests.post
exclude_groups = exclude_groups or []

payload = make_payload(user, str(user.pk), group=group)
payload = make_payload(
user, str(user.pk), group=group,
exclude_groups=exclude_groups)

resps = []
for endpoint_settings in settings.SSO_ENDPOINTS.values():
Expand Down

0 comments on commit 8bcd444

Please sign in to comment.