1
1
import argparse
2
+ import json
3
+ import logging
2
4
from pathlib import Path
3
5
from typing import Iterable
4
6
5
7
import numpy as np
6
8
from datasets import load_dataset
7
9
from more_itertools import batched
8
- from reach import Reach
9
10
from sentence_transformers import SentenceTransformer
10
11
from tqdm import tqdm
11
12
12
13
_SAVE_INTERVAL = 10
13
14
_MAX_MEANS = 1100000
14
15
16
+ logger = logging .getLogger (__name__ )
17
+
18
+
19
+ def save_data (means : list [np .ndarray ], txts : list [str ], base_filepath : str ) -> None :
20
+ """
21
+ Save the means and texts to separate files.
22
+
23
+ :param means: List of numpy arrays representing the mean embeddings.
24
+ :param txts: List of texts corresponding to the embeddings.
25
+ :param base_filepath: Base path for the output files.
26
+ """
27
+ vectors_filepath = base_filepath + "_vectors.npy"
28
+ items_filepath = base_filepath + "_items.json"
29
+
30
+ # Save the embeddings (vectors) to a .npy file
31
+ np .save (vectors_filepath , np .array (means ))
32
+ # Save the texts to a JSON file
33
+ with open (items_filepath , "w" ) as f :
34
+ json .dump ({"items" : txts }, f )
35
+ logger .info (f"Saved { len (txts )} texts to { items_filepath } and vectors to { vectors_filepath } " )
36
+
15
37
16
38
def featurize (texts : Iterable [str ], model : SentenceTransformer , output_dir : str ) -> None :
17
39
"""
@@ -35,55 +57,76 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
35
57
36
58
for index , batch in enumerate (tqdm (batched (texts , 32 ))):
37
59
i = index // _SAVE_INTERVAL
38
- if ( out_path / f"featurized_{ i } .json" ). exists ():
39
- continue
40
- # Consume the generator
60
+ base_filename = f"featurized_{ i } "
61
+ vectors_filepath = out_path / ( base_filename + "_vectors.npy" )
62
+ items_filepath = out_path / ( base_filename + "_items.json" )
41
63
list_batch = [x ["text" ].strip () for x in batch if x .get ("text" )]
64
+ if not list_batch :
65
+ continue # Skip empty batches
66
+
67
+ # Encode the batch to get token embeddings
68
+ token_embeddings = model .encode (
69
+ list_batch ,
70
+ output_value = "token_embeddings" ,
71
+ convert_to_tensor = True ,
72
+ )
42
73
43
- # Already truncated to model max_length
74
+ # Tokenize the batch to get input IDs
44
75
tokenized_ids = model .tokenize (list_batch )["input_ids" ]
45
- token_embeddings : list [np .ndarray ] = [
46
- x .cpu ().numpy () for x in model .encode (list_batch , output_value = "token_embeddings" , convert_to_numpy = True )
47
- ]
48
76
49
- for tokenized_id , token_embedding in zip (tokenized_ids , token_embeddings , strict = True ):
50
- # Truncate to actual length of vectors, remove CLS and SEP.
51
- text = model .tokenizer .decode (tokenized_id [1 : len (token_embedding ) - 1 ])
77
+ for tokenized_id , token_embedding in zip (tokenized_ids , token_embeddings ):
78
+ # Convert token IDs to tokens (excluding special tokens)
79
+ token_ids = tokenized_id [1 :- 1 ]
80
+ # Decode tokens to text
81
+ text = model .tokenizer .decode (token_ids )
52
82
if text in seen :
53
83
continue
54
84
seen .add (text )
55
- mean = np .mean (token_embedding [1 :- 1 ], axis = 0 )
85
+ # Get the corresponding token embeddings (excluding special tokens)
86
+ token_embeds = token_embedding [1 :- 1 ]
87
+ # Convert embeddings to NumPy arrays
88
+ token_embeds = token_embeds .detach ().cpu ().numpy ()
89
+ # Compute the mean of the token embeddings
90
+ mean = np .mean (token_embeds , axis = 0 )
56
91
txts .append (text )
57
92
means .append (mean )
58
93
total_means += 1
59
94
60
95
if total_means >= _MAX_MEANS :
61
- # Save the final batch and stop
62
- r = Reach (means , txts )
63
- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
96
+ save_data (means , txts , str (out_path / base_filename ))
64
97
return
65
98
66
99
if index > 0 and (index + 1 ) % _SAVE_INTERVAL == 0 :
67
- r = Reach (means , txts )
68
- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
100
+ save_data (means , txts , str (out_path / base_filename ))
69
101
txts = []
70
102
means = []
71
103
seen = set ()
72
104
else :
73
- if means :
74
- r = Reach (means , txts )
75
- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
105
+ if txts and means :
106
+ save_data (means , txts , str (out_path / base_filename ))
76
107
77
108
78
- if __name__ == "__main__" :
79
- parser = argparse .ArgumentParser (description = "Train a Model2Vec using tokenlearn." )
109
+ def main () -> None :
110
+ """Main function to featurize texts using a sentence transformer."""
111
+ parser = argparse .ArgumentParser (description = "Featurize texts using a sentence transformer." )
80
112
parser .add_argument (
81
113
"--model-name" ,
82
114
type = str ,
83
115
default = "baai/bge-base-en-v1.5" ,
84
116
help = "The model name for distillation (e.g., 'baai/bge-base-en-v1.5')." ,
85
117
)
118
+ parser .add_argument (
119
+ "--output-dir" ,
120
+ type = str ,
121
+ default = "data/c4_bgebase" ,
122
+ help = "Directory to save the featurized texts." ,
123
+ )
86
124
args = parser .parse_args ()
125
+
87
126
model = SentenceTransformer (args .model_name )
88
127
dataset = load_dataset ("allenai/c4" , name = "en" , split = "train" , streaming = True )
89
- featurize (dataset , model , "data/c4_bgebase" )
128
+ featurize (dataset , model , args .output_dir )
129
+
130
+
131
+ if __name__ == "__main__" :
132
+ main ()
0 commit comments