diff --git a/dalm/__init__.py b/dalm/__init__.py index 27fdca4..81f0fde 100644 --- a/dalm/__init__.py +++ b/dalm/__init__.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.4" diff --git a/dalm/cli.py b/dalm/cli.py index 6b5ceae..fcda96b 100644 --- a/dalm/cli.py +++ b/dalm/cli.py @@ -51,11 +51,9 @@ def train_rag_e2e( help="Path to pretrained (causal) generator or identifier from huggingface.co/models.", show_default=False ), ], - dataset_passage_col_name: Annotated[ - str, typer.Option(help="Name of the column containing the passage") - ] = "Abstract", - dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", - dataset_answer_col_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer", + passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract", + query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + answer_column_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer", query_max_len: Annotated[ int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") ] = 50, @@ -129,9 +127,9 @@ def train_rag_e2e( dataset_or_path=dataset_path, retriever_name_or_path=retriever_name_or_path, generator_name_or_path=generator_name_or_path, - dataset_passage_col_name=dataset_passage_col_name, - dataset_query_col_name=dataset_query_col_name, - dataset_answer_col_name=dataset_answer_col_name, + passage_column_name=passage_column_name, + query_column_name=query_column_name, + answer_column_name=answer_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, generator_max_len=generator_max_len, @@ -169,10 +167,8 @@ def train_retriever_only( show_default=False, ), ], - dataset_passage_col_name: Annotated[ - str, typer.Option(help="Name of the column containing the passage") - ] = "Abstract", - dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", + passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract", + query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question", query_max_len: Annotated[ int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated") ] = 50, @@ -239,8 +235,8 @@ def train_retriever_only( train_retriever( dataset_or_path=dataset_path, retriever_name_or_path=retriever_name_or_path, - dataset_passage_col_name=dataset_passage_col_name, - dataset_query_col_name=dataset_query_col_name, + passage_column_name=passage_column_name, + query_column_name=query_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, per_device_train_batch_size=per_device_train_batch_size, diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index 895161c..c6065a3 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -162,8 +162,8 @@ def main() -> None: lambda example: preprocess_function( example, retriever_tokenizer, - query_col_name=args.query_column_name, - passage_col_name=args.passage_column_name, + query_column_name=args.query_column_name, + passage_column_name=args.passage_column_name, ), batched=True, # remove_columns=test_dataset.column_names, diff --git a/dalm/eval/eval_retriever_only.py b/dalm/eval/eval_retriever_only.py index d100114..c80093f 100644 --- a/dalm/eval/eval_retriever_only.py +++ b/dalm/eval/eval_retriever_only.py @@ -121,8 +121,8 @@ def main() -> None: lambda example: preprocess_function( example, retriever_tokenizer, - query_col_name=args.query_column_name, - passage_col_name=args.passage_column_name, + query_column_name=args.query_column_name, + passage_column_name=args.passage_column_name, ), batched=True, # remove_columns=test_dataset.column_names, diff --git a/dalm/eval/utils.py b/dalm/eval/utils.py index 4d2d1f6..d24a378 100644 --- a/dalm/eval/utils.py +++ b/dalm/eval/utils.py @@ -78,11 +78,11 @@ def calculate_precision_recall(retrieved_items: List, correct_items: List) -> Tu def preprocess_function( examples: LazyBatch, retriever_tokenizer: PreTrainedTokenizer, - query_col_name: str = "query", - passage_col_name: str = "passage", + query_column_name: str = "query", + passage_column_name: str = "passage", ) -> Dict[str, torch.Tensor]: - queries = examples[query_col_name] - passages = examples[passage_col_name] + queries = examples[query_column_name] + passages = examples[passage_column_name] # Tokenization for the retriever retriever_query_tokens = retriever_tokenizer(queries, padding="max_length", max_length=128, truncation=True) diff --git a/dalm/training/rag_e2e/train_rage2e.py b/dalm/training/rag_e2e/train_rage2e.py index 160c562..264b4da 100644 --- a/dalm/training/rag_e2e/train_rage2e.py +++ b/dalm/training/rag_e2e/train_rage2e.py @@ -60,13 +60,13 @@ def parse_args() -> Namespace: help=("Dataset path. Can be a huggingface dataset directory or a csv file."), ) parser.add_argument( - "--dataset_passage_col_name", type=str, default="Abstract", help="Name of the column containing the passage" + "--passage_column_name", type=str, default="Abstract", help="Name of the column containing the passage" ) parser.add_argument( - "--dataset_query_col_name", type=str, default="Question", help="Name of the column containing the query" + "--query_column_name", type=str, default="Question", help="Name of the column containing the query" ) parser.add_argument( - "--dataset_answer_col_name", type=str, default="Answer", help="Name of the column containing the answer" + "--answer_column_name", type=str, default="Answer", help="Name of the column containing the answer" ) parser.add_argument( "--query_max_len", @@ -217,9 +217,9 @@ def train_e2e( dataset_or_path: str | Dataset, retriever_name_or_path: str, generator_name_or_path: str, - dataset_passage_col_name: str = "Abstract", - dataset_query_col_name: str = "Question", - dataset_answer_col_name: str = "Answer", + passage_column_name: str = "Abstract", + query_column_name: str = "Question", + answer_column_name: str = "Answer", query_max_len: int = 50, passage_max_len: int = 128, generator_max_len: int = 256, @@ -295,9 +295,9 @@ def train_e2e( example, retriever_tokenizer=rag_model.retriever_tokenizer, generator_tokenizer=rag_model.generator_tokenizer, - dataset_query_col_name=dataset_query_col_name, - dataset_passage_col_name=dataset_passage_col_name, - dataset_answer_col_name=dataset_answer_col_name, + query_column_name=query_column_name, + passage_column_name=passage_column_name, + answer_column_name=answer_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, generator_max_len=generator_max_len, @@ -523,9 +523,9 @@ def main() -> None: dataset_or_path=args.dataset_path, retriever_name_or_path=args.retriever_name_or_path, generator_name_or_path=args.generator_name_or_path, - dataset_passage_col_name=args.dataset_passage_col_name, - dataset_query_col_name=args.dataset_query_col_name, - dataset_answer_col_name=args.dataset_answer_col_name, + passage_column_name=args.passage_column_name, + query_column_name=args.query_column_name, + answer_column_name=args.answer_column_name, query_max_len=args.query_max_len, passage_max_len=args.passage_max_len, generator_max_len=args.generator_max_len, diff --git a/dalm/training/retriever_only/train_retriever_only.py b/dalm/training/retriever_only/train_retriever_only.py index fafca84..ad78261 100644 --- a/dalm/training/retriever_only/train_retriever_only.py +++ b/dalm/training/retriever_only/train_retriever_only.py @@ -44,10 +44,10 @@ def parse_args() -> Namespace: parser = argparse.ArgumentParser(description="training a PEFT model for Sematic Search task") parser.add_argument("--dataset_path", type=str, default=None, help="dataset path in the local dir") parser.add_argument( - "--dataset_query_col_name", type=str, default="Question", help="Name of the query column in the dataset" + "--query_column_name", type=str, default="Question", help="Name of the query column in the dataset" ) parser.add_argument( - "--dataset_passage_col_name", type=str, default="Abstract", help="Name of the passage column in the dataset" + "--passage_column_name", type=str, default="Abstract", help="Name of the passage column in the dataset" ) parser.add_argument( "--query_max_len", @@ -165,8 +165,8 @@ def parse_args() -> Namespace: def train_retriever( retriever_name_or_path: str, dataset_or_path: str | Dataset, - dataset_passage_col_name: str = "Abstract", - dataset_query_col_name: str = "Question", + passage_column_name: str = "Abstract", + query_column_name: str = "Question", query_max_len: int = 50, passage_max_len: int = 128, per_device_train_batch_size: int = 32, @@ -236,8 +236,8 @@ def train_retriever( lambda example: preprocess_dataset( example, tokenizer, - query_col_name=dataset_query_col_name, - passage_col_name=dataset_passage_col_name, + query_column_name=query_column_name, + passage_column_name=passage_column_name, query_max_len=query_max_len, passage_max_len=passage_max_len, ), @@ -418,8 +418,8 @@ def main() -> None: train_retriever( dataset_or_path=args.dataset_path, retriever_name_or_path=args.retriever_name_or_path, - dataset_passage_col_name=args.dataset_passage_col_name, - dataset_query_col_name=args.dataset_query_col_name, + passage_column_name=args.passage_column_name, + query_column_name=args.query_column_name, query_max_len=args.query_max_len, passage_max_len=args.passage_max_len, per_device_train_batch_size=args.per_device_train_batch_size, diff --git a/dalm/training/utils/rag_e2e_dataloader_utils.py b/dalm/training/utils/rag_e2e_dataloader_utils.py index bf58011..6f62ad4 100644 --- a/dalm/training/utils/rag_e2e_dataloader_utils.py +++ b/dalm/training/utils/rag_e2e_dataloader_utils.py @@ -8,16 +8,16 @@ def preprocess_dataset( examples: LazyBatch, retriever_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, generator_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - dataset_query_col_name: str, - dataset_passage_col_name: str, - dataset_answer_col_name: str, + query_column_name: str, + passage_column_name: str, + answer_column_name: str, query_max_len: int, passage_max_len: int, generator_max_len: int, ) -> Dict[str, Any]: - querie_list = examples[dataset_query_col_name] - passage_list = examples[dataset_passage_col_name] - answers = examples[dataset_answer_col_name] + querie_list = examples[query_column_name] + passage_list = examples[passage_column_name] + answers = examples[answer_column_name] queries = [f"#query# {query}" for query in querie_list] passages = [f"#passage# {passage}" for passage in passage_list] diff --git a/dalm/training/utils/retriever_only_dataloader_utils.py b/dalm/training/utils/retriever_only_dataloader_utils.py index 9e2da0d..9f21601 100644 --- a/dalm/training/utils/retriever_only_dataloader_utils.py +++ b/dalm/training/utils/retriever_only_dataloader_utils.py @@ -8,17 +8,17 @@ def preprocess_dataset( examples: LazyBatch, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - query_col_name: str, - passage_col_name: str, + query_column_name: str, + passage_column_name: str, query_max_len: int, passage_max_len: int, ) -> Dict[str, torch.Tensor]: - query_list = examples[query_col_name] + query_list = examples[query_column_name] queries = [f"#query# {query}" for query in query_list] result_ = tokenizer(queries, padding="max_length", max_length=query_max_len, truncation=True) result_ = {f"query_{k}": v for k, v in result_.items()} - passage_list = examples[passage_col_name] + passage_list = examples[passage_column_name] passages = [f"#passage# {passage}" for passage in passage_list] result_passage = tokenizer(passages, padding="max_length", max_length=passage_max_len, truncation=True) for k, v in result_passage.items(): diff --git a/experiments/llama-index-10k/README.md b/experiments/llama-index-10k/README.md index 85141f2..a33c019 100644 --- a/experiments/llama-index-10k/README.md +++ b/experiments/llama-index-10k/README.md @@ -48,7 +48,7 @@ dalm train-rag-e2e \ "qa-outputs/question_answer_pairs.csv" \ "BAAI/bge-small-en" \ "meta-llama/Llama-2-7b-hf" \ ---dataset-passage-col-name text \ +--passage-column-name text \ --output-dir "rag_e2e_checkpoints_bgsmall" \ --no-with-tracking \ --per-device-train-batch-size 12 @@ -74,7 +74,7 @@ Train the retriever only dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \ --output-dir "retriever_only_checkpoints_bgsmall" \ --use-peft \ ---dataset-passage-col-name text \ +--passage-column-name text \ --per-device-train-batch-size 150 ```