Spaces:
Running
Running
minor
Browse files- paper_chat_tab.py +16 -37
paper_chat_tab.py
CHANGED
@@ -73,7 +73,6 @@ def fetch_paper_info_neurips(paper_id):
|
|
73 |
else:
|
74 |
abstract = 'Abstract not found'
|
75 |
|
76 |
-
# Construct preamble
|
77 |
link = f"https://openreview.net/forum?id={paper_id}"
|
78 |
return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
|
79 |
|
@@ -110,12 +109,9 @@ def fetch_paper_content_arxiv(paper_id):
|
|
110 |
|
111 |
|
112 |
def fetch_paper_info_paperpage(paper_id_value):
|
113 |
-
# Extract paper_id from paper_page link or input
|
114 |
def extract_paper_id(input_string):
|
115 |
-
# Already in correct form?
|
116 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
117 |
return input_string.strip()
|
118 |
-
# If URL
|
119 |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
|
120 |
if match:
|
121 |
return match.group(1)
|
@@ -141,7 +137,6 @@ def fetch_paper_info_paperpage(paper_id_value):
|
|
141 |
|
142 |
|
143 |
def fetch_paper_content_paperpage(paper_id_value):
|
144 |
-
# Extract paper_id
|
145 |
def extract_paper_id(input_string):
|
146 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
147 |
return input_string.strip()
|
@@ -155,7 +150,6 @@ def fetch_paper_content_paperpage(paper_id_value):
|
|
155 |
return text
|
156 |
|
157 |
|
158 |
-
# Dictionary for paper sources
|
159 |
PAPER_SOURCES = {
|
160 |
"neurips": {
|
161 |
"fetch_info": fetch_paper_info_neurips,
|
@@ -170,16 +164,13 @@ PAPER_SOURCES = {
|
|
170 |
|
171 |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
|
172 |
provider_max_total_tokens):
|
173 |
-
# Define the function to handle the chat
|
174 |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
|
175 |
max_total_tokens):
|
176 |
provider_info = PROVIDERS[provider_name_value]
|
177 |
endpoint = provider_info['endpoint']
|
178 |
api_key_env_var = provider_info['api_key_env_var']
|
179 |
-
models = provider_info['models']
|
180 |
max_total_tokens = int(max_total_tokens)
|
181 |
|
182 |
-
# Load tokenizer
|
183 |
tokenizer_key = f"{provider_name_value}_{model_name_value}"
|
184 |
if tokenizer_key not in tokenizer_cache:
|
185 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
@@ -188,44 +179,36 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
188 |
else:
|
189 |
tokenizer = tokenizer_cache[tokenizer_key]
|
190 |
|
191 |
-
# Include the paper content as context
|
192 |
if paper_content_value:
|
193 |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
|
194 |
else:
|
195 |
context = ""
|
196 |
|
197 |
-
# Tokenize the context
|
198 |
context_tokens = tokenizer.encode(context)
|
199 |
context_token_length = len(context_tokens)
|
200 |
|
201 |
-
# Prepare the messages without context
|
202 |
messages = []
|
203 |
message_tokens_list = []
|
204 |
-
total_tokens = context_token_length
|
205 |
|
206 |
for user_msg, assistant_msg in history:
|
207 |
-
# Tokenize user message
|
208 |
user_tokens = tokenizer.encode(user_msg)
|
209 |
messages.append({"role": "user", "content": user_msg})
|
210 |
message_tokens_list.append(len(user_tokens))
|
211 |
total_tokens += len(user_tokens)
|
212 |
|
213 |
-
# Tokenize assistant message
|
214 |
if assistant_msg:
|
215 |
assistant_tokens = tokenizer.encode(assistant_msg)
|
216 |
messages.append({"role": "assistant", "content": assistant_msg})
|
217 |
message_tokens_list.append(len(assistant_tokens))
|
218 |
total_tokens += len(assistant_tokens)
|
219 |
|
220 |
-
# Tokenize the new user message
|
221 |
message_tokens = tokenizer.encode(message)
|
222 |
messages.append({"role": "user", "content": message})
|
223 |
message_tokens_list.append(len(message_tokens))
|
224 |
total_tokens += len(message_tokens)
|
225 |
|
226 |
-
# Check if total tokens exceed the maximum allowed tokens
|
227 |
if total_tokens > max_total_tokens:
|
228 |
-
# Attempt to truncate context
|
229 |
available_tokens = max_total_tokens - (total_tokens - context_token_length)
|
230 |
if available_tokens > 0:
|
231 |
truncated_context_tokens = context_tokens[:available_tokens]
|
@@ -237,24 +220,20 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
237 |
total_tokens -= context_token_length
|
238 |
context_token_length = 0
|
239 |
|
240 |
-
# Truncate message history if needed
|
241 |
while total_tokens > max_total_tokens and len(messages) > 1:
|
242 |
removed_message = messages.pop(0)
|
243 |
removed_tokens = message_tokens_list.pop(0)
|
244 |
total_tokens -= removed_tokens
|
245 |
|
246 |
-
# Rebuild the final messages
|
247 |
final_messages = []
|
248 |
if context:
|
249 |
final_messages.append({"role": "system", "content": f"{context}"})
|
250 |
final_messages.extend(messages)
|
251 |
|
252 |
-
# Use the provider's API key
|
253 |
api_key = hf_token_value or os.environ.get(api_key_env_var)
|
254 |
if not api_key:
|
255 |
raise ValueError("API token is not provided.")
|
256 |
|
257 |
-
# Initialize the OpenAI client
|
258 |
client = OpenAI(
|
259 |
base_url=endpoint,
|
260 |
api_key=api_key,
|
@@ -289,6 +268,7 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
|
|
289 |
|
290 |
|
291 |
def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
|
292 |
with gr.Row():
|
293 |
# Left column: Paper selection and display
|
294 |
with gr.Column(scale=1):
|
@@ -316,10 +296,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
316 |
)
|
317 |
select_paper_button = gr.Button("Load this paper")
|
318 |
|
319 |
-
# Paper info display
|
320 |
content = gr.HTML(value="", elem_id="paper_info_card")
|
321 |
|
322 |
-
# Right column: Provider and model selection
|
323 |
with gr.Column(scale=1, visible=False) as provider_section:
|
324 |
gr.Markdown("### LLM Provider and Model")
|
325 |
provider_names = list(PROVIDERS.keys())
|
@@ -354,7 +334,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
354 |
|
355 |
paper_content = gr.State()
|
356 |
|
357 |
-
|
|
|
|
|
|
|
358 |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
|
359 |
hf_token_input, default_type, default_max_total_tokens)
|
360 |
|
@@ -385,7 +368,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
385 |
)
|
386 |
|
387 |
def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
|
388 |
-
# Use PAPER_SOURCES to fetch info
|
389 |
source_info = PAPER_SOURCES.get(paper_from_value, {})
|
390 |
fetch_info_fn = source_info.get("fetch_info")
|
391 |
fetch_pdf_fn = source_info.get("fetch_pdf")
|
@@ -401,7 +383,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
401 |
if text is None:
|
402 |
text = "Paper content could not be retrieved."
|
403 |
|
404 |
-
# Create a styled card for the paper info
|
405 |
card_html = f"""
|
406 |
<div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
|
407 |
<center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
|
@@ -414,7 +395,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
414 |
return gr.update(value=card_html), text, []
|
415 |
|
416 |
def select_paper(paper_title):
|
417 |
-
# Find the corresponding paper_page from the title
|
418 |
for t, ppage in paper_choices:
|
419 |
if t == paper_title:
|
420 |
return ppage, "paper_page"
|
@@ -426,32 +406,34 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
|
426 |
outputs=[paper_id, paper_from]
|
427 |
)
|
428 |
|
429 |
-
# After updating paper_id, we update paper info
|
430 |
paper_id.change(
|
431 |
fn=update_paper_info,
|
432 |
inputs=[paper_id, paper_from, model_dropdown, content],
|
433 |
outputs=[content, paper_content, chatbot]
|
434 |
)
|
435 |
|
436 |
-
# Function to toggle visibility of the right column based on paper_id
|
437 |
def toggle_provider_visibility(paper_id_value):
|
438 |
if paper_id_value and paper_id_value.strip():
|
439 |
return gr.update(visible=True)
|
440 |
else:
|
441 |
return gr.update(visible=False)
|
442 |
|
443 |
-
#
|
444 |
paper_id.change(
|
445 |
fn=toggle_provider_visibility,
|
446 |
inputs=[paper_id],
|
447 |
outputs=[provider_section]
|
448 |
)
|
449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
|
451 |
def main():
|
452 |
-
"""
|
453 |
-
Launches the Gradio app.
|
454 |
-
"""
|
455 |
with gr.Blocks(css_paths="style.css") as demo:
|
456 |
paper_id = gr.Textbox(label="Paper ID", value="")
|
457 |
paper_from = gr.Radio(
|
@@ -460,9 +442,6 @@ def main():
|
|
460 |
value="neurips"
|
461 |
)
|
462 |
|
463 |
-
# Build the paper chat tab
|
464 |
-
dummy_calendar = gr.State(datetime.now().strftime("%Y-%m-%d"))
|
465 |
-
|
466 |
class MockPaperCentral:
|
467 |
def __init__(self):
|
468 |
import pandas as pd
|
|
|
73 |
else:
|
74 |
abstract = 'Abstract not found'
|
75 |
|
|
|
76 |
link = f"https://openreview.net/forum?id={paper_id}"
|
77 |
return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
|
78 |
|
|
|
109 |
|
110 |
|
111 |
def fetch_paper_info_paperpage(paper_id_value):
|
|
|
112 |
def extract_paper_id(input_string):
|
|
|
113 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
114 |
return input_string.strip()
|
|
|
115 |
match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
|
116 |
if match:
|
117 |
return match.group(1)
|
|
|
137 |
|
138 |
|
139 |
def fetch_paper_content_paperpage(paper_id_value):
|
|
|
140 |
def extract_paper_id(input_string):
|
141 |
if re.fullmatch(r'\d+\.\d+', input_string.strip()):
|
142 |
return input_string.strip()
|
|
|
150 |
return text
|
151 |
|
152 |
|
|
|
153 |
PAPER_SOURCES = {
|
154 |
"neurips": {
|
155 |
"fetch_info": fetch_paper_info_neurips,
|
|
|
164 |
|
165 |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
|
166 |
provider_max_total_tokens):
|
|
|
167 |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
|
168 |
max_total_tokens):
|
169 |
provider_info = PROVIDERS[provider_name_value]
|
170 |
endpoint = provider_info['endpoint']
|
171 |
api_key_env_var = provider_info['api_key_env_var']
|
|
|
172 |
max_total_tokens = int(max_total_tokens)
|
173 |
|
|
|
174 |
tokenizer_key = f"{provider_name_value}_{model_name_value}"
|
175 |
if tokenizer_key not in tokenizer_cache:
|
176 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
|
|
179 |
else:
|
180 |
tokenizer = tokenizer_cache[tokenizer_key]
|
181 |
|
|
|
182 |
if paper_content_value:
|
183 |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
|
184 |
else:
|
185 |
context = ""
|
186 |
|
|
|
187 |
context_tokens = tokenizer.encode(context)
|
188 |
context_token_length = len(context_tokens)
|
189 |
|
|
|
190 |
messages = []
|
191 |
message_tokens_list = []
|
192 |
+
total_tokens = context_token_length
|
193 |
|
194 |
for user_msg, assistant_msg in history:
|
|
|
195 |
user_tokens = tokenizer.encode(user_msg)
|
196 |
messages.append({"role": "user", "content": user_msg})
|
197 |
message_tokens_list.append(len(user_tokens))
|
198 |
total_tokens += len(user_tokens)
|
199 |
|
|
|
200 |
if assistant_msg:
|
201 |
assistant_tokens = tokenizer.encode(assistant_msg)
|
202 |
messages.append({"role": "assistant", "content": assistant_msg})
|
203 |
message_tokens_list.append(len(assistant_tokens))
|
204 |
total_tokens += len(assistant_tokens)
|
205 |
|
|
|
206 |
message_tokens = tokenizer.encode(message)
|
207 |
messages.append({"role": "user", "content": message})
|
208 |
message_tokens_list.append(len(message_tokens))
|
209 |
total_tokens += len(message_tokens)
|
210 |
|
|
|
211 |
if total_tokens > max_total_tokens:
|
|
|
212 |
available_tokens = max_total_tokens - (total_tokens - context_token_length)
|
213 |
if available_tokens > 0:
|
214 |
truncated_context_tokens = context_tokens[:available_tokens]
|
|
|
220 |
total_tokens -= context_token_length
|
221 |
context_token_length = 0
|
222 |
|
|
|
223 |
while total_tokens > max_total_tokens and len(messages) > 1:
|
224 |
removed_message = messages.pop(0)
|
225 |
removed_tokens = message_tokens_list.pop(0)
|
226 |
total_tokens -= removed_tokens
|
227 |
|
|
|
228 |
final_messages = []
|
229 |
if context:
|
230 |
final_messages.append({"role": "system", "content": f"{context}"})
|
231 |
final_messages.extend(messages)
|
232 |
|
|
|
233 |
api_key = hf_token_value or os.environ.get(api_key_env_var)
|
234 |
if not api_key:
|
235 |
raise ValueError("API token is not provided.")
|
236 |
|
|
|
237 |
client = OpenAI(
|
238 |
base_url=endpoint,
|
239 |
api_key=api_key,
|
|
|
268 |
|
269 |
|
270 |
def paper_chat_tab(paper_id, paper_from, paper_central_df):
|
271 |
+
# First row with two columns
|
272 |
with gr.Row():
|
273 |
# Left column: Paper selection and display
|
274 |
with gr.Column(scale=1):
|
|
|
296 |
)
|
297 |
select_paper_button = gr.Button("Load this paper")
|
298 |
|
299 |
+
# Paper info display
|
300 |
content = gr.HTML(value="", elem_id="paper_info_card")
|
301 |
|
302 |
+
# Right column: Provider and model selection
|
303 |
with gr.Column(scale=1, visible=False) as provider_section:
|
304 |
gr.Markdown("### LLM Provider and Model")
|
305 |
provider_names = list(PROVIDERS.keys())
|
|
|
334 |
|
335 |
paper_content = gr.State()
|
336 |
|
337 |
+
# Now a new row, full width, for the chat
|
338 |
+
with gr.Row(visible=False) as chat_row:
|
339 |
+
with gr.Column():
|
340 |
+
# Create chat interface below the two columns
|
341 |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
|
342 |
hf_token_input, default_type, default_max_total_tokens)
|
343 |
|
|
|
368 |
)
|
369 |
|
370 |
def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
|
|
|
371 |
source_info = PAPER_SOURCES.get(paper_from_value, {})
|
372 |
fetch_info_fn = source_info.get("fetch_info")
|
373 |
fetch_pdf_fn = source_info.get("fetch_pdf")
|
|
|
383 |
if text is None:
|
384 |
text = "Paper content could not be retrieved."
|
385 |
|
|
|
386 |
card_html = f"""
|
387 |
<div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
|
388 |
<center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
|
|
|
395 |
return gr.update(value=card_html), text, []
|
396 |
|
397 |
def select_paper(paper_title):
|
|
|
398 |
for t, ppage in paper_choices:
|
399 |
if t == paper_title:
|
400 |
return ppage, "paper_page"
|
|
|
406 |
outputs=[paper_id, paper_from]
|
407 |
)
|
408 |
|
|
|
409 |
paper_id.change(
|
410 |
fn=update_paper_info,
|
411 |
inputs=[paper_id, paper_from, model_dropdown, content],
|
412 |
outputs=[content, paper_content, chatbot]
|
413 |
)
|
414 |
|
|
|
415 |
def toggle_provider_visibility(paper_id_value):
|
416 |
if paper_id_value and paper_id_value.strip():
|
417 |
return gr.update(visible=True)
|
418 |
else:
|
419 |
return gr.update(visible=False)
|
420 |
|
421 |
+
# Toggle provider section visibility
|
422 |
paper_id.change(
|
423 |
fn=toggle_provider_visibility,
|
424 |
inputs=[paper_id],
|
425 |
outputs=[provider_section]
|
426 |
)
|
427 |
|
428 |
+
# Toggle chat row visibility
|
429 |
+
paper_id.change(
|
430 |
+
fn=toggle_provider_visibility,
|
431 |
+
inputs=[paper_id],
|
432 |
+
outputs=[chat_row]
|
433 |
+
)
|
434 |
+
|
435 |
|
436 |
def main():
|
|
|
|
|
|
|
437 |
with gr.Blocks(css_paths="style.css") as demo:
|
438 |
paper_id = gr.Textbox(label="Paper ID", value="")
|
439 |
paper_from = gr.Radio(
|
|
|
442 |
value="neurips"
|
443 |
)
|
444 |
|
|
|
|
|
|
|
445 |
class MockPaperCentral:
|
446 |
def __init__(self):
|
447 |
import pandas as pd
|