Skip to content

Commit

Permalink
Adding banking77 classification and twitter emotion detection files
Browse files Browse the repository at this point in the history
  • Loading branch information
hafeezali committed Dec 8, 2023
1 parent b7fdc63 commit d4edaaa
Show file tree
Hide file tree
Showing 20 changed files with 17,153 additions and 0 deletions.
113 changes: 113 additions & 0 deletions examples/labeling/banking/banking_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
from langroid.utils.logging import setup_colored_logging
from langroid.vector_store.qdrantdb import QdrantDBConfig
from langroid.agent.special.retriever_agent import (
RecordDoc,
RecordMetadata,
RetrieverAgent,
RetrieverAgentConfig,
)
from langroid.parsing.parser import ParsingConfig

import pandas as pd
from typing import Any, Dict, List, Sequence

from sklearn.metrics import accuracy_score


# TODO: Generalize for any single-label classification task and fetch constants from user
class BankingTextRetrieverAgentConfig(RetrieverAgentConfig):
system_message: str = "You are an expert at understanding bank customer support queries."
user_message: str = """
Your task is to match a bank statement to a list of examples in a table based on semantic similarity between the given statement and the examples in the table.
"""
data: List[Dict[str, Any]]
n_matches: int = 10
vecdb: QdrantDBConfig = QdrantDBConfig(
collection_name="banking-classification",
storage_path=":memory:",
)
parsing: ParsingConfig = ParsingConfig(
n_similar_docs=10,
)
cross_encoder_reranking_model = "" # turn off cross-encoder reranking


# TODO: Logic for get_records can come from user
class BankingTextRetrieverAgent(RetrieverAgent):
def __init__(self, config: BankingTextRetrieverAgentConfig):
super().__init__(config)
self.config = config

def get_records(self) -> Sequence[RecordDoc]:
return [
RecordDoc(
content=", ".join(f"{k}={v}" for k, v in d.items()),
metadata=RecordMetadata(id=i),
)
for i, d in enumerate(self.config.data)
]


def compute_acc(llm_labels, gt_labels):
return accuracy_score(gt_labels, llm_labels)


class BankingTextClassifier:
def __int__(
self,
chat_agent_config: ChatAgentConfig,
rag_seed_file: str,
banking_test_file: str,
base_prompt: str
):
setup_colored_logging()

self.chat_agent_config = chat_agent_config
self.banking_classifier_agent = ChatAgent(chat_agent_config)
self.base_prompt = base_prompt

rag_seed_data = pd.read_csv(rag_seed_file).to_dict('records')
self.banking_text_retriever_agent = BankingTextRetrieverAgent(BankingTextRetrieverAgentConfig(data=rag_seed_data))
self.banking_text_retriever_agent.ingest()

self.test_df = pd.read_csv(banking_test_file)
self.test_df['ID'] = range(1, len(self.test_df) + 1)

self.results_file = "./test_llm_responses.csv"
self.results = {}

# TODO: for debug purposes only, must be removed
self.test_df = self.test_df[self.test_df['ID'] < 25]
self.llm_responses = None

def run_tweet_emotion_detect(self):
agent = ChatAgent(self.chat_agent_config)

llm_responses = {}
for idx, row in self.test_df.iterrows():
prompt = self.base_prompt
nearest_examples = self.banking_text_retriever_agent.get_relevant_chunks(query=row['text'])
for index in range(len(nearest_examples)):
example = nearest_examples[index].content
text = example.split("text=")[1].split(", label=")[0]
label = example.split(", label=")[1]
prompt = prompt + f"Text: {text}\n"
prompt = prompt + f"Label: {label}\n"
prompt = prompt + "\n" + f"Text: {row['text']}\n Label: "
llm_responses[row['ID']] = agent.llm_response_forget(prompt).content

result_dict_list = [{'ID': int(key), 'llm_label': value} for key, value in llm_responses.items()]
result_df = pd.DataFrame(result_dict_list)
result_df.to_csv(self.results_file, index=False)

self.llm_responses = result_df

self.compute_results(self.llm_responses)

def run_tweet_emotion_detect_async_batch(self):
pass

def compute_results(self, llm_responses):
combined_labels_df = self.test_df.merge(llm_responses, on="ID", how="inner")
self.results['Accuracy'] = compute_acc(combined_labels_df['llm_label'], combined_labels_df['label'])
Loading

0 comments on commit d4edaaa

Please sign in to comment.