test_biocad / app.py
Belemort's picture
Update app.py
a18d1e2 verified
raw
history blame
9.24 kB
import gradio as gr
from mistralai import Mistral
from langchain_community.tools import TavilySearchResults, JinaSearch
import concurrent.futures
import json
import os
import arxiv
import fitz # PyMuPDF
from docx import Document
from PIL import Image
import io
import base64
import mimetypes
# Set environment variables for Tavily API
os.environ["TAVILY_API_KEY"] = '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'
# Mistral client API keys
client_1 = Mistral(api_key='eLES5HrVqduOE1OSWG6C5XyEUeR7qpXQ')
client_2 = Mistral(api_key='VPqG8sCy3JX5zFkpdiZ7bRSnTLKwngFJ')
client_3 = Mistral(api_key='cvyu5Rdk2lS026epqL4VB6BMPUcUMSgt')
# Function to encode images in base64
def encode_image_bytes(image_bytes):
return base64.b64encode(image_bytes).decode('utf-8')
# Functions to process various file types
def process_file(file_path):
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type == 'application/pdf':
return process_pdf(file_path)
elif mime_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
return process_docx(file_path)
elif mime_type == 'text/plain':
return process_txt(file_path)
else:
print(f"Unsupported file type: {mime_type}")
return None, []
def process_pdf(file_path):
text = ""
images = []
pdf_document = fitz.open(file_path)
for page_num in range(len(pdf_document)):
text += pdf_document[page_num].get_text("text")
for _, img in enumerate(pdf_document.get_page_images(page_num, full=True)):
xref = img[0]
base_image = pdf_document.extract_image(xref)
image_bytes = base_image["image"]
image_ext = base_image["ext"]
base64_image = encode_image_bytes(image_bytes)
image_data = f"data:image/{image_ext};base64,{base64_image}"
images.append({"type": "image_url", "image_url": image_data})
return text, images
def process_docx(file_path):
doc = Document(file_path)
text = ""
images = []
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
img_data = rel.target_part.blob
img = Image.open(io.BytesIO(img_data))
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
image_base64 = encode_image_bytes(buffered.getvalue())
images.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_base64}"})
return text, images
def process_txt(file_path):
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return text, []
# Search setup function
def setup_search(question):
try:
tavily_tool = TavilySearchResults(max_results=20)
results = tavily_tool.invoke({"query": f"{question}"})
if isinstance(results, list):
return results, 'tavily_tool'
except Exception as e:
print("Error with TavilySearchResults:", e)
try:
jina_tool = JinaSearch()
results = json.loads(str(jina_tool.invoke({"query": f"{question}"})))
if isinstance(results, list):
return results, 'jina_tool'
except Exception as e:
print("Error with JinaSearch:", e)
return [], ''
# Function to extract key topics
def extract_key_topics(content, images=[]):
prompt = f"""
Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words.
```{content}```
LIST IN ENGLISH:
-
"""
message_content = [{"type": "text", "text": prompt}] + images
response = client_1.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
def search_relevant_articles_arxiv(key_topics, max_articles=100):
articles_by_topic = {}
final_topics = []
def fetch_articles_for_topic(topic):
topic_articles = []
try:
# Fetch articles using arxiv.py based on the topic
search = arxiv.Search(
query=topic,
max_results=max_articles,
sort_by=arxiv.SortCriterion.Relevance
)
for result in search.results():
article_data = {
"title": result.title,
"doi": result.doi,
"summary": result.summary,
"url": result.entry_id,
"pdf_url": result.pdf_url
}
topic_articles.append(article_data)
final_topics.append(topic)
except Exception as e:
print(f"Error fetching articles for topic '{topic}': {e}")
return topic, topic_articles
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Use threads to fetch articles for each topic
futures = {executor.submit(fetch_articles_for_topic, topic): topic for topic in key_topics}
for future in concurrent.futures.as_completed(futures):
topic, articles = future.result()
if articles:
articles_by_topic[topic] = articles
return articles_by_topic, list(set(final_topics))
# Initialize process for text analysis
def init(content, images=[]):
key_topics = extract_key_topics(content, images)
key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic]
articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics)
result_json = json.dumps(articles_by_topic, indent=4)
return final_topics, result_json
# Summarization function
def process_article_for_summary(text, images=[], compression_percentage=30):
prompt = f"""
You are a commentator.
# article:
{text}
# Instructions:
## Summarize:
In clear and concise language, summarize the key points and themes presented in the article by cutting it by {compression_percentage} percent in the markdown format.
"""
message_content = [{"type": "text", "text": prompt}] + images
response = client_3.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
# Question answering function
def ask_question_to_mistral(text, question, images=[]):
prompt = f"Answer the following question without mentioning it or repeating the original text on which the question is asked in style markdown.IN RUSSIAN:\nQuestion: {question}\n\nText:\n{text}"
message_content = [{"type": "text", "text": prompt}] + images
search_tool, tool = setup_search(question)
context = ''
if search_tool:
if tool == 'tavily_tool':
for result in search_tool:
context += f"{result.get('url', 'N/A')} : {result.get('content', 'No content')} \n"
elif tool == 'jina_tool':
for result in search_tool:
context += f"{result.get('link', 'N/A')} : {result.get('snippet', 'No snippet')} : {result.get('content', 'No content')} \n"
response = client_2.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": f'{message_content}\n\nAdditional Context from Web Search:\n{context}'}]
)
return response.choices[0].message.content
# Gradio interface
def gradio_interface(file, task, question, compression_percentage):
if file:
text, images = process_file(file.name)
else:
text, images = "", []
topics, articles_json = init(text, images)
if task == "Summarization":
summary = process_article_for_summary(text, images, compression_percentage)
return {"Topics": topics, "Summary": summary, "Articles": articles_json}
elif task == "Question Answering":
if question:
answer = ask_question_to_mistral(text, question, images)
return {"Topics": topics, "Answer": answer, "Articles": articles_json}
else:
return {"Topics": topics, "Answer": "No question provided.", "Articles": articles_json}
with gr.Blocks() as demo:
gr.Markdown("## Text Analysis: Summarization or Question Answering")
with gr.Row():
file_input = gr.File(label="Upload File")
task_choice = gr.Radio(["Summarization", "Question Answering"], label="Select Task")
question_input = gr.Textbox(label="Question (for Question Answering)", visible=False)
compression_input = gr.Slider(label="Compression Percentage (for Summarization)", minimum=10, maximum=90, value=30, visible=False)
task_choice.change(lambda choice: (gr.update(visible=choice == "Question Answering"),
gr.update(visible=choice == "Summarization")),
inputs=task_choice, outputs=[question_input, compression_input])
with gr.Row():
result_output = gr.JSON(label="Results")
submit_button = gr.Button("Submit")
submit_button.click(gradio_interface, [file_input, task_choice, question_input, compression_input], result_output)
demo.launch()