File size: 609 Bytes
fc8623e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from models import VQVAE, build_vae_var
from dataset.imagenet_dataset import get_train_transforms
from PIL import Image
from torchvision import transforms


device = 'mps'
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

vae, var = build_vae_var(
    V=4096, Cvae=32, ch=160, share_quant_resi=4,
    device=device, patch_nums=patch_nums,
    num_classes=1000, depth=16, shared_aln=False,
)
var_ckpt='var_d16.pth'
vae_ckpt='vae_ch160v4096z32.pth'
var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)
vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)