Skip to content

Commit 56f76e4

Browse files
committed
feat: Add PGChatMessageHistory
1 parent 164810f commit 56f76e4

File tree

6 files changed

+712
-0
lines changed

6 files changed

+712
-0
lines changed

langchain_postgres/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
44
from langchain_postgres.translator import PGVectorTranslator
5+
from langchain_postgres.v2.chat_message_history import PGChatMessageHistory
56
from langchain_postgres.v2.engine import Column, ColumnDict, PGEngine
67
from langchain_postgres.v2.vectorstores import PGVectorStore
78
from langchain_postgres.vectorstores import PGVector
@@ -18,6 +19,7 @@
1819
"ColumnDict",
1920
"PGEngine",
2021
"PostgresChatMessageHistory",
22+
"PGChatMessageHistory",
2123
"PGVector",
2224
"PGVectorStore",
2325
"PGVectorTranslator",
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from typing import Sequence
5+
6+
from langchain_core.chat_history import BaseChatMessageHistory
7+
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
8+
from sqlalchemy import RowMapping, text
9+
from sqlalchemy.ext.asyncio import AsyncEngine
10+
11+
from .engine import PGEngine
12+
13+
14+
class AsyncPGChatMessageHistory(BaseChatMessageHistory):
15+
"""Chat message history stored in a PostgreSQL database."""
16+
17+
__create_key = object()
18+
19+
def __init__(
20+
self,
21+
key: object,
22+
pool: AsyncEngine,
23+
session_id: str,
24+
table_name: str,
25+
store_message: bool,
26+
schema_name: str = "public",
27+
):
28+
"""AsyncPGChatMessageHistory constructor.
29+
30+
Args:
31+
key (object): Key to prevent direct constructor usage.
32+
engine (PGEngine): Database connection pool.
33+
session_id (str): Retrieve the table content with this session ID.
34+
table_name (str): Table name that stores the chat message history.
35+
store_message (bool): Whether to store the whole message or store data & type seperately
36+
schema_name (str): The schema name where the table is located (default: "public").
37+
38+
Raises:
39+
Exception: If constructor is directly called by the user.
40+
"""
41+
if key != AsyncPGChatMessageHistory.__create_key:
42+
raise Exception(
43+
"Only create class through 'create' or 'create_sync' methods!"
44+
)
45+
self.pool = pool
46+
self.session_id = session_id
47+
self.table_name = table_name
48+
self.schema_name = schema_name
49+
self.store_message = store_message
50+
51+
@classmethod
52+
async def create(
53+
cls,
54+
engine: PGEngine,
55+
session_id: str,
56+
table_name: str,
57+
schema_name: str = "public",
58+
) -> AsyncPGChatMessageHistory:
59+
"""Create a new AsyncPGChatMessageHistory instance.
60+
61+
Args:
62+
engine (PGEngine): PGEngine to use.
63+
session_id (str): Retrieve the table content with this session ID.
64+
table_name (str): Table name that stores the chat message history.
65+
schema_name (str): The schema name where the table is located (default: "public").
66+
67+
Raises:
68+
IndexError: If the table provided does not contain required schema.
69+
70+
Returns:
71+
AsyncPGChatMessageHistory: A newly created instance of AsyncPGChatMessageHistory.
72+
"""
73+
column_names = await engine._aload_table_schema(table_name, schema_name)
74+
75+
required_columns = ["id", "session_id", "data", "type"]
76+
supported_columns = ["id", "session_id", "message", "created_at"]
77+
78+
if not (all(x in column_names for x in required_columns)):
79+
if not (all(x in column_names for x in supported_columns)):
80+
raise IndexError(
81+
f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got "
82+
f"column names '{column_names}' but required column names "
83+
f"'{required_columns}'.\nPlease create table with following schema:"
84+
f"\nCREATE TABLE {schema_name}.{table_name} ("
85+
"\n id INT AUTO_INCREMENT PRIMARY KEY,"
86+
"\n session_id TEXT NOT NULL,"
87+
"\n data JSON NOT NULL,"
88+
"\n type TEXT NOT NULL"
89+
"\n);"
90+
)
91+
92+
store_message = True if "message" in column_names else False
93+
94+
return cls(
95+
cls.__create_key,
96+
engine._pool,
97+
session_id,
98+
table_name,
99+
store_message,
100+
schema_name,
101+
)
102+
103+
def _insert_query(self, message: BaseMessage) -> tuple[str, dict]:
104+
if self.store_message:
105+
query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, message) VALUES (:session_id, :message)"""
106+
params = {
107+
"message": json.dumps(message_to_dict(message)),
108+
"session_id": self.session_id,
109+
}
110+
else:
111+
query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) VALUES (:session_id, :data, :type)"""
112+
params = {
113+
"data": json.dumps(message.model_dump()),
114+
"session_id": self.session_id,
115+
"type": message.type,
116+
}
117+
118+
return query, params
119+
120+
async def aadd_message(self, message: BaseMessage) -> None:
121+
"""Append the message to the record in Postgres"""
122+
query, params = self._insert_query(message)
123+
async with self.pool.connect() as conn:
124+
await conn.execute(text(query), params)
125+
await conn.commit()
126+
127+
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
128+
"""Append a list of messages to the record in Postgres"""
129+
for message in messages:
130+
await self.aadd_message(message)
131+
132+
async def aclear(self) -> None:
133+
"""Clear session memory from Postgres"""
134+
query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;"""
135+
async with self.pool.connect() as conn:
136+
await conn.execute(text(query), {"session_id": self.session_id})
137+
await conn.commit()
138+
139+
def _select_query(self) -> str:
140+
if self.store_message:
141+
return f"""SELECT message FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
142+
else:
143+
return f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
144+
145+
def _convert_to_messages(self, rows: Sequence[RowMapping]) -> list[BaseMessage]:
146+
if self.store_message:
147+
items = [row["message"] for row in rows]
148+
messages = messages_from_dict(items)
149+
else:
150+
items = [{"data": row["data"], "type": row["type"]} for row in rows]
151+
messages = messages_from_dict(items)
152+
return messages
153+
154+
async def _aget_messages(self) -> list[BaseMessage]:
155+
"""Retrieve the messages from Postgres."""
156+
157+
query = self._select_query()
158+
159+
async with self.pool.connect() as conn:
160+
result = await conn.execute(text(query), {"session_id": self.session_id})
161+
result_map = result.mappings()
162+
results = result_map.fetchall()
163+
if not results:
164+
return []
165+
166+
messages = self._convert_to_messages(results)
167+
return messages
168+
169+
def clear(self) -> None:
170+
raise NotImplementedError(
171+
"Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead."
172+
)
173+
174+
def add_message(self, message: BaseMessage) -> None:
175+
raise NotImplementedError(
176+
"Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead."
177+
)
178+
179+
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
180+
raise NotImplementedError(
181+
"Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead."
182+
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
from typing import Sequence
4+
5+
from langchain_core.chat_history import BaseChatMessageHistory
6+
from langchain_core.messages import BaseMessage
7+
8+
from .async_chat_message_history import AsyncPGChatMessageHistory
9+
from .engine import PGEngine
10+
11+
12+
class PGChatMessageHistory(BaseChatMessageHistory):
13+
"""Chat message history stored in a PostgreSQL database."""
14+
15+
__create_key = object()
16+
17+
def __init__(
18+
self,
19+
key: object,
20+
engine: PGEngine,
21+
history: AsyncPGChatMessageHistory,
22+
):
23+
"""PGChatMessageHistory constructor.
24+
25+
Args:
26+
key (object): Key to prevent direct constructor usage.
27+
engine (PGEngine): Database connection pool.
28+
history (AsyncPGChatMessageHistory): Async only implementation.
29+
30+
Raises:
31+
Exception: If constructor is directly called by the user.
32+
"""
33+
if key != PGChatMessageHistory.__create_key:
34+
raise Exception(
35+
"Only create class through 'create' or 'create_sync' methods!"
36+
)
37+
self._engine = engine
38+
self.__history = history
39+
40+
@classmethod
41+
async def create(
42+
cls,
43+
engine: PGEngine,
44+
session_id: str,
45+
table_name: str,
46+
schema_name: str = "public",
47+
) -> PGChatMessageHistory:
48+
"""Create a new PGChatMessageHistory instance.
49+
50+
Args:
51+
engine (PGEngine): PGEngine to use.
52+
session_id (str): Retrieve the table content with this session ID.
53+
table_name (str): Table name that stores the chat message history.
54+
schema_name (str): The schema name where the table is located (default: "public").
55+
56+
Raises:
57+
IndexError: If the table provided does not contain required schema.
58+
59+
Returns:
60+
PGChatMessageHistory: A newly created instance of PGChatMessageHistory.
61+
"""
62+
coro = AsyncPGChatMessageHistory.create(
63+
engine, session_id, table_name, schema_name
64+
)
65+
history = await engine._run_as_async(coro)
66+
return cls(cls.__create_key, engine, history)
67+
68+
@classmethod
69+
def create_sync(
70+
cls,
71+
engine: PGEngine,
72+
session_id: str,
73+
table_name: str,
74+
schema_name: str = "public",
75+
) -> PGChatMessageHistory:
76+
"""Create a new PGChatMessageHistory instance.
77+
78+
Args:
79+
engine (PGEngine): PGEngine to use.
80+
session_id (str): Retrieve the table content with this session ID.
81+
table_name (str): Table name that stores the chat message history.
82+
schema_name: The schema name where the table is located (default: "public").
83+
84+
Raises:
85+
IndexError: If the table provided does not contain required schema.
86+
87+
Returns:
88+
PGChatMessageHistory: A newly created instance of PGChatMessageHistory.
89+
"""
90+
coro = AsyncPGChatMessageHistory.create(
91+
engine, session_id, table_name, schema_name
92+
)
93+
history = engine._run_as_sync(coro)
94+
return cls(cls.__create_key, engine, history)
95+
96+
@property
97+
def messages(self) -> list[BaseMessage]:
98+
"""Fetches all messages stored in Postgres."""
99+
return self._engine._run_as_sync(self.__history._aget_messages())
100+
101+
@messages.setter
102+
def messages(self, value: list[BaseMessage]) -> None:
103+
"""Clear the stored messages and appends a list of messages to the record in Postgres."""
104+
self.clear()
105+
self.add_messages(value)
106+
107+
async def aadd_message(self, message: BaseMessage) -> None:
108+
"""Append the message to the record in Postgres"""
109+
await self._engine._run_as_async(self.__history.aadd_message(message))
110+
111+
def add_message(self, message: BaseMessage) -> None:
112+
"""Append the message to the record in Postgres"""
113+
self._engine._run_as_sync(self.__history.aadd_message(message))
114+
115+
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
116+
"""Append a list of messages to the record in Postgres"""
117+
await self._engine._run_as_async(self.__history.aadd_messages(messages))
118+
119+
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
120+
"""Append a list of messages to the record in Postgres"""
121+
self._engine._run_as_sync(self.__history.aadd_messages(messages))
122+
123+
async def aclear(self) -> None:
124+
"""Clear session memory from Postgres"""
125+
await self._engine._run_as_async(self.__history.aclear())
126+
127+
def clear(self) -> None:
128+
"""Clear session memory from Postgres"""
129+
self._engine._run_as_sync(self.__history.aclear())

0 commit comments

Comments
 (0)