From 56f76e4e0e7d9f0a82fca34f194551286d3d20cc Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Tue, 22 Apr 2025 13:26:34 +0000 Subject: [PATCH 1/6] feat: Add PGChatMessageHistory --- langchain_postgres/__init__.py | 2 + .../v2/async_chat_message_history.py | 182 ++++++++++++++++++ langchain_postgres/v2/chat_message_history.py | 129 +++++++++++++ langchain_postgres/v2/engine.py | 96 +++++++++ .../v2/test_async_chat_message_history.py | 128 ++++++++++++ .../v2/test_chat_message_history.py | 175 +++++++++++++++++ 6 files changed, 712 insertions(+) create mode 100644 langchain_postgres/v2/async_chat_message_history.py create mode 100644 langchain_postgres/v2/chat_message_history.py create mode 100644 tests/unit_tests/v2/test_async_chat_message_history.py create mode 100644 tests/unit_tests/v2/test_chat_message_history.py diff --git a/langchain_postgres/__init__.py b/langchain_postgres/__init__.py index 15a9230..6b73876 100644 --- a/langchain_postgres/__init__.py +++ b/langchain_postgres/__init__.py @@ -2,6 +2,7 @@ from langchain_postgres.chat_message_histories import PostgresChatMessageHistory from langchain_postgres.translator import PGVectorTranslator +from langchain_postgres.v2.chat_message_history import PGChatMessageHistory from langchain_postgres.v2.engine import Column, ColumnDict, PGEngine from langchain_postgres.v2.vectorstores import PGVectorStore from langchain_postgres.vectorstores import PGVector @@ -18,6 +19,7 @@ "ColumnDict", "PGEngine", "PostgresChatMessageHistory", + "PGChatMessageHistory", "PGVector", "PGVectorStore", "PGVectorTranslator", diff --git a/langchain_postgres/v2/async_chat_message_history.py b/langchain_postgres/v2/async_chat_message_history.py new file mode 100644 index 0000000..abb8202 --- /dev/null +++ b/langchain_postgres/v2/async_chat_message_history.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import json +from typing import Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PGEngine + + +class AsyncPGChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + session_id: str, + table_name: str, + store_message: bool, + schema_name: str = "public", + ): + """AsyncPGChatMessageHistory constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PGEngine): Database connection pool. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + store_message (bool): Whether to store the whole message or store data & type seperately + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPGChatMessageHistory.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self.pool = pool + self.session_id = session_id + self.table_name = table_name + self.schema_name = schema_name + self.store_message = store_message + + @classmethod + async def create( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> AsyncPGChatMessageHistory: + """Create a new AsyncPGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncPGChatMessageHistory: A newly created instance of AsyncPGChatMessageHistory. + """ + column_names = await engine._aload_table_schema(table_name, schema_name) + + required_columns = ["id", "session_id", "data", "type"] + supported_columns = ["id", "session_id", "message", "created_at"] + + if not (all(x in column_names for x in required_columns)): + if not (all(x in column_names for x in supported_columns)): + raise IndexError( + f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got " + f"column names '{column_names}' but required column names " + f"'{required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.{table_name} (" + "\n id INT AUTO_INCREMENT PRIMARY KEY," + "\n session_id TEXT NOT NULL," + "\n data JSON NOT NULL," + "\n type TEXT NOT NULL" + "\n);" + ) + + store_message = True if "message" in column_names else False + + return cls( + cls.__create_key, + engine._pool, + session_id, + table_name, + store_message, + schema_name, + ) + + def _insert_query(self, message: BaseMessage) -> tuple[str, dict]: + if self.store_message: + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, message) VALUES (:session_id, :message)""" + params = { + "message": json.dumps(message_to_dict(message)), + "session_id": self.session_id, + } + else: + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) VALUES (:session_id, :data, :type)""" + params = { + "data": json.dumps(message.model_dump()), + "session_id": self.session_id, + "type": message.type, + } + + return query, params + + async def aadd_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + query, params = self._insert_query(message) + async with self.pool.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + for message in messages: + await self.aadd_message(message) + + async def aclear(self) -> None: + """Clear session memory from Postgres""" + query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;""" + async with self.pool.connect() as conn: + await conn.execute(text(query), {"session_id": self.session_id}) + await conn.commit() + + def _select_query(self) -> str: + if self.store_message: + return f"""SELECT message FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" + else: + return f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" + + def _convert_to_messages(self, rows: Sequence[RowMapping]) -> list[BaseMessage]: + if self.store_message: + items = [row["message"] for row in rows] + messages = messages_from_dict(items) + else: + items = [{"data": row["data"], "type": row["type"]} for row in rows] + messages = messages_from_dict(items) + return messages + + async def _aget_messages(self) -> list[BaseMessage]: + """Retrieve the messages from Postgres.""" + + query = self._select_query() + + async with self.pool.connect() as conn: + result = await conn.execute(text(query), {"session_id": self.session_id}) + result_map = result.mappings() + results = result_map.fetchall() + if not results: + return [] + + messages = self._convert_to_messages(results) + return messages + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) + + def add_message(self, message: BaseMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) diff --git a/langchain_postgres/v2/chat_message_history.py b/langchain_postgres/v2/chat_message_history.py new file mode 100644 index 0000000..548935b --- /dev/null +++ b/langchain_postgres/v2/chat_message_history.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage + +from .async_chat_message_history import AsyncPGChatMessageHistory +from .engine import PGEngine + + +class PGChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: PGEngine, + history: AsyncPGChatMessageHistory, + ): + """PGChatMessageHistory constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PGEngine): Database connection pool. + history (AsyncPGChatMessageHistory): Async only implementation. + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != PGChatMessageHistory.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._engine = engine + self.__history = history + + @classmethod + async def create( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> PGChatMessageHistory: + """Create a new PGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + PGChatMessageHistory: A newly created instance of PGChatMessageHistory. + """ + coro = AsyncPGChatMessageHistory.create( + engine, session_id, table_name, schema_name + ) + history = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, history) + + @classmethod + def create_sync( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> PGChatMessageHistory: + """Create a new PGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name: The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + PGChatMessageHistory: A newly created instance of PGChatMessageHistory. + """ + coro = AsyncPGChatMessageHistory.create( + engine, session_id, table_name, schema_name + ) + history = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, history) + + @property + def messages(self) -> list[BaseMessage]: + """Fetches all messages stored in Postgres.""" + return self._engine._run_as_sync(self.__history._aget_messages()) + + @messages.setter + def messages(self, value: list[BaseMessage]) -> None: + """Clear the stored messages and appends a list of messages to the record in Postgres.""" + self.clear() + self.add_messages(value) + + async def aadd_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + await self._engine._run_as_async(self.__history.aadd_message(message)) + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + self._engine._run_as_sync(self.__history.aadd_message(message)) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + await self._engine._run_as_async(self.__history.aadd_messages(messages)) + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + self._engine._run_as_sync(self.__history.aadd_messages(messages)) + + async def aclear(self) -> None: + """Clear session memory from Postgres""" + await self._engine._run_as_async(self.__history.aclear()) + + def clear(self) -> None: + """Clear session memory from Postgres""" + self._engine._run_as_sync(self.__history.aclear()) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index c2a0d93..5c97de3 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -347,6 +347,65 @@ def init_vectorstore_table( ) ) + async def _ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """ + Create a postgres table to save chat history messages. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat history table. + Default: "public". + + Returns: + None + """ + create_table_query = f"""CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL, + message JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + );""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.commit() + + async def ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """Create a postgres table to save chat history messages. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store chat history table. + Default: "public". + + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_history_table( + table_name, + schema_name, + ) + ) + + def init_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """Create a postgres table to store chat history. + + Args: + table_name (str): Table name to store chat history. + schema_name (str): The schema name to store chat history table. + Default: "public". + + Returns: + None + """ + self._run_as_sync(self._ainit_chat_history_table(table_name, schema_name)) + async def _adrop_table( self, table_name: str, @@ -378,3 +437,40 @@ async def drop_table( self._run_as_sync( self._adrop_table(table_name=table_name, schema_name=schema_name) ) + + async def _aload_table_schema( + self, table_name: str, schema_name: str = "public" + ) -> list[str]: + """ + Load table schema from an existing table in a PgSQL database, potentially from a specific database schema. + + Args: + table_name: The name of the table to load the table schema from. + schema_name: The name of the database schema where the table resides. + Default: "public". + + Returns: + (lsit[str]: list of all column names in the table.) + """ + + query = """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = :schema + AND table_name = :table + ORDER BY ordinal_position; + """ + + async with self._pool.connect() as conn: + result = await conn.execute( + text(query), {"schema": schema_name, "table": table_name} + ) + result_map = result.mappings() + results = result_map.fetchall() + + column_names = [row["column_name"] for row in results] + + if column_names: + return column_names + else: + raise ValueError(f'Table, "{schema_name}"."{table_name}", does not exist: ') diff --git a/tests/unit_tests/v2/test_async_chat_message_history.py b/tests/unit_tests/v2/test_async_chat_message_history.py new file mode 100644 index 0000000..8015f17 --- /dev/null +++ b/tests/unit_tests/v2/test_async_chat_message_history.py @@ -0,0 +1,128 @@ +import uuid +from typing import AsyncIterator + +import pytest +import pytest_asyncio +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage +from langchain_core.messages.system import SystemMessage +from sqlalchemy import text + +from langchain_postgres import PGEngine, PostgresChatMessageHistory +from langchain_postgres.v2.async_chat_message_history import ( + AsyncPGChatMessageHistory, +) +from tests.utils import VECTORSTORE_CONNECTION_STRING, asyncpg_client + +TABLE_NAME = "message_store" + str(uuid.uuid4()) +TABLE_NAME_ASYNC = "message_store" + str(uuid.uuid4()) + + +async def aexecute(engine: PGEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest_asyncio.fixture +async def async_engine() -> AsyncIterator[PGEngine]: + async_engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + await async_engine._ainit_chat_history_table(table_name=TABLE_NAME_ASYNC) + yield async_engine + # use default table for AsyncPGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME_ASYNC}"' + await aexecute(async_engine, query) + await async_engine.close() + + +@pytest.mark.asyncio +async def test_chat_message_history_async( + async_engine: PGEngine, +) -> None: + history = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history.aadd_message(msg1) + await history.aadd_message(msg2) + messages = await history._aget_messages() + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + await history.aclear() + assert len(await history._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_sync_messages( + async_engine: PGEngine, +) -> None: + history1 = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + history2 = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history1.aadd_message(msg1) + await history2.aadd_message(msg2) + + assert len(await history1._aget_messages()) == 2 + assert len(await history2._aget_messages()) == 2 + + # verify clear() clears message history + await history2.aclear() + assert len(await history2._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine: PGEngine) -> None: + with pytest.raises(ValueError): + await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_v1_schema_support(async_engine: PGEngine) -> None: + table_name = "chat_history" + session_id = str(uuid.UUID(int=125)) + async with asyncpg_client() as async_connection: + await PostgresChatMessageHistory.adrop_table(async_connection, table_name) + await PostgresChatMessageHistory.acreate_tables(async_connection, table_name) + + chat_history = PostgresChatMessageHistory( + table_name, session_id, async_connection=async_connection + ) + + await chat_history.aadd_messages( + [ + SystemMessage(content="Meow"), + AIMessage(content="woof"), + HumanMessage(content="bark"), + ] + ) + + history = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id=session_id, table_name=table_name + ) + + messages = await history._aget_messages() + + assert len(messages) == 3 + + msg1 = HumanMessage(content="hi!") + await history.aadd_message(msg1) + + messages = await history._aget_messages() + + assert len(messages) == 4 + + await async_engine._adrop_table(table_name=table_name) diff --git a/tests/unit_tests/v2/test_chat_message_history.py b/tests/unit_tests/v2/test_chat_message_history.py new file mode 100644 index 0000000..f8701e4 --- /dev/null +++ b/tests/unit_tests/v2/test_chat_message_history.py @@ -0,0 +1,175 @@ +import uuid +from typing import Any, AsyncIterator + +import pytest +import pytest_asyncio +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage +from sqlalchemy import text + +from langchain_postgres import PGChatMessageHistory, PGEngine +from tests.utils import VECTORSTORE_CONNECTION_STRING + +TABLE_NAME = "message_store" + str(uuid.uuid4()) +TABLE_NAME_ASYNC = "message_store" + str(uuid.uuid4()) + + +async def aexecute( + engine: PGEngine, + query: str, +) -> None: + async def run(engine: PGEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest_asyncio.fixture +async def engine() -> AsyncIterator[PGEngine]: + engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + engine.init_chat_history_table(table_name=TABLE_NAME) + yield engine + # use default table for PGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME}"' + await aexecute(engine, query) + await engine.close() + + +@pytest_asyncio.fixture +async def async_engine() -> AsyncIterator[PGEngine]: + async_engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + await async_engine.ainit_chat_history_table(table_name=TABLE_NAME_ASYNC) + yield async_engine + # use default table for PGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME_ASYNC}"' + await aexecute(async_engine, query) + await async_engine.close() + + +def test_chat_message_history(engine: PGEngine) -> None: + history = PGChatMessageHistory.create_sync( + engine=engine, session_id="test", table_name=TABLE_NAME + ) + history.add_user_message("hi!") + history.add_ai_message("whats up?") + messages = history.messages + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + history.clear() + assert len(history.messages) == 0 + + +def test_chat_table(engine: Any) -> None: + with pytest.raises(ValueError): + PGChatMessageHistory.create_sync( + engine=engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_chat_schema(engine: Any) -> None: + doc_table_name = "test_table" + str(uuid.uuid4()) + engine.init_document_table(table_name=doc_table_name) + with pytest.raises(IndexError): + PGChatMessageHistory.create_sync( + engine=engine, session_id="test", table_name=doc_table_name + ) + + query = f'DROP TABLE IF EXISTS "{doc_table_name}"' + await aexecute(engine, query) + + +@pytest.mark.asyncio +async def test_chat_message_history_async( + async_engine: PGEngine, +) -> None: + history = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history.aadd_message(msg1) + await history.aadd_message(msg2) + messages = history.messages + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + await history.aclear() + assert len(history.messages) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_sync_messages( + async_engine: PGEngine, +) -> None: + history1 = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + history2 = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history1.aadd_message(msg1) + await history2.aadd_message(msg2) + + assert len(history1.messages) == 2 + assert len(history2.messages) == 2 + + # verify clear() clears message history + await history2.aclear() + assert len(history2.messages) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_set_messages( + async_engine: PGEngine, +) -> None: + history = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="bye -_-") + # verify setting messages property adds to message history + history.messages = [msg1, msg2] + assert len(history.messages) == 2 + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine: PGEngine) -> None: + with pytest.raises(ValueError): + await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_cross_env_chat_message_history(engine: PGEngine) -> None: + history = PGChatMessageHistory.create_sync( + engine=engine, session_id="test_cross", table_name=TABLE_NAME + ) + await history.aadd_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear() + + history = await PGChatMessageHistory.create( + engine=engine, session_id="test_cross", table_name=TABLE_NAME + ) + history.add_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear() From a4846f3945c46a0065f9a38b2b06719a7bb19ea0 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Apr 2025 02:33:53 +0000 Subject: [PATCH 2/6] Add incorrect schema test --- .../v2/test_async_chat_message_history.py | 11 ++++++++ .../v2/test_chat_message_history.py | 25 +++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/v2/test_async_chat_message_history.py b/tests/unit_tests/v2/test_async_chat_message_history.py index 8015f17..30579a1 100644 --- a/tests/unit_tests/v2/test_async_chat_message_history.py +++ b/tests/unit_tests/v2/test_async_chat_message_history.py @@ -126,3 +126,14 @@ async def test_v1_schema_support(async_engine: PGEngine) -> None: assert len(messages) == 4 await async_engine._adrop_table(table_name=table_name) + + +async def test_incorrect_schema(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + await async_engine._ainit_vectorstore_table(table_name=table_name, vector_size=1024) + with pytest.raises(IndexError): + await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name + ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) diff --git a/tests/unit_tests/v2/test_chat_message_history.py b/tests/unit_tests/v2/test_chat_message_history.py index f8701e4..749d914 100644 --- a/tests/unit_tests/v2/test_chat_message_history.py +++ b/tests/unit_tests/v2/test_chat_message_history.py @@ -74,17 +74,26 @@ def test_chat_table(engine: Any) -> None: ) -@pytest.mark.asyncio -async def test_chat_schema(engine: Any) -> None: - doc_table_name = "test_table" + str(uuid.uuid4()) - engine.init_document_table(table_name=doc_table_name) +async def test_incorrect_schema_async(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + await async_engine.ainit_vectorstore_table(table_name=table_name, vector_size=1024) with pytest.raises(IndexError): - PGChatMessageHistory.create_sync( - engine=engine, session_id="test", table_name=doc_table_name + await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) - query = f'DROP TABLE IF EXISTS "{doc_table_name}"' - await aexecute(engine, query) + +async def test_incorrect_schema_sync(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + async_engine.init_vectorstore_table(table_name=table_name, vector_size=1024) + with pytest.raises(IndexError): + PGChatMessageHistory.create_sync( + engine=async_engine, session_id="test", table_name=table_name + ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) @pytest.mark.asyncio From 0d78b876bd79ad81eb279b71f6a9e547637a808d Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Fri, 25 Apr 2025 10:12:03 +0000 Subject: [PATCH 3/6] chore(docs): Add documentation for PGChatMessageHistory --- examples/pg_chat_message_history.ipynb | 330 +++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 examples/pg_chat_message_history.ipynb diff --git a/examples/pg_chat_message_history.ipynb b/examples/pg_chat_message_history.ipynb new file mode 100644 index 0000000..d835f35 --- /dev/null +++ b/examples/pg_chat_message_history.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PGChatMessageHistory\n", + "\n", + "`PGChatMessageHistory` is a an implementation of the the LangChain ChatMessageHistory abstraction using `postgres` as the backend.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IR54BmgvdHT_" + }, + "source": [ + "## Install\n", + "\n", + "Install the `langchain-postgres` package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "0ZITIDE160OD", + "outputId": "e184bc0d-6541-4e0a-82d2-1e216db00a2d", + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-postgres" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QuQigs4UoFQ2" + }, + "source": [ + "## Create an engine\n", + "\n", + "The first step is to create a `PGEngine` instance, which does the following:\n", + "\n", + "1. Allows you to create tables for storing documents and embeddings.\n", + "2. Maintains a connection pool that manages connections to the database. This allows sharing of the connection pool and helps to reduce latency for database calls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_postgres import PGEngine\n", + "\n", + "# See docker command above to launch a Postgres instance with pgvector enabled.\n", + "# Replace these values with your own configuration.\n", + "POSTGRES_USER = \"langchain\"\n", + "POSTGRES_PASSWORD = \"langchain\"\n", + "POSTGRES_HOST = \"localhost\"\n", + "POSTGRES_PORT = \"6024\"\n", + "POSTGRES_DB = \"langchain\"\n", + "\n", + "CONNECTION_STRING = (\n", + " f\"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}\"\n", + " f\":{POSTGRES_PORT}/{POSTGRES_DB}\"\n", + ")\n", + "\n", + "pg_engine = PGEngine.from_connection_string(url=CONNECTION_STRING)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use psycopg3 driver, set your connection string to `postgresql+psycopg://`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9Xs2qhm6X56" + }, + "source": [ + "### Initialize a table\n", + "The `PGChatMessageHistory` class requires a database table with a specific schema in order to store the chat message history.\n", + "\n", + "The `PGEngine` engine has a helper method `init_chat_history_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TABLE_NAME = \"chat history\"\n", + "\n", + "pg_engine.init_chat_history_table(table_name=TABLE_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Optional Tip: 💡\n", + "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", + "\n", + "```python\n", + "SCHEMA_NAME=\"my_schema\"\n", + "\n", + "engine.init_chat_history_table(\n", + " table_name=TABLE_NAME,\n", + " schema_name=SCHEMA_NAME # Default: \"public\"\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PGChatMessageHistory\n", + "\n", + "To initialize the `PGChatMessageHistory` class you need to provide only 3 things:\n", + "\n", + "1. `engine` - An instance of a `PGEngine` engine.\n", + "1. `session_id` - A unique identifier string that specifies an id for the session.\n", + "1. `table_name` : The name of the table within the PG database to store the chat message history.\n", + "1. `schema_name` : The name of the database schema containing the chat message history table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z-AZyzAQ7bsf", + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_postgres import PGChatMessageHistory\n", + "\n", + "history = PGChatMessageHistory.create_sync(\n", + " pg_engine,\n", + " session_id=\"test_session\",\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + ")\n", + "history.add_user_message(\"hi!\")\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Cleaning up\n", + "When the history of a specific session is obsolete and can be deleted, it can be done the following way.\n", + "\n", + "**Note:** Once deleted, the data is no longer stored in Postgres and is gone forever." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "history.clear()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔗 Chaining\n", + "\n", + "We can easily combine this message history class with [LCEL Runnables](/docs/expression_language/how_to/message_history)\n", + "\n", + "To do this we will use one of [Google's Vertex AI chat models](https://python.langchain.com/docs/integrations/chat/google_vertex_ai_palm)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# enable Vertex AI API\n", + "!gcloud services enable aiplatform.googleapis.com" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", + "from langchain_core.runnables.history import RunnableWithMessageHistory\n", + "from langchain_google_vertexai import ChatVertexAI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "GOOGLE_CLOUD_PROJECT_ID = \"\"\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", \"You are a helpful assistant.\"),\n", + " MessagesPlaceholder(variable_name=\"history\"),\n", + " (\"human\", \"{question}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | ChatVertexAI(\n", + " project=GOOGLE_CLOUD_PROJECT_ID, model_name=\"gemini-2.0-flash-exp\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain_with_history = RunnableWithMessageHistory(\n", + " chain,\n", + " lambda session_id: PGChatMessageHistory.create_sync(\n", + " pg_engine,\n", + " session_id=session_id,\n", + " table_name=TABLE_NAME,\n", + " # schema_name=SCHEMA_NAME,\n", + " ),\n", + " input_messages_key=\"question\",\n", + " history_messages_key=\"history\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# This is where we configure the session id\n", + "config = {\"configurable\": {\"session_id\": \"test_session\"}}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain_with_history.invoke({\"question\": \"Hi! I'm bob\"}, config=config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain_with_history.invoke({\"question\": \"Whats my name\"}, config=config)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From d713aa8ae0624c088859514885dcda03948fb5c9 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:51:23 +0530 Subject: [PATCH 4/6] Update test_imports.py --- tests/unit_tests/test_imports.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py index 445a4b1..d43a21c 100644 --- a/tests/unit_tests/test_imports.py +++ b/tests/unit_tests/test_imports.py @@ -9,9 +9,11 @@ "PGVectorStore", "PGVectorTranslator", "PostgresChatMessageHistory", + "PGChatMessageHistory", ] def test_all_imports() -> None: """Test that __all__ is correctly defined.""" assert sorted(EXPECTED_ALL) == sorted(__all__) +, From 7e80b4f67f0659e5490b678fd522002e9071e8f4 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Sun, 27 Apr 2025 13:56:43 +0000 Subject: [PATCH 5/6] fix test --- tests/unit_tests/test_imports.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py index d43a21c..e1c21f4 100644 --- a/tests/unit_tests/test_imports.py +++ b/tests/unit_tests/test_imports.py @@ -16,4 +16,3 @@ def test_all_imports() -> None: """Test that __all__ is correctly defined.""" assert sorted(EXPECTED_ALL) == sorted(__all__) -, From 68741c84163da13fbe27ba9040020ed70745d8b7 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 30 Apr 2025 06:52:02 +0000 Subject: [PATCH 6/6] review changes --- examples/pg_chat_message_history.ipynb | 26 ++++--------------- .../v2/async_chat_message_history.py | 4 +-- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/examples/pg_chat_message_history.ipynb b/examples/pg_chat_message_history.ipynb index d835f35..7d1a8ac 100644 --- a/examples/pg_chat_message_history.ipynb +++ b/examples/pg_chat_message_history.ipynb @@ -104,28 +104,11 @@ }, "outputs": [], "source": [ - "TABLE_NAME = \"chat history\"\n", + "TABLE_NAME = \"chat_history\"\n", "\n", "pg_engine.init_chat_history_table(table_name=TABLE_NAME)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Optional Tip: 💡\n", - "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", - "\n", - "```python\n", - "SCHEMA_NAME=\"my_schema\"\n", - "\n", - "engine.init_chat_history_table(\n", - " table_name=TABLE_NAME,\n", - " schema_name=SCHEMA_NAME # Default: \"public\"\n", - ")\n", - "```" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -195,11 +178,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 🔗 Chaining\n", + "## Chaining\n", "\n", - "We can easily combine this message history class with [LCEL Runnables](/docs/expression_language/how_to/message_history)\n", + "We can easily combine this message history class with [LCEL Runnables](https://python.langchain.com/docs/concepts/lcel/) such as `RunnableWithMessageHistory`.\n", "\n", - "To do this we will use one of [Google's Vertex AI chat models](https://python.langchain.com/docs/integrations/chat/google_vertex_ai_palm)\n" + "To create an agent or chain, you will need a model. This example will use one of [Google's Vertex AI chat models](https://python.langchain.com/docs/integrations/chat/google_vertex_ai_palm)\n" ] }, { @@ -231,6 +214,7 @@ }, "outputs": [], "source": [ + "## Please update the project id\n", "GOOGLE_CLOUD_PROJECT_ID = \"\"\n", "\n", "prompt = ChatPromptTemplate.from_messages(\n", diff --git a/langchain_postgres/v2/async_chat_message_history.py b/langchain_postgres/v2/async_chat_message_history.py index abb8202..a80440c 100644 --- a/langchain_postgres/v2/async_chat_message_history.py +++ b/langchain_postgres/v2/async_chat_message_history.py @@ -84,7 +84,7 @@ async def create( f"\nCREATE TABLE {schema_name}.{table_name} (" "\n id INT AUTO_INCREMENT PRIMARY KEY," "\n session_id TEXT NOT NULL," - "\n data JSON NOT NULL," + "\n data JSONB NOT NULL," "\n type TEXT NOT NULL" "\n);" ) @@ -153,9 +153,7 @@ def _convert_to_messages(self, rows: Sequence[RowMapping]) -> list[BaseMessage]: async def _aget_messages(self) -> list[BaseMessage]: """Retrieve the messages from Postgres.""" - query = self._select_query() - async with self.pool.connect() as conn: result = await conn.execute(text(query), {"session_id": self.session_id}) result_map = result.mappings()