wondervictor commited on
Commit
fbb8b6f
·
verified ·
1 Parent(s): c6196a6

Update autoregressive/models/gpt_t2i.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/gpt_t2i.py +10 -3
autoregressive/models/gpt_t2i.py CHANGED
@@ -367,6 +367,7 @@ class Transformer(nn.Module):
367
  self.mask = get_causal_mask(256)
368
  self.global_token = None
369
 
 
370
 
371
  def initialize_weights(self):
372
  # Initialize nn.Linear and nn.Embedding
@@ -411,7 +412,8 @@ class Transformer(nn.Module):
411
  targets: Optional[torch.Tensor] = None,
412
  mask: Optional[torch.Tensor] = None,
413
  valid: Optional[torch.Tensor] = None,
414
- condition: Optional[torch.Tensor] = None
 
415
  ):
416
  if idx is not None and cond_idx is not None: # training or naive inference
417
  cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
@@ -432,6 +434,9 @@ class Transformer(nn.Module):
432
  if condition is not None:
433
  condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
434
  self.condition_token = condition_embeddings
 
 
 
435
 
436
  else: # decode_n_tokens(kv cache) in inference
437
  token_embeddings = self.tok_embeddings(idx)
@@ -451,9 +456,11 @@ class Transformer(nn.Module):
451
  h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
452
  else:
453
  if len(input_pos)>1:
454
- h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
 
455
  else:
456
- h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
 
457
  h = layer(h, freqs_cis, input_pos, mask)
458
  # output layers
459
  h = self.norm(h)
 
367
  self.mask = get_causal_mask(256)
368
  self.global_token = None
369
 
370
+ self.control_strength = 1
371
 
372
  def initialize_weights(self):
373
  # Initialize nn.Linear and nn.Embedding
 
412
  targets: Optional[torch.Tensor] = None,
413
  mask: Optional[torch.Tensor] = None,
414
  valid: Optional[torch.Tensor] = None,
415
+ condition: Optional[torch.Tensor] = None,
416
+ control_strength: Optional[int] = 1
417
  ):
418
  if idx is not None and cond_idx is not None: # training or naive inference
419
  cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
 
434
  if condition is not None:
435
  condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
436
  self.condition_token = condition_embeddings
437
+ self.condition_token = [self.condition_layer[0](self.condition_token),
438
+ self.condition_layer[1](self.condition_token),
439
+ self.condition_layer[2](self.condition_token)]
440
 
441
  else: # decode_n_tokens(kv cache) in inference
442
  token_embeddings = self.tok_embeddings(idx)
 
456
  h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
457
  else:
458
  if len(input_pos)>1:
459
+ # h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
460
+ h[:,-1:] = h[:, -1:] + self.control_strength*self.condition_token[i//self.layer_internal][:,0:1]
461
  else:
462
+ # h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
463
+ h = h + self.control_strength*self.condition_token[i//self.layer_internal][:,input_pos-self.cls_token_num+1]
464
  h = layer(h, freqs_cis, input_pos, mask)
465
  # output layers
466
  h = self.norm(h)