Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.nodes import EmbeddingRetriever | |
import numpy as np | |
import openai | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
system_template = { | |
"role": "system", | |
"content": "You have been a climate change expert for 30 years. You answer questions about climate change in an educationnal and concise manner.", | |
} | |
document_store = FAISSDocumentStore.load( | |
index_path=f"./documents/climate_gpt.faiss", | |
config_path=f"./documents/climate_gpt.json", | |
) | |
dense = EmbeddingRetriever( | |
document_store=document_store, | |
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1", | |
model_format="sentence_transformers", | |
) | |
def is_climate_change_related(sentence: str) -> bool: | |
results = classifier( | |
sequences=sentence, | |
candidate_labels=["climate change related", "non climate change related"], | |
) | |
return results["labels"][np.argmax(results["scores"])] == "climate change related" | |
def make_pairs(lst): | |
"""from a list of even lenght, make tupple pairs""" | |
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] | |
def gen_conv(query: str, history=[system_template], ipcc=True): | |
"""return (answer:str, history:list[dict], sources:str)""" | |
retrieve = ipcc and is_climate_change_related(query) | |
sources = "" | |
messages = history + [ | |
{"role": "user", "content": query}, | |
] | |
if retrieve: | |
docs = dense.retrieve(query=query, top_k=5) | |
sources = "\n\n".join( | |
["If relevant, use those extracts from IPCC reports in your answer"] | |
+ [ | |
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content}" | |
for d in docs | |
] | |
) | |
messages.append({"role": "system", "content": sources}) | |
answer = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0.2, | |
# max_tokens=200, | |
)["choices"][0]["message"]["content"] | |
if retrieve: | |
messages.pop() | |
answer = "(top 5 documents retrieved) " + answer | |
sources = "\n\n".join( | |
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content[:100]} [...]" | |
for d in docs | |
) | |
messages.append({"role": "assistant", "content": answer}) | |
gradio_format = make_pairs([a["content"] for a in messages[1:]]) | |
return gradio_format, messages, sources | |
def connect(text): | |
openai.api_key = text | |
return "You're all set" | |
with gr.Blocks(title="Eki IPCC Explorer") as demo: | |
with gr.Row(): | |
with gr.Column(): | |
api_key = gr.Textbox(label="Open AI api key") | |
connect_btn = gr.Button(value="Connect") | |
with gr.Column(): | |
result = gr.Textbox(label="Connection") | |
connect_btn.click(connect, inputs=api_key, outputs=result, api_name="Connection") | |
gr.Markdown( | |
""" | |
# Ask me anything, I'm an IPCC report | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot() | |
state = gr.State([system_template]) | |
with gr.Row(): | |
ask = gr.Textbox( | |
show_label=False, placeholder="Enter text and press enter" | |
).style(container=False) | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("### Sources") | |
sources_textbox = gr.Textbox( | |
interactive=False, show_label=False, max_lines=50 | |
) | |
ask.submit( | |
fn=gen_conv, inputs=[ask, state], outputs=[chatbot, state, sources_textbox] | |
) | |
demo.launch(share=True) | |