Spaces:
Running
Running
File size: 4,116 Bytes
5f4a7cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import logging
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import warnings
# Logging configuration
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.debug("Starting FastAPI app...")
# Suppress warnings
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Load environment variables
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
if not TOGETHER_AI_API:
raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")
# Initialize embeddings and vectorstore
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
)
# Ensure FAISS vectorstore is loaded properly
try:
db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
except Exception as e:
logger.error(f"Error loading FAISS vectorstore: {e}")
raise RuntimeError("FAISS vectorstore could not be loaded. Ensure the vector database exists.")
# Define the prompt template
prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context."
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
</s>[INST]
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
# Initialize the Together API
try:
llm = Together(
model="mistralai/Mistral-7B-Instruct-v0.2",
temperature=0.5,
max_tokens=1024,
together_api_key=TOGETHER_AI_API,
)
except Exception as e:
logger.error(f"Error initializing Together API: {e}")
raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")
# Initialize conversational retrieval chain
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=db_retriever,
combine_docs_chain_kwargs={"prompt": prompt},
)
# Initialize FastAPI app
app = FastAPI()
# Define request and response models
class ChatRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
answer: str
# Health check endpoint
@app.get("/")
async def root():
return {"message": "Hello, World!"}
# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
# Pass the user question
result = qa.invoke(input=request.question)
answer = result.get("answer", "The chatbot could not generate a response.")
return ChatResponse(answer=answer)
except Exception as e:
logger.error(f"Error during chat invocation: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
|