Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding command line arguments for demo/chat_web_selects.py #56

Merged
merged 5 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/chat_pdf_streamlit_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
from ogbujipt.text_helper import text_splitter
from ogbujipt.embedding_helper import qdrant_collection
from ogbujipt.embedding.qdrant import collection

from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -104,7 +104,7 @@ def prep_pdf():

# Vectorizes chunks for sLLM lookup
# XXX: Look up rules around uploaded object names
kb = qdrant_collection(pdf.name, emb_model) # in-memory vector DB instance
kb = collection(pdf.name, emb_model) # in-memory vector DB instance

# Show throbber, embed the PDF, and get ready for similarity search
embedding_placeholder = st.container()
Expand Down
64 changes: 40 additions & 24 deletions demo/chat_web_selects.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Works especially well with airoboros self-hosted LLM.

Vector store: Qdrant - https://qdrant.tech/
Alternatives: pgvector, Chroma, Faiss, Weaviate, etc.
Alternatives: pgvector & Chroma (built-in support from OgbujiPT), Faiss, Weaviate, etc.
Text to vector (embedding) model:
Alternatives: https://www.sbert.net/docs/pretrained_models.html / OpenAI ada002

Expand All @@ -26,6 +26,13 @@
```

An example question might be "Who are the neighbors of the Igbo people?"

You can tweak it with the following command line options:
--verbose - print more information while processing (for debugging)
--limit (max number of chunks to retrieve for use as context)
--chunk-size (characters per chunk, while prepping to create embeddings)
--chunk-overlap (character overlap between chunks, while prepping to create embeddings)
--question (The user question; if None (the default), prompt the user interactively)
'''
# en.wikipedia.org/wiki/Igbo_people|ahiajoku.igbonet.com/2000/|en.wikivoyage.org/wiki/Igbo_phrasebook"
import asyncio
Expand All @@ -37,7 +44,7 @@

from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
from ogbujipt.text_helper import text_splitter
from ogbujipt.embedding_helper import qdrant_collection
from ogbujipt.embedding.qdrant import collection


# Avoid re-entrace complaints from huggingface/tokenizers
Expand All @@ -49,9 +56,6 @@
COLLECTION_NAME = 'chat-web-selects'
USER_PROMPT = 'What do you want to know from the site(s)?\n'

# Hard-code for demo
EMBED_CHUNK_SIZE = 200
EMBED_CHUNK_OVERLAP = 20
DOTS_SPACING = 0.2 # Number of seconds between each dot printed to console


Expand All @@ -61,7 +65,7 @@ async def indicate_progress(pause=DOTS_SPACING):
await asyncio.sleep(pause)


async def read_site(url, collection):
async def read_site(url, coll, chunk_size, chunk_overlap):
# Crude check; good enough for demo
if not url.startswith('http'): url = 'https://' + url # noqa E701
print('Downloading & processing', url)
Expand All @@ -72,28 +76,27 @@ async def read_site(url, collection):
text = html2text.html2text(html)

# Split text into chunks
chunks = text_splitter(text, chunk_size=EMBED_CHUNK_SIZE,
chunk_overlap=EMBED_CHUNK_OVERLAP, separator='\n')
chunks = text_splitter(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator='\n')

# print('\n\n'.join([ch[:100] for ch in chunks]))
# Crude—for demo. Set URL metadata for all chunks to doc URL
metas = [{'url': url}]*len(chunks)
# Add the text to the collection
collection.update(texts=chunks, metas=metas)
print(f'{collection.count()} chunks added to collection')
coll.update(texts=chunks, metas=metas)
print(f'{coll.count()} chunks added to collection')


async def async_main(oapi, sites):
async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, question):
# Automatic download of embedding model from HuggingFace
# Seem to be reentrancy issues with HuggingFace; defer import
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(DOC_EMBEDDINGS_LLM)
# Sites fuel in-memory Qdrant vector DB instance
collection = qdrant_collection(COLLECTION_NAME, embedding_model)
coll = collection(COLLECTION_NAME, embedding_model)

# Download & process sites in parallel, loading their content & metadata into a knowledgebase
url_task_group = asyncio.gather(*[
asyncio.create_task(read_site(site, collection)) for site in sites.split('|')])
asyncio.create_task(read_site(site, coll, chunk_size, chunk_overlap)) for site in sites.split('|')])
indicator_task = asyncio.create_task(indicate_progress())
tasks = [indicator_task, url_task_group]
done, _ = await asyncio.wait(
Expand All @@ -103,13 +106,16 @@ async def async_main(oapi, sites):
done = False
while not done:
print()
user_question = input(USER_PROMPT)
if question:
user_question = question
else:
user_question = input(USER_PROMPT)
if user_question.strip() == 'done':
break

docs = collection.search(user_question, limit=4)

print(docs)
docs = coll.search(user_question, limit=limit)
if verbose:
print(docs)
if docs:
# Collects "chunked_doc" into "gathered_chunks"
gathered_chunks = '\n\n'.join(
Expand All @@ -123,8 +129,8 @@ async def async_main(oapi, sites):
If you cannot answer with the given context, just say so.\n\n'''
sys_prompt += gathered_chunks + '\n\n'
messages = prompt_to_chat(user_question, system=sys_prompt)

print('-'*80, '\n', messages, '\n', '-'*80)
if verbose:
print('-'*80, '\n', messages, '\n', '-'*80)

# The rest is much like in demo/alpaca_multitask_fix_xml.py
model_params = dict(
Expand All @@ -143,9 +149,11 @@ async def async_main(oapi, sites):

# Instance of openai.openai_object.OpenAIObject, with lots of useful info
retval = next(iter(done)).result()
print(type(retval))
if verbose:
print(type(retval))
# Response is a json-like object; extract the text
print('\nFull response data from LLM:\n', retval)
if verbose:
print('\nFull response data from LLM:\n', retval)

# response is a json-like object;
# just get back the text of the response
Expand All @@ -155,21 +163,29 @@ async def async_main(oapi, sites):

# Command line arguments defined in click decorators
@click.command()
@click.option('--apibase', default='http://127.0.0.1:8000', help='OpenAI API base URL')
@click.option('--verbose/--no-verbose', default=False)
@click.option('--chunk-size', type=int, default=200,
help='Number of characters to include per chunk')
@click.option('--chunk-overlap', type=int, default=20,
help='Number of characters to overlap at the edges of chunks')
@click.option('--limit', default=4, type=int,
help='Maximum number of chunks matched against the posed question to use as context for the LLM')
@click.option('--openai-key',
help='OpenAI API key. Leave blank to specify self-hosted model via --host & --port')
@click.option('--apibase', default='http://127.0.0.1:8000', help='OpenAI API base URL')
@click.option('--model', default='', type=str,
help='OpenAI model to use (see https://platform.openai.com/docs/models).'
'Use only with --openai-key')
@click.option('--question', default=None, help='The question to ask (or prompt for one)')
@click.argument('sites')
def main(apibase, openai_key, model, sites):
def main(verbose, chunk_size, chunk_overlap, limit, openai_key, apibase, model, question, sites):
# Use OpenAI API if specified, otherwise emulate with supplied URL info
if openai_key:
oapi = openai_chat_api(api_key=openai_key, model=(model or 'gpt-3.5-turbo'))
else:
oapi = openai_chat_api(model=model, base_url=apibase)

asyncio.run(async_main(oapi, sites))
asyncio.run(async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, question))


if __name__ == '__main__':
Expand Down
15 changes: 12 additions & 3 deletions pylib/embedding/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
MEMORY_QDRANT_CONNECTION_PARAMS = {'location': ':memory:'}


class qdrant_collection:
class collection:
def __init__(self, name, embedding_model, db=None,
distance_function=None, **conn_params):
'''
Expand All @@ -48,12 +48,12 @@ def __init__(self, name, embedding_model, db=None,
Example:
>>> from ogbujipt.text_helper import text_splitter
>>> from ogbujipt.embedding_helper import qdrant_collection # pip install qdrant_client
>>> from ogbujipt.embedding.qdrant import collection # pip install qdrant_client
>>> from sentence_transformers import SentenceTransformer # pip install sentence_transformers
>>> text = 'The quick brown fox\njumps over the lazy dog,\nthen hides under a log\nwith a frog.\n'
>>> text += 'Should the hound wake up,\nall jumpers beware\nin a log, in a bog\nhe\'ll search everywhere.\n'
>>> embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
>>> collection = qdrant_collection('my-text', embedding_model)
>>> collection = collection('my-text', embedding_model)
>>> chunks = text_splitter(text, chunk_size=20, chunk_overlap=4, separator='\n')
>>> collection.update(texts=chunks, metas=[{'seq-index': i} for (i, _) in enumerate(chunks)])
>>> retval = collection.search('what does the fox say?', limit=1)
Expand Down Expand Up @@ -187,3 +187,12 @@ def count(self):
# This ugly declaration just gets the count as an integer
current_count = int(str(self.db.count(self.name)).partition('=')[-1])
return current_count


# Already disambiguated by the module name. Anyone can use import as if that's not enough
# Deprecating the old name
class qdrant_collection(collection):
def __init__(self, name, embedding_model, db=None,
distance_function=None, **conn_params):
warnings.warn('qdrant_collection is deprecated. Use collection instead.', DeprecationWarning)
super().__init__(name, embedding_model, db=db, distance_function=distance_function, **conn_params)
4 changes: 2 additions & 2 deletions test/embedding/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# import pytest

from ogbujipt.embedding import qdrant
from ogbujipt.embedding.qdrant import qdrant_collection
from ogbujipt.embedding.qdrant import collection
from ogbujipt.text_helper import text_splitter

qdrant.QDRANT_AVAILABLE = True
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_qdrant_embed_poem(mocker, COME_THUNDER_POEM, CORRECT_STRING):
qdrant.models.VectorParams.side_effect = [mock_vparam]
mocker.patch('ogbujipt.embedding.qdrant.QdrantClient')

coll = qdrant_collection(name=collection_name, embedding_model=embedding_model)
coll = collection(name=collection_name, embedding_model=embedding_model)

# client.count.side_effect = ['count=0']
coll.db.count.side_effect = lambda collection_name: 'count=0'
Expand Down