-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
65 lines (51 loc) · 2.03 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from Embeddings.EmbeddingModelOpenAI import EmbeddingModelOpenAI
from data import data_utils
import pickle
import time
import os
import logging
logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
embeddings_model = EmbeddingModelOpenAI()
data = data_utils.load_token_data()
data["text"] = (
data["text"].apply(lambda x: " ".join(x)).tolist()
) # join tokens to text
dict_chunk_list = []
for doc_id, text, author in zip(data.index, data["text"], data["author"]):
dict_chunk_list.extend(
data_utils.create_chunk_dict_list(text, {"id": doc_id, "author": author})
)
# create folder if not exists
if not os.path.exists("out/gutenberg_chunked"):
os.makedirs("out/gutenberg_chunked")
datetime = time.strftime("%Y%m%d_%H%M%S")
data_dump_path = os.path.join(
os.path.dirname(__file__),
"out",
"gutenberg_chunked",
f"test01_{datetime}.pickle",
)
with open(data_dump_path, "wb") as f:
pickle.dump(
dict_chunk_list, f, protocol=3
) # use protocol 3 for compatibility with colab and python 3.6 ?
logger.info(f"Successfully saved chunked data to {data_dump_path}.")
chunk_list = [chunk_dict["text"] for chunk_dict in dict_chunk_list]
corpus_embeddings = embeddings_model.embed_document_list(chunk_list)
assert len(corpus_embeddings) == len(
dict_chunk_list
), "Number of embeddings and number of chunks do not match."
for i, chunk_dict in enumerate(dict_chunk_list):
chunk_dict["embedding"] = corpus_embeddings[i]
data_dump_path = os.path.join(
os.path.dirname(__file__),
"out",
"gutenberg_chunked",
f"test01_embeddings_{datetime}.pickle",
)
with open(data_dump_path, "wb") as f:
pickle.dump(dict_chunk_list, f, protocol=3)
logger.info(f"Successfully saved embedding data to {data_dump_path}.")
logger.info("Done")