IAMJB commited on
Commit
b1815fe
·
1 Parent(s): 8c5c31d
Files changed (1) hide show
  1. paper_chat_tab.py +16 -37
paper_chat_tab.py CHANGED
@@ -73,7 +73,6 @@ def fetch_paper_info_neurips(paper_id):
73
  else:
74
  abstract = 'Abstract not found'
75
 
76
- # Construct preamble
77
  link = f"https://openreview.net/forum?id={paper_id}"
78
  return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
79
 
@@ -110,12 +109,9 @@ def fetch_paper_content_arxiv(paper_id):
110
 
111
 
112
  def fetch_paper_info_paperpage(paper_id_value):
113
- # Extract paper_id from paper_page link or input
114
  def extract_paper_id(input_string):
115
- # Already in correct form?
116
  if re.fullmatch(r'\d+\.\d+', input_string.strip()):
117
  return input_string.strip()
118
- # If URL
119
  match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
120
  if match:
121
  return match.group(1)
@@ -141,7 +137,6 @@ def fetch_paper_info_paperpage(paper_id_value):
141
 
142
 
143
  def fetch_paper_content_paperpage(paper_id_value):
144
- # Extract paper_id
145
  def extract_paper_id(input_string):
146
  if re.fullmatch(r'\d+\.\d+', input_string.strip()):
147
  return input_string.strip()
@@ -155,7 +150,6 @@ def fetch_paper_content_paperpage(paper_id_value):
155
  return text
156
 
157
 
158
- # Dictionary for paper sources
159
  PAPER_SOURCES = {
160
  "neurips": {
161
  "fetch_info": fetch_paper_info_neurips,
@@ -170,16 +164,13 @@ PAPER_SOURCES = {
170
 
171
  def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
172
  provider_max_total_tokens):
173
- # Define the function to handle the chat
174
  def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
175
  max_total_tokens):
176
  provider_info = PROVIDERS[provider_name_value]
177
  endpoint = provider_info['endpoint']
178
  api_key_env_var = provider_info['api_key_env_var']
179
- models = provider_info['models']
180
  max_total_tokens = int(max_total_tokens)
181
 
182
- # Load tokenizer
183
  tokenizer_key = f"{provider_name_value}_{model_name_value}"
184
  if tokenizer_key not in tokenizer_cache:
185
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
@@ -188,44 +179,36 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
188
  else:
189
  tokenizer = tokenizer_cache[tokenizer_key]
190
 
191
- # Include the paper content as context
192
  if paper_content_value:
193
  context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
194
  else:
195
  context = ""
196
 
197
- # Tokenize the context
198
  context_tokens = tokenizer.encode(context)
199
  context_token_length = len(context_tokens)
200
 
201
- # Prepare the messages without context
202
  messages = []
203
  message_tokens_list = []
204
- total_tokens = context_token_length # Start with context tokens
205
 
206
  for user_msg, assistant_msg in history:
207
- # Tokenize user message
208
  user_tokens = tokenizer.encode(user_msg)
209
  messages.append({"role": "user", "content": user_msg})
210
  message_tokens_list.append(len(user_tokens))
211
  total_tokens += len(user_tokens)
212
 
213
- # Tokenize assistant message
214
  if assistant_msg:
215
  assistant_tokens = tokenizer.encode(assistant_msg)
216
  messages.append({"role": "assistant", "content": assistant_msg})
217
  message_tokens_list.append(len(assistant_tokens))
218
  total_tokens += len(assistant_tokens)
219
 
220
- # Tokenize the new user message
221
  message_tokens = tokenizer.encode(message)
222
  messages.append({"role": "user", "content": message})
223
  message_tokens_list.append(len(message_tokens))
224
  total_tokens += len(message_tokens)
225
 
226
- # Check if total tokens exceed the maximum allowed tokens
227
  if total_tokens > max_total_tokens:
228
- # Attempt to truncate context
229
  available_tokens = max_total_tokens - (total_tokens - context_token_length)
230
  if available_tokens > 0:
231
  truncated_context_tokens = context_tokens[:available_tokens]
@@ -237,24 +220,20 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
237
  total_tokens -= context_token_length
238
  context_token_length = 0
239
 
240
- # Truncate message history if needed
241
  while total_tokens > max_total_tokens and len(messages) > 1:
242
  removed_message = messages.pop(0)
243
  removed_tokens = message_tokens_list.pop(0)
244
  total_tokens -= removed_tokens
245
 
246
- # Rebuild the final messages
247
  final_messages = []
248
  if context:
249
  final_messages.append({"role": "system", "content": f"{context}"})
250
  final_messages.extend(messages)
251
 
252
- # Use the provider's API key
253
  api_key = hf_token_value or os.environ.get(api_key_env_var)
254
  if not api_key:
255
  raise ValueError("API token is not provided.")
256
 
257
- # Initialize the OpenAI client
258
  client = OpenAI(
259
  base_url=endpoint,
260
  api_key=api_key,
@@ -289,6 +268,7 @@ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_t
289
 
290
 
291
  def paper_chat_tab(paper_id, paper_from, paper_central_df):
 
292
  with gr.Row():
293
  # Left column: Paper selection and display
294
  with gr.Column(scale=1):
@@ -316,10 +296,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
316
  )
317
  select_paper_button = gr.Button("Load this paper")
318
 
319
- # Paper info display - styled card
320
  content = gr.HTML(value="", elem_id="paper_info_card")
321
 
322
- # Right column: Provider and model selection + chat
323
  with gr.Column(scale=1, visible=False) as provider_section:
324
  gr.Markdown("### LLM Provider and Model")
325
  provider_names = list(PROVIDERS.keys())
@@ -354,7 +334,10 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
354
 
355
  paper_content = gr.State()
356
 
357
- # Create chat interface
 
 
 
358
  chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
359
  hf_token_input, default_type, default_max_total_tokens)
360
 
@@ -385,7 +368,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
385
  )
386
 
387
  def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
388
- # Use PAPER_SOURCES to fetch info
389
  source_info = PAPER_SOURCES.get(paper_from_value, {})
390
  fetch_info_fn = source_info.get("fetch_info")
391
  fetch_pdf_fn = source_info.get("fetch_pdf")
@@ -401,7 +383,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
401
  if text is None:
402
  text = "Paper content could not be retrieved."
403
 
404
- # Create a styled card for the paper info
405
  card_html = f"""
406
  <div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
407
  <center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
@@ -414,7 +395,6 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
414
  return gr.update(value=card_html), text, []
415
 
416
  def select_paper(paper_title):
417
- # Find the corresponding paper_page from the title
418
  for t, ppage in paper_choices:
419
  if t == paper_title:
420
  return ppage, "paper_page"
@@ -426,32 +406,34 @@ def paper_chat_tab(paper_id, paper_from, paper_central_df):
426
  outputs=[paper_id, paper_from]
427
  )
428
 
429
- # After updating paper_id, we update paper info
430
  paper_id.change(
431
  fn=update_paper_info,
432
  inputs=[paper_id, paper_from, model_dropdown, content],
433
  outputs=[content, paper_content, chatbot]
434
  )
435
 
436
- # Function to toggle visibility of the right column based on paper_id
437
  def toggle_provider_visibility(paper_id_value):
438
  if paper_id_value and paper_id_value.strip():
439
  return gr.update(visible=True)
440
  else:
441
  return gr.update(visible=False)
442
 
443
- # Chain a then call to toggle visibility of the provider_section after paper info update
444
  paper_id.change(
445
  fn=toggle_provider_visibility,
446
  inputs=[paper_id],
447
  outputs=[provider_section]
448
  )
449
 
 
 
 
 
 
 
 
450
 
451
  def main():
452
- """
453
- Launches the Gradio app.
454
- """
455
  with gr.Blocks(css_paths="style.css") as demo:
456
  paper_id = gr.Textbox(label="Paper ID", value="")
457
  paper_from = gr.Radio(
@@ -460,9 +442,6 @@ def main():
460
  value="neurips"
461
  )
462
 
463
- # Build the paper chat tab
464
- dummy_calendar = gr.State(datetime.now().strftime("%Y-%m-%d"))
465
-
466
  class MockPaperCentral:
467
  def __init__(self):
468
  import pandas as pd
 
73
  else:
74
  abstract = 'Abstract not found'
75
 
 
76
  link = f"https://openreview.net/forum?id={paper_id}"
77
  return title, author_list, f"**Abstract:** {abstract}\n\n[View on OpenReview]({link})"
78
 
 
109
 
110
 
111
  def fetch_paper_info_paperpage(paper_id_value):
 
112
  def extract_paper_id(input_string):
 
113
  if re.fullmatch(r'\d+\.\d+', input_string.strip()):
114
  return input_string.strip()
 
115
  match = re.search(r'https://huggingface\.co/papers/(\d+\.\d+)', input_string)
116
  if match:
117
  return match.group(1)
 
137
 
138
 
139
  def fetch_paper_content_paperpage(paper_id_value):
 
140
  def extract_paper_id(input_string):
141
  if re.fullmatch(r'\d+\.\d+', input_string.strip()):
142
  return input_string.strip()
 
150
  return text
151
 
152
 
 
153
  PAPER_SOURCES = {
154
  "neurips": {
155
  "fetch_info": fetch_paper_info_neurips,
 
164
 
165
  def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
166
  provider_max_total_tokens):
 
167
  def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
168
  max_total_tokens):
169
  provider_info = PROVIDERS[provider_name_value]
170
  endpoint = provider_info['endpoint']
171
  api_key_env_var = provider_info['api_key_env_var']
 
172
  max_total_tokens = int(max_total_tokens)
173
 
 
174
  tokenizer_key = f"{provider_name_value}_{model_name_value}"
175
  if tokenizer_key not in tokenizer_cache:
176
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
 
179
  else:
180
  tokenizer = tokenizer_cache[tokenizer_key]
181
 
 
182
  if paper_content_value:
183
  context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
184
  else:
185
  context = ""
186
 
 
187
  context_tokens = tokenizer.encode(context)
188
  context_token_length = len(context_tokens)
189
 
 
190
  messages = []
191
  message_tokens_list = []
192
+ total_tokens = context_token_length
193
 
194
  for user_msg, assistant_msg in history:
 
195
  user_tokens = tokenizer.encode(user_msg)
196
  messages.append({"role": "user", "content": user_msg})
197
  message_tokens_list.append(len(user_tokens))
198
  total_tokens += len(user_tokens)
199
 
 
200
  if assistant_msg:
201
  assistant_tokens = tokenizer.encode(assistant_msg)
202
  messages.append({"role": "assistant", "content": assistant_msg})
203
  message_tokens_list.append(len(assistant_tokens))
204
  total_tokens += len(assistant_tokens)
205
 
 
206
  message_tokens = tokenizer.encode(message)
207
  messages.append({"role": "user", "content": message})
208
  message_tokens_list.append(len(message_tokens))
209
  total_tokens += len(message_tokens)
210
 
 
211
  if total_tokens > max_total_tokens:
 
212
  available_tokens = max_total_tokens - (total_tokens - context_token_length)
213
  if available_tokens > 0:
214
  truncated_context_tokens = context_tokens[:available_tokens]
 
220
  total_tokens -= context_token_length
221
  context_token_length = 0
222
 
 
223
  while total_tokens > max_total_tokens and len(messages) > 1:
224
  removed_message = messages.pop(0)
225
  removed_tokens = message_tokens_list.pop(0)
226
  total_tokens -= removed_tokens
227
 
 
228
  final_messages = []
229
  if context:
230
  final_messages.append({"role": "system", "content": f"{context}"})
231
  final_messages.extend(messages)
232
 
 
233
  api_key = hf_token_value or os.environ.get(api_key_env_var)
234
  if not api_key:
235
  raise ValueError("API token is not provided.")
236
 
 
237
  client = OpenAI(
238
  base_url=endpoint,
239
  api_key=api_key,
 
268
 
269
 
270
  def paper_chat_tab(paper_id, paper_from, paper_central_df):
271
+ # First row with two columns
272
  with gr.Row():
273
  # Left column: Paper selection and display
274
  with gr.Column(scale=1):
 
296
  )
297
  select_paper_button = gr.Button("Load this paper")
298
 
299
+ # Paper info display
300
  content = gr.HTML(value="", elem_id="paper_info_card")
301
 
302
+ # Right column: Provider and model selection
303
  with gr.Column(scale=1, visible=False) as provider_section:
304
  gr.Markdown("### LLM Provider and Model")
305
  provider_names = list(PROVIDERS.keys())
 
334
 
335
  paper_content = gr.State()
336
 
337
+ # Now a new row, full width, for the chat
338
+ with gr.Row(visible=False) as chat_row:
339
+ with gr.Column():
340
+ # Create chat interface below the two columns
341
  chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content,
342
  hf_token_input, default_type, default_max_total_tokens)
343
 
 
368
  )
369
 
370
  def update_paper_info(paper_id_value, paper_from_value, selected_model, old_content):
 
371
  source_info = PAPER_SOURCES.get(paper_from_value, {})
372
  fetch_info_fn = source_info.get("fetch_info")
373
  fetch_pdf_fn = source_info.get("fetch_pdf")
 
383
  if text is None:
384
  text = "Paper content could not be retrieved."
385
 
 
386
  card_html = f"""
387
  <div style="border:1px solid #ccc; border-radius:6px; background:#f9f9f9; padding:15px; margin-bottom:10px;">
388
  <center><h3 style="margin-top:0; text-decoration:underline;">You are talking with:</h3></center>
 
395
  return gr.update(value=card_html), text, []
396
 
397
  def select_paper(paper_title):
 
398
  for t, ppage in paper_choices:
399
  if t == paper_title:
400
  return ppage, "paper_page"
 
406
  outputs=[paper_id, paper_from]
407
  )
408
 
 
409
  paper_id.change(
410
  fn=update_paper_info,
411
  inputs=[paper_id, paper_from, model_dropdown, content],
412
  outputs=[content, paper_content, chatbot]
413
  )
414
 
 
415
  def toggle_provider_visibility(paper_id_value):
416
  if paper_id_value and paper_id_value.strip():
417
  return gr.update(visible=True)
418
  else:
419
  return gr.update(visible=False)
420
 
421
+ # Toggle provider section visibility
422
  paper_id.change(
423
  fn=toggle_provider_visibility,
424
  inputs=[paper_id],
425
  outputs=[provider_section]
426
  )
427
 
428
+ # Toggle chat row visibility
429
+ paper_id.change(
430
+ fn=toggle_provider_visibility,
431
+ inputs=[paper_id],
432
+ outputs=[chat_row]
433
+ )
434
+
435
 
436
  def main():
 
 
 
437
  with gr.Blocks(css_paths="style.css") as demo:
438
  paper_id = gr.Textbox(label="Paper ID", value="")
439
  paper_from = gr.Radio(
 
442
  value="neurips"
443
  )
444
 
 
 
 
445
  class MockPaperCentral:
446
  def __init__(self):
447
  import pandas as pd