-
Notifications
You must be signed in to change notification settings - Fork 0
/
vector_database.py
64 lines (58 loc) · 2.06 KB
/
vector_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import psycopg2
import json
from psycopg2 import extras
# Database connection parameters (fill these in)
DB_PARAMS = {
'host': "<your_host>",
'dbname': "postgres",
'user': "<your_user>",
'password': "<your_password>",
'sslmode': "require"
}
# Utility function to connect to the database
def get_db_connection():
return psycopg2.connect(**DB_PARAMS)
# Function to create an index with a specific name and dimension size
def create_index(name, dimension_size):
conn = get_db_connection()
with conn.cursor() as cur:
# Create table if it does not exist
cur.execute(f'''
CREATE TABLE IF NOT EXISTS {name} (
id bigserial PRIMARY KEY,
embedding vector({dimension_size}),
metadata JSONB
);
''')
# Create index with the same name as the table
cur.execute(f'''
CREATE INDEX IF NOT EXISTS {name}_idx
ON {name} USING ivfflat (embedding);
''')
conn.commit()
conn.close()
return name # Return the name of the table
# Function to upsert one or more vectors
def upsert_vectors(table_name, vectors): # Add table_name parameter
conn = get_db_connection()
with conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, f'''
INSERT INTO {table_name} (id, embedding, metadata) VALUES (%s, %s, %s)
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, metadata = EXCLUDED.metadata;
''', vectors)
conn.commit()
conn.close()
# Function to query a given vector against the vectors stored in the DB
def query_vector(table_name, input_vector, limit=5): # Add table_name parameter
conn = get_db_connection()
with conn.cursor() as cur:
input_vector_str = ','.join(map(str, input_vector))
input_vector_formatted = f"'[{input_vector_str}]'"
cur.execute(f'''
SELECT id, embedding, metadata FROM {table_name}
ORDER BY embedding <-> {input_vector_formatted}
LIMIT {limit};
''')
results = cur.fetchall()
conn.close()
return results