CultriX commited on
Commit
30dafcd
1 Parent(s): b805c45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -5
app.py CHANGED
@@ -82,8 +82,12 @@ def update_selection(evt: gr.SelectData, width, height):
82
 
83
  @spaces.GPU(duration=30)
84
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
85
- pipe.to("cuda")
86
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
 
 
87
  with calculateDuration("Generating image"):
88
  # Generate image
89
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
@@ -128,6 +132,71 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
128
 
129
  return final_image, temp_file.name, f"data:image/png;base64,{img_base64}"
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  @spaces.GPU(duration=30)
132
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
133
  if selected_index is None:
@@ -135,7 +204,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
135
  selected_lora = loras[selected_index]
136
  lora_path = selected_lora["repo"]
137
  trigger_word = selected_lora["trigger_word"]
138
- if(trigger_word):
139
  if "trigger_position" in selected_lora:
140
  if selected_lora["trigger_position"] == "prepend":
141
  prompt_mash = f"{trigger_word} {prompt}"
@@ -166,7 +235,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
166
  if randomize_seed:
167
  seed = random.randint(0, MAX_SEED)
168
 
169
- if(image_input is not None):
170
  final_image, file_path, base64_str = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
171
  yield final_image, seed, file_path, base64_str, gr.update(visible=False)
172
  else:
@@ -176,7 +245,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
176
  final_image = None
177
  step_counter = 0
178
  for image in image_generator:
179
- step_counter+=1
180
  final_image = image
181
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
182
  yield image, seed, None, None, gr.update(value=progress_bar, visible=True)
@@ -279,3 +348,4 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60))
279
 
280
  app.queue()
281
  app.launch()
 
 
82
 
83
  @spaces.GPU(duration=30)
84
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
85
+ if torch.cuda.is_available():
86
+ pipe.to("cuda")
87
+ else:
88
+ pipe.to("cpu")
89
+
90
+ generator = torch.Generator(device=device).manual_seed(seed)
91
  with calculateDuration("Generating image"):
92
  # Generate image
93
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
 
132
 
133
  return final_image, temp_file.name, f"data:image/png;base64,{img_base64}"
134
 
135
+ def get_huggingface_safetensors(link):
136
+ split_link = link.split("/")
137
+ if len(split_link) == 2:
138
+ model_card = ModelCard.load(link)
139
+ base_model = model_card.data.get("base_model")
140
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
141
+ raise Exception("Not a FLUX LoRA!")
142
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
143
+ trigger_word = model_card.data.get("instance_prompt", "")
144
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
145
+ fs = HfFileSystem()
146
+ try:
147
+ list_of_files = fs.ls(link, detail=False)
148
+ for file in list_of_files:
149
+ if file.endswith(".safetensors"):
150
+ safetensors_name = file.split("/")[-1]
151
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
152
+ image_elements = file.split("/")
153
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
154
+ except Exception as e:
155
+ raise Exception(f"Invalid Hugging Face repository: {e}")
156
+ return split_link[1], link, safetensors_name, trigger_word, image_url
157
+
158
+ def check_custom_model(link):
159
+ if link.startswith("https://"):
160
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
161
+ link_split = link.split("huggingface.co/")
162
+ return get_huggingface_safetensors(link_split[1])
163
+ else:
164
+ return get_huggingface_safetensors(link)
165
+
166
+ def add_custom_lora(custom_lora):
167
+ global loras
168
+ if custom_lora:
169
+ try:
170
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
171
+ new_lora = {
172
+ "image": image,
173
+ "title": title,
174
+ "repo": repo,
175
+ "weights": path,
176
+ "trigger_word": trigger_word
177
+ }
178
+ loras.append(new_lora)
179
+ card = f'''
180
+ <div class="custom_lora_card">
181
+ <span>Loaded custom LoRA:</span>
182
+ <div class="card_internal">
183
+ <img src="{image}" />
184
+ <div>
185
+ <h3>{title}</h3>
186
+ <small>{"Using: <code><b"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
187
+ </div>
188
+ </div>
189
+ </div>
190
+ '''
191
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", len(loras) - 1, trigger_word
192
+ except Exception as e:
193
+ return gr.update(visible=True, value=f"Error: {e}"), gr.update(visible=False), gr.update(), "", None, ""
194
+ else:
195
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
196
+
197
+ def remove_custom_lora():
198
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
199
+
200
  @spaces.GPU(duration=30)
201
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
202
  if selected_index is None:
 
204
  selected_lora = loras[selected_index]
205
  lora_path = selected_lora["repo"]
206
  trigger_word = selected_lora["trigger_word"]
207
+ if trigger_word:
208
  if "trigger_position" in selected_lora:
209
  if selected_lora["trigger_position"] == "prepend":
210
  prompt_mash = f"{trigger_word} {prompt}"
 
235
  if randomize_seed:
236
  seed = random.randint(0, MAX_SEED)
237
 
238
+ if image_input is not None:
239
  final_image, file_path, base64_str = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
240
  yield final_image, seed, file_path, base64_str, gr.update(visible=False)
241
  else:
 
245
  final_image = None
246
  step_counter = 0
247
  for image in image_generator:
248
+ step_counter += 1
249
  final_image = image
250
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
251
  yield image, seed, None, None, gr.update(value=progress_bar, visible=True)
 
348
 
349
  app.queue()
350
  app.launch()
351
+