Skip to content

Commit

Permalink
Lift state management to a higer level and some corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Nov 21, 2023
1 parent 0f7b7e0 commit 7515b0a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 32 deletions.
42 changes: 22 additions & 20 deletions dalm/datasets/reading_comprehension_generation/synthetic_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_synthetic_data(model_pipeline, text, generation_params):
def generate_synthetic_dataset(
model_name,
input_directory,
state_file=None,
processed_files=[],
chunk=False,
context_length=2048,
generation_params={
Expand All @@ -51,29 +51,17 @@ def generate_synthetic_dataset(
CONSTANT = len(tokenizer(tokens)["input_ids"])
k = context_length - CONSTANT

if state_file:
if os.path.exists(state_file):
with open(state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(state_file, "wb"))

for file, text in input_files:
if state and file in state["processed_files"]:
if file in processed_files:
continue

if chunk:
for chunk_ in text_chunker(text, tokenizer, k):
gen_text = generate_synthetic_data(model_pipeline, chunk_, generation_params)
yield gen_text, chunk_
yield file, chunk_, gen_text
else:
gen_text = generate_synthetic_data(model_pipeline, text, generation_params)
yield gen_text, text

if state:
state["processed_files"].append(file)
pickle.dump(state, open(state_file, "wb"))
yield file, text, gen_text


if __name__ == "__main__":
Expand All @@ -92,13 +80,27 @@ def generate_synthetic_dataset(
that can be used directly for training
"""

for index, (gen_text, context) in enumerate(
if args.state_file:
if os.path.exists(args.state_file):
with open(args.state_file, "rb") as f:
state = pickle.load(f)
else:
state = {"processed_files": []}
pickle.dump(state, open(args.state_file, "wb"))

for index, (filename, context, gen_text) in enumerate(
generate_synthetic_dataset(
args.model_name, args.input_directory, args.state_file, args.chunk, args.context_length
model_name=args.model_name,
input_directory=args.input_directory,
processed_files=state["processed_files"] if args.state_file else [],
chunk=args.chunk,
context_length=args.context_length
)
):
state["processed_files"].append(filename)
pickle.dump(state, open(args.state_file, "wb"))
qanda = question_and_answer_extractor(gen_text, context)
if qanda:
output_file = f"gen_{index}.txt"
output_file = f"gen_{index}.json"
with open(os.path.join(args.output_directory, output_file), "w") as o:
json.dump(qanda, o)
json.dump(qanda, o)
42 changes: 30 additions & 12 deletions dalm/pipelines/reading_comprehension_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import json
import datasets
import sentencepiece as spm
import pickle

from dalm.datasets.reading_comprehension_generation.regex_based import RC
from dalm.datasets.reading_comprehension_generation.synthetic_based import generate_synthetic_dataset
from dalm.datasets.reading_comprehension_generation.utils import question_and_answer_extractor, text_chunker
from dalm.datasets.reading_comprehension_generation.utils import question_and_answer_extractor, create_domain_tokenizer_from_files

from dalm.training.generator_only.trainer import train_generator

Expand Down Expand Up @@ -57,14 +58,19 @@ def pipeline(
neftune_noise_alpha: Optional[int] = 5,
log_with: Optional[str] = "wandb",
generation_state_file: Optional[str] = "generation_state.pkl",
):
domain_spm = spm.SentencePieceProcessor(model_file=domain_spm_path)
ori_spm = spm.SentencePieceProcessor(model_file=general_spm_path)
):
if not domain_spm_path:
# warn user that the domain tokenizer will be created from the input files
domain_spm = create_domain_tokenizer_from_files(dataset_path)
else:
domain_spm = spm.SentencePieceProcessor(model_file=domain_spm_path)

general_spm = spm.SentencePieceProcessor(model_file=general_spm_path)

# generate regex based reading comprehension dataset
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]:
# generate regex based reading comprehension dataset
regex_rc_gen = RC(ori_spm, domain_spm)
regex_rc_gen = RC(general_spm, domain_spm)

# NOTE: this is a simple check to see if the dataset is already generated
if not (os.path.exists(regex_dataset_output_path) and len(os.listdir(regex_dataset_output_path)) > 0):
Expand All @@ -73,10 +79,20 @@ def pipeline(
regex_dataset_output_path,
)

generation_state = None
if generation_state_file:
if os.path.exists(generation_state_file):
with open(generation_state_file, "rb") as f:
generation_state = pickle.load(f)
else:
generation_state = {"processed_files": []}
pickle.dump(generation_state, open(generation_state_file, "wb"))


# generate llm based reading comprehension dataset
if comprehension_type in [SynthMode.LLM, SynthMode.BOTH]:
# generate llm based reading comprehension dataset
for index, (gen_text, context) in enumerate(
for index, (filename, context, gen_text) in enumerate(
generate_synthetic_dataset(
model_name=llm_synth_model_name,
input_directory=dataset_path,
Expand All @@ -90,17 +106,20 @@ def pipeline(
output_file = f"gen_{index}.json"
with open(os.path.join(llm_dataset_output_path, output_file), "w") as o:
json.dump(qanda, o)
else:
pass
generation_state["processed_files"].append(filename)
pickle.dump(generation_state, open(generation_state_file, "wb"))


# mix both and make it a huggingface dataset
list_of_data = []
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]:
# a1 = load_dataset("json", data_dir=regex_dataset_output_path) This does not work
for file in os.listdir(regex_dataset_output_path):
text = json.load(open(os.path.join(regex_dataset_output_path, file), "r"))
list_of_data.append({"messages": text})

if comprehension_type in [SynthMode.LLM, SynthMode.BOTH]:
# a2 = load_dataset("json", data_dir=llm_dataset_output_path) This does not work
for file in os.listdir(llm_dataset_output_path):
text = json.load(open(os.path.join(llm_dataset_output_path, file), "r"))
list_of_data.append({"messages": text})
Expand Down Expand Up @@ -147,12 +166,11 @@ def pipeline(

if __name__ == "__main__":
pipeline(
"HuggingFaceH4/zephyr-7b-beta",
model_name="HuggingFaceH4/zephyr-7b-beta",
comprehension_type=SynthMode.BOTH,
llm_synth_model_name="HuggingFaceH4/zephyr-7b-beta",
domain_spm_path="./tokenizers/domain.spm",
general_spm_path="./tokenizers/general.spm",
general_spm_path="./tokenizers/general.spm", # no
chunk=True,
dataset_path="./data_llm",
packing=True,
)
)

0 comments on commit 7515b0a

Please sign in to comment.