-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #325 from StampyAI/stampy_chat_module
Stampy chat module
- Loading branch information
Showing
10 changed files
with
292 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
""" | ||
Queries chat.stampy.ai with the user's question. | ||
""" | ||
|
||
import json | ||
import re | ||
from collections import deque, defaultdict | ||
from typing import Iterable, List, Dict, Any | ||
from uuid import uuid4 | ||
|
||
import requests | ||
from structlog import get_logger | ||
|
||
from modules.module import Module, Response | ||
from servicemodules.serviceConstants import italicise | ||
from utilities.serviceutils import ServiceChannel, ServiceMessage | ||
from utilities.utilities import Utilities | ||
|
||
log = get_logger() | ||
utils = Utilities.get_instance() | ||
|
||
|
||
LOG_MAX_MESSAGES = 15 # don't store more than X messages back | ||
DATA_HEADER = 'data: ' | ||
|
||
STAMPY_CHAT_ENDPOINT = "https://chat.stampy.ai:8443/chat" | ||
NLP_SEARCH_ENDPOINT = "https://nlp.stampy.ai" | ||
|
||
STAMPY_ANSWER_MIN_SCORE = 0.75 | ||
STAMPY_CHAT_MIN_SCORE = 0.4 | ||
|
||
|
||
def stream_lines(stream: Iterable): | ||
line = '' | ||
for item in stream: | ||
item = item.decode('utf8') | ||
line += item | ||
if '\n' in line: | ||
lines = line.split('\n') | ||
line = lines[-1] | ||
for l in lines[:-1]: | ||
yield l | ||
yield line | ||
|
||
|
||
def parse_data_items(stream: Iterable): | ||
for item in stream: | ||
if item.strip().startswith(DATA_HEADER): | ||
yield json.loads(item.split(DATA_HEADER)[1]) | ||
|
||
|
||
def top_nlp_search(query: str) -> Dict[str, Any]: | ||
resp = requests.get(NLP_SEARCH_ENDPOINT + '/api/search', params={'query': query, 'status': 'all'}) | ||
if not resp: | ||
return {} | ||
|
||
items = resp.json() | ||
if not items: | ||
return {} | ||
return items[0] | ||
|
||
|
||
def chunk_text(text: str, chunk_limit=2000, delimiter='.'): | ||
chunk = '' | ||
for sentence in text.split(delimiter): | ||
if len(chunk + sentence) + 1 >= chunk_limit and chunk and sentence: | ||
yield chunk | ||
chunk = sentence + delimiter | ||
elif sentence: | ||
chunk += sentence + delimiter | ||
yield chunk | ||
|
||
|
||
def filter_citations(text, citations): | ||
used_citations = re.findall(r'\[([a-z],? ?)*?\]', text) | ||
return [c for c in citations if c.get('reference') in used_citations] | ||
|
||
|
||
class StampyChat(Module): | ||
|
||
def __init__(self): | ||
self.utils = Utilities.get_instance() | ||
self._messages: dict[ServiceChannel, deque[ServiceMessage]] = defaultdict(lambda: deque(maxlen=LOG_MAX_MESSAGES)) | ||
self.session_id = str(uuid4()) | ||
super().__init__() | ||
|
||
@property | ||
def class_name(self): | ||
return 'stampy_chat' | ||
|
||
def format_message(self, message: ServiceMessage): | ||
return { | ||
'content': message.content, | ||
'role': 'assistant' if self.utils.stampy_is_author(message) else 'user', | ||
} | ||
|
||
def stream_chat_response(self, query: str, history: List[ServiceMessage]): | ||
return parse_data_items(stream_lines(requests.post(STAMPY_CHAT_ENDPOINT, stream=True, json={ | ||
'query': query, | ||
'history': [self.format_message(m) for m in history], | ||
'sessionId': self.session_id, | ||
'settings': {'mode': 'discord'}, | ||
}))) | ||
|
||
def get_chat_response(self, query: str, history: List[ServiceMessage]): | ||
response = {'citations': [], 'content': '', 'followups': []} | ||
for item in self.stream_chat_response(query, history): | ||
if item.get('state') == 'citations': | ||
response['citations'] += item.get('citations', []) | ||
elif item.get('state') == 'streaming': | ||
response['content'] += item.get('content', '') | ||
elif item.get('state') == 'followups': | ||
response['followups'] += item.get('followups', []) | ||
response['citations'] = filter_citations(response['content'], response['citations']) | ||
return response | ||
|
||
async def query(self, query: str, history: List[ServiceMessage], message: ServiceMessage): | ||
log.info('calling %s', query) | ||
chat_response = self.get_chat_response(query, history) | ||
content_chunks = list(chunk_text(chat_response['content'])) | ||
citations = [f'[{c["reference"]}] - {c["title"]} ({c["url"]})' for c in chat_response['citations'] if c.get('reference')] | ||
if citations: | ||
citations = ['Citations: \n' + '\n'.join(citations)] | ||
followups = [] | ||
if follows := chat_response['followups']: | ||
followups = [ | ||
'Checkout these articles for more info: \n' + '\n'.join( | ||
f'{f["text"]} - https://aisafety.info?state={f["pageid"]}' for f in follows | ||
) | ||
] | ||
|
||
log.info('response: %s', content_chunks + citations + followups) | ||
return Response( | ||
confidence=10, | ||
text=[italicise(text, message) for text in content_chunks + citations + followups], | ||
why='This is what the chat bot returned' | ||
) | ||
|
||
def _add_message(self, message: ServiceMessage) -> deque[ServiceMessage]: | ||
self._messages[message.channel].append(message) | ||
return self._messages[message.channel] | ||
|
||
def make_query(self, messages): | ||
if not messages: | ||
return '', messages | ||
|
||
current = messages[-1] | ||
query, history = '', list(messages) | ||
while message := (history and history.pop()): | ||
if message.author != current.author: | ||
break | ||
query = message.content + ' ' + query | ||
current = message | ||
return query, history | ||
|
||
def process_message(self, message: ServiceMessage) -> Response: | ||
history = self._add_message(message) | ||
|
||
if not self.is_at_me(message): | ||
return Response() | ||
|
||
query, history = self.make_query(history) | ||
nlp = top_nlp_search(query) | ||
if nlp.get('score', 0) > STAMPY_ANSWER_MIN_SCORE and nlp.get('status') == 'Live on site': | ||
return Response(confidence=5, text=f'Check out {nlp.get("url")} ({nlp.get("title")})') | ||
if nlp.get('score', 0) > STAMPY_CHAT_MIN_SCORE: | ||
return Response(confidence=6, callback=self.query, args=[query, history, message]) | ||
return Response() | ||
|
||
def process_message_from_stampy(self, message: ServiceMessage): | ||
self._add_message(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.