Skip to content

Commit c8e9e26

Browse files
authored
Merge pull request #6 from MinishLab/add_code_to_utils
fix: add code to utils. clean up
2 parents 71f91b4 + fb2a58b commit c8e9e26

File tree

3 files changed

+130
-129
lines changed

3 files changed

+130
-129
lines changed

tokenlearn/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import json
2+
from collections import Counter
3+
from pathlib import Path
4+
5+
import numpy as np
6+
from more_itertools import batched
7+
from reach import Reach
8+
from tokenizers import Tokenizer
9+
from tqdm import tqdm
10+
11+
12+
def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]:
13+
"""Collect means and texts from a list of reach paths."""
14+
txts = []
15+
v = []
16+
for path in tqdm(paths, desc="Collecting means and texts"):
17+
if not path.name.endswith(".json"):
18+
continue
19+
try:
20+
r = Reach.load(path)
21+
except KeyError:
22+
# Workaround for old format reach
23+
vectors_path = str(path).replace("_items.json", "_vectors.npy")
24+
items = json.load(open(path))["items"]
25+
vectors = np.load(open(vectors_path, "rb"))
26+
r = Reach(vectors, items)
27+
# Filter out any NaN vectors before appending
28+
non_nan_indices = ~np.isnan(r.vectors).any(axis=1)
29+
valid_vectors = r.vectors[non_nan_indices]
30+
valid_items = np.array(r.sorted_items)[non_nan_indices]
31+
txts.extend(valid_items)
32+
v.append(valid_vectors)
33+
34+
return txts, np.concatenate(v)
35+
36+
37+
def calculate_token_probabilities(tokenizer: Tokenizer, txt: list[str]) -> np.ndarray:
38+
"""Count tokens in a set of texts."""
39+
vocab_size = tokenizer.get_vocab_size()
40+
counts: Counter[int] = Counter()
41+
for t in tqdm(batched(txt, 1024)):
42+
encodings = tokenizer.encode_batch_fast(t, add_special_tokens=False)
43+
for e in encodings:
44+
counts.update(e.ids)
45+
46+
# Add the number of tokens to account for smoothing
47+
sum_id = sum(counts.values()) + vocab_size
48+
# Start with ones for smoothing
49+
x = np.ones(vocab_size)
50+
51+
for word_id, count in counts.items():
52+
x[word_id] += count
53+
54+
# Turn into probabilities
55+
x /= sum_id
56+
57+
return x

train.py

Lines changed: 22 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,51 @@
11
import argparse
2-
import json
32
import logging
4-
from collections import Counter
53
from pathlib import Path
64
from typing import Any
75

86
import numpy as np
97
import torch
108
from model2vec import StaticModel
119
from model2vec.distill import distill
12-
from model2vec.distill.distillation import _post_process_embeddings
13-
from reach import Reach
1410
from sklearn.decomposition import PCA
15-
from tqdm import tqdm
1611

1712
from tokenlearn.train import TextDataset, train_supervised
13+
from tokenlearn.utils import calculate_token_probabilities, collect_means_and_texts
1814

1915
logging.basicConfig(level=logging.INFO)
2016

2117

22-
def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]:
23-
"""Collect means and texts from a list of reach paths."""
24-
txts = []
25-
v = []
26-
for path in tqdm(paths, desc="Collecting means and texts"):
27-
if not path.name.endswith(".json"):
28-
continue
29-
try:
30-
r = Reach.load(path)
31-
except KeyError:
32-
# Workaround for old format reach
33-
vectors_path = str(path).replace("_items.json", "_vectors.npy")
34-
items = json.load(open(path))["items"]
35-
vectors = np.load(open(vectors_path, "rb"))
36-
r = Reach(vectors, items)
37-
# Filter out any NaN vectors before appending
38-
non_nan_indices = ~np.isnan(r.vectors).any(axis=1)
39-
valid_vectors = r.vectors[non_nan_indices]
40-
valid_items = np.array(r.sorted_items)[non_nan_indices]
41-
txts.extend(valid_items)
42-
v.append(valid_vectors)
43-
44-
return txts, np.concatenate(v)
45-
46-
47-
def train_model(
48-
model_name: str, data_path: str, save_path: str, device: str = "cpu", random_embeddings: bool = False
49-
) -> StaticModel:
18+
def train_model(model_name: str, train_txt: list[str], train_vec: np.ndarray, device: str = "cpu") -> StaticModel:
5019
"""
5120
Train a tokenlearn model.
5221
5322
:param model_name: The sentence transformer model name for distillation.
54-
:param data_path: Path to the directory containing the dataset.
55-
:param save_path: Path to save the trained model.
23+
:param train_txt: List of texts to train on.
24+
:param train_vec: List of vectors to train on.
5625
:param device: Device to run the training on.
57-
:param random_embeddings: Use random embeddings instead of distilling the model.
5826
:return: The trained model.
5927
"""
60-
if random_embeddings:
61-
logging.info("Using random embeddings.")
62-
s = distill(model_name)
63-
v = np.random.randn(*s.embedding.shape) # noqa NPY002
64-
v = _post_process_embeddings(v, 256, False).astype(np.float32)
65-
s = StaticModel(v, s.tokenizer)
66-
else:
67-
s = distill(model_name)
68-
69-
# Collect paths for training
70-
paths = sorted(Path(data_path).glob("*.json"))
71-
train_txt, train_vec = collect_means_and_texts(paths)
72-
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), s.tokenizer)
28+
model = distill(model_name)
29+
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer)
7330

