aiqcamp commited on
Commit
3b9fa91
1 Parent(s): 88d8521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -469
app.py CHANGED
@@ -1,470 +1,2 @@
1
- import spaces
2
- import argparse
3
  import os
4
- import time
5
- from os import path
6
- import shutil
7
- from datetime import datetime
8
- from safetensors.torch import load_file
9
- from huggingface_hub import hf_hub_download
10
- import gradio as gr
11
- import torch
12
- from diffusers import FluxPipeline
13
- from diffusers.pipelines.stable_diffusion import safety_checker
14
- from PIL import Image
15
- from transformers import AutoProcessor, AutoModelForCausalLM
16
- import subprocess
17
-
18
- # Flash Attention 설치
19
- subprocess.run('pip install flash-attn --no-build-isolation',
20
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
21
- shell=True)
22
-
23
- # Setup and initialization code
24
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
- gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
-
28
- os.environ["TRANSFORMERS_CACHE"] = cache_path
29
- os.environ["HF_HUB_CACHE"] = cache_path
30
- os.environ["HF_HOME"] = cache_path
31
-
32
- torch.backends.cuda.matmul.allow_tf32 = True
33
-
34
- # Create gallery directory
35
- if not path.exists(gallery_path):
36
- os.makedirs(gallery_path, exist_ok=True)
37
-
38
- # Florence 모델 초기화
39
- florence_models = {
40
- 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
41
- 'gokaygokay/Florence-2-Flux-Large',
42
- trust_remote_code=True
43
- ).eval(),
44
- 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
45
- 'gokaygokay/Florence-2-Flux',
46
- trust_remote_code=True
47
- ).eval(),
48
- }
49
-
50
- florence_processors = {
51
- 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
52
- 'gokaygokay/Florence-2-Flux-Large',
53
- trust_remote_code=True
54
- ),
55
- 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
56
- 'gokaygokay/Florence-2-Flux',
57
- trust_remote_code=True
58
- ),
59
- }
60
-
61
- def filter_prompt(prompt):
62
- inappropriate_keywords = [
63
- "sex"
64
- ]
65
-
66
- # "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
67
- # "erotic", "sensual", "seductive", "provocative", "intimate",
68
- # "violence", "gore", "blood", "death", "kill", "murder", "torture",
69
- # "drug", "suicide", "abuse", "hate", "discrimination"
70
- # ]
71
-
72
- prompt_lower = prompt.lower()
73
-
74
- for keyword in inappropriate_keywords:
75
- if keyword in prompt_lower:
76
- return False, "부적절한 내용이 포함된 프롬프트입니다."
77
-
78
- return True, prompt
79
-
80
- class timer:
81
- def __init__(self, method_name="timed process"):
82
- self.method = method_name
83
- def __enter__(self):
84
- self.start = time.time()
85
- print(f"{self.method} starts")
86
- def __exit__(self, exc_type, exc_val, exc_tb):
87
- end = time.time()
88
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
89
-
90
- # Model initialization
91
- if not path.exists(cache_path):
92
- os.makedirs(cache_path, exist_ok=True)
93
-
94
- pipe = FluxPipeline.from_pretrained(
95
- "black-forest-labs/FLUX.1-dev",
96
- torch_dtype=torch.bfloat16
97
- )
98
- pipe.load_lora_weights(
99
- hf_hub_download(
100
- "ByteDance/Hyper-SD",
101
- "Hyper-FLUX.1-dev-8steps-lora.safetensors"
102
- )
103
- )
104
- pipe.fuse_lora(lora_scale=0.125)
105
- pipe.to(device="cuda", dtype=torch.bfloat16)
106
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
107
- "CompVis/stable-diffusion-safety-checker"
108
- )
109
-
110
-
111
-
112
- def save_image(image):
113
- """Save the generated image and return the path"""
114
- try:
115
- if not os.path.exists(gallery_path):
116
- try:
117
- os.makedirs(gallery_path, exist_ok=True)
118
- except Exception as e:
119
- print(f"Failed to create gallery directory: {str(e)}")
120
- return None
121
-
122
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
123
- random_suffix = os.urandom(4).hex()
124
- filename = f"generated_{timestamp}_{random_suffix}.png"
125
- filepath = os.path.join(gallery_path, filename)
126
-
127
- try:
128
- if isinstance(image, Image.Image):
129
- image.save(filepath, "PNG", quality=100)
130
- else:
131
- image = Image.fromarray(image)
132
- image.save(filepath, "PNG", quality=100)
133
-
134
- if not os.path.exists(filepath):
135
- print(f"Warning: Failed to verify saved image at {filepath}")
136
- return None
137
-
138
- return filepath
139
- except Exception as e:
140
- print(f"Failed to save image: {str(e)}")
141
- return None
142
-
143
- except Exception as e:
144
- print(f"Error in save_image: {str(e)}")
145
- return None
146
-
147
- def load_gallery():
148
- try:
149
- os.makedirs(gallery_path, exist_ok=True)
150
-
151
- image_files = []
152
- for f in os.listdir(gallery_path):
153
- if f.lower().endswith(('.png', '.jpg', '.jpeg')):
154
- full_path = os.path.join(gallery_path, f)
155
- image_files.append((full_path, os.path.getmtime(full_path)))
156
-
157
- image_files.sort(key=lambda x: x[1], reverse=True)
158
- return [f[0] for f in image_files]
159
- except Exception as e:
160
- print(f"Error loading gallery: {str(e)}")
161
- return []
162
-
163
- @spaces.GPU
164
- def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
165
- image = Image.fromarray(image)
166
- task_prompt = "<DESCRIPTION>"
167
- prompt = task_prompt + "Describe this image in great detail."
168
-
169
- if image.mode != "RGB":
170
- image = image.convert("RGB")
171
-
172
- model = florence_models[model_name]
173
- processor = florence_processors[model_name]
174
-
175
- inputs = processor(text=prompt, images=image, return_tensors="pt")
176
- generated_ids = model.generate(
177
- input_ids=inputs["input_ids"],
178
- pixel_values=inputs["pixel_values"],
179
- max_new_tokens=1024,
180
- num_beams=3,
181
- repetition_penalty=1.10,
182
- )
183
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
184
- parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
185
- return parsed_answer["<DESCRIPTION>"]
186
-
187
- @spaces.GPU
188
- def process_and_save_image(height, width, steps, scales, prompt, seed):
189
- is_safe, filtered_prompt = filter_prompt(prompt)
190
- if not is_safe:
191
- gr.Warning("The prompt contains inappropriate content.")
192
- return None, load_gallery()
193
-
194
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
195
- try:
196
- generated_image = pipe(
197
- prompt=[filtered_prompt],
198
- generator=torch.Generator().manual_seed(int(seed)),
199
- num_inference_steps=int(steps),
200
- guidance_scale=float(scales),
201
- height=int(height),
202
- width=int(width),
203
- max_sequence_length=256
204
- ).images[0]
205
-
206
- saved_path = save_image(generated_image)
207
- if saved_path is None:
208
- print("Warning: Failed to save generated image")
209
-
210
- return generated_image, load_gallery()
211
- except Exception as e:
212
- print(f"Error in image generation: {str(e)}")
213
- return None, load_gallery()
214
-
215
- def get_random_seed():
216
- return torch.randint(0, 1000000, (1,)).item()
217
-
218
- def update_seed():
219
- return get_random_seed()
220
-
221
- # CSS 스타일
222
- css = """
223
- footer {display: none !important}
224
- .gradio-container {
225
- max-width: 1200px;
226
- margin: auto;
227
- }
228
- .contain {
229
- background: rgba(255, 255, 255, 0.05);
230
- border-radius: 12px;
231
- padding: 20px;
232
- }
233
- .generate-btn {
234
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
235
- border: none !important;
236
- color: white !important;
237
- }
238
- .generate-btn:hover {
239
- transform: translateY(-2px);
240
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
241
- }
242
- .title {
243
- text-align: center;
244
- font-size: 2.5em;
245
- font-weight: bold;
246
- margin-bottom: 1em;
247
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
248
- -webkit-background-clip: text;
249
- -webkit-text-fill-color: transparent;
250
- }
251
- .tabs {
252
- margin-top: 20px;
253
- border-radius: 10px;
254
- overflow: hidden;
255
- }
256
- .tab-nav {
257
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
258
- padding: 10px;
259
- }
260
- .tab-nav button {
261
- color: white;
262
- border: none;
263
- padding: 10px 20px;
264
- margin: 0 5px;
265
- border-radius: 5px;
266
- transition: all 0.3s ease;
267
- }
268
- .tab-nav button.selected {
269
- background: rgba(255, 255, 255, 0.2);
270
- }
271
- .image-upload-container {
272
- border: 2px dashed #4B79A1;
273
- border-radius: 10px;
274
- padding: 20px;
275
- text-align: center;
276
- transition: all 0.3s ease;
277
- }
278
- .image-upload-container:hover {
279
- border-color: #283E51;
280
- background: rgba(75, 121, 161, 0.1);
281
- }
282
- .primary-btn {
283
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
284
- font-size: 1.2em !important;
285
- padding: 12px 20px !important;
286
- margin-top: 20px !important;
287
- }
288
- hr {
289
- border: none;
290
- border-top: 1px solid rgba(75, 121, 161, 0.2);
291
- margin: 20px 0;
292
- }
293
- .input-section {
294
- background: rgba(255, 255, 255, 0.03);
295
- border-radius: 12px;
296
- padding: 20px;
297
- margin-bottom: 20px;
298
- }
299
- .output-section {
300
- background: rgba(255, 255, 255, 0.03);
301
- border-radius: 12px;
302
- padding: 20px;
303
- }
304
- .example-images {
305
- display: grid;
306
- grid-template-columns: repeat(4, 1fr);
307
- gap: 10px;
308
- margin-bottom: 20px;
309
- }
310
- .example-images img {
311
- width: 100%;
312
- height: 150px;
313
- object-fit: cover;
314
- border-radius: 8px;
315
- cursor: pointer;
316
- transition: transform 0.2s;
317
- }
318
- .example-images img:hover {
319
- transform: scale(1.05);
320
- }
321
- """
322
-
323
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
324
- gr.HTML('<div class="title">FLUX VisionReply</div>')
325
- gr.HTML('<div style="text-align: center; margin-bottom: 2em;">Upload an image(Image2Text2Image)</div>')
326
-
327
- with gr.Row():
328
- # 왼쪽 컬럼: 입력 섹션
329
- with gr.Column(scale=3):
330
- # 이미지 업로드 섹션
331
- input_image = gr.Image(
332
- label="Upload Image (Optional)",
333
- type="numpy",
334
- elem_classes=["image-upload-container"]
335
- )
336
-
337
- # 예시 이미지 갤러리 추가
338
- example_images = [
339
- "5.jpg",
340
- "6.jpg",
341
- "7.jpg",
342
- "3.jpg",
343
- "1.jpg",
344
- "2.jpg",
345
- "4.jpg",
346
-
347
- ]
348
- gr.Examples(
349
- examples=example_images,
350
- inputs=input_image,
351
- label="Example Images",
352
- examples_per_page=4
353
- )
354
-
355
- # Florence 모델 선택 - 숨김 처리
356
- florence_model = gr.Dropdown(
357
- choices=list(florence_models.keys()),
358
- label="Caption Model",
359
- value='gokaygokay/Florence-2-Flux-Large',
360
- visible=False
361
- )
362
-
363
- caption_button = gr.Button(
364
- "🔍 Generate Caption from Image",
365
- elem_classes=["generate-btn"]
366
- )
367
-
368
- # 구분선
369
- gr.HTML('<hr style="margin: 20px 0;">')
370
-
371
- # 텍스트 프롬프트 섹션
372
- prompt = gr.Textbox(
373
- label="Image Description",
374
- placeholder="Enter text description or use generated caption above...",
375
- lines=3
376
- )
377
-
378
- with gr.Accordion("Advanced Settings", open=False):
379
- with gr.Row():
380
- height = gr.Slider(
381
- label="Height",
382
- minimum=256,
383
- maximum=1152,
384
- step=64,
385
- value=1024
386
- )
387
- width = gr.Slider(
388
- label="Width",
389
- minimum=256,
390
- maximum=1152,
391
- step=64,
392
- value=1024
393
- )
394
-
395
- with gr.Row():
396
- steps = gr.Slider(
397
- label="Inference Steps",
398
- minimum=6,
399
- maximum=25,
400
- step=1,
401
- value=8
402
- )
403
- scales = gr.Slider(
404
- label="Guidance Scale",
405
- minimum=0.0,
406
- maximum=5.0,
407
- step=0.1,
408
- value=3.5
409
- )
410
-
411
- seed = gr.Number(
412
- label="Seed",
413
- value=get_random_seed(),
414
- precision=0
415
- )
416
-
417
- randomize_seed = gr.Button(
418
- "🎲 Randomize Seed",
419
- elem_classes=["generate-btn"]
420
- )
421
-
422
- generate_btn = gr.Button(
423
- "✨ Generate Image",
424
- elem_classes=["generate-btn", "primary-btn"]
425
- )
426
-
427
- # 오른쪽 컬럼: 출력 섹션
428
- with gr.Column(scale=4):
429
- output = gr.Image(
430
- label="Generated Image",
431
- elem_classes=["output-image"]
432
- )
433
-
434
- gallery = gr.Gallery(
435
- label="Generated Images Gallery",
436
- show_label=True,
437
- columns=[4],
438
- rows=[2],
439
- height="auto",
440
- object_fit="cover",
441
- elem_classes=["gallery-container"]
442
- )
443
-
444
- gallery.value = load_gallery()
445
-
446
- # Event handlers
447
- caption_button.click(
448
- generate_caption,
449
- inputs=[input_image, florence_model],
450
- outputs=[prompt]
451
- )
452
-
453
- generate_btn.click(
454
- process_and_save_image,
455
- inputs=[height, width, steps, scales, prompt, seed],
456
- outputs=[output, gallery]
457
- )
458
-
459
- randomize_seed.click(
460
- update_seed,
461
- outputs=[seed]
462
- )
463
-
464
- generate_btn.click(
465
- update_seed,
466
- outputs=[seed]
467
- )
468
-
469
- if __name__ == "__main__":
470
- demo.launch(allowed_paths=[PERSISTENT_DIR])
 
 
 
1
  import os
2
+ exec(os.environ.get('APP'))