from climateqa.engine.embeddings import get_embeddings_function embeddings_function = get_embeddings_function() from climateqa.knowledge.openalex import OpenAlex from sentence_transformers import CrossEncoder # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") oa = OpenAlex() import gradio as gr import pandas as pd import numpy as np import os import time import re import json from gradio import ChatMessage # from gradio_modal import Modal from io import BytesIO import base64 from datetime import datetime from azure.storage.fileshare import ShareServiceClient from utils import create_user_id from langchain_chroma import Chroma from collections import defaultdict # ClimateQ&A imports from climateqa.engine.llm import get_llm from climateqa.engine.vectorstore import get_pinecone_vectorstore from climateqa.knowledge.retriever import ClimateQARetriever from climateqa.engine.reranker import get_reranker from climateqa.engine.embeddings import get_embeddings_function from climateqa.engine.chains.prompts import audience_prompts from climateqa.sample_questions import QUESTIONS from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES from climateqa.utils import get_image_from_azure_blob_storage from climateqa.engine.keywords import make_keywords_chain # from climateqa.engine.chains.answer_rag import make_rag_papers_chain from climateqa.engine.graph import make_graph_agent,display_graph from climateqa.engine.embeddings import get_embeddings_function from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs # Load environment variables in local mode try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) init_prompt = "" system_template = { "role": "system", "content": init_prompt, } account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" credential = { "account_key": account_key, "account_name": os.environ["BLOB_ACCOUNT_NAME"], } account_url = os.environ["BLOB_ACCOUNT_URL"] file_share_name = "climateqa" service = ShareServiceClient(account_url=account_url, credential=credential) share_client = service.get_share_client(file_share_name) user_id = create_user_id() embeddings_function = get_embeddings_function() llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) reranker = get_reranker("nano") # Create vectorstore and retriever vectorstore = get_pinecone_vectorstore(embeddings_function) vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/climate-question-answering/data/vectorstore_owid", embedding_function=embeddings_function) # agent = make_graph_agent(llm,vectorstore,reranker) agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker) async def chat(query,history,audience,sources,reports,current_graphs): """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: (messages in gradio format, messages in langchain format, source documents)""" date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f">> NEW QUESTION ({date_now}) : {query}") if audience == "Children": audience_prompt = audience_prompts["children"] elif audience == "General public": audience_prompt = audience_prompts["general"] elif audience == "Experts": audience_prompt = audience_prompts["experts"] else: audience_prompt = audience_prompts["experts"] # Prepare default values if sources is None or len(sources) == 0: sources = ["IPCC", "IPBES", "IPOS"] if reports is None or len(reports) == 0: reports = [] inputs = {"user_input": query,"audience": audience_prompt,"sources":sources} result = agent.astream_events(inputs,version = "v1") # path_reformulation = "/logs/reformulation/final_output" # path_keywords = "/logs/keywords/final_output" # path_retriever = "/logs/find_documents/final_output" # path_answer = "/logs/answer/streamed_output_str/-" docs = [] docs_used = True docs_html = "" current_graphs = [] output_query = "" output_language = "" output_keywords = "" gallery = [] updates = [] start_streaming = False steps_display = { "categorize_intent":("🔄️ Analyzing user message",True), "transform_query":("🔄️ Thinking step by step to answer the question",True), "retrieve_documents":("🔄️ Searching in the knowledge base",False), } used_documents = [] answer_message_content = "" try: async for event in result: # if event["event"] == "on_chat_model_stream" and event["metadata"]["langgraph_node"] in ["answer_rag", "answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]: # if start_streaming == False: # start_streaming = True # history[-1] = (query,"") if "langgraph_node" in event["metadata"]: node = event["metadata"]["langgraph_node"] if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved try: docs = event["data"]["output"]["documents"] docs_html = [] for i, d in enumerate(docs, 1): docs_html.append(make_html_source(d, i)) used_documents = used_documents + [d.metadata["name"] for d in docs] history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents)) docs_html = "".join(docs_html) except Exception as e: print(f"Error getting documents: {e}") print(event) elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps event_description,display_output = steps_display[node] if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description})) elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search"]:# if streaming answer if start_streaming == False: start_streaming = True history.append(ChatMessage(role="assistant", content = "")) answer_message_content += event["data"]["chunk"].content answer_message_content = parse_output_llm_with_sources(answer_message_content) history[-1] = ChatMessage(role="assistant", content = answer_message_content) # history.append(ChatMessage(role="assistant", content = new_message_content)) # if docs_used is True and event["metadata"]["langgraph_node"] in ["answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]: # docs_used = False # elif docs_used is True and event["name"] == "retrieve_documents" and event["event"] == "on_chain_end": # try: # docs = event["data"]["output"]["documents"] # docs_html = [] # for i, d in enumerate(docs, 1): # docs_html.append(make_html_source(d, i)) # docs_html = "".join(docs_html) # except Exception as e: # print(f"Error getting documents: {e}") # print(event) # # elif event["name"] == "retrieve_documents" and event["event"] == "on_chain_start": # # print(event) # # questions = event["data"]["input"]["questions"] # # questions = "\n".join([f"{i+1}. {q['question']} ({q['source']})" for i,q in enumerate(questions)]) # # answer_yet = "🔄️ Searching in the knowledge base\n{questions}" # # history[-1] = (query,answer_yet) # elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end": # try: # recommended_content = event["data"]["output"]["recommended_content"] # # graphs = [ # # { # # "embedding": x.metadata["returned_content"], # # "metadata": { # # "source": x.metadata["source"], # # "category": x.metadata["category"] # # } # # } for x in recommended_content if x.metadata["source"] == "OWID" # # ] # unique_graphs = [] # seen_embeddings = set() # for x in recommended_content: # embedding = x.metadata["returned_content"] # # Check if the embedding has already been seen # if embedding not in seen_embeddings: # unique_graphs.append({ # "embedding": embedding, # "metadata": { # "source": x.metadata["source"], # "category": x.metadata["category"] # } # }) # # Add the embedding to the seen set # seen_embeddings.add(embedding) # categories = {} # for graph in unique_graphs: # category = graph['metadata']['category'] # if category not in categories: # categories[category] = [] # categories[category].append(graph['embedding']) # # graphs_html = "" # for category, embeddings in categories.items(): # # graphs_html += f"
{event_description}
" # # answer_yet = make_toolbox(event_description, "", checked = False) # answer_yet = event_description # history[-1] = (query,answer_yet) # # elif event["event"] == "on_chain_end": # # answer_yet = "" # # history[-1] = (query,answer_yet) # # if display_output: # # print(event["data"]["output"]) # # if op['path'] == path_reformulation: # reforulated question # # try: # # output_language = op['value']["language"] # str # # output_query = op["value"]["question"] # # except Exception as e: # # raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)") # # if op["path"] == path_keywords: # # try: # # output_keywords = op['value']["keywords"] # str # # output_keywords = " AND ".join(output_keywords) # # except Exception as e: # # pass # history = [tuple(x) for x in history] # yield history,docs_html,output_query,output_language,gallery,current_graphs #,output_query,output_keywords if event["name"] == "transform_query" and event["event"] =="on_chain_end": if hasattr(history[-1],"content"): history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]]) if event["name"] == "categorize_intent" and event["event"] == "on_chain_start": print("X") yield history,docs_html,output_query,output_language,gallery #,output_query,output_keywords except Exception as e: print(event, "has failed") raise gr.Error(f"{e}") try: # Log answer on Azure Blob Storage if os.getenv("GRADIO_ENV") != "local": timestamp = str(datetime.now().timestamp()) file = timestamp + ".json" prompt = history[-1][0] logs = { "user_id": str(user_id), "prompt": prompt, "query": prompt, "question":output_query, "sources":sources, "docs":serialize_docs(docs), "answer": history[-1][1], "time": timestamp, } log_on_azure(file, logs, share_client) except Exception as e: print(f"Error logging on Azure Blob Storage: {e}") raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") image_dict = {} for i,doc in enumerate(docs): if doc.metadata["chunk_type"] == "image": try: key = f"Image {i+1}" image_path = doc.metadata["image_path"].split("documents/")[1] img = get_image_from_azure_blob_storage(image_path) # Convert the image to a byte buffer buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Embedding the base64 string in Markdown markdown_image = f"![Alt text](data:image/png;base64,{img_str})" image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} except Exception as e: print(f"Skipped adding image {i} because of {e}") if len(image_dict) > 0: gallery = [x["img"] for x in list(image_dict.values())] img = list(image_dict.values())[0] img_md = img["md"] img_caption = img["caption"] img_code = img["figure_code"] if img_code != "N/A": img_name = f"{img['key']} - {img['figure_code']}" else: img_name = f"{img['key']}" answer_yet = history[-1][1] + f"\n\n{img_md}\n " history[-1] = (history[-1][0],answer_yet) history = [tuple(x) for x in history] # print(f"\n\nImages:\n{gallery}") # # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])] # # if len(gallery) > 0: # # gallery = list(set("|".join(gallery).split("|"))) # # gallery = [get_image_from_azure_blob_storage(x) for x in gallery] # yield history,docs_html,output_query,output_language,gallery,current_graphs #,output_query,output_keywords # # else: # # docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)" # # complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**" # # messages.append({"role": "assistant", "content": complete_response}) # # gradio_format = make_pairs([a["content"] for a in messages[1:]]) # # yield gradio_format, messages, docs_string yield history,docs_html,output_query,output_language,gallery#,output_query,output_keywords def save_feedback(feed: str, user_id): if len(feed) > 1: timestamp = str(datetime.now().timestamp()) file = user_id + timestamp + ".json" logs = { "user_id": user_id, "feedback": feed, "time": timestamp, } log_on_azure(file, logs, share_client) return "Feedback submitted, thank you!" def log_on_azure(file, logs, share_client): logs = json.dumps(logs) file_client = share_client.get_file_client(file) file_client.upload_file(logs) def generate_keywords(query): chain = make_keywords_chain(llm) keywords = chain.invoke(query) keywords = " AND ".join(keywords["keywords"]) return keywords papers_cols_widths = { "doc":50, "id":100, "title":300, "doi":100, "publication_year":100, "abstract":500, "rerank_score":100, "is_oa":50, } papers_cols = list(papers_cols_widths.keys()) papers_cols_widths = list(papers_cols_widths.values()) # -------------------------------------------------------------------- # Gradio # -------------------------------------------------------------------- init_prompt = """ Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. ❓ How to use - **Language**: You can ask me your questions in any language. - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. - **Sources**: You can choose to search in the IPCC or IPBES reports, or both. ⚠️ Limitations *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* What do you want to learn ? """ def vote(data: gr.LikeData): if data.liked: print(data.value) else: print(data) def save_graph(saved_graphs_state, embedding, category): print(f"\nCategory:\n{saved_graphs_state}\n") if category not in saved_graphs_state: saved_graphs_state[category] = [] if embedding not in saved_graphs_state[category]: saved_graphs_state[category].append(embedding) return saved_graphs_state, gr.Button("Graph Saved") # with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo: # user_id_state = gr.State([user_id]) # chat_completed_state = gr.State(0) # current_graphs = gr.State([]) # saved_graphs = gr.State({}) with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo: chat_completed_state = gr.State(0) current_graphs = gr.State([]) saved_graphs = gr.State({}) with gr.Tab("ClimateQ&A"): with gr.Row(elem_id="chatbot-row"): with gr.Column(scale=2): # state = gr.State([system_template]) chatbot = gr.Chatbot( value = [ChatMessage(role="assistant", content=init_prompt)], type = "messages", show_copy_button=True, show_label = False, elem_id="chatbot", layout = "panel", avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), ) # bot.like(vote,None,None) with gr.Row(elem_id = "input-message"): textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") with gr.Column(scale=1.5, variant="panel",elem_id = "right-panel"): with gr.Tabs() as tabs: with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): examples_hidden = gr.Textbox(visible = False) first_key = list(QUESTIONS.keys())[0] dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") samples = [] for i,key in enumerate(QUESTIONS.keys()): examples_visible = True if i == 0 else False with gr.Row(visible = examples_visible) as group_examples: examples_questions = gr.Examples( QUESTIONS[key], [examples_hidden], examples_per_page=8, run_on_click=False, elem_id=f"examples{i}", api_name=f"examples{i}", # label = "Click on the example question or enter your own", # cache_examples=True, ) samples.append(group_examples) with gr.Tab("Sources",elem_id = "tab-citations",id = 1): sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") docs_textbox = gr.State("") # with Modal(visible = False) as config_modal: with gr.Tab("Configuration",elem_id = "tab-config",id = 2): gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!") dropdown_sources = gr.CheckboxGroup( ["IPCC", "IPBES","IPOS"], label="Select source", value=["IPCC", "IPBES","IPOS"], interactive=True, ) dropdown_reports = gr.Dropdown( POSSIBLE_REPORTS, label="Or select specific reports", multiselect=True, value=None, interactive=True, ) dropdown_audience = gr.Dropdown( ["Children","General public","Experts"], label="Select audience", value="Experts", interactive=True, ) output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=3) as recommended_content_tab: @gr.render(inputs=[current_graphs]) def display_default_recommended(current_graphs): if len(current_graphs)==0: placeholder_message = gr.HTML("