timeki's picture
Merge branch 'main' into feature/graph_recommandation
0c4d82b
raw
history blame
31.6 kB
from climateqa.engine.embeddings import get_embeddings_function
embeddings_function = get_embeddings_function()
from climateqa.knowledge.openalex import OpenAlex
from sentence_transformers import CrossEncoder
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
oa = OpenAlex()
import gradio as gr
import pandas as pd
import numpy as np
import os
import time
import re
import json
from gradio import ChatMessage
# from gradio_modal import Modal
from io import BytesIO
import base64
from datetime import datetime
from azure.storage.fileshare import ShareServiceClient
from utils import create_user_id
from langchain_chroma import Chroma
from collections import defaultdict
# ClimateQ&A imports
from climateqa.engine.llm import get_llm
from climateqa.engine.vectorstore import get_pinecone_vectorstore
# from climateqa.knowledge.retriever import ClimateQARetriever
from climateqa.engine.reranker import get_reranker
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.chains.prompts import audience_prompts
from climateqa.sample_questions import QUESTIONS
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
from climateqa.utils import get_image_from_azure_blob_storage
from climateqa.engine.keywords import make_keywords_chain
# from climateqa.engine.chains.answer_rag import make_rag_papers_chain
from climateqa.engine.graph import make_graph_agent,display_graph
from climateqa.engine.embeddings import get_embeddings_function
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox
# Load environment variables in local mode
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = ""
system_template = {
"role": "system",
"content": init_prompt,
}
account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
account_key += "=="
credential = {
"account_key": account_key,
"account_name": os.environ["BLOB_ACCOUNT_NAME"],
}
account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)
user_id = create_user_id()
# Create vectorstore and retriever
vectorstore = get_pinecone_vectorstore(embeddings_function)
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
reranker = get_reranker("large")
agent = make_graph_agent(llm,vectorstore,reranker)
# agent = make_graph_agent(llm,vectorstore,reranker)
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
async def chat(query,history,audience,sources,reports,current_graphs):
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
(messages in gradio format, messages in langchain format, source documents)"""
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
if audience == "Children":
audience_prompt = audience_prompts["children"]
elif audience == "General public":
audience_prompt = audience_prompts["general"]
elif audience == "Experts":
audience_prompt = audience_prompts["experts"]
else:
audience_prompt = audience_prompts["experts"]
# Prepare default values
if sources is None or len(sources) == 0:
sources = ["IPCC", "IPBES", "IPOS"]
if reports is None or len(reports) == 0:
reports = []
inputs = {"user_input": query,"audience": audience_prompt,"sources":sources}
result = agent.astream_events(inputs,version = "v1")
# path_reformulation = "/logs/reformulation/final_output"
# path_keywords = "/logs/keywords/final_output"
# path_retriever = "/logs/find_documents/final_output"
# path_answer = "/logs/answer/streamed_output_str/-"
docs = []
docs_used = True
docs_html = ""
current_graphs = []
output_query = ""
output_language = ""
output_keywords = ""
gallery = []
updates = []
start_streaming = False
figures = '<div class="figures-container"> <p> Go to the "Figures" tab at the top of the page to see full size images </p> </div>'
steps_display = {
"categorize_intent":("πŸ”„οΈ Analyzing user message",True),
"transform_query":("πŸ”„οΈ Thinking step by step to answer the question",True),
"retrieve_documents":("πŸ”„οΈ Searching in the knowledge base",False),
}
used_documents = []
answer_message_content = ""
try:
async for event in result:
if "langgraph_node" in event["metadata"]:
node = event["metadata"]["langgraph_node"]
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
try:
docs = event["data"]["output"]["documents"]
docs_html = []
textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
for i, d in enumerate(textual_docs, 1):
if d.metadata["chunk_type"] == "text":
docs_html.append(make_html_source(d, i))
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
docs_html = "".join(docs_html)
except Exception as e:
print(f"Error getting documents: {e}")
print(event)
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
event_description,display_output = steps_display[node]
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
if start_streaming == False:
start_streaming = True
history.append(ChatMessage(role="assistant", content = ""))
answer_message_content += event["data"]["chunk"].content
answer_message_content = parse_output_llm_with_sources(answer_message_content)
history[-1] = ChatMessage(role="assistant", content = answer_message_content)
# history.append(ChatMessage(role="assistant", content = new_message_content))
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
try:
recommended_content = event["data"]["output"]["recommended_content"]
# graphs = [
# {
# "embedding": x.metadata["returned_content"],
# "metadata": {
# "source": x.metadata["source"],
# "category": x.metadata["category"]
# }
# } for x in recommended_content if x.metadata["source"] == "OWID"
# ]
unique_graphs = []
seen_embeddings = set()
for x in recommended_content:
embedding = x.metadata["returned_content"]
# Check if the embedding has already been seen
if embedding not in seen_embeddings:
unique_graphs.append({
"embedding": embedding,
"metadata": {
"source": x.metadata["source"],
"category": x.metadata["category"]
}
})
# Add the embedding to the seen set
seen_embeddings.add(embedding)
categories = {}
for graph in unique_graphs:
category = graph['metadata']['category']
if category not in categories:
categories[category] = []
categories[category].append(graph['embedding'])
# graphs_html = ""
for category, embeddings in categories.items():
# graphs_html += f"<h3>{category}</h3>"
# current_graphs.append(f"<h3>{category}</h3>")
for embedding in embeddings:
current_graphs.append([embedding, category])
# graphs_html += f"<div>{embedding}</div>"
except Exception as e:
print(f"Error getting graphs: {e}")
if event["name"] == "transform_query" and event["event"] =="on_chain_end":
if hasattr(history[-1],"content"):
history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]])
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
print("X")
yield history,docs_html,output_query,output_language,gallery, figures, current_graphs #,output_query,output_keywords
except Exception as e:
print(event, "has failed")
raise gr.Error(f"{e}")
try:
# Log answer on Azure Blob Storage
if os.getenv("GRADIO_ENV") != "local":
timestamp = str(datetime.now().timestamp())
file = timestamp + ".json"
prompt = history[1]["content"]
logs = {
"user_id": str(user_id),
"prompt": prompt,
"query": prompt,
"question":output_query,
"sources":sources,
"docs":serialize_docs(docs),
"answer": history[-1].content,
"time": timestamp,
}
log_on_azure(file, logs, share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
# image_dict = {}
# for i,doc in enumerate(docs):
# if doc.metadata["chunk_type"] == "image":
# try:
# key = f"Image {i+1}"
# image_path = doc.metadata["image_path"].split("documents/")[1]
# img = get_image_from_azure_blob_storage(image_path)
# # Convert the image to a byte buffer
# buffered = BytesIO()
# img.save(buffered, format="PNG")
# img_str = base64.b64encode(buffered.getvalue()).decode()
# # Embedding the base64 string in Markdown
# markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
# image_dict[key] = {"img":img,"md":markdown_image,"short_name": doc.metadata["short_name"],"figure_code":doc.metadata["figure_code"],"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"], "img_str" : img_str}
# except Exception as e:
# print(f"Skipped adding image {i} because of {e}")
# if len(image_dict) > 0:
# gallery = [x["img"] for x in list(image_dict.values())]
# img = list(image_dict.values())[0]
# img_md = img["md"]
# img_caption = img["caption"]
# img_code = img["figure_code"]
# if img_code != "N/A":
# img_name = f"{img['key']} - {img['figure_code']}"
# else:
# img_name = f"{img['key']}"
# history.append(ChatMessage(role="assistant", content = f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"))
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
for i, doc in enumerate(docs_figures):
if doc.metadata["chunk_type"] == "image":
try:
key = f"Image {i+1}"
image_path = doc.metadata["image_path"].split("documents/")[1]
img = get_image_from_azure_blob_storage(image_path)
# Convert the image to a byte buffer
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
figures = figures + make_html_figure_sources(doc, i, img_str)
gallery.append(img)
except Exception as e:
print(f"Skipped adding image {i} because of {e}")
yield history,docs_html,output_query,output_language,gallery, figures#,output_query,output_keywords
def save_feedback(feed: str, user_id):
if len(feed) > 1:
timestamp = str(datetime.now().timestamp())
file = user_id + timestamp + ".json"
logs = {
"user_id": user_id,
"feedback": feed,
"time": timestamp,
}
log_on_azure(file, logs, share_client)
return "Feedback submitted, thank you!"
def log_on_azure(file, logs, share_client):
logs = json.dumps(logs)
file_client = share_client.get_file_client(file)
file_client.upload_file(logs)
def generate_keywords(query):
chain = make_keywords_chain(llm)
keywords = chain.invoke(query)
keywords = " AND ".join(keywords["keywords"])
return keywords
papers_cols_widths = {
"doc":50,
"id":100,
"title":300,
"doi":100,
"publication_year":100,
"abstract":500,
"rerank_score":100,
"is_oa":50,
}
papers_cols = list(papers_cols_widths.keys())
papers_cols_widths = list(papers_cols_widths.values())
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
❓ How to use
- **Language**: You can ask me your questions in any language.
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
πŸ›ˆ Information
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
What do you want to learn ?
"""
def vote(data: gr.LikeData):
if data.liked:
print(data.value)
else:
print(data)
def save_graph(saved_graphs_state, embedding, category):
print(f"\nCategory:\n{saved_graphs_state}\n")
if category not in saved_graphs_state:
saved_graphs_state[category] = []
if embedding not in saved_graphs_state[category]:
saved_graphs_state[category].append(embedding)
return saved_graphs_state, gr.Button("Graph Saved")
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
# chat_completed_state = gr.State(0)
# current_graphs = gr.State([])
# saved_graphs = gr.State({})
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
chat_completed_state = gr.State(0)
current_graphs = gr.State([])
saved_graphs = gr.State({})
with gr.Tab("ClimateQ&A"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value = [ChatMessage(role="assistant", content=init_prompt)],
type = "messages",
show_copy_button=True,
show_label = False,
elem_id="chatbot",
layout = "panel",
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
max_height="80vh",
height="100vh"
)
# bot.like(vote,None,None)
with gr.Row(elem_id = "input-message"):
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
with gr.Column(scale=1.5, variant="panel",elem_id = "right-panel"):
with gr.Tabs() as tabs:
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
examples_hidden = gr.Textbox(visible = False)
first_key = list(QUESTIONS.keys())[0]
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
samples = []
for i,key in enumerate(QUESTIONS.keys()):
examples_visible = True if i == 0 else False
with gr.Row(visible = examples_visible) as group_examples:
examples_questions = gr.Examples(
QUESTIONS[key],
[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}",
# label = "Click on the example question or enter your own",
# cache_examples=True,
)
samples.append(group_examples)
with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
docs_textbox = gr.State("")
# with Modal(visible = False) as config_modal:
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
dropdown_sources = gr.CheckboxGroup(
["IPCC", "IPBES","IPOS"],
label="Select source",
value=["IPCC"],
interactive=True,
)
dropdown_reports = gr.Dropdown(
POSSIBLE_REPORTS,
label="Or select specific reports",
multiselect=True,
value=None,
interactive=True,
)
dropdown_audience = gr.Dropdown(
["Children","General public","Experts"],
label="Select audience",
value="Experts",
interactive=True,
)
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=3) as recommended_content_tab:
# placeholder_message = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])
# @gr.render(inputs=[current_graphs])
# def display_default_recommended(current_graphs):
# if len(current_graphs)==0:
# placeholder_message = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
# @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
# def render_graphs(current_graph_list):
# global saved_graphs
# with gr.Column():
# print(f"\ncurrent_graph_list:\n{current_graph_list}")
# for (embedding, category) in current_graph_list:
# graphs_placeholder = gr.HTML(embedding, elem_id="graphs-placeholder")
# save_btn = gr.Button("Save Graph")
# save_btn.click(
# save_graph,
# [saved_graphs, gr.State(embedding), gr.State(category)],
# [saved_graphs, save_btn]
# )
# # Display current_graphs
# with gr.Row():
# for embedding in current_graphs:
# with gr.Column():
# gr.HTML(embedding, elem_id="graphs-placeholder")
# save_btn = gr.Button("Save Graph")
# save_btn.click(
# save_graph,
# [saved_graphs, gr.State(embedding)],
# [saved_graphs, save_btn]
# )
with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
#---------------------------------------------------------------------------------------
# OTHER TABS
#---------------------------------------------------------------------------------------
# with gr.Tab("Recommended content", elem_id="tab-recommended_content2") as recommended_content_tab2:
# @gr.render(inputs=[current_graphs])
# def display_default_recommended_head(current_graphs_list):
# if len(current_graphs_list)==0:
# gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
# @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
# def render_graphs_head(current_graph_list):
# global saved_graphs
# category_dict = defaultdict(list)
# for (embedding, category) in current_graph_list:
# category_dict[category].append(embedding)
# for category in category_dict:
# with gr.Tab(category):
# splits = [category_dict[category][i:i+3] for i in range(0, len(category_dict[category]), 3)]
# for row in splits:
# with gr.Row():
# for embedding in row:
# with gr.Column():
# gr.HTML(embedding, elem_id="graphs-placeholder")
# save_btn = gr.Button("Save Graph")
# save_btn.click(
# save_graph,
# [saved_graphs, gr.State(embedding), gr.State(category)],
# [saved_graphs, save_btn]
# )
# with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs") as saved_graphs_tab:
# @gr.render(inputs=[saved_graphs])
# def display_default_save(saved):
# if len(saved)==0:
# gr.HTML("<h2>You have not saved any graphs yet</h2>")
# @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
# def view_saved_graphs(graphs_list):
# categories = [category for category in graphs_list] # graphs_list.keys()
# for category in categories:
# with gr.Tab(category):
# splits = [graphs_list[category][i:i+3] for i in range(0, len(graphs_list[category]), 3)]
# for row in splits:
# with gr.Row():
# for graph in row:
# gr.HTML(graph, elem_id="graphs-placeholder")
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
gallery_component = gr.Gallery()
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
# with gr.Row():
# with gr.Column(scale=1):
# query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
# keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
# after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
# search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
# with gr.Column(scale=7):
# with gr.Tab("Summary",elem_id="papers-summary-tab"):
# papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
# with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
# papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
# with gr.Tab("Citations network",elem_id="papers-network-tab"):
# citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
# with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs", id=4) as saved_graphs_tab:
# @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
# def view_saved_graphs(graphs_list):
# for graph in graphs_list:
# gr.HTML(graph, elem_id="graphs-placeholder")
with gr.Tab("About",elem_classes = "max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")
def start_chat(query,history):
# history = history + [(query,None)]
# history = [tuple(x) for x in history]
history = history + [ChatMessage(role="user", content=query)]
return (gr.update(interactive = False),gr.update(selected=1),history)
def finish_chat():
return (gr.update(interactive = True,value = ""),gr.update(selected=3))
def change_completion_status(current_state):
current_state = 1 - current_state
return current_state
(textbox
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_textbox")
.then(change_completion_status, [chat_completed_state], [chat_completed_state])
# .then(lambda graphs : generate_html_graphs(graphs), [current_graphs], [graphs_container],)
)
(examples_hidden
# .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
# .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports,current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, current_graphs],concurrency_limit = 8,api_name = "chat_examples")
# .then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_examples")
# .then(change_completion_status, [chat_completed_state], [chat_completed_state])
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_examples")
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
# .then(lambda graphs : graphs, [current_graphs], [graphs_container])
)
def change_sample_questions(key):
index = list(QUESTIONS.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
demo.queue()
demo.launch(ssr_mode=False)