Spaces:
Paused
Paused
Commit
•
a9d2af1
1
Parent(s):
059f3f4
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from safetensors.torch import load_file
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
@@ -11,7 +11,9 @@ import json
|
|
11 |
import gc
|
12 |
import random
|
13 |
from urllib.parse import quote
|
14 |
-
|
|
|
|
|
15 |
data = json.load(file)
|
16 |
sdxl_loras_raw = [
|
17 |
{
|
@@ -62,7 +64,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
62 |
vae=vae,
|
63 |
torch_dtype=torch.float16,
|
64 |
)
|
65 |
-
|
66 |
pipe.to(device)
|
67 |
|
68 |
last_lora = ""
|
@@ -75,6 +77,7 @@ button.addEventListener('click', function() {
|
|
75 |
element.classList.add('selected');
|
76 |
});
|
77 |
'''
|
|
|
78 |
def update_selection(selected_state: gr.SelectData, sdxl_loras, is_new=False):
|
79 |
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
80 |
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
|
@@ -139,29 +142,6 @@ def check_selected(selected_state):
|
|
139 |
if not selected_state:
|
140 |
raise gr.Error("You must select a LoRA")
|
141 |
|
142 |
-
def merge_incompatible_lora(full_path_lora, lora_scale):
|
143 |
-
for weights_file in [full_path_lora]:
|
144 |
-
if ";" in weights_file:
|
145 |
-
weights_file, multiplier = weights_file.split(";")
|
146 |
-
multiplier = float(multiplier)
|
147 |
-
else:
|
148 |
-
multiplier = lora_scale
|
149 |
-
|
150 |
-
lora_model, weights_sd = lora.create_network_from_weights(
|
151 |
-
multiplier,
|
152 |
-
full_path_lora,
|
153 |
-
pipe.vae,
|
154 |
-
pipe.text_encoder,
|
155 |
-
pipe.unet,
|
156 |
-
for_inference=True,
|
157 |
-
)
|
158 |
-
lora_model.merge_to(
|
159 |
-
pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
|
160 |
-
)
|
161 |
-
del weights_sd
|
162 |
-
del lora_model
|
163 |
-
gc.collect()
|
164 |
-
|
165 |
def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
|
166 |
global last_lora, last_merged, last_fused, pipe
|
167 |
print("Index when running ", selected_state.index)
|
@@ -187,11 +167,9 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_lora
|
|
187 |
if last_lora != repo_name:
|
188 |
if(last_fused):
|
189 |
pipe.unfuse_lora()
|
190 |
-
|
191 |
-
|
192 |
-
pipe.
|
193 |
-
#pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm_lora")
|
194 |
-
#pipe.set_adapters(["loaded_lora", "lcm_lora"], adapter_weights=[0.8, 1.0])
|
195 |
pipe.fuse_lora()
|
196 |
last_fused = True
|
197 |
is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
|
@@ -207,11 +185,10 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_lora
|
|
207 |
image = pipe(
|
208 |
prompt=prompt,
|
209 |
negative_prompt=negative,
|
210 |
-
|
211 |
-
|
212 |
-
num_inference_steps=20,
|
213 |
-
guidance_scale=7.5,
|
214 |
).images[0]
|
|
|
215 |
last_lora = repo_name
|
216 |
gc.collect()
|
217 |
return image, gr.update(visible=True)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, LCMScheduler
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from safetensors.torch import load_file
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
|
|
11 |
import gc
|
12 |
import random
|
13 |
from urllib.parse import quote
|
14 |
+
|
15 |
+
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
|
16 |
+
with open(lora_list, "r") as file:
|
17 |
data = json.load(file)
|
18 |
sdxl_loras_raw = [
|
19 |
{
|
|
|
64 |
vae=vae,
|
65 |
torch_dtype=torch.float16,
|
66 |
)
|
67 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
68 |
pipe.to(device)
|
69 |
|
70 |
last_lora = ""
|
|
|
77 |
element.classList.add('selected');
|
78 |
});
|
79 |
'''
|
80 |
+
|
81 |
def update_selection(selected_state: gr.SelectData, sdxl_loras, is_new=False):
|
82 |
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
83 |
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
|
|
|
142 |
if not selected_state:
|
143 |
raise gr.Error("You must select a LoRA")
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
|
146 |
global last_lora, last_merged, last_fused, pipe
|
147 |
print("Index when running ", selected_state.index)
|
|
|
167 |
if last_lora != repo_name:
|
168 |
if(last_fused):
|
169 |
pipe.unfuse_lora()
|
170 |
+
pipe.load_lora_weights(loaded_state_dict, adapter_name="loaded_lora")
|
171 |
+
pipe.load_lora_weights(lcm_lora_id, weight_name="lcm_sdxl_lora.safetensors", adapter_name="lcm_lora", use_auth_token=True)
|
172 |
+
pipe.set_adapters(["loaded_lora", "lcm_lora"], adapter_weights=[0.8, 1.0])
|
|
|
|
|
173 |
pipe.fuse_lora()
|
174 |
last_fused = True
|
175 |
is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
|
|
|
185 |
image = pipe(
|
186 |
prompt=prompt,
|
187 |
negative_prompt=negative,
|
188 |
+
num_inference_steps=4,
|
189 |
+
guidance_scale=0.5,
|
|
|
|
|
190 |
).images[0]
|
191 |
+
|
192 |
last_lora = repo_name
|
193 |
gc.collect()
|
194 |
return image, gr.update(visible=True)
|