Skip to content

Commit cd3c58e

Browse files
authored
feat: Make dataset selection flexible (#11)
* Made datasets flexible * Made datasets flexible
1 parent e335d4c commit cd3c58e

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ Tokenlearn is trained using means from a sentence transformer. To create means,
1919
python3 -m tokenlearn.featurize --model-name "baai/bge-base-en-v1.5" --output-dir "data/c4_features"
2020
```
2121

22+
NOTE: the default model is trained on the C4 dataset. If you want to use a different dataset, the following code can be used:
23+
24+
```bash
25+
python3 -m tokenlearn.featurize \
26+
--model-name "baai/bge-base-en-v1.5" \
27+
--output-dir "data/c4_features" \
28+
--dataset-path "allenai/c4" \
29+
--dataset-name "en" \
30+
--dataset-split "train"
31+
```
32+
2233
To train a model on the featurized data, the `tokenlearn-train` CLI can be used:
2334
```bash
2435
python3 -m tokenlearn.train --model-name "baai/bge-base-en-v1.5" --data-path "data/c4_features" --save-path "<path-to-save-model>"

tokenlearn/featurize.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from tqdm import tqdm
1212

1313
_SAVE_INTERVAL = 10
14-
_MAX_MEANS = 1100000
15-
1614
logger = logging.getLogger(__name__)
1715

1816

@@ -35,13 +33,14 @@ def save_data(means: list[np.ndarray], txts: list[str], base_filepath: str) -> N
3533
logger.info(f"Saved {len(txts)} texts to {items_filepath} and vectors to {vectors_filepath}")
3634

3735

38-
def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str) -> None:
36+
def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str, max_means: int) -> None:
3937
"""
4038
Featurize text using a sentence transformer.
4139
4240
:param texts: Iterable of texts to featurize.
4341
:param model: SentenceTransformer model to use.
4442
:param output_dir: Directory to save the featurized texts.
43+
:param max_means: Maximum number of mean embeddings to generate.
4544
:raises ValueError: If the model does not have a fixed dimension.
4645
"""
4746
out_path = Path(output_dir)
@@ -58,8 +57,6 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
5857
for index, batch in enumerate(tqdm(batched(texts, 32))):
5958
i = index // _SAVE_INTERVAL
6059
base_filename = f"featurized_{i}"
61-
vectors_filepath = out_path / (base_filename + "_vectors.npy")
62-
items_filepath = out_path / (base_filename + "_items.json")
6360
list_batch = [x["text"].strip() for x in batch if x.get("text")]
6461
if not list_batch:
6562
continue # Skip empty batches
@@ -92,7 +89,7 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
9289
means.append(mean)
9390
total_means += 1
9491

95-
if total_means >= _MAX_MEANS:
92+
if total_means >= max_means:
9693
save_data(means, txts, str(out_path / base_filename))
9794
return
9895

@@ -121,11 +118,46 @@ def main() -> None:
121118
default="data/c4_bgebase",
122119
help="Directory to save the featurized texts.",
123120
)
121+
parser.add_argument(
122+
"--dataset-path",
123+
type=str,
124+
default="allenai/c4",
125+
help="The dataset path or name (e.g. 'allenai/c4').",
126+
)
127+
parser.add_argument(
128+
"--dataset-name",
129+
type=str,
130+
default="en",
131+
help="The dataset configuration name (e.g., 'en' for C4).",
132+
)
133+
parser.add_argument(
134+
"--dataset-split",
135+
type=str,
136+
default="train",
137+
help="The dataset split (e.g., 'train', 'validation').",
138+
)
139+
parser.add_argument(
140+
"--no-streaming",
141+
action="store_true",
142+
help="Disable streaming mode when loading the dataset.",
143+
)
144+
parser.add_argument(
145+
"--max-means",
146+
type=int,
147+
default=1000000,
148+
help="The maximum number of mean embeddings to generate.",
149+
)
150+
124151
args = parser.parse_args()
125152

126153
model = SentenceTransformer(args.model_name)
127-
dataset = load_dataset("allenai/c4", name="en", split="train", streaming=True)
128-
featurize(dataset, model, args.output_dir)
154+
dataset = load_dataset(
155+
args.dataset_path,
156+
name=args.dataset_name,
157+
split=args.dataset_split,
158+
streaming=not args.no_streaming,
159+
)
160+
featurize(dataset, model, args.output_dir, args.max_means)
129161

130162

131163
if __name__ == "__main__":

0 commit comments

Comments
 (0)