andito HF staff commited on
Commit
50fae8a
1 Parent(s): beec895

Changed model for DocVQA and added task

Browse files
Files changed (1) hide show
  1. app.py +28 -34
app.py CHANGED
@@ -16,22 +16,12 @@ import numpy as np
16
  import subprocess
17
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
18
 
19
- models = {
20
- 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True).to("cuda").eval(),
21
- 'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
22
- 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
23
- 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
24
- }
25
 
26
- processors = {
27
- 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
28
- 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
29
- 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
30
- 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
31
- }
32
 
33
 
34
- DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-large)"
35
 
36
  colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
37
  'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
@@ -43,9 +33,9 @@ def fig_to_pil(fig):
43
  return Image.open(buf)
44
 
45
  @spaces.GPU
46
- def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-large'):
47
- model = models[model_id]
48
- processor = processors[model_id]
49
  if text_input is None:
50
  prompt = task_prompt
51
  else:
@@ -123,71 +113,75 @@ def draw_ocr_bboxes(image, prediction):
123
 
124
  def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
125
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
126
- if task_prompt == 'Caption':
 
 
 
 
127
  task_prompt = '<CAPTION>'
128
- results = run_example(task_prompt, image, model_id=model_id)
129
  return results, None
130
  elif task_prompt == 'Detailed Caption':
131
  task_prompt = '<DETAILED_CAPTION>'
132
- results = run_example(task_prompt, image, model_id=model_id)
133
  return results, None
134
  elif task_prompt == 'More Detailed Caption':
135
  task_prompt = '<MORE_DETAILED_CAPTION>'
136
- results = run_example(task_prompt, image, model_id=model_id)
137
  return results, None
138
  elif task_prompt == 'Object Detection':
139
  task_prompt = '<OD>'
140
- results = run_example(task_prompt, image, model_id=model_id)
141
  fig = plot_bbox(image, results['<OD>'])
142
  return results, fig_to_pil(fig)
143
  elif task_prompt == 'Dense Region Caption':
144
  task_prompt = '<DENSE_REGION_CAPTION>'
145
- results = run_example(task_prompt, image, model_id=model_id)
146
  fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
147
  return results, fig_to_pil(fig)
148
  elif task_prompt == 'Region Proposal':
149
  task_prompt = '<REGION_PROPOSAL>'
150
- results = run_example(task_prompt, image, model_id=model_id)
151
  fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
152
  return results, fig_to_pil(fig)
153
  elif task_prompt == 'Caption to Phrase Grounding':
154
  task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
155
- results = run_example(task_prompt, image, text_input, model_id)
156
  fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
157
  return results, fig_to_pil(fig)
158
  elif task_prompt == 'Referring Expression Segmentation':
159
  task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
160
- results = run_example(task_prompt, image, text_input, model_id)
161
  output_image = copy.deepcopy(image)
162
  output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
163
  return results, output_image
164
  elif task_prompt == 'Region to Segmentation':
165
  task_prompt = '<REGION_TO_SEGMENTATION>'
166
- results = run_example(task_prompt, image, text_input, model_id)
167
  output_image = copy.deepcopy(image)
168
  output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
169
  return results, output_image
170
  elif task_prompt == 'Open Vocabulary Detection':
171
  task_prompt = '<OPEN_VOCABULARY_DETECTION>'
172
- results = run_example(task_prompt, image, text_input, model_id)
173
  bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
174
  fig = plot_bbox(image, bbox_results)
175
  return results, fig_to_pil(fig)
176
  elif task_prompt == 'Region to Category':
177
  task_prompt = '<REGION_TO_CATEGORY>'
178
- results = run_example(task_prompt, image, text_input, model_id)
179
  return results, None
180
  elif task_prompt == 'Region to Description':
181
  task_prompt = '<REGION_TO_DESCRIPTION>'
182
- results = run_example(task_prompt, image, text_input, model_id)
183
  return results, None
184
  elif task_prompt == 'OCR':
185
  task_prompt = '<OCR>'
186
- results = run_example(task_prompt, image, model_id=model_id)
187
  return results, None
188
  elif task_prompt == 'OCR with Region':
189
  task_prompt = '<OCR_WITH_REGION>'
190
- results = run_example(task_prompt, image, model_id=model_id)
191
  output_image = copy.deepcopy(image)
192
  output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
193
  return results, output_image
@@ -208,14 +202,14 @@ with gr.Blocks(css=css) as demo:
208
  with gr.Row():
209
  with gr.Column():
210
  input_img = gr.Image(label="Input Picture")
211
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
212
  task_prompt = gr.Dropdown(choices=[
 
213
  'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
214
  'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
215
  'Referring Expression Segmentation', 'Region to Segmentation',
216
  'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
217
  'OCR', 'OCR with Region'
218
- ], label="Task Prompt", value= 'Caption')
219
  text_input = gr.Textbox(label="Text Input (optional)")
220
  submit_btn = gr.Button(value="Submit")
221
  with gr.Column():
@@ -234,6 +228,6 @@ with gr.Blocks(css=css) as demo:
234
  label='Try examples'
235
  )
236
 
237
- submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
238
 
239
  demo.launch(debug=True)
 
16
  import subprocess
17
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
18
 
19
+ model = AutoModelForCausalLM.from_pretrained('HuggingFaceM4/Florence-2-DocVQA', trust_remote_code=True).to("cuda").eval()
 
 
 
 
 
20
 
21
+ processor = AutoProcessor.from_pretrained('HuggingFaceM4/Florence-2-DocVQA', trust_remote_code=True)
 
 
 
 
 
22
 
23
 
24
+ DESCRIPTION = "# [Florence-2-DocVQA Demo](https://huggingface.co/HuggingFaceM4/Florence-2-DocVQA)"
25
 
26
  colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
27
  'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
 
33
  return Image.open(buf)
34
 
35
  @spaces.GPU
36
+ def run_example(task_prompt, image, text_input=None):
37
+ model = model
38
+ processor = processor
39
  if text_input is None:
40
  prompt = task_prompt
41
  else:
 
113
 
114
  def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
115
  image = Image.fromarray(image) # Convert NumPy array to PIL Image
116
+ if task_prompt == 'Document Visual Question Answering':
117
+ task_prompt = '<DocVQA>'
118
+ results = run_example(task_prompt, image)
119
+ return results, None
120
+ elif task_prompt == 'Caption':
121
  task_prompt = '<CAPTION>'
122
+ results = run_example(task_prompt, image)
123
  return results, None
124
  elif task_prompt == 'Detailed Caption':
125
  task_prompt = '<DETAILED_CAPTION>'
126
+ results = run_example(task_prompt, image)
127
  return results, None
128
  elif task_prompt == 'More Detailed Caption':
129
  task_prompt = '<MORE_DETAILED_CAPTION>'
130
+ results = run_example(task_prompt, image)
131
  return results, None
132
  elif task_prompt == 'Object Detection':
133
  task_prompt = '<OD>'
134
+ results = run_example(task_prompt, image)
135
  fig = plot_bbox(image, results['<OD>'])
136
  return results, fig_to_pil(fig)
137
  elif task_prompt == 'Dense Region Caption':
138
  task_prompt = '<DENSE_REGION_CAPTION>'
139
+ results = run_example(task_prompt, image)
140
  fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
141
  return results, fig_to_pil(fig)
142
  elif task_prompt == 'Region Proposal':
143
  task_prompt = '<REGION_PROPOSAL>'
144
+ results = run_example(task_prompt, image)
145
  fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
146
  return results, fig_to_pil(fig)
147
  elif task_prompt == 'Caption to Phrase Grounding':
148
  task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
149
+ results = run_example(task_prompt, image, text_input)
150
  fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
151
  return results, fig_to_pil(fig)
152
  elif task_prompt == 'Referring Expression Segmentation':
153
  task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
154
+ results = run_example(task_prompt, image, text_input)
155
  output_image = copy.deepcopy(image)
156
  output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
157
  return results, output_image
158
  elif task_prompt == 'Region to Segmentation':
159
  task_prompt = '<REGION_TO_SEGMENTATION>'
160
+ results = run_example(task_prompt, image, text_input)
161
  output_image = copy.deepcopy(image)
162
  output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
163
  return results, output_image
164
  elif task_prompt == 'Open Vocabulary Detection':
165
  task_prompt = '<OPEN_VOCABULARY_DETECTION>'
166
+ results = run_example(task_prompt, image, text_input)
167
  bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
168
  fig = plot_bbox(image, bbox_results)
169
  return results, fig_to_pil(fig)
170
  elif task_prompt == 'Region to Category':
171
  task_prompt = '<REGION_TO_CATEGORY>'
172
+ results = run_example(task_prompt, image, text_input)
173
  return results, None
174
  elif task_prompt == 'Region to Description':
175
  task_prompt = '<REGION_TO_DESCRIPTION>'
176
+ results = run_example(task_prompt, image, text_input)
177
  return results, None
178
  elif task_prompt == 'OCR':
179
  task_prompt = '<OCR>'
180
+ results = run_example(task_prompt, image)
181
  return results, None
182
  elif task_prompt == 'OCR with Region':
183
  task_prompt = '<OCR_WITH_REGION>'
184
+ results = run_example(task_prompt, image)
185
  output_image = copy.deepcopy(image)
186
  output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
187
  return results, output_image
 
202
  with gr.Row():
203
  with gr.Column():
204
  input_img = gr.Image(label="Input Picture")
 
205
  task_prompt = gr.Dropdown(choices=[
206
+ 'Document Visual Question Answering',
207
  'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
208
  'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
209
  'Referring Expression Segmentation', 'Region to Segmentation',
210
  'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
211
  'OCR', 'OCR with Region'
212
+ ], label="Task Prompt", value= 'Document Visual Question Answering')
213
  text_input = gr.Textbox(label="Text Input (optional)")
214
  submit_btn = gr.Button(value="Submit")
215
  with gr.Column():
 
228
  label='Try examples'
229
  )
230
 
231
+ submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
232
 
233
  demo.launch(debug=True)