Spaces:
No application file
No application file
Upload 8 files
Browse files- .gitignore +3 -0
- DockerFile +16 -0
- app.py +177 -0
- custom_message.py +29 -0
- db.py +43 -0
- postgres.py +379 -0
- prompt.py +15 -0
- requirements.txt +14 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
rag_implementation
|
3 |
+
.env
|
DockerFile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.11
|
5 |
+
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
USER user
|
8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
9 |
+
|
10 |
+
WORKDIR /app
|
11 |
+
|
12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
14 |
+
|
15 |
+
COPY --chown=user . /app
|
16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Necessary Imports
|
3 |
+
'''
|
4 |
+
|
5 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException,Form
|
6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from postgres import PostgresChatMessageHistory
|
9 |
+
|
10 |
+
from langchain_community.document_loaders import PyPDFLoader
|
11 |
+
from langchain_postgres.vectorstores import PGVector
|
12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
+
from langchain.chains import create_retrieval_chain
|
14 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
16 |
+
|
17 |
+
from typing import Dict
|
18 |
+
from langchain_openai import ChatOpenAI
|
19 |
+
from prompt import prompt,system_prompt
|
20 |
+
import psycopg
|
21 |
+
import uuid
|
22 |
+
import os
|
23 |
+
from custom_message import CustomMessage
|
24 |
+
|
25 |
+
from dotenv import load_dotenv
|
26 |
+
import os
|
27 |
+
from io import BytesIO
|
28 |
+
from pypdf import PdfReader
|
29 |
+
from langchain.docstore.document import Document
|
30 |
+
|
31 |
+
vector_store = None
|
32 |
+
|
33 |
+
# LOADING ENVIRONMENT VARIABLES
|
34 |
+
|
35 |
+
load_dotenv()
|
36 |
+
|
37 |
+
# INSTANTIATING THE APP
|
38 |
+
app = FastAPI()
|
39 |
+
llm = ChatOpenAI(model="gpt-4o",
|
40 |
+
temperature=0.2,
|
41 |
+
max_tokens=None,
|
42 |
+
timeout=None,
|
43 |
+
max_retries=1)
|
44 |
+
|
45 |
+
# ALLOWING CORS
|
46 |
+
app.add_middleware(
|
47 |
+
CORSMiddleware,
|
48 |
+
allow_origins=["*"],
|
49 |
+
allow_credentials=True,
|
50 |
+
allow_methods=["*"],
|
51 |
+
allow_headers=["*"],
|
52 |
+
)
|
53 |
+
|
54 |
+
# INITIALIZING THE EMBEDDING MODEL
|
55 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
56 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
57 |
+
chunk_size=1000,
|
58 |
+
chunk_overlap=300,
|
59 |
+
length_function=len,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
@app.get("/")
|
64 |
+
def greeting():
|
65 |
+
return {'response':'success','status code':200}
|
66 |
+
|
67 |
+
# PDF UPLOAD ROUTE
|
68 |
+
@app.post("/upload")
|
69 |
+
async def upload_pdf(file: UploadFile = File(...), collection_name: str = Form(...)):
|
70 |
+
"""
|
71 |
+
Upload and process a PDF file, storing its embeddings in the vector database.
|
72 |
+
"""
|
73 |
+
if not file.filename.endswith('.pdf'):
|
74 |
+
raise HTTPException(status_code=400, detail="Only PDF files are allowed")
|
75 |
+
|
76 |
+
try:
|
77 |
+
# Read PDF content directly into memory
|
78 |
+
pdf_content = await file.read()
|
79 |
+
pdf_file = BytesIO(pdf_content)
|
80 |
+
pdf_reader = PdfReader(pdf_file)
|
81 |
+
|
82 |
+
# Extract text from PDF
|
83 |
+
documents = []
|
84 |
+
for page_num, page in enumerate(pdf_reader.pages):
|
85 |
+
text = page.extract_text()
|
86 |
+
# Create a Document object with metadata
|
87 |
+
doc = Document(
|
88 |
+
page_content=text,
|
89 |
+
metadata={"page": page_num + 1, "source": file.filename}
|
90 |
+
)
|
91 |
+
documents.append(doc)
|
92 |
+
|
93 |
+
# Split documents into chunks
|
94 |
+
texts = text_splitter.split_documents(documents)
|
95 |
+
try:
|
96 |
+
|
97 |
+
|
98 |
+
global vector_store
|
99 |
+
vector_store = PGVector.from_documents(
|
100 |
+
documents=texts,
|
101 |
+
embedding=embeddings,
|
102 |
+
connection=os.environ['CONNECTION_STRING'],
|
103 |
+
collection_name=collection_name,
|
104 |
+
use_jsonb=True,
|
105 |
+
|
106 |
+
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
raise("Error in establishing the connection with DB: {e}")
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
return {"message": "PDF processed successfully", "collection_name": file.filename.replace('.pdf', '')}
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
raise HTTPException(status_code=500, detail=str(e))
|
119 |
+
|
120 |
+
|
121 |
+
@app.post("/query")
|
122 |
+
async def upload_pdf(query: str = Form(...),collection_name:str = Form(...),username:str = Form(...),table_name:str = Form(...)):
|
123 |
+
try:
|
124 |
+
global vector_store
|
125 |
+
if vector_store == None :
|
126 |
+
vector_store = PGVector(
|
127 |
+
embeddings=embeddings,
|
128 |
+
connection=os.environ['CONNECTION_STRING'],
|
129 |
+
collection_name=collection_name,
|
130 |
+
use_jsonb=True,
|
131 |
+
|
132 |
+
)
|
133 |
+
|
134 |
+
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
140 |
+
|
141 |
+
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
142 |
+
|
143 |
+
response = rag_chain.invoke({"input":query})['answer']
|
144 |
+
|
145 |
+
sync_connection = psycopg.connect(os.environ['CONNECTION_STRING'])
|
146 |
+
session_id = str(uuid.uuid4())
|
147 |
+
|
148 |
+
chat_history = PostgresChatMessageHistory(
|
149 |
+
table_name,
|
150 |
+
session_id,
|
151 |
+
username,
|
152 |
+
sync_connection=sync_connection
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
try:
|
157 |
+
custom_message = CustomMessage(content=f"SYSTEM_PROMPT:{system_prompt}\n\nHUMAN_MESSAGE:{query}\n\nAI_RESPONSE:{response}")
|
158 |
+
|
159 |
+
chat_history.add_message(custom_message)
|
160 |
+
except Exception as e:
|
161 |
+
print(e)
|
162 |
+
print("Ended")
|
163 |
+
return {
|
164 |
+
"relevant docs":response,
|
165 |
+
"session_id":session_id
|
166 |
+
|
167 |
+
}
|
168 |
+
except Exception as e:
|
169 |
+
raise HTTPException(status_code=500, detail=str(e))
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
# if __name__ == "__main__":
|
175 |
+
# import uvicorn
|
176 |
+
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
177 |
+
|
custom_message.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import BaseMessage
|
2 |
+
from typing import Any, Literal, Union
|
3 |
+
class CustomMessage(BaseMessage):
|
4 |
+
"""Message from an AI.
|
5 |
+
|
6 |
+
AIMessage is returned from a chat model as a response to a prompt.
|
7 |
+
|
8 |
+
This message represents the output of the model and consists of both
|
9 |
+
the raw output as returned by the model together standardized fields
|
10 |
+
(e.g., tool calls, usage metadata) added by the LangChain framework.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
type: Literal["custom"] = "custom"
|
16 |
+
"""The type of the message (used for deserialization). Defaults to "custom"."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
20 |
+
) -> None:
|
21 |
+
"""Pass in content as positional arg.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
content: The content of the message.
|
25 |
+
kwargs: Additional arguments to pass to the parent class.
|
26 |
+
"""
|
27 |
+
super().__init__(content=content, **kwargs)
|
28 |
+
CustomMessage.model_rebuild()
|
29 |
+
|
db.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores.pgvector import PGVector
|
2 |
+
import os
|
3 |
+
# try:
|
4 |
+
# CONNECTION_STRING = PGVector.connection_string_from_db_params(
|
5 |
+
# driver="psycopg2",
|
6 |
+
# host=os.getenv("POSTGRES_HOST"),
|
7 |
+
# port=int(os.getenv("POSTGRES_PORT", 5432)),
|
8 |
+
# database=os.getenv("POSTGRES_DB"),
|
9 |
+
# user=os.getenv("POSTGRES_USER"),
|
10 |
+
# password=os.getenv("POSTGRES_PASSWORD"),
|
11 |
+
# )
|
12 |
+
# print("Successfully established the connection")
|
13 |
+
# except Exception as e:
|
14 |
+
# print("Error in establishing the connection with DB: {e}")
|
15 |
+
print("Entered here")
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
load_dotenv()
|
18 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
19 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
20 |
+
from sqlalchemy import create_engine
|
21 |
+
from langchain_postgres.vectorstores import PGVector
|
22 |
+
|
23 |
+
from langchain_core.documents import Document
|
24 |
+
document_1 = Document(page_content="fddsdfoo", metadata={"baz": "bar"})
|
25 |
+
document_2 = Document(page_content="thufeed", metadata={"bar": "baz"})
|
26 |
+
document_3 = Document(page_content="i wefsill be deleted :(")
|
27 |
+
|
28 |
+
documents = [document_1, document_2, document_3]
|
29 |
+
ids = ["1", "2", "3"]
|
30 |
+
engine = create_engine(os.environ['CONNECTION_STRING'])
|
31 |
+
vector_store = PGVector.from_documents(
|
32 |
+
documents=documents,
|
33 |
+
embedding=embeddings,
|
34 |
+
connection=os.environ['CONNECTION_STRING'],
|
35 |
+
collection_name="collection_name",
|
36 |
+
use_jsonb=True,
|
37 |
+
|
38 |
+
|
39 |
+
)
|
40 |
+
# vector_store.add_documents(documents=documents, ids=ids)
|
41 |
+
|
42 |
+
|
43 |
+
print("Stored babe")
|
postgres.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Client for persisting chat message history in a Postgres database.
|
2 |
+
|
3 |
+
This client provides support for both sync and async via psycopg 3.
|
4 |
+
"""
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import re
|
10 |
+
import uuid
|
11 |
+
from typing import List, Optional, Sequence
|
12 |
+
|
13 |
+
import psycopg
|
14 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
15 |
+
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
|
16 |
+
from psycopg import sql
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def _create_table_and_index(table_name: str) -> List[sql.Composed]:
|
22 |
+
"""Make a SQL query to create a table."""
|
23 |
+
index_name = f"idx_{table_name}_session_id"
|
24 |
+
statements = [
|
25 |
+
sql.SQL(
|
26 |
+
"""
|
27 |
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
28 |
+
id SERIAL PRIMARY KEY,
|
29 |
+
username VARCHAR(255) NOT NULL, -- Add the username field
|
30 |
+
session_id UUID NOT NULL,
|
31 |
+
message JSONB NOT NULL,
|
32 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
33 |
+
);
|
34 |
+
|
35 |
+
"""
|
36 |
+
).format(table_name=sql.Identifier(table_name)),
|
37 |
+
sql.SQL(
|
38 |
+
"""
|
39 |
+
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
|
40 |
+
"""
|
41 |
+
).format(
|
42 |
+
table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name)
|
43 |
+
),
|
44 |
+
]
|
45 |
+
return statements
|
46 |
+
|
47 |
+
|
48 |
+
def _get_messages_query(table_name: str, last_messages: int) -> sql.Composed:
|
49 |
+
"""Make a SQL query to get the last N messages for a given session and username."""
|
50 |
+
return sql.SQL(
|
51 |
+
"SELECT message "
|
52 |
+
"FROM {table_name} "
|
53 |
+
"WHERE session_id = %(session_id)s AND username = %(username)s "
|
54 |
+
"ORDER BY id DESC "
|
55 |
+
"LIMIT %(last_messages)s;"
|
56 |
+
).format(table_name=sql.Identifier(table_name))
|
57 |
+
|
58 |
+
|
59 |
+
def _delete_by_session_id_query(table_name: str) -> sql.Composed:
|
60 |
+
"""Make a SQL query to delete messages for a given session and username."""
|
61 |
+
return sql.SQL(
|
62 |
+
"DELETE FROM {table_name} WHERE session_id = %(session_id)s AND username = %(username)s;"
|
63 |
+
).format(table_name=sql.Identifier(table_name))
|
64 |
+
|
65 |
+
|
66 |
+
def _delete_table_query(table_name: str) -> sql.Composed:
|
67 |
+
"""Make a SQL query to delete a table."""
|
68 |
+
return sql.SQL("DROP TABLE IF EXISTS {table_name};").format(
|
69 |
+
table_name=sql.Identifier(table_name)
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def _insert_message_query(table_name: str) -> sql.Composed:
|
75 |
+
"""Make a SQL query to insert a message with username."""
|
76 |
+
return sql.SQL(
|
77 |
+
"INSERT INTO {table_name} (username, session_id, message) VALUES (%s, %s, %s)"
|
78 |
+
).format(table_name=sql.Identifier(table_name))
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class PostgresChatMessageHistory(BaseChatMessageHistory):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
table_name: str,
|
86 |
+
session_id: str,
|
87 |
+
username:str,
|
88 |
+
/,
|
89 |
+
*,
|
90 |
+
sync_connection: Optional[psycopg.Connection] = None,
|
91 |
+
async_connection: Optional[psycopg.AsyncConnection] = None,
|
92 |
+
) -> None:
|
93 |
+
"""Client for persisting chat message history in a Postgres database,
|
94 |
+
|
95 |
+
This client provides support for both sync and async via psycopg >=3.
|
96 |
+
|
97 |
+
The client can create schema in the database and provides methods to
|
98 |
+
add messages, get messages, and clear the chat message history.
|
99 |
+
|
100 |
+
The schema has the following columns:
|
101 |
+
|
102 |
+
- id: A serial primary key.
|
103 |
+
- session_id: The session ID for the chat message history.
|
104 |
+
- message: The JSONB message content.
|
105 |
+
- created_at: The timestamp of when the message was created.
|
106 |
+
- username: username for the user
|
107 |
+
Messages are retrieved for a given session_id and are sorted by
|
108 |
+
the id (which should be increasing monotonically), and correspond
|
109 |
+
to the order in which the messages were added to the history.
|
110 |
+
|
111 |
+
The "created_at" column is not returned by the interface, but
|
112 |
+
has been added for the schema so the information is available in the database.
|
113 |
+
|
114 |
+
A session_id can be used to separate different chat histories in the same table,
|
115 |
+
the session_id should be provided when initializing the client.
|
116 |
+
|
117 |
+
This chat history client takes in a psycopg connection object (either
|
118 |
+
Connection or AsyncConnection) and uses it to interact with the database.
|
119 |
+
|
120 |
+
This design allows to reuse the underlying connection object across
|
121 |
+
multiple instantiations of this class, making instantiation fast.
|
122 |
+
|
123 |
+
This chat history client is designed for prototyping applications that
|
124 |
+
involve chat and are based on Postgres.
|
125 |
+
|
126 |
+
As your application grows, you will likely need to extend the schema to
|
127 |
+
handle more complex queries. For example, a chat application
|
128 |
+
may involve multiple tables like a user table, a table for storing
|
129 |
+
chat sessions / conversations, and this table for storing chat messages
|
130 |
+
for a given session. The application will require access to additional
|
131 |
+
endpoints like deleting messages by user id, listing conversations by
|
132 |
+
user id or ordering them based on last message time, etc.
|
133 |
+
|
134 |
+
Feel free to adapt this implementation to suit your application's needs.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
session_id: The session ID to use for the chat message history
|
138 |
+
table_name: The name of the database table to use
|
139 |
+
sync_connection: An existing psycopg connection instance
|
140 |
+
async_connection: An existing psycopg async connection instance
|
141 |
+
|
142 |
+
Usage:
|
143 |
+
- Use the create_tables or acreate_tables method to set up the table
|
144 |
+
schema in the database.
|
145 |
+
- Initialize the class with the appropriate session ID, table name,
|
146 |
+
and database connection.
|
147 |
+
- Add messages to the database using add_messages or aadd_messages.
|
148 |
+
- Retrieve messages with get_messages or aget_messages.
|
149 |
+
- Clear the session history with clear or aclear when needed.
|
150 |
+
|
151 |
+
Note:
|
152 |
+
- At least one of sync_connection or async_connection must be provided.
|
153 |
+
|
154 |
+
Examples:
|
155 |
+
|
156 |
+
.. code-block:: python
|
157 |
+
|
158 |
+
import uuid
|
159 |
+
|
160 |
+
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
|
161 |
+
from langchain_postgres import PostgresChatMessageHistory
|
162 |
+
import psycopg
|
163 |
+
|
164 |
+
# Establish a synchronous connection to the database
|
165 |
+
# (or use psycopg.AsyncConnection for async)
|
166 |
+
sync_connection = psycopg2.connect(conn_info)
|
167 |
+
|
168 |
+
# Create the table schema (only needs to be done once)
|
169 |
+
table_name = "chat_history"
|
170 |
+
PostgresChatMessageHistory.create_tables(sync_connection, table_name)
|
171 |
+
|
172 |
+
session_id = str(uuid.uuid4())
|
173 |
+
|
174 |
+
# Initialize the chat history manager
|
175 |
+
chat_history = PostgresChatMessageHistory(
|
176 |
+
table_name,
|
177 |
+
session_id,
|
178 |
+
sync_connection=sync_connection
|
179 |
+
)
|
180 |
+
|
181 |
+
# Add messages to the chat history
|
182 |
+
chat_history.add_messages([
|
183 |
+
SystemMessage(content="Meow"),
|
184 |
+
AIMessage(content="woof"),
|
185 |
+
HumanMessage(content="bark"),
|
186 |
+
])
|
187 |
+
|
188 |
+
print(chat_history.messages)
|
189 |
+
"""
|
190 |
+
if not sync_connection and not async_connection:
|
191 |
+
raise ValueError("Must provide sync_connection or async_connection")
|
192 |
+
|
193 |
+
self._connection = sync_connection
|
194 |
+
self._aconnection = async_connection
|
195 |
+
|
196 |
+
# Validate that session id is a UUID
|
197 |
+
try:
|
198 |
+
uuid.UUID(session_id)
|
199 |
+
except ValueError:
|
200 |
+
raise ValueError(
|
201 |
+
f"Invalid session id. Session id must be a valid UUID. Got {session_id}"
|
202 |
+
)
|
203 |
+
|
204 |
+
self._session_id = session_id
|
205 |
+
self._username = username
|
206 |
+
if not re.match(r"^\w+$", table_name):
|
207 |
+
raise ValueError(
|
208 |
+
"Invalid table name. Table name must contain only alphanumeric "
|
209 |
+
"characters and underscores."
|
210 |
+
)
|
211 |
+
self._table_name = table_name
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def create_tables(
|
215 |
+
connection: psycopg.Connection,
|
216 |
+
table_name: str,
|
217 |
+
/,
|
218 |
+
) -> None:
|
219 |
+
"""Create the table schema in the database and create relevant indexes."""
|
220 |
+
queries = _create_table_and_index(table_name)
|
221 |
+
logger.info("Creating schema for table %s", table_name)
|
222 |
+
with connection.cursor() as cursor:
|
223 |
+
for query in queries:
|
224 |
+
cursor.execute(query)
|
225 |
+
connection.commit()
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
async def acreate_tables(
|
229 |
+
connection: psycopg.AsyncConnection, table_name: str, /
|
230 |
+
) -> None:
|
231 |
+
"""Create the table schema in the database and create relevant indexes."""
|
232 |
+
queries = _create_table_and_index(table_name)
|
233 |
+
logger.info("Creating schema for table %s", table_name)
|
234 |
+
async with connection.cursor() as cur:
|
235 |
+
for query in queries:
|
236 |
+
await cur.execute(query)
|
237 |
+
await connection.commit()
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def drop_table(connection: psycopg.Connection, table_name: str, /) -> None:
|
241 |
+
"""Delete the table schema in the database.
|
242 |
+
|
243 |
+
WARNING:
|
244 |
+
This will delete the given table from the database including
|
245 |
+
all the database in the table and the schema of the table.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
connection: The database connection.
|
249 |
+
table_name: The name of the table to create.
|
250 |
+
"""
|
251 |
+
|
252 |
+
query = _delete_table_query(table_name)
|
253 |
+
logger.info("Dropping table %s", table_name)
|
254 |
+
with connection.cursor() as cursor:
|
255 |
+
cursor.execute(query)
|
256 |
+
connection.commit()
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
async def adrop_table(
|
260 |
+
connection: psycopg.AsyncConnection, table_name: str, /
|
261 |
+
) -> None:
|
262 |
+
"""Delete the table schema in the database.
|
263 |
+
|
264 |
+
WARNING:
|
265 |
+
This will delete the given table from the database including
|
266 |
+
all the database in the table and the schema of the table.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
connection: Async database connection.
|
270 |
+
table_name: The name of the table to create.
|
271 |
+
"""
|
272 |
+
query = _delete_table_query(table_name)
|
273 |
+
logger.info("Dropping table %s", table_name)
|
274 |
+
|
275 |
+
async with connection.cursor() as acur:
|
276 |
+
await acur.execute(query)
|
277 |
+
await connection.commit()
|
278 |
+
|
279 |
+
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
280 |
+
"""Add messages to the chat message history."""
|
281 |
+
if self._connection is None:
|
282 |
+
raise ValueError(
|
283 |
+
"Please initialize the PostgresChatMessageHistory "
|
284 |
+
"with a sync connection or use the aadd_messages method instead."
|
285 |
+
)
|
286 |
+
# print(messages)
|
287 |
+
values = [
|
288 |
+
(self._username, self._session_id , json.dumps(message_to_dict(message)))
|
289 |
+
for message in messages
|
290 |
+
]
|
291 |
+
|
292 |
+
query = _insert_message_query(self._table_name)
|
293 |
+
# print(query.as_string(self._connection) )
|
294 |
+
with self._connection.cursor() as cursor:
|
295 |
+
cursor.executemany(query, values)
|
296 |
+
self._connection.commit()
|
297 |
+
|
298 |
+
|
299 |
+
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
300 |
+
"""Add messages to the chat message history."""
|
301 |
+
if self._aconnection is None:
|
302 |
+
raise ValueError(
|
303 |
+
"Please initialize the PostgresChatMessageHistory "
|
304 |
+
"with an async connection or use the sync add_messages method instead."
|
305 |
+
)
|
306 |
+
|
307 |
+
values = [
|
308 |
+
(self._session_id, self._username, json.dumps(message_to_dict(message)))
|
309 |
+
for message in messages
|
310 |
+
]
|
311 |
+
|
312 |
+
query = _insert_message_query(self._table_name)
|
313 |
+
async with self._aconnection.cursor() as cursor:
|
314 |
+
await cursor.executemany(query, values)
|
315 |
+
await self._aconnection.commit()
|
316 |
+
|
317 |
+
def get_messages(self, last_messages:int) -> List[BaseMessage]:
|
318 |
+
"""Retrieve messages from the chat message history."""
|
319 |
+
if self._connection is None:
|
320 |
+
raise ValueError(
|
321 |
+
"Please initialize the PostgresChatMessageHistory "
|
322 |
+
"with a sync connection or use the async aget_messages method instead."
|
323 |
+
)
|
324 |
+
|
325 |
+
query = _get_messages_query(self._table_name, last_messages=last_messages)
|
326 |
+
|
327 |
+
with self._connection.cursor() as cursor:
|
328 |
+
cursor.execute(query, {"session_id": self._session_id})
|
329 |
+
items = [record[0] for record in cursor.fetchall()]
|
330 |
+
|
331 |
+
messages = messages_from_dict(items)
|
332 |
+
return messages
|
333 |
+
|
334 |
+
async def aget_messages(self) -> List[BaseMessage]:
|
335 |
+
"""Retrieve messages from the chat message history."""
|
336 |
+
if self._aconnection is None:
|
337 |
+
raise ValueError(
|
338 |
+
"Please initialize the PostgresChatMessageHistory "
|
339 |
+
"with an async connection or use the sync get_messages method instead."
|
340 |
+
)
|
341 |
+
|
342 |
+
query = _get_messages_query(self._table_name)
|
343 |
+
async with self._aconnection.cursor() as cursor:
|
344 |
+
await cursor.execute(query, {"session_id": self._session_id,"username":self._username})
|
345 |
+
items = [record[0] for record in await cursor.fetchall()]
|
346 |
+
|
347 |
+
messages = messages_from_dict(items)
|
348 |
+
return messages
|
349 |
+
|
350 |
+
@property # type: ignore[override]
|
351 |
+
def messages(self) -> List[BaseMessage]:
|
352 |
+
"""The abstraction required a property."""
|
353 |
+
return self.get_messages()
|
354 |
+
|
355 |
+
def clear(self) -> None:
|
356 |
+
"""Clear the chat message history for the GIVEN session."""
|
357 |
+
if self._connection is None:
|
358 |
+
raise ValueError(
|
359 |
+
"Please initialize the PostgresChatMessageHistory "
|
360 |
+
"with a sync connection or use the async clear method instead."
|
361 |
+
)
|
362 |
+
|
363 |
+
query = _delete_by_session_id_query(self._table_name)
|
364 |
+
with self._connection.cursor() as cursor:
|
365 |
+
cursor.execute(query, {"session_id": self._session_id})
|
366 |
+
self._connection.commit()
|
367 |
+
|
368 |
+
async def aclear(self) -> None:
|
369 |
+
"""Clear the chat message history for the GIVEN session."""
|
370 |
+
if self._aconnection is None:
|
371 |
+
raise ValueError(
|
372 |
+
"Please initialize the PostgresChatMessageHistory "
|
373 |
+
"with an async connection or use the sync clear method instead."
|
374 |
+
)
|
375 |
+
|
376 |
+
query = _delete_by_session_id_query(self._table_name)
|
377 |
+
async with self._aconnection.cursor() as cursor:
|
378 |
+
await cursor.execute(query, {"session_id": self._session_id})
|
379 |
+
await self._aconnection.commit()
|
prompt.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
system_prompt = (
|
3 |
+
"You are a helpful assistant"
|
4 |
+
"If you dont know the answer , just say you dont know the answer."
|
5 |
+
"\n\n"
|
6 |
+
"Only use the following context to answer the question"
|
7 |
+
"{context}"
|
8 |
+
)
|
9 |
+
prompt = ChatPromptTemplate(
|
10 |
+
[
|
11 |
+
("system",system_prompt),
|
12 |
+
('human',"{input}"),
|
13 |
+
|
14 |
+
]
|
15 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain-google-genai==2.0.1
|
2 |
+
fastapi==0.104.1
|
3 |
+
python-multipart==0.0.6
|
4 |
+
langchain
|
5 |
+
pypdf==3.17.1
|
6 |
+
psycopg2-binary==2.9.9
|
7 |
+
pgvector==0.2.3
|
8 |
+
python-dotenv==1.0.0
|
9 |
+
uvicorn==0.24.0
|
10 |
+
langchain-postgres
|
11 |
+
langchain-openai
|
12 |
+
psycopg
|
13 |
+
psycopg2
|
14 |
+
psycopg[binary,pool]
|