Spaces:
No application file
No application file
"""Client for persisting chat message history in a Postgres database. | |
This client provides support for both sync and async via psycopg 3. | |
""" | |
from __future__ import annotations | |
import json | |
import logging | |
import re | |
import uuid | |
from typing import List, Optional, Sequence | |
import psycopg | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict | |
from psycopg import sql | |
logger = logging.getLogger(__name__) | |
def _create_table_and_index(table_name: str) -> List[sql.Composed]: | |
"""Make a SQL query to create a table.""" | |
index_name = f"idx_{table_name}_session_id" | |
statements = [ | |
sql.SQL( | |
""" | |
CREATE TABLE IF NOT EXISTS {table_name} ( | |
id SERIAL PRIMARY KEY, | |
username VARCHAR(255) NOT NULL, -- Add the username field | |
session_id UUID NOT NULL, | |
message JSONB NOT NULL, | |
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() | |
); | |
""" | |
).format(table_name=sql.Identifier(table_name)), | |
sql.SQL( | |
""" | |
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id); | |
""" | |
).format( | |
table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name) | |
), | |
] | |
return statements | |
def _get_messages_query(table_name: str, last_messages: int) -> sql.Composed: | |
"""Make a SQL query to get the last N messages for a given session and username.""" | |
return sql.SQL( | |
"SELECT message " | |
"FROM {table_name} " | |
"WHERE session_id = %(session_id)s AND username = %(username)s " | |
"ORDER BY id DESC " | |
"LIMIT %(last_messages)s;" | |
).format(table_name=sql.Identifier(table_name)) | |
def _delete_by_session_id_query(table_name: str) -> sql.Composed: | |
"""Make a SQL query to delete messages for a given session and username.""" | |
return sql.SQL( | |
"DELETE FROM {table_name} WHERE session_id = %(session_id)s AND username = %(username)s;" | |
).format(table_name=sql.Identifier(table_name)) | |
def _delete_table_query(table_name: str) -> sql.Composed: | |
"""Make a SQL query to delete a table.""" | |
return sql.SQL("DROP TABLE IF EXISTS {table_name};").format( | |
table_name=sql.Identifier(table_name) | |
) | |
def _insert_message_query(table_name: str) -> sql.Composed: | |
"""Make a SQL query to insert a message with username.""" | |
return sql.SQL( | |
"INSERT INTO {table_name} (username, session_id, message) VALUES (%s, %s, %s)" | |
).format(table_name=sql.Identifier(table_name)) | |
class PostgresChatMessageHistory(BaseChatMessageHistory): | |
def __init__( | |
self, | |
table_name: str, | |
session_id: str, | |
username:str, | |
/, | |
*, | |
sync_connection: Optional[psycopg.Connection] = None, | |
async_connection: Optional[psycopg.AsyncConnection] = None, | |
) -> None: | |
"""Client for persisting chat message history in a Postgres database, | |
This client provides support for both sync and async via psycopg >=3. | |
The client can create schema in the database and provides methods to | |
add messages, get messages, and clear the chat message history. | |
The schema has the following columns: | |
- id: A serial primary key. | |
- session_id: The session ID for the chat message history. | |
- message: The JSONB message content. | |
- created_at: The timestamp of when the message was created. | |
- username: username for the user | |
Messages are retrieved for a given session_id and are sorted by | |
the id (which should be increasing monotonically), and correspond | |
to the order in which the messages were added to the history. | |
The "created_at" column is not returned by the interface, but | |
has been added for the schema so the information is available in the database. | |
A session_id can be used to separate different chat histories in the same table, | |
the session_id should be provided when initializing the client. | |
This chat history client takes in a psycopg connection object (either | |
Connection or AsyncConnection) and uses it to interact with the database. | |
This design allows to reuse the underlying connection object across | |
multiple instantiations of this class, making instantiation fast. | |
This chat history client is designed for prototyping applications that | |
involve chat and are based on Postgres. | |
As your application grows, you will likely need to extend the schema to | |
handle more complex queries. For example, a chat application | |
may involve multiple tables like a user table, a table for storing | |
chat sessions / conversations, and this table for storing chat messages | |
for a given session. The application will require access to additional | |
endpoints like deleting messages by user id, listing conversations by | |
user id or ordering them based on last message time, etc. | |
Feel free to adapt this implementation to suit your application's needs. | |
Args: | |
session_id: The session ID to use for the chat message history | |
table_name: The name of the database table to use | |
sync_connection: An existing psycopg connection instance | |
async_connection: An existing psycopg async connection instance | |
Usage: | |
- Use the create_tables or acreate_tables method to set up the table | |
schema in the database. | |
- Initialize the class with the appropriate session ID, table name, | |
and database connection. | |
- Add messages to the database using add_messages or aadd_messages. | |
- Retrieve messages with get_messages or aget_messages. | |
- Clear the session history with clear or aclear when needed. | |
Note: | |
- At least one of sync_connection or async_connection must be provided. | |
Examples: | |
.. code-block:: python | |
import uuid | |
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage | |
from langchain_postgres import PostgresChatMessageHistory | |
import psycopg | |
# Establish a synchronous connection to the database | |
# (or use psycopg.AsyncConnection for async) | |
sync_connection = psycopg2.connect(conn_info) | |
# Create the table schema (only needs to be done once) | |
table_name = "chat_history" | |
PostgresChatMessageHistory.create_tables(sync_connection, table_name) | |
session_id = str(uuid.uuid4()) | |
# Initialize the chat history manager | |
chat_history = PostgresChatMessageHistory( | |
table_name, | |
session_id, | |
sync_connection=sync_connection | |
) | |
# Add messages to the chat history | |
chat_history.add_messages([ | |
SystemMessage(content="Meow"), | |
AIMessage(content="woof"), | |
HumanMessage(content="bark"), | |
]) | |
print(chat_history.messages) | |
""" | |
if not sync_connection and not async_connection: | |
raise ValueError("Must provide sync_connection or async_connection") | |
self._connection = sync_connection | |
self._aconnection = async_connection | |
# Validate that session id is a UUID | |
try: | |
uuid.UUID(session_id) | |
except ValueError: | |
raise ValueError( | |
f"Invalid session id. Session id must be a valid UUID. Got {session_id}" | |
) | |
self._session_id = session_id | |
self._username = username | |
if not re.match(r"^\w+$", table_name): | |
raise ValueError( | |
"Invalid table name. Table name must contain only alphanumeric " | |
"characters and underscores." | |
) | |
self._table_name = table_name | |
def create_tables( | |
connection: psycopg.Connection, | |
table_name: str, | |
/, | |
) -> None: | |
"""Create the table schema in the database and create relevant indexes.""" | |
queries = _create_table_and_index(table_name) | |
logger.info("Creating schema for table %s", table_name) | |
with connection.cursor() as cursor: | |
for query in queries: | |
cursor.execute(query) | |
connection.commit() | |
async def acreate_tables( | |
connection: psycopg.AsyncConnection, table_name: str, / | |
) -> None: | |
"""Create the table schema in the database and create relevant indexes.""" | |
queries = _create_table_and_index(table_name) | |
logger.info("Creating schema for table %s", table_name) | |
async with connection.cursor() as cur: | |
for query in queries: | |
await cur.execute(query) | |
await connection.commit() | |
def drop_table(connection: psycopg.Connection, table_name: str, /) -> None: | |
"""Delete the table schema in the database. | |
WARNING: | |
This will delete the given table from the database including | |
all the database in the table and the schema of the table. | |
Args: | |
connection: The database connection. | |
table_name: The name of the table to create. | |
""" | |
query = _delete_table_query(table_name) | |
logger.info("Dropping table %s", table_name) | |
with connection.cursor() as cursor: | |
cursor.execute(query) | |
connection.commit() | |
async def adrop_table( | |
connection: psycopg.AsyncConnection, table_name: str, / | |
) -> None: | |
"""Delete the table schema in the database. | |
WARNING: | |
This will delete the given table from the database including | |
all the database in the table and the schema of the table. | |
Args: | |
connection: Async database connection. | |
table_name: The name of the table to create. | |
""" | |
query = _delete_table_query(table_name) | |
logger.info("Dropping table %s", table_name) | |
async with connection.cursor() as acur: | |
await acur.execute(query) | |
await connection.commit() | |
def add_messages(self, messages: Sequence[BaseMessage]) -> None: | |
"""Add messages to the chat message history.""" | |
if self._connection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with a sync connection or use the aadd_messages method instead." | |
) | |
# print(messages) | |
values = [ | |
(self._username, self._session_id , json.dumps(message_to_dict(message))) | |
for message in messages | |
] | |
query = _insert_message_query(self._table_name) | |
# print(query.as_string(self._connection) ) | |
with self._connection.cursor() as cursor: | |
cursor.executemany(query, values) | |
self._connection.commit() | |
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: | |
"""Add messages to the chat message history.""" | |
if self._aconnection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with an async connection or use the sync add_messages method instead." | |
) | |
values = [ | |
(self._session_id, self._username, json.dumps(message_to_dict(message))) | |
for message in messages | |
] | |
query = _insert_message_query(self._table_name) | |
async with self._aconnection.cursor() as cursor: | |
await cursor.executemany(query, values) | |
await self._aconnection.commit() | |
def get_messages(self, last_messages:int) -> List[BaseMessage]: | |
"""Retrieve messages from the chat message history.""" | |
if self._connection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with a sync connection or use the async aget_messages method instead." | |
) | |
query = _get_messages_query(self._table_name, last_messages=last_messages) | |
with self._connection.cursor() as cursor: | |
cursor.execute(query, {"session_id": self._session_id}) | |
items = [record[0] for record in cursor.fetchall()] | |
messages = messages_from_dict(items) | |
return messages | |
async def aget_messages(self) -> List[BaseMessage]: | |
"""Retrieve messages from the chat message history.""" | |
if self._aconnection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with an async connection or use the sync get_messages method instead." | |
) | |
query = _get_messages_query(self._table_name) | |
async with self._aconnection.cursor() as cursor: | |
await cursor.execute(query, {"session_id": self._session_id,"username":self._username}) | |
items = [record[0] for record in await cursor.fetchall()] | |
messages = messages_from_dict(items) | |
return messages | |
# type: ignore[override] | |
def messages(self) -> List[BaseMessage]: | |
"""The abstraction required a property.""" | |
return self.get_messages() | |
def clear(self) -> None: | |
"""Clear the chat message history for the GIVEN session.""" | |
if self._connection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with a sync connection or use the async clear method instead." | |
) | |
query = _delete_by_session_id_query(self._table_name) | |
with self._connection.cursor() as cursor: | |
cursor.execute(query, {"session_id": self._session_id}) | |
self._connection.commit() | |
async def aclear(self) -> None: | |
"""Clear the chat message history for the GIVEN session.""" | |
if self._aconnection is None: | |
raise ValueError( | |
"Please initialize the PostgresChatMessageHistory " | |
"with an async connection or use the sync clear method instead." | |
) | |
query = _delete_by_session_id_query(self._table_name) | |
async with self._aconnection.cursor() as cursor: | |
await cursor.execute(query, {"session_id": self._session_id}) | |
await self._aconnection.commit() |