|
from climateqa.engine.embeddings import get_embeddings_function |
|
embeddings_function = get_embeddings_function() |
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
import gradio as gr |
|
from gradio_modal import Modal |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import time |
|
import re |
|
import json |
|
|
|
from gradio import ChatMessage |
|
|
|
|
|
|
|
from io import BytesIO |
|
import base64 |
|
|
|
from datetime import datetime |
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
from utils import create_user_id |
|
|
|
from gradio_modal import Modal |
|
|
|
from PIL import Image |
|
|
|
from langchain_core.runnables.schema import StreamEvent |
|
|
|
|
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
|
|
from climateqa.engine.reranker import get_reranker |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.chains.prompts import audience_prompts |
|
from climateqa.sample_questions import QUESTIONS |
|
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES |
|
from climateqa.utils import get_image_from_azure_blob_storage |
|
from climateqa.engine.graph import make_graph_agent |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.chains.retrieve_papers import find_papers |
|
|
|
from front.utils import serialize_docs,process_figures,make_html_df |
|
|
|
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs |
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
import requests |
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
|
|
init_prompt = "" |
|
|
|
system_template = { |
|
"role": "system", |
|
"content": init_prompt, |
|
} |
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
|
|
CITATION_LABEL = "BibTeX citation for ClimateQ&A" |
|
CITATION_TEXT = r"""@misc{climateqa, |
|
author={Théo Alves Da Costa, Timothée Bohe}, |
|
title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss}, |
|
year={2024}, |
|
howpublished= {\url{https://climateqa.com}}, |
|
} |
|
@software{climateqa, |
|
author = {Théo Alves Da Costa, Timothée Bohe}, |
|
publisher = {ClimateQ&A}, |
|
title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss}, |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX")) |
|
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description") |
|
|
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
reranker = get_reranker("nano") |
|
|
|
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker) |
|
|
|
|
|
async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only): |
|
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: |
|
(messages in gradio format, messages in langchain format, source documents)""" |
|
|
|
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
print(f">> NEW QUESTION ({date_now}) : {query}") |
|
|
|
audience_prompt = init_audience(audience) |
|
|
|
|
|
if sources is None or len(sources) == 0: |
|
sources = ["IPCC", "IPBES", "IPOS"] |
|
|
|
if reports is None or len(reports) == 0: |
|
reports = [] |
|
|
|
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources, "search_only": search_only} |
|
result = agent.astream_events(inputs,version = "v1") |
|
|
|
|
|
docs = [] |
|
used_figures=[] |
|
related_contents = [] |
|
docs_html = "" |
|
output_query = "" |
|
output_language = "" |
|
output_keywords = "" |
|
start_streaming = False |
|
graphs_html = "" |
|
figures = '<div class="figures-container"><p></p> </div>' |
|
|
|
steps_display = { |
|
"categorize_intent":("🔄️ Analyzing user message",True), |
|
"transform_query":("🔄️ Thinking step by step to answer the question",True), |
|
"retrieve_documents":("🔄️ Searching in the knowledge base",False), |
|
} |
|
|
|
used_documents = [] |
|
answer_message_content = "" |
|
try: |
|
async for event in result: |
|
if "langgraph_node" in event["metadata"]: |
|
node = event["metadata"]["langgraph_node"] |
|
|
|
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" : |
|
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents) |
|
|
|
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": |
|
|
|
intent = event["data"]["output"]["intent"] |
|
if "language" in event["data"]["output"]: |
|
output_language = event["data"]["output"]["language"] |
|
else : |
|
output_language = "English" |
|
history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}" |
|
|
|
|
|
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": |
|
event_description, display_output = steps_display[node] |
|
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: |
|
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description})) |
|
|
|
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]: |
|
history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content) |
|
|
|
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end": |
|
graphs_html = handle_retrieved_owid_graphs(event, graphs_html) |
|
|
|
|
|
if event["name"] == "transform_query" and event["event"] =="on_chain_end": |
|
if hasattr(history[-1],"content"): |
|
history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]]) |
|
|
|
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start": |
|
print("X") |
|
|
|
yield history, docs_html, output_query, output_language, related_contents , graphs_html, |
|
|
|
except Exception as e: |
|
print(event, "has failed") |
|
raise gr.Error(f"{e}") |
|
|
|
|
|
try: |
|
|
|
if os.getenv("GRADIO_ENV") != "local": |
|
timestamp = str(datetime.now().timestamp()) |
|
file = timestamp + ".json" |
|
prompt = history[1]["content"] |
|
logs = { |
|
"user_id": str(user_id), |
|
"prompt": prompt, |
|
"query": prompt, |
|
"question":output_query, |
|
"sources":sources, |
|
"docs":serialize_docs(docs), |
|
"answer": history[-1].content, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
except Exception as e: |
|
print(f"Error logging on Azure Blob Storage: {e}") |
|
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") |
|
|
|
yield history, docs_html, output_query, output_language, related_contents, graphs_html |
|
|
|
|
|
def save_feedback(feed: str, user_id): |
|
if len(feed) > 1: |
|
timestamp = str(datetime.now().timestamp()) |
|
file = user_id + timestamp + ".json" |
|
logs = { |
|
"user_id": user_id, |
|
"feedback": feed, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
return "Feedback submitted, thank you!" |
|
|
|
|
|
|
|
|
|
def log_on_azure(file, logs, share_client): |
|
logs = json.dumps(logs) |
|
file_client = share_client.get_file_client(file) |
|
file_client.upload_file(logs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_prompt = """ |
|
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. |
|
|
|
❓ How to use |
|
- **Language**: You can ask me your questions in any language. |
|
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. |
|
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both. |
|
|
|
⚠️ Limitations |
|
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
🛈 Information |
|
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information. |
|
|
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
|
|
def vote(data: gr.LikeData): |
|
if data.liked: |
|
print(data.value) |
|
else: |
|
print(data) |
|
|
|
def save_graph(saved_graphs_state, embedding, category): |
|
print(f"\nCategory:\n{saved_graphs_state}\n") |
|
if category not in saved_graphs_state: |
|
saved_graphs_state[category] = [] |
|
if embedding not in saved_graphs_state[category]: |
|
saved_graphs_state[category].append(embedding) |
|
return saved_graphs_state, gr.Button("Graph Saved") |
|
|
|
|
|
|
|
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo: |
|
chat_completed_state = gr.State(0) |
|
current_graphs = gr.State([]) |
|
saved_graphs = gr.State({}) |
|
|
|
|
|
with gr.Tab("ClimateQ&A"): |
|
|
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
value = [ChatMessage(role="assistant", content=init_prompt)], |
|
type = "messages", |
|
show_copy_button=True, |
|
show_label = False, |
|
elem_id="chatbot", |
|
layout = "panel", |
|
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), |
|
max_height="80vh", |
|
height="100vh" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id = "input-message"): |
|
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") |
|
|
|
|
|
with gr.Column(scale=2, variant="panel",elem_id = "right-panel"): |
|
|
|
|
|
with gr.Tabs(elem_id = "right_panel_tab") as tabs: |
|
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): |
|
|
|
examples_hidden = gr.Textbox(visible = False) |
|
first_key = list(QUESTIONS.keys())[0] |
|
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") |
|
|
|
samples = [] |
|
for i,key in enumerate(QUESTIONS.keys()): |
|
|
|
examples_visible = True if i == 0 else False |
|
|
|
with gr.Row(visible = examples_visible) as group_examples: |
|
|
|
examples_questions = gr.Examples( |
|
QUESTIONS[key], |
|
[examples_hidden], |
|
examples_per_page=8, |
|
run_on_click=False, |
|
elem_id=f"examples{i}", |
|
api_name=f"examples{i}", |
|
|
|
|
|
) |
|
|
|
samples.append(group_examples) |
|
|
|
with gr.Tab("Configuration", id = 10, ) as tab_config: |
|
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!") |
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
dropdown_sources = gr.CheckboxGroup( |
|
["IPCC", "IPBES","IPOS"], |
|
label="Select source", |
|
value=["IPCC"], |
|
interactive=True, |
|
) |
|
dropdown_external_sources = gr.CheckboxGroup( |
|
["IPCC figures","OpenAlex", "OurWorldInData"], |
|
label="Select database to search for relevant content", |
|
value=["IPCC figures"], |
|
interactive=True, |
|
) |
|
|
|
dropdown_reports = gr.Dropdown( |
|
POSSIBLE_REPORTS, |
|
label="Or select specific reports", |
|
multiselect=True, |
|
value=None, |
|
interactive=True, |
|
) |
|
|
|
search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat") |
|
|
|
|
|
dropdown_audience = gr.Dropdown( |
|
["Children","General public","Experts"], |
|
label="Select audience", |
|
value="Experts", |
|
interactive=True, |
|
) |
|
|
|
|
|
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False) |
|
|
|
|
|
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False) |
|
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False) |
|
|
|
|
|
dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after]) |
|
|
|
|
|
|
|
with gr.Tab("Sources",elem_id = "tab-sources",id = 1) as tab_sources: |
|
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") |
|
docs_textbox = gr.State("") |
|
|
|
|
|
|
|
with gr.Tab("Recommended content", elem_id="tab-recommended_content",id=2) as tab_recommended_content: |
|
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content: |
|
|
|
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures: |
|
sources_raw = gr.State() |
|
|
|
with Modal(visible=False, elem_id="modal_figure_galery") as modal: |
|
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh") |
|
|
|
show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True) |
|
show_full_size_figures.click(lambda : Modal(visible=True),None,modal) |
|
|
|
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures") |
|
|
|
|
|
|
|
with gr.Tab("Papers",elem_id = "tab-citations",id = 4) as tab_papers: |
|
|
|
|
|
with gr.Accordion(visible=True, elem_id="papers-summary-popup", label= "See summary of relevant papers", open= False) as summary_popup: |
|
papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary") |
|
|
|
|
|
|
|
with gr.Accordion(visible=True, elem_id="papers-relevant-popup",label= "See relevant papers", open= False) as relevant_popup: |
|
papers_html = gr.HTML(show_label=False, elem_id="papers-textbox") |
|
docs_textbox = gr.State("") |
|
|
|
btn_citations_network = gr.Button("Explore papers citations network") |
|
|
|
with Modal(visible=False) as modal: |
|
citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network") |
|
btn_citations_network.click(lambda: Modal(visible=True), None, modal) |
|
|
|
|
|
|
|
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs: |
|
|
|
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",elem_id="graphs-container") |
|
current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("About",elem_classes = "max-height other-tabs"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
""" |
|
### More info |
|
- See more info at [https://climateqa.com](https://climateqa.com/docs/intro/) |
|
- Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp) |
|
|
|
### Citation |
|
""" |
|
) |
|
with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,): |
|
|
|
gr.Textbox( |
|
value=CITATION_TEXT, |
|
label="", |
|
interactive=False, |
|
show_copy_button=True, |
|
lines=len(CITATION_TEXT.split('\n')), |
|
) |
|
|
|
|
|
|
|
def start_chat(query,history): |
|
history = history + [ChatMessage(role="user", content=query)] |
|
return (gr.update(interactive = False),gr.update(selected=1),history) |
|
|
|
def finish_chat(): |
|
return gr.update(interactive = True,value = "") |
|
|
|
|
|
summary_visible = False |
|
relevant_visible = False |
|
|
|
|
|
def toggle_summary_visibility(): |
|
global summary_visible |
|
summary_visible = not summary_visible |
|
return gr.update(visible=summary_visible) |
|
|
|
def toggle_relevant_visibility(): |
|
global relevant_visible |
|
relevant_visible = not relevant_visible |
|
return gr.update(visible=relevant_visible) |
|
|
|
|
|
def change_completion_status(current_state): |
|
current_state = 1 - current_state |
|
return current_state |
|
|
|
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html): |
|
sources_number = sources_textbox.count("<h2>") |
|
figures_number = figures_cards.count("<h2>") |
|
graphs_number = current_graphs.count("<iframe") |
|
papers_number = papers_html.count("<h2>") |
|
sources_notif_label = f"Sources ({sources_number})" |
|
figures_notif_label = f"Figures ({figures_number})" |
|
graphs_notif_label = f"Graphs ({graphs_number})" |
|
papers_notif_label = f"Papers ({papers_number})" |
|
recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})" |
|
|
|
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label) |
|
|
|
(textbox |
|
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox") |
|
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox") |
|
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox") |
|
|
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples") |
|
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox") |
|
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples") |
|
|
|
) |
|
|
|
|
|
def change_sample_questions(key): |
|
index = list(QUESTIONS.keys()).index(key) |
|
visible_bools = [False] * len(samples) |
|
visible_bools[index] = True |
|
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] |
|
|
|
|
|
sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component]) |
|
|
|
|
|
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) |
|
figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) |
|
current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) |
|
papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) |
|
|
|
|
|
dropdown_samples.change(change_sample_questions,dropdown_samples,samples) |
|
|
|
|
|
textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary]) |
|
examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary]) |
|
|
|
|
|
|
|
|
|
demo.queue() |
|
|
|
demo.launch(ssr_mode=False) |
|
|