Update app.py
Browse files
app.py
CHANGED
@@ -2,21 +2,25 @@ from fastapi import APIRouter, HTTPException
|
|
2 |
from pydantic import BaseModel
|
3 |
from pathlib import Path
|
4 |
import os
|
|
|
5 |
from langchain_community.vectorstores import FAISS
|
6 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_together import Together
|
9 |
from langchain.memory import ConversationBufferWindowMemory
|
10 |
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
|
11 |
|
12 |
# Set the API key for Together.ai
|
13 |
-
|
14 |
|
15 |
# Ensure proper cache directory is available for models
|
16 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
|
17 |
|
18 |
# Initialize FastAPI Router
|
19 |
-
|
|
|
20 |
|
21 |
# Lazy loading of large models (only load embeddings and index when required)
|
22 |
embeddings = HuggingFaceEmbeddings(
|
@@ -26,7 +30,7 @@ embeddings = HuggingFaceEmbeddings(
|
|
26 |
|
27 |
index_path = Path("models/index.faiss")
|
28 |
if not index_path.exists():
|
29 |
-
raise FileNotFoundError("FAISS index not found. Please generate it and place it in '
|
30 |
|
31 |
# Load the FAISS index
|
32 |
db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
|
@@ -46,7 +50,7 @@ llm = Together(
|
|
46 |
model="mistralai/Mistral-7B-Instruct-v0.2",
|
47 |
temperature=0.5,
|
48 |
max_tokens=1024,
|
49 |
-
together_api_key=TOGETHER_AI_API,
|
50 |
)
|
51 |
|
52 |
# Set up memory for conversational context
|
@@ -60,25 +64,85 @@ qa_chain = ConversationalRetrievalChain.from_llm(
|
|
60 |
combine_docs_chain_kwargs={"prompt": prompt},
|
61 |
)
|
62 |
|
|
|
|
|
|
|
63 |
# Input schema for chat requests
|
64 |
class ChatRequest(BaseModel):
|
65 |
question: str
|
66 |
chat_history: str
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# POST endpoint to handle chat requests
|
69 |
-
@
|
70 |
async def chat(request: ChatRequest):
|
71 |
try:
|
72 |
-
#
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
result = qa_chain(inputs)
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
except Exception as e:
|
78 |
-
# Return an error if something goes wrong
|
79 |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
80 |
|
81 |
# GET endpoint to check if the API is running
|
82 |
-
@
|
83 |
async def root():
|
84 |
-
return {"message": "LawGPT API is running."}
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from pathlib import Path
|
4 |
import os
|
5 |
+
import re
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
8 |
from langchain.prompts import PromptTemplate
|
9 |
from langchain_together import Together
|
10 |
from langchain.memory import ConversationBufferWindowMemory
|
11 |
from langchain.chains import ConversationalRetrievalChain
|
12 |
+
from langdetect import detect
|
13 |
+
from googletrans import Translator, LANGUAGES
|
14 |
|
15 |
# Set the API key for Together.ai
|
16 |
+
os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd"
|
17 |
|
18 |
# Ensure proper cache directory is available for models
|
19 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
|
20 |
|
21 |
# Initialize FastAPI Router
|
22 |
+
router = APIRouter()
|
23 |
+
bot_name = "LawGPT"
|
24 |
|
25 |
# Lazy loading of large models (only load embeddings and index when required)
|
26 |
embeddings = HuggingFaceEmbeddings(
|
|
|
30 |
|
31 |
index_path = Path("models/index.faiss")
|
32 |
if not index_path.exists():
|
33 |
+
raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.")
|
34 |
|
35 |
# Load the FAISS index
|
36 |
db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
|
|
|
50 |
model="mistralai/Mistral-7B-Instruct-v0.2",
|
51 |
temperature=0.5,
|
52 |
max_tokens=1024,
|
53 |
+
together_api_key=os.getenv("TOGETHER_AI_API"),
|
54 |
)
|
55 |
|
56 |
# Set up memory for conversational context
|
|
|
64 |
combine_docs_chain_kwargs={"prompt": prompt},
|
65 |
)
|
66 |
|
67 |
+
# Translator instance (sync version)
|
68 |
+
translator = Translator()
|
69 |
+
|
70 |
# Input schema for chat requests
|
71 |
class ChatRequest(BaseModel):
|
72 |
question: str
|
73 |
chat_history: str
|
74 |
|
75 |
+
# Function to validate the input question
|
76 |
+
def is_valid_question(question: str) -> bool:
|
77 |
+
"""
|
78 |
+
Validate the input question to ensure it is meaningful and related to Indian law or crime.
|
79 |
+
"""
|
80 |
+
question = question.strip()
|
81 |
+
|
82 |
+
# Reject if the question is too short
|
83 |
+
if len(question) < 3:
|
84 |
+
return False
|
85 |
+
|
86 |
+
# Reject if the question contains only numbers or symbols
|
87 |
+
if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question):
|
88 |
+
return False
|
89 |
+
|
90 |
+
# Define keywords related to Indian law and crime
|
91 |
+
legal_keywords = [
|
92 |
+
"IPC", "CrPC", "section", "law", "crime", "penalty", "punishment",
|
93 |
+
"legal", "court", "justice", "offense", "fraud", "murder", "theft",
|
94 |
+
"bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional",
|
95 |
+
"habeas corpus", "petition", "rights", "lawyer", "advocate", "accused",
|
96 |
+
"penal", "conviction", "sentence", "appeal", "trial", "witness"
|
97 |
+
]
|
98 |
+
|
99 |
+
# Check if the question contains at least one legal keyword
|
100 |
+
if not any(keyword.lower() in question.lower() for keyword in legal_keywords):
|
101 |
+
return False
|
102 |
+
|
103 |
+
return True
|
104 |
+
|
105 |
# POST endpoint to handle chat requests
|
106 |
+
@router.post("/chat/")
|
107 |
async def chat(request: ChatRequest):
|
108 |
try:
|
109 |
+
# Detect language
|
110 |
+
detected_lang = await translator.detect(request.question)
|
111 |
+
detected_language = detected_lang.lang
|
112 |
+
|
113 |
+
# Translate question to English
|
114 |
+
question_translation = await translator.translate(request.question, src=detected_language, dest="en")
|
115 |
+
question_in_english = question_translation.text
|
116 |
+
|
117 |
+
# Validate translated question
|
118 |
+
if not is_valid_question(question_in_english):
|
119 |
+
return {
|
120 |
+
"answer": "Please provide a valid legal question related to Indian laws.",
|
121 |
+
"language": LANGUAGES.get(detected_language, "unknown")
|
122 |
+
}
|
123 |
+
|
124 |
+
# Prepare input for LLM
|
125 |
+
inputs = {"question": question_in_english, "chat_history": request.chat_history}
|
126 |
+
|
127 |
+
# Run LLM retrieval chain
|
128 |
result = qa_chain(inputs)
|
129 |
+
|
130 |
+
# Ensure response contains an answer
|
131 |
+
if 'answer' not in result:
|
132 |
+
raise ValueError("Missing 'answer' key in the result from qa_chain")
|
133 |
+
|
134 |
+
# Translate response back to original language
|
135 |
+
answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language)
|
136 |
+
answer_in_original_language = answer_translation.text
|
137 |
+
|
138 |
+
return {
|
139 |
+
"answer": answer_in_original_language,
|
140 |
+
"language": LANGUAGES.get(detected_language, "unknown")
|
141 |
+
}
|
142 |
except Exception as e:
|
|
|
143 |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
144 |
|
145 |
# GET endpoint to check if the API is running
|
146 |
+
@router.get("/")
|
147 |
async def root():
|
148 |
+
return {"message": "LawGPT API is running."}
|