guoyww commited on
Commit
e1c7172
·
1 Parent(s): e9be54a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -1,8 +1,10 @@
1
 
 
2
  import os
3
  import json
4
  import torch
5
  import random
 
6
 
7
  import gradio as gr
8
  from glob import glob
@@ -147,8 +149,8 @@ class AnimateController:
147
  raise gr.Error(f"Please select a pretrained model path.")
148
  if motion_module_dropdown == "":
149
  raise gr.Error(f"Please select a motion module.")
150
- if base_model_dropdown == "":
151
- raise gr.Error(f"Please select a base DreamBooth model.")
152
 
153
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
154
 
@@ -158,11 +160,13 @@ class AnimateController:
158
  ).to("cuda")
159
 
160
  if self.lora_model_state_dict != {}:
161
- pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
 
162
 
163
  pipeline.to("cuda")
164
 
165
- if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
 
166
  else: torch.seed()
167
  seed = torch.initial_seed()
168
 
@@ -259,7 +263,7 @@ def ui():
259
  )
260
  lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
261
 
262
- lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
263
 
264
  personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
265
  def update_personalized_model():
 
1
 
2
+
3
  import os
4
  import json
5
  import torch
6
  import random
7
+ import copy
8
 
9
  import gradio as gr
10
  from glob import glob
 
149
  raise gr.Error(f"Please select a pretrained model path.")
150
  if motion_module_dropdown == "":
151
  raise gr.Error(f"Please select a motion module.")
152
+ # if base_model_dropdown == "":
153
+ # raise gr.Error(f"Please select a base DreamBooth model.")
154
 
155
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
156
 
 
160
  ).to("cuda")
161
 
162
  if self.lora_model_state_dict != {}:
163
+ print(f"Lora alpha: {lora_alpha_slider}")
164
+ pipeline = convert_lora(copy.deepcopy(pipeline), self.lora_model_state_dict, alpha=lora_alpha_slider)
165
 
166
  pipeline.to("cuda")
167
 
168
+ seed_textbox = int(seed_textbox)
169
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(seed_textbox)
170
  else: torch.seed()
171
  seed = torch.initial_seed()
172
 
 
263
  )
264
  lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
265
 
266
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.7, minimum=0, maximum=2, interactive=True)
267
 
268
  personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
269
  def update_personalized_model():