File size: 4,116 Bytes
5f4a7cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from dotenv import load_dotenv
import warnings

# Logging configuration
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.debug("Starting FastAPI app...")

# Suppress warnings
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")
warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Load environment variables
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")

if not TOGETHER_AI_API:
    raise ValueError("Environment variable TOGETHER_AI_API is missing. Please set it in your .env file.")

# Initialize embeddings and vectorstore
embeddings = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1",
    model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
)

# Ensure FAISS vectorstore is loaded properly
try:
    db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)
    db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2, "max_length": 512})
except Exception as e:
    logger.error(f"Error loading FAISS vectorstore: {e}")
    raise RuntimeError("FAISS vectorstore could not be loaded. Ensure the vector database exists.")

# Define the prompt template
prompt_template = """<s>[INST]As a legal chatbot specializing in the Indian Penal Code, provide a concise and accurate answer based on the given context. Avoid unnecessary details or unrelated content. Only respond if the answer can be derived from the provided context; otherwise, say "The information is not available in the provided context." 

CONTEXT: {context}

CHAT HISTORY: {chat_history}

QUESTION: {question}

ANSWER:

</s>[INST]

"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])

# Initialize the Together API
try:
    llm = Together(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        temperature=0.5,
        max_tokens=1024,
        together_api_key=TOGETHER_AI_API,
    )
except Exception as e:
    logger.error(f"Error initializing Together API: {e}")
    raise RuntimeError("Together API could not be initialized. Check your API key and network connection.")

# Initialize conversational retrieval chain
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
    llm=llm,
    memory=memory,
    retriever=db_retriever,
    combine_docs_chain_kwargs={"prompt": prompt},
)

# Initialize FastAPI app
app = FastAPI()

# Define request and response models
class ChatRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    answer: str

# Health check endpoint
@app.get("/")
async def root():
    return {"message": "Hello, World!"}

# Chat endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        # Pass the user question
        result = qa.invoke(input=request.question)
        answer = result.get("answer", "The chatbot could not generate a response.")
        return ChatResponse(answer=answer)
    except Exception as e:
        logger.error(f"Error during chat invocation: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")