Skip to content

Commit

Permalink
Merge pull request #16 from lpm0073/next
Browse files Browse the repository at this point in the history
bug fix to rag() and stronger typing
  • Loading branch information
lpm0073 authored Dec 2, 2023
2 parents 80a0897 + 2aaf8b0 commit e0f8fa5
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 45 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## [1.1.2](https://github.com/lpm0073/hybrid-search-retriever/compare/v1.1.1...v1.1.2) (2023-12-01)


### Bug Fixes

* syntax error in examples.prompt ([230b709](https://github.com/lpm0073/hybrid-search-retriever/commit/230b7090c96bdd4d7d8757b182f891ab1b82c6f4))
- syntax error in examples.prompt ([230b709](https://github.com/lpm0073/hybrid-search-retriever/commit/230b7090c96bdd4d7d8757b182f891ab1b82c6f4))

## [1.1.1](https://github.com/lpm0073/netec-llm/compare/v1.1.0...v1.1.1) (2023-12-01)

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ python3 -m models.examples.training_services_oracle "Oracle database administrat
python3 -m models.examples.load "./data/"

# example 6 - Retrieval Augmented Generation
python3 -m models.examples.rag "What is Accounting Based Valuation?"
python3 -m models.examples.rag "What analytics and accounting courses does Wharton offer?"
```

## Setup
Expand Down
2 changes: 1 addition & 1 deletion models/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = "1.1.2"
__version__ = "1.1.3"
2 changes: 1 addition & 1 deletion models/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
OPENAI_CHAT_TEMPERATURE = float(os.environ.get("OPENAI_CHAT_TEMPERATURE", 0.0))
OPENAI_CHAT_MAX_RETRIES = int(os.environ.get("OPENAI_CHAT_MAX_RETRIES", 3))
OPENAI_CHAT_CACHE = bool(os.environ.get("OPENAI_CHAT_CACHE", True))
DEBUG_MODE = bool(os.environ.get("DEBUG_MODE", False))
DEBUG_MODE = os.environ.get("DEBUG_MODE", "False") == "True"
else:
raise FileNotFoundError("No .env file found in root directory of repository")

Expand Down
12 changes: 8 additions & 4 deletions models/examples/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Sales Support Model (hsr)"""
import argparse

from langchain.schema import HumanMessage, SystemMessage

from models.hybrid_search_retreiver import HybridSearchRetriever


Expand All @@ -10,9 +12,11 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="hsr examples")
parser.add_argument("system_prompt", type=str, help="A system prompt to send to the model.")
parser.add_argument("human_prompt", type=str, help="A human prompt to send to the model.")
parser.add_argument("system_message", type=str, help="A system prompt to send to the model.")
parser.add_argument("human_message", type=str, help="A human prompt to send to the model.")
args = parser.parse_args()

result = hsr.cached_chat_request(args.system_prompt, args.human_prompt)
print(result)
system_message = SystemMessage(text=args.system_message)
human_message = HumanMessage(text=args.human_message)
result = hsr.cached_chat_request(system_message=system_message, human_message=human_message)
print(result.content)
5 changes: 4 additions & 1 deletion models/examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
import argparse

from langchain.schema import HumanMessage

from models.hybrid_search_retreiver import HybridSearchRetriever


Expand All @@ -12,5 +14,6 @@
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
args = parser.parse_args()

result = hsr.rag(prompt=args.prompt)
human_message = HumanMessage(text=args.prompt)
result = hsr.rag(human_message=human_message)
print(result)
56 changes: 32 additions & 24 deletions models/hybrid_search_retreiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

# document loading
import glob

# general purpose imports
import logging
import os
import textwrap
from typing import Union

# pinecone integration
import pinecone
Expand All @@ -38,7 +41,7 @@

# hybrid search capability
from langchain.retrievers import PineconeHybridSearchRetriever
from langchain.schema import HumanMessage, SystemMessage
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain.text_splitter import Document
from langchain.vectorstores.pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder
Expand Down Expand Up @@ -95,14 +98,20 @@ class HybridSearchRetriever:
text_splitter = TextSplitter()
bm25_encoder = BM25Encoder().default()

def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
def cached_chat_request(
self, system_message: Union[str, SystemMessage], human_message: Union[str, HumanMessage]
) -> BaseMessage:
"""Cached chat request."""
messages = [
SystemMessage(content=system_message),
HumanMessage(content=human_message),
]
if not isinstance(system_message, SystemMessage):
logging.debug("Converting system message to SystemMessage")
system_message = SystemMessage(content=str(system_message))

if not isinstance(human_message, HumanMessage):
logging.debug("Converting human message to HumanMessage")
human_message = HumanMessage(content=str(human_message))
messages = [system_message, human_message]
# pylint: disable=not-callable
retval = self.chat(messages).content
retval = self.chat(messages)
return retval

def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str = DEFAULT_MODEL_NAME) -> str:
Expand Down Expand Up @@ -158,10 +167,10 @@ def load(self, filepath: str):

logging.debug("Finished loading PDFs")

def rag(self, prompt: str):
def rag(self, human_message: Union[str, HumanMessage]):
"""
Embedded prompt.
1. Retrieve prompt: Given a user input, relevant splits are retrieved
1. Retrieve human message prompt: Given a user input, relevant splits are retrieved
from storage using a Retriever.
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
the question and the retrieved data
Expand All @@ -174,33 +183,32 @@ def rag(self, prompt: str):
The typical workflow is to use the embeddings to retrieve relevant documents,
and then use the text of these documents as part of the prompt for GPT-3.
"""
if not isinstance(human_message, HumanMessage):
logging.debug("Converting human_message to HumanMessage")
human_message = HumanMessage(content=human_message)

retriever = PineconeHybridSearchRetriever(
embeddings=self.openai_embeddings, sparse_encoder=self.bm25_encoder, index=self.pinecone_index
)
documents = retriever.get_relevant_documents(query=prompt)
documents = retriever.get_relevant_documents(query=human_message.content)
logging.debug("Retrieved %i related documents from Pinecone", len(documents))

# Extract the text from the documents
document_texts = [doc.page_content for doc in documents]
leader = textwrap.dedent(
"""\
\n\nYou can assume that the following is true.
"""You are a helpful assistant.
You can assume that all of the following is true.
You should attempt to incorporate these facts
into your response:\n\n
into your responses:\n\n
"""
)
system_message = f"{leader} {'. '.join(document_texts)}"

# Create a prompt that includes the document texts
prompt_with_relevant_documents = f"{prompt + leader} {'. '.join(document_texts)}"

logging.debug("Prompt contains %i words", len(prompt_with_relevant_documents.split()))
logging.debug("Prompt: %s", prompt_with_relevant_documents)

# Get a response from the GPT-3.5-turbo model
response = self.cached_chat_request(
system_message="You are a helpful assistant.", human_message=prompt_with_relevant_documents
)
logging.debug("System messages contains %i words", len(system_message.split()))
logging.debug("Prompt: %s", system_message)
system_message = SystemMessage(content=system_message)
response = self.cached_chat_request(system_message=system_message, human_message=human_message)

logging.debug("Response:")
logging.debug("------------------------------------------------------")
return response
return response.content
24 changes: 14 additions & 10 deletions models/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import MagicMock, patch

import pytest # pylint: disable=unused-import
from langchain.schema import HumanMessage, SystemMessage

from models.examples.prompt import hsr as prompt_hrs
from models.examples.rag import hsr as rag_hsr
Expand All @@ -14,7 +15,7 @@
from models.prompt_templates import NetecPromptTemplates


HUMAN_PROMPT = 'return the word "SUCCESS" in upper case.'
HUMAN_MESSAGE = 'return the word "SUCCESS" in upper case.'


class TestExamples:
Expand All @@ -25,44 +26,47 @@ def test_prompt(self, mock_parse_args):
"""Test prompt example."""
mock_args = MagicMock()
mock_args.system_prompt = "you are a helpful assistant"
mock_args.human_prompt = HUMAN_PROMPT
mock_args.human_prompt = HUMAN_MESSAGE
mock_parse_args.return_value = mock_args

result = prompt_hrs.cached_chat_request(mock_args.system_prompt, mock_args.human_prompt)
assert result == "SUCCESS"
system_message = SystemMessage(content="you are a helpful assistant")
human_message = HumanMessage(content=HUMAN_MESSAGE)
result = prompt_hrs.cached_chat_request(system_message=system_message, human_message=human_message)
assert result.content == "SUCCESS"

@patch("argparse.ArgumentParser.parse_args")
def test_rag(self, mock_parse_args):
"""Test RAG example."""
mock_args = MagicMock()
mock_args.human_prompt = HUMAN_PROMPT
mock_args.human_message = HUMAN_MESSAGE
mock_parse_args.return_value = mock_args

result = rag_hsr.rag(mock_args.human_prompt)
human_message = HumanMessage(content=mock_args.human_message)
result = rag_hsr.rag(human_message=human_message)
assert result == "SUCCESS"

@patch("argparse.ArgumentParser.parse_args")
def test_training_services(self, mock_parse_args):
"""Test training services templates."""
mock_args = MagicMock()
mock_args.human_prompt = HUMAN_PROMPT
mock_args.human_message = HUMAN_MESSAGE
mock_parse_args.return_value = mock_args

templates = NetecPromptTemplates()
prompt = templates.training_services

result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
assert "SUCCESS" in result

@patch("argparse.ArgumentParser.parse_args")
def test_oracle_training_services(self, mock_parse_args):
"""Test oracle training services."""
mock_args = MagicMock()
mock_args.human_prompt = HUMAN_PROMPT
mock_args.human_message = HUMAN_MESSAGE
mock_parse_args.return_value = mock_args

templates = NetecPromptTemplates()
prompt = templates.oracle_training_services

result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
assert "SUCCESS" in result
2 changes: 1 addition & 1 deletion models/tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def test_03_test_openai_connectivity(self):
retval = hsr.cached_chat_request(
"your are a helpful assistant", "please return the value 'CORRECT' in all upper case."
)
assert retval == "CORRECT"
assert retval.content == "CORRECT"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ codespell==2.2.6
# ------------
langchain==0.0.343
langchainhub==0.1.14
langchain-experimental==0.0.43
openai==1.3.5
pinecone-client==2.2.4
pinecone-text==0.7.0
Expand Down

0 comments on commit e0f8fa5

Please sign in to comment.