Skip to content

Commit

Permalink
Da 24/reading comprehension (#74)
Browse files Browse the repository at this point in the history
* Broken reading-comprehension generators and generator training

* Refactor out shared utils and add domain tokenizer training as a option in direct script tun

* Corrections

* Further corrections

* Add q&a extractor as a util and to the pipeline example script

* Util correction and context addition to output

* Revert previous corrections

* Chatml outputformat for regex based rc and remove domain keyword

* Generator additions

* More generator corrections and additions

* Add train num epochs as an arg to the generator training script

* Add new dependencies

* 1. json dump when writing to file
2. regex-based-gen now creates a domain tokenizer if both domain sentencepiece model and domain text (explicitly) is not given
3. attempt at pipeline

* Post pipeline test run corrections

* Remove (explicit) use of ConstantLengthDataset

* Lift state management to a higer level and some corrections

* Add option to save dataset as huggingface dataset

* Training script corrections

* Trainer script cleanup and corrections

* Add cloud friendly logger

* Reformatted synthetic dataset generation + corrections

* More formatting for the synth-gen script

* More corrections to synth-gen

* Regex-gen changes and banner addition

* Missing comma

* Type hint all the functions in utils

* Lightly refactor the pipeline

* Address proper negation of lags

* Switch out generator for iterator when type hinting

* Util typing correction

* Correct all linting issues

* Pipeline corrections

* More corrections for output type of generator

* More corrections to the pipeline

* Appeasing the linter for the pipeline code

* Appeasing the linter for llm synth script

* Appeasing the linter for the training script

* Linter based corrections for utils

* More appeasing of the linter and work arounds

* Incorporate csv reading and associated changes

* Unicode decoding revisit

* More fixes

* Forgot to put in replace line

* Better logging and removal of statefile and more corrections

* Add missing general spm input validation line to pipeline script

* More validation lines for pipeline

* More corrections

* Banner correction, corrections

* Start of README.md and add general sentencepiece model to resources

* Add defaults for cli args

* Add more detail to README.md

* Add defaults to function

* Defaults

* README.md for rc pipeline

* transformers version dependency constraint

* alpha -> beta

* Better warning message

* Correct description in README.md

* Stream arg correction for trainer

* Add prompt link to README

* Add general spm to resources

* - Better input content generator (deals with directory of csv(s))
- remove default wandb value
- new logging statement
- change of terminology for logging (files -> texts)

* Vocab size second-try and key error fix

* Correct logging

* Update dalm/pipelines/reading_comprehension_pipeline.py

Co-authored-by: Traun Leyden <[email protected]>

* Update dalm/pipelines/reading_comprehension_pipeline.py

Co-authored-by: Traun Leyden <[email protected]>

* Update dalm/datasets/reading_comprehension_generation/synthetic_based.py

Co-authored-by: Traun Leyden <[email protected]>

* Update dalm/datasets/reading_comprehension_generation/utils.py

Co-authored-by: Traun Leyden <[email protected]>

* Update dalm/pipelines/reading_comprehension_pipeline.py

Co-authored-by: Traun Leyden <[email protected]>

* Update dalm/pipelines/reading_comprehension_pipeline.py

Co-authored-by: Traun Leyden <[email protected]>

* Corrections

* Post linting

* Update README with suggested corrections

* grammar corrections

---------

Co-authored-by: Traun Leyden <[email protected]>
  • Loading branch information
metric-space and tleyden authored Dec 4, 2023
1 parent e6c3d29 commit 567c910
Show file tree
Hide file tree
Showing 9 changed files with 2,455 additions and 2 deletions.
49 changes: 49 additions & 0 deletions dalm/datasets/reading_comprehension_generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
## A note about reading comprehension

This aproach of adapting LLMs is based on this [paper](https://arxiv.org/abs/2309.09530) by Microsoft
The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora
and training a llm on said generated dataset can enhance its domain adaptiveness

We have two ways of generating reading comprehension data

1. Via regex based methods that combs the input data for match and aligns them into questions and answers
2. Via prompting a large language model to come up with questions and answers

To see the prompt behind LLM based reading-comprehension dataset generation please go [here](https://github.com/arcee-ai/DALM/blob/4d93d4a198cc64ce5d19ee98786b70f579dbef0c/dalm/datasets/reading_comprehension_generation/synthetic_based.py#L22)

## How to get started

For the input, either a single csv file or a directory of individual files each containing raw text will do.


### LLM based

Assuming you have your dataset as a csv file with the column `text` containing the raw texts

(chunking based on context length of model is enabled by default)

```bash
python dalm/datasets/reading_comprehension_generation/synthetic_based.py \
--model HuggingFaceH4/zephyr-7b-alpha \
--context-length 4192
--input input_dataset.csv --output_directory synth_data --dataset_name llm_generated_dataset
```

the output directory serves as a temporary holding place of all generated data before it can be made a dataset.
The generation process is time consuming and expensive. On average, because the process uses a LLM (if using the recommended 13b llama2 model), it takes about 20-30 minutes to produce 10 questions (numbers may vary depending on the content of your dataset and the unpredictability of the model). So every step is taken to ensure that if the process is interrupted, once back running will pick up where it left off.

Chunking of data is enabled by default and requires the context length to be passed which is why it passed in in the example

### Regex based

(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts)

Please note there is the choice of passing in a domain sentence model in addition, but this is not required as
the script will train a domain specific sentencepiece model on the input corpus

```bash

python dalm/datasets/reading_comprehension_generation/regex_based.py --input input.csv \
--csv_column text --general_spm_path resources/general.spm \
--output_dataset_name regex_dataset
```
1,265 changes: 1,265 additions & 0 deletions dalm/datasets/reading_comprehension_generation/regex_based.py

Large diffs are not rendered by default.

223 changes: 223 additions & 0 deletions dalm/datasets/reading_comprehension_generation/synthetic_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import json
import logging
import os
import pickle
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
from datasets import Dataset
from transformers import Pipeline, pipeline

from dalm.datasets.reading_comprehension_generation.utils import (
input_generator,
question_and_answer_extractor,
text_chunker,
)

logger = logging.getLogger(__name__)

# ruff: noqa: B006

PROMPT = (
"There are 4 types of reading comprehension tasks. "
"The point of reading comprehension tasks is to be assigned a text and questions to "
"prompt answers so as to test conceptual and procedural knowledge present in the text. "
"The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK "
"2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state "
"the correctness of the statement) 3. frame a sentence with domain specific keywords"
"(these keywords are required to be present in the text) Q&A TASK "
"4. Normal questions and answer Task (description: longform Q&A to test procedural and "
"conceptual knowledge). An example of all four tasks given an example text is as follows: "
"\n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep "
"processes in human and animal brain led to other biologically inspired approaches. While "
"declarative memories are in the classical picture consolidated by hippocampo-neocortical "
"dialog during NREM phase of sleep, some types of procedural memories were suggested not "
"to rely on the hippocampus and involve REM phase of the sleep. This inspired models where "
"internal representations (memories) created by previous learning are spontaneously replayed "
"during sleep-like periods in the network itself (i.e. without help of secondary network "
"performed by generative replay approaches mentioned above).\n"
"Question: [type: true/false] Is the following sentence true? all types of procedural "
"memories rely on the hippocampus\n"
"Answer: False. The text clearly states there are some types of procedural memories not "
"reliant on the hippocampus\n--------\n"
"Question [type: complete-the-sentence] Complete the following sentence: The insights into "
"____ in human and animal brain led to other _____ approaches\n"
"Answer: The insights into the mechanisms of memory consolidation during the sleep processes "
"in human and animal brain led to other biologically inspired approaches\n------\n"
"Question [type 3 domain-keywords] Make a sentence with the following keywords "
"'hippocampo-neocortical', 'declarative' and 'NREM'\n"
"Answer: declarative memories are in the classical picture consolidated by "
"hippocampo-neocortical dialog during NREM phase of sleep\n-------\n"
"Question [type: normal q&a] Some types of procedural memories were suggested not to rely on "
"the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\n"
"Answer This inspired models where internal representations (memories) created by previous "
"learning are spontaneously replayed during sleep-like periods in the network itself [END OF "
"EXAMPLE]\n\n "
"Similar to the above, could you craft 4 different reading comprehension tasks (make sure "
"your output is a list of question answer pairs and each question is labelled QUESTION and "
"answer is labelled ANSWER and there is one question and answer per task) based solely and "
"completely focused on the following TEXT: "
)


def gen_prompt(text: str) -> List[Dict[str, str]]:
prompt = PROMPT + text

return [
{
"role": "system",
"content": (
"You are a helpful and meticulous instruction following question and answer making chatbot. "
"Please refrain from acknowledgments, additions or niceties of any sort"
),
},
{"role": "user", "content": prompt},
]


def generate_synthetic_data(model_pipeline: Pipeline, text: str, generation_params: Dict[str, Any]) -> str:
prompt = gen_prompt(text)
prompt = model_pipeline.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
outputs = model_pipeline(prompt, **generation_params)

return outputs[0]["generated_text"]


def generate_synthetic_dataset(
model_name: str,
input_directory_or_file: str,
csv_column: Optional[str],
processed_files: List[str],
chunk: bool,
context_length: int,
generation_params: Dict[str, Any] = {
"max_new_tokens": 600,
"do_sample": True,
"temperature": 0.7,
"top_k": 5,
"top_p": 0.95,
"return_full_text": False,
},
) -> Iterator[Tuple[int, str, str, str]]:
model_pipeline = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")

input_files = input_generator(input_directory_or_file, csv_column)

if chunk:
tokenizer = model_pipeline.tokenizer
tokens = tokenizer.apply_chat_template(gen_prompt(""), tokenize=False, add_generation_prompt=True)
CONSTANT = len(tokenizer(tokens)["input_ids"])
k = context_length - CONSTANT

for file, text in input_files:
if file in processed_files:
continue

if chunk:
for index, chunk_ in enumerate(text_chunker(text, tokenizer, k)):
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params)
yield index, file, chunk_, gen_text
else:
gen_text = generate_synthetic_data(model_pipeline, text, generation_params)
yield 0, file, text, gen_text


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("Generate synthetic dataset for reading comprehension")
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta")
parser.add_argument("--input", type=str, required=True, help="Directory containing the input files OR a CSV file")
parser.add_argument("--csv_column", type=str, help="Column to read from the CSV file")
parser.add_argument(
"--output_directory",
type=str,
required=True,
help="Directory to save the generated files (serves as intermediate step and for debugging purposes)",
)
parser.add_argument(
"--state_file",
type=str,
required=False,
default="rc_generation_state.pkl",
help="File to save the state of the generation in order to support resume functionality",
)
parser.add_argument("--context_length", type=int, default=4096, help="context length to calculate the chunk size")
parser.add_argument("--no_chunk", action="store_false")
parser.add_argument(
"--dataset_name", type=str, default="synthetic_rc_dataset", help="name of the dataset to be saved"
)

return parser.parse_args()


def main() -> None:
args = parse_args()
"""
Pipeline here includes chunking, generation and parsing of question and answer into a list of exchanges
that can be used directly for training
"""

if os.path.isfile(args.input) and not args.csv_column:
raise ValueError("a CSV column must be specified if the input is a file")

if args.state_file:
if os.path.exists(args.state_file):
with open(args.state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(args.state_file, "wb"))

if not os.path.exists(args.output_directory):
os.makedirs(args.output_directory)

files_missed = 0
total_files = 0

synth_dataset_generator = generate_synthetic_dataset(
model_name=args.model_name,
input_directory_or_file=args.input,
processed_files=state["processed_files"] if args.state_file else [],
chunk=args.no_chunk,
context_length=args.context_length,
csv_column=args.csv_column,
)

for index, filename, context, gen_text in synth_dataset_generator:
state["processed_files"].append(filename)
pickle.dump(state, open(args.state_file, "wb"))
qanda = question_and_answer_extractor(gen_text, context)
if qanda:
output_file = f"{filename}_{index}.json"
with open(os.path.join(args.output_directory, output_file), "w") as o:
json.dump(qanda, o)
else:
logger.warning(
(f"No question and answer pairs found for {filename} " f"chunk: {index}" if not args.no_chunk else "")
)
files_missed += 1
total_files += 1

unit = "files" if args.no_chunk else "chunks"

logger.info(" Statistics ")
logger.info(f"Total number of successfully extracted q&a {unit}: {total_files - files_missed}")
logger.info(f"Total {unit} missed: {files_missed} out of {total_files}")

in_memory_dataset = []
for file in os.listdir(args.output_directory):
with open(os.path.join(args.output_directory, file), "r") as f:
in_memory_dataset.append({"messages": json.load(f)})

dataset = Dataset.from_list(in_memory_dataset)
dataset.save_to_disk(args.dataset_name)

logger.info("Done generating synthetic dataset")
logger.info(f"Dataset saved to {args.dataset_name}")

if args.state_file:
os.remove(args.state_file)


if __name__ == "__main__":
main()
Loading

0 comments on commit 567c910

Please sign in to comment.