Skip to content

Commit

Permalink
User.enable_protocol: send welcome DM
Browse files Browse the repository at this point in the history
via new Protocol.bot_dm method

for #1024, #966, etc
  • Loading branch information
snarfed committed Aug 9, 2024
1 parent 37c781a commit 23aa24e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 14 deletions.
17 changes: 14 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,24 +479,35 @@ def is_enabled(self, to_proto, explicit=False):

return False

@ndb.transactional()
def enable_protocol(self, to_proto):
"""Adds ``to_proto` to :attr:`enabled_protocols`.
Also sends a welcome DM to the user (via a send task) if their protocol
supports DMs.
Args:
to_proto (:class:`protocol.Protocol` subclass)
"""
added = False

@ndb.transactional()
def enable():
user = self.key.get()
add(user.enabled_protocols, to_proto.LABEL)
if to_proto.LABEL not in user.enabled_protocols:
user.enabled_protocols.append(to_proto.LABEL)
user.put()
nonlocal added
added = True

if to_proto.LABEL in ids.COPIES_PROTOCOLS and not user.get_copy(to_proto):
to_proto.create_for(user)
user.put()

enable()
add(self.enabled_protocols, to_proto.LABEL)

if added:
to_proto.bot_dm(to_user=self, text='hello world')

msg = f'Enabled {to_proto.LABEL} for {self.key.id()} : {self.user_page_path()}'
logger.info(msg)

Expand Down
32 changes: 31 additions & 1 deletion protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,36 @@ def bot_follow(bot_cls, user):
url=target, protocol=user.LABEL,
user=bot.key.urlsafe())

@classmethod
def bot_dm(bot_cls, to_user, text):
"""Sends a DM from this protocol's bot user.
Creates a task to send the DM asynchronously.
Args:
to_user (models.User)
text (str)
Returns
bool: True if the DM was successfully sent, False otherwise
"""
from web import Web
bot = Web.get_by_id(bot_cls.bot_user_id())
logger.info(f'Sending DM from {bot.key.id()} to {to_user.key.id()}: {text[:100]}')

id = f'{bot.profile_id()}#welcome-dm-{to_user.key.id()}-{util.now().isoformat()}'
target = Target(protocol=to_user.LABEL, uri=to_user.target_for(to_user.obj))
obj_key = Object(id=id, source_protocol='web', undelivered=[target], our_as1={
'objectType': 'note',
'id': id,
'actor': bot.key.id(),
'content': text,
'to': [to_user.key.id()],
}).put()

common.create_task(queue='send', obj=obj_key.urlsafe(), protocol=to_user.LABEL,
url=target.uri, user=bot.key.urlsafe())

@classmethod
def delete_user_copy(copy_cls, user):
"""Deletes a user's copy actor in a given protocol.
Expand Down Expand Up @@ -1587,7 +1617,7 @@ def check_supported(cls, obj):
and inner_type not in cls.SUPPORTED_AS1_TYPES)):
error(f"Bridgy Fed for {cls.LABEL} doesn't support {obj.type} {inner_type} yet", status=204)

if as1.is_dm(obj.as1):
if as1.is_dm(obj.as1) and as1.get_owner(obj.as1) not in PROTOCOL_DOMAINS:
error(f"Bridgy Fed doesn't support DMs", status=204)


Expand Down
35 changes: 34 additions & 1 deletion tests/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import app
from atproto import ATProto, Cursor
from atproto_firehose import handle, new_commits, Op
import common
from models import Follower, Object, Target
from web import Web

Expand Down Expand Up @@ -454,12 +455,23 @@ def test_activitypub_follow_bsky_bot_bad_username_error(self, mock_get):
self.assertEqual(0, len(user.copies))


@patch('requests.post', return_value=requests_response({ # sendMessage
'id': 'chat456',
'rev': '22222222tef2d',
# ...
}))
@patch('requests.get', side_effect=[
requests_response(DID_DOC), # alice DID
requests_response(PROFILE_GETRECORD), # alice profile
requests_response(PROFILE_GETRECORD), # ...
requests_response({ # getConvoForMembers
'convo': {
'id': 'convo123',
'rev': '22222222fuozt',
},
}),
])
def test_atproto_follow_ap_bot_user_enables_protocol(self, mock_get):
def test_atproto_follow_ap_bot_user_enables_protocol(self, mock_get, mock_post):
"""ATProto follow of @ap.brid.gy enables the ActivityPub protocol.
ATProto user alice.com, did:plc:alice
Expand All @@ -479,6 +491,27 @@ def test_atproto_follow_ap_bot_user_enables_protocol(self, mock_get):
user = ATProto.get_by_id('did:plc:alice')
self.assertTrue(user.is_enabled(ActivityPub))

headers = {
'Content-Type': 'application/json',
'User-Agent': common.USER_AGENT,
'Authorization': ANY,
}
mock_get.assert_any_call(
'https://chat.service.local/xrpc/chat.bsky.convo.getConvoForMembers?members=did%3Aplc%3Aalice',
json=None, data=None, headers=headers)
mock_post.assert_called_with(
'https://chat.service.local/xrpc/chat.bsky.convo.sendMessage',
json={
'convoId': 'convo123',
'message': {
'$type': 'chat.bsky.convo.defs#messageInput',
'text': 'hello world',
'createdAt': '2022-01-02T03:04:05.000Z',
'bridgyOriginalText': 'hello world',
'bridgyOriginalUrl': 'https://ap.brid.gy/#welcome-dm-did:plc:alice-2022-01-02T03:04:05+00:00',
},
}, data=None, headers=headers)


@patch('requests.post')
@patch('requests.get')
Expand Down
23 changes: 14 additions & 9 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,13 @@ def test_check_supported(self):
{'objectType': 'activity', 'verb': 'undo',
'object': {'objectType': 'activity', 'verb': 'share'}},
{'objectType': 'activity', 'verb': 'flag'},
# DM from protocol bot
{
'objectType': 'note',
'actor': 'ap.brid.gy',
'to': ['did:bob'],
'content': 'hello world',
},
):
with self.subTest(obj=obj):
Fake.check_supported(Object(our_as1=obj))
Expand All @@ -733,19 +740,17 @@ def test_check_supported(self):
{'objectType': 'event'},
{'objectType': 'activity', 'verb': 'post',
'object': {'objectType': 'event'}},
):
with self.subTest(obj=obj):
with self.assertRaises(NoContent):
Fake.check_supported(Object(our_as1=obj))

# DM
with self.assertRaises(NoContent):
Fake.check_supported(Object(our_as1={
# DM
{
'objectType': 'note',
'actor': 'did:alice',
'to': ['did:bob'],
'content': 'hello world',
}))
},
):
with self.subTest(obj=obj):
with self.assertRaises(NoContent):
Fake.check_supported(Object(our_as1=obj))

class ProtocolReceiveTest(TestCase):

Expand Down

0 comments on commit 23aa24e

Please sign in to comment.