-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Da 24/reading comprehension #74
Merged
Merged
Changes from all commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
298e100
Broken reading-comprehension generators and generator training
metric-space ce4a23f
Refactor out shared utils and add domain tokenizer training as a opti…
metric-space 9987c06
Corrections
metric-space b842667
Further corrections
metric-space f6d253a
Add q&a extractor as a util and to the pipeline example script
metric-space 571b41c
Util correction and context addition to output
metric-space 5c349ca
Revert previous corrections
metric-space f11797b
Chatml outputformat for regex based rc and remove domain keyword
metric-space 21b6210
Generator additions
metric-space 3c49970
More generator corrections and additions
metric-space c7b63f6
Add train num epochs as an arg to the generator training script
metric-space 8a9ec67
Add new dependencies
metric-space c0c530f
1. json dump when writing to file
metric-space fcdfa3b
Post pipeline test run corrections
metric-space 0f7b7e0
Remove (explicit) use of ConstantLengthDataset
metric-space 7515b0a
Lift state management to a higer level and some corrections
metric-space 75c7811
Add option to save dataset as huggingface dataset
metric-space 979408c
Training script corrections
metric-space 42f77f5
Trainer script cleanup and corrections
metric-space de7f60c
Add cloud friendly logger
metric-space ead9e7b
Reformatted synthetic dataset generation + corrections
metric-space 3efa61b
More formatting for the synth-gen script
metric-space 1d39491
More corrections to synth-gen
metric-space 3abc635
Regex-gen changes and banner addition
metric-space 980cc6a
Missing comma
metric-space 1868423
Type hint all the functions in utils
metric-space ce9b352
Lightly refactor the pipeline
metric-space 3760fc5
Address proper negation of lags
metric-space fe85703
Switch out generator for iterator when type hinting
metric-space 331eaed
Util typing correction
metric-space 8c1a35a
Correct all linting issues
metric-space e77d273
Pipeline corrections
metric-space 2f91c25
More corrections for output type of generator
metric-space 3b8e270
More corrections to the pipeline
metric-space a37b5d3
Appeasing the linter for the pipeline code
metric-space 280e2dd
Appeasing the linter for llm synth script
metric-space 3facf92
Appeasing the linter for the training script
metric-space 3f50761
Linter based corrections for utils
metric-space 8bb10c4
More appeasing of the linter and work arounds
metric-space 42330eb
Incorporate csv reading and associated changes
metric-space 194cf91
Unicode decoding revisit
metric-space 451015b
More fixes
metric-space 2e3fd08
Forgot to put in replace line
metric-space 9a430e5
Better logging and removal of statefile and more corrections
metric-space 3eb3b44
Add missing general spm input validation line to pipeline script
metric-space 14f65c5
More validation lines for pipeline
metric-space 55f03ef
More corrections
metric-space 4d5802d
Banner correction, corrections
metric-space 3f17238
Start of README.md and add general sentencepiece model to resources
metric-space 88a74c1
Add defaults for cli args
metric-space 4ab1211
Add more detail to README.md
metric-space b6fa0ae
Add defaults to function
metric-space 644c294
Defaults
metric-space 9c2f00d
README.md for rc pipeline
metric-space 1d4ec97
transformers version dependency constraint
metric-space 9171734
alpha -> beta
metric-space 1af81cd
Better warning message
metric-space 3039050
Correct description in README.md
metric-space 4d93d4a
Stream arg correction for trainer
metric-space 5fb1252
Add prompt link to README
metric-space 958f1f4
Add general spm to resources
metric-space 749f11c
- Better input content generator (deals with directory of csv(s))
metric-space 903ba18
Vocab size second-try and key error fix
metric-space bc63323
Correct logging
metric-space a6f0e6e
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space 0b89e1d
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space 340b969
Update dalm/datasets/reading_comprehension_generation/synthetic_based.py
metric-space 9ae906d
Update dalm/datasets/reading_comprehension_generation/utils.py
metric-space 1d3ed52
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space a7ab91c
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space 9cb16ee
Corrections
metric-space 4025e45
Post linting
metric-space 39b7f1e
Update README with suggested corrections
metric-space 81a2ea0
grammar corrections
metric-space File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
## A note about reading comprehension | ||
|
||
This aproach of adapting LLMs is based on this [paper](https://arxiv.org/abs/2309.09530) by Microsoft | ||
The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora | ||
and training a llm on said generated dataset can enhance its domain adaptiveness | ||
|
||
We have two ways of generating reading comprehension data | ||
|
||
1. Via regex based methods that combs the input data for match and aligns them into questions and answers | ||
2. Via prompting a large language model to come up with questions and answers | ||
|
||
To see the prompt behind LLM based reading-comprehension dataset generation please go [here](https://github.com/arcee-ai/DALM/blob/4d93d4a198cc64ce5d19ee98786b70f579dbef0c/dalm/datasets/reading_comprehension_generation/synthetic_based.py#L22) | ||
|
||
## How to get started | ||
|
||
For the input, either a single csv file or a directory of individual files each containing raw text will do. | ||
|
||
|
||
### LLM based | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we please add a one line saying where the users can prompts related to this! |
||
|
||
Assuming you have your dataset as a csv file with the column `text` containing the raw texts | ||
|
||
(chunking based on context length of model is enabled by default) | ||
|
||
```bash | ||
python dalm/datasets/reading_comprehension_generation/synthetic_based.py \ | ||
--model HuggingFaceH4/zephyr-7b-alpha \ | ||
--context-length 4192 | ||
--input input_dataset.csv --output_directory synth_data --dataset_name llm_generated_dataset | ||
``` | ||
|
||
the output directory serves as a temporary holding place of all generated data before it can be made a dataset. | ||
The generation process is time consuming and expensive. On average, because the process uses a LLM (if using the recommended 13b llama2 model), it takes about 20-30 minutes to produce 10 questions (numbers may vary depending on the content of your dataset and the unpredictability of the model). So every step is taken to ensure that if the process is interrupted, once back running will pick up where it left off. | ||
|
||
Chunking of data is enabled by default and requires the context length to be passed which is why it passed in in the example | ||
|
||
### Regex based | ||
|
||
(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts) | ||
|
||
Please note there is the choice of passing in a domain sentence model in addition, but this is not required as | ||
the script will train a domain specific sentencepiece model on the input corpus | ||
|
||
```bash | ||
|
||
python dalm/datasets/reading_comprehension_generation/regex_based.py --input input.csv \ | ||
--csv_column text --general_spm_path resources/general.spm \ | ||
--output_dataset_name regex_dataset | ||
``` |
1,265 changes: 1,265 additions & 0 deletions
1,265
dalm/datasets/reading_comprehension_generation/regex_based.py
Large diffs are not rendered by default.
Oops, something went wrong.
223 changes: 223 additions & 0 deletions
223
dalm/datasets/reading_comprehension_generation/synthetic_based.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import pickle | ||
from typing import Any, Dict, Iterator, List, Optional, Tuple | ||
|
||
import torch | ||
from datasets import Dataset | ||
from transformers import Pipeline, pipeline | ||
|
||
from dalm.datasets.reading_comprehension_generation.utils import ( | ||
input_generator, | ||
question_and_answer_extractor, | ||
text_chunker, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# ruff: noqa: B006 | ||
|
||
PROMPT = ( | ||
"There are 4 types of reading comprehension tasks. " | ||
"The point of reading comprehension tasks is to be assigned a text and questions to " | ||
"prompt answers so as to test conceptual and procedural knowledge present in the text. " | ||
"The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK " | ||
"2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state " | ||
"the correctness of the statement) 3. frame a sentence with domain specific keywords" | ||
"(these keywords are required to be present in the text) Q&A TASK " | ||
"4. Normal questions and answer Task (description: longform Q&A to test procedural and " | ||
"conceptual knowledge). An example of all four tasks given an example text is as follows: " | ||
"\n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep " | ||
"processes in human and animal brain led to other biologically inspired approaches. While " | ||
"declarative memories are in the classical picture consolidated by hippocampo-neocortical " | ||
"dialog during NREM phase of sleep, some types of procedural memories were suggested not " | ||
"to rely on the hippocampus and involve REM phase of the sleep. This inspired models where " | ||
"internal representations (memories) created by previous learning are spontaneously replayed " | ||
"during sleep-like periods in the network itself (i.e. without help of secondary network " | ||
"performed by generative replay approaches mentioned above).\n" | ||
"Question: [type: true/false] Is the following sentence true? all types of procedural " | ||
"memories rely on the hippocampus\n" | ||
"Answer: False. The text clearly states there are some types of procedural memories not " | ||
"reliant on the hippocampus\n--------\n" | ||
"Question [type: complete-the-sentence] Complete the following sentence: The insights into " | ||
"____ in human and animal brain led to other _____ approaches\n" | ||
"Answer: The insights into the mechanisms of memory consolidation during the sleep processes " | ||
"in human and animal brain led to other biologically inspired approaches\n------\n" | ||
"Question [type 3 domain-keywords] Make a sentence with the following keywords " | ||
"'hippocampo-neocortical', 'declarative' and 'NREM'\n" | ||
"Answer: declarative memories are in the classical picture consolidated by " | ||
"hippocampo-neocortical dialog during NREM phase of sleep\n-------\n" | ||
"Question [type: normal q&a] Some types of procedural memories were suggested not to rely on " | ||
"the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\n" | ||
"Answer This inspired models where internal representations (memories) created by previous " | ||
"learning are spontaneously replayed during sleep-like periods in the network itself [END OF " | ||
"EXAMPLE]\n\n " | ||
"Similar to the above, could you craft 4 different reading comprehension tasks (make sure " | ||
"your output is a list of question answer pairs and each question is labelled QUESTION and " | ||
"answer is labelled ANSWER and there is one question and answer per task) based solely and " | ||
"completely focused on the following TEXT: " | ||
) | ||
|
||
|
||
def gen_prompt(text: str) -> List[Dict[str, str]]: | ||
prompt = PROMPT + text | ||
|
||
return [ | ||
{ | ||
"role": "system", | ||
"content": ( | ||
"You are a helpful and meticulous instruction following question and answer making chatbot. " | ||
"Please refrain from acknowledgments, additions or niceties of any sort" | ||
), | ||
}, | ||
{"role": "user", "content": prompt}, | ||
] | ||
|
||
|
||
def generate_synthetic_data(model_pipeline: Pipeline, text: str, generation_params: Dict[str, Any]) -> str: | ||
prompt = gen_prompt(text) | ||
prompt = model_pipeline.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | ||
outputs = model_pipeline(prompt, **generation_params) | ||
|
||
return outputs[0]["generated_text"] | ||
|
||
|
||
def generate_synthetic_dataset( | ||
model_name: str, | ||
input_directory_or_file: str, | ||
csv_column: Optional[str], | ||
processed_files: List[str], | ||
chunk: bool, | ||
context_length: int, | ||
tleyden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
generation_params: Dict[str, Any] = { | ||
"max_new_tokens": 600, | ||
"do_sample": True, | ||
"temperature": 0.7, | ||
"top_k": 5, | ||
"top_p": 0.95, | ||
"return_full_text": False, | ||
}, | ||
) -> Iterator[Tuple[int, str, str, str]]: | ||
model_pipeline = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto") | ||
|
||
input_files = input_generator(input_directory_or_file, csv_column) | ||
|
||
if chunk: | ||
tokenizer = model_pipeline.tokenizer | ||
tokens = tokenizer.apply_chat_template(gen_prompt(""), tokenize=False, add_generation_prompt=True) | ||
CONSTANT = len(tokenizer(tokens)["input_ids"]) | ||
k = context_length - CONSTANT | ||
|
||
for file, text in input_files: | ||
if file in processed_files: | ||
continue | ||
|
||
if chunk: | ||
for index, chunk_ in enumerate(text_chunker(text, tokenizer, k)): | ||
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params) | ||
yield index, file, chunk_, gen_text | ||
else: | ||
gen_text = generate_synthetic_data(model_pipeline, text, generation_params) | ||
yield 0, file, text, gen_text | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser("Generate synthetic dataset for reading comprehension") | ||
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta") | ||
tleyden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser.add_argument("--input", type=str, required=True, help="Directory containing the input files OR a CSV file") | ||
parser.add_argument("--csv_column", type=str, help="Column to read from the CSV file") | ||
parser.add_argument( | ||
"--output_directory", | ||
type=str, | ||
required=True, | ||
help="Directory to save the generated files (serves as intermediate step and for debugging purposes)", | ||
) | ||
parser.add_argument( | ||
"--state_file", | ||
type=str, | ||
required=False, | ||
default="rc_generation_state.pkl", | ||
help="File to save the state of the generation in order to support resume functionality", | ||
) | ||
parser.add_argument("--context_length", type=int, default=4096, help="context length to calculate the chunk size") | ||
parser.add_argument("--no_chunk", action="store_false") | ||
parser.add_argument( | ||
"--dataset_name", type=str, default="synthetic_rc_dataset", help="name of the dataset to be saved" | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main() -> None: | ||
args = parse_args() | ||
""" | ||
Pipeline here includes chunking, generation and parsing of question and answer into a list of exchanges | ||
that can be used directly for training | ||
tleyden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
if os.path.isfile(args.input) and not args.csv_column: | ||
raise ValueError("a CSV column must be specified if the input is a file") | ||
|
||
if args.state_file: | ||
if os.path.exists(args.state_file): | ||
with open(args.state_file, "rb") as f: | ||
state = pickle.load(f) | ||
else: | ||
state = {"processed_files": []} | ||
pickle.dump(state, open(args.state_file, "wb")) | ||
|
||
if not os.path.exists(args.output_directory): | ||
os.makedirs(args.output_directory) | ||
|
||
files_missed = 0 | ||
total_files = 0 | ||
|
||
synth_dataset_generator = generate_synthetic_dataset( | ||
model_name=args.model_name, | ||
input_directory_or_file=args.input, | ||
processed_files=state["processed_files"] if args.state_file else [], | ||
chunk=args.no_chunk, | ||
context_length=args.context_length, | ||
csv_column=args.csv_column, | ||
) | ||
|
||
for index, filename, context, gen_text in synth_dataset_generator: | ||
state["processed_files"].append(filename) | ||
pickle.dump(state, open(args.state_file, "wb")) | ||
qanda = question_and_answer_extractor(gen_text, context) | ||
if qanda: | ||
output_file = f"{filename}_{index}.json" | ||
with open(os.path.join(args.output_directory, output_file), "w") as o: | ||
json.dump(qanda, o) | ||
else: | ||
logger.warning( | ||
(f"No question and answer pairs found for {filename} " f"chunk: {index}" if not args.no_chunk else "") | ||
) | ||
files_missed += 1 | ||
total_files += 1 | ||
|
||
unit = "files" if args.no_chunk else "chunks" | ||
|
||
logger.info(" Statistics ") | ||
logger.info(f"Total number of successfully extracted q&a {unit}: {total_files - files_missed}") | ||
logger.info(f"Total {unit} missed: {files_missed} out of {total_files}") | ||
|
||
in_memory_dataset = [] | ||
for file in os.listdir(args.output_directory): | ||
with open(os.path.join(args.output_directory, file), "r") as f: | ||
in_memory_dataset.append({"messages": json.load(f)}) | ||
|
||
dataset = Dataset.from_list(in_memory_dataset) | ||
dataset.save_to_disk(args.dataset_name) | ||
|
||
logger.info("Done generating synthetic dataset") | ||
logger.info(f"Dataset saved to {args.dataset_name}") | ||
|
||
if args.state_file: | ||
os.remove(args.state_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a user, how should I decide which approach to use? Here's a stab: