from climateqa.engine.embeddings import get_embeddings_function embeddings_function = get_embeddings_function() from climateqa.knowledge.openalex import OpenAlex from sentence_transformers import CrossEncoder # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") oa = OpenAlex() import gradio as gr from gradio_modal import Modal import pandas as pd import numpy as np import os import time import re import json from gradio import ChatMessage # from gradio_modal import Modal from io import BytesIO import base64 from datetime import datetime from azure.storage.fileshare import ShareServiceClient from utils import create_user_id from gradio_modal import Modal from PIL import Image from langchain_core.runnables.schema import StreamEvent # ClimateQ&A imports from climateqa.engine.llm import get_llm from climateqa.engine.vectorstore import get_pinecone_vectorstore # from climateqa.knowledge.retriever import ClimateQARetriever from climateqa.engine.reranker import get_reranker from climateqa.engine.embeddings import get_embeddings_function from climateqa.engine.chains.prompts import audience_prompts from climateqa.sample_questions import QUESTIONS from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES from climateqa.utils import get_image_from_azure_blob_storage from climateqa.engine.keywords import make_keywords_chain from climateqa.engine.chains.answer_rag import make_rag_papers_chain from climateqa.engine.graph import make_graph_agent from climateqa.engine.embeddings import get_embeddings_function from front.utils import serialize_docs,process_figures,make_html_df from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs # Load environment variables in local mode try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) init_prompt = "" system_template = { "role": "system", "content": init_prompt, } account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" credential = { "account_key": account_key, "account_name": os.environ["BLOB_ACCOUNT_NAME"], } account_url = os.environ["BLOB_ACCOUNT_URL"] file_share_name = "climateqa" service = ShareServiceClient(account_url=account_url, credential=credential) share_client = service.get_share_client(file_share_name) user_id = create_user_id() CITATION_LABEL = "BibTeX citation for ClimateQ&A" CITATION_TEXT = r"""@misc{climateqa, author={Théo Alves Da Costa, Timothée Bohe}, title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss}, year={2024}, howpublished= {\url{https://climateqa.com}}, } @software{climateqa, author = {Théo Alves Da Costa, Timothée Bohe}, publisher = {ClimateQ&A}, title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss}, } """ # Create vectorstore and retriever vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX")) vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="title") llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) reranker = get_reranker("nano") agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker) async def chat(query,history,audience,sources,reports,current_graphs): """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: (messages in gradio format, messages in langchain format, source documents)""" date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f">> NEW QUESTION ({date_now}) : {query}") audience_prompt = init_audience(audience) # Prepare default values if sources is None or len(sources) == 0: sources = ["IPCC", "IPBES", "IPOS"] if reports is None or len(reports) == 0: reports = [] inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources} result = agent.astream_events(inputs,version = "v1") docs = [] used_figures=[] docs_html = "" output_query = "" output_language = "" output_keywords = "" start_streaming = False graphs_html = "" figures = '

