maxidl commited on
Commit
32a2059
1 Parent(s): 166956f
Files changed (2) hide show
  1. app.py +29 -20
  2. examples/2105.04505v1.pdf +0 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
  from marker.convert import convert_single_pdf
7
  from marker.output import markdown_exists, save_markdown, get_markdown_filepath
@@ -33,6 +34,7 @@ model = AutoModelForCausalLM.from_pretrained(
33
  device_map="auto"
34
  )
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
36
 
37
  # Define prompts
38
  SYSTEM_PROMPT_TEMPLATE = """You are an expert reviewer for AI conferences. You follow best practices and review papers according to the reviewer guidelines.
@@ -172,7 +174,7 @@ def process_file(file):
172
  return paper_text
173
 
174
 
175
- @spaces.GPU(duration=120)
176
  def generate(paper_text, review_template):
177
  # messages = [
178
  # {"role": "system", "content": "You are a pirate."},
@@ -185,19 +187,26 @@ def generate(paper_text, review_template):
185
  return_tensors='pt'
186
  ).to(model.device)
187
  print(f"input_ids shape: {input_ids.shape}")
188
- generated_ids = model.generate(
189
- input_ids=input_ids,
190
- max_new_tokens=4096,
191
- do_sample=True,
192
- temperature=0.6,
193
- top_p=0.9
194
- )
195
- generated_ids = [
196
- output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
197
- ]
 
 
 
 
 
 
 
198
 
199
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
200
- return response
201
  # return "Success"
202
 
203
 
@@ -249,13 +258,13 @@ with gr.Blocks(theme=theme) as demo:
249
  review_field = gr.Markdown("\n\n\n\n\n", label="Review")
250
  generate_button.click(fn=lambda: gr.update(interactive=False), inputs=None, outputs=generate_button).then(generate, [paper_text_field, review_template_field], review_field).then(fn=lambda: gr.update(interactive=True), inputs=None, outputs=generate_button)
251
 
252
- gr.Examples([
253
- ["examples/2105.04505v1.pdf", REVIEW_FIELDS]
254
- ],
255
- inputs=[paper_text_field, review_template_field],
256
- outputs=[review_field],
257
- fn=generate,
258
- cache_examples=True)
259
 
260
  demo.title = "Paper Review Generator"
261
 
 
2
  import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ from threading import Thread
6
 
7
  from marker.convert import convert_single_pdf
8
  from marker.output import markdown_exists, save_markdown, get_markdown_filepath
 
34
  device_map="auto"
35
  )
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
38
 
39
  # Define prompts
40
  SYSTEM_PROMPT_TEMPLATE = """You are an expert reviewer for AI conferences. You follow best practices and review papers according to the reviewer guidelines.
 
174
  return paper_text
175
 
176
 
177
+ @spaces.GPU()
178
  def generate(paper_text, review_template):
179
  # messages = [
180
  # {"role": "system", "content": "You are a pirate."},
 
187
  return_tensors='pt'
188
  ).to(model.device)
189
  print(f"input_ids shape: {input_ids.shape}")
190
+ generation_kwargs = dict(input_ids=input_ids, streamer=streamer, max_new_tokens=4096, do_sample=True, temperature=0.6, top_p=0.9)
191
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
192
+ thread.start()
193
+ generated_text = ""
194
+ for new_text in streamer:
195
+ generated_text += new_text
196
+ yield generated_text
197
+ # generated_ids = model.generate(
198
+ # input_ids=input_ids,
199
+ # max_new_tokens=4096,
200
+ # do_sample=True,
201
+ # temperature=0.6,
202
+ # top_p=0.9
203
+ # )
204
+ # generated_ids = [
205
+ # output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
206
+ # ]
207
 
208
+ # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
209
+ # return response
210
  # return "Success"
211
 
212
 
 
258
  review_field = gr.Markdown("\n\n\n\n\n", label="Review")
259
  generate_button.click(fn=lambda: gr.update(interactive=False), inputs=None, outputs=generate_button).then(generate, [paper_text_field, review_template_field], review_field).then(fn=lambda: gr.update(interactive=True), inputs=None, outputs=generate_button)
260
 
261
+ # gr.Examples([
262
+ # ["examples/2105.04505v1.pdf", REVIEW_FIELDS]
263
+ # ],
264
+ # inputs=[paper_text_field, review_template_field],
265
+ # outputs=[review_field],
266
+ # fn=generate,
267
+ # cache_examples=True)
268
 
269
  demo.title = "Paper Review Generator"
270
 
examples/2105.04505v1.pdf DELETED
Binary file (238 kB)