1
1
import argparse
2
- import json
3
2
import logging
4
- from collections import Counter
5
3
from pathlib import Path
6
4
from typing import Any
7
5
8
6
import numpy as np
9
7
import torch
10
8
from model2vec import StaticModel
11
9
from model2vec .distill import distill
12
- from model2vec .distill .distillation import _post_process_embeddings
13
- from reach import Reach
14
10
from sklearn .decomposition import PCA
15
- from tqdm import tqdm
16
11
17
12
from tokenlearn .train import TextDataset , train_supervised
13
+ from tokenlearn .utils import calculate_token_probabilities , collect_means_and_texts
18
14
19
15
logging .basicConfig (level = logging .INFO )
20
16
21
17
22
- def collect_means_and_texts (paths : list [Path ]) -> tuple [list [str ], np .ndarray ]:
23
- """Collect means and texts from a list of reach paths."""
24
- txts = []
25
- v = []
26
- for path in tqdm (paths , desc = "Collecting means and texts" ):
27
- if not path .name .endswith (".json" ):
28
- continue
29
- try :
30
- r = Reach .load (path )
31
- except KeyError :
32
- # Workaround for old format reach
33
- vectors_path = str (path ).replace ("_items.json" , "_vectors.npy" )
34
- items = json .load (open (path ))["items" ]
35
- vectors = np .load (open (vectors_path , "rb" ))
36
- r = Reach (vectors , items )
37
- # Filter out any NaN vectors before appending
38
- non_nan_indices = ~ np .isnan (r .vectors ).any (axis = 1 )
39
- valid_vectors = r .vectors [non_nan_indices ]
40
- valid_items = np .array (r .sorted_items )[non_nan_indices ]
41
- txts .extend (valid_items )
42
- v .append (valid_vectors )
43
-
44
- return txts , np .concatenate (v )
45
-
46
-
47
- def train_model (
48
- model_name : str , data_path : str , save_path : str , device : str = "cpu" , random_embeddings : bool = False
49
- ) -> StaticModel :
18
+ def train_model (model_name : str , train_txt : list [str ], train_vec : np .ndarray , device : str = "cpu" ) -> StaticModel :
50
19
"""
51
20
Train a tokenlearn model.
52
21
53
22
:param model_name: The sentence transformer model name for distillation.
54
- :param data_path: Path to the directory containing the dataset .
55
- :param save_path: Path to save the trained model .
23
+ :param train_txt: List of texts to train on .
24
+ :param train_vec: List of vectors to train on .
56
25
:param device: Device to run the training on.
57
- :param random_embeddings: Use random embeddings instead of distilling the model.
58
26
:return: The trained model.
59
27
"""
60
- if random_embeddings :
61
- logging .info ("Using random embeddings." )
62
- s = distill (model_name )
63
- v = np .random .randn (* s .embedding .shape ) # noqa NPY002
64
- v = _post_process_embeddings (v , 256 , False ).astype (np .float32 )
65
- s = StaticModel (v , s .tokenizer )
66
- else :
67
- s = distill (model_name )
68
-
69
- # Collect paths for training
70
- paths = sorted (Path (data_path ).glob ("*.json" ))
71
- train_txt , train_vec = collect_means_and_texts (paths )
72
- train_data = TextDataset (train_txt , torch .from_numpy (train_vec ), s .tokenizer )
28
+ model = distill (model_name )
29
+ train_data = TextDataset (train_txt , torch .from_numpy (train_vec ), model .tokenizer )
73
30
74
31
# Train the model
75
- model , _ = train_supervised (train_dataset = train_data , model = s , device = device )
76
-
77
- # Save the trained model
78
- model .save_pretrained (save_path )
32
+ model , _ = train_supervised (train_dataset = train_data , model = model , device = device )
79
33
80
34
return model
81
35
82
36
83
- def weight_model (model_name : str , data_path : str , pca_dims : int ) -> StaticModel :
37
+ def weight_model (model : StaticModel , text : list [ str ] , pca_dims : int , alpha : float = 1e-3 ) -> StaticModel :
84
38
"""
85
39
Function to weight the model.
86
40
87
- :param model_name : The model name to weight.
88
- :param data_path: Path to the directory containing the dataset .
41
+ :param model : The model to weight.
42
+ :param text: The text to use for weighting .
89
43
:param pca_dims: The number of PCA dimensions to use.
44
+ :param alpha: The alpha value for SIF weighting. Words with probabilities above this value will be downweighted.
90
45
:return: The weighted model.
91
46
"""
92
- # Load the trained model
93
- model = StaticModel .from_pretrained (model_name )
94
-
95
47
logging .info ("Applying reweighting and PCA to the model." )
96
-
97
- # Collect data for counting
98
- paths = sorted (Path (data_path ).glob ("*.json" ))
99
- txt , _ = collect_means_and_texts (paths )
100
-
101
- counts : Counter [str ] = Counter ()
102
- for t in tqdm (txt ):
103
- counts .update (model .tokenizer .encode (t , add_special_tokens = False ).ids )
104
-
105
- sum_id = sum (counts .values ()) + len (model .tokens )
106
- x = np .full (len (model .embedding ), 1 / sum_id )
107
-
108
- # Weight the embeddings based on frequency
109
- for word_id , count in counts .items ():
110
- x [word_id ] = (count + 1 ) / sum_id
48
+ probas = calculate_token_probabilities (model .tokenizer , text )
111
49
112
50
w = model .embedding
113
51
w = np .nan_to_num (w )
@@ -117,23 +55,25 @@ def weight_model(model_name: str, data_path: str, pca_dims: int) -> StaticModel:
117
55
w = p .fit_transform (w )
118
56
119
57
# Apply SIF weighting
120
- alpha = 1e-3
121
- f = alpha / (alpha + x )
58
+ f = alpha / (alpha + probas )
122
59
w *= f [:, None ]
123
60
model .embedding = w
124
61
model .normalize = True
125
62
126
- model .save_pretrained (f"{ model_name } _weighted" )
127
-
128
63
return model
129
64
130
65
131
66
def main (args : Any ) -> None :
132
67
"""Main function."""
133
- train_model (
134
- args .model_name , args .data_path , args .save_path , device = args .device , random_embeddings = args .random_embeddings
135
- )
136
- weight_model (args .save_path , args .data_path , 256 )
68
+ # Collect paths for training
69
+ paths = sorted (Path (args .data_path ).glob ("*.json" ))
70
+ train_txt , train_vec = collect_means_and_texts (paths )
71
+
72
+ model = train_model (args .model_name , train_txt , train_vec , device = args .device )
73
+ model .save_pretrained (args .save_path )
74
+ model = weight_model (model , train_txt , 256 )
75
+ weighted_name = f"{ args .save_path } _weighted"
76
+ model .save_pretrained (weighted_name )
137
77
138
78
139
79
if __name__ == "__main__" :
@@ -151,13 +91,9 @@ def main(args: Any) -> None:
151
91
"--data-path" , type = str , default = "data/fineweb_bgebase" , help = "Path to the directory containing the dataset."
152
92
)
153
93
parser .add_argument ("--save-path" , type = str , help = "Path to save the trained model." )
154
-
155
94
parser .add_argument (
156
95
"--device" , type = str , default = "cpu" , help = "Device to run the training on (e.g., 'cpu', 'cuda')."
157
96
)
158
- parser .add_argument (
159
- "--random-embeddings" , action = "store_true" , help = "Use random embeddings instead of distilling the model."
160
- )
161
97
162
98
args = parser .parse_args ()
163
99
0 commit comments