Skip to content

Commit

Permalink
1. json dump when writing to file
Browse files Browse the repository at this point in the history
2. regex-based-gen now creates a domain tokenizer if both domain sentencepiece model and domain text (explicitly) is not given
3. attempt at pipeline
  • Loading branch information
metric-space committed Nov 18, 2023
1 parent 8a9ec67 commit c0c530f
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 76 deletions.
99 changes: 53 additions & 46 deletions dalm/datasets/reading_comprehension_generation/regex_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from tqdm.contrib.concurrent import process_map
from pysbd import Segmenter
import copy
from functools import partial
from dalm.datasets.reading_comprehension_generation.utils import create_domain_tokenizer
from dalm.datasets.reading_comprehension_generation.utils import create_domain_tokenizer, create_domain_tokenizer_from_files
import pprint
from warnings import warn

TYPES = ["nli", "common_reason", "paraphrase", "word2text", "summarize", "text_completion"]

Expand Down Expand Up @@ -82,7 +82,6 @@ def fill_in_the_template(self, template, kw_dic):
4. length 2 template and qa_demos
"""
print(template)
qa_demos = kw_dic.get("qa_demos", [])

if "qa_demos" in kw_dic.keys():
Expand Down Expand Up @@ -1025,8 +1024,6 @@ def format_recomprehension(self, overall_entry, insert_types=TYPES):
else:
qa_demos = []

print("qa_demos", qa_demos)

def summaize_only(count_dict):
count_dict["summarize"] = 1
count_dict["text_completion"] = 0
Expand Down Expand Up @@ -1083,13 +1080,11 @@ def no_summarize_or_completion(count_dict):
# read_func = np.random.choice([summaize_only, no_summarize_or_completion], p=[0.5, 0.5])
if "text_completion" in insert_types and len(overall_entry["text_completion"]["sents"]) >= 2:
np.random.seed(seed)
print("text_completion")
if len(qa_demos) == 0:
read_func = completion_only
else:
read_func = np.random.choice([completion_only, no_summarize_or_completion], p=[0.5, 0.5])
else:
print("no_summarize_or_completion")
read_func = no_summarize_or_completion

return read_func(count_dict)
Expand Down Expand Up @@ -1168,76 +1163,88 @@ def generate(self, entry):

return {"read_compre": read_compre_list, "file_name": entry["file_name"]}

def create_dataset(self, input_dir, output_dir ,workers=1):
print("loading raw texts in the input folder...")
paths = glob.glob(f"{input_dir}/*")
# print(f'paths: {paths}')

raw_texts = []
# NOTE: use generator here
# NOTE: do I really need TQDM here?
for text_id, path in tqdm(enumerate(paths)):
file_name = path.split("/")[-1]

try:
with open(path, "r", encoding="utf8") as f:
text = f.read().strip()

except UnicodeDecodeError:
with open(path, "r", encoding="utf8", errors="replace") as f:
text = f.read().strip()

raw_texts.append({"text": text, "text_id": text_id, "file_name": file_name})

print("transferring raw texts into reading comprehension...")
read_compre = list(process_map(self.generate, raw_texts, max_workers=workers, chunksize=8192))

print("saving reading comprehension texts...")
# sort by text_id to align with the order of raw texts
for entry in read_compre:
for index, read_compre_example in enumerate(entry["read_compre"]):
file_name = entry["file_name"]
path = os.path.join(output_dir, f"{file_name}_{str(index)}")

with open(path, "w") as f:
f.write(pprint.pformat(read_compre_example))
f.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, help="directory of the input raw texts", default="./data")
parser.add_argument(
"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output"
"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output2"
)
parser.add_argument(
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm"
)
parser.add_argument(
"--domain_spm_path", type=str, help="path of the domain sentencepiece model", default="./tokenizers/domain.spm"
"--domain_spm_path", type=str, help="path of the domain sentencepiece model", # default="./tokenizers/domain.spm"
)
parser.add_argument(
"--domain_tokenizer_training_text",
type=str,
help="path of the domain sentencepiece model",
default="./data/domain_tokenizer_training_text.txt",
)

args = parser.parse_args()

if not (args.domain_spm_path or args.domain_tokenizer_training_text):
raise ValueError("domain_spm_path or domain_tokenizer_training_text should be provided")
# warn user that the domain tokenizer will be created from the input files
warn(
"No domain tokenizer is provided nor explicit file for training domain tokenizer is provided, "
"the domain tokenizer will be created from the input files, "
)

if not args.domain_spm_path:
if args.domain_tokenizer_training_text:
# train domain tokenizer
domain_spm = create_domain_tokenizer(args.domain_tokenizer_training_text)
else:
elif args.domain_spm_path:
domain_spm = spm.SentencePieceProcessor(model_file=args.domain_spm_path)
else:
domain_spm = create_domain_tokenizer_from_files(args.input_dir)

ori_spm = spm.SentencePieceProcessor(model_file=args.ori_spm_path)

# get max worker for multi-process
max_workers = get_max_workers()
print(f"max_workers: {max_workers}")

# load sentences in the input file
print("loading raw texts in the input folder...")
paths = glob.glob(f"{args.input_dir}/*")
# print(f'paths: {paths}')

raw_texts = []
for text_id, path in tqdm(enumerate(paths)):
file_name = path.split("/")[-1]

try:
with open(path, "r", encoding="utf8") as f:
text = f.read().strip()

except UnicodeDecodeError:
with open(path, "r", encoding="utf8", errors="replace") as f:
text = f.read().strip()

raw_texts.append({"text": text, "text_id": text_id, "file_name": file_name})

rc = RC(ori_spm, domain_spm)
# side effect warning
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

print("transferring raw texts into reading comprehension...")
read_compre = list(process_map(rc.generate, raw_texts, max_workers=max_workers, chunksize=8192))

print("saving reading comprehension texts...")
# sort by text_id to align with the order of raw texts
for entry in read_compre:
for index, read_compre_example in enumerate(entry["read_compre"]):
file_name = entry["file_name"]
path = os.path.join(args.output_dir, f"{file_name}_{str(index)}")

with open(path, "w") as f:
f.write(pprint.pformat(read_compre_example))
f.close()
rc.create_dataset(args.input_dir, args.output_dir, workers=max_workers)

print(f"saved to {args.output_dir}")
print(f"saved to {args.output_dir}")
39 changes: 20 additions & 19 deletions dalm/datasets/reading_comprehension_generation/synthetic_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from transformers import pipeline, AutoTokenizer
import torch
import pickle
from typing import Optional
from dalm.datasets.reading_comprehension_generation.utils import list_dir, text_chunker, question_and_answer_extractor
import pprint
import json


def gen_prompt(text):
Expand All @@ -30,7 +29,7 @@ def generate_synthetic_data(model_pipeline, text, generation_params):
def generate_synthetic_dataset(
model_name,
input_directory,
state_file,
state_file=None,
chunk=False,
context_length=2048,
generation_params={
Expand All @@ -52,25 +51,27 @@ def generate_synthetic_dataset(
CONSTANT = len(tokenizer(tokens)["input_ids"])
k = context_length - CONSTANT

if os.path.exists(state_file):
with open(state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(state_file, "wb"))
if state_file:
if os.path.exists(state_file):
with open(state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(state_file, "wb"))

for file, text in input_files:
if file in state["processed_files"]:
if state and file in state["processed_files"]:
continue

if chunk:
for chunk_ in text_chunker(text, tokenizer, k):
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params)
yield gen_text, chunk_
else:
if chunk:
for chunk_ in text_chunker(text, tokenizer, k):
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params)
yield gen_text, chunk_
else:
gen_text = generate_synthetic_data(model_pipeline, text, generation_params)
yield gen_text, text
gen_text = generate_synthetic_data(model_pipeline, text, generation_params)
yield gen_text, text

if state:
state["processed_files"].append(file)
pickle.dump(state, open(state_file, "wb"))

Expand All @@ -80,7 +81,7 @@ def generate_synthetic_dataset(
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--input_directory", type=str, required=True)
parser.add_argument("--output_directory", type=str, required=True)
parser.add_argument("--state_file", type=str, required=True)
parser.add_argument("--state_file", type=str, required=False, default="rc_generation_state.pkl")
parser.add_argument("--context_length", type=int, default=2048)
parser.add_argument("--chunk", action="store_true")

Expand All @@ -100,4 +101,4 @@ def generate_synthetic_dataset(
if qanda:
output_file = f"gen_{index}.txt"
with open(os.path.join(args.output_directory, output_file), "w") as o:
o.write(pprint.pformat(qanda))
json.dump(qanda, o)
46 changes: 35 additions & 11 deletions dalm/datasets/reading_comprehension_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
import re
import sentencepiece as spm

from transformers import AutoTokenizer

Expand Down Expand Up @@ -39,7 +40,7 @@ def files_chunker(input_directory, model, context_length, output_directory, prom
o.write(chunk)


def create_domain_tokenizer(text):
def create_domain_tokenizer(text_file):
"""
train and return domain tokenizer
"""
Expand All @@ -48,12 +49,43 @@ def create_domain_tokenizer(text):
model_prefix = f"{temp_dir}/domain"

# Train the SentencePiece model, the model is saved in the temporary directory
spm.SentencePieceTrainer.train(input=text, model_prefix=model_prefix, vocab_size=32000, character_coverage=1.0)
spm.SentencePieceTrainer.train(input=text_file, model_prefix=model_prefix, vocab_size=32000, character_coverage=1.0)

sp_model_file = f"{model_prefix}.model"
return spm.SentencePieceProcessor(model_file=sp_model_file)


def split_to_sentences(infile):
text = infile.read()
sentences = re.split(r"[.?!]\s+", text)

return sentences


# TODO: revisit the errors part
def create_domain_tokenizer_from_files(directory_with_files):
# open a tempfile and add sentences from files in directory_with_files to it
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = open(os.path.join(temp_dir, "temp.txt"), "w", encoding="utf-8")
for filename in os.listdir(directory_with_files):
try:
with open(os.path.join(directory_with_files, filename), "r", encoding="utf-8") as infile:
sentences = split_to_sentences(infile)

except UnicodeDecodeError:
with open(
os.path.join(directory_with_files, filename), "r", encoding="utf-8", errors="replace"
) as infile:
sentences = split_to_sentences(infile)

for sentence in sentences:
sentence = sentence.strip()
if sentence and sentence != "":
temp_file.write(sentence + "\n")

return create_domain_tokenizer(os.path.join(temp_dir, "temp.txt"))


def fix_first_prompt(text, chat_chain):
# remove the first prompt
first_prompt = chat_chain.pop(0)
Expand All @@ -69,6 +101,7 @@ def fix_first_prompt(text, chat_chain):

# TODO: type hinting is very necessary here
# TODO: add test
# TODO: refactor this as a state machine?
def question_and_answer_extractor(whole_text, context):
whole_text = whole_text.split("\n")
question = []
Expand All @@ -94,15 +127,6 @@ def question_and_answer_extractor(whole_text, context):
if text == "":
continue

# task regex to match Task 1, task 1 , task
task_regex = r"^\*?\*?task\s*\d*"

# question regex
question_regex = r"^question\s*\d*"

# answer regex
answer_regex = r"^answer\s*\d*"

# if the line start matches the question regex or the task regex
if re.match(question_regex, text) or re.match(task_regex, text):
if answer_context:
Expand Down
Loading

0 comments on commit c0c530f

Please sign in to comment.