openfree commited on
Commit
e5c3be3
·
verified ·
1 Parent(s): 51eea90

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +831 -0
app-backup.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, FluxControlNetModel
9
+ from diffusers.pipelines import FluxControlNetPipeline
10
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
11
+ from diffusers.utils import load_image
12
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
13
+ import copy
14
+ import random
15
+ import time
16
+ import requests
17
+ import pandas as pd
18
+ from transformers import pipeline
19
+ from gradio_imageslider import ImageSlider
20
+ import numpy as np
21
+ import warnings
22
+
23
+
24
+ huggingface_token = os.getenv("HUGGINFACE_TOKEN")
25
+
26
+
27
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
28
+
29
+
30
+
31
+ #Load prompts for randomization
32
+ df = pd.read_csv('prompts.csv', header=None)
33
+ prompt_values = df.values.flatten()
34
+
35
+ # Load LoRAs from JSON file
36
+ with open('loras.json', 'r') as f:
37
+ loras = json.load(f)
38
+
39
+ # Initialize the base model
40
+ dtype = torch.bfloat16
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ # 공통 FLUX 모델 로드
45
+ base_model = "black-forest-labs/FLUX.1-dev"
46
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
47
+
48
+ # LoRA를 위한 설정
49
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
50
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
51
+
52
+ # Image-to-Image 파이프라인 설정
53
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
54
+ base_model,
55
+ vae=good_vae,
56
+ transformer=pipe.transformer,
57
+ text_encoder=pipe.text_encoder,
58
+ tokenizer=pipe.tokenizer,
59
+ text_encoder_2=pipe.text_encoder_2,
60
+ tokenizer_2=pipe.tokenizer_2,
61
+ torch_dtype=dtype
62
+ ).to(device)
63
+
64
+ # Upscale을 위한 ControlNet 설정
65
+ controlnet = FluxControlNetModel.from_pretrained(
66
+ "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
67
+ ).to(device)
68
+
69
+ # Upscale 파이프라인 설정 (기존 pipe 재사용)
70
+ pipe_upscale = FluxControlNetPipeline(
71
+ vae=pipe.vae,
72
+ text_encoder=pipe.text_encoder,
73
+ text_encoder_2=pipe.text_encoder_2,
74
+ tokenizer=pipe.tokenizer,
75
+ tokenizer_2=pipe.tokenizer_2,
76
+ transformer=pipe.transformer,
77
+ scheduler=pipe.scheduler,
78
+ controlnet=controlnet
79
+ ).to(device)
80
+
81
+ MAX_SEED = 2**32 - 1
82
+ MAX_PIXEL_BUDGET = 1024 * 1024
83
+
84
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
85
+
86
+ class calculateDuration:
87
+ def __init__(self, activity_name=""):
88
+ self.activity_name = activity_name
89
+
90
+ def __enter__(self):
91
+ self.start_time = time.time()
92
+ return self
93
+
94
+ def __exit__(self, exc_type, exc_value, traceback):
95
+ self.end_time = time.time()
96
+ self.elapsed_time = self.end_time - self.start_time
97
+ if self.activity_name:
98
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
99
+ else:
100
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
101
+
102
+ def download_file(url, directory=None):
103
+ if directory is None:
104
+ directory = os.getcwd() # Use current working directory if not specified
105
+
106
+ # Get the filename from the URL
107
+ filename = url.split('/')[-1]
108
+
109
+ # Full path for the downloaded file
110
+ filepath = os.path.join(directory, filename)
111
+
112
+ # Download the file
113
+ response = requests.get(url)
114
+ response.raise_for_status() # Raise an exception for bad status codes
115
+
116
+ # Write the content to the file
117
+ with open(filepath, 'wb') as file:
118
+ file.write(response.content)
119
+
120
+ return filepath
121
+
122
+ def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
123
+ selected_index = evt.index
124
+ selected_indices = selected_indices or []
125
+ if selected_index in selected_indices:
126
+ selected_indices.remove(selected_index)
127
+ else:
128
+ if len(selected_indices) < 2:
129
+ selected_indices.append(selected_index)
130
+ else:
131
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
132
+ return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
133
+
134
+ selected_info_1 = "Select a LoRA 1"
135
+ selected_info_2 = "Select a LoRA 2"
136
+ lora_scale_1 = 1.15
137
+ lora_scale_2 = 1.15
138
+ lora_image_1 = None
139
+ lora_image_2 = None
140
+ if len(selected_indices) >= 1:
141
+ lora1 = loras_state[selected_indices[0]]
142
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
143
+ lora_image_1 = lora1['image']
144
+ if len(selected_indices) >= 2:
145
+ lora2 = loras_state[selected_indices[1]]
146
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
147
+ lora_image_2 = lora2['image']
148
+
149
+ if selected_indices:
150
+ last_selected_lora = loras_state[selected_indices[-1]]
151
+ new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
152
+ else:
153
+ new_placeholder = "Type a prompt after selecting a LoRA"
154
+
155
+ return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
156
+
157
+ def remove_lora_1(selected_indices, loras_state):
158
+ if len(selected_indices) >= 1:
159
+ selected_indices.pop(0)
160
+ selected_info_1 = "Select a LoRA 1"
161
+ selected_info_2 = "Select a LoRA 2"
162
+ lora_scale_1 = 1.15
163
+ lora_scale_2 = 1.15
164
+ lora_image_1 = None
165
+ lora_image_2 = None
166
+ if len(selected_indices) >= 1:
167
+ lora1 = loras_state[selected_indices[0]]
168
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
169
+ lora_image_1 = lora1['image']
170
+ if len(selected_indices) >= 2:
171
+ lora2 = loras_state[selected_indices[1]]
172
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
173
+ lora_image_2 = lora2['image']
174
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
175
+
176
+ def remove_lora_2(selected_indices, loras_state):
177
+ if len(selected_indices) >= 2:
178
+ selected_indices.pop(1)
179
+ selected_info_1 = "Select LoRA 1"
180
+ selected_info_2 = "Select LoRA 2"
181
+ lora_scale_1 = 1.15
182
+ lora_scale_2 = 1.15
183
+ lora_image_1 = None
184
+ lora_image_2 = None
185
+ if len(selected_indices) >= 1:
186
+ lora1 = loras_state[selected_indices[0]]
187
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
188
+ lora_image_1 = lora1['image']
189
+ if len(selected_indices) >= 2:
190
+ lora2 = loras_state[selected_indices[1]]
191
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
192
+ lora_image_2 = lora2['image']
193
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
194
+
195
+ def randomize_loras(selected_indices, loras_state):
196
+ try:
197
+ if len(loras_state) < 2:
198
+ raise gr.Error("Not enough LoRAs to randomize.")
199
+ selected_indices = random.sample(range(len(loras_state)), 2)
200
+ lora1 = loras_state[selected_indices[0]]
201
+ lora2 = loras_state[selected_indices[1]]
202
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
203
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
204
+ lora_scale_1 = 1.15
205
+ lora_scale_2 = 1.15
206
+ lora_image_1 = lora1['image']
207
+ lora_image_2 = lora2['image']
208
+ random_prompt = random.choice(prompt_values)
209
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
210
+ except Exception as e:
211
+ print(f"Error in randomize_loras: {str(e)}")
212
+ return "Error", "Error", [], 1.15, 1.15, None, None, ""
213
+
214
+ def add_custom_lora(custom_lora, selected_indices, current_loras):
215
+ if custom_lora:
216
+ try:
217
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
218
+ print(f"Loaded custom LoRA: {repo}")
219
+ existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
220
+ if existing_item_index is None:
221
+ if repo.endswith(".safetensors") and repo.startswith("http"):
222
+ repo = download_file(repo)
223
+ new_item = {
224
+ "image": image if image else "/home/user/app/custom.png",
225
+ "title": title,
226
+ "repo": repo,
227
+ "weights": path,
228
+ "trigger_word": trigger_word
229
+ }
230
+ print(f"New LoRA: {new_item}")
231
+ existing_item_index = len(current_loras)
232
+ current_loras.append(new_item)
233
+
234
+ # Update gallery
235
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
236
+ # Update selected_indices if there's room
237
+ if len(selected_indices) < 2:
238
+ selected_indices.append(existing_item_index)
239
+ else:
240
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
241
+
242
+ # Update selected_info and images
243
+ selected_info_1 = "Select a LoRA 1"
244
+ selected_info_2 = "Select a LoRA 2"
245
+ lora_scale_1 = 1.15
246
+ lora_scale_2 = 1.15
247
+ lora_image_1 = None
248
+ lora_image_2 = None
249
+ if len(selected_indices) >= 1:
250
+ lora1 = current_loras[selected_indices[0]]
251
+ selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
252
+ lora_image_1 = lora1['image'] if lora1['image'] else None
253
+ if len(selected_indices) >= 2:
254
+ lora2 = current_loras[selected_indices[1]]
255
+ selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
256
+ lora_image_2 = lora2['image'] if lora2['image'] else None
257
+ print("Finished adding custom LoRA")
258
+ return (
259
+ current_loras,
260
+ gr.update(value=gallery_items),
261
+ selected_info_1,
262
+ selected_info_2,
263
+ selected_indices,
264
+ lora_scale_1,
265
+ lora_scale_2,
266
+ lora_image_1,
267
+ lora_image_2
268
+ )
269
+ except Exception as e:
270
+ print(e)
271
+ gr.Warning(str(e))
272
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
273
+ else:
274
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
275
+
276
+ def remove_custom_lora(selected_indices, current_loras):
277
+ if current_loras:
278
+ custom_lora_repo = current_loras[-1]['repo']
279
+ # Remove from loras list
280
+ current_loras = current_loras[:-1]
281
+ # Remove from selected_indices if selected
282
+ custom_lora_index = len(current_loras)
283
+ if custom_lora_index in selected_indices:
284
+ selected_indices.remove(custom_lora_index)
285
+ # Update gallery
286
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
287
+ # Update selected_info and images
288
+ selected_info_1 = "Select a LoRA 1"
289
+ selected_info_2 = "Select a LoRA 2"
290
+ lora_scale_1 = 1.15
291
+ lora_scale_2 = 1.15
292
+ lora_image_1 = None
293
+ lora_image_2 = None
294
+ if len(selected_indices) >= 1:
295
+ lora1 = current_loras[selected_indices[0]]
296
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
297
+ lora_image_1 = lora1['image']
298
+ if len(selected_indices) >= 2:
299
+ lora2 = current_loras[selected_indices[1]]
300
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
301
+ lora_image_2 = lora2['image']
302
+ return (
303
+ current_loras,
304
+ gr.update(value=gallery_items),
305
+ selected_info_1,
306
+ selected_info_2,
307
+ selected_indices,
308
+ lora_scale_1,
309
+ lora_scale_2,
310
+ lora_image_1,
311
+ lora_image_2
312
+ )
313
+
314
+ @spaces.GPU(duration=75)
315
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
316
+ print("Generating image...")
317
+ pipe.to("cuda")
318
+ generator = torch.Generator(device="cuda").manual_seed(seed)
319
+ with calculateDuration("Generating image"):
320
+ # Generate image
321
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
322
+ prompt=prompt_mash,
323
+ num_inference_steps=steps,
324
+ guidance_scale=cfg_scale,
325
+ width=width,
326
+ height=height,
327
+ generator=generator,
328
+ joint_attention_kwargs={"scale": 1.0},
329
+ output_type="pil",
330
+ good_vae=good_vae,
331
+ ):
332
+ yield img
333
+
334
+ @spaces.GPU(duration=75)
335
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
336
+ pipe_i2i.to("cuda")
337
+ generator = torch.Generator(device="cuda").manual_seed(seed)
338
+ image_input = load_image(image_input_path)
339
+ final_image = pipe_i2i(
340
+ prompt=prompt_mash,
341
+ image=image_input,
342
+ strength=image_strength,
343
+ num_inference_steps=steps,
344
+ guidance_scale=cfg_scale,
345
+ width=width,
346
+ height=height,
347
+ generator=generator,
348
+ joint_attention_kwargs={"scale": 1.0},
349
+ output_type="pil",
350
+ ).images[0]
351
+ return final_image
352
+
353
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
354
+ try:
355
+ # 한글 감지 및 번역
356
+ if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
357
+ translated = translator(prompt, max_length=512)[0]['translation_text']
358
+ print(f"Original prompt: {prompt}")
359
+ print(f"Translated prompt: {translated}")
360
+ prompt = translated
361
+
362
+ if not selected_indices:
363
+ raise gr.Error("You must select at least one LoRA before proceeding.")
364
+
365
+ selected_loras = [loras_state[idx] for idx in selected_indices]
366
+
367
+ # Build the prompt with trigger words
368
+ prepends = []
369
+ appends = []
370
+ for lora in selected_loras:
371
+ trigger_word = lora.get('trigger_word', '')
372
+ if trigger_word:
373
+ if lora.get("trigger_position") == "prepend":
374
+ prepends.append(trigger_word)
375
+ else:
376
+ appends.append(trigger_word)
377
+ prompt_mash = " ".join(prepends + [prompt] + appends)
378
+ print("Prompt Mash: ", prompt_mash)
379
+
380
+ # Unload previous LoRA weights
381
+ with calculateDuration("Unloading LoRA"):
382
+ pipe.unload_lora_weights()
383
+ pipe_i2i.unload_lora_weights()
384
+
385
+ print(pipe.get_active_adapters())
386
+ # Load LoRA weights with respective scales
387
+ lora_names = []
388
+ lora_weights = []
389
+ with calculateDuration("Loading LoRA weights"):
390
+ for idx, lora in enumerate(selected_loras):
391
+ lora_name = f"lora_{idx}"
392
+ lora_names.append(lora_name)
393
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
394
+ lora_path = lora['repo']
395
+ weight_name = lora.get("weights")
396
+ print(f"Lora Path: {lora_path}")
397
+ if image_input is not None:
398
+ if weight_name:
399
+ pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
400
+ else:
401
+ pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
402
+ else:
403
+ if weight_name:
404
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
405
+ else:
406
+ pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
407
+ print("Loaded LoRAs:", lora_names)
408
+ print("Adapter weights:", lora_weights)
409
+ if image_input is not None:
410
+ pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
411
+ else:
412
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
413
+ print(pipe.get_active_adapters())
414
+ # Set random seed for reproducibility
415
+ with calculateDuration("Randomizing seed"):
416
+ if randomize_seed:
417
+ seed = random.randint(0, MAX_SEED)
418
+
419
+ # Generate image
420
+ if image_input is not None:
421
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
422
+ else:
423
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
424
+ final_image = None
425
+ step_counter = 0
426
+ for image in image_generator:
427
+ step_counter += 1
428
+ final_image = image
429
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
430
+ yield image, seed, gr.update(value=progress_bar, visible=True)
431
+
432
+
433
+
434
+ if final_image is None:
435
+ raise Exception("Failed to generate image")
436
+
437
+ return final_image, seed, gr.update(visible=False)
438
+ except Exception as e:
439
+ print(f"Error in run_lora: {str(e)}")
440
+ return None, seed, gr.update(visible=False)
441
+
442
+
443
+
444
+ run_lora.zerogpu = True
445
+
446
+ def get_huggingface_safetensors(link):
447
+ split_link = link.split("/")
448
+ if len(split_link) == 2:
449
+ model_card = ModelCard.load(link)
450
+ base_model = model_card.data.get("base_model")
451
+ print(f"Base model: {base_model}")
452
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
453
+ raise Exception("Not a FLUX LoRA!")
454
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
455
+ trigger_word = model_card.data.get("instance_prompt", "")
456
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
457
+ fs = HfFileSystem()
458
+ safetensors_name = None
459
+ try:
460
+ list_of_files = fs.ls(link, detail=False)
461
+ for file in list_of_files:
462
+ if file.endswith(".safetensors"):
463
+ safetensors_name = file.split("/")[-1]
464
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
465
+ image_elements = file.split("/")
466
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
467
+ except Exception as e:
468
+ print(e)
469
+ raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
470
+ if not safetensors_name:
471
+ raise gr.Error("No *.safetensors file found in the repository")
472
+ return split_link[1], link, safetensors_name, trigger_word, image_url
473
+ else:
474
+ raise gr.Error("Invalid Hugging Face repository link")
475
+
476
+ def check_custom_model(link):
477
+ if link.endswith(".safetensors"):
478
+ # Treat as direct link to the LoRA weights
479
+ title = os.path.basename(link)
480
+ repo = link
481
+ path = None # No specific weight name
482
+ trigger_word = ""
483
+ image_url = None
484
+ return title, repo, path, trigger_word, image_url
485
+ elif link.startswith("https://"):
486
+ if "huggingface.co" in link:
487
+ link_split = link.split("huggingface.co/")
488
+ return get_huggingface_safetensors(link_split[1])
489
+ else:
490
+ raise Exception("Unsupported URL")
491
+ else:
492
+ # Assume it's a Hugging Face model path
493
+ return get_huggingface_safetensors(link)
494
+
495
+ def update_history(new_image, history):
496
+ """Updates the history gallery with the new image."""
497
+ if history is None:
498
+ history = []
499
+ if new_image is not None:
500
+ history.insert(0, new_image)
501
+ return history
502
+
503
+ css = '''
504
+ #gen_btn{height: 100%}
505
+ #title{text-align: center}
506
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
507
+ #title img{width: 100px; margin-right: 0.25em}
508
+ #gallery .grid-wrap{height: 5vh}
509
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
510
+ .custom_lora_card{margin-bottom: 1em}
511
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
512
+ .card_internal img{margin-right: 1em}
513
+ .styler{--form-gap-width: 0px !important}
514
+ #progress{height:30px}
515
+ #progress .generating{display:none}
516
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
517
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
518
+ #component-8, .button_total{height: 100%; align-self: stretch;}
519
+ #loaded_loras [data-testid="block-info"]{font-size:80%}
520
+ #custom_lora_structure{background: var(--block-background-fill)}
521
+ #custom_lora_btn{margin-top: auto;margin-bottom: 11px}
522
+ #random_btn{font-size: 300%}
523
+ #component-11{align-self: stretch;}
524
+ footer {visibility: hidden;}
525
+ '''
526
+
527
+ # 업스케일 관련 함수 추가
528
+ def process_input(input_image, upscale_factor, **kwargs):
529
+ w, h = input_image.size
530
+ w_original, h_original = w, h
531
+ aspect_ratio = w / h
532
+
533
+ was_resized = False
534
+
535
+ max_size = int(np.sqrt(MAX_PIXEL_BUDGET / (upscale_factor ** 2)))
536
+ if w > max_size or h > max_size:
537
+ if w > h:
538
+ w_new = max_size
539
+ h_new = int(w_new / aspect_ratio)
540
+ else:
541
+ h_new = max_size
542
+ w_new = int(h_new * aspect_ratio)
543
+
544
+ input_image = input_image.resize((w_new, h_new), Image.LANCZOS)
545
+ was_resized = True
546
+ gr.Info(f"Input image resized to {w_new}x{h_new} to fit within pixel budget after upscaling.")
547
+
548
+ # resize to multiple of 8
549
+ w, h = input_image.size
550
+ w = w - w % 8
551
+ h = h - h % 8
552
+
553
+ return input_image.resize((w, h)), w_original, h_original, was_resized
554
+
555
+ from PIL import Image
556
+ import numpy as np
557
+
558
+ @spaces.GPU
559
+ def infer_upscale(
560
+ seed,
561
+ randomize_seed,
562
+ input_image,
563
+ num_inference_steps,
564
+ upscale_factor,
565
+ controlnet_conditioning_scale,
566
+ progress=gr.Progress(track_tqdm=True),
567
+ ):
568
+ if input_image is None:
569
+ return None, seed, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=True, value="Please upload an image for upscaling.")
570
+
571
+ try:
572
+ if randomize_seed:
573
+ seed = random.randint(0, MAX_SEED)
574
+
575
+ input_image, w_original, h_original, was_resized = process_input(input_image, upscale_factor)
576
+
577
+ # rescale with upscale factor
578
+ w, h = input_image.size
579
+ control_image = input_image.resize((w * upscale_factor, h * upscale_factor), Image.LANCZOS)
580
+
581
+ generator = torch.Generator(device=device).manual_seed(seed)
582
+
583
+ gr.Info("Upscaling image...")
584
+ # 모든 텐서를 동일한 디바이스로 이동
585
+ pipe_upscale.to(device)
586
+
587
+ # Ensure the image is in RGB format
588
+ if control_image.mode != 'RGB':
589
+ control_image = control_image.convert('RGB')
590
+
591
+ # Convert to tensor and add batch dimension
592
+ control_image = torch.from_numpy(np.array(control_image)).permute(2, 0, 1).float().unsqueeze(0).to(device) / 255.0
593
+
594
+ with torch.no_grad():
595
+ image = pipe_upscale(
596
+ prompt="",
597
+ control_image=control_image,
598
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
599
+ num_inference_steps=num_inference_steps,
600
+ guidance_scale=3.5,
601
+ generator=generator,
602
+ ).images[0]
603
+
604
+ # Convert the image back to PIL Image
605
+ if isinstance(image, torch.Tensor):
606
+ image = image.cpu().permute(1, 2, 0).numpy()
607
+
608
+ # Ensure the image data is in the correct range
609
+ image = np.clip(image * 255, 0, 255).astype(np.uint8)
610
+ image = Image.fromarray(image)
611
+
612
+ if was_resized:
613
+ gr.Info(
614
+ f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size."
615
+ )
616
+ image = image.resize((w_original * upscale_factor, h_original * upscale_factor), Image.LANCZOS)
617
+
618
+ return image, seed, num_inference_steps, upscale_factor, controlnet_conditioning_scale, gr.update(), gr.update(visible=False)
619
+ except Exception as e:
620
+ print(f"Error in infer_upscale: {str(e)}")
621
+ import traceback
622
+ traceback.print_exc()
623
+ return None, seed, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=True, value=f"Error: {str(e)}")
624
+
625
+ def check_upscale_input(input_image, *args):
626
+ if input_image is None:
627
+ return gr.update(interactive=False), *args, gr.update(visible=True, value="Please upload an image for upscaling.")
628
+ return gr.update(interactive=True), *args, gr.update(visible=False)
629
+
630
+ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css, delete_cache=(60, 3600)) as app:
631
+ loras_state = gr.State(loras)
632
+ selected_indices = gr.State([])
633
+
634
+ with gr.Row():
635
+ with gr.Column(scale=3):
636
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
637
+ with gr.Column(scale=1):
638
+ generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
639
+
640
+ with gr.Row(elem_id="loaded_loras"):
641
+ with gr.Column(scale=1, min_width=25):
642
+ randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
643
+ with gr.Column(scale=8):
644
+ with gr.Row():
645
+ with gr.Column(scale=0, min_width=50):
646
+ lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
647
+ with gr.Column(scale=3, min_width=100):
648
+ selected_info_1 = gr.Markdown("Select a LoRA 1")
649
+ with gr.Column(scale=5, min_width=50):
650
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
651
+ with gr.Row():
652
+ remove_button_1 = gr.Button("Remove", size="sm")
653
+ with gr.Column(scale=8):
654
+ with gr.Row():
655
+ with gr.Column(scale=0, min_width=50):
656
+ lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
657
+ with gr.Column(scale=3, min_width=100):
658
+ selected_info_2 = gr.Markdown("Select a LoRA 2")
659
+ with gr.Column(scale=5, min_width=50):
660
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
661
+ with gr.Row():
662
+ remove_button_2 = gr.Button("Remove", size="sm")
663
+
664
+ with gr.Row():
665
+ with gr.Column():
666
+ with gr.Group():
667
+ with gr.Row(elem_id="custom_lora_structure"):
668
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="ginipick/flux-lora-eric-cat", scale=3, min_width=150)
669
+ add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
670
+ remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
671
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
672
+ gallery = gr.Gallery(
673
+ [(item["image"], item["title"]) for item in loras],
674
+ label="Or pick from the LoRA Explorer gallery",
675
+ allow_preview=False,
676
+ columns=4,
677
+ elem_id="gallery"
678
+ )
679
+ with gr.Column():
680
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
681
+ result = gr.Image(label="Generated Image", interactive=False)
682
+ with gr.Accordion("History", open=False):
683
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
684
+
685
+ with gr.Row():
686
+ with gr.Accordion("Advanced Settings", open=False):
687
+ with gr.Row():
688
+ input_image = gr.Image(label="Input image", type="filepath")
689
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
690
+ with gr.Column():
691
+ with gr.Row():
692
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
693
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
694
+ with gr.Row():
695
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
696
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
697
+ with gr.Row():
698
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
699
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
700
+
701
+ # 업스케일 관련 UI 추가
702
+ with gr.Row():
703
+ upscale_button = gr.Button("Upscale", interactive=False)
704
+
705
+ with gr.Row():
706
+ with gr.Column(scale=4):
707
+ upscale_input = gr.Image(label="Input Image for Upscaling", type="pil")
708
+ with gr.Column(scale=1):
709
+ upscale_steps = gr.Slider(
710
+ label="Number of Inference Steps for Upscaling",
711
+ minimum=8,
712
+ maximum=50,
713
+ step=1,
714
+ value=28,
715
+ )
716
+ upscale_factor = gr.Slider(
717
+ label="Upscale Factor",
718
+ minimum=1,
719
+ maximum=4,
720
+ step=1,
721
+ value=4,
722
+ )
723
+ controlnet_conditioning_scale = gr.Slider(
724
+ label="Controlnet Conditioning Scale",
725
+ minimum=0.1,
726
+ maximum=1.0,
727
+ step=0.05,
728
+ value=0.5, # 기본값을 0.5로 낮춤
729
+ )
730
+ upscale_seed = gr.Slider(
731
+ label="Seed for Upscaling",
732
+ minimum=0,
733
+ maximum=MAX_SEED,
734
+ step=1,
735
+ value=42,
736
+ )
737
+ upscale_randomize_seed = gr.Checkbox(label="Randomize seed for Upscaling", value=True)
738
+ upscale_error = gr.Markdown(visible=False, value="Please provide an input image for upscaling.")
739
+
740
+ with gr.Row():
741
+ upscale_result = gr.Image(label="Upscaled Image", type="pil")
742
+ upscale_seed_output = gr.Number(label="Seed Used", precision=0)
743
+
744
+
745
+ gallery.select(
746
+ update_selection,
747
+ inputs=[selected_indices, loras_state, width, height],
748
+ outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2]
749
+ )
750
+ remove_button_1.click(
751
+ remove_lora_1,
752
+ inputs=[selected_indices, loras_state],
753
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
754
+ )
755
+ remove_button_2.click(
756
+ remove_lora_2,
757
+ inputs=[selected_indices, loras_state],
758
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
759
+ )
760
+ randomize_button.click(
761
+ randomize_loras,
762
+ inputs=[selected_indices, loras_state],
763
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
764
+ )
765
+ add_custom_lora_button.click(
766
+ add_custom_lora,
767
+ inputs=[custom_lora, selected_indices, loras_state],
768
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
769
+ )
770
+ remove_custom_lora_button.click(
771
+ remove_custom_lora,
772
+ inputs=[selected_indices, loras_state],
773
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
774
+ )
775
+
776
+ gr.on(
777
+ triggers=[generate_button.click, prompt.submit],
778
+ fn=run_lora,
779
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
780
+ outputs=[result, seed, progress_bar]
781
+ ).then(
782
+ fn=lambda x, history: update_history(x, history) if x is not None else history,
783
+ inputs=[result, history_gallery],
784
+ outputs=history_gallery,
785
+ )
786
+
787
+ upscale_input.upload(
788
+ lambda x: gr.update(interactive=x is not None),
789
+ inputs=[upscale_input],
790
+ outputs=[upscale_button]
791
+ )
792
+
793
+ upscale_error = gr.Markdown(visible=False, value="")
794
+
795
+ upscale_button.click(
796
+ infer_upscale,
797
+ inputs=[
798
+ upscale_seed,
799
+ upscale_randomize_seed,
800
+ upscale_input,
801
+ upscale_steps,
802
+ upscale_factor,
803
+ controlnet_conditioning_scale,
804
+ ],
805
+ outputs=[
806
+ upscale_result,
807
+ upscale_seed_output,
808
+ upscale_steps,
809
+ upscale_factor,
810
+ controlnet_conditioning_scale,
811
+ upscale_randomize_seed,
812
+ upscale_error
813
+ ],
814
+
815
+ ).then(
816
+ infer_upscale,
817
+ inputs=[
818
+ upscale_seed,
819
+ upscale_randomize_seed,
820
+ upscale_input,
821
+ upscale_steps,
822
+ upscale_factor,
823
+ controlnet_conditioning_scale,
824
+ ],
825
+ outputs=[upscale_result, upscale_seed_output]
826
+ )
827
+
828
+
829
+ if __name__ == "__main__":
830
+ app.queue(max_size=20)
831
+ app.launch(debug=True)