-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathcheck_dump_bin.py
158 lines (144 loc) · 5.65 KB
/
check_dump_bin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import datacompy
import fire
import pandas as pd
import qlib
from loguru import logger
from qlib.data import D
from tqdm import tqdm
class CheckBin:
NOT_IN_FEATURES = "not in features"
COMPARE_FALSE = "compare False"
COMPARE_TRUE = "compare True"
COMPARE_ERROR = "compare error"
def __init__(
self,
qlib_dir: str,
csv_path: str,
check_fields: str = None,
freq: str = "day",
symbol_field_name: str = "symbol",
date_field_name: str = "date",
file_suffix: str = ".csv",
max_workers: int = 16,
):
"""
Parameters
----------
qlib_dir : str
qlib dir
csv_path : str
origin csv path
check_fields : str, optional
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
freq : str, optional
freq, value from ["day", "1m"]
symbol_field_name: str, optional
symbol field name, by default "symbol"
date_field_name: str, optional
date field name, by default "date"
file_suffix: str, optional
csv file suffix, by default ".csv"
max_workers: int, optional
max workers, by default 16
"""
self.qlib_dir = Path(qlib_dir).expanduser()
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
qlib.init(
provider_uri=str(self.qlib_dir.resolve()),
mount_path=str(self.qlib_dir.resolve()),
auto_mount=False,
redis_port=-1,
)
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(
csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path]
)
if check_fields is None:
check_fields = list(
map(lambda x: x.name.split(".")[0], bin_path_list[0].glob("*.bin"))
)
else:
check_fields = (
check_fields.split(",")
if isinstance(check_fields, str)
else check_fields
)
self.check_fields = list(map(lambda x: x.strip(), check_fields))
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
self.max_workers = max_workers
self.symbol_field_name = symbol_field_name
self.date_field_name = date_field_name
self.freq = freq
self.file_suffix = file_suffix
def _compare(self, file_path: Path):
symbol = file_path.name.strip(self.file_suffix)
if symbol.lower() not in self.qlib_symbols:
return self.NOT_IN_FEATURES
# qlib data
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
qlib_df.rename(
columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True
)
# csv data
origin_df = pd.read_csv(file_path)
origin_df[self.date_field_name] = pd.to_datetime(
origin_df[self.date_field_name]
)
if self.symbol_field_name not in origin_df.columns:
origin_df[self.symbol_field_name] = symbol
origin_df.set_index(
[self.symbol_field_name, self.date_field_name], inplace=True
)
origin_df.index.names = qlib_df.index.names
origin_df = origin_df.reindex(qlib_df.index)
try:
compare = datacompy.Compare(
origin_df,
qlib_df,
on_index=True,
abs_tol=1e-08, # Optional, defaults to 0
rel_tol=1e-05, # Optional, defaults to 0
df1_name="Original", # Optional, defaults to 'df1'
df2_name="New", # Optional, defaults to 'df2'
)
_r = compare.matches(ignore_extra_columns=True)
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
except Exception as e:
logger.warning(f"{symbol} compare error: {e}")
return self.COMPARE_ERROR
def check(self):
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data"""
logger.info("start check......")
error_list = []
not_in_features = []
compare_false = []
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for file_path, _check_res in zip(
self.csv_files, executor.map(self._compare, self.csv_files)
):
symbol = file_path.name.strip(self.file_suffix)
if _check_res == self.NOT_IN_FEATURES:
not_in_features.append(symbol)
elif _check_res == self.COMPARE_ERROR:
error_list.append(symbol)
elif _check_res == self.COMPARE_FALSE:
compare_false.append(symbol)
p_bar.update()
logger.info("end of check......")
if error_list:
logger.warning(f"compare error: {error_list}")
if not_in_features:
logger.warning(f"not in features: {not_in_features}")
if compare_false:
logger.warning(f"compare False: {compare_false}")
logger.info(
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
)
if __name__ == "__main__":
fire.Fire(CheckBin)