Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -189,83 +189,6 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
189 |
|
190 |
yield final_image, seed, temp_file.name, f"data:image/png;base64,{img_base64}", gr.update(value=progress_bar, visible=False)
|
191 |
|
192 |
-
def get_huggingface_safetensors(link):
|
193 |
-
split_link = link.split("/")
|
194 |
-
if(len(split_link) == 2):
|
195 |
-
model_card = ModelCard.load(link)
|
196 |
-
base_model = model_card.data.get("base_model")
|
197 |
-
print(base_model)
|
198 |
-
if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
|
199 |
-
raise Exception("Not a FLUX LoRA!")
|
200 |
-
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
201 |
-
trigger_word = model_card.data.get("instance_prompt", "")
|
202 |
-
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
203 |
-
fs = HfFileSystem()
|
204 |
-
try:
|
205 |
-
list_of_files = fs.ls(link, detail=False)
|
206 |
-
for file in list_of_files:
|
207 |
-
if(file.endswith(".safetensors")):
|
208 |
-
safetensors_name = file.split("/")[-1]
|
209 |
-
if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
|
210 |
-
image_elements = file.split("/")
|
211 |
-
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
|
212 |
-
except Exception as e:
|
213 |
-
print(e)
|
214 |
-
gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
215 |
-
raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
216 |
-
return split_link[1], link, safetensors_name, trigger_word, image_url
|
217 |
-
|
218 |
-
def check_custom_model(link):
|
219 |
-
if(link.startswith("https://")):
|
220 |
-
if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
|
221 |
-
link_split = link.split("huggingface.co/")
|
222 |
-
return get_huggingface_safetensors(link_split[1])
|
223 |
-
else:
|
224 |
-
return get_huggingface_safetensors(link)
|
225 |
-
|
226 |
-
def add_custom_lora(custom_lora):
|
227 |
-
global loras
|
228 |
-
if(custom_lora):
|
229 |
-
try:
|
230 |
-
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
|
231 |
-
print(f"Loaded custom LoRA: {repo}")
|
232 |
-
card = f'''
|
233 |
-
<div class="custom_lora_card">
|
234 |
-
<span>Loaded custom LoRA:</span>
|
235 |
-
<div class="card_internal">
|
236 |
-
<img src="{image}" />
|
237 |
-
<div>
|
238 |
-
<h3>{title}</h3>
|
239 |
-
<small>{"Using: <code><b"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
|
240 |
-
</div>
|
241 |
-
</div>
|
242 |
-
</div>
|
243 |
-
'''
|
244 |
-
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
|
245 |
-
if(not existing_item_index):
|
246 |
-
new_item = {
|
247 |
-
"image": image,
|
248 |
-
"title": title,
|
249 |
-
"repo": repo,
|
250 |
-
"weights": path,
|
251 |
-
"trigger_word": trigger_word
|
252 |
-
}
|
253 |
-
print(new_item)
|
254 |
-
existing_item_index = len(loras)
|
255 |
-
loras.append(new_item)
|
256 |
-
|
257 |
-
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
|
258 |
-
except Exception as e:
|
259 |
-
gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
|
260 |
-
return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None, ""
|
261 |
-
else:
|
262 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
263 |
-
|
264 |
-
def remove_custom_lora():
|
265 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
266 |
-
|
267 |
-
run_lora.zerogpu = True
|
268 |
-
|
269 |
css = '''
|
270 |
#gen_btn{height: 100%}
|
271 |
#gen_column{align-self: stretch}
|
@@ -332,4 +255,27 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60))
|
|
332 |
|
333 |
with gr.Row():
|
334 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
335 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
yield final_image, seed, temp_file.name, f"data:image/png;base64,{img_base64}", gr.update(value=progress_bar, visible=False)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
css = '''
|
193 |
#gen_btn{height: 100%}
|
194 |
#gen_column{align-self: stretch}
|
|
|
255 |
|
256 |
with gr.Row():
|
257 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
258 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
259 |
+
|
260 |
+
gallery.select(
|
261 |
+
update_selection,
|
262 |
+
inputs=[width, height],
|
263 |
+
outputs=[prompt, selected_info, selected_index, width, height]
|
264 |
+
)
|
265 |
+
custom_lora.input(
|
266 |
+
add_custom_lora,
|
267 |
+
inputs=[custom_lora],
|
268 |
+
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
|
269 |
+
)
|
270 |
+
custom_lora_button.click(
|
271 |
+
remove_custom_lora,
|
272 |
+
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
|
273 |
+
)
|
274 |
+
generate_button.click(
|
275 |
+
run_lora,
|
276 |
+
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
|
277 |
+
outputs=[result, seed, download_link, base64_output, progress_bar]
|
278 |
+
)
|
279 |
+
|
280 |
+
app.queue()
|
281 |
+
app.launch()
|