Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #78 Reading comprehension synthetic data regex improvements #80

Merged
merged 13 commits into from
Dec 5, 2023
209 changes: 164 additions & 45 deletions dalm/datasets/reading_comprehension_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,38 +140,87 @@ def create_domain_tokenizer_from_files(directory_or_file: str, csv_column: Optio
return create_domain_tokenizer(os.path.join(temp_dir, "temp.txt"))


def fix_first_prompt(text: str, chat_chain: List[Dict[str, str]]) -> List[Dict[str, str]]:
# remove the first prompt
first_prompt = chat_chain.pop(0)
fixed_first_prompt = [
{
"content": f"Based on the following text: \n {text}, \n I'd like you to answer a few questions\n"
+ first_prompt["content"],
"role": "user",
}
]
return fixed_first_prompt + chat_chain
def wrap_context_with_rag_instruction(context: str) -> str:
return f"Based on the following text: \n {context}, \n I'd like you to answer a few questions\n"


# TODO: add test
# TODO: Address known issues described in #78
def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[str, str]] | None:
text_lines = whole_text.split("\n")
question: List[str] = []
answer: List[str] = []
def extract_question(text: str) -> Tuple[bool, str]:
"""
Extracts a question from a line of text.
Returns a tuple of (is_question, question_text)
"""
# question regex
return extract_question_or_answer(text, extract_type="question")


def extract_answer(text: str) -> Tuple[bool, str]:
"""
Extracts an answer from a line of text.
Returns a tuple of (is_answer, answer_text)
"""
# question regex
return extract_question_or_answer(text, extract_type="answer")


def extract_question_or_answer(text: str, extract_type: str = "question") -> Tuple[bool, str]:
# Match a line that starts with any number of junk characters, followed by either "question:"
# or "answer:", followed by any number of spaces (ignored), followed by any number of characters
# that will be captured in a group as the question or answer.
# extraction_regex = rf".*{extract_type}:\s*(.*)"

# Update above to handle the case where the question or answer is in brackets, with
# other text to be ignored inside the brackets
extraction_regex = rf".*\[?{extract_type}[:\]]*(?:.*?\])?\s*(.*)"

match = re.match(extraction_regex, text, re.IGNORECASE)
extracted_text = match.group(1) if match else ""
found_extracted = True if extracted_text else False
return found_extracted, extracted_text


def _raw_question_and_answer_extractor(whole_text: str) -> List[Dict[str, str]] | None:
"""
Extracts questions and answers from the raw text generated by the large language model.

@param whole_text: the raw questions and answers generated by the large language model, eg:
"1. QUESTION: Can you summarize the .. ?
ANSWER: Population imaging studies generated .."

Algorithm overview:

1. Loop over all lines in the text.
2. When we find a question, capture the question into a variable and set a state flag
3. When we find an answer, capture the answer into a variable and save the QA pair
4. When we run out of lines, return the list of QA pairs

Supported formats are documented in the unit tests for this function:
tests/datasets/reading_comprehension_generation/test_utils.py

Unsupported formats:

Format 1: A question where the full question is in brackets

[QUESTION: How can machine learning used in radiation oncology?]

Format 2: A question which does not contain the keyword "Question":

2. [type: true/false] Is the following sentence true? ... clinical trials.

Format 3: A question and answer on the same line:

question_context = False
answer_context = False
2. [QUESTION: What benefits of physical activity? ANSWER: Reductions in risk factors.]
"""

result = []
task_regex = r"^\*?\*?task\s*\d*"

# question regex
question_regex = r"^question\s*\d*"
cur_qa_pair = {}
qa_pairs = []

# answer regex
answer_regex = r"^answer\s*\d*"
state_waiting_for_question = "waiting_for_question"
state_waiting_for_answer = "waiting_for_answer"
state = state_waiting_for_question

