Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qa gen with end point #101

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from functools import partial
from pathlib import Path

import multiprocessing
import math
import datasets
import torch
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from langchain_core.prompts.chat import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_openai import ChatOpenAI

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand Down Expand Up @@ -73,6 +77,89 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
return args

def generate_question_answer_pairs_langchain_parallel(
documents: dict, passage_column_name: str, max_input_tokens: int
) -> dict:
texts = documents[passage_column_name]
num_processors = min(multiprocessing.cpu_count(), len(texts))
chunk_size = math.ceil(len(texts) / num_processors)

with multiprocessing.Pool(num_processors) as pool:
results = pool.map(
partial(generate_qa_chunk, passage_column_name=passage_column_name, max_input_tokens=max_input_tokens),
[texts[i:i + chunk_size] for i in range(0, len(texts), chunk_size)]
)

# Flatten the results
flattened_results = [item for sublist in results for item in sublist]

return {
"Question": [r["Question"] for r in flattened_results],
"Answer": [r["Answer"] for r in flattened_results]
}

def generate_qa_chunk(texts, passage_column_name, max_input_tokens):
example_passage = (
"Dense retrieval models are essential for embedding-based information "
"retrieval systems, as they map queries and documents into a shared "
"embedding space where their relevance can be computed. By using in-batch "
"negative contrastive learning, these models can be trained more efficiently, "
"as each batch contains not only positive examples but also negative samples "
"from unrelated queries or documents. This approach helps optimize the model's "
"ability to retrieve the most relevant information in real-world applications, "
"such as question-answering systems, where precision is critical."
)
example_question = (
"What role does in-batch negative contrastive learning play in training dense "
"retrieval models, particularly in optimizing the retrieval of relevant information "
"across different applications?"
)
prompt_template = (
"Read the following passage and generate a single, relevant question based "
"on its content. The question should be less than 100 words and more than 10 "
"words. Do not generate anything other than the question itself. Avoid any tokens, "
"explanations, or formatting. Do not use words like 'Question:', 'Answer:', 'Example:', or 'Passage:'. "
"Ensure there are no line breaks in the output. The output should be the question only, nothing more.\n\n"
"Example:\nPassage: {example_passage}\n{example_question}\n\nNow, do the same for the next "
"passage:\n{passage}\n"
)

arcee_llm = ChatOpenAI(
model_name="arcee-spark",
temperature=0,
max_tokens=1024,
openai_api_base=os.getenv("ARCEE_END_POINT"),
timeout=None,
max_retries=2,
top_p=0.75
)
system_template = "You are Arcee spark, created by Arcee.Inc. You are a helpful assistant whose task is to generate Questions given a piece of text and not the answer. You are a question generator."

system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
human_message_prompt = HumanMessagePromptTemplate.from_template(prompt_template)

chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

input_chain = chat_prompt | arcee_llm

results = []

for text in texts:
result = input_chain.invoke({
"example_passage" : example_passage,
"example_question" : example_question,
"passage" : text
})

if isinstance(result, dict) and "output" in result:
output = result["output"]
else:
output = str(result.content)
print("The output is : ", output)
print("The passage is : ", text)
results.append({"Question": output, "Answer": ""})

return results

def generate_question_answer_pairs(
documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str, max_input_tokens: int
Expand All @@ -95,7 +182,7 @@ def generate_question_answer_pairs(
"What role does in-batch negative contrastive learning play in training dense "
"retrieval models, particularly in optimizing the retrieval of relevant information "
"across different applications?"
)
)

prompt_template = (
"Read the following passage and generate a single, relevant question based "
Expand Down Expand Up @@ -188,8 +275,8 @@ def generate_qa_from_dataset(
dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, max_input_tokens: int, load_in_8bit: bool = True
) -> DatasetDict:
logger.info(f"Generating question answer pairs with batch size: {batch_size}")
tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
model = AutoModelForCausalLM.from_pretrained(QA_MODEL, torch_dtype="auto", device_map="auto")
#tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
#model = AutoModelForCausalLM.from_pretrained(QA_MODEL, torch_dtype="auto", device_map="auto")

# shuffle data
dataset.shuffle(seed=42)
Expand All @@ -203,7 +290,7 @@ def generate_qa_from_dataset(
f"Test dataset size: {len(small_dataset_splits['test'])}"
)
qa_gen_map = partial(
generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name, max_input_tokens=max_input_tokens
generate_question_answer_pairs_langchain_parallel, passage_column_name=passage_column_name, max_input_tokens=max_input_tokens
)
processed_data = small_dataset_splits.map(qa_gen_map, batched=True, batch_size=batch_size)
# Print all questions from the test split before filtering
Expand Down Expand Up @@ -300,4 +387,4 @@ def main() -> None:
--sample_size=50 \
--output_dir=out \
--max_input_tokens=512
"""
"""
Loading
Loading