''' Necessary Imports ''' from fastapi import FastAPI, UploadFile, File, HTTPException,Form from fastapi.middleware.cors import CORSMiddleware from langchain.text_splitter import RecursiveCharacterTextSplitter from postgres import PostgresChatMessageHistory from langchain_community.document_loaders import PyPDFLoader from langchain_postgres.vectorstores import PGVector from langchain_google_genai import ChatGoogleGenerativeAI from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_google_genai import GoogleGenerativeAIEmbeddings from typing import Dict from langchain_openai import ChatOpenAI from prompt import prompt,system_prompt import psycopg import uuid import os from custom_message import CustomMessage from dotenv import load_dotenv import os from io import BytesIO from pypdf import PdfReader from langchain.docstore.document import Document vector_store = None # LOADING ENVIRONMENT VARIABLES load_dotenv() # INSTANTIATING THE APP app = FastAPI() llm = ChatOpenAI(model="gpt-4o", temperature=0.2, max_tokens=None, timeout=None, max_retries=1) # ALLOWING CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # INITIALIZING THE EMBEDDING MODEL embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=300, length_function=len, ) @app.get("/") def greeting(): return {'response':'success','status code':200} # PDF UPLOAD ROUTE @app.post("/upload") async def upload_pdf(file: UploadFile = File(...), collection_name: str = Form(...)): """ Upload and process a PDF file, storing its embeddings in the vector database. """ if not file.filename.endswith('.pdf'): raise HTTPException(status_code=400, detail="Only PDF files are allowed") try: # Read PDF content directly into memory pdf_content = await file.read() pdf_file = BytesIO(pdf_content) pdf_reader = PdfReader(pdf_file) # Extract text from PDF documents = [] for page_num, page in enumerate(pdf_reader.pages): text = page.extract_text() # Create a Document object with metadata doc = Document( page_content=text, metadata={"page": page_num + 1, "source": file.filename} ) documents.append(doc) # Split documents into chunks texts = text_splitter.split_documents(documents) try: global vector_store vector_store = PGVector.from_documents( documents=texts, embedding=embeddings, connection=os.environ['CONNECTION_STRING'], collection_name=collection_name, use_jsonb=True, ) except Exception as e: raise("Error in establishing the connection with DB: {e}") return {"message": "PDF processed successfully", "collection_name": file.filename.replace('.pdf', '')} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/query") async def upload_pdf(query: str = Form(...),collection_name:str = Form(...),username:str = Form(...),table_name:str = Form(...)): try: global vector_store if vector_store == None : vector_store = PGVector( embeddings=embeddings, connection=os.environ['CONNECTION_STRING'], collection_name=collection_name, use_jsonb=True, ) retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) question_answer_chain = create_stuff_documents_chain(llm, prompt) rag_chain = create_retrieval_chain(retriever, question_answer_chain) response = rag_chain.invoke({"input":query})['answer'] sync_connection = psycopg.connect(os.environ['CONNECTION_STRING']) session_id = str(uuid.uuid4()) chat_history = PostgresChatMessageHistory( table_name, session_id, username, sync_connection=sync_connection ) try: custom_message = CustomMessage(content=f"SYSTEM_PROMPT:{system_prompt}\n\nHUMAN_MESSAGE:{query}\n\nAI_RESPONSE:{response}") chat_history.add_message(custom_message) except Exception as e: print(e) print("Ended") return { "relevant docs":response, "session_id":session_id } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=8000)