test_biocad / app.py
Belemort's picture
Update app.py
410ba66 verified
raw
history blame
18.4 kB
import gradio as gr
from mistralai import Mistral
from langchain_community.tools import TavilySearchResults, JinaSearch
import concurrent.futures
import json
import os
import arxiv
from PIL import Image
import io
import base64
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.text_splitter import CharacterTextSplitter
from langchain_mistralai import ChatMistralAI
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain_core.prompts import PromptTemplate
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistral-community/pixtral-12b")
def count_tokens_in_text(text):
tokens = tokenizer(text, return_tensors="pt", truncation=False, add_special_tokens=True)
return len(tokens["input_ids"][0])
# Set environment variables for Tavily API
os.environ["TAVILY_API_KEY"] = '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'
# Mistral client API keys
client_1 = Mistral(api_key='eLES5HrVqduOE1OSWG6C5XyEUeR7qpXQ')
client_2 = Mistral(api_key='VPqG8sCy3JX5zFkpdiZ7bRSnTLKwngFJ')
client_3 = Mistral(api_key='cvyu5Rdk2lS026epqL4VB6BMPUcUMSgt')
api_key_4 = 'aYls8aj48SOEov8AY1dwp4hr07MsCRFb'
client_4 = ChatMistralAI(api_key=api_key_4, model="pixtral-12b-2409")
# Function to encode images in base64
def encode_image_bytes(image_bytes):
return base64.b64encode(image_bytes).decode('utf-8')
# Function to decode base64 images
def decode_base64_image(base64_str):
image_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(image_data))
# Process text and images provided by the user
def process_input(text_input, images_base64):
images = []
if images_base64:
for img_data in images_base64:
try:
img = decode_base64_image(img_data)
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
image_base64 = encode_image_bytes(buffered.getvalue())
images.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_base64}"})
except Exception as e:
print(f"Error decoding image: {e}")
return text_input, images
# Search setup function
def setup_search(question):
try:
tavily_tool = TavilySearchResults(max_results=20)
results = tavily_tool.invoke({"query": f"{question}"})
if isinstance(results, list):
return results, 'tavily_tool'
except Exception as e:
print("Error with TavilySearchResults:", e)
try:
jina_tool = JinaSearch()
results = json.loads(str(jina_tool.invoke({"query": f"{question}"})))
if isinstance(results, list):
return results, 'jina_tool'
except Exception as e:
print("Error with JinaSearch:", e)
return [], ''
# Function to extract key topics
def extract_key_topics(content, images=[]):
prompt = f"""
Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words.
```{content}```
LIST IN ENGLISH:
-
"""
message_content = [{"type": "text", "text": prompt}] + images
response = client_1.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
def extract_key_topics_with_large_text(content, images=[]):
# Map prompt template for extracting key themes
map_template = f"""
Текст: {{docs}}
Изображения: {{images}}
Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words.
LIST IN ENGLISH:
-
:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=client_4, prompt=map_prompt)
# Reduce prompt template to further refine and extract key themes
reduce_template = f"""Следующий текст состоит из нескольких кратких итогов:
{{docs}}
Extract the primary themes from the text below. List each theme in as few words as possible, focusing on essential concepts only. Format as a concise, unordered list with no extraneous words.
LIST IN ENGLISH:
-
:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt)
# Combine documents chain for Reduce step
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# ReduceDocumentsChain configuration
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=128000,
)
# MapReduceDocumentsChain combining Map and Reduce
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
# Text splitter configuration
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer,
chunk_size=100000,
chunk_overlap=14000,
)
# Split the text into documents
split_docs = text_splitter.create_documents([content])
# Include image descriptions (optional, if required by the prompt)
image_descriptions = "\n".join(
[f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)]
)
# Run the summarization chain to extract key themes
key_topics = map_reduce_chain.run({"input_documents": split_docs, "images": image_descriptions})
return key_topics
def search_relevant_articles_arxiv(key_topics, max_articles=100):
articles_by_topic = {}
final_topics = []
def fetch_articles_for_topic(topic):
topic_articles = []
try:
# Fetch articles using arxiv.py based on the topic
search = arxiv.Search(
query=topic,
max_results=max_articles,
sort_by=arxiv.SortCriterion.Relevance
)
for result in search.results():
article_data = {
"title": result.title,
"doi": result.doi,
"summary": result.summary,
"url": result.entry_id,
"pdf_url": result.pdf_url
}
topic_articles.append(article_data)
final_topics.append(topic)
except Exception as e:
print(f"Error fetching articles for topic '{topic}': {e}")
return topic, topic_articles
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Use threads to fetch articles for each topic
futures = {executor.submit(fetch_articles_for_topic, topic): topic for topic in key_topics}
for future in concurrent.futures.as_completed(futures):
topic, articles = future.result()
if articles:
articles_by_topic[topic] = articles
return articles_by_topic, list(set(final_topics))
def init(content, images=[]):
if count_tokens_in_text(text=content) < 128_000:
key_topics = extract_key_topics(content, images)
key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic]
articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics)
result_json = json.dumps(articles_by_topic, indent=4)
return final_topics, result_json
else:
key_topics = extract_key_topics_with_large_text(content, images)
key_topics = [topic.strip("- ") for topic in key_topics.split("\n") if topic]
articles_by_topic, final_topics = search_relevant_articles_arxiv(key_topics)
result_json = json.dumps(articles_by_topic, indent=4)
return final_topics, result_json
# Summarization function
def process_article_for_summary(text, images=[], compression_percentage=30):
prompt = f"""
You are a commentator.
# article:
{text}
# Instructions:
## Summarize IN RUSSIAN:
In clear and concise language, summarize the key points and themes presented in the article by cutting it by {compression_percentage} percent in the markdown format.
"""
if len(images) >= 8 :
images = images[:7]
message_content = [{"type": "text", "text": prompt}] + images
response = client_3.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": message_content}]
)
return response.choices[0].message.content
def process_large_article_for_summary(text, images=[], compression_percentage=30):
# Map prompt template
map_template = f"""Следующий текст состоит из текста и изображений:
Текст: {{docs}}
Изображения: {{images}}
На основе приведенного материала, выполните сжатие текста, выделяя основные темы и важные моменты.
Уровень сжатия: {compression_percentage}%.
Ответ предоставьте на русском языке в формате Markdown.
Полезный ответ:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=client_4, prompt=map_prompt)
# Reduce prompt template
reduce_template = f"""Следующий текст состоит из нескольких кратких итогов:
{{docs}}
На основе этих кратких итогов, выполните финальное сжатие текста, объединяя основные темы и ключевые моменты.
Уровень сжатия: {compression_percentage}%.
Результат предоставьте на русском языке в формате Markdown.
Полезный ответ:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt)
# Combine documents chain for Reduce step
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# ReduceDocumentsChain configuration
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=128000,
)
# MapReduceDocumentsChain combining Map and Reduce
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
# Text splitter configuration
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer,
chunk_size=100000,
chunk_overlap=14000,
)
# Split the text into documents
split_docs = text_splitter.create_documents([text])
# Include image descriptions
image_descriptions = "\n".join(
[f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)]
)
# Run the summarization chain
with concurrent.futures.ThreadPoolExecutor() as executor:
extract_future = executor.submit(init, text, images)
summary = map_reduce_chain.run({"input_documents": split_docs, "images": image_descriptions})
key_topics , result_article_json = extract_future.result()
return summary, key_topics, result_article_json
# Question answering function
def ask_question_to_mistral(text, question, images=[]):
prompt = f"Answer the following question without mentioning it or repeating the original text on which the question is asked in style markdown.IN RUSSIAN:\nQuestion: {question}\n\nText:\n{text}"
if len(images) >= 8 :
images = images[:7]
message_content = [{"type": "text", "text": prompt}] + images
search_tool, tool = setup_search(question)
context = ''
if search_tool:
if tool == 'tavily_tool':
for result in search_tool:
context += f"{result.get('url', 'N/A')} : {result.get('content', 'No content')} \n"
elif tool == 'jina_tool':
for result in search_tool:
context += f"{result.get('link', 'N/A')} : {result.get('snippet', 'No snippet')} : {result.get('content', 'No content')} \n"
response = client_2.chat.complete(
model="pixtral-12b-2409",
messages=[{"role": "user", "content": f'{message_content}\n\nAdditional Context from Web Search:\n{context}'}]
)
return response.choices[0].message.content
def ask_question_to_mistral_with_large_text(text, question, images=[]):
# Prompts for QA
map_template = """Следующий текст содержит статью/произведение:
Текст: {{docs}}
Изображения: {{images}}
На основе приведенного текста, ответьте на следующий вопрос:
Вопрос: {question}
Ответ должен быть точным. Пожалуйста, ответьте на русском языке в формате Markdown.
Полезный ответ:"""
reduce_template = """Следующий текст содержит несколько кратких ответов на вопрос:
{{docs}}
Объедините их в финальный ответ. Ответ предоставьте на русском языке в формате Markdown.
Полезный ответ:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=client_4, prompt=map_prompt)
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=client_4, prompt=reduce_prompt)
# Combine documents chain for Reduce step
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# ReduceDocumentsChain configuration
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=128000,
)
# MapReduceDocumentsChain combining Map and Reduce
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
# Text splitter configuration
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer,
chunk_size=100000,
chunk_overlap=14000,
)
# Split the text into documents
split_docs = text_splitter.create_documents([text])
# Include image descriptions
image_descriptions = "\n".join(
[f"Изображение {i+1}: {img['image_url']}" for i, img in enumerate(images)]
)
with concurrent.futures.ThreadPoolExecutor() as executor:
extract_future = executor.submit(init, text, images)
summary = map_reduce_chain.run({"input_documents": split_docs, "question": question , "images": image_descriptions})
key_topics , result_article_json = extract_future.result()
return summary, key_topics, result_article_json
# Gradio interface
def gradio_interface(text_input, images_base64, task, question, compression_percentage):
text, images = process_input(text_input, images_base64)
if task == "Summarization":
if count_tokens_in_text(text=text) < 128_000:
topics, articles_json = init(text, images)
summary = process_article_for_summary(text, images, compression_percentage)
return {"Topics": topics, "Summary": summary, "Articles": articles_json}
else:
summary , key_topics, result_article_json = process_large_article_for_summary(text, images, compression_percentage)
return {"Topics": key_topics, "Summary": summary, "Articles": result_article_json}
elif task == "Question Answering":
if question:
if count_tokens_in_text(text=text) < 128_000:
topics, articles_json = init(text, images)
answer = ask_question_to_mistral(text, question, images)
return {"Topics": topics, "Answer": answer, "Articles": articles_json}
else:
summary , key_topics, result_article_json = ask_question_to_mistral_with_large_text(text, question, images)
return {"Topics": key_topics, "Answer": answer, "Articles": result_article_json}
else:
return {"Topics": topics, "Answer": "No question provided.", "Articles": articles_json}
with gr.Blocks() as demo:
gr.Markdown("## Text Analysis: Summarization or Question Answering")
with gr.Row():
text_input = gr.Textbox(label="Input Text")
images_base64 = gr.Textbox(label="Base64 Images (comma-separated, if any)", placeholder="data:image/jpeg;base64,...", lines=2)
task_choice = gr.Radio(["Summarization", "Question Answering"], label="Select Task")
question_input = gr.Textbox(label="Question (for Question Answering)", visible=False)
compression_input = gr.Slider(label="Compression Percentage (for Summarization)", minimum=10, maximum=90, value=30, visible=False)
task_choice.change(lambda choice: (gr.update(visible=choice == "Question Answering"),
gr.update(visible=choice == "Summarization")),
inputs=task_choice, outputs=[question_input, compression_input])
with gr.Row():
result_output = gr.JSON(label="Results")
submit_button = gr.Button("Submit")
submit_button.click(gradio_interface, [text_input, images_base64, task_choice, question_input, compression_input], result_output)
demo.launch(show_error=True)