Skip to content

Commit

Permalink
Post pipeline test run corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Nov 19, 2023
1 parent c0c530f commit fcdfa3b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 34 deletions.
20 changes: 12 additions & 8 deletions dalm/datasets/reading_comprehension_generation/regex_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from tqdm.contrib.concurrent import process_map
from pysbd import Segmenter
import copy
from dalm.datasets.reading_comprehension_generation.utils import create_domain_tokenizer, create_domain_tokenizer_from_files
import pprint
from dalm.datasets.reading_comprehension_generation.utils import (
create_domain_tokenizer,
create_domain_tokenizer_from_files,
)
import json
from warnings import warn

TYPES = ["nli", "common_reason", "paraphrase", "word2text", "summarize", "text_completion"]
Expand Down Expand Up @@ -1163,13 +1166,13 @@ def generate(self, entry):

return {"read_compre": read_compre_list, "file_name": entry["file_name"]}

def create_dataset(self, input_dir, output_dir ,workers=1):
def create_dataset(self, input_dir, output_dir, workers=1):
print("loading raw texts in the input folder...")
paths = glob.glob(f"{input_dir}/*")
# print(f'paths: {paths}')

raw_texts = []
# NOTE: use generator here
# NOTE: use generator here
# NOTE: do I really need TQDM here?
for text_id, path in tqdm(enumerate(paths)):
file_name = path.split("/")[-1]
Expand All @@ -1195,8 +1198,7 @@ def create_dataset(self, input_dir, output_dir ,workers=1):
path = os.path.join(output_dir, f"{file_name}_{str(index)}")

with open(path, "w") as f:
f.write(pprint.pformat(read_compre_example))
f.close()
json.dump(read_compre_example, f)


if __name__ == "__main__":
Expand All @@ -1209,7 +1211,9 @@ def create_dataset(self, input_dir, output_dir ,workers=1):
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm"
)
parser.add_argument(
"--domain_spm_path", type=str, help="path of the domain sentencepiece model", # default="./tokenizers/domain.spm"
"--domain_spm_path",
type=str,
help="path of the domain sentencepiece model", # default="./tokenizers/domain.spm"
)
parser.add_argument(
"--domain_tokenizer_training_text",
Expand Down Expand Up @@ -1247,4 +1251,4 @@ def create_dataset(self, input_dir, output_dir ,workers=1):

rc.create_dataset(args.input_dir, args.output_dir, workers=max_workers)

print(f"saved to {args.output_dir}")
print(f"saved to {args.output_dir}")
44 changes: 25 additions & 19 deletions dalm/pipelines/reading_comprehension_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from enum import Enum
import os
import json
from datasets import load_dataset, concatenate_datasets
import datasets
import sentencepiece as spm

from dalm.datasets.reading_comprehension_generation.regex_based import RC
from dalm.datasets.reading_comprehension_generation.synthetic_based import generate_synthetic_dataset
Expand Down Expand Up @@ -50,21 +51,24 @@ def pipeline(
lr_scheduler_type: Optional[str] = "linear",
num_warmup_steps: Optional[int] = 0,
weight_decay: Optional[float] = 0.0,
optimizer_type: Optional[str] = "adamw",
optimizer_type: Optional[str] = "paged_adamw_32bit",
model_output_dir: Optional[str] = "model_output_dir",
log_freq: Optional[int] = 100,
neftune_noise_alpha: Optional[int] = 5,
log_with: Optional[str] = "wandb",
generation_state_file: Optional[str] = "generation_state.pkl",
):
domain_spm = spm.SentencePieceProcessor(model_file=domain_spm_path)
ori_spm = spm.SentencePieceProcessor(model_file=general_spm_path)

# generate regex based reading comprehension dataset
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]:
# generate regex based reading comprehension dataset
regex_rc_gen = RC(model_name, general_spm_path, domain_spm_path)
regex_rc_gen = RC(ori_spm, domain_spm)

# NOTE: this is a simple check to see if the dataset is already generated
if not (os.path.exists(regex_dataset_output_path) and len(os.listdir(regex_dataset_output_path)) > 0):
regex_rc_gen.generate_dataset(
regex_rc_gen.create_dataset(
dataset_path,
regex_dataset_output_path,
)
Expand All @@ -75,7 +79,7 @@ def pipeline(
for index, (gen_text, context) in enumerate(
generate_synthetic_dataset(
model_name=llm_synth_model_name,
input_directory=llm_dataset_output_path,
input_directory=dataset_path,
state_file=generation_state_file,
chunk=chunk,
context_length=llm_synth_model_context_length,
Expand All @@ -90,28 +94,28 @@ def pipeline(
# mix both and make it a huggingface dataset
list_of_data = []
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]:
a1 = load_dataset("json", data_files=regex_dataset_output_path)
# for file in os.listdir(regex_dataset_output_path):
# text = open(file, 'r').read()
# list_of_data.append(text)
# a1 = load_dataset("json", data_dir=regex_dataset_output_path) This does not work
for file in os.listdir(regex_dataset_output_path):
text = json.load(open(os.path.join(regex_dataset_output_path, file), "r"))
list_of_data.append({"messages": text})

if comprehension_type in [SynthMode.LLM, SynthMode.BOTH]:
a2 = load_dataset("json", data_files=llm_dataset_output_path)
# for file in os.listdir(llm_dataset_output_path):
# text = open(file, 'r').read()
# list_of_data.append(text)
# a2 = load_dataset("json", data_dir=llm_dataset_output_path) This does not work
for file in os.listdir(llm_dataset_output_path):
text = json.load(open(os.path.join(llm_dataset_output_path, file), "r"))
list_of_data.append({"messages": text})

if comprehension_type == SynthMode.BOTH:
dataset = concatenate_datasets([a1, a2])
dataset = datasets.Dataset.from_list(list_of_data)

dataset.save_to_disk("reading_comprehension_dataset") # TODO: change name from

del dataset, a1, a2 # TODO: change name
# del dataset # TODO: change name

train_generator(
model_name=model_name,
dataset_name="reading_comprehension_dataset",
num_train_epochs=num_train_epochs,
num_train_epochs=num_train_epochs,
split=split,
size_valid_set=size_valid_set,
streaming=streaming,
Expand All @@ -134,7 +138,7 @@ def pipeline(
num_warmup_steps=num_warmup_steps,
weight_decay=weight_decay,
optimizer_type=optimizer_type,
model_output_dir=model_output_dir,
output_dir=model_output_dir,
log_freq=log_freq,
neftune_noise_alpha=neftune_noise_alpha,
log_with=log_with,
Expand All @@ -146,7 +150,9 @@ def pipeline(
"HuggingFaceH4/zephyr-7b-beta",
comprehension_type=SynthMode.BOTH,
llm_synth_model_name="HuggingFaceH4/zephyr-7b-beta",
domain_spm_path="domain.spm",
domain_spm_path="./tokenizers/domain.spm",
general_spm_path="./tokenizers/general.spm",
chunk=True,
dataset_path="./data",
dataset_path="./data_llm",
packing=True,
)
21 changes: 14 additions & 7 deletions dalm/training/generator_only/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from accelerate import Accelerator
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
Expand All @@ -24,13 +24,20 @@ def create_datasets(
num_workers: int,
tokenizer: AutoTokenizer,
formatting_func: callable,
local_dataset: bool = True,
):
dataset = load_dataset(
dataset_name,
split=split,
num_proc=num_workers if not streaming else None,
streaming=streaming,
)
if local_dataset:
dataset = load_from_disk(
dataset_name,
)
streaming = False
else:
dataset = load_dataset(
dataset_name,
split=split,
num_proc=num_workers if not streaming else None,
streaming=streaming,
)
if streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(size_valid_set)
Expand Down

0 comments on commit fcdfa3b

Please sign in to comment.