Spaces:
Sleeping
Sleeping
from typing import List, Dict | |
from typing_extensions import TypedDict | |
from src.websearch import * | |
from src.llm import * | |
from langchain.schema import Document, AIMessage | |
from langgraph.errors import GraphRecursionError | |
class GraphState(TypedDict): | |
question: str | |
generation: str | |
documents: List[str] | |
chat_history: List[Dict[str, str]] | |
def understand_intent(state): | |
print("---UNDERSTAND INTENT---") | |
question = state["question"].lower() | |
chat_history = state.get("chat_history", []) | |
# context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-2:]]) | |
intent = intent_classifier.invoke({"question": question}) | |
print(f"Intent: {intent}") # Debug print | |
return {"intent": intent, "question": question} | |
def intent_aware_response(state): | |
print("---INTENT-AWARE RESPONSE---") | |
question = state["question"] | |
chat_history = state.get("chat_history", []) | |
intent = state.get("intent", "") | |
print(f"Responding to intent: {intent}") # Debug print | |
# Check if intent is an IntentClassifier object | |
if hasattr(intent, 'intent'): | |
intent = intent.intent.lower() | |
elif isinstance(intent, str): | |
intent = intent.lower().strip("intent='").rstrip("'") | |
else: | |
print(f"Unexpected intent type: {type(intent)}") | |
intent = "unknown" | |
if intent == 'greeting': | |
return "greeting" | |
elif intent == 'off_topic': | |
return "off_topic" | |
elif intent in ["legal_query", "follow_up"]: | |
return "route_question" | |
else: | |
print(f"Unknown intent '{intent}', treating as off-topic") | |
return "off_topic" | |
def retrieve(state): | |
print("---RETRIEVE---") | |
question = state["question"] | |
documents = retriever.invoke(question) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
print("---GENERATE---") | |
question = state["question"] | |
documents = state["documents"] | |
chat_history = state.get("chat_history", []) | |
context = "\n".join([doc.page_content for doc in documents]) | |
chat_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-5:]]) | |
generation_prompt = f""" | |
As LegalAlly, an AI assistant specializing in the Indian Penal Code, provide a helpful and informative response to the following question. Use the given context and chat history for reference. | |
Responses are concise and answer user's queries directly. They are not verbose. The answer feels natural and not robotic. | |
Make sure the answer is grounded in the documents and is not hallucination. | |
Context: | |
{context} | |
Chat History: | |
{chat_context} | |
Question: {question} | |
Response: | |
""" | |
generation = llm.invoke(generation_prompt) | |
generation = generation.content if hasattr(generation, 'content') else str(generation) | |
return { | |
"documents": documents, | |
"question": question, | |
"generation": generation, | |
"chat_history": chat_history + [{"role": "human", "content": question}, {"role": "ai", "content": generation}] | |
} | |
def grade_documents(state): | |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
question = state["question"] | |
documents = state["documents"] | |
filtered_docs = [] | |
for d in documents: | |
score = retrieval_grader.invoke( | |
{"question": question, "document": d.page_content} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---") | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---") | |
continue | |
return {"documents": filtered_docs, "question": question} | |
def transform_query(state): | |
print("---TRANSFORM QUERY---") | |
question = state["question"] | |
documents = state["documents"] | |
better_question = question_rewriter.invoke({"question": question}) | |
return {"documents": documents, "question": better_question} | |
def web_search(state): | |
print("---WEB SEARCH---") | |
question = state["question"] | |
web_results = web_search_tool.invoke({"query": question}) | |
if isinstance(web_results, str): | |
web_results = [{"content": web_results}] | |
elif isinstance(web_results, list): | |
web_results = [{"content": result} for result in web_results if isinstance(result, str)] | |
else: | |
web_results = [] | |
web_content = "\n".join([d["content"] for d in web_results]) | |
web_document = Document(page_content=web_content) | |
return {"documents": [web_document], "question": question} | |
def route_question(state): | |
""" | |
Route question to web search or RAG. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
dict: Updated state with routing information | |
""" | |
print("---ROUTE QUESTION---") | |
question = state["question"] | |
source = question_router.invoke({"question": question}) | |
if source.datasource == "web_search": | |
print("---ROUTE QUESTION TO WEB SEARCH---") | |
return { | |
"route_question": "web_search", | |
"question": question # Maintain the current question | |
} | |
elif source.datasource == "vectorstore": | |
print("---ROUTE QUESTION TO RAG---") | |
return { | |
"route_question": "vectorstore", | |
"question": question # Maintain the current question | |
} | |
else: | |
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---") | |
return { | |
"route_question": "vectorstore", | |
"question": question # Maintain the current question | |
} | |
def decide_to_generate(state): | |
print("---ASSESS GRADED DOCUMENTS---") | |
state["question"] | |
filtered_documents = state["documents"] | |
if not filtered_documents: | |
print( | |
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" | |
) | |
return "transform_query" | |
else: | |
print("---DECISION: GENERATE---") | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
chat_history = state.get("chat_history", []) | |
score = hallucination_grader.invoke( | |
{"documents": documents, "generation": generation} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
score = answer_grader.invoke({"question": question, "generation": generation}) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return { | |
"grade_generation": "useful", | |
"question": question, | |
"generation": generation, | |
"documents": documents, | |
"chat_history": chat_history | |
} | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return { | |
"grade_generation": "not useful", | |
"question": question, | |
"generation": generation, | |
"documents": documents, | |
"chat_history": chat_history | |
} | |
else: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
return { | |
"grade_generation": "not supported", | |
"question": question, | |
"generation": generation, | |
"documents": documents, | |
"chat_history": chat_history | |
} | |
def greeting(state): | |
print("---GREETING---") | |
return { | |
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?" | |
} | |
def off_topic(state): | |
print("---OFF-TOPIC---") | |
return { | |
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?" | |
} | |
# conditional edges for recursion limit | |
def check_recursion_limit(state): | |
# LangGraph will automatically raise GraphRecursionError if the limit is reached | |
# We don't need to check for it explicitly | |
return "continue" |