Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Da 24/reading comprehension #74

Merged
merged 74 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
298e100
Broken reading-comprehension generators and generator training
metric-space Nov 15, 2023
ce4a23f
Refactor out shared utils and add domain tokenizer training as a opti…
metric-space Nov 16, 2023
9987c06
Corrections
metric-space Nov 16, 2023
b842667
Further corrections
metric-space Nov 16, 2023
f6d253a
Add q&a extractor as a util and to the pipeline example script
metric-space Nov 16, 2023
571b41c
Util correction and context addition to output
metric-space Nov 16, 2023
5c349ca
Revert previous corrections
metric-space Nov 16, 2023
f11797b
Chatml outputformat for regex based rc and remove domain keyword
metric-space Nov 17, 2023
21b6210
Generator additions
metric-space Nov 17, 2023
3c49970
More generator corrections and additions
metric-space Nov 17, 2023
c7b63f6
Add train num epochs as an arg to the generator training script
metric-space Nov 17, 2023
8a9ec67
Add new dependencies
metric-space Nov 18, 2023
c0c530f
1. json dump when writing to file
metric-space Nov 18, 2023
fcdfa3b
Post pipeline test run corrections
metric-space Nov 19, 2023
0f7b7e0
Remove (explicit) use of ConstantLengthDataset
metric-space Nov 21, 2023
7515b0a
Lift state management to a higer level and some corrections
metric-space Nov 21, 2023
75c7811
Add option to save dataset as huggingface dataset
metric-space Nov 21, 2023
979408c
Training script corrections
metric-space Nov 21, 2023
42f77f5
Trainer script cleanup and corrections
metric-space Nov 21, 2023
de7f60c
Add cloud friendly logger
metric-space Nov 21, 2023
ead9e7b
Reformatted synthetic dataset generation + corrections
metric-space Nov 21, 2023
3efa61b
More formatting for the synth-gen script
metric-space Nov 21, 2023
1d39491
More corrections to synth-gen
metric-space Nov 21, 2023
3abc635
Regex-gen changes and banner addition
metric-space Nov 21, 2023
980cc6a
Missing comma
metric-space Nov 21, 2023
1868423
Type hint all the functions in utils
metric-space Nov 21, 2023
ce9b352
Lightly refactor the pipeline
metric-space Nov 21, 2023
3760fc5
Address proper negation of lags
metric-space Nov 21, 2023
fe85703
Switch out generator for iterator when type hinting
metric-space Nov 21, 2023
331eaed
Util typing correction
metric-space Nov 21, 2023
8c1a35a
Correct all linting issues
metric-space Nov 22, 2023
e77d273
Pipeline corrections
metric-space Nov 22, 2023
2f91c25
More corrections for output type of generator
metric-space Nov 22, 2023
3b8e270
More corrections to the pipeline
metric-space Nov 22, 2023
a37b5d3
Appeasing the linter for the pipeline code
metric-space Nov 22, 2023
280e2dd
Appeasing the linter for llm synth script
metric-space Nov 22, 2023
3facf92
Appeasing the linter for the training script
metric-space Nov 22, 2023
3f50761
Linter based corrections for utils
metric-space Nov 22, 2023
8bb10c4
More appeasing of the linter and work arounds
metric-space Nov 22, 2023
42330eb
Incorporate csv reading and associated changes
metric-space Nov 22, 2023
194cf91
Unicode decoding revisit
metric-space Nov 22, 2023
451015b
More fixes
metric-space Nov 22, 2023
2e3fd08
Forgot to put in replace line
metric-space Nov 22, 2023
9a430e5
Better logging and removal of statefile and more corrections
metric-space Nov 23, 2023
3eb3b44
Add missing general spm input validation line to pipeline script
metric-space Nov 23, 2023
14f65c5
More validation lines for pipeline
metric-space Nov 23, 2023
55f03ef
More corrections
metric-space Nov 23, 2023
4d5802d
Banner correction, corrections
metric-space Nov 23, 2023
3f17238
Start of README.md and add general sentencepiece model to resources
metric-space Nov 23, 2023
88a74c1
Add defaults for cli args
metric-space Nov 23, 2023
4ab1211
Add more detail to README.md
metric-space Nov 23, 2023
b6fa0ae
Add defaults to function
metric-space Nov 23, 2023
644c294
Defaults
metric-space Nov 23, 2023
9c2f00d
README.md for rc pipeline
metric-space Nov 23, 2023
1d4ec97
transformers version dependency constraint
metric-space Nov 23, 2023
9171734
alpha -> beta
metric-space Nov 23, 2023
1af81cd
Better warning message
metric-space Nov 23, 2023
3039050
Correct description in README.md
metric-space Nov 23, 2023
4d93d4a
Stream arg correction for trainer
metric-space Nov 23, 2023
5fb1252
Add prompt link to README
metric-space Nov 23, 2023
958f1f4
Add general spm to resources
metric-space Nov 27, 2023
749f11c
- Better input content generator (deals with directory of csv(s))
metric-space Nov 29, 2023
903ba18
Vocab size second-try and key error fix
metric-space Nov 29, 2023
bc63323
Correct logging
metric-space Nov 29, 2023
a6f0e6e
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
0b89e1d
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
340b969
Update dalm/datasets/reading_comprehension_generation/synthetic_based.py
metric-space Dec 2, 2023
9ae906d
Update dalm/datasets/reading_comprehension_generation/utils.py
metric-space Dec 2, 2023
1d3ed52
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
a7ab91c
Update dalm/pipelines/reading_comprehension_pipeline.py
metric-space Dec 2, 2023
9cb16ee
Corrections
metric-space Dec 2, 2023
4025e45
Post linting
metric-space Dec 2, 2023
39b7f1e
Update README with suggested corrections
metric-space Dec 2, 2023
81a2ea0
grammar corrections
metric-space Dec 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions dalm/datasets/reading_comprehension_generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
## A note about reading comprehension

