aifeifei798 commited on
Commit
e25fb5a
β€’
1 Parent(s): e5535b9

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +4 -4
  2. app.py +233 -0
  3. image1.jpg +0 -0
  4. image2.jpg +0 -0
  5. pre-requirements.txt +1 -0
  6. requirements.txt +3 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Florence 2 Base
3
- emoji: πŸ“Š
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
1
  ---
2
+ title: Florence-2-base
3
+ emoji: πŸ“‰
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ #import spaces
4
+
5
+ import requests
6
+ import copy
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+
13
+ import random
14
+ import numpy as np
15
+
16
+ import subprocess
17
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
18
+
19
+ models = {
20
+ 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
21
+ }
22
+
23
+ processors = {
24
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
25
+ }
26
+
27
+
28
+ DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-base)"
29
+
30
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
31
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
32
+
33
+ def fig_to_pil(fig):
34
+ buf = io.BytesIO()
35
+ fig.savefig(buf, format='png')
36
+ buf.seek(0)
37
+ return Image.open(buf)
38
+
39
+ #@spaces.GPU
40
+ def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-base'):
41
+ model = models[model_id]
42
+ processor = processors[model_id]
43
+ if text_input is None:
44
+ prompt = task_prompt
45
+ else:
46
+ prompt = task_prompt + text_input
47
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
48
+ generated_ids = model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=1024,
52
+ early_stopping=False,
53
+ do_sample=False,
54
+ num_beams=3,
55
+ )
56
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
57
+ parsed_answer = processor.post_process_generation(
58
+ generated_text,
59
+ task=task_prompt,
60
+ image_size=(image.width, image.height)
61
+ )
62
+ return parsed_answer
63
+
64
+ def plot_bbox(image, data):
65
+ fig, ax = plt.subplots()
66
+ ax.imshow(image)
67
+ for bbox, label in zip(data['bboxes'], data['labels']):
68
+ x1, y1, x2, y2 = bbox
69
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
70
+ ax.add_patch(rect)
71
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
72
+ ax.axis('off')
73
+ return fig
74
+
75
+ def draw_polygons(image, prediction, fill_mask=False):
76
+
77
+ draw = ImageDraw.Draw(image)
78
+ scale = 1
79
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
80
+ color = random.choice(colormap)
81
+ fill_color = random.choice(colormap) if fill_mask else None
82
+ for _polygon in polygons:
83
+ _polygon = np.array(_polygon).reshape(-1, 2)
84
+ if len(_polygon) < 3:
85
+ print('Invalid polygon:', _polygon)
86
+ continue
87
+ _polygon = (_polygon * scale).reshape(-1).tolist()
88
+ if fill_mask:
89
+ draw.polygon(_polygon, outline=color, fill=fill_color)
90
+ else:
91
+ draw.polygon(_polygon, outline=color)
92
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
93
+ return image
94
+
95
+ def convert_to_od_format(data):
96
+ bboxes = data.get('bboxes', [])
97
+ labels = data.get('bboxes_labels', [])
98
+ od_results = {
99
+ 'bboxes': bboxes,
100
+ 'labels': labels
101
+ }
102
+ return od_results
103
+
104
+ def draw_ocr_bboxes(image, prediction):
105
+ scale = 1
106
+ draw = ImageDraw.Draw(image)
107
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
108
+ for box, label in zip(bboxes, labels):
109
+ color = random.choice(colormap)
110
+ new_box = (np.array(box) * scale).tolist()
111
+ draw.polygon(new_box, width=3, outline=color)
112
+ draw.text((new_box[0]+8, new_box[1]+2),
113
+ "{}".format(label),
114
+ align="right",
115
+ fill=color)
116
+ return image
117
+
118
+ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-base'):
119
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
120
+ if task_prompt == 'Caption':
121
+ task_prompt = '<CAPTION>'
122
+ results = run_example(task_prompt, image, model_id=model_id)
123
+ return results, None
124
+ elif task_prompt == 'Detailed Caption':
125
+ task_prompt = '<DETAILED_CAPTION>'
126
+ results = run_example(task_prompt, image, model_id=model_id)
127
+ return results, None
128
+ elif task_prompt == 'More Detailed Caption':
129
+ task_prompt = '<MORE_DETAILED_CAPTION>'
130
+ results = run_example(task_prompt, image, model_id=model_id)
131
+ return results, None
132
+ elif task_prompt == 'Object Detection':
133
+ task_prompt = '<OD>'
134
+ results = run_example(task_prompt, image, model_id=model_id)
135
+ fig = plot_bbox(image, results['<OD>'])
136
+ return results, fig_to_pil(fig)
137
+ elif task_prompt == 'Dense Region Caption':
138
+ task_prompt = '<DENSE_REGION_CAPTION>'
139
+ results = run_example(task_prompt, image, model_id=model_id)
140
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
141
+ return results, fig_to_pil(fig)
142
+ elif task_prompt == 'Region Proposal':
143
+ task_prompt = '<REGION_PROPOSAL>'
144
+ results = run_example(task_prompt, image, model_id=model_id)
145
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
146
+ return results, fig_to_pil(fig)
147
+ elif task_prompt == 'Caption to Phrase Grounding':
148
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
149
+ results = run_example(task_prompt, image, text_input, model_id)
150
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
151
+ return results, fig_to_pil(fig)
152
+ elif task_prompt == 'Referring Expression Segmentation':
153
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
154
+ results = run_example(task_prompt, image, text_input, model_id)
155
+ output_image = copy.deepcopy(image)
156
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
157
+ return results, output_image
158
+ elif task_prompt == 'Region to Segmentation':
159
+ task_prompt = '<REGION_TO_SEGMENTATION>'
160
+ results = run_example(task_prompt, image, text_input, model_id)
161
+ output_image = copy.deepcopy(image)
162
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
163
+ return results, output_image
164
+ elif task_prompt == 'Open Vocabulary Detection':
165
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
166
+ results = run_example(task_prompt, image, text_input, model_id)
167
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
168
+ fig = plot_bbox(image, bbox_results)
169
+ return results, fig_to_pil(fig)
170
+ elif task_prompt == 'Region to Category':
171
+ task_prompt = '<REGION_TO_CATEGORY>'
172
+ results = run_example(task_prompt, image, text_input, model_id)
173
+ return results, None
174
+ elif task_prompt == 'Region to Description':
175
+ task_prompt = '<REGION_TO_DESCRIPTION>'
176
+ results = run_example(task_prompt, image, text_input, model_id)
177
+ return results, None
178
+ elif task_prompt == 'OCR':
179
+ task_prompt = '<OCR>'
180
+ results = run_example(task_prompt, image, model_id=model_id)
181
+ return results, None
182
+ elif task_prompt == 'OCR with Region':
183
+ task_prompt = '<OCR_WITH_REGION>'
184
+ results = run_example(task_prompt, image, model_id=model_id)
185
+ output_image = copy.deepcopy(image)
186
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
187
+ return results, output_image
188
+ else:
189
+ return "", None # Return empty string and None for unknown task prompts
190
+
191
+ css = """
192
+ #output {
193
+ height: 500px;
194
+ overflow: auto;
195
+ border: 1px solid #ccc;
196
+ }
197
+ """
198
+
199
+ with gr.Blocks(css=css) as demo:
200
+ gr.Markdown(DESCRIPTION)
201
+ with gr.Tab(label="Florence-2 Image Captioning"):
202
+ with gr.Row():
203
+ with gr.Column():
204
+ input_img = gr.Image(label="Input Picture",height=400)
205
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-base')
206
+ task_prompt = gr.Dropdown(choices=[
207
+ 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
208
+ 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
209
+ 'Referring Expression Segmentation', 'Region to Segmentation',
210
+ 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
211
+ 'OCR', 'OCR with Region'
212
+ ], label="Task Prompt", value= 'Caption')
213
+ text_input = gr.Textbox(label="Text Input (optional)")
214
+ submit_btn = gr.Button(value="Submit")
215
+ with gr.Column():
216
+ output_text = gr.Textbox(label="Output Text")
217
+ output_img = gr.Image(label="Output Image")
218
+
219
+ gr.Examples(
220
+ examples=[
221
+ ["image1.jpg", 'Object Detection'],
222
+ ["image2.jpg", 'OCR with Region']
223
+ ],
224
+ inputs=[input_img, task_prompt],
225
+ outputs=[output_text, output_img],
226
+ fn=process_image,
227
+ cache_examples=True,
228
+ label='Try examples'
229
+ )
230
+
231
+ submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
232
+
233
+ demo.launch(debug=True,server_name="0.0.0.0")
image1.jpg ADDED
image2.jpg ADDED
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ spaces
2
+ transformers
3
+ timm