Spaces:
Running
on
A100
Running
on
A100
transformer3d: init mode xora never happens because lower case needed.
Browse files
xora/models/transformers/transformer3d.py
CHANGED
@@ -186,14 +186,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
186 |
|
187 |
# Zero-out adaLN modulation layers in PixArt blocks:
|
188 |
for block in self.transformer_blocks:
|
189 |
-
if mode == "xora":
|
190 |
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
191 |
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
192 |
|
193 |
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
194 |
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
195 |
|
196 |
-
if mode == "xora":
|
197 |
nn.init.constant_(block.ff.net[2].weight, 0)
|
198 |
nn.init.constant_(block.ff.net[2].bias, 0)
|
199 |
|
|
|
186 |
|
187 |
# Zero-out adaLN modulation layers in PixArt blocks:
|
188 |
for block in self.transformer_blocks:
|
189 |
+
if mode.lower() == "xora":
|
190 |
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
191 |
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
192 |
|
193 |
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
194 |
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
195 |
|
196 |
+
if mode.lower() == "xora":
|
197 |
nn.init.constant_(block.ff.net[2].weight, 0)
|
198 |
nn.init.constant_(block.ff.net[2].bias, 0)
|
199 |
|