import os from fastapi import FastAPI from pydantic import BaseModel from dotenv import load_dotenv from utils.document_loader import load_pdf, create_unique_ids from utils.embeddings import get_embeddings from utils.vector_store import create_vector_store, get_retriever, load_vector_store from utils.rag_chain import get_model, create_rag_chain, get_conversational_rag_chain from utils.gradio_interface import create_gradio_interface from utils.agent import init_agent, get_agent_response import gradio as gr load_dotenv() app = FastAPI() class QuestionRequest(BaseModel): question: str class AnswerResponse(BaseModel): answer: str def init_rag_system(): pdf_path = os.getenv("SOURCE_DATA") vector_store_path = os.getenv("VECTOR_STORE") # Load embeddings embeddings = get_embeddings() if os.path.exists(vector_store_path) and os.listdir(vector_store_path): print("Loading existing vector store...") vector_store = load_vector_store(embeddings) else: print("Creating new vector store...") documents = load_pdf(pdf_path) unique_ids = create_unique_ids(documents) vector_store = create_vector_store(documents, unique_ids, embeddings) retriever = get_retriever(vector_store) model = get_model() rag_chain = create_rag_chain(model, retriever) return get_conversational_rag_chain(rag_chain) # Initialize conversational RAG chain conversational_rag_chain = init_rag_system() # Initialize agent agent = init_agent() @app.post("/rag", response_model=AnswerResponse) async def ask_rag_question(request: QuestionRequest): print(f"RAG Question: {request.question}") response = conversational_rag_chain.invoke( {"input": request.question}, config={"configurable": {"session_id": "default_session"}} ) return AnswerResponse(answer=response["answer"]) @app.post("/agent", response_model=AnswerResponse) async def ask_agent_question(request: QuestionRequest): print(f"Agent Question: {request.question}") response = get_agent_response(agent, request.question) return AnswerResponse(answer=response) interface = create_gradio_interface(app, conversational_rag_chain, agent) app = gr.mount_gradio_app(app, interface, path="/") if __name__ == "__main__": import uvicorn uvicorn.run( app, host=os.getenv("UVICORN_HOST"), port=int(os.getenv("UVICORN_PORT")), # reload=True )