timeki's picture
add_poc_french_local_insights (#18)
2fc727e verified
import os
from datetime import datetime
import gradio as gr
# from .agent import agent
from gradio import ChatMessage
from langgraph.graph.state import CompiledStateGraph
import json
from .handle_stream_events import (
init_audience,
handle_retrieved_documents,
convert_to_docs_to_html,
stream_answer,
handle_retrieved_owid_graphs,
serialize_docs,
)
# Function to log data on Azure
def log_on_azure(file, logs, share_client):
logs = json.dumps(logs)
file_client = share_client.get_file_client(file)
file_client.upload_file(logs)
# Chat functions
def start_chat(query, history, search_only):
history = history + [ChatMessage(role="user", content=query)]
if not search_only:
return (gr.update(interactive=False), gr.update(selected=1), history, [])
else:
return (gr.update(interactive=False), gr.update(selected=2), history, [])
def finish_chat():
return gr.update(interactive=True, value="")
def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
try:
# Log interaction to Azure if not in local environment
if os.getenv("GRADIO_ENV") != "local":
timestamp = str(datetime.now().timestamp())
prompt = history[1]["content"]
logs = {
"user_id": str(user_id),
"prompt": prompt,
"query": prompt,
"question": output_query,
"sources": sources,
"docs": serialize_docs(docs),
"answer": history[-1].content,
"time": timestamp,
}
log_on_azure(f"{timestamp}.json", logs, share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
raise gr.Error(error_msg)
# Main chat function
async def chat_stream(
agent : CompiledStateGraph,
query: str,
history: list[ChatMessage],
audience: str,
sources: list[str],
reports: list[str],
relevant_content_sources_selection: list[str],
search_only: bool,
share_client,
user_id: str
) -> tuple[list, str, str, str, list, str]:
"""Process a chat query and return response with relevant sources and visualizations.
Args:
query (str): The user's question
history (list): Chat message history
audience (str): Target audience type
sources (list): Knowledge base sources to search
reports (list): Specific reports to search within sources
relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
search_only (bool): Whether to only search without generating answer
Yields:
tuple: Contains:
- history: Updated chat history
- docs_html: HTML of retrieved documents
- output_query: Processed query
- output_language: Detected language
- related_contents: Related content
- graphs_html: HTML of relevant graphs
"""
# Log incoming question
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
audience_prompt = init_audience(audience)
sources = sources or ["IPCC", "IPBES"]
reports = reports or []
# Prepare inputs for agent
inputs = {
"user_input": query,
"audience": audience_prompt,
"sources_input": sources,
"relevant_content_sources_selection": relevant_content_sources_selection,
"search_only": search_only,
"reports": reports
}
# Get streaming events from agent
result = agent.astream_events(inputs, version="v1")
# Initialize state variables
docs = []
related_contents = []
docs_html = ""
new_docs_html = ""
output_query = ""
output_language = ""
output_keywords = ""
start_streaming = False
graphs_html = ""
used_documents = []
retrieved_contents = []
answer_message_content = ""
# Define processing steps
steps_display = {
"categorize_intent": ("πŸ”„οΈ Analyzing user message", True),
"transform_query": ("πŸ”„οΈ Thinking step by step to answer the question", True),
"retrieve_documents": ("πŸ”„οΈ Searching in the knowledge base", False),
"retrieve_local_data": ("πŸ”„οΈ Searching in the knowledge base", False),
}
try:
# Process streaming events
async for event in result:
if "langgraph_node" in event["metadata"]:
node = event["metadata"]["langgraph_node"]
# Handle document retrieval
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
history, used_documents, retrieved_contents = handle_retrieved_documents(
event, history, used_documents, retrieved_contents
)
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
docs = event["data"]["input"]["documents"]
docs_html = convert_to_docs_to_html(docs)
related_contents = event["data"]["input"]["related_contents"]
# Handle intent categorization
elif (event["event"] == "on_chain_end" and
node == "categorize_intent" and
event["name"] == "_write"):
intent = event["data"]["output"]["intent"]
output_language = event["data"]["output"].get("language", "English")
history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
# Handle processing steps display
elif event["name"] in steps_display and event["event"] == "on_chain_start":
event_description, display_output = steps_display[node]
if (not hasattr(history[-1], 'metadata') or
history[-1].metadata["title"] != event_description):
history.append(ChatMessage(
role="assistant",
content="",
metadata={'title': event_description}
))
# Handle answer streaming
elif (event["name"] != "transform_query" and
event["event"] == "on_chat_model_stream" and
node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
history, start_streaming, answer_message_content = stream_answer(
history, event, start_streaming, answer_message_content
)
# Handle graph retrieval
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
# Handle query transformation
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
if hasattr(history[-1], "content"):
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
yield history, docs_html, output_query, output_language, related_contents, graphs_html
except Exception as e:
print(f"Event {event} has failed")
raise gr.Error(str(e))
# Call the function to log interaction
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
yield history, docs_html, output_query, output_language, related_contents, graphs_html