multimodalart HF staff commited on
Commit
c4cd17d
·
1 Parent(s): 8fe2fce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -55
app.py CHANGED
@@ -9,16 +9,17 @@ import copy
9
  import json
10
 
11
  with open("sdxl_loras.json", "r") as file:
 
12
  sdxl_loras = [
13
- (
14
- item["image"],
15
- item["title"],
16
- item["repo"],
17
- item["trigger_word"],
18
- item["weights"],
19
- item["is_compatible"],
20
- )
21
- for item in json.load(file)
22
  ]
23
 
24
  saved_names = [
@@ -43,10 +44,10 @@ last_merged = False
43
 
44
 
45
  def update_selection(selected_state: gr.SelectData):
46
- lora_repo = sdxl_loras[selected_state.index][2]
47
- instance_prompt = sdxl_loras[selected_state.index][3]
48
  new_placeholder = "Type a prompt! This style works for all prompts without a trigger word" if instance_prompt == "" else "Type a prompt to use your selected LoRA"
49
- weight_name = sdxl_loras[selected_state.index][4]
50
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
51
  use_with_diffusers = f'''
52
  ## Using [`{lora_repo}`](https://huggingface.co/{lora_repo})
@@ -93,52 +94,49 @@ def check_selected(selected_state):
93
  if not selected_state:
94
  raise gr.Error("You must select a LoRA")
95
 
96
- def run_lora(prompt, negative, lora_scale, selected_state):
97
- global last_lora, last_merged, pipe
98
-
99
- if not selected_state:
100
- raise gr.Error("You must select a LoRA")
101
-
102
- if negative == "":
103
- negative = None
104
 
 
 
 
105
 
106
- repo_name = sdxl_loras[selected_state.index][2]
107
- weight_name = sdxl_loras[selected_state.index][4]
108
- full_path_lora = saved_names[selected_state.index]
109
- cross_attention_kwargs = None
110
- if last_lora != repo_name:
111
- if last_merged:
112
- pipe = copy.deepcopy(original_pipe)
113
- pipe.to(device)
114
- else:
115
- pipe.unload_lora_weights()
116
- is_compatible = sdxl_loras[selected_state.index][5]
117
- if is_compatible:
118
- pipe.load_lora_weights(full_path_lora)
119
- cross_attention_kwargs = {"scale": lora_scale}
 
 
 
120
  else:
121
- for weights_file in [full_path_lora]:
122
- if ";" in weights_file:
123
- weights_file, multiplier = weights_file.split(";")
124
- multiplier = float(multiplier)
125
- else:
126
- multiplier = lora_scale
127
 
128
- lora_model, weights_sd = lora.create_network_from_weights(
129
- multiplier,
130
- full_path_lora,
131
- pipe.vae,
132
- pipe.text_encoder,
133
- pipe.unet,
134
- for_inference=True,
135
- )
136
- lora_model.merge_to(
137
- pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
138
- )
139
- last_merged = True
140
 
141
- image = pipe(
 
142
  prompt=prompt,
143
  negative_prompt=negative,
144
  width=768,
@@ -147,6 +145,26 @@ def run_lora(prompt, negative, lora_scale, selected_state):
147
  guidance_scale=7.5,
148
  cross_attention_kwargs=cross_attention_kwargs,
149
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  last_lora = repo_name
151
  return image, gr.update(visible=True)
152
 
@@ -235,7 +253,7 @@ with gr.Blocks(css="custom.css") as demo:
235
  inputs=[selected_state],
236
  queue=False,
237
  show_progress=False
238
- ).then(
239
  fn=run_lora,
240
  inputs=[prompt, negative, weight, selected_state],
241
  outputs=[result, share_group],
@@ -245,7 +263,7 @@ with gr.Blocks(css="custom.css") as demo:
245
  inputs=[selected_state],
246
  queue=False,
247
  show_progress=False
248
- ).then(
249
  fn=run_lora,
250
  inputs=[prompt, negative, weight, selected_state],
251
  outputs=[result, share_group],
 
9
  import json
10
 
11
  with open("sdxl_loras.json", "r") as file:
12
+ data = json.load(file)
13
  sdxl_loras = [
14
+ {
15
+ "image": item["image"],
16
+ "title": item["title"],
17
+ "repo": item["repo"],
18
+ "trigger_word": item["trigger_word"],
19
+ "weights": item["weights"],
20
+ "is_compatible": item["is_compatible"],
21
+ }
22
+ for item in data
23
  ]
24
 
25
  saved_names = [
 
44
 
45
 
46
  def update_selection(selected_state: gr.SelectData):
47
+ lora_repo = sdxl_loras[selected_state.index]["repo"]
48
+ instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
49
  new_placeholder = "Type a prompt! This style works for all prompts without a trigger word" if instance_prompt == "" else "Type a prompt to use your selected LoRA"
50
+ weight_name = sdxl_loras[selected_state.index]["weights"]
51
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
52
  use_with_diffusers = f'''
53
  ## Using [`{lora_repo}`](https://huggingface.co/{lora_repo})
 
94
  if not selected_state:
95
  raise gr.Error("You must select a LoRA")
96
 
97
+ def get_cross_attention_kwargs(scale, repo_name, is_compatible):
98
+ if repo_name != last_lora and is_compatible:
99
+ return {"scale": scale}
100
+ return None
 
 
 
 
101
 
102
+ def load_lora_model(pipe, repo_name, full_path_lora, lora_scale):
103
+ if repo_name == last_lora:
104
+ return
105
 
106
+ if last_merged:
107
+ pipe = copy.deepcopy(original_pipe)
108
+ pipe.to(device)
109
+ else:
110
+ pipe.unload_lora_weights()
111
+
112
+ is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
113
+ if is_compatible:
114
+ pipe.load_lora_weights(full_path_lora)
115
+ else:
116
+ load_incompatible_lora(pipe, full_path_lora, lora_scale)
117
+
118
+ def load_incompatible_lora(pipe, full_path_lora, lora_scale):
119
+ for weights_file in [full_path_lora]:
120
+ if ";" in weights_file:
121
+ weights_file, multiplier = weights_file.split(";")
122
+ multiplier = float(multiplier)
123
  else:
124
+ multiplier = lora_scale
 
 
 
 
 
125
 
126
+ lora_model, weights_sd = lora.create_network_from_weights(
127
+ multiplier,
128
+ full_path_lora,
129
+ pipe.vae,
130
+ pipe.text_encoder,
131
+ pipe.unet,
132
+ for_inference=True,
133
+ )
134
+ lora_model.merge_to(
135
+ pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
136
+ )
 
137
 
138
+ def generate_image(pipe, prompt, negative, cross_attention_kwargs):
139
+ return pipe(
140
  prompt=prompt,
141
  negative_prompt=negative,
142
  width=768,
 
145
  guidance_scale=7.5,
146
  cross_attention_kwargs=cross_attention_kwargs,
147
  ).images[0]
148
+
149
+ def run_lora(prompt, negative, lora_scale, selected_state):
150
+ global last_lora, last_merged, pipe
151
+
152
+ if not selected_state:
153
+ raise gr.Error("You must select a LoRA")
154
+
155
+ if negative == "":
156
+ negative = None
157
+
158
+ repo_name = sdxl_loras[selected_state.index]["repo"]
159
+ full_path_lora = saved_names[selected_state.index]
160
+
161
+ cross_attention_kwargs = get_cross_attention_kwargs(
162
+ lora_scale, repo_name, sdxl_loras[selected_state.index]["is_compatible"])
163
+
164
+ load_lora_model(pipe, repo_name, full_path_lora, lora_scale)
165
+
166
+ image = generate_image(pipe, prompt, negative, cross_attention_kwargs)
167
+
168
  last_lora = repo_name
169
  return image, gr.update(visible=True)
170
 
 
253
  inputs=[selected_state],
254
  queue=False,
255
  show_progress=False
256
+ ).success(
257
  fn=run_lora,
258
  inputs=[prompt, negative, weight, selected_state],
259
  outputs=[result, share_group],
 
263
  inputs=[selected_state],
264
  queue=False,
265
  show_progress=False
266
+ ).success(
267
  fn=run_lora,
268
  inputs=[prompt, negative, weight, selected_state],
269
  outputs=[result, share_group],