7431
# Train the model
75-
model, _ = train_supervised(train_dataset=train_data, model=s, device=device)
76-
77-
# Save the trained model
78-
model.save_pretrained(save_path)
32+
model, _ = train_supervised(train_dataset=train_data, model=model, device=device)
7933

8034
return model
8135

8236

83-
def weight_model(model_name: str, data_path: str, pca_dims: int) -> StaticModel:
37+
def weight_model(model: StaticModel, text: list[str], pca_dims: int, alpha: float = 1e-3) -> StaticModel:
8438
"""
8539
Function to weight the model.
8640
87-
:param model_name: The model name to weight.
88-
:param data_path: Path to the directory containing the dataset.
41+
:param model: The model to weight.
42+
:param text: The text to use for weighting.
8943
:param pca_dims: The number of PCA dimensions to use.
44+
:param alpha: The alpha value for SIF weighting. Words with probabilities above this value will be downweighted.
9045
:return: The weighted model.
9146
"""
92-
# Load the trained model
93-
model = StaticModel.from_pretrained(model_name)
94-
9547
logging.info("Applying reweighting and PCA to the model.")
96-
97-
# Collect data for counting
98-
paths = sorted(Path(data_path).glob("*.json"))
99-
txt, _ = collect_means_and_texts(paths)
100-
101-
counts: Counter[str] = Counter()
102-
for t in tqdm(txt):
103-
counts.update(model.tokenizer.encode(t, add_special_tokens=False).ids)
104-
105-
sum_id = sum(counts.values()) + len(model.tokens)
106-
x = np.full(len(model.embedding), 1 / sum_id)
107-
108-
# Weight the embeddings based on frequency
109-
for word_id, count in counts.items():
110-
x[word_id] = (count + 1) / sum_id
48+
probas = calculate_token_probabilities(model.tokenizer, text)
11149

11250
w = model.embedding
11351
w = np.nan_to_num(w)
@@ -117,23 +55,25 @@ def weight_model(model_name: str, data_path: str, pca_dims: int) -> StaticModel:
11755
w = p.fit_transform(w)
11856

11957
# Apply SIF weighting
120-
alpha = 1e-3
121-
f = alpha / (alpha + x)
58+
f = alpha / (alpha + probas)
12259
w *= f[:, None]
12360
model.embedding = w
12461
model.normalize = True
12562

126-
model.save_pretrained(f"{model_name}_weighted")
127-
12863
return model
12964

13065

13166
def main(args: Any) -> None:
13267
"""Main function."""
133-
train_model(
134-
args.model_name, args.data_path, args.save_path, device=args.device, random_embeddings=args.random_embeddings
135-
)
136-
weight_model(args.save_path, args.data_path, 256)
68+
# Collect paths for training
69+
paths = sorted(Path(args.data_path).glob("*.json"))
70+
train_txt, train_vec = collect_means_and_texts(paths)
71+
72+
model = train_model(args.model_name, train_txt, train_vec, device=args.device)
73+
model.save_pretrained(args.save_path)
74+
model = weight_model(model, train_txt, 256)
75+
weighted_name = f"{args.save_path}_weighted"
76+
model.save_pretrained(weighted_name)
13777

13878

13979
if __name__ == "__main__":
@@ -151,13 +91,9 @@ def main(args: Any) -> None:
15191
"--data-path", type=str, default="data/fineweb_bgebase", help="Path to the directory containing the dataset."
15292
)
15393
parser.add_argument("--save-path", type=str, help="Path to save the trained model.")
154-
15594
parser.add_argument(
15695
"--device", type=str, default="cpu", help="Device to run the training on (e.g., 'cpu', 'cuda')."
15796
)
158-
parser.add_argument(
159-
"--random-embeddings", action="store_true", help="Use random embeddings instead of distilling the model."
160-
)
16197

16298
args = parser.parse_args()
16399

uv.lock

Lines changed: 51 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)