yupeng.zhou commited on
Commit
14f69c4
1 Parent(s): a5df616
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -110,7 +110,7 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
110
  encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
111
  # 判断随机数是否大于0.5
112
  if cur_step <5:
113
- hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
114
  else: # 256 1024 4096
115
  random_number = random.random()
116
  if cur_step <20:
 
110
  encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
111
  # 判断随机数是否大于0.5
112
  if cur_step <5:
113
+ hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
114
  else: # 256 1024 4096
115
  random_number = random.random()
116
  if cur_step <20: