aniudupa commited on
Commit
4786462
·
verified ·
1 Parent(s): 91d6566

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -12
app.py CHANGED
@@ -2,21 +2,25 @@ from fastapi import APIRouter, HTTPException
2
  from pydantic import BaseModel
3
  from pathlib import Path
4
  import os
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain.prompts import PromptTemplate
8
  from langchain_together import Together
9
  from langchain.memory import ConversationBufferWindowMemory
10
  from langchain.chains import ConversationalRetrievalChain
 
 
11
 
12
  # Set the API key for Together.ai
13
- TOGETHER_AI_API = os.getenv("TOGETHER_AI_API", "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd")
14
 
15
  # Ensure proper cache directory is available for models
16
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
17
 
18
  # Initialize FastAPI Router
19
- app = APIRouter()
 
20
 
21
  # Lazy loading of large models (only load embeddings and index when required)
22
  embeddings = HuggingFaceEmbeddings(
@@ -26,7 +30,7 @@ embeddings = HuggingFaceEmbeddings(
26
 
27
  index_path = Path("models/index.faiss")
28
  if not index_path.exists():
29
- raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'ipc_vector_db'.")
30
 
31
  # Load the FAISS index
32
  db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
@@ -46,7 +50,7 @@ llm = Together(
46
  model="mistralai/Mistral-7B-Instruct-v0.2",
47
  temperature=0.5,
48
  max_tokens=1024,
49
- together_api_key=TOGETHER_AI_API,
50
  )
51
 
52
  # Set up memory for conversational context
@@ -60,25 +64,85 @@ qa_chain = ConversationalRetrievalChain.from_llm(
60
  combine_docs_chain_kwargs={"prompt": prompt},
61
  )
62
 
 
 
 
63
  # Input schema for chat requests
64
  class ChatRequest(BaseModel):
65
  question: str
66
  chat_history: str
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # POST endpoint to handle chat requests
69
- @app.post("/chat/")
70
  async def chat(request: ChatRequest):
71
  try:
72
- # Prepare the input data
73
- inputs = {"question": request.question, "chat_history": request.chat_history}
74
- # Run the chain to get the answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  result = qa_chain(inputs)
76
- return {"answer": result["answer"]}
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
- # Return an error if something goes wrong
79
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
80
 
81
  # GET endpoint to check if the API is running
82
- @app.get("/")
83
  async def root():
84
- return {"message": "LawGPT API is running."}
 
2
  from pydantic import BaseModel
3
  from pathlib import Path
4
  import os
5
+ import re
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from langchain.prompts import PromptTemplate
9
  from langchain_together import Together
10
  from langchain.memory import ConversationBufferWindowMemory
11
  from langchain.chains import ConversationalRetrievalChain
12
+ from langdetect import detect
13
+ from googletrans import Translator, LANGUAGES
14
 
15
  # Set the API key for Together.ai
16
+ os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd"
17
 
18
  # Ensure proper cache directory is available for models
19
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
20
 
21
  # Initialize FastAPI Router
22
+ router = APIRouter()
23
+ bot_name = "LawGPT"
24
 
25
  # Lazy loading of large models (only load embeddings and index when required)
26
  embeddings = HuggingFaceEmbeddings(
 
30
 
31
  index_path = Path("models/index.faiss")
32
  if not index_path.exists():
33
+ raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.")
34
 
35
  # Load the FAISS index
36
  db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
 
50
  model="mistralai/Mistral-7B-Instruct-v0.2",
51
  temperature=0.5,
52
  max_tokens=1024,
53
+ together_api_key=os.getenv("TOGETHER_AI_API"),
54
  )
55
 
56
  # Set up memory for conversational context
 
64
  combine_docs_chain_kwargs={"prompt": prompt},
65
  )
66
 
67
+ # Translator instance (sync version)
68
+ translator = Translator()
69
+
70
  # Input schema for chat requests
71
  class ChatRequest(BaseModel):
72
  question: str
73
  chat_history: str
74
 
75
+ # Function to validate the input question
76
+ def is_valid_question(question: str) -> bool:
77
+ """
78
+ Validate the input question to ensure it is meaningful and related to Indian law or crime.
79
+ """
80
+ question = question.strip()
81
+
82
+ # Reject if the question is too short
83
+ if len(question) < 3:
84
+ return False
85
+
86
+ # Reject if the question contains only numbers or symbols
87
+ if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question):
88
+ return False
89
+
90
+ # Define keywords related to Indian law and crime
91
+ legal_keywords = [
92
+ "IPC", "CrPC", "section", "law", "crime", "penalty", "punishment",
93
+ "legal", "court", "justice", "offense", "fraud", "murder", "theft",
94
+ "bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional",
95
+ "habeas corpus", "petition", "rights", "lawyer", "advocate", "accused",
96
+ "penal", "conviction", "sentence", "appeal", "trial", "witness"
97
+ ]
98
+
99
+ # Check if the question contains at least one legal keyword
100
+ if not any(keyword.lower() in question.lower() for keyword in legal_keywords):
101
+ return False
102
+
103
+ return True
104
+
105
  # POST endpoint to handle chat requests
106
+ @router.post("/chat/")
107
  async def chat(request: ChatRequest):
108
  try:
109
+ # Detect language
110
+ detected_lang = await translator.detect(request.question)
111
+ detected_language = detected_lang.lang
112
+
113
+ # Translate question to English
114
+ question_translation = await translator.translate(request.question, src=detected_language, dest="en")
115
+ question_in_english = question_translation.text
116
+
117
+ # Validate translated question
118
+ if not is_valid_question(question_in_english):
119
+ return {
120
+ "answer": "Please provide a valid legal question related to Indian laws.",
121
+ "language": LANGUAGES.get(detected_language, "unknown")
122
+ }
123
+
124
+ # Prepare input for LLM
125
+ inputs = {"question": question_in_english, "chat_history": request.chat_history}
126
+
127
+ # Run LLM retrieval chain
128
  result = qa_chain(inputs)
129
+
130
+ # Ensure response contains an answer
131
+ if 'answer' not in result:
132
+ raise ValueError("Missing 'answer' key in the result from qa_chain")
133
+
134
+ # Translate response back to original language
135
+ answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language)
136
+ answer_in_original_language = answer_translation.text
137
+
138
+ return {
139
+ "answer": answer_in_original_language,
140
+ "language": LANGUAGES.get(detected_language, "unknown")
141
+ }
142
  except Exception as e:
 
143
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
144
 
145
  # GET endpoint to check if the API is running
146
+ @router.get("/")
147
  async def root():
148
+ return {"message": "LawGPT API is running."}