File size: 4,509 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import logging
from src.ddpm.diffusion import Diffusion
from src.ddpm.modules import UNet
from src.ddpm.dataset import create_dataloader
from pathlib import Path
import argparse
import datetime

from src.gan.gankits import process_onehot, get_decoder
from src.smb.level import MarioLevel, lvlhcat, save_batch
from src.utils.filesys import getpath
from src.utils.img import make_img_sheet
sprite_counts = np.power(np.array([
        74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
    ]), 1/4
)
min_count = np.min(sprite_counts)


def setup_logging(run_name, beta_schedule):
    model_path = os.path.join("models", beta_schedule, run_name)
    result_path = os.path.join("results", beta_schedule, run_name)
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(result_path, exist_ok=True)
    return model_path, result_path

# 测试DDPM的模型训练
def train(args):
    path = getpath(args.res_path)
    os.makedirs(path, exist_ok=True)

    dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0)
    device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}'
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(device=device, schedule=args.beta_schedule)
    temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
    l = len(dataloader)


    for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
        logging.info(f"Starting epoch {epoch}:")
        epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
        for i, images in enumerate(dataloader):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)  # random int from 1~1000
            x_t, noise = diffusion.noise_images(images, t)  # x_t: image with noise at t, noise: gaussian noise
            predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta

            original_img = images.argmax(dim=1) # batch x 14 x 14
            reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures)
            rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch
            mse_loss = mse(noise.float(), predicted_noise.float())
            loss = 0.001 * rec_loss + mse_loss
            epoch_loss['rec_loss'] += rec_loss.item()
            epoch_loss['mse'] += mse_loss.item()
            epoch_loss['loss'] += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(
            '\nIteration: %d' % epoch,
            'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
            'mse: %.5g' % (epoch_loss['mse']/l)
        )

        if epoch % 1000 == 0:
            itpath = getpath(path, f'it{epoch}')
            os.makedirs(itpath, exist_ok=True)
            model.save(getpath(path, itpath, 'ddpm.pth'))


    lvls = []
    init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')))
    gan = get_decoder()
    init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1))
    i = 0
    for init_seg_onehot in init_seg_onhots:
        seg_onehots = diffusion.sample(model, n=25)[-1]
        a = init_seg_onehot.view(1, *init_seg_onehot.shape)
        b = seg_onehots.detach().cpu()
        print(a.shape, b.shape)
        segs = process_onehot(torch.cat([a, b], dim=0))
        level = lvlhcat(segs)
        lvls.append(level)
    save_batch(lvls, getpath(path, 'samples.lvls'))
    model.save(getpath(path, 'ddpm.pth'))

def launch():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
    parser.add_argument("--gpuid", type=int, default=0)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
    parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
    parser.add_argument("--resume_from", type=int, default=0)
    args = parser.parse_args()
    train(args)

if __name__ == "__main__":
    launch()