test_biocad / app.py
Belemort's picture
Update app.py
f6a6a4d verified
raw
history blame
8.21 kB
import gradio as gr
from mistralai import Mistral
from langchain_community.tools import TavilySearchResults, JinaSearch
import concurrent.futures
import json
import os
import arxiv
from docx import Document
from PIL import Image
import io
import base64
# Set environment variables for Tavily API
os.environ["TAVILY_API_KEY"] = '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'
# Mistral client API keys
client_1 = Mistral(api_key='eLES5HrVqduOE1OSWG6C5XyEUeR7qpXQ')
client_2 = Mistral(api_key='VPqG8sCy3JX5zFkpdiZ7bRSnTLKwngFJ')
client_3 = Mistral(api_key='cvyu5Rdk2lS026epqL4VB6BMPUcUMSgt')
# Function to encode images in base64
def encode_image_bytes(image_bytes):
return base64.b64encode(image_bytes).decode('utf-8')
# Function to decode base64 images
def decode_base64_image(base64_str):
image_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(image_data))
# Process text and images provided by the user
def process_input(text_input, images_base64):
images = []
if images_base64:
for img_data in images_base64:
try:
img = decode_base64_image(img_data)
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
image_base64 = encode_image_bytes(buffered.getvalue())
images.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_base64}"})
except Exception as e:
print(f"Error decoding image: {e}")
return text_input, images
# Search setup function
def setup_search(question):
try:
tavily_tool = TavilySearchResults(max_results=20)
results = tavily_tool.invoke({"query": f"{question}"})
if isinstance(results, list):
return results, 'tavily_tool'
except Exception as e:
print("Error with TavilySearchResults:", e)
try:
jina_tool = JinaSearch()
results = json.loads(str(jina_tool.invoke({"query": f"{question}"})))
if isinstance(results, list):
return results, 'jina_tool'
except Exception as e:
print("Error with JinaSearch:", e)
return [], ''
# Function to extract key topics
def extract_key_topics(content, images=[]):
prompt = f"""
Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words.
```{content}```
LIST IN ENGLISH:
-
"""
message_content = [{"type": "text", "text": prompt}] + images
response = client_1.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
def search_relevant_articles_arxiv(key_topics, max_articles=100):
articles_by_topic = {}
final_topics = []
def fetch_articles_for_topic(topic):
topic_articles = []
try:
# Fetch articles using arxiv.py based on the topic
search = arxiv.Search(
query=topic,
max_results=max_articles,
sort_by=arxiv.SortCriterion.Relevance
)
for result in search.results():
article_data = {
"title": result.title,
"doi": result.doi,
"summary": result.summary,
"url": result.entry_id,
"pdf_url": result.pdf_url
}
topic_articles.append(article_data)
final_topics.append(topic)
except Exception as e:
print(f"Error fetching articles for topic '{topic}': {e}")
return topic, topic_articles
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Use threads to fetch articles for each topic
futures = {executor.submit(fetch_articles_for_topic, topic): topic for topic in key_topics}
for future in concurrent.futures.as_completed(futures):
topic, articles = future.result()
if articles:
articles_by_topic[topic] = articles
return articles_by_topic, list(set(final_topics))
# Initialize process for text analysis
def init(content, images=[]):
key_topics = extract_key_topics(content, images)
key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic]
articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics)
result_json = json.dumps(articles_by_topic, indent=4)
return final_topics, result_json
# Summarization function
def process_article_for_summary(text, images=[], compression_percentage=30):
prompt = f"""
You are a commentator.
# article:
{text}
# Instructions:
## Summarize IN RUSSIAN:
In clear and concise language, summarize the key points and themes presented in the article by cutting it by {compression_percentage} percent in the markdown format.
"""
message_content = [{"type": "text", "text": prompt}] + images
response = client_3.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
# Question answering function
def ask_question_to_mistral(text, question, images=[]):
prompt = f"Answer the following question without mentioning it or repeating the original text on which the question is asked in style markdown.IN RUSSIAN:\nQuestion: {question}\n\nText:\n{text}"
message_content = [{"type": "text", "text": prompt}] + images
search_tool, tool = setup_search(question)
context = ''
if search_tool:
if tool == 'tavily_tool':
for result in search_tool:
context += f"{result.get('url', 'N/A')} : {result.get('content', 'No content')} \n"
elif tool == 'jina_tool':
for result in search_tool:
context += f"{result.get('link', 'N/A')} : {result.get('snippet', 'No snippet')} : {result.get('content', 'No content')} \n"
response = client_2.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": f'{message_content}\n\nAdditional Context from Web Search:\n{context}'}]
)
return response.choices[0].message.content
# Gradio interface
def gradio_interface(text_input, images_base64, task, question, compression_percentage):
text, images = process_input(text_input, images_base64)
topics, articles_json = init(text, images)
if task == "Summarization":
summary = process_article_for_summary(text, images, compression_percentage)
return {"Topics": topics, "Summary": summary, "Articles": articles_json}
elif task == "Question Answering":
if question:
answer = ask_question_to_mistral(text, question, images)
return {"Topics": topics, "Answer": answer, "Articles": articles_json}
else:
return {"Topics": topics, "Answer": "No question provided.", "Articles": articles_json}
with gr.Blocks() as demo:
gr.Markdown("## Text Analysis: Summarization or Question Answering")
with gr.Row():
text_input = gr.Textbox(label="Input Text")
images_base64 = gr.Textbox(label="Base64 Images (comma-separated, if any)", placeholder="data:image/jpeg;base64,...", lines=2)
task_choice = gr.Radio(["Summarization", "Question Answering"], label="Select Task")
question_input = gr.Textbox(label="Question (for Question Answering)", visible=False)
compression_input = gr.Slider(label="Compression Percentage (for Summarization)", minimum=10, maximum=90, value=30, visible=False)
task_choice.change(lambda choice: (gr.update(visible=choice == "Question Answering"),
gr.update(visible=choice == "Summarization")),
inputs=task_choice, outputs=[question_input, compression_input])
with gr.Row():
result_output = gr.JSON(label="Results")
submit_button = gr.Button("Submit")
submit_button.click(gradio_interface, [text_input, images_base64, task_choice, question_input, compression_input], result_output)
demo.launch()