forked from arjunsk/mo-benchmark-test
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
128 lines (99 loc) · 3.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# CREATE DATABASE a;
# USE a;
# CREATE TABLE speedtest (id int, one_k_vector vecf32(1024));
# CREATE TABLE speedtest (id int, sequence_id int, token_id int, layer_id int, one_k_vector vecf32(1024));
import binascii
import time
import numpy as np
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
table_name = "speedtest"
vec_len = 1024
num_sql_stmt = 1024 * 8
num_vector_per_sql_stmt = 5
start_id = None
def to_db_hex(value, dim=None):
if value is None:
return value
value = np.asarray(value, dtype='<f')
if value.ndim != 1:
raise ValueError('expected ndim to be 1')
# return value.tobytes()
return binascii.b2a_hex(value)
def from_db_hex(hex_str):
buf = binascii.a2b_hex(hex_str)
return np.frombuffer(buf, dtype='<f')
def get_max_id(engine):
with engine.connect() as con:
rs = con.execute(text('SELECT MAX(id) FROM speedtest'))
max_id = next(rs)[0]
return max_id
def correctness_test():
engine = create_engine("mysql+pymysql://root:[email protected]:6001/a")
# insert 2 new rows
new_id = get_max_id(engine) + 1
Session = sessionmaker(bind=engine)
session = Session()
data = [np.random.rand(vec_len), np.random.rand(vec_len)]
sql_insert = text("insert into speedtest (id, one_k_vector) values(:id, decode(:data,'hex') );")
for i, arr in enumerate(data):
session.execute(sql_insert, {"id": new_id + i, "data": to_db_hex(arr)})
session.commit()
new_max_id = get_max_id(engine)
assert new_max_id == new_id + 1
# select using string output
with engine.connect() as con:
rs = con.execute(text(f'SELECT * FROM speedtest WHERE id >= {new_id} order by id'))
data_read = []
for row in rs:
s = row[1].lstrip().lstrip("[").rstrip().rstrip("]")
data_read.append(np.fromstring(s, sep=","))
assert np.allclose(data, data_read)
# select using hex output
with engine.connect() as con:
rs = con.execute(text(f'SELECT id, encode(one_k_vector, "hex") FROM speedtest WHERE id >= {new_id} order by id'))
data_read = []
for row in rs:
data_read.append(from_db_hex(row[1]))
assert np.allclose(data, data_read)
def insert():
engine = create_engine("mysql+pymysql://root:[email protected]:6001/a")
Session = sessionmaker(bind=engine)
session = Session()
global start_id
start_id = get_max_id(engine) + 1
sql_insert = text("insert into speedtest (id, one_k_vector) values(:id, decode(:data,'hex') );")
for i in range(num_sql_stmt * num_vector_per_sql_stmt):
arr = np.random.rand(vec_len)
# print(arr)
session.execute(sql_insert, {"id": start_id + i, "data": to_db_hex(arr)})
session.commit()
def select_string():
engine = create_engine("mysql+pymysql://root:[email protected]:6001/a")
global start_id
with engine.connect() as con:
rs = con.execute(text(f'SELECT * FROM speedtest WHERE id >= {start_id}'))
total = 0
for row in rs:
s = row[1].lstrip().lstrip("[").rstrip().rstrip("]")
v = np.fromstring(s, sep=",")
total += 1
assert total == num_sql_stmt * num_vector_per_sql_stmt
def select_hex():
engine = create_engine("mysql+pymysql://root:[email protected]:6001/a")
global start_id
with engine.connect() as con:
rs = con.execute(text(f'SELECT * FROM speedtest WHERE id >= {start_id}'))
total = 0
for row in rs:
v = from_db_hex(row[1])
total += 1
assert total == num_sql_stmt * num_vector_per_sql_stmt
correctness_test()
for func in [insert, select_string, select_string]:
start = time.time()
func()
duration = time.time() - start
print(f"{func.__name__} Result: vector dim={vec_len} vectors "
f"rows={num_sql_stmt * num_vector_per_sql_stmt} "
f"rows/second={num_sql_stmt * num_vector_per_sql_stmt / duration}")