' steps_display = { "categorize_intent":("🔄️ Analyzing user message",True), "transform_query":("🔄️ Thinking step by step to answer the question",True), "retrieve_documents":("🔄️ Searching in the knowledge base",False), } used_documents = [] answer_message_content = "" try: async for event in result: if "langgraph_node" in event["metadata"]: node = event["metadata"]["langgraph_node"] if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved docs, docs_html, history, used_documents = handle_retrieved_documents(event, history, used_documents) elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps event_description, display_output = steps_display[node] if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description})) elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content) elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end": graphs_html = handle_retrieved_owid_graphs(event, graphs_html) if event["name"] == "transform_query" and event["event"] =="on_chain_end": if hasattr(history[-1],"content"): history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]]) if event["name"] == "categorize_intent" and event["event"] == "on_chain_start": print("X") yield history, docs_html, output_query, output_language, docs , graphs_html #,output_query,output_keywords except Exception as e: print(event, "has failed") raise gr.Error(f"{e}") try: # Log answer on Azure Blob Storage if os.getenv("GRADIO_ENV") != "local": timestamp = str(datetime.now().timestamp()) file = timestamp + ".json" prompt = history[1]["content"] logs = { "user_id": str(user_id), "prompt": prompt, "query": prompt, "question":output_query, "sources":sources, "docs":serialize_docs(docs), "answer": history[-1].content, "time": timestamp, } log_on_azure(file, logs, share_client) except Exception as e: print(f"Error logging on Azure Blob Storage: {e}") raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") yield history, docs_html, output_query, output_language, docs, graphs_html def save_feedback(feed: str, user_id): if len(feed) > 1: timestamp = str(datetime.now().timestamp()) file = user_id + timestamp + ".json" logs = { "user_id": user_id, "feedback": feed, "time": timestamp, } log_on_azure(file, logs, share_client) return "Feedback submitted, thank you!" def log_on_azure(file, logs, share_client): logs = json.dumps(logs) file_client = share_client.get_file_client(file) file_client.upload_file(logs) def generate_keywords(query): chain = make_keywords_chain(llm) keywords = chain.invoke(query) keywords = " AND ".join(keywords["keywords"]) return keywords papers_cols_widths = { "id":100, "title":300, "doi":100, "publication_year":100, "abstract":500, "is_oa":50, } papers_cols = list(papers_cols_widths.keys()) papers_cols_widths = list(papers_cols_widths.values()) async def find_papers(query,after): summary = "" keywords = generate_keywords(query) df_works = oa.search(keywords,after = after) df_works = df_works.dropna(subset=["abstract"]) df_works = oa.rerank(query,df_works,reranker) df_works = df_works.sort_values("rerank_score",ascending=False) docs_html = [] for i in range(10): docs_html.append(make_html_df(df_works, i)) docs_html = "".join(docs_html) print(docs_html) G = oa.make_network(df_works) height = "750px" network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height) network_html = network.generate_html() network_html = network_html.replace("'", "\"") css_to_inject = "" network_html = network_html + css_to_inject network_html = f"""""" docs = df_works["content"].head(10).tolist() df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"}) df_works["doc"] = df_works["doc"] + 1 df_works = df_works[papers_cols] yield docs_html, network_html, summary chain = make_rag_papers_chain(llm) result = chain.astream_log({"question": query,"docs": docs,"language":"English"}) path_answer = "/logs/StrOutputParser/streamed_output/-" async for op in result: op = op.ops[0] if op['path'] == path_answer: # reforulated question new_token = op['value'] # str summary += new_token else: continue yield docs_html, network_html, summary # -------------------------------------------------------------------- # Gradio # -------------------------------------------------------------------- init_prompt = """ Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. ❓ How to use - **Language**: You can ask me your questions in any language. - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. - **Sources**: You can choose to search in the IPCC or IPBES reports, or both. ⚠️ Limitations *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* 🛈 Information Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information. What do you want to learn ? """ def vote(data: gr.LikeData): if data.liked: print(data.value) else: print(data) def save_graph(saved_graphs_state, embedding, category): print(f"\nCategory:\n{saved_graphs_state}\n") if category not in saved_graphs_state: saved_graphs_state[category] = [] if embedding not in saved_graphs_state[category]: saved_graphs_state[category].append(embedding) return saved_graphs_state, gr.Button("Graph Saved") with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo: chat_completed_state = gr.State(0) current_graphs = gr.State([]) saved_graphs = gr.State({}) with gr.Tab("ClimateQ&A"): with gr.Row(elem_id="chatbot-row"): with gr.Column(scale=2): chatbot = gr.Chatbot( value = [ChatMessage(role="assistant", content=init_prompt)], type = "messages", show_copy_button=True, show_label = False, elem_id="chatbot", layout = "panel", avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), max_height="80vh", height="100vh" ) # bot.like(vote,None,None) with gr.Row(elem_id = "input-message"): textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") with gr.Column(scale=2, variant="panel",elem_id = "right-panel"): with gr.Tabs() as tabs: with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): examples_hidden = gr.Textbox(visible = False) first_key = list(QUESTIONS.keys())[0] dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") samples = [] for i,key in enumerate(QUESTIONS.keys()): examples_visible = True if i == 0 else False with gr.Row(visible = examples_visible) as group_examples: examples_questions = gr.Examples( QUESTIONS[key], [examples_hidden], examples_per_page=8, run_on_click=False, elem_id=f"examples{i}", api_name=f"examples{i}", # label = "Click on the example question or enter your own", # cache_examples=True, ) samples.append(group_examples) with gr.Tab("Sources",elem_id = "tab-sources",id = 1): sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") docs_textbox = gr.State("") with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures: sources_raw = gr.State() with Modal(visible=False, elem_id="modal_figure_galery") as modal: gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh") show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True) show_full_size_figures.click(lambda : Modal(visible=True),None,modal) figures_cards = gr.HTML(show_label=False, elem_id="sources-figures") with gr.Tab("Papers",elem_id = "tab-citations",id = 5): btn_summary = gr.Button("Summary") # Fenêtre simulée pour le Summary with gr.Group(visible=False, elem_id="papers-summary-popup") as summary_popup: papers_summary = gr.Markdown("### Summary Content", visible=True, elem_id="papers-summary") btn_relevant_papers = gr.Button("Relevant papers") # Fenêtre simulée pour les Relevant Papers with gr.Group(visible=False, elem_id="papers-relevant-popup") as relevant_popup: papers_html = gr.HTML(show_label=False, elem_id="sources-textbox") docs_textbox = gr.State("") btn_citations_network = gr.Button("Citations network") # Fenêtre simulée pour le Citations Network with Modal(visible=False) as modal: citations_network = gr.HTML("

Citations Network Graph

", visible=True, elem_id="papers-citations-network") btn_citations_network.click(lambda: Modal(visible=True), None, modal) with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content: graphs_container = gr.HTML("

There are no graphs to be displayed at the moment. Try asking another question.

") current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container]) # @gr.render(inputs=[current_graphs]) # def display_default_recommended(current_graphs): # if len(current_graphs)==0: # placeholder_message = gr.HTML("

There are no graphs to be displayed at the moment. Try asking another question.

") # @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change]) # def render_graphs(current_graph_list): # global saved_graphs # with gr.Column(): # print(f"\ncurrent_graph_list:\n{current_graph_list}") # for (embedding, category) in current_graph_list: # graphs_placeholder = gr.HTML(embedding, elem_id="graphs-placeholder") # save_btn = gr.Button("Save Graph") # save_btn.click( # save_graph, # [saved_graphs, gr.State(embedding), gr.State(category)], # [saved_graphs, save_btn] # ) # # Display current_graphs # with gr.Row(): # for embedding in current_graphs: # with gr.Column(): # gr.HTML(embedding, elem_id="graphs-placeholder") # save_btn = gr.Button("Save Graph") # save_btn.click( # save_graph, # [saved_graphs, gr.State(embedding)], # [saved_graphs, save_btn] # ) with gr.Tab("Configuration") as tab_config: gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!") dropdown_sources = gr.CheckboxGroup( ["IPCC", "IPBES","IPOS"], label="Select source", value=["IPCC"], interactive=True, ) dropdown_reports = gr.Dropdown( POSSIBLE_REPORTS, label="Or select specific reports", multiselect=True, value=None, interactive=True, ) dropdown_audience = gr.Dropdown( ["Children","General public","Experts"], label="Select audience", value="Experts", interactive=True, ) output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) #--------------------------------------------------------------------------------------- # OTHER TABS #--------------------------------------------------------------------------------------- # with gr.Tab("Recommended content", elem_id="tab-recommended_content2") as recommended_content_tab2: # @gr.render(inputs=[current_graphs]) # def display_default_recommended_head(current_graphs_list): # if len(current_graphs_list)==0: # gr.HTML("

There are no graphs to be displayed at the moment. Try asking another question.

") # @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change]) # def render_graphs_head(current_graph_list): # global saved_graphs # category_dict = defaultdict(list) # for (embedding, category) in current_graph_list: # category_dict[category].append(embedding) # for category in category_dict: # with gr.Tab(category): # splits = [category_dict[category][i:i+3] for i in range(0, len(category_dict[category]), 3)] # for row in splits: # with gr.Row(): # for embedding in row: # with gr.Column(): # gr.HTML(embedding, elem_id="graphs-placeholder") # save_btn = gr.Button("Save Graph") # save_btn.click( # save_graph, # [saved_graphs, gr.State(embedding), gr.State(category)], # [saved_graphs, save_btn] # ) # with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs") as saved_graphs_tab: # @gr.render(inputs=[saved_graphs]) # def display_default_save(saved): # if len(saved)==0: # gr.HTML("

You have not saved any graphs yet

") # @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change]) # def view_saved_graphs(graphs_list): # categories = [category for category in graphs_list] # graphs_list.keys() # for category in categories: # with gr.Tab(category): # splits = [graphs_list[category][i:i+3] for i in range(0, len(graphs_list[category]), 3)] # for row in splits: # with gr.Row(): # for graph in row: # gr.HTML(graph, elem_id="graphs-placeholder") # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"): # gallery_component = gr.Gallery(object_fit='cover') with gr.Tab("Settings",elem_id = "tab-config",id = 2): gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!") dropdown_sources = gr.CheckboxGroup( ["IPCC", "IPBES","IPOS", "OpenAlex"], label="Select source", value=["IPCC"], interactive=True, ) dropdown_reports = gr.Dropdown( POSSIBLE_REPORTS, label="Or select specific reports", multiselect=True, value=None, interactive=True, ) dropdown_audience = gr.Dropdown( ["Children","General public","Experts"], label="Select audience", value="Experts", interactive=True, ) after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers") output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"): # with gr.Row(): # with gr.Column(scale=1): # query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers") # keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers") # after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers") # search_papers = gr.Button("Search",elem_id="search-papers",interactive=True) # with gr.Column(scale=7): # with gr.Tab("Summary",elem_id="papers-summary-tab"): # papers_summary = gr.Markdown(visible=True,elem_id="papers-summary") # with gr.Tab("Relevant papers",elem_id="papers-results-tab"): # papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols) # with gr.Tab("Citations network",elem_id="papers-network-tab"): # citations_network = gr.HTML(visible=True,elem_id="papers-citations-network") # with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs", id=4) as saved_graphs_tab: # @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change]) # def view_saved_graphs(graphs_list): # for graph in graphs_list: # gr.HTML(graph, elem_id="graphs-placeholder") with gr.Tab("About",elem_classes = "max-height other-tabs"): with gr.Row(): with gr.Column(scale=1): gr.Markdown( """ ### More info - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/) - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp) ### Citation """ ) with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,): # # Display citation label and text) gr.Textbox( value=CITATION_TEXT, label="", interactive=False, show_copy_button=True, lines=len(CITATION_TEXT.split('\n')), ) def start_chat(query,history): history = history + [ChatMessage(role="user", content=query)] return (gr.update(interactive = False),gr.update(selected=1),history) def finish_chat(): return gr.update(interactive = True,value = "") # Initialize visibility states summary_visible = False relevant_visible = False # Functions to toggle visibility def toggle_summary_visibility(): global summary_visible summary_visible = not summary_visible return gr.update(visible=summary_visible) def toggle_relevant_visibility(): global relevant_visible relevant_visible = not relevant_visible return gr.update(visible=relevant_visible) def change_completion_status(current_state): current_state = 1 - current_state return current_state def update_sources_number_display(sources_textbox, figures_cards, current_graphs): sources_number = sources_textbox.count("

") figures_number = figures_cards.count("

") graphs_number = current_graphs.count("