yupeng.zhou commited on
Commit
bb17730
1 Parent(s): 9b7872b
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -111,7 +111,7 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
111
  else:
112
  encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
113
  # 判断随机数是否大于0.5
114
- if cur_step <0.1* num_steps:
115
  hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
116
  else: # 256 1024 4096
117
  random_number = random.random()
@@ -510,7 +510,7 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
510
  if _upload_images is None and _model_type != "original":
511
  raise gr.Error(f"Cannot find any input face image!")
512
  if len(prompt_array.splitlines()) > 6:
513
- raise gr.Error(f"No more than 6 prompts in huggface demo for Speed! But found {len(prompt_array)} prompts!")
514
  global sa32, sa64,id_length,total_length,attn_procs,unet,cur_model_type,device
515
  global num_steps
516
  global write
 
111
  else:
112
  encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
113
  # 判断随机数是否大于0.5
114
+ if cur_step <=1:
115
  hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
116
  else: # 256 1024 4096
117
  random_number = random.random()
 
510
  if _upload_images is None and _model_type != "original":
511
  raise gr.Error(f"Cannot find any input face image!")
512
  if len(prompt_array.splitlines()) > 6:
513
+ raise gr.Error(f"No more than 6 prompts in huggface demo for Speed! But found {len(prompt_array.splitlines()) > 6} prompts!")
514
  global sa32, sa64,id_length,total_length,attn_procs,unet,cur_model_type,device
515
  global num_steps
516
  global write