11
11
from tqdm import tqdm
12
12
13
13
_SAVE_INTERVAL = 10
14
- _MAX_MEANS = 1100000
15
-
16
14
logger = logging .getLogger (__name__ )
17
15
18
16
@@ -35,13 +33,14 @@ def save_data(means: list[np.ndarray], txts: list[str], base_filepath: str) -> N
35
33
logger .info (f"Saved { len (txts )} texts to { items_filepath } and vectors to { vectors_filepath } " )
36
34
37
35
38
- def featurize (texts : Iterable [str ], model : SentenceTransformer , output_dir : str ) -> None :
36
+ def featurize (texts : Iterable [str ], model : SentenceTransformer , output_dir : str , max_means : int ) -> None :
39
37
"""
40
38
Featurize text using a sentence transformer.
41
39
42
40
:param texts: Iterable of texts to featurize.
43
41
:param model: SentenceTransformer model to use.
44
42
:param output_dir: Directory to save the featurized texts.
43
+ :param max_means: Maximum number of mean embeddings to generate.
45
44
:raises ValueError: If the model does not have a fixed dimension.
46
45
"""
47
46
out_path = Path (output_dir )
@@ -58,8 +57,6 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
58
57
for index , batch in enumerate (tqdm (batched (texts , 32 ))):
59
58
i = index // _SAVE_INTERVAL
60
59
base_filename = f"featurized_{ i } "
61
- vectors_filepath = out_path / (base_filename + "_vectors.npy" )
62
- items_filepath = out_path / (base_filename + "_items.json" )
63
60
list_batch = [x ["text" ].strip () for x in batch if x .get ("text" )]
64
61
if not list_batch :
65
62
continue # Skip empty batches
@@ -92,7 +89,7 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
92
89
means .append (mean )
93
90
total_means += 1
94
91
95
- if total_means >= _MAX_MEANS :
92
+ if total_means >= max_means :
96
93
save_data (means , txts , str (out_path / base_filename ))
97
94
return
98
95
@@ -121,11 +118,46 @@ def main() -> None:
121
118
default = "data/c4_bgebase" ,
122
119
help = "Directory to save the featurized texts." ,
123
120
)
121
+ parser .add_argument (
122
+ "--dataset-path" ,
123
+ type = str ,
124
+ default = "allenai/c4" ,
125
+ help = "The dataset path or name (e.g. 'allenai/c4')." ,
126
+ )
127
+ parser .add_argument (
128
+ "--dataset-name" ,
129
+ type = str ,
130
+ default = "en" ,
131
+ help = "The dataset configuration name (e.g., 'en' for C4)." ,
132
+ )
133
+ parser .add_argument (
134
+ "--dataset-split" ,
135
+ type = str ,
136
+ default = "train" ,
137
+ help = "The dataset split (e.g., 'train', 'validation')." ,
138
+ )
139
+ parser .add_argument (
140
+ "--no-streaming" ,
141
+ action = "store_true" ,
142
+ help = "Disable streaming mode when loading the dataset." ,
143
+ )
144
+ parser .add_argument (
145
+ "--max-means" ,
146
+ type = int ,
147
+ default = 1000000 ,
148
+ help = "The maximum number of mean embeddings to generate." ,
149
+ )
150
+
124
151
args = parser .parse_args ()
125
152
126
153
model = SentenceTransformer (args .model_name )
127
- dataset = load_dataset ("allenai/c4" , name = "en" , split = "train" , streaming = True )
128
- featurize (dataset , model , args .output_dir )
154
+ dataset = load_dataset (
155
+ args .dataset_path ,
156
+ name = args .dataset_name ,
157
+ split = args .dataset_split ,
158
+ streaming = not args .no_streaming ,
159
+ )
160
+ featurize (dataset , model , args .output_dir , args .max_means )
129
161
130
162
131
163
if __name__ == "__main__" :
0 commit comments