Spaces:
Running
Running
import google.generativeai as genai | |
import gradio as gr | |
from PyPDF2 import PdfReader | |
from bs4 import BeautifulSoup | |
import openai | |
import traceback | |
import requests | |
from io import BytesIO | |
from transformers import AutoTokenizer | |
import json | |
from datetime import datetime | |
import os | |
from openai import OpenAI | |
import re | |
# Cache for tokenizers to avoid reloading | |
tokenizer_cache = {} | |
# Global variables for providers | |
PROVIDERS = { | |
"Gemini": { | |
"name": "Gemini", | |
"logo": "https://www.gstatic.com/lamda/images/gemini_thumbnail_c362e5eadc46ca9f617e2.png", | |
"endpoint": "https://example-gemini-endpoint", # not need | |
# Not necessarily needed for Gemini since we use google.generativeai directly | |
"api_key_env_var": "GEMINI_API_KEY", # If using env vars for key storage | |
"models": [ | |
"gemini-2.0-flash-exp", | |
"gemini-1.5-flash", | |
], | |
"type": "tuples", | |
"max_total_tokens": "50000", | |
}, | |
"SambaNova": { | |
"name": "SambaNova", | |
"logo": "https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg", | |
"endpoint": "https://api.sambanova.ai/v1/", | |
"api_key_env_var": "SAMBANOVA_API_KEY", | |
"models": [ | |
"Meta-Llama-3.1-70B-Instruct", | |
"Meta-Llama-3.3-70B-Instruct", | |
], | |
"type": "tuples", | |
"max_total_tokens": "50000", | |
}, | |
"Hyperbolic": { | |
"name": "hyperbolic", | |
"logo": "https://www.nftgators.com/wp-content/uploads/2024/07/Hyperbolic.jpg", | |
"endpoint": "https://api.hyperbolic.xyz/v1", | |
"api_key_env_var": "HYPERBOLIC_API_KEY", | |
"models": [ | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"meta-llama/Meta-Llama-3.1-405B-Instruct", | |
], | |
"type": "tuples", | |
"max_total_tokens": "50000", | |
}, | |
} | |
# Functions for paper fetching | |
def fetch_paper_info_neurips(paper_id): | |
url = f"https://openreview.net/forum?id={paper_id}" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return None, None, None | |
html_content = response.content | |
soup = BeautifulSoup(html_content, 'html.parser') | |
# Extract title | |
title_tag = soup.find('h2', class_='citation_title') | |
title = title_tag.get_text(strip=True) if title_tag else 'Title not found' | |
# Extract authors | |
authors = [] | |
author_div = soup.find('div', class_='forum-authors') | |
if author_div: | |
author_tags = author_div.find_all('a') | |
authors = [tag.get_text(strip=True) for tag in author_tags] | |
author_list = ', '.join(authors) if authors else 'Authors not found' | |
# Extract abstract | |
abstract_div = soup.find('strong', text='Abstract:') | |
if abstract_div: | |
abstract_paragraph = abstract_div.find_next_sibling('div') | |
abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' | |
else: | |
abstract = 'Abstract not found' | |
link = f"https://openreview.net/forum?id={paper_id}" | |
return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})" | |
def fetch_paper_content_neurips(paper_id): | |
try: | |
url = f"https://openreview.net/pdf?id={paper_id}" | |
response = requests.get(url) | |
response.raise_for_status() | |
pdf_content = BytesIO(response.content) | |
reader = PdfReader(pdf_content) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text | |
except: | |
return None | |
def fetch_paper_content_arxiv(paper_id): | |
try: | |
url = f"https://arxiv.org/pdf/{paper_id}.pdf" | |
response = requests.get(url) | |
response.raise_for_status() | |
pdf_content = BytesIO(response.content) | |
reader = PdfReader(pdf_content) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text | |
except Exception as e: | |
print(f"Error fetching paper content: {e}") | |
return None | |
def fetch_paper_info_paperpage(paper_id_value): | |
def extract_paper_id(input_string): | |
if re.fullmatch(r'\d+\.\d+', input_string.strip()): | |
return input_string.strip() | |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string) | |
if match: | |
return match.group(1) | |
return input_string.strip() | |
paper_id_value = extract_paper_id(paper_id_value) | |
url = f"https://huggingface.co/api/papers/{paper_id_value}?field=comments" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return None, None, None | |
paper_info = response.json() | |
title = paper_info.get('title', 'No Title') | |
authors_list = [author.get('name', 'Unknown') for author in paper_info.get('authors', [])] | |
authors = ', '.join(authors_list) | |
summary = paper_info.get('summary', 'No Summary') | |
num_comments = len(paper_info.get('comments', [])) | |
num_upvotes = paper_info.get('upvotes', 0) | |
link = f"https://huggingface.co/papers/{paper_id_value}" | |
details = f"{summary}<br/>👍{num_comments} 💬{num_upvotes}<br/> <a href='{link}' " \ | |
f"target='_blank'>View on 🤗 hugging face</a>" | |
return title, authors, details | |
def fetch_paper_content_paperpage(paper_id_value): | |
def extract_paper_id(input_string): | |
if re.fullmatch(r'\d+\.\d+', input_string.strip()): | |
return input_string.strip() | |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string) | |
if match: | |
return match.group(1) | |
return input_string.strip() | |
paper_id_value = extract_paper_id(paper_id_value) | |
text = fetch_paper_content_arxiv(paper_id_value) | |
return text | |
PAPER_SOURCES = { | |
"neurips": { | |
"fetch_info": fetch_paper_info_neurips, | |
"fetch_pdf": fetch_paper_content_neurips | |
}, | |
"paper_page": { | |
"fetch_info": fetch_paper_info_paperpage, | |
"fetch_pdf": fetch_paper_content_paperpage | |
} | |
} | |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type, | |
provider_max_total_tokens): | |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value, | |
max_total_tokens): | |
provider_info = PROVIDERS[provider_name_value] | |
endpoint = provider_info['endpoint'] | |
api_key_env_var = provider_info['api_key_env_var'] | |
max_total_tokens = int(max_total_tokens) | |
tokenizer_key = f"{provider_name_value}_{model_name_value}" | |
if tokenizer_key not in tokenizer_cache: | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", | |
token=os.environ.get("HF_TOKEN")) | |
tokenizer_cache[tokenizer_key] = tokenizer | |
else: | |
tokenizer = tokenizer_cache[tokenizer_key] | |
if paper_content_value: | |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n" | |
else: | |
context = "" | |
context_tokens = tokenizer.encode(context) | |
context_token_length = len(context_tokens) | |
messages = [] | |
message_tokens_list = [] | |
total_tokens = context_token_length | |
# Reconstruct the conversation from history and current user message | |
for user_msg, assistant_msg in history: | |
user_tokens = tokenizer.encode(user_msg) | |
messages.append({"role": "user", "content": user_msg}) | |
message_tokens_list.append(len(user_tokens)) | |
total_tokens += len(user_tokens) | |
if assistant_msg: | |
assistant_tokens = tokenizer.encode(assistant_msg) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
message_tokens_list.append(len(assistant_tokens)) | |
total_tokens += len(assistant_tokens) | |
message_tokens = tokenizer.encode(message) | |
messages.append({"role": "user", "content": message}) | |
message_tokens_list.append(len(message_tokens)) | |
total_tokens += len(message_tokens) | |
# Token truncation logic | |
if total_tokens > max_total_tokens: | |
available_tokens = max_total_tokens - (total_tokens - context_token_length) | |
if available_tokens > 0: | |
truncated_context_tokens = context_tokens[:available_tokens] | |
context = tokenizer.decode(truncated_context_tokens) | |
context_token_length = available_tokens | |
total_tokens = total_tokens - len(context_tokens) + context_token_length | |
else: | |
context = "" | |
total_tokens -= context_token_length | |
context_token_length = 0 | |
while total_tokens > max_total_tokens and len(messages) > 1: | |
removed_message = messages.pop(0) | |
removed_tokens = message_tokens_list.pop(0) | |
total_tokens -= removed_tokens | |
final_messages = [] | |
if context: | |
final_messages.append( | |
{"role": "system" if not provider_name_value == "Gemini" else "user", "content": f"{context}"}) | |
final_messages.extend(messages) | |
api_key = hf_token_value or os.environ.get(api_key_env_var) | |
if not api_key: | |
raise ValueError("API token is not provided.") | |
# Gemini logic | |
if provider_name_value == "Gemini": | |
import google.generativeai as genai | |
genai.configure(api_key=api_key) | |
# According to the docs, model should be instantiated with full model name, e.g. "models/gemini-1.5-flash" | |
# Ensure your PROVIDERS dict sets the model_name_value accordingly (e.g. "models/gemini-1.5-flash") | |
model = genai.GenerativeModel(model_name=model_name_value) | |
# Convert final_messages into Gemini's format: | |
# Gemini expects a list of messages: [{"role": "user"/"assistant"/"system", "parts": ["..."]}, ...] | |
gemini_messages = [] | |
for m in final_messages: | |
gemini_messages.append({"role": m["role"], "parts": [m["content"]]}) | |
# Now call generate_content with stream=True | |
try: | |
response = model.generate_content(gemini_messages, stream=True) | |
response_text = "" | |
for chunk in response: | |
if chunk.text: | |
response_text += chunk.text | |
yield response_text | |
except Exception as ex: | |
yield f"Error calling Gemini: {ex}" | |
else: | |
# Default OpenAI-compatible logic | |
from openai import OpenAI | |
import openai | |
import json | |
client = OpenAI( | |
base_url=endpoint, | |
api_key=api_key, | |
) | |
try: | |
completion = client.chat.completions.create( | |
model=model_name_value, | |
messages=final_messages, | |
stream=True, | |
) | |
response_text = "" | |
for chunk in completion: | |
delta = chunk.choices[0].delta.content or "" | |
response_text += delta | |
yield response_text | |
except json.JSONDecodeError as e: | |
yield f"JSON decoding error: {e.msg}" | |
except openai.OpenAIError as openai_err: | |
yield f"OpenAI error: {openai_err}" | |
except Exception as ex: | |
yield f"Unexpected error: {ex}" | |
chatbot = gr.Chatbot(label="Chatbot", scale=1, height=800, autoscroll=True) | |
chat_interface = gr.ChatInterface( | |
fn=get_fn, | |
chatbot=chatbot, | |
additional_inputs=[paper_content, hf_token_input, provider_dropdown, model_dropdown, provider_max_total_tokens], | |
type="tuples", | |
) | |
return chat_interface, chatbot | |
def paper_chat_tab(paper_id, paper_from, paper_central_df): | |
# A top-level button to "Chat with another paper" (visible only if paper_id is set) | |
# We'll place it above everything | |
chat_another_button = gr.Button("Chat with another paper", variant="primary", visible=False) | |
# First row with two columns | |
with gr.Row(): | |
# Left column: Paper selection and display | |
with gr.Column(scale=1): | |
todays_date = datetime.today().strftime('%Y-%m-%d') | |
# Filter papers for today's date and having a paper_page | |
selectable_papers = paper_central_df.df_prettified | |
selectable_papers = selectable_papers[ | |
selectable_papers['paper_page'].notna() & | |
(selectable_papers['paper_page'] != "") & | |
(selectable_papers['date'] == todays_date) | |
] | |
paper_choices = [(row['title'], row['paper_page']) for _, row in selectable_papers.iterrows()] | |
paper_choices = sorted(paper_choices, key=lambda x: x[0]) | |
if not paper_choices: | |
paper_choices = [("No available papers for today", "")] | |
paper_select = gr.Dropdown( | |
label="Select a paper to chat with: (from today's 🤗 hugging face paper page)", | |
choices=[p[0] for p in paper_choices], | |
value=paper_choices[0][0] if paper_choices else None | |
) | |
# Add a textbox for user to enter a paper_id (arxiv_id) | |
paper_id_input = gr.Textbox( | |
label="Or enter a 🤗 paper_id directly", | |
placeholder="e.g. 1234.56789" | |
) | |
select_paper_button = gr.Button("Load this paper") | |
# Paper info display | |
content = gr.HTML(value="", elem_id="paper_info_card") | |
# Right column: Provider and model selection | |
with gr.Column(scale=1, visible=False) as provider_section: | |
gr.Markdown("### LLM Provider and Model") | |
provider_names = list(PROVIDERS.keys()) | |
default_provider = provider_names[0] | |
default_type = gr.State(value=PROVIDERS[default_provider]["type"]) | |
default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"]) | |
provider_dropdown = gr.Dropdown( | |
label="Select Provider", | |
choices=provider_names, | |
value=default_provider | |
) | |
hf_token_input = gr.Textbox( | |
label=f"Enter your {default_provider} API token (optional)", | |
type="password", | |
placeholder=f"Enter your {default_provider} API token to avoid rate limits" | |
) | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=PROVIDERS[default_provider]['models'], | |
value=PROVIDERS[default_provider]['models'][0] | |
) | |
logo_html = gr.HTML( | |
value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />' | |
) | |
note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.") | |
paper_content = gr.State() | |
# Now a new row, full width, for the chat | |
with gr.Row(visible=False) as chat_row: | |
with gr.Column(): | |
# Create chat interface below the two columns | |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content, | |
hf_token_input, default_type, default_max_total_tokens) | |
def update_provider(selected_provider): | |
provider_info = PROVIDERS[selected_provider] | |
models = provider_info['models'] | |
logo_url = provider_info['logo'] | |
max_total_tokens = provider_info['max_total_tokens'] | |
model_dropdown_choices = gr.update(choices=models, value=models[0]) | |
logo_html_content = f'<img src="{logo_url}" width="100px" />' | |
logo_html_update = gr.update(value=logo_html_content) | |
note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.") | |
hf_token_input_update = gr.update( | |
label=f"Enter your {selected_provider} API token (optional)", | |
placeholder=f"Enter your {selected_provider} API token to avoid rate limits" | |
) | |
chatbot_reset = [] | |
return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, provider_info[ | |
'type'], max_total_tokens, chatbot_reset | |
provider_dropdown.change( | |
fn=update_provider, | |
inputs=provider_dropdown, | |
outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens, | |
chatbot], | |
queue=False | |
) | |
def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content): | |
source_info = PAPER_SOURCES.get(paper_from_value, {}) | |
fetch_info_fn = source_info.get("fetch_info") | |
fetch_pdf_fn = source_info.get("fetch_pdf") | |
if not fetch_info_fn or not fetch_pdf_fn: | |
return gr.update(value="<div>No information available.</div>"), None, [] | |
title, authors, details = fetch_info_fn(paper_id_value) | |
if title is None and authors is None and details is None: | |
return gr.update(value="<div>No information could be retrieved.</div>"), None, [] | |
text = fetch_pdf_fn(paper_id_value) | |
if text is None: | |
text = "Paper content could not be retrieved." | |
card_html = f""" | |
<div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;"> | |
<center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center> | |
<h3>{title}</h3> | |
<p><strong>Authors:</strong> {authors}</p> | |
<p>{details}</p> | |
</div> | |
""" | |
return gr.update(value=card_html), text, [] | |
def select_paper(paper_title, paper_id_val): | |
# If user provided a paper_id_val (arxiv_id), use that | |
if paper_id_val and paper_id_val.strip(): | |
# Check if it exists in df as a paper with paper_page not None | |
df = paper_central_df.df_raw | |
# We assume `arxiv_id` column exists in df (the user requested checking arxiv_id) | |
# If not present, you must ensure `paper_central_df` has `arxiv_id` column. | |
if 'arxiv_id' not in df.columns: | |
return gr.update(value="<div>arxiv_id column not found in dataset</div>"), None | |
found = df[ | |
(df['arxiv_id'] == paper_id_val.strip()) & | |
df['paper_page'].notna() & (df['paper_page'] != "") | |
] | |
if len(found) > 0: | |
# We found a matching paper | |
return paper_id_val.strip(), "paper_page" | |
else: | |
# Not found, show error in content | |
# We can't directly show error from here. We'll return something that doesn't update states and rely on error message | |
# Let's return empty paper_id and paper_from but we must also show error in content after this call | |
return "", "" | |
else: | |
# fallback to dropdown selection | |
for t, ppage in paper_choices: | |
if t == paper_title: | |
return ppage, "paper_page" | |
return "", "" | |
select_paper_button.click( | |
fn=select_paper, | |
inputs=[paper_select, paper_id_input], | |
outputs=[paper_id, paper_from] | |
) | |
# After the paper_id/paper_from are set, we update paper info | |
paper_id_update = paper_id.change( | |
fn=update_paper_info, | |
inputs=[paper_id, paper_from, model_dropdown, content], | |
outputs=[content, paper_content, chatbot] | |
) | |
def toggle_provider_visibility(paper_id_value): | |
if paper_id_value and paper_id_value.strip(): | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
paper_id.change( | |
fn=toggle_provider_visibility, | |
inputs=[paper_id], | |
outputs=[provider_section] | |
) | |
paper_id.change( | |
fn=toggle_provider_visibility, | |
inputs=[paper_id], | |
outputs=[chat_row] | |
) | |
# Show/hide the "Chat with another paper" button | |
# If paper_id is set, show it. If not, hide it. | |
def toggle_chat_another_button(paper_id_value): | |
if paper_id_value and paper_id_value.strip(): | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
paper_id.change( | |
fn=toggle_chat_another_button, | |
inputs=[paper_id], | |
outputs=[chat_another_button] | |
) | |
# Button action to reset paper_id to None | |
def reset_paper_id(): | |
# reset paper_id to "" | |
return "", "neurips", gr.update(value="<div></div>") | |
# When this button is clicked, we reset the paper_id and content | |
chat_another_button.click( | |
fn=reset_paper_id, | |
outputs=[paper_id, paper_from, content] | |
) | |
# If user tried an invalid paper_id_input, no error was shown yet: | |
# Actually we can show error message if no paper selected by updating after select_paper_button | |
# The select_paper returns paper_id/paper_from. If empty means error: | |
def check_paper_id_error(p_id, p_from): | |
# If p_id is empty after clicking load, show error message | |
if not p_id: | |
return gr.update(value="<div style='color:red;'>No valid paper found for the given input.</div>") | |
else: | |
return gr.update() | |
select_paper_button.click( | |
fn=check_paper_id_error, | |
inputs=[paper_id, paper_from], | |
outputs=[content], | |
queue=False | |
) | |
def main(): | |
with gr.Blocks(css_paths="style.css") as demo: | |
paper_id = gr.Textbox(label="Paper ID", value="") | |
paper_from = gr.Radio( | |
label="Paper Source", | |
choices=["neurips", "paper_page"], | |
value="neurips" | |
) | |
class MockPaperCentral: | |
def __init__(self): | |
import pandas as pd | |
data = { | |
'date': [datetime.today().strftime('%Y-%m-%d')], | |
'paper_page': ['1234.56789'], | |
'arxiv_id': ['1234.56789'], # adding arxiv_id column as user requested | |
'title': ['An Example Paper'] | |
} | |
self.df_prettified = pd.DataFrame(data) | |
paper_central_df = MockPaperCentral() | |
paper_chat_tab(paper_id, paper_from, paper_central_df) | |
demo.launch(ssr_mode=False) | |
if __name__ == "__main__": | |
main() | |