Skip to content

Commit

Permalink
Get demo/chat_web_selects.py & demo/chat_pdf_streamlit_ui.py working …
Browse files Browse the repository at this point in the history
…again
  • Loading branch information
uogbuji committed Nov 24, 2023
1 parent 2c7edb4 commit dfbc3e4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
4 changes: 2 additions & 2 deletions demo/chat_pdf_streamlit_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
from ogbujipt.text_helper import text_splitter
from ogbujipt.embedding_helper import qdrant_collection
from ogbujipt.embedding.qdrant import collection

from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -104,7 +104,7 @@ def prep_pdf():

# Vectorizes chunks for sLLM lookup
# XXX: Look up rules around uploaded object names
kb = qdrant_collection(pdf.name, emb_model) # in-memory vector DB instance
kb = collection(pdf.name, emb_model) # in-memory vector DB instance

# Show throbber, embed the PDF, and get ready for similarity search
embedding_placeholder = st.container()
Expand Down
31 changes: 18 additions & 13 deletions demo/chat_web_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Works especially well with airoboros self-hosted LLM.
Vector store: Qdrant - https://qdrant.tech/
Alternatives: pgvector, Chroma, Faiss, Weaviate, etc.
Alternatives: pgvector & Chroma (built-in support from OgbujiPT), Faiss, Weaviate, etc.
Text to vector (embedding) model:
Alternatives: https://www.sbert.net/docs/pretrained_models.html / OpenAI ada002
Expand All @@ -26,6 +26,13 @@
```
An example question might be "Who are the neighbors of the Igbo people?"
You can tweak it with the following command line options:
--verbose - print more information while processing (for debugging)
--limit (max number of chunks to retrieve for use as context)
--chunk-size (characters per chunk, while prepping to create embeddings)
--chunk-overlap (character overlap between chunks, while prepping to create embeddings)
--question (The user question; if None (the default), prompt the user interactively)
'''
# en.wikipedia.org/wiki/Igbo_people|ahiajoku.igbonet.com/2000/|en.wikivoyage.org/wiki/Igbo_phrasebook"
import asyncio
Expand All @@ -37,7 +44,7 @@

from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
from ogbujipt.text_helper import text_splitter
from ogbujipt.embedding_helper import qdrant_collection
from ogbujipt.embedding.qdrant import collection


# Avoid re-entrace complaints from huggingface/tokenizers
Expand All @@ -49,9 +56,6 @@
COLLECTION_NAME = 'chat-web-selects'
USER_PROMPT = 'What do you want to know from the site(s)?\n'

# Hard-code for demo
EMBED_CHUNK_SIZE = 200
EMBED_CHUNK_OVERLAP = 20
DOTS_SPACING = 0.2 # Number of seconds between each dot printed to console


Expand All @@ -61,7 +65,7 @@ async def indicate_progress(pause=DOTS_SPACING):
await asyncio.sleep(pause)


async def read_site(url, collection, chunk_size, chunk_overlap):
async def read_site(url, coll, chunk_size, chunk_overlap):
# Crude check; good enough for demo
if not url.startswith('http'): url = 'https://' + url # noqa E701
print('Downloading & processing', url)
Expand All @@ -78,8 +82,8 @@ async def read_site(url, collection, chunk_size, chunk_overlap):
# Crude—for demo. Set URL metadata for all chunks to doc URL
metas = [{'url': url}]*len(chunks)
# Add the text to the collection
collection.update(texts=chunks, metas=metas)
print(f'{collection.count()} chunks added to collection')
coll.update(texts=chunks, metas=metas)
print(f'{coll.count()} chunks added to collection')


async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, question):
Expand All @@ -88,11 +92,11 @@ async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, que
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(DOC_EMBEDDINGS_LLM)
# Sites fuel in-memory Qdrant vector DB instance
collection = qdrant_collection(COLLECTION_NAME, embedding_model)
coll = collection(COLLECTION_NAME, embedding_model)

# Download & process sites in parallel, loading their content & metadata into a knowledgebase
url_task_group = asyncio.gather(*[
asyncio.create_task(read_site(site, collection, chunk_size, chunk_overlap)) for site in sites.split('|')])
asyncio.create_task(read_site(site, coll, chunk_size, chunk_overlap)) for site in sites.split('|')])
indicator_task = asyncio.create_task(indicate_progress())
tasks = [indicator_task, url_task_group]
done, _ = await asyncio.wait(
Expand All @@ -109,7 +113,7 @@ async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, que
if user_question.strip() == 'done':
break

docs = collection.search(user_question, limit=limit)
docs = coll.search(user_question, limit=limit)
if verbose:
print(docs)
if docs:
Expand Down Expand Up @@ -160,8 +164,9 @@ async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, que
# Command line arguments defined in click decorators
@click.command()
@click.option('--verbose/--no-verbose', default=False)
@click.option('--chunk-size', default=EMBED_CHUNK_SIZE, type=int, help='Number of characters to include per chunk')
@click.option('--chunk-overlap', default=EMBED_CHUNK_OVERLAP, type=int,
@click.option('--chunk-size', default=EMBED_CHUNK_SIZE, type=int, default=200,
help='Number of characters to include per chunk')
@click.option('--chunk-overlap', default=EMBED_CHUNK_OVERLAP, type=int, default=20,
help='Number of characters to overlap at the edges of chunks')
@click.option('--limit', default=4, type=int,
help='Maximum number of chunks matched against the posed question to use as context for the LLM')
Expand Down
15 changes: 12 additions & 3 deletions pylib/embedding/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
MEMORY_QDRANT_CONNECTION_PARAMS = {'location': ':memory:'}


class qdrant_collection:
class collection:
def __init__(self, name, embedding_model, db=None,
distance_function=None, **conn_params):
'''
Expand All @@ -48,12 +48,12 @@ def __init__(self, name, embedding_model, db=None,
Example:
>>> from ogbujipt.text_helper import text_splitter
>>> from ogbujipt.embedding_helper import qdrant_collection # pip install qdrant_client
>>> from ogbujipt.embedding.qdrant import collection # pip install qdrant_client
>>> from sentence_transformers import SentenceTransformer # pip install sentence_transformers
>>> text = 'The quick brown fox\njumps over the lazy dog,\nthen hides under a log\nwith a frog.\n'
>>> text += 'Should the hound wake up,\nall jumpers beware\nin a log, in a bog\nhe\'ll search everywhere.\n'
>>> embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
>>> collection = qdrant_collection('my-text', embedding_model)
>>> collection = collection('my-text', embedding_model)
>>> chunks = text_splitter(text, chunk_size=20, chunk_overlap=4, separator='\n')
>>> collection.update(texts=chunks, metas=[{'seq-index': i} for (i, _) in enumerate(chunks)])
>>> retval = collection.search('what does the fox say?', limit=1)
Expand Down Expand Up @@ -187,3 +187,12 @@ def count(self):
# This ugly declaration just gets the count as an integer
current_count = int(str(self.db.count(self.name)).partition('=')[-1])
return current_count


# Already disambiguated by the module name. Anyone can use import as if that's not enough
# Deprecating the old name
class qdrant_collection(collection):
def __init__(self, name, embedding_model, db=None,
distance_function=None, **conn_params):
warnings.warn('qdrant_collection is deprecated. Use collection instead.', DeprecationWarning)
super().__init__(name, embedding_model, db=db, distance_function=distance_function, **conn_params)
4 changes: 2 additions & 2 deletions test/embedding/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# import pytest

from ogbujipt.embedding import qdrant
from ogbujipt.embedding.qdrant import qdrant_collection
from ogbujipt.embedding.qdrant import collection
from ogbujipt.text_helper import text_splitter

qdrant.QDRANT_AVAILABLE = True
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_qdrant_embed_poem(mocker, COME_THUNDER_POEM, CORRECT_STRING):
qdrant.models.VectorParams.side_effect = [mock_vparam]
mocker.patch('ogbujipt.embedding.qdrant.QdrantClient')

coll = qdrant_collection(name=collection_name, embedding_model=embedding_model)
coll = collection(name=collection_name, embedding_model=embedding_model)

# client.count.side_effect = ['count=0']
coll.db.count.side_effect = lambda collection_name: 'count=0'
Expand Down

0 comments on commit dfbc3e4

Please sign in to comment.