text_lines = whole_text.split("\n")
for i in text_lines:
raw_text = i.strip()
text = raw_text.lower()
Expand All @@ -180,31 +229,101 @@ def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[st
if text == "":
continue

# if the line start matches the question regex or the task regex
if re.match(question_regex, text) or re.match(task_regex, text):
if answer_context:
result.append({"content": " ".join(question), "role": "user"})
result.append({"content": " ".join(answer), "role": "assistant"})
question = []
answer = []
answer_context = False
# If the line matches the task regex, print a warning. The old code handled
# "tasks", but this new code does not. Need to inspect where these come into play
if re.match(task_regex, text):
logger.warning(f"Found a task line: {text}")

if state == state_waiting_for_question:
is_question, question_text = extract_question(text)
if is_question:
state = state_waiting_for_answer
cur_qa_pair = {"question": question_text, "answer": "TBD"}
continue
elif state == state_waiting_for_answer:
is_answer, answer_text = extract_answer(text)
if is_answer:
state = state_waiting_for_question
cur_qa_pair["answer"] = answer_text
if not cur_qa_pair["question"] or not cur_qa_pair["answer"]:
logger.warning(f"Found a QA pair with an empty question or answer: {cur_qa_pair}. Skipping.")
else:
qa_pairs.append(cur_qa_pair)
else:
# If we're expecting an answer, but the next non-empty line is not an answer,
# something probably went wrong. Print a warning and skip this QA pair.
logger.warning(f"Found a question with no answer: {cur_qa_pair}. Skipping.")
state = state_waiting_for_question

continue
else:
raise ValueError(f"Unknown state while extracting Q&A pairs: {state}")

return qa_pairs

question_context = True
answer_context = False

if re.match(answer_regex, text):
question_context = False
answer_context = True
def convert_qa_pairs_to_chat_completions(qa_pairs: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Convert a list of QA pairs into a list of chat completions that can be fed into the large language model.
"""
chat_completions = []
for qa_pair in qa_pairs:
question = qa_pair["question"]
answer = qa_pair["answer"]

question_chat_completion = {
"content": question,
"role": "user",
}

answer_chat_completion = {
"content": answer,
"role": "assistant",
}

if question_context:
# remove (labelled as QUESTION and ANSWER) from the text
raw_text = re.sub(r"\(labelled as QUESTION and ANSWER\)", "", raw_text)
question.append(raw_text)
chat_completions.append(question_chat_completion)
chat_completions.append(answer_chat_completion)

if answer_context:
answer.append(raw_text)
return chat_completions

if result == []:

def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[str, str]] | None:
"""

Extracts questions and answers from the raw text generated by the large language model.

@param whole_text: the raw questions and answers generated by the large language model, eg:
"1. QUESTION: Can you summarize the .. ?
ANSWER: Population imaging studies generated .."
@param context: the full dataset text that was used to generate the questions and answers, eg:
"Population imaging studies generate data for developing and implementing..."

"""

chat_completion_inputs = []

# Wrap the context with a RAG instruction
context_instruction = wrap_context_with_rag_instruction(context)

# The first chat completion input is the context instruction
first_chat_completion_input = {
"content": context_instruction,
"role": "user",
}
chat_completion_inputs.append(first_chat_completion_input)

# Extract the qa pairs from whole_text
qa_pairs = _raw_question_and_answer_extractor(whole_text)

# If there are no qa pairs, return None
if not qa_pairs:
logger.warning(f"No QA pairs could be generated from whole_text: {whole_text} \n\n and context: {context}")
return None

return fix_first_prompt(context, result)
# Convert the qa pairs to chat completion inputs
qa_pairs_chat_completions = convert_qa_pairs_to_chat_completions(qa_pairs)

# Add the qa pairs chat completions to the result
chat_completion_inputs.extend(qa_pairs_chat_completions)

return chat_completion_inputs
30 changes: 30 additions & 0 deletions dalm/pipelines/reading_comprehension_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LLMKwargs:
context_length: Optional[int]
dataset_output_path: str
chunk: bool
unprocessed_dataset_output_path: Optional[str] = None

def __post_init__(self) -> None:
if self.chunk and not self.context_length:
Expand Down Expand Up @@ -127,6 +128,11 @@ def pipeline(
if not os.path.exists(llm_kwargs.dataset_output_path):
os.makedirs(llm_kwargs.dataset_output_path)

if llm_kwargs.unprocessed_dataset_output_path and not os.path.exists(
llm_kwargs.unprocessed_dataset_output_path
):
os.makedirs(llm_kwargs.unprocessed_dataset_output_path)

llm_rc_dataset_generator = generate_synthetic_dataset(
model_name=llm_kwargs.model_name,
input_directory_or_file=input,
Expand All @@ -138,9 +144,26 @@ def pipeline(

# generate llm based reading comprehension dataset
for index, text_identifier, context, gen_text in llm_rc_dataset_generator:
logger.info(
f"LLM RC dataset generated text of length {len(gen_text)} from context of length {len(context)}"
)
qanda = question_and_answer_extractor(gen_text, context)
if llm_kwargs.unprocessed_dataset_output_path:
output_file = f"{text_identifier}_{index}.json"
logger.info(f"Writing unprocessed LLM output to {output_file}")
unprocessed = {
"context": context,
"gen_text": gen_text,
"qanda": qanda,
"index": index,
"text_identifier": text_identifier,
}
with open(os.path.join(llm_kwargs.unprocessed_dataset_output_path, output_file), "w") as o:
json.dump(unprocessed, o)

if qanda:
output_file = f"{text_identifier}_{index}.json"
logger.info(f"Writing Q & A chat completions of length {len(qanda)} to {output_file}")
with open(os.path.join(llm_kwargs.dataset_output_path, output_file), "w") as o:
json.dump(qanda, o)
else:
Expand Down Expand Up @@ -247,6 +270,12 @@ def parse_args() -> argparse.Namespace:
default="llm_dataset",
help="path to save the generated LLM based dataset",
)
parser.add_argument(
"--llm_unprocessed_dataset_output_path",
type=str,
default=None,
help="path to save the raw unprocessed LLM based dataset for debugging purposes",
)
parser.add_argument(
"--general_spm_path",
type=str,
Expand Down Expand Up @@ -332,6 +361,7 @@ def main() -> None:
model_name=args.llm_synth_model_name,
context_length=args.llm_synth_model_context_length,
dataset_output_path=args.llm_dataset_output_path,
unprocessed_dataset_output_path=args.llm_unprocessed_dataset_output_path,
chunk=not args.no_chunk,
)

Expand Down
8 changes: 6 additions & 2 deletions dalm/training/generator_only/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,12 @@ def train_generator(

def prepare_sample_text(example: Dict[str, Any]) -> str:
"""Prepare the text from a sample of the dataset."""
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return text
try:
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return text
except Exception as e:
logger.exception(f"Error while preparing the text: {e}. Skipping this example: {example}")
return ""

train_dataset, eval_dataset = create_datasets(
dataset_name=dataset_name,
Expand Down
Loading