LegalAlly / src /graph.py
Rohil Bansal
working...
d8143c9
raw
history blame
8.29 kB
from typing import List, Dict
from typing_extensions import TypedDict
from src.websearch import *
from src.llm import *
from langchain.schema import Document, AIMessage
from langgraph.errors import GraphRecursionError
class GraphState(TypedDict):
question: str
generation: str
documents: List[str]
chat_history: List[Dict[str, str]]
def understand_intent(state):
print("---UNDERSTAND INTENT---")
question = state["question"].lower()
chat_history = state.get("chat_history", [])
# context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-2:]])
intent = intent_classifier.invoke({"question": question})
print(f"Intent: {intent}") # Debug print
return {"intent": intent, "question": question}
def intent_aware_response(state):
print("---INTENT-AWARE RESPONSE---")
question = state["question"]
chat_history = state.get("chat_history", [])
intent = state.get("intent", "")
print(f"Responding to intent: {intent}") # Debug print
# Check if intent is an IntentClassifier object
if hasattr(intent, 'intent'):
intent = intent.intent.lower()
elif isinstance(intent, str):
intent = intent.lower().strip("intent='").rstrip("'")
else:
print(f"Unexpected intent type: {type(intent)}")
intent = "unknown"
if intent == 'greeting':
return "greeting"
elif intent == 'off_topic':
return "off_topic"
elif intent in ["legal_query", "follow_up"]:
return "route_question"
else:
print(f"Unknown intent '{intent}', treating as off-topic")
return "off_topic"
def retrieve(state):
print("---RETRIEVE---")
question = state["question"]
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
chat_history = state.get("chat_history", [])
context = "\n".join([doc.page_content for doc in documents])
chat_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-5:]])
generation_prompt = f"""
As LegalAlly, an AI assistant specializing in the Indian Penal Code, provide a helpful and informative response to the following question. Use the given context and chat history for reference.
Responses are concise and answer user's queries directly. They are not verbose. The answer feels natural and not robotic.
Make sure the answer is grounded in the documents and is not hallucination.
Context:
{context}
Chat History:
{chat_context}
Question: {question}
Response:
"""
generation = llm.invoke(generation_prompt)
generation = generation.content if hasattr(generation, 'content') else str(generation)
return {
"documents": documents,
"question": question,
"generation": generation,
"chat_history": chat_history + [{"role": "human", "content": question}, {"role": "ai", "content": generation}]
}
def grade_documents(state):
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(state):
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
print("---WEB SEARCH---")
question = state["question"]
web_results = web_search_tool.invoke({"query": question})
if isinstance(web_results, str):
web_results = [{"content": web_results}]
elif isinstance(web_results, list):
web_results = [{"content": result} for result in web_results if isinstance(result, str)]
else:
web_results = []
web_content = "\n".join([d["content"] for d in web_results])
web_document = Document(page_content=web_content)
return {"documents": [web_document], "question": question}
def route_question(state):
"""
Route question to web search or RAG.
Args:
state (dict): The current graph state
Returns:
dict: Updated state with routing information
"""
print("---ROUTE QUESTION---")
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
print("---ROUTE QUESTION TO WEB SEARCH---")
return {
"route_question": "web_search",
"question": question # Maintain the current question
}
elif source.datasource == "vectorstore":
print("---ROUTE QUESTION TO RAG---")
return {
"route_question": "vectorstore",
"question": question # Maintain the current question
}
else:
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---")
return {
"route_question": "vectorstore",
"question": question # Maintain the current question
}
def decide_to_generate(state):
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
filtered_documents = state["documents"]
if not filtered_documents:
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
chat_history = state.get("chat_history", [])
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return {
"grade_generation": "useful",
"question": question,
"generation": generation,
"documents": documents,
"chat_history": chat_history
}
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return {
"grade_generation": "not useful",
"question": question,
"generation": generation,
"documents": documents,
"chat_history": chat_history
}
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return {
"grade_generation": "not supported",
"question": question,
"generation": generation,
"documents": documents,
"chat_history": chat_history
}
def greeting(state):
print("---GREETING---")
return {
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?"
}
def off_topic(state):
print("---OFF-TOPIC---")
return {
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?"
}
# conditional edges for recursion limit
def check_recursion_limit(state):
# LangGraph will automatically raise GraphRecursionError if the limit is reached
# We don't need to check for it explicitly
return "continue"