aiqcamp commited on
Commit
efc13cd
โ€ข
1 Parent(s): 9a8bbc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +399 -39
app.py CHANGED
@@ -1,25 +1,254 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
  from transformers import AutoProcessor, AutoModelForCausalLM
3
- import spaces
4
- from PIL import Image
5
-
6
  import subprocess
7
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
- models = {
10
- 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True).eval(),
11
- 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True).eval(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- processors = {
15
- 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True),
16
- 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
 
 
 
 
 
 
18
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @spaces.GPU
22
- def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'):
23
  image = Image.fromarray(image)
24
  task_prompt = "<DESCRIPTION>"
25
  prompt = task_prompt + "Describe this image in great detail."
@@ -27,8 +256,8 @@ def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'):
27
  if image.mode != "RGB":
28
  image = image.convert("RGB")
29
 
30
- model = models[model_name]
31
- processor = processors[model_name]
32
 
33
  inputs = processor(text=prompt, images=image, return_tensors="pt")
34
  generated_ids = model.generate(
@@ -42,35 +271,166 @@ def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'):
42
  parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
43
  return parsed_answer["<DESCRIPTION>"]
44
 
45
- css = """
46
- footer {
47
- visibility: hidden;
48
- }
49
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
 
52
 
 
 
53
 
 
 
 
 
54
  with gr.Row():
55
- with gr.Column():
56
- input_img = gr.Image(label="Input Picture")
57
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='gokaygokay/Florence-2-Flux-Large')
58
- submit_btn = gr.Button(value="Submit")
59
- with gr.Column():
60
- output_text = gr.Textbox(label="Output Text")
61
-
62
- gr.Examples(
63
- [["image1.jpg"],
64
- ["image2.jpg"],
65
- ["image3.png"],
66
- ["image5.jpg"]],
67
- inputs=[input_img, model_selector],
68
- outputs=[output_text],
69
- fn=run_example,
70
- label='Try captioning on below examples',
71
- cache_examples=True
72
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- submit_btn.click(run_example, [input_img, model_selector], [output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- demo.launch(debug=True)
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
  import gradio as gr
11
+ import torch
12
+ from diffusers import FluxPipeline
13
+ from diffusers.pipelines.stable_diffusion import safety_checker
14
+ from PIL import Image
15
  from transformers import AutoProcessor, AutoModelForCausalLM
 
 
 
16
  import subprocess
 
17
 
18
+ # Flash Attention ์„ค์น˜
19
+ subprocess.run('pip install flash-attn --no-build-isolation',
20
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
21
+ shell=True)
22
+
23
+ # Setup and initialization code
24
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
+
28
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
29
+ os.environ["HF_HUB_CACHE"] = cache_path
30
+ os.environ["HF_HOME"] = cache_path
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+
34
+ # Create gallery directory
35
+ if not path.exists(gallery_path):
36
+ os.makedirs(gallery_path, exist_ok=True)
37
+
38
+ # Florence ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
39
+ florence_models = {
40
+ 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
41
+ 'gokaygokay/Florence-2-Flux-Large',
42
+ trust_remote_code=True
43
+ ).eval(),
44
+ 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
45
+ 'gokaygokay/Florence-2-Flux',
46
+ trust_remote_code=True
47
+ ).eval(),
48
+ }
49
+
50
+ florence_processors = {
51
+ 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
52
+ 'gokaygokay/Florence-2-Flux-Large',
53
+ trust_remote_code=True
54
+ ),
55
+ 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
56
+ 'gokaygokay/Florence-2-Flux',
57
+ trust_remote_code=True
58
+ ),
59
+ }
60
+
61
+ def filter_prompt(prompt):
62
+ inappropriate_keywords = [
63
+ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
64
+ "erotic", "sensual", "seductive", "provocative", "intimate",
65
+ "violence", "gore", "blood", "death", "kill", "murder", "torture",
66
+ "drug", "suicide", "abuse", "hate", "discrimination"
67
+ ]
68
+
69
+ prompt_lower = prompt.lower()
70
+
71
+ for keyword in inappropriate_keywords:
72
+ if keyword in prompt_lower:
73
+ return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
74
+
75
+ return True, prompt
76
+
77
+ class timer:
78
+ def __init__(self, method_name="timed process"):
79
+ self.method = method_name
80
+ def __enter__(self):
81
+ self.start = time.time()
82
+ print(f"{self.method} starts")
83
+ def __exit__(self, exc_type, exc_val, exc_tb):
84
+ end = time.time()
85
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
86
+
87
+ # Model initialization
88
+ if not path.exists(cache_path):
89
+ os.makedirs(cache_path, exist_ok=True)
90
+
91
+ pipe = FluxPipeline.from_pretrained(
92
+ "black-forest-labs/FLUX.1-dev",
93
+ torch_dtype=torch.bfloat16
94
+ )
95
+ pipe.load_lora_weights(
96
+ hf_hub_download(
97
+ "ByteDance/Hyper-SD",
98
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors"
99
+ )
100
+ )
101
+ pipe.fuse_lora(lora_scale=0.125)
102
+ pipe.to(device="cuda", dtype=torch.bfloat16)
103
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
104
+ "CompVis/stable-diffusion-safety-checker"
105
+ )
106
+
107
+ # CSS ์Šคํƒ€์ผ
108
+ css = """
109
+ footer {display: none !important}
110
+ .gradio-container {
111
+ max-width: 1200px;
112
+ margin: auto;
113
+ }
114
+ .contain {
115
+ background: rgba(255, 255, 255, 0.05);
116
+ border-radius: 12px;
117
+ padding: 20px;
118
+ }
119
+ .generate-btn {
120
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
121
+ border: none !important;
122
+ color: white !important;
123
+ }
124
+ .generate-btn:hover {
125
+ transform: translateY(-2px);
126
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
127
+ }
128
+ .title {
129
+ text-align: center;
130
+ font-size: 2.5em;
131
+ font-weight: bold;
132
+ margin-bottom: 1em;
133
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
134
+ -webkit-background-clip: text;
135
+ -webkit-text-fill-color: transparent;
136
  }
137
+ .tabs {
138
+ margin-top: 20px;
139
+ border-radius: 10px;
140
+ overflow: hidden;
141
+ }
142
+ .tab-nav {
143
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
144
+ padding: 10px;
145
+ }
146
+ .tab-nav button {
147
+ color: white;
148
+ border: none;
149
+ padding: 10px 20px;
150
+ margin: 0 5px;
151
+ border-radius: 5px;
152
+ transition: all 0.3s ease;
153
+ }
154
+ .tab-nav button.selected {
155
+ background: rgba(255, 255, 255, 0.2);
156
+ }
157
+ .image-upload-container {
158
+ border: 2px dashed #4B79A1;
159
+ border-radius: 10px;
160
+ padding: 20px;
161
+ text-align: center;
162
+ transition: all 0.3s ease;
163
+ }
164
+ .image-upload-container:hover {
165
+ border-color: #283E51;
166
+ background: rgba(75, 121, 161, 0.1);
167
+ }
168
+ """
169
 
170
+ # CSS์— ์ถ”๊ฐ€ํ•  ์Šคํƒ€์ผ
171
+ additional_css = """
172
+ .primary-btn {
173
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
174
+ font-size: 1.2em !important;
175
+ padding: 12px 20px !important;
176
+ margin-top: 20px !important;
177
+ }
178
+ hr {
179
+ border: none;
180
+ border-top: 1px solid rgba(75, 121, 161, 0.2);
181
+ margin: 20px 0;
182
+ }
183
+ .input-section {
184
+ background: rgba(255, 255, 255, 0.03);
185
+ border-radius: 12px;
186
+ padding: 20px;
187
+ margin-bottom: 20px;
188
  }
189
+ .output-section {
190
+ background: rgba(255, 255, 255, 0.03);
191
+ border-radius: 12px;
192
+ padding: 20px;
193
+ }
194
+ """
195
 
196
+ # ๊ธฐ์กด CSS์— ์ƒˆ๋กœ์šด ์Šคํƒ€์ผ ์ถ”๊ฐ€
197
+ css = css + additional_css
198
 
199
+ def save_image(image):
200
+ """Save the generated image and return the path"""
201
+ try:
202
+ if not os.path.exists(gallery_path):
203
+ try:
204
+ os.makedirs(gallery_path, exist_ok=True)
205
+ except Exception as e:
206
+ print(f"Failed to create gallery directory: {str(e)}")
207
+ return None
208
+
209
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
210
+ random_suffix = os.urandom(4).hex()
211
+ filename = f"generated_{timestamp}_{random_suffix}.png"
212
+ filepath = os.path.join(gallery_path, filename)
213
+
214
+ try:
215
+ if isinstance(image, Image.Image):
216
+ image.save(filepath, "PNG", quality=100)
217
+ else:
218
+ image = Image.fromarray(image)
219
+ image.save(filepath, "PNG", quality=100)
220
+
221
+ if not os.path.exists(filepath):
222
+ print(f"Warning: Failed to verify saved image at {filepath}")
223
+ return None
224
+
225
+ return filepath
226
+ except Exception as e:
227
+ print(f"Failed to save image: {str(e)}")
228
+ return None
229
+
230
+ except Exception as e:
231
+ print(f"Error in save_image: {str(e)}")
232
+ return None
233
+
234
+ def load_gallery():
235
+ try:
236
+ os.makedirs(gallery_path, exist_ok=True)
237
+
238
+ image_files = []
239
+ for f in os.listdir(gallery_path):
240
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
241
+ full_path = os.path.join(gallery_path, f)
242
+ image_files.append((full_path, os.path.getmtime(full_path)))
243
+
244
+ image_files.sort(key=lambda x: x[1], reverse=True)
245
+ return [f[0] for f in image_files]
246
+ except Exception as e:
247
+ print(f"Error loading gallery: {str(e)}")
248
+ return []
249
 
250
  @spaces.GPU
251
+ def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
252
  image = Image.fromarray(image)
253
  task_prompt = "<DESCRIPTION>"
254
  prompt = task_prompt + "Describe this image in great detail."
 
256
  if image.mode != "RGB":
257
  image = image.convert("RGB")
258
 
259
+ model = florence_models[model_name]
260
+ processor = florence_processors[model_name]
261
 
262
  inputs = processor(text=prompt, images=image, return_tensors="pt")
263
  generated_ids = model.generate(
 
271
  parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
272
  return parsed_answer["<DESCRIPTION>"]
273
 
274
+ @spaces.GPU
275
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
276
+ is_safe, filtered_prompt = filter_prompt(prompt)
277
+ if not is_safe:
278
+ gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
279
+ return None, load_gallery()
280
+
281
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
282
+ try:
283
+ generated_image = pipe(
284
+ prompt=[filtered_prompt],
285
+ generator=torch.Generator().manual_seed(int(seed)),
286
+ num_inference_steps=int(steps),
287
+ guidance_scale=float(scales),
288
+ height=int(height),
289
+ width=int(width),
290
+ max_sequence_length=256
291
+ ).images[0]
292
+
293
+ saved_path = save_image(generated_image)
294
+ if saved_path is None:
295
+ print("Warning: Failed to save generated image")
296
+
297
+ return generated_image, load_gallery()
298
+ except Exception as e:
299
+ print(f"Error in image generation: {str(e)}")
300
+ return None, load_gallery()
301
 
302
+ def get_random_seed():
303
+ return torch.randint(0, 1000000, (1,)).item()
304
 
305
+ def update_seed():
306
+ return get_random_seed()
307
 
308
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
309
+ gr.HTML('<div class="title">AI Image Generator & Caption</div>')
310
+ gr.HTML('<div style="text-align: center; margin-bottom: 2em;">Upload an image for caption or create from text description</div>')
311
+
312
  with gr.Row():
313
+ # ์™ผ์ชฝ ์ปฌ๋Ÿผ: ์ž…๋ ฅ ์„น์…˜
314
+ with gr.Column(scale=3):
315
+ # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ์„น์…˜
316
+ input_image = gr.Image(
317
+ label="Upload Image (Optional)",
318
+ type="numpy",
319
+ elem_classes=["image-upload-container"]
320
+ )
321
+
322
+ florence_model = gr.Dropdown(
323
+ choices=list(florence_models.keys()),
324
+ label="Caption Model",
325
+ value='gokaygokay/Florence-2-Flux-Large',
326
+ visible=True
327
+ )
328
+
329
+ caption_button = gr.Button(
330
+ "๐Ÿ” Generate Caption from Image",
331
+ elem_classes=["generate-btn"]
332
+ )
333
+
334
+ # ๊ตฌ๋ถ„์„ 
335
+ gr.HTML('<hr style="margin: 20px 0;">')
336
+
337
+ # ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์„น์…˜
338
+ prompt = gr.Textbox(
339
+ label="Image Description",
340
+ placeholder="Enter text description or use generated caption above...",
341
+ lines=3
342
+ )
343
+
344
+ with gr.Accordion("Advanced Settings", open=False):
345
+ with gr.Row():
346
+ height = gr.Slider(
347
+ label="Height",
348
+ minimum=256,
349
+ maximum=1152,
350
+ step=64,
351
+ value=1024
352
+ )
353
+ width = gr.Slider(
354
+ label="Width",
355
+ minimum=256,
356
+ maximum=1152,
357
+ step=64,
358
+ value=1024
359
+ )
360
+
361
+ with gr.Row():
362
+ steps = gr.Slider(
363
+ label="Inference Steps",
364
+ minimum=6,
365
+ maximum=25,
366
+ step=1,
367
+ value=8
368
+ )
369
+ scales = gr.Slider(
370
+ label="Guidance Scale",
371
+ minimum=0.0,
372
+ maximum=5.0,
373
+ step=0.1,
374
+ value=3.5
375
+ )
376
+
377
+ seed = gr.Number(
378
+ label="Seed",
379
+ value=get_random_seed(),
380
+ precision=0
381
+ )
382
+
383
+ randomize_seed = gr.Button(
384
+ "๐ŸŽฒ Randomize Seed",
385
+ elem_classes=["generate-btn"]
386
+ )
387
+
388
+ generate_btn = gr.Button(
389
+ "โœจ Generate Image",
390
+ elem_classes=["generate-btn", "primary-btn"]
391
+ )
392
 
393
+ # ์˜ค๋ฅธ์ชฝ ์ปฌ๋Ÿผ: ์ถœ๋ ฅ ์„น์…˜
394
+ with gr.Column(scale=4):
395
+ output = gr.Image(
396
+ label="Generated Image",
397
+ elem_classes=["output-image"]
398
+ )
399
+
400
+ gallery = gr.Gallery(
401
+ label="Generated Images Gallery",
402
+ show_label=True,
403
+ columns=[4],
404
+ rows=[2],
405
+ height="auto",
406
+ object_fit="cover",
407
+ elem_classes=["gallery-container"]
408
+ )
409
+
410
+ gallery.value = load_gallery()
411
+
412
+ # Event handlers
413
+ caption_button.click(
414
+ generate_caption,
415
+ inputs=[input_image, florence_model],
416
+ outputs=[prompt]
417
+ )
418
+
419
+ generate_btn.click(
420
+ process_and_save_image,
421
+ inputs=[height, width, steps, scales, prompt, seed],
422
+ outputs=[output, gallery]
423
+ )
424
+
425
+ randomize_seed.click(
426
+ update_seed,
427
+ outputs=[seed]
428
+ )
429
+
430
+ generate_btn.click(
431
+ update_seed,
432
+ outputs=[seed]
433
+ )
434
 
435
+ if __name__ == "__main__":
436
+ demo.launch(allowed_paths=[PERSISTENT_DIR])