PommesPeter commited on
Commit
15222c4
·
verified ·
1 Parent(s): 8a566c0

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +16 -19
models/model.py CHANGED
@@ -14,8 +14,8 @@ import torch.nn.functional as F
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
- def modulate(x, shift, scale):
18
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
19
 
20
 
21
  #############################################################################
@@ -533,16 +533,17 @@ class TransformerBlock(nn.Module):
533
  ffn_dim_multiplier=ffn_dim_multiplier,
534
  )
535
  self.layer_id = layer_id
536
- self.attention_norm = RMSNorm(dim, eps=norm_eps)
537
  self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
538
- self.ffn_norm = RMSNorm(dim, eps=norm_eps)
 
539
  self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
 
540
 
541
  self.adaLN_modulation = nn.Sequential(
542
  nn.SiLU(),
543
  nn.Linear(
544
  min(dim, 1024),
545
- 6 * dim,
546
  bias=True,
547
  ),
548
  )
@@ -571,14 +572,11 @@ class TransformerBlock(nn.Module):
571
 
572
  """
573
  if adaln_input is not None:
574
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
575
- self.adaLN_modulation(adaln_input).chunk(6, dim=1)
576
- )
577
 
578
- x = x + self.attention_norm1(
579
- gate_msa.unsqueeze(1)
580
- * self.attention(
581
- modulate(self.attention_norm(x), shift_msa, scale_msa),
582
  x_mask,
583
  freqs_cis,
584
  self.attention_y_norm(y),
@@ -586,10 +584,9 @@ class TransformerBlock(nn.Module):
586
  )
587
  )
588
  d = x.shape[-1]
589
- x = x + self.ffn_norm1(
590
- gate_mlp.unsqueeze(1)
591
- * self.feed_forward(
592
- modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
593
  ).view(*x.shape)
594
  )
595
 
@@ -633,14 +630,14 @@ class ParallelFinalLayer(nn.Module):
633
  nn.SiLU(),
634
  nn.Linear(
635
  min(hidden_size, 1024),
636
- 2 * hidden_size,
637
  bias=True,
638
  ),
639
  )
640
 
641
  def forward(self, x, c):
642
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
643
- x = modulate(self.norm_final(x), shift, scale)
644
  x = self.linear(x)
645
  return x
646
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
+ def modulate(x, scale):
18
+ return x * (1 + scale.unsqueeze(1))
19
 
20
 
21
  #############################################################################
 
533
  ffn_dim_multiplier=ffn_dim_multiplier,
534
  )
535
  self.layer_id = layer_id
 
536
  self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
537
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
538
+
539
  self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
540
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
541
 
542
  self.adaLN_modulation = nn.Sequential(
543
  nn.SiLU(),
544
  nn.Linear(
545
  min(dim, 1024),
546
+ 2 * dim,
547
  bias=True,
548
  ),
549
  )
 
572
 
573
  """
574
  if adaln_input is not None:
575
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
 
 
576
 
577
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
578
+ self.attention(
579
+ modulate(self.attention_norm1(x), scale_msa),
 
580
  x_mask,
581
  freqs_cis,
582
  self.attention_y_norm(y),
 
584
  )
585
  )
586
  d = x.shape[-1]
587
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
588
+ self.feed_forward(
589
+ modulate(self.ffn_norm1(x), scale_mlp).view(-1, d),
 
590
  ).view(*x.shape)
591
  )
592
 
 
630
  nn.SiLU(),
631
  nn.Linear(
632
  min(hidden_size, 1024),
633
+ hidden_size,
634
  bias=True,
635
  ),
636
  )
637
 
638
  def forward(self, x, c):
639
+ scale = self.adaLN_modulation(c)
640
+ x = modulate(self.norm_final(x), scale)
641
  x = self.linear(x)
642
  return x
643