wondervictor commited on
Commit
ef02327
·
verified ·
1 Parent(s): 29ec12a

Update autoregressive/models/gpt_t2i.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/gpt_t2i.py +1 -0
autoregressive/models/gpt_t2i.py CHANGED
@@ -429,6 +429,7 @@ class Transformer(nn.Module):
429
  self.freqs_cis = self.freqs_cis.to(h.device)
430
  else:
431
  if cond_idx is not None: # prefill in inference
 
432
  token_embeddings = self.cls_embedding(cond_idx, train=self.training)
433
  token_embeddings = token_embeddings[:,:self.cls_token_num]
434
  if condition is not None:
 
429
  self.freqs_cis = self.freqs_cis.to(h.device)
430
  else:
431
  if cond_idx is not None: # prefill in inference
432
+ self.control_strength = control_strength
433
  token_embeddings = self.cls_embedding(cond_idx, train=self.training)
434
  token_embeddings = token_embeddings[:,:self.cls_token_num]
435
  if condition is not None: