File size: 8,155 Bytes
2fc727e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
from datetime import datetime
import gradio as gr
# from .agent import agent
from gradio import ChatMessage
from langgraph.graph.state import CompiledStateGraph
import json
from .handle_stream_events import (
init_audience,
handle_retrieved_documents,
convert_to_docs_to_html,
stream_answer,
handle_retrieved_owid_graphs,
serialize_docs,
)
# Function to log data on Azure
def log_on_azure(file, logs, share_client):
logs = json.dumps(logs)
file_client = share_client.get_file_client(file)
file_client.upload_file(logs)
# Chat functions
def start_chat(query, history, search_only):
history = history + [ChatMessage(role="user", content=query)]
if not search_only:
return (gr.update(interactive=False), gr.update(selected=1), history, [])
else:
return (gr.update(interactive=False), gr.update(selected=2), history, [])
def finish_chat():
return gr.update(interactive=True, value="")
def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
try:
# Log interaction to Azure if not in local environment
if os.getenv("GRADIO_ENV") != "local":
timestamp = str(datetime.now().timestamp())
prompt = history[1]["content"]
logs = {
"user_id": str(user_id),
"prompt": prompt,
"query": prompt,
"question": output_query,
"sources": sources,
"docs": serialize_docs(docs),
"answer": history[-1].content,
"time": timestamp,
}
log_on_azure(f"{timestamp}.json", logs, share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
raise gr.Error(error_msg)
# Main chat function
async def chat_stream(
agent : CompiledStateGraph,
query: str,
history: list[ChatMessage],
audience: str,
sources: list[str],
reports: list[str],
relevant_content_sources_selection: list[str],
search_only: bool,
share_client,
user_id: str
) -> tuple[list, str, str, str, list, str]:
"""Process a chat query and return response with relevant sources and visualizations.
Args:
query (str): The user's question
history (list): Chat message history
audience (str): Target audience type
sources (list): Knowledge base sources to search
reports (list): Specific reports to search within sources
relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
search_only (bool): Whether to only search without generating answer
Yields:
tuple: Contains:
- history: Updated chat history
- docs_html: HTML of retrieved documents
- output_query: Processed query
- output_language: Detected language
- related_contents: Related content
- graphs_html: HTML of relevant graphs
"""
# Log incoming question
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
audience_prompt = init_audience(audience)
sources = sources or ["IPCC", "IPBES"]
reports = reports or []
# Prepare inputs for agent
inputs = {
"user_input": query,
"audience": audience_prompt,
"sources_input": sources,
"relevant_content_sources_selection": relevant_content_sources_selection,
"search_only": search_only,
"reports": reports
}
# Get streaming events from agent
result = agent.astream_events(inputs, version="v1")
# Initialize state variables
docs = []
related_contents = []
docs_html = ""
new_docs_html = ""
output_query = ""
output_language = ""
output_keywords = ""
start_streaming = False
graphs_html = ""
used_documents = []
retrieved_contents = []
answer_message_content = ""
# Define processing steps
steps_display = {
"categorize_intent": ("ποΈ Analyzing user message", True),
"transform_query": ("ποΈ Thinking step by step to answer the question", True),
"retrieve_documents": ("ποΈ Searching in the knowledge base", False),
"retrieve_local_data": ("ποΈ Searching in the knowledge base", False),
}
try:
# Process streaming events
async for event in result:
if "langgraph_node" in event["metadata"]:
node = event["metadata"]["langgraph_node"]
# Handle document retrieval
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
history, used_documents, retrieved_contents = handle_retrieved_documents(
event, history, used_documents, retrieved_contents
)
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
docs = event["data"]["input"]["documents"]
docs_html = convert_to_docs_to_html(docs)
related_contents = event["data"]["input"]["related_contents"]
# Handle intent categorization
elif (event["event"] == "on_chain_end" and
node == "categorize_intent" and
event["name"] == "_write"):
intent = event["data"]["output"]["intent"]
output_language = event["data"]["output"].get("language", "English")
history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
# Handle processing steps display
elif event["name"] in steps_display and event["event"] == "on_chain_start":
event_description, display_output = steps_display[node]
if (not hasattr(history[-1], 'metadata') or
history[-1].metadata["title"] != event_description):
history.append(ChatMessage(
role="assistant",
content="",
metadata={'title': event_description}
))
# Handle answer streaming
elif (event["name"] != "transform_query" and
event["event"] == "on_chat_model_stream" and
node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
history, start_streaming, answer_message_content = stream_answer(
history, event, start_streaming, answer_message_content
)
# Handle graph retrieval
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
# Handle query transformation
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
if hasattr(history[-1], "content"):
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
yield history, docs_html, output_query, output_language, related_contents, graphs_html
except Exception as e:
print(f"Event {event} has failed")
raise gr.Error(str(e))
# Call the function to log interaction
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
yield history, docs_html, output_query, output_language, related_contents, graphs_html |