Skip to content

Commit

Permalink
align param names
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Epstein committed Sep 21, 2023
1 parent 865cff3 commit ec6c627
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 55 deletions.
2 changes: 1 addition & 1 deletion dalm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.3"
__version__ = "0.0.4"
24 changes: 10 additions & 14 deletions dalm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ def train_rag_e2e(
help="Path to pretrained (causal) generator or identifier from huggingface.co/models.", show_default=False
),
],
dataset_passage_col_name: Annotated[
str, typer.Option(help="Name of the column containing the passage")
] = "Abstract",
dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
dataset_answer_col_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer",
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
answer_column_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer",
query_max_len: Annotated[
int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated")
] = 50,
Expand Down Expand Up @@ -129,9 +127,9 @@ def train_rag_e2e(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
generator_name_or_path=generator_name_or_path,
dataset_passage_col_name=dataset_passage_col_name,
dataset_query_col_name=dataset_query_col_name,
dataset_answer_col_name=dataset_answer_col_name,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
answer_column_name=answer_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
generator_max_len=generator_max_len,
Expand Down Expand Up @@ -169,10 +167,8 @@ def train_retriever_only(
show_default=False,
),
],
dataset_passage_col_name: Annotated[
str, typer.Option(help="Name of the column containing the passage")
] = "Abstract",
dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
query_max_len: Annotated[
int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated")
] = 50,
Expand Down Expand Up @@ -239,8 +235,8 @@ def train_retriever_only(
train_retriever(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
dataset_passage_col_name=dataset_passage_col_name,
dataset_query_col_name=dataset_query_col_name,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
per_device_train_batch_size=per_device_train_batch_size,
Expand Down
4 changes: 2 additions & 2 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def main() -> None:
lambda example: preprocess_function(
example,
retriever_tokenizer,
query_col_name=args.query_column_name,
passage_col_name=args.passage_column_name,
query_column_name=args.query_column_name,
passage_column_name=args.passage_column_name,
),
batched=True,
# remove_columns=test_dataset.column_names,
Expand Down
4 changes: 2 additions & 2 deletions dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def main() -> None:
lambda example: preprocess_function(
example,
retriever_tokenizer,
query_col_name=args.query_column_name,
passage_col_name=args.passage_column_name,
query_column_name=args.query_column_name,
passage_column_name=args.passage_column_name,
),
batched=True,
# remove_columns=test_dataset.column_names,
Expand Down
8 changes: 4 additions & 4 deletions dalm/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def calculate_precision_recall(retrieved_items: List, correct_items: List) -> Tu
def preprocess_function(
examples: LazyBatch,
retriever_tokenizer: PreTrainedTokenizer,
query_col_name: str = "query",
passage_col_name: str = "passage",
query_column_name: str = "query",
passage_column_name: str = "passage",
) -> Dict[str, torch.Tensor]:
queries = examples[query_col_name]
passages = examples[passage_col_name]
queries = examples[query_column_name]
passages = examples[passage_column_name]

# Tokenization for the retriever
retriever_query_tokens = retriever_tokenizer(queries, padding="max_length", max_length=128, truncation=True)
Expand Down
24 changes: 12 additions & 12 deletions dalm/training/rag_e2e/train_rage2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def parse_args() -> Namespace:
help=("Dataset path. Can be a huggingface dataset directory or a csv file."),
)
parser.add_argument(
"--dataset_passage_col_name", type=str, default="Abstract", help="Name of the column containing the passage"
"--passage_column_name", type=str, default="Abstract", help="Name of the column containing the passage"
)
parser.add_argument(
"--dataset_query_col_name", type=str, default="Question", help="Name of the column containing the query"
"--query_column_name", type=str, default="Question", help="Name of the column containing the query"
)
parser.add_argument(
"--dataset_answer_col_name", type=str, default="Answer", help="Name of the column containing the answer"
"--answer_column_name", type=str, default="Answer", help="Name of the column containing the answer"
)
parser.add_argument(
"--query_max_len",
Expand Down Expand Up @@ -217,9 +217,9 @@ def train_e2e(
dataset_or_path: str | Dataset,
retriever_name_or_path: str,
generator_name_or_path: str,
dataset_passage_col_name: str = "Abstract",
dataset_query_col_name: str = "Question",
dataset_answer_col_name: str = "Answer",
passage_column_name: str = "Abstract",
query_column_name: str = "Question",
answer_column_name: str = "Answer",
query_max_len: int = 50,
passage_max_len: int = 128,
generator_max_len: int = 256,
Expand Down Expand Up @@ -295,9 +295,9 @@ def train_e2e(
example,
retriever_tokenizer=rag_model.retriever_tokenizer,
generator_tokenizer=rag_model.generator_tokenizer,
dataset_query_col_name=dataset_query_col_name,
dataset_passage_col_name=dataset_passage_col_name,
dataset_answer_col_name=dataset_answer_col_name,
query_column_name=query_column_name,
passage_column_name=passage_column_name,
answer_column_name=answer_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
generator_max_len=generator_max_len,
Expand Down Expand Up @@ -523,9 +523,9 @@ def main() -> None:
dataset_or_path=args.dataset_path,
retriever_name_or_path=args.retriever_name_or_path,
generator_name_or_path=args.generator_name_or_path,
dataset_passage_col_name=args.dataset_passage_col_name,
dataset_query_col_name=args.dataset_query_col_name,
dataset_answer_col_name=args.dataset_answer_col_name,
passage_column_name=args.passage_column_name,
query_column_name=args.query_column_name,
answer_column_name=args.answer_column_name,
query_max_len=args.query_max_len,
passage_max_len=args.passage_max_len,
generator_max_len=args.generator_max_len,
Expand Down
16 changes: 8 additions & 8 deletions dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="training a PEFT model for Sematic Search task")
parser.add_argument("--dataset_path", type=str, default=None, help="dataset path in the local dir")
parser.add_argument(
"--dataset_query_col_name", type=str, default="Question", help="Name of the query column in the dataset"
"--query_column_name", type=str, default="Question", help="Name of the query column in the dataset"
)
parser.add_argument(
"--dataset_passage_col_name", type=str, default="Abstract", help="Name of the passage column in the dataset"
"--passage_column_name", type=str, default="Abstract", help="Name of the passage column in the dataset"
)
parser.add_argument(
"--query_max_len",
Expand Down Expand Up @@ -165,8 +165,8 @@ def parse_args() -> Namespace:
def train_retriever(
retriever_name_or_path: str,
dataset_or_path: str | Dataset,
dataset_passage_col_name: str = "Abstract",
dataset_query_col_name: str = "Question",
passage_column_name: str = "Abstract",
query_column_name: str = "Question",
query_max_len: int = 50,
passage_max_len: int = 128,
per_device_train_batch_size: int = 32,
Expand Down Expand Up @@ -236,8 +236,8 @@ def train_retriever(
lambda example: preprocess_dataset(
example,
tokenizer,
query_col_name=dataset_query_col_name,
passage_col_name=dataset_passage_col_name,
query_column_name=query_column_name,
passage_column_name=passage_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
),
Expand Down Expand Up @@ -418,8 +418,8 @@ def main() -> None:
train_retriever(
dataset_or_path=args.dataset_path,
retriever_name_or_path=args.retriever_name_or_path,
dataset_passage_col_name=args.dataset_passage_col_name,
dataset_query_col_name=args.dataset_query_col_name,
passage_column_name=args.passage_column_name,
query_column_name=args.query_column_name,
query_max_len=args.query_max_len,
passage_max_len=args.passage_max_len,
per_device_train_batch_size=args.per_device_train_batch_size,
Expand Down
12 changes: 6 additions & 6 deletions dalm/training/utils/rag_e2e_dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ def preprocess_dataset(
examples: LazyBatch,
retriever_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
generator_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
dataset_query_col_name: str,
dataset_passage_col_name: str,
dataset_answer_col_name: str,
query_column_name: str,
passage_column_name: str,
answer_column_name: str,
query_max_len: int,
passage_max_len: int,
generator_max_len: int,
) -> Dict[str, Any]:
querie_list = examples[dataset_query_col_name]
passage_list = examples[dataset_passage_col_name]
answers = examples[dataset_answer_col_name]
querie_list = examples[query_column_name]
passage_list = examples[passage_column_name]
answers = examples[answer_column_name]

queries = [f"#query# {query}" for query in querie_list]
passages = [f"#passage# {passage}" for passage in passage_list]
Expand Down
8 changes: 4 additions & 4 deletions dalm/training/utils/retriever_only_dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
def preprocess_dataset(
examples: LazyBatch,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
query_col_name: str,
passage_col_name: str,
query_column_name: str,
passage_column_name: str,
query_max_len: int,
passage_max_len: int,
) -> Dict[str, torch.Tensor]:
query_list = examples[query_col_name]
query_list = examples[query_column_name]
queries = [f"#query# {query}" for query in query_list]
result_ = tokenizer(queries, padding="max_length", max_length=query_max_len, truncation=True)
result_ = {f"query_{k}": v for k, v in result_.items()}

passage_list = examples[passage_col_name]
passage_list = examples[passage_column_name]
passages = [f"#passage# {passage}" for passage in passage_list]
result_passage = tokenizer(passages, padding="max_length", max_length=passage_max_len, truncation=True)
for k, v in result_passage.items():
Expand Down
4 changes: 2 additions & 2 deletions experiments/llama-index-10k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dalm train-rag-e2e \
"qa-outputs/question_answer_pairs.csv" \
"BAAI/bge-small-en" \
"meta-llama/Llama-2-7b-hf" \
--dataset-passage-col-name text \
--passage-column-name text \
--output-dir "rag_e2e_checkpoints_bgsmall" \
--no-with-tracking \
--per-device-train-batch-size 12
Expand All @@ -74,7 +74,7 @@ Train the retriever only
dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \
--output-dir "retriever_only_checkpoints_bgsmall" \
--use-peft \
--dataset-passage-col-name text \
--passage-column-name text \
--per-device-train-batch-size 150
```

Expand Down

0 comments on commit ec6c627

Please sign in to comment.