Spaces:
Running
on
Zero
Running
on
Zero
PommesPeter
commited on
Update models/model.py
Browse files- models/model.py +16 -19
models/model.py
CHANGED
@@ -14,8 +14,8 @@ import torch.nn.functional as F
|
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
-
def modulate(x,
|
18 |
-
return x * (1 + scale.unsqueeze(1))
|
19 |
|
20 |
|
21 |
#############################################################################
|
@@ -533,16 +533,17 @@ class TransformerBlock(nn.Module):
|
|
533 |
ffn_dim_multiplier=ffn_dim_multiplier,
|
534 |
)
|
535 |
self.layer_id = layer_id
|
536 |
-
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
537 |
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
538 |
-
self.
|
|
|
539 |
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
|
|
540 |
|
541 |
self.adaLN_modulation = nn.Sequential(
|
542 |
nn.SiLU(),
|
543 |
nn.Linear(
|
544 |
min(dim, 1024),
|
545 |
-
|
546 |
bias=True,
|
547 |
),
|
548 |
)
|
@@ -571,14 +572,11 @@ class TransformerBlock(nn.Module):
|
|
571 |
|
572 |
"""
|
573 |
if adaln_input is not None:
|
574 |
-
|
575 |
-
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
576 |
-
)
|
577 |
|
578 |
-
x = x + self.
|
579 |
-
|
580 |
-
|
581 |
-
modulate(self.attention_norm(x), shift_msa, scale_msa),
|
582 |
x_mask,
|
583 |
freqs_cis,
|
584 |
self.attention_y_norm(y),
|
@@ -586,10 +584,9 @@ class TransformerBlock(nn.Module):
|
|
586 |
)
|
587 |
)
|
588 |
d = x.shape[-1]
|
589 |
-
x = x + self.
|
590 |
-
|
591 |
-
|
592 |
-
modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
|
593 |
).view(*x.shape)
|
594 |
)
|
595 |
|
@@ -633,14 +630,14 @@ class ParallelFinalLayer(nn.Module):
|
|
633 |
nn.SiLU(),
|
634 |
nn.Linear(
|
635 |
min(hidden_size, 1024),
|
636 |
-
|
637 |
bias=True,
|
638 |
),
|
639 |
)
|
640 |
|
641 |
def forward(self, x, c):
|
642 |
-
|
643 |
-
x = modulate(self.norm_final(x),
|
644 |
x = self.linear(x)
|
645 |
return x
|
646 |
|
|
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
+
def modulate(x, scale):
|
18 |
+
return x * (1 + scale.unsqueeze(1))
|
19 |
|
20 |
|
21 |
#############################################################################
|
|
|
533 |
ffn_dim_multiplier=ffn_dim_multiplier,
|
534 |
)
|
535 |
self.layer_id = layer_id
|
|
|
536 |
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
537 |
+
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
538 |
+
|
539 |
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
540 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
541 |
|
542 |
self.adaLN_modulation = nn.Sequential(
|
543 |
nn.SiLU(),
|
544 |
nn.Linear(
|
545 |
min(dim, 1024),
|
546 |
+
2 * dim,
|
547 |
bias=True,
|
548 |
),
|
549 |
)
|
|
|
572 |
|
573 |
"""
|
574 |
if adaln_input is not None:
|
575 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
|
|
|
|
576 |
|
577 |
+
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
578 |
+
self.attention(
|
579 |
+
modulate(self.attention_norm1(x), scale_msa),
|
|
|
580 |
x_mask,
|
581 |
freqs_cis,
|
582 |
self.attention_y_norm(y),
|
|
|
584 |
)
|
585 |
)
|
586 |
d = x.shape[-1]
|
587 |
+
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
588 |
+
self.feed_forward(
|
589 |
+
modulate(self.ffn_norm1(x), scale_mlp).view(-1, d),
|
|
|
590 |
).view(*x.shape)
|
591 |
)
|
592 |
|
|
|
630 |
nn.SiLU(),
|
631 |
nn.Linear(
|
632 |
min(hidden_size, 1024),
|
633 |
+
hidden_size,
|
634 |
bias=True,
|
635 |
),
|
636 |
)
|
637 |
|
638 |
def forward(self, x, c):
|
639 |
+
scale = self.adaLN_modulation(c)
|
640 |
+
x = modulate(self.norm_final(x), scale)
|
641 |
x = self.linear(x)
|
642 |
return x
|
643 |
|