This aproach of adapting LLMs is based on this [paper](https://arxiv.org/abs/2309.09530) by Microsoft
The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora
and training a llm on said generated dataset can enhance its domain adaptiveness

We have two ways of generating reading comprehension data

1. Via regex based methods that combs the input data for match and aligns them into questions and answers
2. Via prompting a large language model to come up with questions and answers
Comment on lines +9 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user, how should I decide which approach to use? Here's a stab:

  • Use regex based reading comprehension dataset generation when it works on that dataset
  • Otherwise fallback to the slower synthetic data generation approach


To see the prompt behind LLM based reading-comprehension dataset generation please go [here](https://github.com/arcee-ai/DALM/blob/4d93d4a198cc64ce5d19ee98786b70f579dbef0c/dalm/datasets/reading_comprehension_generation/synthetic_based.py#L22)

## How to get started

For the input, either a single csv file or a directory of individual files each containing raw text will do.


### LLM based
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add a one line saying where the users can prompts related to this!


Assuming you have your dataset as a csv file with the column `text` containing the raw texts

(chunking based on context length of model is enabled by default)

```bash
python dalm/datasets/reading_comprehension_generation/synthetic_based.py \
--model HuggingFaceH4/zephyr-7b-alpha \
--context-length 4192
--input input_dataset.csv --output_directory synth_data --dataset_name llm_generated_dataset
```

the output directory serves as a temporary holding place of all generated data before it can be made a dataset.
The generation process is time consuming and expensive. On average, because the process uses a LLM (if using the recommended 13b llama2 model), it takes about 20-30 minutes to produce 10 questions (numbers may vary depending on the content of your dataset and the unpredictability of the model). So every step is taken to ensure that if the process is interrupted, once back running will pick up where it left off.

Chunking of data is enabled by default and requires the context length to be passed which is why it passed in in the example

### Regex based

(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts)

Please note there is the choice of passing in a domain sentence model in addition, but this is not required as
the script will train a domain specific sentencepiece model on the input corpus

```bash

python dalm/datasets/reading_comprehension_generation/regex_based.py --input input.csv \
--csv_column text --general_spm_path resources/general.spm \
--output_dataset_name regex_dataset
```
1,265 changes: 1,265 additions & 0 deletions dalm/datasets/reading_comprehension_generation/regex_based.py

Large diffs are not rendered by default.

223 changes: 223 additions & 0 deletions dalm/datasets/reading_comprehension_generation/synthetic_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import json
import logging
import os
import pickle
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
from datasets import Dataset
from transformers import Pipeline, pipeline

from dalm.datasets.reading_comprehension_generation.utils import (
input_generator,
question_and_answer_extractor,
text_chunker,
)

logger = logging.getLogger(__name__)

# ruff: noqa: B006

PROMPT = (
"There are 4 types of reading comprehension tasks. "
"The point of reading comprehension tasks is to be assigned a text and questions to "
"prompt answers so as to test conceptual and procedural knowledge present in the text. "
"The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK "
"2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state "
"the correctness of the statement) 3. frame a sentence with domain specific keywords"
"(these keywords are required to be present in the text) Q&A TASK "
"4. Normal questions and answer Task (description: longform Q&A to test procedural and "
"conceptual knowledge). An example of all four tasks given an example text is as follows: "
"\n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep "
"processes in human and animal brain led to other biologically inspired approaches. While "
"declarative memories are in the classical picture consolidated by hippocampo-neocortical "
"dialog during NREM phase of sleep, some types of procedural memories were suggested not "
"to rely on the hippocampus and involve REM phase of the sleep. This inspired models where "
"internal representations (memories) created by previous learning are spontaneously replayed "
"during sleep-like periods in the network itself (i.e. without help of secondary network "
"performed by generative replay approaches mentioned above).\n"
"Question: [type: true/false] Is the following sentence true? all types of procedural "
"memories rely on the hippocampus\n"
"Answer: False. The text clearly states there are some types of procedural memories not "
"reliant on the hippocampus\n--------\n"
"Question [type: complete-the-sentence] Complete the following sentence: The insights into "
"____ in human and animal brain led to other _____ approaches\n"
"Answer: The insights into the mechanisms of memory consolidation during the sleep processes "
"in human and animal brain led to other biologically inspired approaches\n------\n"
"Question [type 3 domain-keywords] Make a sentence with the following keywords "
"'hippocampo-neocortical', 'declarative' and 'NREM'\n"
"Answer: declarative memories are in the classical picture consolidated by "
"hippocampo-neocortical dialog during NREM phase of sleep\n-------\n"
"Question [type: normal q&a] Some types of procedural memories were suggested not to rely on "
"the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\n"
"Answer This inspired models where internal representations (memories) created by previous "
"learning are spontaneously replayed during sleep-like periods in the network itself [END OF "
"EXAMPLE]\n\n "
"Similar to the above, could you craft 4 different reading comprehension tasks (make sure "
"your output is a list of question answer pairs and each question is labelled QUESTION and "
"answer is labelled ANSWER and there is one question and answer per task) based solely and "
"completely focused on the following TEXT: "
)


def gen_prompt(text: str) -> List[Dict[str, str]]:
prompt = PROMPT + text

return [
{
"role": "system",
"content": (
"You are a helpful and meticulous instruction following question and answer making chatbot. "
"Please refrain from acknowledgments, additions or niceties of any sort"
),
},
{"role": "user", "content": prompt},
]


def generate_synthetic_data(model_pipeline: Pipeline, text: str, generation_params: Dict[str, Any]) -> str:
prompt = gen_prompt(text)
prompt = model_pipeline.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
outputs = model_pipeline(prompt, **generation_params)

return outputs[0]["generated_text"]


def generate_synthetic_dataset(
model_name: str,
input_directory_or_file: str,
csv_column: Optional[str],
processed_files: List[str],
chunk: bool,
context_length: int,
tleyden marked this conversation as resolved.
Show resolved Hide resolved
generation_params: Dict[str, Any] = {
"max_new_tokens": 600,
"do_sample": True,
"temperature": 0.7,
"top_k": 5,
"top_p": 0.95,
"return_full_text": False,
},
) -> Iterator[Tuple[int, str, str, str]]:
model_pipeline = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")

input_files = input_generator(input_directory_or_file, csv_column)

if chunk:
tokenizer = model_pipeline.tokenizer
tokens = tokenizer.apply_chat_template(gen_prompt(""), tokenize=False, add_generation_prompt=True)
CONSTANT = len(tokenizer(tokens)["input_ids"])
k = context_length - CONSTANT

for file, text in input_files:
if file in processed_files:
continue

if chunk:
for index, chunk_ in enumerate(text_chunker(text, tokenizer, k)):
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params)
yield index, file, chunk_, gen_text
else:
gen_text = generate_synthetic_data(model_pipeline, text, generation_params)
yield 0, file, text, gen_text


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("Generate synthetic dataset for reading comprehension")
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta")
tleyden marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--input", type=str, required=True, help="Directory containing the input files OR a CSV file")
parser.add_argument("--csv_column", type=str, help="Column to read from the CSV file")
parser.add_argument(
"--output_directory",
type=str,
required=True,
help="Directory to save the generated files (serves as intermediate step and for debugging purposes)",
)
parser.add_argument(
"--state_file",
type=str,
required=False,
default="rc_generation_state.pkl",
help="File to save the state of the generation in order to support resume functionality",
)
parser.add_argument("--context_length", type=int, default=4096, help="context length to calculate the chunk size")
parser.add_argument("--no_chunk", action="store_false")
parser.add_argument(
"--dataset_name", type=str, default="synthetic_rc_dataset", help="name of the dataset to be saved"
)

return parser.parse_args()


def main() -> None:
args = parse_args()
"""
Pipeline here includes chunking, generation and parsing of question and answer into a list of exchanges
that can be used directly for training
tleyden marked this conversation as resolved.
Show resolved Hide resolved
"""

if os.path.isfile(args.input) and not args.csv_column:
raise ValueError("a CSV column must be specified if the input is a file")

if args.state_file:
if os.path.exists(args.state_file):
with open(args.state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(args.state_file, "wb"))

if not os.path.exists(args.output_directory):
os.makedirs(args.output_directory)

files_missed = 0
total_files = 0

synth_dataset_generator = generate_synthetic_dataset(
model_name=args.model_name,
input_directory_or_file=args.input,
processed_files=state["processed_files"] if args.state_file else [],
chunk=args.no_chunk,
context_length=args.context_length,
csv_column=args.csv_column,
)

for index, filename, context, gen_text in synth_dataset_generator:
state["processed_files"].append(filename)
pickle.dump(state, open(args.state_file, "wb"))
qanda = question_and_answer_extractor(gen_text, context)
if qanda:
output_file = f"{filename}_{index}.json"
with open(os.path.join(args.output_directory, output_file), "w") as o:
json.dump(qanda, o)
else:
logger.warning(
(f"No question and answer pairs found for {filename} " f"chunk: {index}" if not args.no_chunk else "")
)
files_missed += 1
total_files += 1

unit = "files" if args.no_chunk else "chunks"

logger.info(" Statistics ")
logger.info(f"Total number of successfully extracted q&a {unit}: {total_files - files_missed}")
logger.info(f"Total {unit} missed: {files_missed} out of {total_files}")

in_memory_dataset = []
for file in os.listdir(args.output_directory):
with open(os.path.join(args.output_directory, file), "r") as f:
in_memory_dataset.append({"messages": json.load(f)})

dataset = Dataset.from_list(in_memory_dataset)
dataset.save_to_disk(args.dataset_name)

logger.info("Done generating synthetic dataset")
logger.info(f"Dataset saved to {args.dataset_name}")

if args.state_file:
os.remove(args.state_file)


if __name__ == "__main__":
main()
Loading