multimodalart HF staff commited on
Commit
a9d2af1
1 Parent(s): 059f3f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -35
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
@@ -11,7 +11,9 @@ import json
11
  import gc
12
  import random
13
  from urllib.parse import quote
14
- with open("sdxl_loras.json", "r") as file:
 
 
15
  data = json.load(file)
16
  sdxl_loras_raw = [
17
  {
@@ -62,7 +64,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
62
  vae=vae,
63
  torch_dtype=torch.float16,
64
  )
65
- original_pipe = copy.deepcopy(pipe)
66
  pipe.to(device)
67
 
68
  last_lora = ""
@@ -75,6 +77,7 @@ button.addEventListener('click', function() {
75
  element.classList.add('selected');
76
  });
77
  '''
 
78
  def update_selection(selected_state: gr.SelectData, sdxl_loras, is_new=False):
79
  lora_repo = sdxl_loras[selected_state.index]["repo"]
80
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
@@ -139,29 +142,6 @@ def check_selected(selected_state):
139
  if not selected_state:
140
  raise gr.Error("You must select a LoRA")
141
 
142
- def merge_incompatible_lora(full_path_lora, lora_scale):
143
- for weights_file in [full_path_lora]:
144
- if ";" in weights_file:
145
- weights_file, multiplier = weights_file.split(";")
146
- multiplier = float(multiplier)
147
- else:
148
- multiplier = lora_scale
149
-
150
- lora_model, weights_sd = lora.create_network_from_weights(
151
- multiplier,
152
- full_path_lora,
153
- pipe.vae,
154
- pipe.text_encoder,
155
- pipe.unet,
156
- for_inference=True,
157
- )
158
- lora_model.merge_to(
159
- pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
160
- )
161
- del weights_sd
162
- del lora_model
163
- gc.collect()
164
-
165
  def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
166
  global last_lora, last_merged, last_fused, pipe
167
  print("Index when running ", selected_state.index)
@@ -187,11 +167,9 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_lora
187
  if last_lora != repo_name:
188
  if(last_fused):
189
  pipe.unfuse_lora()
190
- pipe.unload_lora_weights()
191
- #is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
192
- pipe.load_lora_weights(loaded_state_dict)#, adapter_name="loaded_lora")
193
- #pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm_lora")
194
- #pipe.set_adapters(["loaded_lora", "lcm_lora"], adapter_weights=[0.8, 1.0])
195
  pipe.fuse_lora()
196
  last_fused = True
197
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
@@ -207,11 +185,10 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_lora
207
  image = pipe(
208
  prompt=prompt,
209
  negative_prompt=negative,
210
- width=1024,
211
- height=1024,
212
- num_inference_steps=20,
213
- guidance_scale=7.5,
214
  ).images[0]
 
215
  last_lora = repo_name
216
  gc.collect()
217
  return image, gr.update(visible=True)
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL, LCMScheduler
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
 
11
  import gc
12
  import random
13
  from urllib.parse import quote
14
+
15
+ lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
16
+ with open(lora_list, "r") as file:
17
  data = json.load(file)
18
  sdxl_loras_raw = [
19
  {
 
64
  vae=vae,
65
  torch_dtype=torch.float16,
66
  )
67
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
68
  pipe.to(device)
69
 
70
  last_lora = ""
 
77
  element.classList.add('selected');
78
  });
79
  '''
80
+
81
  def update_selection(selected_state: gr.SelectData, sdxl_loras, is_new=False):
82
  lora_repo = sdxl_loras[selected_state.index]["repo"]
83
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
 
142
  if not selected_state:
143
  raise gr.Error("You must select a LoRA")
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
146
  global last_lora, last_merged, last_fused, pipe
147
  print("Index when running ", selected_state.index)
 
167
  if last_lora != repo_name:
168
  if(last_fused):
169
  pipe.unfuse_lora()
170
+ pipe.load_lora_weights(loaded_state_dict, adapter_name="loaded_lora")
171
+ pipe.load_lora_weights(lcm_lora_id, weight_name="lcm_sdxl_lora.safetensors", adapter_name="lcm_lora", use_auth_token=True)
172
+ pipe.set_adapters(["loaded_lora", "lcm_lora"], adapter_weights=[0.8, 1.0])
 
 
173
  pipe.fuse_lora()
174
  last_fused = True
175
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
 
185
  image = pipe(
186
  prompt=prompt,
187
  negative_prompt=negative,
188
+ num_inference_steps=4,
189
+ guidance_scale=0.5,
 
 
190
  ).images[0]
191
+
192
  last_lora = repo_name
193
  gc.collect()
194
  return image, gr.update(visible=True)