|
from climateqa.engine.embeddings import get_embeddings_function |
|
embeddings_function = get_embeddings_function() |
|
|
|
from climateqa.papers.openalex import OpenAlex |
|
from sentence_transformers import CrossEncoder |
|
|
|
reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") |
|
oa = OpenAlex() |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import time |
|
import re |
|
import json |
|
|
|
|
|
|
|
from io import BytesIO |
|
import base64 |
|
|
|
from datetime import datetime |
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
from utils import create_user_id |
|
|
|
|
|
|
|
|
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.chains.answer_rag import make_rag_chain |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
from climateqa.engine.retriever import ClimateQARetriever |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.chains.prompts import audience_prompts |
|
from climateqa.sample_questions import QUESTIONS |
|
from climateqa.constants import POSSIBLE_REPORTS |
|
from climateqa.utils import get_image_from_azure_blob_storage |
|
from climateqa.engine.keywords import make_keywords_chain |
|
from climateqa.engine.chains.answer_rag import make_rag_papers_chain |
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
|
|
init_prompt = "" |
|
|
|
system_template = { |
|
"role": "system", |
|
"content": init_prompt, |
|
} |
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
|
|
|
|
def parse_output_llm_with_sources(output): |
|
|
|
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output) |
|
parts = [] |
|
for part in content_parts: |
|
if part.startswith("Doc"): |
|
subparts = part.split(",") |
|
subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts] |
|
subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts] |
|
parts.append("".join(subparts)) |
|
else: |
|
parts.append(part) |
|
content_parts = "".join(parts) |
|
return content_parts |
|
|
|
|
|
|
|
vectorstore = get_pinecone_vectorstore(embeddings_function) |
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
|
|
|
|
def make_pairs(lst): |
|
"""from a list of even lenght, make tupple pairs""" |
|
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] |
|
|
|
|
|
def serialize_docs(docs): |
|
new_docs = [] |
|
for doc in docs: |
|
new_doc = {} |
|
new_doc["page_content"] = doc.page_content |
|
new_doc["metadata"] = doc.metadata |
|
new_docs.append(new_doc) |
|
return new_docs |
|
|
|
|
|
|
|
async def chat(query,history,audience,sources,reports): |
|
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: |
|
(messages in gradio format, messages in langchain format, source documents)""" |
|
|
|
print(f">> NEW QUESTION : {query}") |
|
|
|
if audience == "Children": |
|
audience_prompt = audience_prompts["children"] |
|
elif audience == "General public": |
|
audience_prompt = audience_prompts["general"] |
|
elif audience == "Experts": |
|
audience_prompt = audience_prompts["experts"] |
|
else: |
|
audience_prompt = audience_prompts["experts"] |
|
|
|
|
|
if len(sources) == 0: |
|
sources = ["IPCC"] |
|
|
|
if len(reports) == 0: |
|
reports = [] |
|
|
|
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5) |
|
rag_chain = make_rag_chain(retriever,llm) |
|
|
|
inputs = {"query": query,"audience": audience_prompt} |
|
result = rag_chain.astream_log(inputs) |
|
|
|
|
|
path_reformulation = "/logs/reformulation/final_output" |
|
path_keywords = "/logs/keywords/final_output" |
|
path_retriever = "/logs/find_documents/final_output" |
|
path_answer = "/logs/answer/streamed_output_str/-" |
|
|
|
docs_html = "" |
|
output_query = "" |
|
output_language = "" |
|
output_keywords = "" |
|
gallery = [] |
|
|
|
try: |
|
async for op in result: |
|
|
|
op = op.ops[0] |
|
|
|
if op['path'] == path_reformulation: |
|
try: |
|
output_language = op['value']["language"] |
|
output_query = op["value"]["question"] |
|
except Exception as e: |
|
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)") |
|
|
|
if op["path"] == path_keywords: |
|
try: |
|
output_keywords = op['value']["keywords"] |
|
output_keywords = " AND ".join(output_keywords) |
|
except Exception as e: |
|
pass |
|
|
|
|
|
elif op['path'] == path_retriever: |
|
try: |
|
docs = op['value']['docs'] |
|
docs_html = [] |
|
for i, d in enumerate(docs, 1): |
|
docs_html.append(make_html_source(d, i)) |
|
docs_html = "".join(docs_html) |
|
except TypeError: |
|
print("No documents found") |
|
print("op: ",op) |
|
continue |
|
|
|
elif op['path'] == path_answer: |
|
new_token = op['value'] |
|
|
|
previous_answer = history[-1][1] |
|
previous_answer = previous_answer if previous_answer is not None else "" |
|
answer_yet = previous_answer + new_token |
|
answer_yet = parse_output_llm_with_sources(answer_yet) |
|
history[-1] = (query,answer_yet) |
|
|
|
|
|
|
|
else: |
|
continue |
|
|
|
history = [tuple(x) for x in history] |
|
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords |
|
|
|
except Exception as e: |
|
raise gr.Error(f"{e}") |
|
|
|
|
|
try: |
|
|
|
if os.getenv("GRADIO_ENV") != "local": |
|
timestamp = str(datetime.now().timestamp()) |
|
file = timestamp + ".json" |
|
prompt = history[-1][0] |
|
logs = { |
|
"user_id": str(user_id), |
|
"prompt": prompt, |
|
"query": prompt, |
|
"question":output_query, |
|
"sources":sources, |
|
"docs":serialize_docs(docs), |
|
"answer": history[-1][1], |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
except Exception as e: |
|
print(f"Error logging on Azure Blob Storage: {e}") |
|
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") |
|
|
|
image_dict = {} |
|
for i,doc in enumerate(docs): |
|
|
|
if doc.metadata["chunk_type"] == "image": |
|
try: |
|
key = f"Image {i+1}" |
|
image_path = doc.metadata["image_path"].split("documents/")[1] |
|
img = get_image_from_azure_blob_storage(image_path) |
|
|
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
markdown_image = f"![Alt text](data:image/png;base64,{img_str})" |
|
image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} |
|
except Exception as e: |
|
print(f"Skipped adding image {i} because of {e}") |
|
|
|
if len(image_dict) > 0: |
|
|
|
gallery = [x["img"] for x in list(image_dict.values())] |
|
img = list(image_dict.values())[0] |
|
img_md = img["md"] |
|
img_caption = img["caption"] |
|
img_code = img["figure_code"] |
|
if img_code != "N/A": |
|
img_name = f"{img['key']} - {img['figure_code']}" |
|
else: |
|
img_name = f"{img['key']}" |
|
|
|
answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>" |
|
history[-1] = (history[-1][0],answer_yet) |
|
history = [tuple(x) for x in history] |
|
|
|
|
|
|
|
|
|
|
|
|
|
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords |
|
|
|
|
|
def make_html_source(source,i): |
|
meta = source.metadata |
|
|
|
content = source.page_content.strip() |
|
|
|
toc_levels = [] |
|
for j in range(2): |
|
level = meta[f"toc_level{j}"] |
|
if level != "N/A": |
|
toc_levels.append(level) |
|
else: |
|
break |
|
toc_levels = " > ".join(toc_levels) |
|
|
|
if len(toc_levels) > 0: |
|
name = f"<b>{toc_levels}</b><br/>{meta['name']}" |
|
else: |
|
name = meta['name'] |
|
|
|
if meta["chunk_type"] == "text": |
|
|
|
card = f""" |
|
<div class="card" id="doc{i}"> |
|
<div class="card-content"> |
|
<h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2> |
|
<p>{content}</p> |
|
</div> |
|
<div class="card-footer"> |
|
<span>{name}</span> |
|
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link"> |
|
<span role="img" aria-label="Open PDF">🔗</span> |
|
</a> |
|
</div> |
|
</div> |
|
""" |
|
|
|
else: |
|
|
|
if meta["figure_code"] != "N/A": |
|
title = f"{meta['figure_code']} - {meta['short_name']}" |
|
else: |
|
title = f"{meta['short_name']}" |
|
|
|
card = f""" |
|
<div class="card card-image"> |
|
<div class="card-content"> |
|
<h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2> |
|
<p>{content}</p> |
|
<p class='ai-generated'>AI-generated description</p> |
|
</div> |
|
<div class="card-footer"> |
|
<span>{name}</span> |
|
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link"> |
|
<span role="img" aria-label="Open PDF">🔗</span> |
|
</a> |
|
</div> |
|
</div> |
|
""" |
|
|
|
return card |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_feedback(feed: str, user_id): |
|
if len(feed) > 1: |
|
timestamp = str(datetime.now().timestamp()) |
|
file = user_id + timestamp + ".json" |
|
logs = { |
|
"user_id": user_id, |
|
"feedback": feed, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
return "Feedback submitted, thank you!" |
|
|
|
|
|
|
|
|
|
def log_on_azure(file, logs, share_client): |
|
logs = json.dumps(logs) |
|
file_client = share_client.get_file_client(file) |
|
file_client.upload_file(logs) |
|
|
|
|
|
def generate_keywords(query): |
|
chain = make_keywords_chain(llm) |
|
keywords = chain.invoke(query) |
|
keywords = " AND ".join(keywords["keywords"]) |
|
return keywords |
|
|
|
|
|
|
|
papers_cols_widths = { |
|
"doc":50, |
|
"id":100, |
|
"title":300, |
|
"doi":100, |
|
"publication_year":100, |
|
"abstract":500, |
|
"rerank_score":100, |
|
"is_oa":50, |
|
} |
|
|
|
papers_cols = list(papers_cols_widths.keys()) |
|
papers_cols_widths = list(papers_cols_widths.values()) |
|
|
|
async def find_papers(query, keywords,after): |
|
|
|
summary = "" |
|
|
|
df_works = oa.search(keywords,after = after) |
|
df_works = df_works.dropna(subset=["abstract"]) |
|
df_works = oa.rerank(query,df_works,reranker) |
|
df_works = df_works.sort_values("rerank_score",ascending=False) |
|
G = oa.make_network(df_works) |
|
|
|
height = "750px" |
|
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height) |
|
network_html = network.generate_html() |
|
|
|
network_html = network_html.replace("'", "\"") |
|
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>" |
|
network_html = network_html + css_to_inject |
|
|
|
|
|
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>""" |
|
|
|
|
|
docs = df_works["content"].head(15).tolist() |
|
|
|
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"}) |
|
df_works["doc"] = df_works["doc"] + 1 |
|
df_works = df_works[papers_cols] |
|
|
|
yield df_works,network_html,summary |
|
|
|
chain = make_rag_papers_chain(llm) |
|
result = chain.astream_log({"question": query,"docs": docs,"language":"English"}) |
|
path_answer = "/logs/StrOutputParser/streamed_output/-" |
|
|
|
async for op in result: |
|
|
|
op = op.ops[0] |
|
|
|
if op['path'] == path_answer: |
|
new_token = op['value'] |
|
summary += new_token |
|
else: |
|
continue |
|
yield df_works,network_html,summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_prompt = """ |
|
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. |
|
|
|
❓ How to use |
|
- **Language**: You can ask me your questions in any language. |
|
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. |
|
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both. |
|
|
|
⚠️ Limitations |
|
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
|
|
def vote(data: gr.LikeData): |
|
if data.liked: |
|
print(data.value) |
|
else: |
|
print(data) |
|
|
|
|
|
|
|
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo: |
|
|
|
|
|
with gr.Tab("ClimateQ&A"): |
|
|
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2): |
|
|
|
chatbot = gr.Chatbot( |
|
value=[(None,init_prompt)], |
|
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel", |
|
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id = "input-message"): |
|
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") |
|
|
|
|
|
|
|
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"): |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): |
|
|
|
examples_hidden = gr.Textbox(visible = False) |
|
first_key = list(QUESTIONS.keys())[0] |
|
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") |
|
|
|
samples = [] |
|
for i,key in enumerate(QUESTIONS.keys()): |
|
|
|
examples_visible = True if i == 0 else False |
|
|
|
with gr.Row(visible = examples_visible) as group_examples: |
|
|
|
examples_questions = gr.Examples( |
|
QUESTIONS[key], |
|
[examples_hidden], |
|
examples_per_page=8, |
|
run_on_click=False, |
|
elem_id=f"examples{i}", |
|
api_name=f"examples{i}", |
|
|
|
|
|
) |
|
|
|
samples.append(group_examples) |
|
|
|
|
|
with gr.Tab("Sources",elem_id = "tab-citations",id = 1): |
|
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") |
|
docs_textbox = gr.State("") |
|
|
|
|
|
with gr.Tab("Configuration",elem_id = "tab-config",id = 2): |
|
|
|
gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!") |
|
|
|
|
|
dropdown_sources = gr.CheckboxGroup( |
|
["IPCC", "IPBES","IPOS"], |
|
label="Select source", |
|
value=["IPCC"], |
|
interactive=True, |
|
) |
|
|
|
dropdown_reports = gr.Dropdown( |
|
POSSIBLE_REPORTS, |
|
label="Or select specific reports", |
|
multiselect=True, |
|
value=None, |
|
interactive=True, |
|
) |
|
|
|
dropdown_audience = gr.Dropdown( |
|
["Children","General public","Experts"], |
|
label="Select audience", |
|
value="Experts", |
|
interactive=True, |
|
) |
|
|
|
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) |
|
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"): |
|
gallery_component = gr.Gallery() |
|
|
|
with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"): |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers") |
|
keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers") |
|
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers") |
|
search_papers = gr.Button("Search",elem_id="search-papers",interactive=True) |
|
|
|
with gr.Column(scale=7): |
|
|
|
with gr.Tab("Summary",elem_id="papers-summary-tab"): |
|
papers_summary = gr.Markdown(visible=True,elem_id="papers-summary") |
|
|
|
with gr.Tab("Relevant papers",elem_id="papers-results-tab"): |
|
papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols) |
|
|
|
with gr.Tab("Citations network",elem_id="papers-network-tab"): |
|
citations_network = gr.HTML(visible=True,elem_id="papers-citations-network") |
|
|
|
|
|
|
|
with gr.Tab("About",elem_classes = "max-height other-tabs"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)") |
|
|
|
|
|
def start_chat(query,history): |
|
history = history + [(query,None)] |
|
history = [tuple(x) for x in history] |
|
return (gr.update(interactive = False),gr.update(selected=1),history) |
|
|
|
def finish_chat(): |
|
return (gr.update(interactive = True,value = "")) |
|
|
|
(textbox |
|
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox") |
|
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox") |
|
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox") |
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples") |
|
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples") |
|
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples") |
|
) |
|
|
|
|
|
def change_sample_questions(key): |
|
index = list(QUESTIONS.keys()).index(key) |
|
visible_bools = [False] * len(samples) |
|
visible_bools[index] = True |
|
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] |
|
|
|
|
|
|
|
dropdown_samples.change(change_sample_questions,dropdown_samples,samples) |
|
|
|
query_papers.submit(generate_keywords,[query_papers], [keywords_papers]) |
|
search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue() |
|
|
|
demo.launch() |
|
|