Spaces:
Sleeping
Sleeping
File size: 4,509 Bytes
eaf2e33 3582c8a eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import logging
from src.ddpm.diffusion import Diffusion
from src.ddpm.modules import UNet
from src.ddpm.dataset import create_dataloader
from pathlib import Path
import argparse
import datetime
from src.gan.gankits import process_onehot, get_decoder
from src.smb.level import MarioLevel, lvlhcat, save_batch
from src.utils.filesys import getpath
from src.utils.img import make_img_sheet
sprite_counts = np.power(np.array([
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
]), 1/4
)
min_count = np.min(sprite_counts)
def setup_logging(run_name, beta_schedule):
model_path = os.path.join("models", beta_schedule, run_name)
result_path = os.path.join("results", beta_schedule, run_name)
os.makedirs(model_path, exist_ok=True)
os.makedirs(result_path, exist_ok=True)
return model_path, result_path
# 测试DDPM的模型训练
def train(args):
path = getpath(args.res_path)
os.makedirs(path, exist_ok=True)
dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0)
device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}'
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(device=device, schedule=args.beta_schedule)
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
l = len(dataloader)
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
logging.info(f"Starting epoch {epoch}:")
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
for i, images in enumerate(dataloader):
images = images.to(device)
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
original_img = images.argmax(dim=1) # batch x 14 x 14
reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures)
rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch
mse_loss = mse(noise.float(), predicted_noise.float())
loss = 0.001 * rec_loss + mse_loss
epoch_loss['rec_loss'] += rec_loss.item()
epoch_loss['mse'] += mse_loss.item()
epoch_loss['loss'] += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(
'\nIteration: %d' % epoch,
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
'mse: %.5g' % (epoch_loss['mse']/l)
)
if epoch % 1000 == 0:
itpath = getpath(path, f'it{epoch}')
os.makedirs(itpath, exist_ok=True)
model.save(getpath(path, itpath, 'ddpm.pth'))
lvls = []
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')))
gan = get_decoder()
init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1))
i = 0
for init_seg_onehot in init_seg_onhots:
seg_onehots = diffusion.sample(model, n=25)[-1]
a = init_seg_onehot.view(1, *init_seg_onehot.shape)
b = seg_onehots.detach().cpu()
print(a.shape, b.shape)
segs = process_onehot(torch.cat([a, b], dim=0))
level = lvlhcat(segs)
lvls.append(level)
save_batch(lvls, getpath(path, 'samples.lvls'))
model.save(getpath(path, 'ddpm.pth'))
def launch():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
parser.add_argument("--gpuid", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
parser.add_argument("--resume_from", type=int, default=0)
args = parser.parse_args()
train(args)
if __name__ == "__main__":
launch()
|