SunderAli17 commited on
Commit
f061925
·
verified ·
1 Parent(s): fab6dd2

Update toonmage/fluxpipeline.py

Browse files
Files changed (1) hide show
  1. toonmage/fluxpipeline.py +3 -3
toonmage/fluxpipeline.py CHANGED
@@ -27,7 +27,7 @@ class ToonMagePipeline(nn.Module):
27
  single_interval = 4
28
 
29
  # init encoder
30
- self.toonmage_encoder = IDFormer().to(self.device, self.weight_dtype)
31
 
32
  num_ca = 19 // double_interval + 38 // single_interval
33
  if 19 % double_interval != 0:
@@ -174,7 +174,7 @@ class ToonMagePipeline(nn.Module):
174
 
175
  id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1)
176
 
177
- id_embedding = self.toonmage_encoder(id_cond, id_vit_hidden)
178
 
179
  if not cal_uncond:
180
  return id_embedding, None
@@ -183,6 +183,6 @@ class ToonMagePipeline(nn.Module):
183
  id_vit_hidden_uncond = []
184
  for layer_idx in range(0, len(id_vit_hidden)):
185
  id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[layer_idx]))
186
- uncond_id_embedding = self.toonmage_encoder(id_uncond, id_vit_hidden_uncond)
187
 
188
  return id_embedding, uncond_id_embedding
 
27
  single_interval = 4
28
 
29
  # init encoder
30
+ self.pulid_encoder = IDFormer().to(self.device, self.weight_dtype)
31
 
32
  num_ca = 19 // double_interval + 38 // single_interval
33
  if 19 % double_interval != 0:
 
174
 
175
  id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1)
176
 
177
+ id_embedding = self.pulid_encoder(id_cond, id_vit_hidden)
178
 
179
  if not cal_uncond:
180
  return id_embedding, None
 
183
  id_vit_hidden_uncond = []
184
  for layer_idx in range(0, len(id_vit_hidden)):
185
  id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[layer_idx]))
186
+ uncond_id_embedding = self.pulid_encoder(id_uncond, id_vit_hidden_uncond)
187
 
188
  return id_embedding, uncond_id_embedding