Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Da 24/reading comprehension #74

Merged
merged 74 commits into from
Dec 4, 2023
Merged

Da 24/reading comprehension #74

merged 74 commits into from
Dec 4, 2023

Conversation

metric-space
Copy link
Collaborator

@metric-space metric-space commented Nov 17, 2023

What this is

Reading comprehension is based off the idea that we can enhance the domain adaptiveness of the generator via synthetically augmenting input data.
I want to emphasis this targets the generator and not the retriever

Integration of Reading Comprehension

  1. Regex based dataset generation
  2. LLM based dataset generation
  3. SFT based generator tuning

Notes:

Why not stream the dataset by default ?
I have no idea what the behavior is when streaming actually works as intended, so hard for me to recommend something that I have no idea what the benefits are

commands to get things running

This assumes you have a csv file with raw texts in a certain column (for this example let's say text)

  1. llm based generation
python dalm/datasets/reading_comprehension_generation/synthetic_based.py \
            --model HuggingFaceH4/zephyr-7b-alpha \
            --context-length  4192
            --csv_column text
            --input input_dataset.csv  --output_directory synth_data --dataset_name llm_generated_dataset
  1. rc based generation
python dalm/datasets/reading_comprehension_generation/regex_based.py  --input input.csv \
            --csv_column text  \
            --output_dataset_name regex_dataset
  1. end to end pipeline
python dalm/pipelines/reading_comprehension_pipeline.py --model_name HuggingFaceH4/zephyr-7b-alpha \
    --input input.csv --csv_column text \
    --output_dataset_name combined \
    --llm_synth_model_name meta-llama/Llama-2-13b \
    --llm_synth_model_context_length 4096
  1. training
python dalm/training/generator_only/trainer.py --dataset_name combined --local_dataset \
     --model_name HuggingFaceH4/zephyr-7b-beta

Copy link
Collaborator

@tleyden tleyden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! A few things I think that would be useful to include in the README:

  • Examples of downloading and processing a public or local dataset with both approaches.
  • Example of calling trainer.py
  • Reference to upstream codebase where the regex code was originally sourced and modified

model_prefix = f"{temp_dir}/domain"

# Train the SentencePiece model, the model is saved in the temporary directory
spm.SentencePieceTrainer.train(input=text, model_prefix=model_prefix, vocab_size=32000, character_coverage=1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this missing an import for spm? When I opened in VS code, it's flagging this line


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make this code a bit easier to run if a default model is provided

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also document the types of models that are supported. I tried running:

python -m dalm.datasets.reading_comprehension_generation.synthetic_based --model_name "meta-llama/Llama-2-7b-hf" --input_directory "datasets" --output_directory generated_datasets --state_file "processing_state.txt"

and got error:

AttributeError: 'LlamaTokenizerFast' object has no attribute 'apply_chat_template'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually after digging in a bit more, it seems like a problem with the version of the transformers library I was using rather than the model.

Running pip install transformers --upgrade got past the no attribute 'apply_chat_template' error. In my case, it went from transformers version 4.33.1 -> 4.35.2

As part of this PR, should we add a constraint on the dependencies to require a newer version of the transformers library? (> 4.35?)

"""

for index, (gen_text, context) in enumerate(
generate_synthetic_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this many args, named arguments would help readability and prevent future errors around arg order:

generate_synthetic_dataset(
   model_name = args.model_name,
   .. etc

Copy link
Collaborator

@tleyden tleyden Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think an approach like this would reduce the cognitive load of this code block:


# Create a synthetic dataset generator
dataset_gen = generate_synthetic_dataset(
    args.model_name, args.input_directory, args.state_file, args.chunk, args.context_length
 )
        
# Loop over synthetic dataset and extract question and answers
for index, (gen_text, context) in enumerate(dataset_gen):
   q_and_a = question_and_answer_extractor(gen_text, context)
   .. etc

parser.add_argument("--output_directory", type=str, required=True)
parser.add_argument("--state_file", type=str, required=True)
parser.add_argument("--context_length", type=int, default=2048)
parser.add_argument("--chunk", action="store_true")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could a default=False be added to make this code easier to run? Eg, less arguments for users to have to think about.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually on second glance, maybe chunk should default to true. When I ran it on this dataset I got this warning:

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.

and the model generated complete gibberish:

outputs: [{'generated_text': 'ЋЪЋ.ЪЪЉЁЉЋЋЏЪЉЏЉЉЉЋЋЪЪЋЋЪЉЉЉЋЉЉЏЉЋЪЋЏЋЋЉЉЪЪЏЉЪЪ...'}]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding --chunk it managed to generate a reading comprehension dataset! It definitely seems like we should default to chunk=True.

parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--input_directory", type=str, required=True)
parser.add_argument("--output_directory", type=str, required=True)
parser.add_argument("--state_file", type=str, required=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance state_file could be made optional? Meaning if the user didn't pass anything, the code would just create a state file as a temp file somewhere?

State file = something to track processing state?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is a separate processing state file even needed? If the file has been processed, won't it be present in the output directory? (and likewise, if not processed, not present in the output directory)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to remove the state file.
I think because the function has no control over the naming of the output file, keeping a record of the files already processed is the most reliant way of keeping record

Also made a minor change giving user the ability to switch off behaviour

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 that said, it would be cleaner to give the caller of the function the reins to state keeping of this nature

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think because the function has no control over the naming of the output file, keeping a record of the files already processed is the most reliant way of keeping record

Yeah for this to work:

If the file has been processed, won't it be present in the output directory? (and likewise, if not processed, not present in the output directory)

The output file would need to have the same name as the input file, or use a "content addressable" scheme of some sort (hash of content). Currently it uses an index in the filename.

If there's a 1:1 mapping between input and output file, then using the same filename in both input and output directories should make it easy to track the processing state without needing an extra file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give the caller of the function the reins to state keeping of this nature

I think only a small percentage of users would want any control here. So it would be better to have the default behavior "just work" while allowing power users to override it if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code is too eagerly adding files to the state file, because I was able to end up in a state where:

  1. My output directory is empty
  2. My only input file (pubmed25.csv) has been added to the state file
root@04e369b03ae7:/# ls -alh generated_datasets/
total 4.0K
drwxr-xr-x 2 root root   10 Nov 20 13:14 .
drwxr-xr-x 1 root root 4.0K Nov 20 13:26 ..
root@04e369b03ae7:/# strings processing_state.txt
processed_files
pubmed25.csv

Eliminating the separate state file (as suggested above) would eliminate this drift. Or alternatively, a file should only be added to the state file after it successfully generated the corresponding output file.

Copy link
Collaborator Author

@metric-space metric-space Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the time you reported this, a csv file was not considered a valid input, it's very likely the processing was correct and the one file didn't make it through because the q&a parser didn't parse anything out from output of the llm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metric-space Ok I will try this again and see if I observe the same issue

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think trl needs to get added to the project.toml dependencies

return first_prompt + chat_chain


# TODO: type hinting is very necessary here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah type hinting would be very helpful


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be zephyr-7b-beta?

"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name"
)
parser.add_argument("--split", type=str, default="train", help="the split to use")
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems easier for users to express this in terms of percentage of training set rather than absolute size. Eg, 20%

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this argument for streaming, I believe in this scenario a preset size is required

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it, for streaming we don't know the full size of the dataset until we load it. I think its confusing and misleading though to advertise this parameter without mentioning the fact that its streaming only.

WDYT of changing to this?

    parser.add_argument(
      "--size_valid_set_streaming", 
      type=int, 
      default=4000, 
      help="the size of the validation set when used in streaming mode, ignored otherwise"
    )

return train_dataset, valid_dataset


def chars_token_ratio(dataset, tokenizer, formatting_func, nb_examples=400):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does nb stand for? If "number", then I think num_examples is a bit clearer

)
parser.add_argument("--split", type=str, default="train", help="the split to use")
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set")
parser.add_argument("--streaming", type=bool, default=False, help="whether to stream the dataset")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to just default to True to keep memory footprint low by default? What are the downsides of streaming the dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested out the advantages of streaming over just vanilla loading. So hard for me to go for set that as a default for the user

train_data = dataset.skip(size_valid_set)
train_data = train_data.shuffle(buffer_size=shuffle_buffer, seed=None)
else:
dataset = dataset.train_test_split(test_size=0.05, seed=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it ignores size_valid_set in this code path - is that a bug?

dalm/training/generator_only/trainer.py Show resolved Hide resolved
parser.add_argument("--weight_decay", type=float, default=0.05, help="the weight decay")
parser.add_argument("--optimizer_type", type=str, default="paged_adamw_32bit", help="the optimizer type")

parser.add_argument("--output_dir", type=str, default="./generator_finetuned_model", help="the output directory")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add this directory to .gitignore so that it doesn't show up in the list of changed files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't show up in the list of changes files. It would show up in the list of untracked files no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but that's just as bad. Users that want to contribute would be confused by the untracked files being created when they ran git status. Also it would be a nuisance to core committers to have ignore them when running git add

streaming=streaming,
)
if streaming:
print("Loading the dataset in streaming mode")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use loggers instead of raw print. When running in cloud training environments, only the logger output is captured and the print statements get lost.

Here's an example from the DALM repo of creating a logger:

https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py#L14

and using it:

https://github.com/arcee-ai/DALM/blob/main/dalm/datasets/qa_gen/question_answer_generation.py#L119

"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output"
)
parser.add_argument(
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the code auto-download that somehow? If not, it seems like something that is good to add to the README example. Ditto for --domain_spm_path and --domain_tokenizer_training_text

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good one. But this domain tokenizer needs to get trained with user data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the ori in ori_spm_path referring to "original"? If so we should definitely change to orig instead of ori, which makes the meaning a lot clearer at the cost of one char

2. regex-based-gen now creates a domain tokenizer if both domain sentencepiece model and domain text (explicitly) is not given
3. attempt at pipeline
"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output"
)
parser.add_argument(
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good one. But this domain tokenizer needs to get trained with user data.

model_output_dir: Optional[str] = "model_output_dir",
log_freq: Optional[int] = 100,
neftune_noise_alpha: Optional[int] = 5,
log_with: Optional[str] = "wandb",
generation_state_file: Optional[str] = "generation_state.pkl",
):
domain_spm = spm.SentencePieceProcessor(model_file=domain_spm_path)
ori_spm = spm.SentencePieceProcessor(model_file=general_spm_path)

# generate regex based reading comprehension dataset
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use both if possible ?


dataset.save_to_disk("reading_comprehension_dataset") # TODO: change name from

del dataset, a1, a2 # TODO: change name
# del dataset # TODO: change name

train_generator(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we are using deep-speed during the prod training ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using del is suspect and makes me think the code isn't optimally structured (functions too big).

Is there a way to refactor the code so that these automatically go out of scope after they are no longer useful?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry this is just removing some del statements and adding a comment.

To have it automatically deleted, you could factor out a separate function:

def write_dataset(list_of_data, dataset_name):
  dataset = datasets.Dataset.from_list(list_of_data)
  dataset.save_to_disk(dataset_name)

and in the caller:

write_dataset(list_of_data, "reading_comprehension_dataset")
train_generator(
        model_name=model_name,
        dataset_name="reading_comprehension_dataset",
        .. etc
)

As soon as write_dataset() returns, the dataset local variable will go out of scope and be GC'd.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I just noticed something suspect in this code block:

    if comprehension_type == SynthMode.BOTH:
        dataset = datasets.Dataset.from_list(list_of_data)

    dataset.save_to_disk("reading_comprehension_dataset")  # TODO: change name from

I don't see dataset ever being created if if comprehension_type != SynthMode.BOTH, and it looks like that code will throw an exception unless it is set to SynthMode.BOTH

Copy link
Member

@shamanez shamanez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we use packing = True in the SFT trainer do we still need to use ConstatnLengthDataset object ?

https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-

@tleyden
Copy link
Collaborator

tleyden commented Nov 20, 2023

if we use packing = True in the SFT trainer do we still need to use ConstatnLengthDataset object ?

https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-

It defaults to True:

parser.add_argument("--packing", type=bool, default=True, help="whether to use packing for SFTTrainer")

but the user can override the value.

If it makes the code too complex (eg, dynamically using a ConstantLengthDataset depending on the arg value of packing), then maybe it's better to remove the argument and just hardcode a particular value for packing.



def gen_prompt(text):
prompt = f"There are 4 types of reading comprehension tasks. The point of reading comprehension tasks is to be assigned a text and questions to prompt answers so as to test conceptual and procedural knowledge present in the text. The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK 2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state the correctness of the statement)3. frame a sentence with domain specific keywords(these keywords are required to be present in the text) Q&A TASK 4. Normal questions and answer Task (description: longform Q&A to test procedural and conceptual knowledge). An example of all four tasks given an example text is as follows: \n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep processes in human and animal brain led to other biologically inspired approaches. While declarative memories are in the classical picture consolidated by hippocampo-neocortical dialog during NREM phase of sleep, some types of procedural memories were suggested not to rely on the hippocampus and involve REM phase of the sleep. This inspired models where internal representations (memories) created by previous learning are spontaneously replayed during sleep-like periods in the network itself (i.e. without help of secondary network performed by generative replay approaches mentioned above).\nQuestion: [type: true/false] Is the following sentence true? all types of procedural memories rely on the hippocampus\nAnswer: False. The text clearly states there are some types of procedural memories not reliant on the hippocampus\n--------\nQuestion [type: complete-the-sentence] Complete the following sentence: The insights into ____ in human and animal brain led to other _____ approaches\nAnswer: The insights into the mechanisms of memory consolidation during the sleep processes in human and animal brain led to other biologically inspired approaches\n------\nQuestion [type 3 domain-keywords] Make a sentence with the following keywords 'hippocampo-neocortical', 'declarative' and 'NREM'\nAnswer: declarative memories are in the classical picture consolidated by hippocampo-neocortical dialog during NREM phase of sleep\n-------\nQuestion [type: normal q&a] Some types of procedural memories were suggested not to rely on the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\nAnswer This inspired models where internal representations (memories) created by previous learning are spontaneously replayed during sleep-like periods in the network itself [END OF EXAMPLE]\n\n Similar to the above, could you craft 4 different reading comprehension tasks (make sure your output is a list of question answer pairs and each question is labelled QUESTION and answer is labelled ANSWER and there is one question and answer per task) based solely and completely focused on the following TEXT: {text}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This string should be word wrapped so the line isn't so long

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--input_directory", type=str, required=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think --input-directory would be slightly easier to read, and would match the same convention used by https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the suggestion makes sense, every other place in our codebase has it has snake case, perhaps to do this while we schedule an upgrade might make sense

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good call. Let's try to do it for the entire codebase at some point to be consistent.

parser.add_argument("--context_length", type=int, default=2048)
parser.add_argument("--chunk", action="store_true")

args = parser.parse_args()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fail-fast, it would be good to check if all required args were passed:

if not args.model_name:
    parser.error("--model_name is a required argument")
... etc

Otherwise what might happen is that it takes several minutes to run, then throws an exception because the output_directory argument was missing. This would be frustrating for the user.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, argparse should take care of this for us as long as we mark all args as required=True (see other comment regarding model_name)

parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name")
parser.add_argument("--log_with", type=str, default="wandb", help="use 'wandb' to log with wandb")
parser.add_argument(
"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the help message be more descriptive and say something like:

The dataset hugging face repo name or the local directory where the dataset is stored? Must be in reading comprehension format.

@tleyden
Copy link
Collaborator

tleyden commented Nov 27, 2023

I tried running the pipeline with the following command:

python -m dalm.pipelines.reading_comprehension_pipeline --model_name "/zephyr-7b-beta" --input arcee-ai_azure-reading-comprehension-dataset.csv --csv_column messages --output_dataset_name combined --llm_synth_model_name "/zephyr-7b-beta"

but hit this error:

RuntimeError: Internal: src/trainer_interface.cc(661) [(trainer_spec_.vocab_size()) == (model_proto->pieces_size())] Vocabulary size too high (32000). Please set it to a value <= 404.
Full error stacktrace
# python -m dalm.pipelines.reading_comprehension_pipeline --model_name "/zephyr-7b-beta" --input arcee-ai_azure-reading-comprehension-dataset.csv --csv_column messages --output_dataset_name combined --llm_synth_model_name "/zephyr-7b-beta"
/opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.
  warnings.warn(
11/27/2023 15:40:45 - WARNING - __main__ - No domain tokenizer provided. The domain tokenizer will be created from the input files
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with :
trainer_spec {
  input: /tmp/tmpl1we0632/temp.txt
  input_format:
  model_prefix: /tmp/tmpaokdrc8l/domain
  model_type: UNIGRAM
  vocab_size: 32000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter:
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars:
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  differential_privacy_clipping_threshold: 0
}
normalizer_spec {
  name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv:
}
denormalizer_spec {}
trainer_interface.cc(351) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(183) LOG(INFO) Loading corpus: /tmp/tmpl1we0632/temp.txt
trainer_interface.cc(407) LOG(INFO) Loaded all 67 sentences
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(428) LOG(INFO) Normalizing sentences...
trainer_interface.cc(537) LOG(INFO) all chars count=4883
trainer_interface.cc(558) LOG(INFO) Alphabet size=75
trainer_interface.cc(559) LOG(INFO) Final character coverage=1
trainer_interface.cc(591) LOG(INFO) Done! preprocessed 67 sentences.
unigram_model_trainer.cc(222) LOG(INFO) Making suffix array...
unigram_model_trainer.cc(226) LOG(INFO) Extracting frequent sub strings... node_num=2576
unigram_model_trainer.cc(274) LOG(INFO) Initialized 747 seed sentencepieces
trainer_interface.cc(597) LOG(INFO) Tokenizing input sentences with whitespace: 67
trainer_interface.cc(608) LOG(INFO) Done! 326
unigram_model_trainer.cc(564) LOG(INFO) Using 326 sentences for EM training
unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=412 obj=13.4434 num_tokens=883 num_tokens/piece=2.1432
unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=379 obj=12.2767 num_tokens=901 num_tokens/piece=2.37731
trainer_interface.cc(686) LOG(INFO) Saving model: /tmp/tmpaokdrc8l/domain.model
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 382, in <module>
    main()
  File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 340, in main
    pipeline(
  File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 99, in pipeline
    domain_spm = create_domain_tokenizer_from_files(input, csv_column=csv_column)
  File "/opt/conda/lib/python3.10/site-packages/dalm/datasets/reading_comprehension_generation/utils.py", line 104, in create_domain_tokenizer_from_files
    return create_domain_tokenizer(os.path.join(temp_dir, "temp.txt"))
  File "/opt/conda/lib/python3.10/site-packages/dalm/datasets/reading_comprehension_generation/utils.py", line 77, in create_domain_tokenizer
    spm.SentencePieceTrainer.train(
  File "/opt/conda/lib/python3.10/site-packages/sentencepiece/__init__.py", line 989, in Train
    SentencePieceTrainer._Train(arg=arg, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sentencepiece/__init__.py", line 982, in _Train
    return SentencePieceTrainer._TrainFromMap(new_kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sentencepiece/__init__.py", line 927, in _TrainFromMap
    return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)
RuntimeError: Internal: src/trainer_interface.cc(661) [(trainer_spec_.vocab_size()) == (model_proto->pieces_size())] Vocabulary size too high (32000). Please set it to a value <= 404.

Comment on lines +9 to +10
1. Via regex based methods that combs the input data for match and aligns them into questions and answers
2. Via prompting a large language model to come up with questions and answers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user, how should I decide which approach to use? Here's a stab:

  • Use regex based reading comprehension dataset generation when it works on that dataset
  • Otherwise fallback to the slower synthetic data generation approach

output_dataset_name: str,
input: str,
model_output_dir: str,
generation_state_file: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a default value here? This has a default value for the CLI, but not for the API.

"bitsandbytes",
"typer>=0.9.0,<1.0",
"pydantic==1.10.9", # Sync w/ other platform components
"pysbd",
"sentencepiece"
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a dependency on wandb due to this line:

    log_with: str = "wandb",

Can you add that to the list of dependencies?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best way forward is to remove this default, given the options of trackers for accelerator makes sense to not assume anything about the user

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed

dalm/training/generator_only/trainer.py Show resolved Hide resolved
Comment on lines 19 to 30
if os.path.isdir(directory_or_file):
for file in os.listdir(directory_or_file):
file_path = os.path.join(directory_or_file, file)
if os.path.isfile(file_path): # Ensures that we are reading files
try:
with open(file_path, "r", encoding="utf-8") as file_contents:
contents = file_contents.read()
except UnicodeDecodeError:
with open(file_path, "r", encoding="utf-8", errors="replace") as file_contents:
contents = file_contents.read()

yield file, contents
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a very nasty bug here: if you pass a csv file, the csv aware code below kicks in directory_or_file.endswith(".csv") and csv_column: and will work as expected an emit rows of the CSV.

However, if you instead pass a directory that contains CSV files, it will treat the CSV file as raw text and return the entire text (or an entire chunk) rather than rows from the CSV file.

The fix is to treat CSV files the same whether a single CSV is passed in or a directory containing CSV files is passed in.

Comment on lines 174 to 175
if not os.path.exists(output_dataset_name):
dataset.save_to_disk(output_dataset_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit surprising behavior. If the caller passes in an empty directory that exists on disk, instead of saving the generated dataset to that directory, it will discard the dataset (not save it) and then later fail with an error:

FileNotFoundError: Directory output_dataset is neither a `Dataset` directory nor a `DatasetDict` directory.

A few ideas on how to fix:

  1. Just overwrite the existing dataset with the new dataset (remove the if not os.path.exists() check)
  2. Instead of asking the user to pass the name of the dataset, just generate a unique dataset directory and use that, then inform the user where the dataset was generated.

Comment on lines 162 to 163
logger.info(f"Total files missed: {generation_state['files_missed']} out of {generation_state['total_files']}")
logger.info(f"Total files processed: {generation_state['total_files']}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These stats are really helpful! The term "files" is a bit misleading, but its clear from the output that can also represent rows.

If the dataset is empty however, we should throw an exception.

if not dataset:
    raise Exception("Failed to generate dataset")

This can happen when you have a small dataset and each chunk fails to generate a question/answer pair.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it will fail on it's own

>>> datasets.Dataset.from_list([])
Dataset({
    features: [],
    num_rows: 0
})
>>> a = datasets.Dataset.from_list([])
>>> a.save_to_disk('./my_dataset')
Saving the dataset (0/1 shards): 0 examples [00:00, ? examples/s]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 1530, in save_to_disk
    for job_id, done, content in Dataset._save_to_disk_single(**kwargs):
  File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 1571, in _save_to_disk_single
    num_examples, num_bytes = writer.finalize()
                              ^^^^^^^^^^^^^^^^^
  File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_writer.py", line 599, in finalize
    raise SchemaInferenceError("Please pass `features` or at least one example when writing data")
datasets.arrow_writer.SchemaInferenceError: Please pass `features` or at least one example when writing data

but the message won't be clear so your recommendation stands

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I hit that error testing with a small dataset and it was a pretty cryptic failure message. The idea is just to make it a lot clearer when that happens.

TrainingArguments,
)
from trl import SFTTrainer # type: ignore[import]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add import wandb here? This will fail fast when wandb is not installed, rather than failing in the middle of a pipeline with this stacktrace:

  File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 180, in pipeline
    train_generator(
  File "/opt/conda/lib/python3.10/site-packages/dalm/training/generator_only/trainer.py", line 240, in train_generator
    trainer = SFTTrainer(
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 252, in __init__
    super().__init__(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 525, in __init__
    self.callback_handler = CallbackHandler(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 306, in __init__
    self.add_callback(cb)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 323, in add_callback
    cb = callback() if isinstance(callback, type) else callback
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 669, in __init__
    raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
RuntimeError: WandbCallback requires wandb to be installed. Run `pip install wandb`.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given the number of options that exists for tracker, separating ourselves from a certain tracker may be the best way forward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, can we just default to no tracking? Then users can add it as needed, including libs

- remove default wandb value
- new logging statement
- change of terminology for logging (files -> texts)
Copy link
Collaborator

@tleyden tleyden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great overall! In my testing I'm close to having it working end-to-end via calling the python API functions.

The main changes I think we still need is to disable wandb by default (as discussed)

A few other changes I'd suggest (but aren't necessarily blockers)

  1. Default to SynthMode.LLM - but that's up for debate. What do folks think? @shamanez @Jacobsolawetz
  2. It does seem odd to me to not mention any of this new code in the main PR of the repo. We can always add this later in a follow-up PR, but it seems like a good time to add it. Thoughts?

@@ -0,0 +1,1266 @@
# Modified version of code from https://github.com/microsoft/LMOps/blob/main/adaptllm/utils/read.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts)

Please note there is the choice of passing in a domain sentence model in addition, but this is not required as
the script will train a domain speicifc sentencepiece model on the input corpus
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: should be "specific"

```

the output directory serves as a temporary holding place of all generated data before it can be made a dataset.
The generation process usually takes time. SO every step is taken to ensure if the process is interrupted, once back running
Copy link
Collaborator

@tleyden tleyden Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generation process usually takes time

I would change this to a concrete estimate like: "when using synthetic data generation, it takes approximately 10 minutes for each 100 rows of data you have". Even if it's not 100% accurate it's a lot more helpful than "takes some time".

SO every step is taken to ensure if the process is interrupted, once back running

A more succinct way to state this: "The script was designed to be idempotent"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more succinct way to state this: "The script was designed to be idempotent"

I think idempotency is a property of the current script/function and state-tracking to ensure we can resume where we left off, are connected but two distinct concepts. I am unconvinced that the recommended statement works here

The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora
and training a llm on said generated dataset can enhance its domain adaptiveness

We have two ways of generating reading comprehension data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an extra space between generating and reading

dalm/pipelines/reading_comprehension_pipeline.py Outdated Show resolved Hide resolved
dalm/pipelines/reading_comprehension_pipeline.py Outdated Show resolved Hide resolved


if __name__ == "__main__":
logger.setLevel(logging.INFO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should already be being set (see earlier comment)

"If local, be sure to set the local_dataset flag"
),
)
parser.add_argument("--local_dataset", action="store_true", help="whether to use a local dataset")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is this the normal way of doing things in HF?

I thought the normal hugging face approach is to "just figure it out". Eg, try to load a local dataset, then fallback to loading from HF hub (or vice-versa).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow it isn't the case, I don't understand why. HF datasets errors and asks you to switch methods to load local. This param is just surfacing this



if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be not needed (see earlier comments)

--input input.csv --csv_column text \
--output_dataset_name combined \
--general_spm_path tokenizers/general.spm \
--llm_synth_model_name meta-llama/Llama-2-13b \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be meta-llama/Llama-2-13b-chat-hf? In order to:

  • Use the RLHF'd chat model instead of the raw pretrained model
  • Use the HF variation to avoid errors like this

@tleyden
Copy link
Collaborator

tleyden commented Nov 30, 2023

The main changes I think we still need is to disable wandb by default (as discussed)

@metric-space I just realized the problem. In the HF TrainingArgs docs:

        report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
            The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
            `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
            `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
            integrations.

And on the training container it currently has wandb installed, so it picks it up and tries to use it.

I think whats happening is that log_with defaults to an empty string:

    parser.add_argument(
        "--log_with",
        type=str,
        help="tracker backend to be used",
    )

And then we pass an empty string to report_to, which defaults to all. It's a bit surprising behavior. I would suggest we just default --log_with to "none" and let users explicitly enable the tracker they want, if any.

@metric-space
Copy link
Collaborator Author

@tleyden @Jacobsolawetz the main README.md is already long enough and has a certain theme to it, and new deluge of knowledge related to reading comprehension will clash with it. I think a re-org is necessary and to me seems like a task in itself

Copy link
Collaborator

@tleyden tleyden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

I retested the latest round of changes and everything is working.

I will drive a follow-up PR to address #78

@tleyden tleyden merged commit 567c910 into main Dec 4, 2023
1 check passed
@tleyden tleyden deleted the DA-24/reading-comprehension branch December 4, 2023 11:39
@tleyden
Copy link
Collaborator

tleyden commented Dec 4, 2023

I think a re-org is necessary and to me seems like a task in itself

Yeah, agreed. I filed a tracking ticket for that: #79

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants