Skip to content

Commit

Permalink
Merge pull request #42 from arcee-ai/split-by-title
Browse files Browse the repository at this point in the history
split dataset by title
  • Loading branch information
Ben-Epstein authored Sep 19, 2023
2 parents d0e5a6c + 441f10d commit 0cf02bc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
21 changes: 20 additions & 1 deletion dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datasets
import torch
from sklearn.model_selection import train_test_split
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -83,6 +84,24 @@ def filter_malformed_questions(record: dict) -> bool:
return question != "-" and answer != "-"


def split_dataset(shuffled_dataset: datasets.Dataset, test_size: float = TEST_SIZE) -> datasets.DatasetDict:
unique_titles = set(shuffled_dataset[args.title_column_name])

train_titles, test_titles = train_test_split(list(unique_titles), test_size=test_size, random_state=42)

train_dataset = shuffled_dataset.filter(
lambda example: example[args.title_column_name] in train_titles, num_proc=128
)
test_dataset = shuffled_dataset.filter(lambda example: example[args.title_column_name] in test_titles, num_proc=128)

return datasets.DatasetDict(
{
"train": train_dataset,
"test": test_dataset,
}
)


dataset = datasets.load_dataset("csv", data_files={"data": args.dataset_path})["data"]

# shuffle data
Expand All @@ -92,7 +111,7 @@ def filter_malformed_questions(record: dict) -> bool:
small_dataset = dataset.select(range(args.sample_size))

# train-test split
small_dataset_splits = small_dataset.train_test_split(test_size=TEST_SIZE)
small_dataset_splits = split_dataset(small_dataset)

print(
f"Train dataset size: {len(small_dataset_splits['train'])}, Test dataset size: {len(small_dataset_splits['test'])}"
Expand Down
28 changes: 12 additions & 16 deletions dalm/training/utils/rag_e2e_dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def preprocess_dataset(
dataset_answer_col_name: str,
query_max_len: int,
passage_max_len: int,
generator_max_len: int
generator_max_len: int,
) -> Dict[str, Any]:
querie_list = examples[dataset_query_col_name]
passage_list = examples[dataset_passage_col_name]
Expand All @@ -23,34 +23,30 @@ def preprocess_dataset(
passages = [f"#passage# {passage}" for passage in passage_list]

# Tokenization for the retriever
retriever_query_tokens = retriever_tokenizer(queries, padding="max_length",
max_length=query_max_len,
truncation=True)
retriever_passage_tokens = retriever_tokenizer(passages,
padding="max_length",
max_length=passage_max_len,
truncation=True)
retriever_query_tokens = retriever_tokenizer(
queries, padding="max_length", max_length=query_max_len, truncation=True
)
retriever_passage_tokens = retriever_tokenizer(
passages, padding="max_length", max_length=passage_max_len, truncation=True
)

# Tokenize for causal model
# Here, we need to combine the query, passage, and the answer as the input, and the answer as the output
casual_input_text = [
f"#query# {query} #passage# {passage} #answer# {answer}"
for passage, query, answer in zip(passages, queries, answers, strict=True)
]
causal_input_tokens = generator_tokenizer(casual_input_text,
padding="max_length",
max_length=generator_max_len,
truncation=True)
causal_input_tokens = generator_tokenizer(
casual_input_text, padding="max_length", max_length=generator_max_len, truncation=True
)

query_passage_text = [
f"#query# {query} #passage# {passage} #answer#"
for passage, query in zip(passages, queries, strict=True)
f"#query# {query} #passage# {passage} #answer#" for passage, query in zip(passages, queries, strict=True)
]

query_passage_lengths = []

query_passage_tokens = generator_tokenizer(query_passage_text,
padding=False)
query_passage_tokens = generator_tokenizer(query_passage_text, padding=False)

for single_query_passage in query_passage_tokens["input_ids"]:
query_passage_lengths.append(len(single_query_passage))
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ packages = [
{ include = "dalm" }
]
dependencies = [
"scikit-learn",
"transformers",
"peft",
"accelerate",
Expand Down Expand Up @@ -128,3 +129,7 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "hnswlib.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "sklearn.*"
ignore_missing_imports = true

0 comments on commit 0cf02bc

Please sign in to comment.