File size: 610 Bytes
7f51798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import torch

from modules.diffusionmodules.model import SimpleDecoder, Decoder

# decoder = SimpleDecoder(4,32)
# https://github.com/CompVis/stable-diffusion/blob/main/configs/autoencoder/autoencoder_kl_32x32x4.yaml
decoder = Decoder(
    ch=64,
    out_ch=32,
    ch_mult=(1, 2),
    num_res_blocks=2,
    # num_res_blocks=1,
    dropout=0.0,
    attn_resolutions=(),
    z_channels=4,
    resolution=64,
    in_channels=3,
).cuda()

input_tensor = torch.randn(1,4,32,32,).cuda()
# input_tensor = torch.randn(
#     1,
#     96,
#     32,
#     32,
# )

print(decoder(input_tensor).shape)