-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Broken reading-comprehension generators and generator training * Refactor out shared utils and add domain tokenizer training as a option in direct script tun * Corrections * Further corrections * Add q&a extractor as a util and to the pipeline example script * Util correction and context addition to output * Revert previous corrections * Chatml outputformat for regex based rc and remove domain keyword * Generator additions * More generator corrections and additions * Add train num epochs as an arg to the generator training script * Add new dependencies * 1. json dump when writing to file 2. regex-based-gen now creates a domain tokenizer if both domain sentencepiece model and domain text (explicitly) is not given 3. attempt at pipeline * Post pipeline test run corrections * Remove (explicit) use of ConstantLengthDataset * Lift state management to a higer level and some corrections * Add option to save dataset as huggingface dataset * Training script corrections * Trainer script cleanup and corrections * Add cloud friendly logger * Reformatted synthetic dataset generation + corrections * More formatting for the synth-gen script * More corrections to synth-gen * Regex-gen changes and banner addition * Missing comma * Type hint all the functions in utils * Lightly refactor the pipeline * Address proper negation of lags * Switch out generator for iterator when type hinting * Util typing correction * Correct all linting issues * Pipeline corrections * More corrections for output type of generator * More corrections to the pipeline * Appeasing the linter for the pipeline code * Appeasing the linter for llm synth script * Appeasing the linter for the training script * Linter based corrections for utils * More appeasing of the linter and work arounds * Incorporate csv reading and associated changes * Unicode decoding revisit * More fixes * Forgot to put in replace line * Better logging and removal of statefile and more corrections * Add missing general spm input validation line to pipeline script * More validation lines for pipeline * More corrections * Banner correction, corrections * Start of README.md and add general sentencepiece model to resources * Add defaults for cli args * Add more detail to README.md * Add defaults to function * Defaults * README.md for rc pipeline * transformers version dependency constraint * alpha -> beta * Better warning message * Correct description in README.md * Stream arg correction for trainer * Add prompt link to README * Add general spm to resources * - Better input content generator (deals with directory of csv(s)) - remove default wandb value - new logging statement - change of terminology for logging (files -> texts) * Vocab size second-try and key error fix * Correct logging * Update dalm/pipelines/reading_comprehension_pipeline.py Co-authored-by: Traun Leyden <[email protected]> * Update dalm/pipelines/reading_comprehension_pipeline.py Co-authored-by: Traun Leyden <[email protected]> * Update dalm/datasets/reading_comprehension_generation/synthetic_based.py Co-authored-by: Traun Leyden <[email protected]> * Update dalm/datasets/reading_comprehension_generation/utils.py Co-authored-by: Traun Leyden <[email protected]> * Update dalm/pipelines/reading_comprehension_pipeline.py Co-authored-by: Traun Leyden <[email protected]> * Update dalm/pipelines/reading_comprehension_pipeline.py Co-authored-by: Traun Leyden <[email protected]> * Corrections * Post linting * Update README with suggested corrections * grammar corrections --------- Co-authored-by: Traun Leyden <[email protected]>
- Loading branch information
1 parent
e6c3d29
commit 567c910
Showing
9 changed files
with
2,455 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
## A note about reading comprehension | ||
|
||
This aproach of adapting LLMs is based on this [paper](https://arxiv.org/abs/2309.09530) by Microsoft | ||
The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora | ||
and training a llm on said generated dataset can enhance its domain adaptiveness | ||
|
||
We have two ways of generating reading comprehension data | ||
|
||
1. Via regex based methods that combs the input data for match and aligns them into questions and answers | ||
2. Via prompting a large language model to come up with questions and answers | ||
|
||
To see the prompt behind LLM based reading-comprehension dataset generation please go [here](https://github.com/arcee-ai/DALM/blob/4d93d4a198cc64ce5d19ee98786b70f579dbef0c/dalm/datasets/reading_comprehension_generation/synthetic_based.py#L22) | ||
|
||
## How to get started | ||
|
||
For the input, either a single csv file or a directory of individual files each containing raw text will do. | ||
|
||
|
||
### LLM based | ||
|
||
Assuming you have your dataset as a csv file with the column `text` containing the raw texts | ||
|
||
(chunking based on context length of model is enabled by default) | ||
|
||
```bash | ||
python dalm/datasets/reading_comprehension_generation/synthetic_based.py \ | ||
--model HuggingFaceH4/zephyr-7b-alpha \ | ||
--context-length 4192 | ||
--input input_dataset.csv --output_directory synth_data --dataset_name llm_generated_dataset | ||
``` | ||
|
||
the output directory serves as a temporary holding place of all generated data before it can be made a dataset. | ||
The generation process is time consuming and expensive. On average, because the process uses a LLM (if using the recommended 13b llama2 model), it takes about 20-30 minutes to produce 10 questions (numbers may vary depending on the content of your dataset and the unpredictability of the model). So every step is taken to ensure that if the process is interrupted, once back running will pick up where it left off. | ||
|
||
Chunking of data is enabled by default and requires the context length to be passed which is why it passed in in the example | ||
|
||
### Regex based | ||
|
||
(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts) | ||
|
||
Please note there is the choice of passing in a domain sentence model in addition, but this is not required as | ||
the script will train a domain specific sentencepiece model on the input corpus | ||
|
||
```bash | ||
|
||
python dalm/datasets/reading_comprehension_generation/regex_based.py --input input.csv \ | ||
--csv_column text --general_spm_path resources/general.spm \ | ||
--output_dataset_name regex_dataset | ||
``` |
1,265 changes: 1,265 additions & 0 deletions
1,265
dalm/datasets/reading_comprehension_generation/regex_based.py
Large diffs are not rendered by default.
Oops, something went wrong.
223 changes: 223 additions & 0 deletions
223
dalm/datasets/reading_comprehension_generation/synthetic_based.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import pickle | ||
from typing import Any, Dict, Iterator, List, Optional, Tuple | ||
|
||
import torch | ||
from datasets import Dataset | ||
from transformers import Pipeline, pipeline | ||
|
||
from dalm.datasets.reading_comprehension_generation.utils import ( | ||
input_generator, | ||
question_and_answer_extractor, | ||
text_chunker, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# ruff: noqa: B006 | ||
|
||
PROMPT = ( | ||
"There are 4 types of reading comprehension tasks. " | ||
"The point of reading comprehension tasks is to be assigned a text and questions to " | ||
"prompt answers so as to test conceptual and procedural knowledge present in the text. " | ||
"The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK " | ||
"2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state " | ||
"the correctness of the statement) 3. frame a sentence with domain specific keywords" | ||
"(these keywords are required to be present in the text) Q&A TASK " | ||
"4. Normal questions and answer Task (description: longform Q&A to test procedural and " | ||
"conceptual knowledge). An example of all four tasks given an example text is as follows: " | ||
"\n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep " | ||
"processes in human and animal brain led to other biologically inspired approaches. While " | ||
"declarative memories are in the classical picture consolidated by hippocampo-neocortical " | ||
"dialog during NREM phase of sleep, some types of procedural memories were suggested not " | ||
"to rely on the hippocampus and involve REM phase of the sleep. This inspired models where " | ||
"internal representations (memories) created by previous learning are spontaneously replayed " | ||
"during sleep-like periods in the network itself (i.e. without help of secondary network " | ||
"performed by generative replay approaches mentioned above).\n" | ||
"Question: [type: true/false] Is the following sentence true? all types of procedural " | ||
"memories rely on the hippocampus\n" | ||
"Answer: False. The text clearly states there are some types of procedural memories not " | ||
"reliant on the hippocampus\n--------\n" | ||
"Question [type: complete-the-sentence] Complete the following sentence: The insights into " | ||
"____ in human and animal brain led to other _____ approaches\n" | ||
"Answer: The insights into the mechanisms of memory consolidation during the sleep processes " | ||
"in human and animal brain led to other biologically inspired approaches\n------\n" | ||
"Question [type 3 domain-keywords] Make a sentence with the following keywords " | ||
"'hippocampo-neocortical', 'declarative' and 'NREM'\n" | ||
"Answer: declarative memories are in the classical picture consolidated by " | ||
"hippocampo-neocortical dialog during NREM phase of sleep\n-------\n" | ||
"Question [type: normal q&a] Some types of procedural memories were suggested not to rely on " | ||
"the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\n" | ||
"Answer This inspired models where internal representations (memories) created by previous " | ||
"learning are spontaneously replayed during sleep-like periods in the network itself [END OF " | ||
"EXAMPLE]\n\n " | ||
"Similar to the above, could you craft 4 different reading comprehension tasks (make sure " | ||
"your output is a list of question answer pairs and each question is labelled QUESTION and " | ||
"answer is labelled ANSWER and there is one question and answer per task) based solely and " | ||
"completely focused on the following TEXT: " | ||
) | ||
|
||
|
||
def gen_prompt(text: str) -> List[Dict[str, str]]: | ||
prompt = PROMPT + text | ||
|
||
return [ | ||
{ | ||
"role": "system", | ||
"content": ( | ||
"You are a helpful and meticulous instruction following question and answer making chatbot. " | ||
"Please refrain from acknowledgments, additions or niceties of any sort" | ||
), | ||
}, | ||
{"role": "user", "content": prompt}, | ||
] | ||
|
||
|
||
def generate_synthetic_data(model_pipeline: Pipeline, text: str, generation_params: Dict[str, Any]) -> str: | ||
prompt = gen_prompt(text) | ||
prompt = model_pipeline.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | ||
outputs = model_pipeline(prompt, **generation_params) | ||
|
||
return outputs[0]["generated_text"] | ||
|
||
|
||
def generate_synthetic_dataset( | ||
model_name: str, | ||
input_directory_or_file: str, | ||
csv_column: Optional[str], | ||
processed_files: List[str], | ||
chunk: bool, | ||
context_length: int, | ||
generation_params: Dict[str, Any] = { | ||
"max_new_tokens": 600, | ||
"do_sample": True, | ||
"temperature": 0.7, | ||
"top_k": 5, | ||
"top_p": 0.95, | ||
"return_full_text": False, | ||
}, | ||
) -> Iterator[Tuple[int, str, str, str]]: | ||
model_pipeline = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto") | ||
|
||
input_files = input_generator(input_directory_or_file, csv_column) | ||
|
||
if chunk: | ||
tokenizer = model_pipeline.tokenizer | ||
tokens = tokenizer.apply_chat_template(gen_prompt(""), tokenize=False, add_generation_prompt=True) | ||
CONSTANT = len(tokenizer(tokens)["input_ids"]) | ||
k = context_length - CONSTANT | ||
|
||
for file, text in input_files: | ||
if file in processed_files: | ||
continue | ||
|
||
if chunk: | ||
for index, chunk_ in enumerate(text_chunker(text, tokenizer, k)): | ||
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params) | ||
yield index, file, chunk_, gen_text | ||
else: | ||
gen_text = generate_synthetic_data(model_pipeline, text, generation_params) | ||
yield 0, file, text, gen_text | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser("Generate synthetic dataset for reading comprehension") | ||
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta") | ||
parser.add_argument("--input", type=str, required=True, help="Directory containing the input files OR a CSV file") | ||
parser.add_argument("--csv_column", type=str, help="Column to read from the CSV file") | ||
parser.add_argument( | ||
"--output_directory", | ||
type=str, | ||
required=True, | ||
help="Directory to save the generated files (serves as intermediate step and for debugging purposes)", | ||
) | ||
parser.add_argument( | ||
"--state_file", | ||
type=str, | ||
required=False, | ||
default="rc_generation_state.pkl", | ||
help="File to save the state of the generation in order to support resume functionality", | ||
) | ||
parser.add_argument("--context_length", type=int, default=4096, help="context length to calculate the chunk size") | ||
parser.add_argument("--no_chunk", action="store_false") | ||
parser.add_argument( | ||
"--dataset_name", type=str, default="synthetic_rc_dataset", help="name of the dataset to be saved" | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main() -> None: | ||
args = parse_args() | ||
""" | ||
Pipeline here includes chunking, generation and parsing of question and answer into a list of exchanges | ||
that can be used directly for training | ||
""" | ||
|
||
if os.path.isfile(args.input) and not args.csv_column: | ||
raise ValueError("a CSV column must be specified if the input is a file") | ||
|
||
if args.state_file: | ||
if os.path.exists(args.state_file): | ||
with open(args.state_file, "rb") as f: | ||
state = pickle.load(f) | ||
else: | ||
state = {"processed_files": []} | ||
pickle.dump(state, open(args.state_file, "wb")) | ||
|
||
if not os.path.exists(args.output_directory): | ||
os.makedirs(args.output_directory) | ||
|
||
files_missed = 0 | ||
total_files = 0 | ||
|
||
synth_dataset_generator = generate_synthetic_dataset( | ||
model_name=args.model_name, | ||
input_directory_or_file=args.input, | ||
processed_files=state["processed_files"] if args.state_file else [], | ||
chunk=args.no_chunk, | ||
context_length=args.context_length, | ||
csv_column=args.csv_column, | ||
) | ||
|
||
for index, filename, context, gen_text in synth_dataset_generator: | ||
state["processed_files"].append(filename) | ||
pickle.dump(state, open(args.state_file, "wb")) | ||
qanda = question_and_answer_extractor(gen_text, context) | ||
if qanda: | ||
output_file = f"{filename}_{index}.json" | ||
with open(os.path.join(args.output_directory, output_file), "w") as o: | ||
json.dump(qanda, o) | ||
else: | ||
logger.warning( | ||
(f"No question and answer pairs found for {filename} " f"chunk: {index}" if not args.no_chunk else "") | ||
) | ||
files_missed += 1 | ||
total_files += 1 | ||
|
||
unit = "files" if args.no_chunk else "chunks" | ||
|
||
logger.info(" Statistics ") | ||
logger.info(f"Total number of successfully extracted q&a {unit}: {total_files - files_missed}") | ||
logger.info(f"Total {unit} missed: {files_missed} out of {total_files}") | ||
|
||
in_memory_dataset = [] | ||
for file in os.listdir(args.output_directory): | ||
with open(os.path.join(args.output_directory, file), "r") as f: | ||
in_memory_dataset.append({"messages": json.load(f)}) | ||
|
||
dataset = Dataset.from_list(in_memory_dataset) | ||
dataset.save_to_disk(args.dataset_name) | ||
|
||
logger.info("Done generating synthetic dataset") | ||
logger.info(f"Dataset saved to {args.dataset_name}") | ||
|
||
if args.state_file: | ||
os.remove(args.state_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.