Spaces:
Running
Running
bubbliiiing
commited on
Commit
·
43ed08d
1
Parent(s):
08038f7
add requirements
Browse files- easyanimate/vae/ldm/models/__init__.py +0 -0
- easyanimate/vae/ldm/models/autoencoder.py +337 -0
- easyanimate/vae/ldm/models/enc_dec_pytorch.py +234 -0
- easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +321 -0
- easyanimate/vae/ldm/models/omnigen_enc_dec.py +396 -0
- easyanimate/video_caption/datasets/put preprocess datasets here.txt +0 -0
easyanimate/vae/ldm/models/__init__.py
ADDED
File without changes
|
easyanimate/vae/ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from contextlib import contextmanager
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ..modules.diffusionmodules.model import Decoder, Encoder
|
9 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
10 |
+
from ..util import instantiate_from_config
|
11 |
+
from .enc_dec_pytorch import Decoder as Mag_Decoder
|
12 |
+
from .enc_dec_pytorch import Encoder as Mag_Encoder
|
13 |
+
|
14 |
+
|
15 |
+
class AutoencoderKLMagvit(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
ddconfig,
|
18 |
+
lossconfig,
|
19 |
+
embed_dim,
|
20 |
+
ckpt_path=None,
|
21 |
+
ignore_keys=[],
|
22 |
+
image_key="image",
|
23 |
+
colorize_nlabels=None,
|
24 |
+
monitor=None,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Mag_Encoder()
|
29 |
+
self.decoder = Mag_Decoder()
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
self.quant_conv = torch.nn.Conv3d(16, 16, 1)
|
32 |
+
self.post_quant_conv = torch.nn.Conv3d(8, 8, 1)
|
33 |
+
self.embed_dim = embed_dim
|
34 |
+
if colorize_nlabels is not None:
|
35 |
+
assert type(colorize_nlabels)==int
|
36 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
37 |
+
if monitor is not None:
|
38 |
+
self.monitor = monitor
|
39 |
+
if ckpt_path is not None:
|
40 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
41 |
+
|
42 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
43 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
44 |
+
keys = list(sd.keys())
|
45 |
+
for k in keys:
|
46 |
+
for ik in ignore_keys:
|
47 |
+
if k.startswith(ik):
|
48 |
+
print("Deleting key {} from state_dict.".format(k))
|
49 |
+
del sd[k]
|
50 |
+
self.load_state_dict(sd, strict=False)
|
51 |
+
print(f"Restored from {path}")
|
52 |
+
|
53 |
+
def encode(self, x):
|
54 |
+
h = self.encoder(x)
|
55 |
+
moments = self.quant_conv(h)
|
56 |
+
posterior = DiagonalGaussianDistribution(moments)
|
57 |
+
return posterior
|
58 |
+
|
59 |
+
def decode(self, z):
|
60 |
+
z = self.post_quant_conv(z)
|
61 |
+
dec = self.decoder(z)
|
62 |
+
return dec
|
63 |
+
|
64 |
+
def forward(self, input, sample_posterior=True):
|
65 |
+
if input.ndim==4:
|
66 |
+
input = input.unsqueeze(2)
|
67 |
+
posterior = self.encode(input)
|
68 |
+
if sample_posterior:
|
69 |
+
z = posterior.sample()
|
70 |
+
else:
|
71 |
+
z = posterior.mode()
|
72 |
+
dec = self.decode(z)
|
73 |
+
return dec, posterior
|
74 |
+
|
75 |
+
def get_input(self, batch, k):
|
76 |
+
x = batch[k]
|
77 |
+
if x.ndim==5:
|
78 |
+
x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
|
79 |
+
return x
|
80 |
+
if len(x.shape) == 3:
|
81 |
+
x = x[..., None]
|
82 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
83 |
+
return x
|
84 |
+
|
85 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
86 |
+
# tic = time.time()
|
87 |
+
inputs = self.get_input(batch, self.image_key)
|
88 |
+
# print(f"get_input time {time.time() - tic}")
|
89 |
+
# tic = time.time()
|
90 |
+
reconstructions, posterior = self(inputs)
|
91 |
+
# print(f"model forward time {time.time() - tic}")
|
92 |
+
|
93 |
+
if optimizer_idx == 0:
|
94 |
+
# train encoder+decoder+logvar
|
95 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
96 |
+
last_layer=self.get_last_layer(), split="train")
|
97 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
98 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
99 |
+
# print(f"cal loss time {time.time() - tic}")
|
100 |
+
return aeloss
|
101 |
+
|
102 |
+
if optimizer_idx == 1:
|
103 |
+
# train the discriminator
|
104 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
105 |
+
last_layer=self.get_last_layer(), split="train")
|
106 |
+
|
107 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
108 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
109 |
+
# print(f"cal loss time {time.time() - tic}")
|
110 |
+
return discloss
|
111 |
+
|
112 |
+
def validation_step(self, batch, batch_idx):
|
113 |
+
with torch.no_grad():
|
114 |
+
inputs = self.get_input(batch, self.image_key)
|
115 |
+
reconstructions, posterior = self(inputs)
|
116 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
117 |
+
last_layer=self.get_last_layer(), split="val")
|
118 |
+
|
119 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
120 |
+
last_layer=self.get_last_layer(), split="val")
|
121 |
+
|
122 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
123 |
+
self.log_dict(log_dict_ae)
|
124 |
+
self.log_dict(log_dict_disc)
|
125 |
+
return self.log_dict
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
lr = self.learning_rate
|
129 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
130 |
+
list(self.decoder.parameters())+
|
131 |
+
list(self.quant_conv.parameters())+
|
132 |
+
list(self.post_quant_conv.parameters()),
|
133 |
+
lr=lr, betas=(0.5, 0.9))
|
134 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
135 |
+
lr=lr, betas=(0.5, 0.9))
|
136 |
+
return [opt_ae, opt_disc], []
|
137 |
+
|
138 |
+
def get_last_layer(self):
|
139 |
+
return self.decoder.conv_out.weight
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
143 |
+
log = dict()
|
144 |
+
x = self.get_input(batch, self.image_key)
|
145 |
+
x = x.to(self.device)
|
146 |
+
if not only_inputs:
|
147 |
+
xrec, posterior = self(x)
|
148 |
+
if x.shape[1] > 3:
|
149 |
+
# colorize with random projection
|
150 |
+
assert xrec.shape[1] > 3
|
151 |
+
x = self.to_rgb(x)
|
152 |
+
xrec = self.to_rgb(xrec)
|
153 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
154 |
+
log["reconstructions"] = xrec
|
155 |
+
log["inputs"] = x
|
156 |
+
return log
|
157 |
+
|
158 |
+
def to_rgb(self, x):
|
159 |
+
assert self.image_key == "segmentation"
|
160 |
+
if not hasattr(self, "colorize"):
|
161 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
162 |
+
x = F.conv2d(x, weight=self.colorize)
|
163 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
164 |
+
return x
|
165 |
+
|
166 |
+
class AutoencoderKL(pl.LightningModule):
|
167 |
+
def __init__(self,
|
168 |
+
ddconfig,
|
169 |
+
lossconfig,
|
170 |
+
embed_dim,
|
171 |
+
ckpt_path=None,
|
172 |
+
ignore_keys=[],
|
173 |
+
image_key="image",
|
174 |
+
colorize_nlabels=None,
|
175 |
+
monitor=None,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.image_key = image_key
|
179 |
+
self.encoder = Encoder(**ddconfig)
|
180 |
+
self.decoder = Decoder(**ddconfig)
|
181 |
+
self.loss = instantiate_from_config(lossconfig)
|
182 |
+
assert ddconfig["double_z"]
|
183 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
184 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
185 |
+
self.embed_dim = embed_dim
|
186 |
+
if colorize_nlabels is not None:
|
187 |
+
assert type(colorize_nlabels)==int
|
188 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
189 |
+
if monitor is not None:
|
190 |
+
self.monitor = monitor
|
191 |
+
if ckpt_path is not None:
|
192 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
193 |
+
|
194 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
195 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
196 |
+
keys = list(sd.keys())
|
197 |
+
for k in keys:
|
198 |
+
for ik in ignore_keys:
|
199 |
+
if k.startswith(ik):
|
200 |
+
print("Deleting key {} from state_dict.".format(k))
|
201 |
+
del sd[k]
|
202 |
+
self.load_state_dict(sd, strict=False)
|
203 |
+
print(f"Restored from {path}")
|
204 |
+
|
205 |
+
def encode(self, x):
|
206 |
+
h = self.encoder(x)
|
207 |
+
moments = self.quant_conv(h)
|
208 |
+
posterior = DiagonalGaussianDistribution(moments)
|
209 |
+
return posterior
|
210 |
+
|
211 |
+
def decode(self, z):
|
212 |
+
z = self.post_quant_conv(z)
|
213 |
+
dec = self.decoder(z)
|
214 |
+
return dec
|
215 |
+
|
216 |
+
def forward(self, input, sample_posterior=True):
|
217 |
+
posterior = self.encode(input)
|
218 |
+
if sample_posterior:
|
219 |
+
z = posterior.sample()
|
220 |
+
else:
|
221 |
+
z = posterior.mode()
|
222 |
+
dec = self.decode(z)
|
223 |
+
return dec, posterior
|
224 |
+
|
225 |
+
def get_input(self, batch, k):
|
226 |
+
x = batch[k]
|
227 |
+
if len(x.shape) == 3:
|
228 |
+
x = x[..., None]
|
229 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
230 |
+
return x
|
231 |
+
|
232 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
233 |
+
# tic = time.time()
|
234 |
+
inputs = self.get_input(batch, self.image_key)
|
235 |
+
# print(f"get_input time {time.time() - tic}")
|
236 |
+
# tic = time.time()
|
237 |
+
reconstructions, posterior = self(inputs)
|
238 |
+
# print(f"model forward time {time.time() - tic}")
|
239 |
+
tic = time.time()
|
240 |
+
|
241 |
+
if optimizer_idx == 0:
|
242 |
+
# train encoder+decoder+logvar
|
243 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
244 |
+
last_layer=self.get_last_layer(), split="train")
|
245 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
246 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
247 |
+
# print(f"cal loss time {time.time() - tic}")
|
248 |
+
return aeloss
|
249 |
+
|
250 |
+
if optimizer_idx == 1:
|
251 |
+
# train the discriminator
|
252 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
253 |
+
last_layer=self.get_last_layer(), split="train")
|
254 |
+
|
255 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
256 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
257 |
+
# print(f"cal loss time {time.time() - tic}")
|
258 |
+
return discloss
|
259 |
+
|
260 |
+
def validation_step(self, batch, batch_idx):
|
261 |
+
tic = time.time()
|
262 |
+
inputs = self.get_input(batch, self.image_key)
|
263 |
+
print(f"get_input time {time.time() - tic}")
|
264 |
+
tic = time.time()
|
265 |
+
reconstructions, posterior = self(inputs)
|
266 |
+
print(f"val forward time {time.time() - tic}")
|
267 |
+
tic = time.time()
|
268 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
269 |
+
last_layer=self.get_last_layer(), split="val")
|
270 |
+
|
271 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
272 |
+
last_layer=self.get_last_layer(), split="val")
|
273 |
+
|
274 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
275 |
+
self.log_dict(log_dict_ae)
|
276 |
+
self.log_dict(log_dict_disc)
|
277 |
+
print(f"val end time {time.time() - tic}")
|
278 |
+
return self.log_dict
|
279 |
+
|
280 |
+
def configure_optimizers(self):
|
281 |
+
lr = self.learning_rate
|
282 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
283 |
+
list(self.decoder.parameters())+
|
284 |
+
list(self.quant_conv.parameters())+
|
285 |
+
list(self.post_quant_conv.parameters()),
|
286 |
+
lr=lr, betas=(0.5, 0.9))
|
287 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
288 |
+
lr=lr, betas=(0.5, 0.9))
|
289 |
+
return [opt_ae, opt_disc], []
|
290 |
+
|
291 |
+
def get_last_layer(self):
|
292 |
+
return self.decoder.conv_out.weight
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
296 |
+
log = dict()
|
297 |
+
x = self.get_input(batch, self.image_key)
|
298 |
+
x = x.to(self.device)
|
299 |
+
if not only_inputs:
|
300 |
+
xrec, posterior = self(x)
|
301 |
+
if x.shape[1] > 3:
|
302 |
+
# colorize with random projection
|
303 |
+
assert xrec.shape[1] > 3
|
304 |
+
x = self.to_rgb(x)
|
305 |
+
xrec = self.to_rgb(xrec)
|
306 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
307 |
+
log["reconstructions"] = xrec
|
308 |
+
log["inputs"] = x
|
309 |
+
return log
|
310 |
+
|
311 |
+
def to_rgb(self, x):
|
312 |
+
assert self.image_key == "segmentation"
|
313 |
+
if not hasattr(self, "colorize"):
|
314 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
315 |
+
x = F.conv2d(x, weight=self.colorize)
|
316 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
class IdentityFirstStage(torch.nn.Module):
|
321 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
322 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
323 |
+
super().__init__()
|
324 |
+
|
325 |
+
def encode(self, x, *args, **kwargs):
|
326 |
+
return x
|
327 |
+
|
328 |
+
def decode(self, x, *args, **kwargs):
|
329 |
+
return x
|
330 |
+
|
331 |
+
def quantize(self, x, *args, **kwargs):
|
332 |
+
if self.vq_interface:
|
333 |
+
return x, None, [None, None, None]
|
334 |
+
return x
|
335 |
+
|
336 |
+
def forward(self, x, *args, **kwargs):
|
337 |
+
return x
|
easyanimate/vae/ldm/models/enc_dec_pytorch.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
def cast_tuple(t, length = 1):
|
7 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
8 |
+
|
9 |
+
def divisible_by(num, den):
|
10 |
+
return (num % den) == 0
|
11 |
+
|
12 |
+
def is_odd(n):
|
13 |
+
return not divisible_by(n, 2)
|
14 |
+
|
15 |
+
class CausalConv3d(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
chan_in,
|
19 |
+
chan_out,
|
20 |
+
kernel_size,
|
21 |
+
pad_mode = 'constant',
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
26 |
+
|
27 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
28 |
+
|
29 |
+
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
30 |
+
|
31 |
+
dilation = kwargs.pop('dilation', 1)
|
32 |
+
stride = kwargs.pop('stride', 1)
|
33 |
+
|
34 |
+
self.pad_mode = pad_mode
|
35 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
36 |
+
height_pad = height_kernel_size // 2
|
37 |
+
width_pad = width_kernel_size // 2
|
38 |
+
|
39 |
+
self.time_pad = time_pad
|
40 |
+
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
41 |
+
|
42 |
+
stride = (stride, 1, 1)
|
43 |
+
dilation = (dilation, 1, 1)
|
44 |
+
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = F.pad(x, self.time_causal_padding, mode = 'replicate')
|
48 |
+
return self.conv(x)
|
49 |
+
|
50 |
+
class Swish(nn.Module):
|
51 |
+
def __init__(self) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return x * F.sigmoid(x)
|
56 |
+
|
57 |
+
class ResBlockX(nn.Module):
|
58 |
+
def __init__(self, inchannel) -> None:
|
59 |
+
super().__init__()
|
60 |
+
self.conv = nn.Sequential(
|
61 |
+
nn.GroupNorm(32, inchannel),
|
62 |
+
Swish(),
|
63 |
+
CausalConv3d(inchannel, inchannel, 3),
|
64 |
+
nn.GroupNorm(32, inchannel),
|
65 |
+
Swish(),
|
66 |
+
CausalConv3d(inchannel, inchannel, 3)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return x + self.conv(x)
|
71 |
+
|
72 |
+
class ResBlockXY(nn.Module):
|
73 |
+
def __init__(self, inchannel, outchannel) -> None:
|
74 |
+
super().__init__()
|
75 |
+
self.conv = nn.Sequential(
|
76 |
+
nn.GroupNorm(32, inchannel),
|
77 |
+
Swish(),
|
78 |
+
CausalConv3d(inchannel, outchannel, 3),
|
79 |
+
nn.GroupNorm(32, outchannel),
|
80 |
+
Swish(),
|
81 |
+
CausalConv3d(outchannel, outchannel, 3)
|
82 |
+
)
|
83 |
+
self.conv_1 = nn.Conv3d(inchannel, outchannel, 1)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return self.conv_1(x) + self.conv(x)
|
87 |
+
|
88 |
+
class PoolDown222(nn.Module):
|
89 |
+
def __init__(self) -> None:
|
90 |
+
super().__init__()
|
91 |
+
self.pool = nn.AvgPool3d(2, 2)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
x = F.pad(x, (0, 0, 0, 0, 1, 0), 'replicate')
|
95 |
+
return self.pool(x)
|
96 |
+
|
97 |
+
class PoolDown122(nn.Module):
|
98 |
+
def __init__(self) -> None:
|
99 |
+
super().__init__()
|
100 |
+
self.pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
return self.pool(x)
|
104 |
+
|
105 |
+
class Unpool222(nn.Module):
|
106 |
+
def __init__(self) -> None:
|
107 |
+
super().__init__()
|
108 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
x = self.up(x)
|
112 |
+
return x[:, :, 1:]
|
113 |
+
|
114 |
+
class Unpool122(nn.Module):
|
115 |
+
def __init__(self) -> None:
|
116 |
+
super().__init__()
|
117 |
+
self.up = nn.Upsample(scale_factor=(1, 2, 2), mode='nearest')
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.up(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
class ResBlockDown(nn.Module):
|
124 |
+
def __init__(self, inchannel, outchannel) -> None:
|
125 |
+
super().__init__()
|
126 |
+
self.blcok = nn.Sequential(
|
127 |
+
CausalConv3d(inchannel, outchannel, 3),
|
128 |
+
nn.LeakyReLU(inplace=True),
|
129 |
+
PoolDown222(),
|
130 |
+
CausalConv3d(outchannel, outchannel, 3),
|
131 |
+
nn.LeakyReLU(inplace=True)
|
132 |
+
)
|
133 |
+
self.res = nn.Sequential(
|
134 |
+
PoolDown222(),
|
135 |
+
nn.Conv3d(inchannel, outchannel, 1)
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
return self.res(x) + self.blcok(x)
|
140 |
+
|
141 |
+
|
142 |
+
class Discriminator(nn.Module):
|
143 |
+
def __init__(self) -> None:
|
144 |
+
super().__init__()
|
145 |
+
self.block = nn.Sequential(
|
146 |
+
CausalConv3d(3, 64, 3),
|
147 |
+
nn.LeakyReLU(inplace=True),
|
148 |
+
ResBlockDown(64, 128),
|
149 |
+
ResBlockDown(128, 256),
|
150 |
+
ResBlockDown(256, 256),
|
151 |
+
ResBlockDown(256, 256),
|
152 |
+
ResBlockDown(256, 256),
|
153 |
+
CausalConv3d(256, 256, 3),
|
154 |
+
nn.LeakyReLU(inplace=True),
|
155 |
+
nn.AdaptiveAvgPool3d(1),
|
156 |
+
nn.Flatten(),
|
157 |
+
nn.Linear(256, 256),
|
158 |
+
nn.LeakyReLU(inplace=True),
|
159 |
+
nn.Linear(256, 1)
|
160 |
+
)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
if x.ndim==4:
|
164 |
+
x = x.unsqueeze(2)
|
165 |
+
return self.block(x)
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
class Encoder(nn.Module):
|
170 |
+
def __init__(self) -> None:
|
171 |
+
super().__init__()
|
172 |
+
self.encoder = nn.Sequential(
|
173 |
+
CausalConv3d(3, 64, 3),
|
174 |
+
ResBlockX(64),
|
175 |
+
ResBlockX(64),
|
176 |
+
PoolDown222(),
|
177 |
+
ResBlockXY(64, 128),
|
178 |
+
ResBlockX(128),
|
179 |
+
PoolDown222(),
|
180 |
+
ResBlockX(128),
|
181 |
+
ResBlockX(128),
|
182 |
+
PoolDown122(),
|
183 |
+
ResBlockXY(128, 256),
|
184 |
+
ResBlockX(256),
|
185 |
+
ResBlockX(256),
|
186 |
+
ResBlockX(256),
|
187 |
+
nn.GroupNorm(32, 256),
|
188 |
+
Swish(),
|
189 |
+
nn.Conv3d(256, 16, 1)
|
190 |
+
)
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
return self.encoder(x)
|
194 |
+
|
195 |
+
class Decoder(nn.Module):
|
196 |
+
def __init__(self) -> None:
|
197 |
+
super().__init__()
|
198 |
+
self.decoder = nn.Sequential(
|
199 |
+
CausalConv3d(8, 256, 3),
|
200 |
+
ResBlockX(256),
|
201 |
+
ResBlockX(256),
|
202 |
+
ResBlockX(256),
|
203 |
+
ResBlockX(256),
|
204 |
+
Unpool122(),
|
205 |
+
CausalConv3d(256, 256, 3),
|
206 |
+
ResBlockXY(256, 128),
|
207 |
+
ResBlockX(128),
|
208 |
+
Unpool222(),
|
209 |
+
CausalConv3d(128, 128, 3),
|
210 |
+
ResBlockX(128),
|
211 |
+
ResBlockX(128),
|
212 |
+
Unpool222(),
|
213 |
+
CausalConv3d(128, 128, 3),
|
214 |
+
ResBlockXY(128, 64),
|
215 |
+
ResBlockX(64),
|
216 |
+
nn.GroupNorm(32, 64),
|
217 |
+
Swish(),
|
218 |
+
CausalConv3d(64, 64, 3)
|
219 |
+
)
|
220 |
+
self.conv_out = nn.Conv3d(64, 3, 1)
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
return self.conv_out(self.decoder(x))
|
224 |
+
|
225 |
+
|
226 |
+
if __name__=='__main__':
|
227 |
+
encoder = Encoder()
|
228 |
+
decoder = Decoder()
|
229 |
+
dis = Discriminator()
|
230 |
+
x = torch.randn((1, 3, 1, 64, 64))
|
231 |
+
embedding = encoder(x)
|
232 |
+
y = decoder(embedding)
|
233 |
+
tmp = torch.randn((1, 4, 1, 64, 64))
|
234 |
+
print('something mmm')
|
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from ..util import instantiate_from_config
|
12 |
+
from .omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
13 |
+
from .omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
14 |
+
|
15 |
+
|
16 |
+
class DiagonalGaussianDistribution:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
mean: torch.Tensor,
|
20 |
+
logvar: torch.Tensor,
|
21 |
+
deterministic: bool = False,
|
22 |
+
):
|
23 |
+
self.mean = mean
|
24 |
+
self.logvar = torch.clamp(logvar, -30.0, 20.0)
|
25 |
+
self.deterministic = deterministic
|
26 |
+
|
27 |
+
if deterministic:
|
28 |
+
self.var = self.std = torch.zeros_like(self.mean)
|
29 |
+
else:
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
|
33 |
+
def sample(self, generator = None) -> torch.FloatTensor:
|
34 |
+
x = torch.randn(
|
35 |
+
self.mean.shape,
|
36 |
+
generator=generator,
|
37 |
+
device=self.mean.device,
|
38 |
+
dtype=self.mean.dtype,
|
39 |
+
)
|
40 |
+
return self.mean + self.std * x
|
41 |
+
|
42 |
+
def mode(self):
|
43 |
+
return self.mean
|
44 |
+
|
45 |
+
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
46 |
+
dims = list(range(1, self.mean.ndim))
|
47 |
+
|
48 |
+
if self.deterministic:
|
49 |
+
return torch.Tensor([0.0])
|
50 |
+
else:
|
51 |
+
if other is None:
|
52 |
+
return 0.5 * torch.sum(
|
53 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
54 |
+
dim=dims,
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
return 0.5 * torch.sum(
|
58 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
59 |
+
+ self.var / other.var
|
60 |
+
- 1.0
|
61 |
+
- self.logvar
|
62 |
+
+ other.logvar,
|
63 |
+
dim=dims,
|
64 |
+
)
|
65 |
+
|
66 |
+
def nll(self, sample: torch.Tensor) -> torch.Tensor:
|
67 |
+
dims = list(range(1, self.mean.ndim))
|
68 |
+
|
69 |
+
if self.deterministic:
|
70 |
+
return torch.Tensor([0.0])
|
71 |
+
|
72 |
+
logtwopi = np.log(2.0 * np.pi)
|
73 |
+
return 0.5 * torch.sum(
|
74 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
75 |
+
dim=dims,
|
76 |
+
)
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class EncoderOutput:
|
80 |
+
latent_dist: DiagonalGaussianDistribution
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class DecoderOutput:
|
84 |
+
sample: torch.Tensor
|
85 |
+
|
86 |
+
def str_eval(item):
|
87 |
+
if type(item) == str:
|
88 |
+
return eval(item)
|
89 |
+
else:
|
90 |
+
return item
|
91 |
+
|
92 |
+
class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
in_channels: int = 3,
|
96 |
+
out_channels: int = 3,
|
97 |
+
ch = 128,
|
98 |
+
ch_mult = [ 1,2,4,4 ],
|
99 |
+
use_gc_blocks = None,
|
100 |
+
down_block_types: tuple = None,
|
101 |
+
up_block_types: tuple = None,
|
102 |
+
mid_block_type: str = "MidBlock3D",
|
103 |
+
mid_block_use_attention: bool = True,
|
104 |
+
mid_block_attention_type: str = "3d",
|
105 |
+
mid_block_num_attention_heads: int = 1,
|
106 |
+
layers_per_block: int = 2,
|
107 |
+
act_fn: str = "silu",
|
108 |
+
num_attention_heads: int = 1,
|
109 |
+
latent_channels: int = 4,
|
110 |
+
norm_num_groups: int = 32,
|
111 |
+
image_key="image",
|
112 |
+
monitor=None,
|
113 |
+
ckpt_path=None,
|
114 |
+
lossconfig=None,
|
115 |
+
slice_compression_vae=False,
|
116 |
+
mini_batch_encoder=9,
|
117 |
+
mini_batch_decoder=3,
|
118 |
+
train_decoder_only=False,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
self.image_key = image_key
|
122 |
+
down_block_types = str_eval(down_block_types)
|
123 |
+
up_block_types = str_eval(up_block_types)
|
124 |
+
self.encoder = omnigen_Mag_Encoder(
|
125 |
+
in_channels=in_channels,
|
126 |
+
out_channels=latent_channels,
|
127 |
+
down_block_types=down_block_types,
|
128 |
+
ch = ch,
|
129 |
+
ch_mult = ch_mult,
|
130 |
+
use_gc_blocks=use_gc_blocks,
|
131 |
+
mid_block_type=mid_block_type,
|
132 |
+
mid_block_use_attention=mid_block_use_attention,
|
133 |
+
mid_block_attention_type=mid_block_attention_type,
|
134 |
+
mid_block_num_attention_heads=mid_block_num_attention_heads,
|
135 |
+
layers_per_block=layers_per_block,
|
136 |
+
norm_num_groups=norm_num_groups,
|
137 |
+
act_fn=act_fn,
|
138 |
+
num_attention_heads=num_attention_heads,
|
139 |
+
double_z=True,
|
140 |
+
slice_compression_vae=slice_compression_vae,
|
141 |
+
mini_batch_encoder=mini_batch_encoder,
|
142 |
+
)
|
143 |
+
|
144 |
+
self.decoder = omnigen_Mag_Decoder(
|
145 |
+
in_channels=latent_channels,
|
146 |
+
out_channels=out_channels,
|
147 |
+
up_block_types=up_block_types,
|
148 |
+
ch = ch,
|
149 |
+
ch_mult = ch_mult,
|
150 |
+
use_gc_blocks=use_gc_blocks,
|
151 |
+
mid_block_type=mid_block_type,
|
152 |
+
mid_block_use_attention=mid_block_use_attention,
|
153 |
+
mid_block_attention_type=mid_block_attention_type,
|
154 |
+
mid_block_num_attention_heads=mid_block_num_attention_heads,
|
155 |
+
layers_per_block=layers_per_block,
|
156 |
+
norm_num_groups=norm_num_groups,
|
157 |
+
act_fn=act_fn,
|
158 |
+
num_attention_heads=num_attention_heads,
|
159 |
+
slice_compression_vae=slice_compression_vae,
|
160 |
+
mini_batch_decoder=mini_batch_decoder,
|
161 |
+
)
|
162 |
+
|
163 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
164 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
165 |
+
|
166 |
+
self.mini_batch_encoder = mini_batch_encoder
|
167 |
+
self.mini_batch_decoder = mini_batch_decoder
|
168 |
+
self.train_decoder_only = train_decoder_only
|
169 |
+
if train_decoder_only:
|
170 |
+
self.encoder.requires_grad_(False)
|
171 |
+
self.quant_conv.requires_grad_(False)
|
172 |
+
if monitor is not None:
|
173 |
+
self.monitor = monitor
|
174 |
+
if ckpt_path is not None:
|
175 |
+
self.init_from_ckpt(ckpt_path, ignore_keys="loss")
|
176 |
+
if lossconfig is not None:
|
177 |
+
self.loss = instantiate_from_config(lossconfig)
|
178 |
+
|
179 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
180 |
+
if path.endswith("safetensors"):
|
181 |
+
from safetensors.torch import load_file, safe_open
|
182 |
+
sd = load_file(path)
|
183 |
+
else:
|
184 |
+
sd = torch.load(path, map_location="cpu")
|
185 |
+
if "state_dict" in list(sd.keys()):
|
186 |
+
sd = sd["state_dict"]
|
187 |
+
keys = list(sd.keys())
|
188 |
+
for k in keys:
|
189 |
+
for ik in ignore_keys:
|
190 |
+
if k.startswith(ik):
|
191 |
+
print("Deleting key {} from state_dict.".format(k))
|
192 |
+
del sd[k]
|
193 |
+
self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
|
194 |
+
print(f"Restored from {path}")
|
195 |
+
|
196 |
+
def encode(self, x: torch.Tensor) -> EncoderOutput:
|
197 |
+
h = self.encoder(x)
|
198 |
+
|
199 |
+
moments: torch.Tensor = self.quant_conv(h)
|
200 |
+
mean, logvar = moments.chunk(2, dim=1)
|
201 |
+
posterior = DiagonalGaussianDistribution(mean, logvar)
|
202 |
+
|
203 |
+
# return EncoderOutput(latent_dist=posterior)
|
204 |
+
return posterior
|
205 |
+
|
206 |
+
def decode(self, z: torch.Tensor) -> DecoderOutput:
|
207 |
+
z = self.post_quant_conv(z)
|
208 |
+
|
209 |
+
decoded = self.decoder(z)
|
210 |
+
|
211 |
+
# return DecoderOutput(sample=decoded)
|
212 |
+
return decoded
|
213 |
+
|
214 |
+
|
215 |
+
def forward(self, input, sample_posterior=True):
|
216 |
+
if input.ndim==4:
|
217 |
+
input = input.unsqueeze(2)
|
218 |
+
posterior = self.encode(input)
|
219 |
+
if sample_posterior:
|
220 |
+
z = posterior.sample()
|
221 |
+
else:
|
222 |
+
z = posterior.mode()
|
223 |
+
# print("stt latent shape", z.shape)
|
224 |
+
dec = self.decode(z)
|
225 |
+
return dec, posterior
|
226 |
+
|
227 |
+
def get_input(self, batch, k):
|
228 |
+
x = batch[k]
|
229 |
+
if x.ndim==5:
|
230 |
+
x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
|
231 |
+
return x
|
232 |
+
if len(x.shape) == 3:
|
233 |
+
x = x[..., None]
|
234 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
235 |
+
return x
|
236 |
+
|
237 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
238 |
+
# tic = time.time()
|
239 |
+
inputs = self.get_input(batch, self.image_key)
|
240 |
+
# print(f"get_input time {time.time() - tic}")
|
241 |
+
# tic = time.time()
|
242 |
+
reconstructions, posterior = self(inputs)
|
243 |
+
# print(f"model forward time {time.time() - tic}")
|
244 |
+
|
245 |
+
if optimizer_idx == 0:
|
246 |
+
# train encoder+decoder+logvar
|
247 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
248 |
+
last_layer=self.get_last_layer(), split="train")
|
249 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
250 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
251 |
+
# print(f"cal loss time {time.time() - tic}")
|
252 |
+
return aeloss
|
253 |
+
|
254 |
+
if optimizer_idx == 1:
|
255 |
+
# train the discriminator
|
256 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
257 |
+
last_layer=self.get_last_layer(), split="train")
|
258 |
+
|
259 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
260 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
261 |
+
# print(f"cal loss time {time.time() - tic}")
|
262 |
+
return discloss
|
263 |
+
|
264 |
+
def validation_step(self, batch, batch_idx):
|
265 |
+
with torch.no_grad():
|
266 |
+
inputs = self.get_input(batch, self.image_key)
|
267 |
+
reconstructions, posterior = self(inputs)
|
268 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
269 |
+
last_layer=self.get_last_layer(), split="val")
|
270 |
+
|
271 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
272 |
+
last_layer=self.get_last_layer(), split="val")
|
273 |
+
|
274 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
275 |
+
self.log_dict(log_dict_ae)
|
276 |
+
self.log_dict(log_dict_disc)
|
277 |
+
return self.log_dict
|
278 |
+
|
279 |
+
def configure_optimizers(self):
|
280 |
+
lr = self.learning_rate
|
281 |
+
if self.train_decoder_only:
|
282 |
+
opt_ae = torch.optim.Adam(list(self.decoder.parameters())+
|
283 |
+
list(self.post_quant_conv.parameters()),
|
284 |
+
lr=lr, betas=(0.5, 0.9))
|
285 |
+
else:
|
286 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
287 |
+
list(self.decoder.parameters())+
|
288 |
+
list(self.quant_conv.parameters())+
|
289 |
+
list(self.post_quant_conv.parameters()),
|
290 |
+
lr=lr, betas=(0.5, 0.9))
|
291 |
+
opt_disc = torch.optim.Adam(list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
292 |
+
lr=lr, betas=(0.5, 0.9))
|
293 |
+
return [opt_ae, opt_disc], []
|
294 |
+
|
295 |
+
def get_last_layer(self):
|
296 |
+
return self.decoder.conv_out.weight
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
300 |
+
log = dict()
|
301 |
+
x = self.get_input(batch, self.image_key)
|
302 |
+
x = x.to(self.device)
|
303 |
+
if not only_inputs:
|
304 |
+
xrec, posterior = self(x)
|
305 |
+
if x.shape[1] > 3:
|
306 |
+
# colorize with random projection
|
307 |
+
assert xrec.shape[1] > 3
|
308 |
+
x = self.to_rgb(x)
|
309 |
+
xrec = self.to_rgb(xrec)
|
310 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
311 |
+
log["reconstructions"] = xrec
|
312 |
+
log["inputs"] = x
|
313 |
+
return log
|
314 |
+
|
315 |
+
def to_rgb(self, x):
|
316 |
+
assert self.image_key == "segmentation"
|
317 |
+
if not hasattr(self, "colorize"):
|
318 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
319 |
+
x = F.conv2d(x, weight=self.colorize)
|
320 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
321 |
+
return x
|
easyanimate/vae/ldm/models/omnigen_enc_dec.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from ..modules.vaemodules.activations import get_activation
|
5 |
+
from ..modules.vaemodules.common import CausalConv3d
|
6 |
+
from ..modules.vaemodules.down_blocks import get_down_block
|
7 |
+
from ..modules.vaemodules.mid_blocks import get_mid_block
|
8 |
+
from ..modules.vaemodules.up_blocks import get_up_block
|
9 |
+
|
10 |
+
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
r"""
|
13 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channels (`int`, *optional*, defaults to 3):
|
17 |
+
The number of input channels.
|
18 |
+
out_channels (`int`, *optional*, defaults to 8):
|
19 |
+
The number of output channels.
|
20 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`):
|
21 |
+
The types of down blocks to use.
|
22 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
23 |
+
The number of output channels for each block.
|
24 |
+
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
|
25 |
+
Whether to use global context blocks for each down block.
|
26 |
+
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
|
27 |
+
The type of mid block to use.
|
28 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
29 |
+
The number of layers per block.
|
30 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
31 |
+
The number of groups for normalization.
|
32 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
33 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 1):
|
35 |
+
The number of attention heads to use.
|
36 |
+
double_z (`bool`, *optional*, defaults to `True`):
|
37 |
+
Whether to double the number of output channels for the last block.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_channels: int = 3,
|
43 |
+
out_channels: int = 8,
|
44 |
+
down_block_types = ("SpatialDownBlock3D",),
|
45 |
+
ch = 128,
|
46 |
+
ch_mult = [1,2,4,4,],
|
47 |
+
use_gc_blocks = None,
|
48 |
+
mid_block_type: str = "MidBlock3D",
|
49 |
+
mid_block_use_attention: bool = True,
|
50 |
+
mid_block_attention_type: str = "3d",
|
51 |
+
mid_block_num_attention_heads: int = 1,
|
52 |
+
layers_per_block: int = 2,
|
53 |
+
norm_num_groups: int = 32,
|
54 |
+
act_fn: str = "silu",
|
55 |
+
num_attention_heads: int = 1,
|
56 |
+
double_z: bool = True,
|
57 |
+
slice_compression_vae: bool = False,
|
58 |
+
mini_batch_encoder: int = 9,
|
59 |
+
verbose = False,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
block_out_channels = [ch * i for i in ch_mult]
|
63 |
+
assert len(down_block_types) == len(block_out_channels), (
|
64 |
+
"Number of down block types must match number of block output channels."
|
65 |
+
)
|
66 |
+
if use_gc_blocks is not None:
|
67 |
+
assert len(use_gc_blocks) == len(down_block_types), (
|
68 |
+
"Number of GC blocks must match number of down block types."
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
use_gc_blocks = [False] * len(down_block_types)
|
72 |
+
self.conv_in = CausalConv3d(
|
73 |
+
in_channels,
|
74 |
+
block_out_channels[0],
|
75 |
+
kernel_size=3,
|
76 |
+
)
|
77 |
+
|
78 |
+
self.down_blocks = nn.ModuleList([])
|
79 |
+
|
80 |
+
output_channels = block_out_channels[0]
|
81 |
+
for i, down_block_type in enumerate(down_block_types):
|
82 |
+
input_channels = output_channels
|
83 |
+
output_channels = block_out_channels[i]
|
84 |
+
is_final_block = (i == len(block_out_channels) - 1)
|
85 |
+
down_block = get_down_block(
|
86 |
+
down_block_type,
|
87 |
+
in_channels=input_channels,
|
88 |
+
out_channels=output_channels,
|
89 |
+
num_layers=layers_per_block,
|
90 |
+
act_fn=act_fn,
|
91 |
+
norm_num_groups=norm_num_groups,
|
92 |
+
norm_eps=1e-6,
|
93 |
+
num_attention_heads=num_attention_heads,
|
94 |
+
add_gc_block=use_gc_blocks[i],
|
95 |
+
add_downsample=not is_final_block,
|
96 |
+
)
|
97 |
+
self.down_blocks.append(down_block)
|
98 |
+
|
99 |
+
self.mid_block = get_mid_block(
|
100 |
+
mid_block_type,
|
101 |
+
in_channels=block_out_channels[-1],
|
102 |
+
num_layers=layers_per_block,
|
103 |
+
act_fn=act_fn,
|
104 |
+
norm_num_groups=norm_num_groups,
|
105 |
+
norm_eps=1e-6,
|
106 |
+
add_attention=mid_block_use_attention,
|
107 |
+
attention_type=mid_block_attention_type,
|
108 |
+
num_attention_heads=mid_block_num_attention_heads,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.conv_norm_out = nn.GroupNorm(
|
112 |
+
num_channels=block_out_channels[-1],
|
113 |
+
num_groups=norm_num_groups,
|
114 |
+
eps=1e-6,
|
115 |
+
)
|
116 |
+
self.conv_act = get_activation(act_fn)
|
117 |
+
|
118 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
119 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
120 |
+
|
121 |
+
self.slice_compression_vae = slice_compression_vae
|
122 |
+
self.mini_batch_encoder = mini_batch_encoder
|
123 |
+
self.features_share = False
|
124 |
+
self.verbose = verbose
|
125 |
+
|
126 |
+
def set_padding_one_frame(self):
|
127 |
+
def _set_padding_one_frame(name, module):
|
128 |
+
if hasattr(module, 'padding_flag'):
|
129 |
+
if self.verbose:
|
130 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
131 |
+
module.padding_flag = 1
|
132 |
+
for sub_name, sub_mod in module.named_children():
|
133 |
+
_set_padding_one_frame(sub_name, sub_mod)
|
134 |
+
for name, module in self.named_children():
|
135 |
+
_set_padding_one_frame(name, module)
|
136 |
+
|
137 |
+
def set_padding_more_frame(self):
|
138 |
+
def _set_padding_more_frame(name, module):
|
139 |
+
if hasattr(module, 'padding_flag'):
|
140 |
+
if self.verbose:
|
141 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
142 |
+
module.padding_flag = 2
|
143 |
+
for sub_name, sub_mod in module.named_children():
|
144 |
+
_set_padding_more_frame(sub_name, sub_mod)
|
145 |
+
for name, module in self.named_children():
|
146 |
+
_set_padding_more_frame(name, module)
|
147 |
+
|
148 |
+
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
149 |
+
# x: (B, C, T, H, W)
|
150 |
+
if self.features_share and previous_features is not None and after_features is None:
|
151 |
+
x = torch.concat([previous_features, x], 2)
|
152 |
+
elif self.features_share and previous_features is None and after_features is not None:
|
153 |
+
x = torch.concat([x, after_features], 2)
|
154 |
+
elif self.features_share and previous_features is not None and after_features is not None:
|
155 |
+
x = torch.concat([previous_features, x, after_features], 2)
|
156 |
+
|
157 |
+
x = self.conv_in(x)
|
158 |
+
|
159 |
+
for down_block in self.down_blocks:
|
160 |
+
x = down_block(x)
|
161 |
+
|
162 |
+
x = self.mid_block(x)
|
163 |
+
|
164 |
+
x = self.conv_norm_out(x)
|
165 |
+
x = self.conv_act(x)
|
166 |
+
x = self.conv_out(x)
|
167 |
+
|
168 |
+
if self.features_share and previous_features is not None and after_features is None:
|
169 |
+
x = x[:, :, 1:]
|
170 |
+
elif self.features_share and previous_features is None and after_features is not None:
|
171 |
+
x = x[:, :, :2]
|
172 |
+
elif self.features_share and previous_features is not None and after_features is not None:
|
173 |
+
x = x[:, :, 1:3]
|
174 |
+
return x
|
175 |
+
|
176 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
177 |
+
if self.slice_compression_vae:
|
178 |
+
_, _, f, _, _ = x.size()
|
179 |
+
if f % 2 != 0:
|
180 |
+
self.set_padding_one_frame()
|
181 |
+
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
|
182 |
+
self.set_padding_more_frame()
|
183 |
+
|
184 |
+
new_pixel_values = [first_frames]
|
185 |
+
start_index = 1
|
186 |
+
else:
|
187 |
+
self.set_padding_more_frame()
|
188 |
+
new_pixel_values = []
|
189 |
+
start_index = 0
|
190 |
+
|
191 |
+
previous_features = None
|
192 |
+
for i in range(start_index, x.shape[2], self.mini_batch_encoder):
|
193 |
+
after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None
|
194 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features)
|
195 |
+
previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :]
|
196 |
+
new_pixel_values.append(next_frames)
|
197 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
198 |
+
else:
|
199 |
+
new_pixel_values = self.single_forward(x, None, None)
|
200 |
+
return new_pixel_values
|
201 |
+
|
202 |
+
class Decoder(nn.Module):
|
203 |
+
r"""
|
204 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
in_channels (`int`, *optional*, defaults to 8):
|
208 |
+
The number of input channels.
|
209 |
+
out_channels (`int`, *optional*, defaults to 3):
|
210 |
+
The number of output channels.
|
211 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`):
|
212 |
+
The types of up blocks to use.
|
213 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
214 |
+
The number of output channels for each block.
|
215 |
+
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
|
216 |
+
Whether to use global context blocks for each down block.
|
217 |
+
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
|
218 |
+
The type of mid block to use.
|
219 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
220 |
+
The number of layers per block.
|
221 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
222 |
+
The number of groups for normalization.
|
223 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
224 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
225 |
+
num_attention_heads (`int`, *optional*, defaults to 1):
|
226 |
+
The number of attention heads to use.
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
in_channels: int = 8,
|
232 |
+
out_channels: int = 3,
|
233 |
+
up_block_types = ("SpatialUpBlock3D",),
|
234 |
+
ch = 128,
|
235 |
+
ch_mult = [1,2,4,4,],
|
236 |
+
use_gc_blocks = None,
|
237 |
+
mid_block_type: str = "MidBlock3D",
|
238 |
+
mid_block_use_attention: bool = True,
|
239 |
+
mid_block_attention_type: str = "3d",
|
240 |
+
mid_block_num_attention_heads: int = 1,
|
241 |
+
layers_per_block: int = 2,
|
242 |
+
norm_num_groups: int = 32,
|
243 |
+
act_fn: str = "silu",
|
244 |
+
num_attention_heads: int = 1,
|
245 |
+
slice_compression_vae: bool = False,
|
246 |
+
mini_batch_decoder: int = 3,
|
247 |
+
verbose = False,
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
block_out_channels = [ch * i for i in ch_mult]
|
251 |
+
assert len(up_block_types) == len(block_out_channels), (
|
252 |
+
"Number of up block types must match number of block output channels."
|
253 |
+
)
|
254 |
+
if use_gc_blocks is not None:
|
255 |
+
assert len(use_gc_blocks) == len(up_block_types), (
|
256 |
+
"Number of GC blocks must match number of up block types."
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
use_gc_blocks = [False] * len(up_block_types)
|
260 |
+
|
261 |
+
self.conv_in = CausalConv3d(
|
262 |
+
in_channels,
|
263 |
+
block_out_channels[-1],
|
264 |
+
kernel_size=3,
|
265 |
+
)
|
266 |
+
|
267 |
+
self.mid_block = get_mid_block(
|
268 |
+
mid_block_type,
|
269 |
+
in_channels=block_out_channels[-1],
|
270 |
+
num_layers=layers_per_block,
|
271 |
+
act_fn=act_fn,
|
272 |
+
norm_num_groups=norm_num_groups,
|
273 |
+
norm_eps=1e-6,
|
274 |
+
add_attention=mid_block_use_attention,
|
275 |
+
attention_type=mid_block_attention_type,
|
276 |
+
num_attention_heads=mid_block_num_attention_heads,
|
277 |
+
)
|
278 |
+
|
279 |
+
self.up_blocks = nn.ModuleList([])
|
280 |
+
|
281 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
282 |
+
output_channels = reversed_block_out_channels[0]
|
283 |
+
for i, up_block_type in enumerate(up_block_types):
|
284 |
+
input_channels = output_channels
|
285 |
+
output_channels = reversed_block_out_channels[i]
|
286 |
+
# is_first_block = i == 0
|
287 |
+
is_final_block = i == len(block_out_channels) - 1
|
288 |
+
|
289 |
+
up_block = get_up_block(
|
290 |
+
up_block_type,
|
291 |
+
in_channels=input_channels,
|
292 |
+
out_channels=output_channels,
|
293 |
+
num_layers=layers_per_block + 1,
|
294 |
+
act_fn=act_fn,
|
295 |
+
norm_num_groups=norm_num_groups,
|
296 |
+
norm_eps=1e-6,
|
297 |
+
num_attention_heads=num_attention_heads,
|
298 |
+
add_gc_block=use_gc_blocks[i],
|
299 |
+
add_upsample=not is_final_block,
|
300 |
+
)
|
301 |
+
self.up_blocks.append(up_block)
|
302 |
+
|
303 |
+
self.conv_norm_out = nn.GroupNorm(
|
304 |
+
num_channels=block_out_channels[0],
|
305 |
+
num_groups=norm_num_groups,
|
306 |
+
eps=1e-6,
|
307 |
+
)
|
308 |
+
self.conv_act = get_activation(act_fn)
|
309 |
+
|
310 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
311 |
+
|
312 |
+
self.slice_compression_vae = slice_compression_vae
|
313 |
+
self.mini_batch_decoder = mini_batch_decoder
|
314 |
+
self.features_share = True
|
315 |
+
self.verbose = verbose
|
316 |
+
|
317 |
+
def set_padding_one_frame(self):
|
318 |
+
def _set_padding_one_frame(name, module):
|
319 |
+
if hasattr(module, 'padding_flag'):
|
320 |
+
if self.verbose:
|
321 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
322 |
+
module.padding_flag = 1
|
323 |
+
for sub_name, sub_mod in module.named_children():
|
324 |
+
_set_padding_one_frame(sub_name, sub_mod)
|
325 |
+
for name, module in self.named_children():
|
326 |
+
_set_padding_one_frame(name, module)
|
327 |
+
|
328 |
+
def set_padding_more_frame(self):
|
329 |
+
def _set_padding_more_frame(name, module):
|
330 |
+
if hasattr(module, 'padding_flag'):
|
331 |
+
if self.verbose:
|
332 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
333 |
+
module.padding_flag = 2
|
334 |
+
for sub_name, sub_mod in module.named_children():
|
335 |
+
_set_padding_more_frame(sub_name, sub_mod)
|
336 |
+
for name, module in self.named_children():
|
337 |
+
_set_padding_more_frame(name, module)
|
338 |
+
|
339 |
+
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
340 |
+
# x: (B, C, T, H, W)
|
341 |
+
if self.features_share and previous_features is not None and after_features is None:
|
342 |
+
b, c, t, h, w = x.size()
|
343 |
+
x = torch.concat([previous_features, x], 2)
|
344 |
+
x = self.conv_in(x)
|
345 |
+
x = self.mid_block(x)
|
346 |
+
x = x[:, :, -t:]
|
347 |
+
elif self.features_share and previous_features is None and after_features is not None:
|
348 |
+
b, c, t, h, w = x.size()
|
349 |
+
x = torch.concat([x, after_features], 2)
|
350 |
+
x = self.conv_in(x)
|
351 |
+
x = self.mid_block(x)
|
352 |
+
x = x[:, :, :t]
|
353 |
+
elif self.features_share and previous_features is not None and after_features is not None:
|
354 |
+
_, _, t_1, _, _ = previous_features.size()
|
355 |
+
_, _, t_2, _, _ = x.size()
|
356 |
+
x = torch.concat([previous_features, x, after_features], 2)
|
357 |
+
x = self.conv_in(x)
|
358 |
+
x = self.mid_block(x)
|
359 |
+
x = x[:, :, t_1:(t_1 + t_2)]
|
360 |
+
else:
|
361 |
+
x = self.conv_in(x)
|
362 |
+
x = self.mid_block(x)
|
363 |
+
|
364 |
+
for up_block in self.up_blocks:
|
365 |
+
x = up_block(x)
|
366 |
+
|
367 |
+
x = self.conv_norm_out(x)
|
368 |
+
x = self.conv_act(x)
|
369 |
+
x = self.conv_out(x)
|
370 |
+
|
371 |
+
return x
|
372 |
+
|
373 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
374 |
+
if self.slice_compression_vae:
|
375 |
+
_, _, f, _, _ = x.size()
|
376 |
+
if f % 2 != 0:
|
377 |
+
self.set_padding_one_frame()
|
378 |
+
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
|
379 |
+
self.set_padding_more_frame()
|
380 |
+
new_pixel_values = [first_frames]
|
381 |
+
start_index = 1
|
382 |
+
else:
|
383 |
+
self.set_padding_more_frame()
|
384 |
+
new_pixel_values = []
|
385 |
+
start_index = 0
|
386 |
+
|
387 |
+
previous_features = None
|
388 |
+
for i in range(start_index, x.shape[2], self.mini_batch_decoder):
|
389 |
+
after_features = x[:, :, i + self.mini_batch_decoder: i + 2 * self.mini_batch_decoder, :, :] if i + self.mini_batch_decoder < x.shape[2] else None
|
390 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], previous_features, after_features)
|
391 |
+
previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
|
392 |
+
new_pixel_values.append(next_frames)
|
393 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
394 |
+
else:
|
395 |
+
new_pixel_values = self.single_forward(x, None, None)
|
396 |
+
return new_pixel_values
|
easyanimate/video_caption/datasets/put preprocess datasets here.txt
ADDED
File without changes
|