Skip to content

Commit 9a8763a

Browse files
authored
feat: Prepared package for release (#8)
* Prepared package for release * Removed reach dependency * Removed reach dependency * Updates * Updated cli * Updated readme * Removed file
1 parent 41d2159 commit 9a8763a

File tree

9 files changed

+974
-1774
lines changed

9 files changed

+974
-1774
lines changed

README.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,25 @@ Tokenlearn is a method to pre-train [Model2Vec](https://github.com/MinishLab/mod
33

44
The method is described in detail in our [Tokenlearn blogpost](https://minishlab.github.io/tokenlearn_blogpost/).
55

6-
## Usage
6+
## Quickstart
77

8-
### Featurizing
9-
Tokenlearn is trained using means from a sentence transformer. To create means, the `featurize` script can be used:
8+
Install the package with:
109

1110
```bash
12-
python tokenlearn/featurize.py
11+
pip install tokenlearn
1312
```
1413

15-
This will create means for [C4](https://huggingface.co/datasets/allenai/c4) using [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5).
14+
The basic usage of Tokenlearn consists of two CLI scripts: `featurize` and `train`.
1615

17-
### Training
18-
The easiest way to train using Tokenlearn is to use the CLI. You can use the following command to train a model:
16+
Tokenlearn is trained using means from a sentence transformer. To create means, the `tokenlearn-featurize` CLI can be used:
1917

2018
```bash
21-
python train.py --data-path <path-to-your-data> --save-path <path-to-save-model>
19+
python3 -m tokenlearn.featurize --model-name "baai/bge-base-en-v1.5" --output-dir "data/c4_features"
20+
```
21+
22+
To train a model on the featurized data, the `tokenlearn-train` CLI can be used:
23+
```bash
24+
python3 -m tokenlearn.train --model-name "baai/bge-base-en-v1.5" --data-path "data/c4_features" --save-path "<path-to-save-model>"
2225
```
2326

2427
Training will create two models:
@@ -27,9 +30,14 @@ Training will create two models:
2730

2831
NOTE: the code assumes that the padding token ID in your tokenizer is 0. If this is not the case, you will need to modify the code.
2932

30-
### Evaluating
33+
### Evaluation
3134

32-
To evaluate a model, you can use the following command:
35+
To evaluate a model, you can use the following command after installing the optional evaluation dependencies:
36+
37+
```bash
38+
pip install evaluation@git+https://github.com/MinishLab/evaluation@main
39+
40+
```
3341

3442
```python
3543
from model2vec import StaticModel
@@ -61,3 +69,7 @@ task_scores = summarize_results(parsed_results)
6169
# Print the results in a leaderboard format
6270
print(make_leaderboard(task_scores))
6371
```
72+
73+
## License
74+
75+
MIT

pyproject.toml

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
[project]
22
name = "tokenlearn"
3-
description = "Pre-train static embedders."
3+
description = "Pre-train Static Embedders"
44
readme = "README.md"
5-
version = "0.1.0"
6-
requires-python = ">=3.10"
5+
requires-python = ">=3.9"
6+
authors = [{name = "Thomas van Dongen", email = "[email protected]"}, { name = "Stéphan Tulkens", email = "[email protected]"}]
7+
dynamic = ["version"]
8+
9+
classifiers = [
10+
"Development Status :: 4 - Beta",
11+
"Intended Audience :: Developers",
12+
"Intended Audience :: Science/Research",
13+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
14+
"Topic :: Software Development :: Libraries",
15+
"License :: OSI Approved :: MIT License",
16+
"Programming Language :: Python :: 3 :: Only",
17+
"Programming Language :: Python :: 3.9",
18+
"Programming Language :: Python :: 3.10",
19+
"Programming Language :: Python :: 3.11",
20+
"Programming Language :: Python :: 3.12",
21+
"Natural Language :: English",
22+
]
723

824
dependencies = [
9-
"model2vec>=0.3.0",
25+
"model2vec[distill]>=0.3.0",
1026
"sentence-transformers",
1127
"torch",
1228
"datasets",
1329
"more-itertools>=10.5.0",
14-
"reach@git+https://github.com/stephantul/reach@main",
15-
"evaluation@git+https://github.com/MinishLab/evaluation@main"
1630
]
1731

1832
[build-system]
@@ -77,3 +91,6 @@ packages = ["tokenlearn"]
7791

7892
[tool.setuptools_scm]
7993
# can be empty if no extra settings are needed, presence enables setuptools_scm
94+
95+
[tool.setuptools.dynamic]
96+
version = {attr = "tokenlearn.version.__version__"}

tokenlearn/featurize.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,39 @@
11
import argparse
2+
import json
3+
import logging
24
from pathlib import Path
35
from typing import Iterable
46

57
import numpy as np
68
from datasets import load_dataset
79
from more_itertools import batched
8-
from reach import Reach
910
from sentence_transformers import SentenceTransformer
1011
from tqdm import tqdm
1112

1213
_SAVE_INTERVAL = 10
1314
_MAX_MEANS = 1100000
1415

16+
logger = logging.getLogger(__name__)
17+
18+
19+
def save_data(means: list[np.ndarray], txts: list[str], base_filepath: str) -> None:
20+
"""
21+
Save the means and texts to separate files.
22+
23+
:param means: List of numpy arrays representing the mean embeddings.
24+
:param txts: List of texts corresponding to the embeddings.
25+
:param base_filepath: Base path for the output files.
26+
"""
27+
vectors_filepath = base_filepath + "_vectors.npy"
28+
items_filepath = base_filepath + "_items.json"
29+
30+
# Save the embeddings (vectors) to a .npy file
31+
np.save(vectors_filepath, np.array(means))
32+
# Save the texts to a JSON file
33+
with open(items_filepath, "w") as f:
34+
json.dump({"items": txts}, f)
35+
logger.info(f"Saved {len(txts)} texts to {items_filepath} and vectors to {vectors_filepath}")
36+
1537

1638
def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str) -> None:
1739
"""
@@ -35,55 +57,76 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
3557

3658
for index, batch in enumerate(tqdm(batched(texts, 32))):
3759
i = index // _SAVE_INTERVAL
38-
if (out_path / f"featurized_{i}.json").exists():
39-
continue
40-
# Consume the generator
60+
base_filename = f"featurized_{i}"
61+
vectors_filepath = out_path / (base_filename + "_vectors.npy")
62+
items_filepath = out_path / (base_filename + "_items.json")
4163
list_batch = [x["text"].strip() for x in batch if x.get("text")]
64+
if not list_batch:
65+
continue # Skip empty batches
66+
67+
# Encode the batch to get token embeddings
68+
token_embeddings = model.encode(
69+
list_batch,
70+
output_value="token_embeddings",
71+
convert_to_tensor=True,
72+
)
4273

43-
# Already truncated to model max_length
74+
# Tokenize the batch to get input IDs
4475
tokenized_ids = model.tokenize(list_batch)["input_ids"]
45-
token_embeddings: list[np.ndarray] = [
46-
x.cpu().numpy() for x in model.encode(list_batch, output_value="token_embeddings", convert_to_numpy=True)
47-
]
4876

49-
for tokenized_id, token_embedding in zip(tokenized_ids, token_embeddings, strict=True):
50-
# Truncate to actual length of vectors, remove CLS and SEP.
51-
text = model.tokenizer.decode(tokenized_id[1 : len(token_embedding) - 1])
77+
for tokenized_id, token_embedding in zip(tokenized_ids, token_embeddings):
78+
# Convert token IDs to tokens (excluding special tokens)
79+
token_ids = tokenized_id[1:-1]
80+
# Decode tokens to text
81+
text = model.tokenizer.decode(token_ids)
5282
if text in seen:
5383
continue
5484
seen.add(text)
55-
mean = np.mean(token_embedding[1:-1], axis=0)
85+
# Get the corresponding token embeddings (excluding special tokens)
86+
token_embeds = token_embedding[1:-1]
87+
# Convert embeddings to NumPy arrays
88+
token_embeds = token_embeds.detach().cpu().numpy()
89+
# Compute the mean of the token embeddings
90+
mean = np.mean(token_embeds, axis=0)
5691
txts.append(text)
5792
means.append(mean)
5893
total_means += 1
5994

6095
if total_means >= _MAX_MEANS:
61-
# Save the final batch and stop
62-
r = Reach(means, txts)
63-
r.save(out_path / f"featurized_{(index // _SAVE_INTERVAL)}.json")
96+
save_data(means, txts, str(out_path / base_filename))
6497
return
6598

6699
if index > 0 and (index + 1) % _SAVE_INTERVAL == 0:
67-
r = Reach(means, txts)
68-
r.save(out_path / f"featurized_{(index // _SAVE_INTERVAL)}.json")
100+
save_data(means, txts, str(out_path / base_filename))
69101
txts = []
70102
means = []
71103
seen = set()
72104
else:
73-
if means:
74-
r = Reach(means, txts)
75-
r.save(out_path / f"featurized_{(index // _SAVE_INTERVAL)}.json")
105+
if txts and means:
106+
save_data(means, txts, str(out_path / base_filename))
76107

77108

78-
if __name__ == "__main__":
79-
parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.")
109+
def main() -> None:
110+
"""Main function to featurize texts using a sentence transformer."""
111+
parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.")
80112
parser.add_argument(
81113
"--model-name",
82114
type=str,
83115
default="baai/bge-base-en-v1.5",
84116
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
85117
)
118+
parser.add_argument(
119+
"--output-dir",
120+
type=str,
121+
default="data/c4_bgebase",
122+
help="Directory to save the featurized texts.",
123+
)
86124
args = parser.parse_args()
125+
87126
model = SentenceTransformer(args.model_name)
88127
dataset = load_dataset("allenai/c4", name="en", split="train", streaming=True)
89-
featurize(dataset, model, "data/c4_bgebase")
128+
featurize(dataset, model, args.output_dir)
129+
130+
131+
if __name__ == "__main__":
132+
main()

0 commit comments

Comments
 (0)