Skip to content

Commit

Permalink
Merge pull request #52 from arcee-ai/align-param-names
Browse files Browse the repository at this point in the history
align names to rage2e params
  • Loading branch information
Ben-Epstein authored Sep 21, 2023
2 parents f938906 + ec6c627 commit 73fc83e
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 73 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
```

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
```


Expand Down
2 changes: 1 addition & 1 deletion dalm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.3"
__version__ = "0.0.4"
28 changes: 12 additions & 16 deletions dalm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ def train_rag_e2e(
help="Path to pretrained (causal) generator or identifier from huggingface.co/models.", show_default=False
),
],
dataset_passage_col_name: Annotated[
str, typer.Option(help="Name of the column containing the passage")
] = "Abstract",
dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
dataset_answer_col_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer",
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
answer_column_name: Annotated[str, typer.Option(help="Name of the column containing the Answer")] = "Answer",
query_max_len: Annotated[
int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated")
] = 50,
Expand Down Expand Up @@ -129,9 +127,9 @@ def train_rag_e2e(
dataset_or_path=dataset_path,
retriever_name_or_path=retriever_name_or_path,
generator_name_or_path=generator_name_or_path,
dataset_passage_col_name=dataset_passage_col_name,
dataset_query_col_name=dataset_query_col_name,
dataset_answer_col_name=dataset_answer_col_name,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
answer_column_name=answer_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
generator_max_len=generator_max_len,
Expand Down Expand Up @@ -159,7 +157,7 @@ def train_rag_e2e(

@cli.command()
def train_retriever_only(
model_name_or_path: Annotated[
retriever_name_or_path: Annotated[
str, typer.Argument(help="Path to the model or identifier from huggingface.co/models.", show_default=False)
],
dataset_path: Annotated[
Expand All @@ -169,10 +167,8 @@ def train_retriever_only(
show_default=False,
),
],
dataset_passage_col_name: Annotated[
str, typer.Option(help="Name of the column containing the passage")
] = "Abstract",
dataset_query_col_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
passage_column_name: Annotated[str, typer.Option(help="Name of the column containing the passage")] = "Abstract",
query_column_name: Annotated[str, typer.Option(help="Name of the column containing the query")] = "Question",
query_max_len: Annotated[
int, typer.Option(help="The max query sequence length during tokenization. Longer sequences are truncated")
] = 50,
Expand Down Expand Up @@ -238,9 +234,9 @@ def train_retriever_only(
"""End-to-end train an in-domain model, including the retriever and generator"""
train_retriever(
dataset_or_path=dataset_path,
model_name_or_path=model_name_or_path,
dataset_passage_col_name=dataset_passage_col_name,
dataset_query_col_name=dataset_query_col_name,
retriever_name_or_path=retriever_name_or_path,
passage_column_name=passage_column_name,
query_column_name=query_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
per_device_train_batch_size=per_device_train_batch_size,
Expand Down
4 changes: 2 additions & 2 deletions dalm/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
```

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints
```
10 changes: 5 additions & 5 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--retriever_model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained retriever model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--generator_model_name_or_path",
"--generator_name_or_path",
type=str,
help="Path to pretrained generator model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -141,7 +141,7 @@ def main() -> None:

# rag retriver and the generator (don't load new peft layers no need)
rag_model = AutoModelForRagE2E(
args.retriever_model_name_or_path, args.generator_model_name_or_path, get_peft=False, use_bnb=False
args.retriever_name_or_path, args.generator_name_or_path, get_peft=False, use_bnb=False
)

# load the test dataset
Expand All @@ -162,8 +162,8 @@ def main() -> None:
lambda example: preprocess_function(
example,
retriever_tokenizer,
query_col_name=args.query_column_name,
passage_col_name=args.passage_column_name,
query_column_name=args.query_column_name,
passage_column_name=args.passage_column_name,
),
batched=True,
# remove_columns=test_dataset.column_names,
Expand Down
8 changes: 4 additions & 4 deletions dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--retriever_model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained retriever model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -104,7 +104,7 @@ def main() -> None:
SELECTED_TORCH_DTYPE: Final[torch.dtype] = torch.float16 if args.torch_dtype == "float16" else torch.bfloat16

# rag retriver and the generator (don't load new peft layers no need)
retriever_model = AutoModelForSentenceEmbedding(args.retriever_model_name_or_path, get_peft=False, use_bnb=False)
retriever_model = AutoModelForSentenceEmbedding(args.retriever_name_or_path, get_peft=False, use_bnb=False)

# load the test dataset
test_dataset = (
Expand All @@ -121,8 +121,8 @@ def main() -> None:
lambda example: preprocess_function(
example,
retriever_tokenizer,
query_col_name=args.query_column_name,
passage_col_name=args.passage_column_name,
query_column_name=args.query_column_name,
passage_column_name=args.passage_column_name,
),
batched=True,
# remove_columns=test_dataset.column_names,
Expand Down
8 changes: 4 additions & 4 deletions dalm/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def calculate_precision_recall(retrieved_items: List, correct_items: List) -> Tu
def preprocess_function(
examples: LazyBatch,
retriever_tokenizer: PreTrainedTokenizer,
query_col_name: str = "query",
passage_col_name: str = "passage",
query_column_name: str = "query",
passage_column_name: str = "passage",
) -> Dict[str, torch.Tensor]:
queries = examples[query_col_name]
passages = examples[passage_col_name]
queries = examples[query_column_name]
passages = examples[passage_column_name]

# Tokenization for the retriever
retriever_query_tokens = retriever_tokenizer(queries, padding="max_length", max_length=128, truncation=True)
Expand Down
24 changes: 12 additions & 12 deletions dalm/training/rag_e2e/train_rage2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def parse_args() -> Namespace:
help=("Dataset path. Can be a huggingface dataset directory or a csv file."),
)
parser.add_argument(
"--dataset_passage_col_name", type=str, default="Abstract", help="Name of the column containing the passage"
"--passage_column_name", type=str, default="Abstract", help="Name of the column containing the passage"
)
parser.add_argument(
"--dataset_query_col_name", type=str, default="Question", help="Name of the column containing the query"
"--query_column_name", type=str, default="Question", help="Name of the column containing the query"
)
parser.add_argument(
"--dataset_answer_col_name", type=str, default="Answer", help="Name of the column containing the answer"
"--answer_column_name", type=str, default="Answer", help="Name of the column containing the answer"
)
parser.add_argument(
"--query_max_len",
Expand Down Expand Up @@ -217,9 +217,9 @@ def train_e2e(
dataset_or_path: str | Dataset,
retriever_name_or_path: str,
generator_name_or_path: str,
dataset_passage_col_name: str = "Abstract",
dataset_query_col_name: str = "Question",
dataset_answer_col_name: str = "Answer",
passage_column_name: str = "Abstract",
query_column_name: str = "Question",
answer_column_name: str = "Answer",
query_max_len: int = 50,
passage_max_len: int = 128,
generator_max_len: int = 256,
Expand Down Expand Up @@ -295,9 +295,9 @@ def train_e2e(
example,
retriever_tokenizer=rag_model.retriever_tokenizer,
generator_tokenizer=rag_model.generator_tokenizer,
dataset_query_col_name=dataset_query_col_name,
dataset_passage_col_name=dataset_passage_col_name,
dataset_answer_col_name=dataset_answer_col_name,
query_column_name=query_column_name,
passage_column_name=passage_column_name,
answer_column_name=answer_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
generator_max_len=generator_max_len,
Expand Down Expand Up @@ -523,9 +523,9 @@ def main() -> None:
dataset_or_path=args.dataset_path,
retriever_name_or_path=args.retriever_name_or_path,
generator_name_or_path=args.generator_name_or_path,
dataset_passage_col_name=args.dataset_passage_col_name,
dataset_query_col_name=args.dataset_query_col_name,
dataset_answer_col_name=args.dataset_answer_col_name,
passage_column_name=args.passage_column_name,
query_column_name=args.query_column_name,
answer_column_name=args.answer_column_name,
query_max_len=args.query_max_len,
passage_max_len=args.passage_max_len,
generator_max_len=args.generator_max_len,
Expand Down
26 changes: 13 additions & 13 deletions dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="training a PEFT model for Sematic Search task")
parser.add_argument("--dataset_path", type=str, default=None, help="dataset path in the local dir")
parser.add_argument(
"--dataset_query_col_name", type=str, default="Question", help="Name of the query column in the dataset"
"--query_column_name", type=str, default="Question", help="Name of the query column in the dataset"
)
parser.add_argument(
"--dataset_passage_col_name", type=str, default="Abstract", help="Name of the passage column in the dataset"
"--passage_column_name", type=str, default="Abstract", help="Name of the passage column in the dataset"
)
parser.add_argument(
"--query_max_len",
Expand All @@ -67,7 +67,7 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -163,10 +163,10 @@ def parse_args() -> Namespace:


def train_retriever(
model_name_or_path: str,
retriever_name_or_path: str,
dataset_or_path: str | Dataset,
dataset_passage_col_name: str = "Abstract",
dataset_query_col_name: str = "Question",
passage_column_name: str = "Abstract",
query_column_name: str = "Question",
query_max_len: int = 50,
passage_max_len: int = 128,
per_device_train_batch_size: int = 32,
Expand Down Expand Up @@ -220,7 +220,7 @@ def train_retriever(
os.makedirs(output_dir, exist_ok=True)
accelerator.wait_for_everyone()

model = AutoModelForSentenceEmbedding(model_name_or_path, use_bnb=True, get_peft=use_peft)
model = AutoModelForSentenceEmbedding(retriever_name_or_path, use_bnb=True, get_peft=use_peft)
tokenizer = model.tokenizer

# dataset download and preprocessing
Expand All @@ -236,8 +236,8 @@ def train_retriever(
lambda example: preprocess_dataset(
example,
tokenizer,
query_col_name=dataset_query_col_name,
passage_col_name=dataset_passage_col_name,
query_column_name=query_column_name,
passage_column_name=passage_column_name,
query_max_len=query_max_len,
passage_max_len=passage_max_len,
),
Expand Down Expand Up @@ -417,9 +417,9 @@ def main() -> None:
args = parse_args()
train_retriever(
dataset_or_path=args.dataset_path,
model_name_or_path=args.model_name_or_path,
dataset_passage_col_name=args.dataset_passage_col_name,
dataset_query_col_name=args.dataset_query_col_name,
retriever_name_or_path=args.retriever_name_or_path,
passage_column_name=args.passage_column_name,
query_column_name=args.query_column_name,
query_max_len=args.query_max_len,
passage_max_len=args.passage_max_len,
per_device_train_batch_size=args.per_device_train_batch_size,
Expand Down Expand Up @@ -449,5 +449,5 @@ def main() -> None:


# python contrastive_train/peft_lora_constrastive_learning.py --dataset_path "xxxx.csv" \
# --model_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \
# --retriever_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \
# --with_tracking --report_to all --per_device_train_batch_size 30
12 changes: 6 additions & 6 deletions dalm/training/utils/rag_e2e_dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ def preprocess_dataset(
examples: LazyBatch,
retriever_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
generator_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
dataset_query_col_name: str,
dataset_passage_col_name: str,
dataset_answer_col_name: str,
query_column_name: str,
passage_column_name: str,
answer_column_name: str,
query_max_len: int,
passage_max_len: int,
generator_max_len: int,
) -> Dict[str, Any]:
querie_list = examples[dataset_query_col_name]
passage_list = examples[dataset_passage_col_name]
answers = examples[dataset_answer_col_name]
querie_list = examples[query_column_name]
passage_list = examples[passage_column_name]
answers = examples[answer_column_name]

queries = [f"#query# {query}" for query in querie_list]
passages = [f"#passage# {passage}" for passage in passage_list]
Expand Down
8 changes: 4 additions & 4 deletions dalm/training/utils/retriever_only_dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
def preprocess_dataset(
examples: LazyBatch,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
query_col_name: str,
passage_col_name: str,
query_column_name: str,
passage_column_name: str,
query_max_len: int,
passage_max_len: int,
) -> Dict[str, torch.Tensor]:
query_list = examples[query_col_name]
query_list = examples[query_column_name]
queries = [f"#query# {query}" for query in query_list]
result_ = tokenizer(queries, padding="max_length", max_length=query_max_len, truncation=True)
result_ = {f"query_{k}": v for k, v in result_.items()}

passage_list = examples[passage_col_name]
passage_list = examples[passage_column_name]
passages = [f"#passage# {passage}" for passage in passage_list]
result_passage = tokenizer(passages, padding="max_length", max_length=passage_max_len, truncation=True)
for k, v in result_passage.items():
Expand Down
8 changes: 4 additions & 4 deletions experiments/llama-index-10k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ dalm train-rag-e2e \
"qa-outputs/question_answer_pairs.csv" \
"BAAI/bge-small-en" \
"meta-llama/Llama-2-7b-hf" \
--dataset-passage-col-name text \
--passage-column-name text \
--output-dir "rag_e2e_checkpoints_bgsmall" \
--no-with-tracking \
--per-device-train-batch-size 12
```

And eval
```
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
*************
Retriever results:
Expand All @@ -74,13 +74,13 @@ Train the retriever only
dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \
--output-dir "retriever_only_checkpoints_bgsmall" \
--use-peft \
--dataset-passage-col-name text \
--passage-column-name text \
--per-device-train-batch-size 150
```

and eval
```
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
*************
Retriever results:
Expand Down

0 comments on commit 73fc83e

Please sign in to comment.