DrChamyoung commited on
Commit
ec19bce
1 Parent(s): 821073f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -133
app.py CHANGED
@@ -1,146 +1,239 @@
1
  import gradio as gr
2
- import numpy as np
 
 
 
 
 
 
 
 
 
 
3
  import random
4
- from diffusers import DiffusionPipeline
5
- import torch
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
-
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
-
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
 
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
 
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
 
50
  }
51
- """
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Blocks(css=css) as demo:
59
-
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
  with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
-
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
  gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
 
 
 
 
 
 
 
138
  )
139
 
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
144
- )
145
 
146
- demo.queue().launch()
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+
5
+ import requests
6
+ import copy
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+
13
  import random
14
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ import subprocess
17
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
18
 
19
+ models = {
20
+ 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True).to("cuda").eval(),
21
+ 'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
22
+ 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
23
+ 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
24
  }
 
25
 
26
+ processors = {
27
+ 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
28
+ 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
29
+ 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
30
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
31
+ }
32
+
33
+
34
+ DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-large)"
35
+
36
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
37
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
38
+
39
+ def fig_to_pil(fig):
40
+ buf = io.BytesIO()
41
+ fig.savefig(buf, format='png')
42
+ buf.seek(0)
43
+ return Image.open(buf)
44
+
45
+ @spaces.GPU
46
+ def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-large'):
47
+ model = models[model_id]
48
+ processor = processors[model_id]
49
+ if text_input is None:
50
+ prompt = task_prompt
51
+ else:
52
+ prompt = task_prompt + text_input
53
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
54
+ generated_ids = model.generate(
55
+ input_ids=inputs["input_ids"],
56
+ pixel_values=inputs["pixel_values"],
57
+ max_new_tokens=1024,
58
+ early_stopping=False,
59
+ do_sample=False,
60
+ num_beams=3,
61
+ )
62
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
+ parsed_answer = processor.post_process_generation(
64
+ generated_text,
65
+ task=task_prompt,
66
+ image_size=(image.width, image.height)
67
+ )
68
+ return parsed_answer
69
+
70
+ def plot_bbox(image, data):
71
+ fig, ax = plt.subplots()
72
+ ax.imshow(image)
73
+ for bbox, label in zip(data['bboxes'], data['labels']):
74
+ x1, y1, x2, y2 = bbox
75
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
76
+ ax.add_patch(rect)
77
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
78
+ ax.axis('off')
79
+ return fig
80
+
81
+ def draw_polygons(image, prediction, fill_mask=False):
82
+
83
+ draw = ImageDraw.Draw(image)
84
+ scale = 1
85
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
86
+ color = random.choice(colormap)
87
+ fill_color = random.choice(colormap) if fill_mask else None
88
+ for _polygon in polygons:
89
+ _polygon = np.array(_polygon).reshape(-1, 2)
90
+ if len(_polygon) < 3:
91
+ print('Invalid polygon:', _polygon)
92
+ continue
93
+ _polygon = (_polygon * scale).reshape(-1).tolist()
94
+ if fill_mask:
95
+ draw.polygon(_polygon, outline=color, fill=fill_color)
96
+ else:
97
+ draw.polygon(_polygon, outline=color)
98
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
99
+ return image
100
+
101
+ def convert_to_od_format(data):
102
+ bboxes = data.get('bboxes', [])
103
+ labels = data.get('bboxes_labels', [])
104
+ od_results = {
105
+ 'bboxes': bboxes,
106
+ 'labels': labels
107
+ }
108
+ return od_results
109
+
110
+ def draw_ocr_bboxes(image, prediction):
111
+ scale = 1
112
+ draw = ImageDraw.Draw(image)
113
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
114
+ for box, label in zip(bboxes, labels):
115
+ color = random.choice(colormap)
116
+ new_box = (np.array(box) * scale).tolist()
117
+ draw.polygon(new_box, width=3, outline=color)
118
+ draw.text((new_box[0]+8, new_box[1]+2),
119
+ "{}".format(label),
120
+ align="right",
121
+ fill=color)
122
+ return image
123
+
124
+ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
125
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
126
+ if task_prompt == 'Caption':
127
+ task_prompt = '<CAPTION>'
128
+ results = run_example(task_prompt, image, model_id=model_id)
129
+ return results, None
130
+ elif task_prompt == 'Detailed Caption':
131
+ task_prompt = '<DETAILED_CAPTION>'
132
+ results = run_example(task_prompt, image, model_id=model_id)
133
+ return results, None
134
+ elif task_prompt == 'More Detailed Caption':
135
+ task_prompt = '<MORE_DETAILED_CAPTION>'
136
+ results = run_example(task_prompt, image, model_id=model_id)
137
+ return results, None
138
+ elif task_prompt == 'Object Detection':
139
+ task_prompt = '<OD>'
140
+ results = run_example(task_prompt, image, model_id=model_id)
141
+ fig = plot_bbox(image, results['<OD>'])
142
+ return results, fig_to_pil(fig)
143
+ elif task_prompt == 'Dense Region Caption':
144
+ task_prompt = '<DENSE_REGION_CAPTION>'
145
+ results = run_example(task_prompt, image, model_id=model_id)
146
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
147
+ return results, fig_to_pil(fig)
148
+ elif task_prompt == 'Region Proposal':
149
+ task_prompt = '<REGION_PROPOSAL>'
150
+ results = run_example(task_prompt, image, model_id=model_id)
151
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
152
+ return results, fig_to_pil(fig)
153
+ elif task_prompt == 'Caption to Phrase Grounding':
154
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
155
+ results = run_example(task_prompt, image, text_input, model_id)
156
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
157
+ return results, fig_to_pil(fig)
158
+ elif task_prompt == 'Referring Expression Segmentation':
159
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
160
+ results = run_example(task_prompt, image, text_input, model_id)
161
+ output_image = copy.deepcopy(image)
162
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
163
+ return results, output_image
164
+ elif task_prompt == 'Region to Segmentation':
165
+ task_prompt = '<REGION_TO_SEGMENTATION>'
166
+ results = run_example(task_prompt, image, text_input, model_id)
167
+ output_image = copy.deepcopy(image)
168
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
169
+ return results, output_image
170
+ elif task_prompt == 'Open Vocabulary Detection':
171
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
172
+ results = run_example(task_prompt, image, text_input, model_id)
173
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
174
+ fig = plot_bbox(image, bbox_results)
175
+ return results, fig_to_pil(fig)
176
+ elif task_prompt == 'Region to Category':
177
+ task_prompt = '<REGION_TO_CATEGORY>'
178
+ results = run_example(task_prompt, image, text_input, model_id)
179
+ return results, None
180
+ elif task_prompt == 'Region to Description':
181
+ task_prompt = '<REGION_TO_DESCRIPTION>'
182
+ results = run_example(task_prompt, image, text_input, model_id)
183
+ return results, None
184
+ elif task_prompt == 'OCR':
185
+ task_prompt = '<OCR>'
186
+ results = run_example(task_prompt, image, model_id=model_id)
187
+ return results, None
188
+ elif task_prompt == 'OCR with Region':
189
+ task_prompt = '<OCR_WITH_REGION>'
190
+ results = run_example(task_prompt, image, model_id=model_id)
191
+ output_image = copy.deepcopy(image)
192
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
193
+ return results, output_image
194
+ else:
195
+ return "", None # Return empty string and None for unknown task prompts
196
+
197
+ css = """
198
+ #output {
199
+ height: 500px;
200
+ overflow: auto;
201
+ border: 1px solid #ccc;
202
+ }
203
+ """
204
 
205
  with gr.Blocks(css=css) as demo:
206
+ gr.Markdown(DESCRIPTION)
207
+ with gr.Tab(label="Florence-2 Image Captioning"):
 
 
 
 
 
208
  with gr.Row():
209
+ with gr.Column():
210
+ input_img = gr.Image(label="Input Picture")
211
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
212
+ task_prompt = gr.Dropdown(choices=[
213
+ 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
214
+ 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
215
+ 'Referring Expression Segmentation', 'Region to Segmentation',
216
+ 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
217
+ 'OCR', 'OCR with Region'
218
+ ], label="Task Prompt", value= 'Caption')
219
+ text_input = gr.Textbox(label="Text Input (optional)")
220
+ submit_btn = gr.Button(value="Submit")
221
+ with gr.Column():
222
+ output_text = gr.Textbox(label="Output Text")
223
+ output_img = gr.Image(label="Output Image")
224
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  gr.Examples(
226
+ examples=[
227
+ ["image1.jpg", 'Object Detection'],
228
+ ["image2.jpg", 'OCR with Region']
229
+ ],
230
+ inputs=[input_img, task_prompt],
231
+ outputs=[output_text, output_img],
232
+ fn=process_image,
233
+ cache_examples=True,
234
+ label='Try examples'
235
  )
236
 
237
+ submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
 
 
 
 
238
 
239
+ demo.launch(debug=True)