-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathnpy_to_lmdb.py
116 lines (95 loc) · 4.16 KB
/
npy_to_lmdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import glob
import os
import pickle
import lmdb
import numpy as np
import tqdm
import base64
class LMDBConversion:
def __init__(self):
self.args = self.get_parser().parse_args()
def get_parser(self):
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"--mode",
required=True,
type=str,
help="Mode can either be `convert` (for conversion of \n"
+ "features to an LMDB file) or `extract` (extract \n"
+ "raw features from a LMDB file)",
)
parser.add_argument("--lmdb_path", required=True, type=str, help="LMDB file path")
parser.add_argument("--features_folder", required=True, type=str, help="Features folder")
return parser
def convert(self):
env = lmdb.open(self.args.lmdb_path, map_size=1099511627776, writemap=True)
id_list = []
all_features = glob.glob(os.path.join(self.args.features_folder, "**", "*.npy"), recursive=True)
features = []
for feature in all_features:
if not feature.endswith("_info.npy"):
features.append(feature)
with env.begin(write=True) as txn:
for infile in tqdm.tqdm(features):
reader = np.load(infile, allow_pickle=True)
item = {}
split = os.path.relpath(infile, self.args.features_folder).split(".npy")[0]
item["feature_path"] = split
key = split.encode()
id_list.append(key)
item["features"] = reader
info_file = infile[:-4] + "_info.npy"
if not os.path.isfile(info_file):
txn.put(key, pickle.dumps(item))
raise ValueError(f"Missing {info_file}")
# continue
reader = np.load(info_file, allow_pickle=True)
item["img_h"] = reader.item().get("image_height")
item["img_w"] = reader.item().get("image_width")
item["num_boxes"] = reader.item().get("num_boxes")
item["objects"] = reader.item().get("objects")
item["cls_prob"] = reader.item().get("cls_prob", None)
item["boxes"] = reader.item().get("bbox")
item["features"] = base64.b64encode(item["features"])
item["boxes"] = base64.b64encode(item["boxes"])
txn.put(key, pickle.dumps(item))
txn.put(b"keys", pickle.dumps(id_list))
def extract(self):
os.makedirs(self.args.features_folder, exist_ok=True)
env = lmdb.open(
self.args.lmdb_path,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
with env.begin(write=False) as txn:
_image_ids = pickle.loads(txn.get(b"keys"))
for img_id in tqdm.tqdm(_image_ids):
item = pickle.loads(txn.get(img_id))
img_id = img_id.decode("utf-8")
tmp_dict = {}
tmp_dict["image_id"] = img_id
tmp_dict["bbox"] = item["bbox"]
tmp_dict["num_boxes"] = item["num_boxes"]
tmp_dict["image_height"] = item["image_width"]
tmp_dict["iimage_width"] = item["image_width"]
tmp_dict["objects"] = item["objects"]
tmp_dict["cls_prob"] = item["cls_prob"]
info_file_base_name = str(img_id) + "_info.npy"
file_base_name = str(img_id) + ".npy"
np.save(os.path.join(self.args.features_folder, file_base_name), item["features"])
np.save(os.path.join(self.args.features_folder, info_file_base_name), tmp_dict)
def execute(self):
if self.args.mode == "convert":
self.convert()
elif self.args.mode == "extract":
self.extract()
else:
raise ValueError("mode must be either `convert` or `extract` ")
if __name__ == "__main__":
lmdb_converter = LMDBConversion()
lmdb_converter.execute()