VictorSanh commited on
Commit
82e8993
1 Parent(s): 1e870d6

fix generation parsing

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -19,7 +19,7 @@ model = Idefics2ForConditionalGeneration.from_pretrained(
19
 
20
  @spaces.GPU(duration=180)
21
  def model_inference(
22
- image, text, decoding_strategy, temperature,
23
  max_new_tokens, repetition_penalty, top_p
24
  ):
25
  if text == "" and not image:
@@ -36,16 +36,16 @@ def model_inference(
36
  ]
37
  }
38
  ]
39
-
40
-
41
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
42
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
43
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
44
-
45
  generation_args = {
46
  "max_new_tokens": max_new_tokens,
47
  "repetition_penalty": repetition_penalty,
48
-
49
  }
50
 
51
  assert decoding_strategy in [
@@ -59,20 +59,15 @@ def model_inference(
59
  generation_args["do_sample"] = True
60
  generation_args["top_p"] = top_p
61
 
62
-
63
  generation_args.update(inputs)
64
 
65
  # Generate
66
  generated_ids = model.generate(**generation_args)
67
-
68
- generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
69
- print(generated_texts)
70
- pattern = r"Assistant: (.*)"
71
 
72
- # Use regular expression to find the desired part
73
- result = re.search(pattern, generated_texts[0]).group(1)
74
-
75
- return result[:-1]
76
 
77
 
78
  with gr.Blocks(fill_height=True) as demo:
@@ -87,7 +82,7 @@ with gr.Blocks(fill_height=True) as demo:
87
  query_input = gr.Textbox(label="Prompt")
88
  submit_btn = gr.Button("Submit")
89
  output = gr.Textbox(label="Output")
90
-
91
  with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"):
92
  examples=[["./example_images/docvqa_example.png", "How many items are sold?", "Greedy", 0.4, 512, 1.2, 0.8],
93
  ["./example_images/example_images_travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8],
@@ -95,7 +90,7 @@ with gr.Blocks(fill_height=True) as demo:
95
  ["./example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8],
96
  ["./example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8],
97
  ["./example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8]]
98
-
99
  # Hyper-parameters for generation
100
  max_new_tokens = gr.Slider(
101
  minimum=8,
@@ -151,7 +146,7 @@ with gr.Blocks(fill_height=True) as demo:
151
  inputs=decoding_strategy,
152
  outputs=temperature,
153
  )
154
-
155
  decoding_strategy.change(
156
  fn=lambda selection: gr.Slider(
157
  visible=(
@@ -168,13 +163,13 @@ with gr.Blocks(fill_height=True) as demo:
168
  )
169
  gr.Examples(
170
  examples = examples,
171
- inputs=[image_input, query_input, decoding_strategy, temperature,
172
  max_new_tokens, repetition_penalty, top_p],
173
  outputs=output,
174
  fn=model_inference
175
  )
176
-
177
- submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature,
178
  max_new_tokens, repetition_penalty, top_p], outputs=output)
179
 
180
 
 
19
 
20
  @spaces.GPU(duration=180)
21
  def model_inference(
22
+ image, text, decoding_strategy, temperature,
23
  max_new_tokens, repetition_penalty, top_p
24
  ):
25
  if text == "" and not image:
 
36
  ]
37
  }
38
  ]
39
+
40
+
41
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
42
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
43
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
44
+
45
  generation_args = {
46
  "max_new_tokens": max_new_tokens,
47
  "repetition_penalty": repetition_penalty,
48
+
49
  }
50
 
51
  assert decoding_strategy in [
 
59
  generation_args["do_sample"] = True
60
  generation_args["top_p"] = top_p
61
 
62
+
63
  generation_args.update(inputs)
64
 
65
  # Generate
66
  generated_ids = model.generate(**generation_args)
 
 
 
 
67
 
68
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
69
+ print("INPUT:", prompt, "|OUTPUT:", generated_texts)
70
+ return generated_texts[0]
 
71
 
72
 
73
  with gr.Blocks(fill_height=True) as demo:
 
82
  query_input = gr.Textbox(label="Prompt")
83
  submit_btn = gr.Button("Submit")
84
  output = gr.Textbox(label="Output")
85
+
86
  with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"):
87
  examples=[["./example_images/docvqa_example.png", "How many items are sold?", "Greedy", 0.4, 512, 1.2, 0.8],
88
  ["./example_images/example_images_travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8],
 
90
  ["./example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8],
91
  ["./example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8],
92
  ["./example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8]]
93
+
94
  # Hyper-parameters for generation
95
  max_new_tokens = gr.Slider(
96
  minimum=8,
 
146
  inputs=decoding_strategy,
147
  outputs=temperature,
148
  )
149
+
150
  decoding_strategy.change(
151
  fn=lambda selection: gr.Slider(
152
  visible=(
 
163
  )
164
  gr.Examples(
165
  examples = examples,
166
+ inputs=[image_input, query_input, decoding_strategy, temperature,
167
  max_new_tokens, repetition_penalty, top_p],
168
  outputs=output,
169
  fn=model_inference
170
  )
171
+
172
+ submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature,
173
  max_new_tokens, repetition_penalty, top_p], outputs=output)
174
 
175