Spaces:
Running
Running
fix: causal_mask based on image tokens
Browse files
dalle_mini/model/modeling.py
CHANGED
@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
|
|
52 |
class FlaxBartAttention(FlaxBartAttention):
|
53 |
"""
|
54 |
Edits:
|
55 |
-
- causal mask considers
|
56 |
"""
|
57 |
|
58 |
def setup(self) -> None:
|
@@ -77,8 +77,9 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
77 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
78 |
|
79 |
if self.causal:
|
|
|
80 |
self.causal_mask = make_causal_mask(
|
81 |
-
jnp.ones((1, self.
|
82 |
)
|
83 |
|
84 |
|
|
|
52 |
class FlaxBartAttention(FlaxBartAttention):
|
53 |
"""
|
54 |
Edits:
|
55 |
+
- causal mask is used only in decoder and considers image_length + 1 (for BOS)
|
56 |
"""
|
57 |
|
58 |
def setup(self) -> None:
|
|
|
77 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
78 |
|
79 |
if self.causal:
|
80 |
+
# used only in decoder
|
81 |
self.causal_mask = make_causal_mask(
|
82 |
+
jnp.ones((1, self.config.image_length + 1), dtype="bool"), dtype="bool"
|
83 |
)
|
84 |
|
85 |
|