Skip to content

Commit

Permalink
Merge pull request #325 from StampyAI/stampy_chat_module
Browse files Browse the repository at this point in the history
Stampy chat module
  • Loading branch information
mruwnik authored Nov 12, 2023
2 parents b4b9e61 + a26d2fa commit 84cfd1c
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 75 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ You'll need at least these:
- `DISCORD_GUILD`: your server ID
- `DATABASE_PATH`: the path to the Q&A database (normally in `./database/stampy.db`).
- `STAMPY_MODULES`: list of your desired modules, or leave unset to load all modules in the `./modules/` directory. You probably don't want all, as some of them aren't applicable to servers other than Rob's.
- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates and info are sent

Not required:

- `BOT_VIP_IDS`: list of user IDs. VIPs have full access and some special permissions.
- `BOT_DEV_ROLES`: list of roles representing bot devs.
- `BOT_DEV_IDS`: list of user ids of bot devs. You may want to include `BOT_VIP_IDS` here.
- `BOT_CONTROL_CHANNEL_IDS`: list of channels where control commands are accepted.
- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates are sent
- `BOT_ERROR_CHANNEL_ID`: (defaults to private channel) low level error tracebacks from Python. with this variable they can be shunted to a seperate channel.
- `CODA_API_TOKEN`: token to access Coda. Without it, modules `Questions` and `QuestionSetter` will not be available and `StampyControls` will have limited functionality.
- `BOT_REBOOT`: how Stampy reboots himself. Unset, he only quits, expecting an external `while true` loop (like in `runstampy`/Dockerfile). Set to `exec` he will try to relaunch himself from his own CLI arguments.
Expand Down
34 changes: 19 additions & 15 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Literal, TypeVar, Optional, Union, cast, get_args, overload, Any, Tuple
from pathlib import Path

import dotenv
from structlog import get_logger
Expand All @@ -10,16 +11,15 @@
dotenv.load_dotenv()
NOT_PROVIDED = "__NOT_PROVIDED__"

module_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modules")
module_dir = Path(__file__).parent / 'modules'


def get_all_modules() -> frozenset[str]:
modules = set()
for file_name in os.listdir(module_dir):
if file_name.endswith(".py") and file_name not in ("__init__.py", "module.py"):
modules.add(file_name[:-3])

return frozenset(modules)
return frozenset({
filename.stem
for filename in module_dir.glob('*.py')
if filename.suffix == '.py' and filename.name not in ('__init__.py', 'module.py')
})


ALL_STAMPY_MODULES = get_all_modules()
Expand Down Expand Up @@ -47,8 +47,7 @@ def getenv(env_var: str, default = NOT_PROVIDED) -> str:


def getenv_bool(env_var: str) -> bool:
e = getenv(env_var, default="UNDEFINED")
return e != "UNDEFINED"
return getenv(env_var, default="UNDEFINED") != "UNDEFINED"


# fmt:off
Expand All @@ -64,12 +63,12 @@ def getenv_unique_set(var_name: str, default: T) -> Union[frozenset[str], T]:...


def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozenset, T]:
l = getenv(var_name, default="EMPTY_SET").split(" ")
if l == ["EMPTY_SET"]:
var = getenv(var_name, default='')
if not var.strip():
return default
s = frozenset(l)
assert len(l) == len(s), f"{var_name} has duplicate members! {l}"
return s
items = var.split()
assert len(items) == len(set(items)), f"{var_name} has duplicate members! {sorted(items)}"
return frozenset(items)


maximum_recursion_depth = 30
Expand Down Expand Up @@ -150,6 +149,11 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
channel_whitelist: Optional[frozenset[str]]
disable_prompt_moderation: bool

## Flask settings
if flask_port := getenv('FLASK_PORT', '2300'):
flask_port = int(flask_port)
flask_address = getenv('FLASK_ADDRESS', "0.0.0.0")

is_rob_server = getenv_bool("IS_ROB_SERVER")
if is_rob_server:
# use robmiles server defaults
Expand Down Expand Up @@ -233,7 +237,7 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
bot_dev_roles = getenv_unique_set("BOT_DEV_ROLES", frozenset())
bot_dev_ids = getenv_unique_set("BOT_DEV_IDS", frozenset())
bot_control_channel_ids = getenv_unique_set("BOT_CONTROL_CHANNEL_IDS", frozenset())
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID", '')
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID")
bot_error_channel_id = getenv("BOT_ERROR_CHANNEL_ID", bot_private_channel_id)
# NOTE: Rob's invite/member management functions, not ported yet
member_role_id = getenv("MEMBER_ROLE_ID", default=None)
Expand Down
7 changes: 3 additions & 4 deletions modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]:
r",? @?[sS](tampy)?(?P<punctuation>[.!?]*)$", r"\g<punctuation>", text
)
at_me = True
elif re.search(r'^[sS]tamp[ys]?\?', text):
at_me = True

if message.is_dm:
# DMs are always at you
Expand All @@ -255,10 +257,7 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]:
)
at_me = True

if at_me:
return text
else:
return False
return at_me and text

def get_guild_and_invite_role(self):
return get_guild_and_invite_role()
Expand Down
172 changes: 172 additions & 0 deletions modules/stampy_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Queries chat.stampy.ai with the user's question.
"""

import json
import re
from collections import deque, defaultdict
from typing import Iterable, List, Dict, Any
from uuid import uuid4

import requests
from structlog import get_logger

from modules.module import Module, Response
from servicemodules.serviceConstants import italicise
from utilities.serviceutils import ServiceChannel, ServiceMessage
from utilities.utilities import Utilities

log = get_logger()
utils = Utilities.get_instance()


LOG_MAX_MESSAGES = 15 # don't store more than X messages back
DATA_HEADER = 'data: '

STAMPY_CHAT_ENDPOINT = "https://chat.stampy.ai:8443/chat"
NLP_SEARCH_ENDPOINT = "https://nlp.stampy.ai"

STAMPY_ANSWER_MIN_SCORE = 0.75
STAMPY_CHAT_MIN_SCORE = 0.4


def stream_lines(stream: Iterable):
line = ''
for item in stream:
item = item.decode('utf8')
line += item
if '\n' in line:
lines = line.split('\n')
line = lines[-1]
for l in lines[:-1]:
yield l
yield line


def parse_data_items(stream: Iterable):
for item in stream:
if item.strip().startswith(DATA_HEADER):
yield json.loads(item.split(DATA_HEADER)[1])


def top_nlp_search(query: str) -> Dict[str, Any]:
resp = requests.get(NLP_SEARCH_ENDPOINT + '/api/search', params={'query': query, 'status': 'all'})
if not resp:
return {}

items = resp.json()
if not items:
return {}
return items[0]


def chunk_text(text: str, chunk_limit=2000, delimiter='.'):
chunk = ''
for sentence in text.split(delimiter):
if len(chunk + sentence) + 1 >= chunk_limit and chunk and sentence:
yield chunk
chunk = sentence + delimiter
elif sentence:
chunk += sentence + delimiter
yield chunk


def filter_citations(text, citations):
used_citations = re.findall(r'\[([a-z],? ?)*?\]', text)
return [c for c in citations if c.get('reference') in used_citations]


class StampyChat(Module):

def __init__(self):
self.utils = Utilities.get_instance()
self._messages: dict[ServiceChannel, deque[ServiceMessage]] = defaultdict(lambda: deque(maxlen=LOG_MAX_MESSAGES))
self.session_id = str(uuid4())
super().__init__()

@property
def class_name(self):
return 'stampy_chat'

def format_message(self, message: ServiceMessage):
return {
'content': message.content,
'role': 'assistant' if self.utils.stampy_is_author(message) else 'user',
}

def stream_chat_response(self, query: str, history: List[ServiceMessage]):
return parse_data_items(stream_lines(requests.post(STAMPY_CHAT_ENDPOINT, stream=True, json={
'query': query,
'history': [self.format_message(m) for m in history],
'sessionId': self.session_id,
'settings': {'mode': 'discord'},
})))

def get_chat_response(self, query: str, history: List[ServiceMessage]):
response = {'citations': [], 'content': '', 'followups': []}
for item in self.stream_chat_response(query, history):
if item.get('state') == 'citations':
response['citations'] += item.get('citations', [])
elif item.get('state') == 'streaming':
response['content'] += item.get('content', '')
elif item.get('state') == 'followups':
response['followups'] += item.get('followups', [])
response['citations'] = filter_citations(response['content'], response['citations'])
return response

async def query(self, query: str, history: List[ServiceMessage], message: ServiceMessage):
log.info('calling %s', query)
chat_response = self.get_chat_response(query, history)
content_chunks = list(chunk_text(chat_response['content']))
citations = [f'[{c["reference"]}] - {c["title"]} ({c["url"]})' for c in chat_response['citations'] if c.get('reference')]
if citations:
citations = ['Citations: \n' + '\n'.join(citations)]
followups = []
if follows := chat_response['followups']:
followups = [
'Checkout these articles for more info: \n' + '\n'.join(
f'{f["text"]} - https://aisafety.info?state={f["pageid"]}' for f in follows
)
]

log.info('response: %s', content_chunks + citations + followups)
return Response(
confidence=10,
text=[italicise(text, message) for text in content_chunks + citations + followups],
why='This is what the chat bot returned'
)

def _add_message(self, message: ServiceMessage) -> deque[ServiceMessage]:
self._messages[message.channel].append(message)
return self._messages[message.channel]

def make_query(self, messages):
if not messages:
return '', messages

current = messages[-1]
query, history = '', list(messages)
while message := (history and history.pop()):
if message.author != current.author:
break
query = message.content + ' ' + query
current = message
return query, history

def process_message(self, message: ServiceMessage) -> Response:
history = self._add_message(message)

if not self.is_at_me(message):
return Response()

query, history = self.make_query(history)
nlp = top_nlp_search(query)
if nlp.get('score', 0) > STAMPY_ANSWER_MIN_SCORE and nlp.get('status') == 'Live on site':
return Response(confidence=5, text=f'Check out {nlp.get("url")} ({nlp.get("title")})')
if nlp.get('score', 0) > STAMPY_CHAT_MIN_SCORE:
return Response(confidence=6, callback=self.query, args=[query, history, message])
return Response()

def process_message_from_stampy(self, message: ServiceMessage):
self._add_message(message)
4 changes: 3 additions & 1 deletion servicemodules/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@


# TODO: store long responses temporarily for viewing outside of discord
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> str:
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> Union[str, Iterable]:
if isinstance(response.text, str):
wastrimmed = False
wastrimmed, text_to_return = limit_text(response.text, discordLimit)
if wastrimmed:
why_traceback.append(f"I had to trim the output from {response.module}")
return text_to_return
elif isinstance(response.text, (list, tuple)):
return response.text
return ""


Expand Down
Loading

0 comments on commit 84cfd1c

Please sign in to comment.