Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
import logging | |
from src.ddpm.diffusion import Diffusion | |
from src.ddpm.modules import UNet | |
from src.ddpm.dataset import create_dataloader | |
from pathlib import Path | |
import argparse | |
import datetime | |
from src.gan.gankits import process_onehot, get_decoder | |
from src.smb.level import MarioLevel, lvlhcat, save_batch | |
from src.utils.filesys import getpath | |
from src.utils.img import make_img_sheet | |
sprite_counts = np.power(np.array([ | |
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428 | |
]), 1/4 | |
) | |
min_count = np.min(sprite_counts) | |
def setup_logging(run_name, beta_schedule): | |
model_path = os.path.join("models", beta_schedule, run_name) | |
result_path = os.path.join("results", beta_schedule, run_name) | |
os.makedirs(model_path, exist_ok=True) | |
os.makedirs(result_path, exist_ok=True) | |
return model_path, result_path | |
# 测试DDPM的模型训练 | |
def train(args): | |
path = getpath(args.res_path) | |
os.makedirs(path, exist_ok=True) | |
dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0) | |
device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}' | |
model = UNet().to(device) | |
optimizer = optim.AdamW(model.parameters(), lr=args.lr) | |
mse = nn.MSELoss() | |
diffusion = Diffusion(device=device, schedule=args.beta_schedule) | |
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device) | |
l = len(dataloader) | |
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1): | |
logging.info(f"Starting epoch {epoch}:") | |
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0} | |
for i, images in enumerate(dataloader): | |
images = images.to(device) | |
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000 | |
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise | |
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta | |
original_img = images.argmax(dim=1) # batch x 14 x 14 | |
reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures) | |
rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch | |
mse_loss = mse(noise.float(), predicted_noise.float()) | |
loss = 0.001 * rec_loss + mse_loss | |
epoch_loss['rec_loss'] += rec_loss.item() | |
epoch_loss['mse'] += mse_loss.item() | |
epoch_loss['loss'] += loss.item() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
print( | |
'\nIteration: %d' % epoch, | |
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l), | |
'mse: %.5g' % (epoch_loss['mse']/l) | |
) | |
if epoch % 1000 == 0: | |
itpath = getpath(path, f'it{epoch}') | |
os.makedirs(itpath, exist_ok=True) | |
model.save(getpath(path, itpath, 'ddpm.pth')) | |
lvls = [] | |
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy'))) | |
gan = get_decoder() | |
init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1)) | |
i = 0 | |
for init_seg_onehot in init_seg_onhots: | |
seg_onehots = diffusion.sample(model, n=25)[-1] | |
a = init_seg_onehot.view(1, *init_seg_onehot.shape) | |
b = seg_onehots.detach().cpu() | |
print(a.shape, b.shape) | |
segs = process_onehot(torch.cat([a, b], dim=0)) | |
level = lvlhcat(segs) | |
lvls.append(level) | |
save_batch(lvls, getpath(path, 'samples.lvls')) | |
model.save(getpath(path, 'ddpm.pth')) | |
def launch(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--epochs", type=int, default=10000) | |
parser.add_argument("--batch_size", type=int, default=256) | |
parser.add_argument("--res_path", type=str, default='exp_data/DDPM') | |
parser.add_argument("--gpuid", type=int, default=0) | |
parser.add_argument("--lr", type=float, default=3e-4) | |
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid']) | |
parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}") | |
parser.add_argument("--resume_from", type=int, default=0) | |
args = parser.parse_args() | |
train(args) | |
if __name__ == "__main__": | |
launch() | |